Compare commits
703 Commits
v0.2.1
...
v0.2.5-hot
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
130f15665c | ||
|
|
026e4376d4 | ||
|
|
cf571cf02b | ||
|
|
218671ef06 | ||
|
|
34de0bb9c5 | ||
|
|
8e6cf09056 | ||
|
|
5929072b76 | ||
|
|
aa69cd3a0c | ||
|
|
a726a81224 | ||
|
|
9aae6163f0 | ||
|
|
941527e7ee | ||
|
|
a3f05220d3 | ||
|
|
7446241735 | ||
|
|
6033d37537 | ||
|
|
1524d7b5ce | ||
|
|
e00341a4cc | ||
|
|
f5185d2e95 | ||
|
|
dc9003f9db | ||
|
|
07e0c70629 | ||
|
|
37f77e0990 | ||
|
|
aef1a57ea8 | ||
|
|
69af479224 | ||
|
|
f38223c97f | ||
|
|
1ac6702eb0 | ||
|
|
2510f60dce | ||
|
|
b9d7fb2598 | ||
|
|
a39ba564fa | ||
|
|
34310bfabe | ||
|
|
78fd189510 | ||
|
|
94836ed9af | ||
|
|
229eb5cc86 | ||
|
|
bbb2c6c903 | ||
|
|
5edf3f2b8a | ||
|
|
006c6cd159 | ||
|
|
9675982555 | ||
|
|
3ac8a9431b | ||
|
|
5c42a84c3e | ||
|
|
b9578bd08a | ||
|
|
035e56e42f | ||
|
|
5a90d4776d | ||
|
|
f81fdca62a | ||
|
|
729c283c63 | ||
|
|
c99f04314c | ||
|
|
dd9be2ed90 | ||
|
|
2327be7557 | ||
|
|
a7ffc19ba1 | ||
|
|
bbaa39c569 | ||
|
|
d1de0250e7 | ||
|
|
2d731c6412 | ||
|
|
6a6e64f487 | ||
|
|
b9201c918a | ||
|
|
7dedad898a | ||
|
|
d497189352 | ||
|
|
fa4da8f467 | ||
|
|
e9ff742162 | ||
|
|
3849cfb835 | ||
|
|
c453af23c6 | ||
|
|
bcf2376f5a | ||
|
|
be2f56ae6a | ||
|
|
cbc9602495 | ||
|
|
c72ce381c0 | ||
|
|
2ef54168fc | ||
|
|
b33ccf00f9 | ||
|
|
829eb4b3be | ||
|
|
6c49456c13 | ||
|
|
fc8f06ee14 | ||
|
|
120a524b7e | ||
|
|
bd037ac3a3 | ||
|
|
b8ea427029 | ||
|
|
275be47224 | ||
|
|
4ea9c7e660 | ||
|
|
92d78d9a52 | ||
|
|
a820001eea | ||
|
|
8273f6d217 | ||
|
|
bd63e0fce8 | ||
|
|
12ba3d473e | ||
|
|
0b9cc0f068 | ||
|
|
5ca397befa | ||
|
|
da735fe776 | ||
|
|
b4f69f2cff | ||
|
|
1885c00cbc | ||
|
|
1e4fdeb1a6 | ||
|
|
cb7dbb0ed4 | ||
|
|
44083aec79 | ||
|
|
4a9b743153 | ||
|
|
b462e17a5b | ||
|
|
b272a52b57 | ||
|
|
3f87c64e83 | ||
|
|
1795364f5f | ||
|
|
e69fbb2f97 | ||
|
|
32b40fc6bf | ||
|
|
f039ea7f56 | ||
|
|
41334f5f1e | ||
|
|
79b19b744e | ||
|
|
2103410694 | ||
|
|
2143d94e83 | ||
|
|
9ae2612945 | ||
|
|
3a09b26b6d | ||
|
|
e381449aec | ||
|
|
bacffc94d9 | ||
|
|
7044f705e7 | ||
|
|
6db4fe28a7 | ||
|
|
f966176694 | ||
|
|
bd24de4577 | ||
|
|
dc2ea5c007 | ||
|
|
4fb673077a | ||
|
|
b3a136ac03 | ||
|
|
22f1bfa3fa | ||
|
|
f6ad0aab94 | ||
|
|
371fdeb948 | ||
|
|
f7a0af75c4 | ||
|
|
b31e526e4d | ||
|
|
26abf7b586 | ||
|
|
d477e24e34 | ||
|
|
3ca3e8e023 | ||
|
|
3bd374495b | ||
|
|
b26f60ee8d | ||
|
|
df681eaf22 | ||
|
|
01458ac111 | ||
|
|
6c7a68802b | ||
|
|
e3074b833f | ||
|
|
1097d699f8 | ||
|
|
55b4e0ebd3 | ||
|
|
0011a8ce9f | ||
|
|
100bf4fa49 | ||
|
|
6da5b81311 | ||
|
|
787adf5423 | ||
|
|
01b500e7d1 | ||
|
|
e64603ea27 | ||
|
|
4219e12cc0 | ||
|
|
c86ccf0931 | ||
|
|
d4571fb75b | ||
|
|
ec2369c397 | ||
|
|
6ebd48408b | ||
|
|
7e7b54593c | ||
|
|
f93c9f5cd2 | ||
|
|
a810fbe008 | ||
|
|
600a914bd9 | ||
|
|
b1688950c4 | ||
|
|
d8e3f9b7b8 | ||
|
|
08d55e4463 | ||
|
|
55e2baa865 | ||
|
|
55174dc707 | ||
|
|
d57e3b3f64 | ||
|
|
aa42cd0aec | ||
|
|
ac6d9a39ec | ||
|
|
9b07775395 | ||
|
|
936fb8b8a1 | ||
|
|
6c8318b696 | ||
|
|
d554079e2b | ||
|
|
37464a101e | ||
|
|
c5674246b0 | ||
|
|
f076199e3f | ||
|
|
8326db1143 | ||
|
|
992e41e0a0 | ||
|
|
076e95d5c2 | ||
|
|
dfd79e5972 | ||
|
|
b16c9d53ef | ||
|
|
5fe85fb457 | ||
|
|
b45f470310 | ||
|
|
0ecda33ab8 | ||
|
|
7fcfca455a | ||
|
|
6a32154b8f | ||
|
|
132206677f | ||
|
|
30a8775548 | ||
|
|
045bc9aefc | ||
|
|
d5c46574cc | ||
|
|
37fea09403 | ||
|
|
063e8fae43 | ||
|
|
184c4fbf7f | ||
|
|
e19d27f640 | ||
|
|
ea96830758 | ||
|
|
d2edbc738d | ||
|
|
03bc8c8280 | ||
|
|
68908213da | ||
|
|
b3d5add89a | ||
|
|
7fe2d8fbe1 | ||
|
|
de545a69ca | ||
|
|
dc48ba540d | ||
|
|
81e92b4fa6 | ||
|
|
ebad5e00a3 | ||
|
|
bca03f1365 | ||
|
|
c89f55f0bd | ||
|
|
4d98bace87 | ||
|
|
dcdc899528 | ||
|
|
b57aa55001 | ||
|
|
d0c0168c20 | ||
|
|
af596a09cf | ||
|
|
6849c620b8 | ||
|
|
12598f0dca | ||
|
|
3f4ce4f16f | ||
|
|
4aaf0d8d5c | ||
|
|
65db056e09 | ||
|
|
232cef7cb9 | ||
|
|
73a432879a | ||
|
|
09afec17f9 | ||
|
|
ac47ab3deb | ||
|
|
8b3d7c168a | ||
|
|
60e8eb63ac | ||
|
|
4f29cd24b8 | ||
|
|
ba73ade2a0 | ||
|
|
7559305fc9 | ||
|
|
6985f553f9 | ||
|
|
8fc15df6d0 | ||
|
|
eb8160a5af | ||
|
|
16cf6eee9b | ||
|
|
320f684354 | ||
|
|
12062a5440 | ||
|
|
4423a9d979 | ||
|
|
1eb44defb6 | ||
|
|
e253fba2e9 | ||
|
|
c05d95924f | ||
|
|
2db583d62d | ||
|
|
59d8e1bf9f | ||
|
|
1001344c27 | ||
|
|
8a0e2da03f | ||
|
|
f58886be6f | ||
|
|
3c1d3b4d6a | ||
|
|
bbba995ff7 | ||
|
|
0033b5be80 | ||
|
|
87d53fb9b7 | ||
|
|
157031f23e | ||
|
|
8a37869489 | ||
|
|
5c10f11681 | ||
|
|
7b72bf0cd0 | ||
|
|
be29666916 | ||
|
|
8d4c5b5b33 | ||
|
|
52260f469a | ||
|
|
c566d22836 | ||
|
|
75f59a86c8 | ||
|
|
1eaf12446f | ||
|
|
efdd42426e | ||
|
|
62c557deae | ||
|
|
db1da4a61a | ||
|
|
db46c186aa | ||
|
|
677a603835 | ||
|
|
447d8790ad | ||
|
|
7a78f15a90 | ||
|
|
c1941809e9 | ||
|
|
623aaf8a0e | ||
|
|
7b3bf41120 | ||
|
|
0c3960eb0b | ||
|
|
fe3c31c08c | ||
|
|
94600cdbfc | ||
|
|
4e7ab3d7e3 | ||
|
|
47b25d7a26 | ||
|
|
0249666fa4 | ||
|
|
2e8504ce2f | ||
|
|
aca7d25001 | ||
|
|
2444309bc2 | ||
|
|
97c5a78d48 | ||
|
|
effdb88455 | ||
|
|
2f0ce3852e | ||
|
|
5475496399 | ||
|
|
b569d77a23 | ||
|
|
dfa7a2d4cf | ||
|
|
169e01276d | ||
|
|
07e698265e | ||
|
|
0632d7611f | ||
|
|
b3f39eedac | ||
|
|
46ed7e38bf | ||
|
|
8c5199d32d | ||
|
|
36ed833d64 | ||
|
|
47969ce61e | ||
|
|
06731e2026 | ||
|
|
123347169d | ||
|
|
f9101a744c | ||
|
|
97eb33000f | ||
|
|
60231ec88d | ||
|
|
3364374dc6 | ||
|
|
a3cf773e75 | ||
|
|
4092d5fbaf | ||
|
|
07e9fde9e8 | ||
|
|
9b4613630b | ||
|
|
f125d11b6d | ||
|
|
657d48a5f9 | ||
|
|
3735bdde19 | ||
|
|
3f906d81cb | ||
|
|
7c1f622797 | ||
|
|
cfe696ae8d | ||
|
|
021c50a8f2 | ||
|
|
95745ba869 | ||
|
|
adfae54816 | ||
|
|
10ed093eb8 | ||
|
|
c96df6bfa5 | ||
|
|
0126d18525 | ||
|
|
9e6e8f50f8 | ||
|
|
7e0b31626f | ||
|
|
1d9e249a77 | ||
|
|
88b89ef315 | ||
|
|
62b7925cb0 | ||
|
|
cc1528f550 | ||
|
|
1c8a83140b | ||
|
|
34276e2066 | ||
|
|
71abd16ae7 | ||
|
|
918e7285c4 | ||
|
|
056d422c71 | ||
|
|
5ee54f4e0e | ||
|
|
260c75e70c | ||
|
|
2d7401922f | ||
|
|
8c7a1348cf | ||
|
|
24fbdbd716 | ||
|
|
aad8f0e36b | ||
|
|
15cad44f08 | ||
|
|
0271454671 | ||
|
|
d0ddf288ca | ||
|
|
bc250ac377 | ||
|
|
7922fc3b0e | ||
|
|
161da723b9 | ||
|
|
514c19a247 | ||
|
|
41550d4a41 | ||
|
|
33cc3c1c3f | ||
|
|
7d15182202 | ||
|
|
8f0a1d9c6e | ||
|
|
72b5e5cf8e | ||
|
|
62aba2dd38 | ||
|
|
cdd6b80089 | ||
|
|
333836f5e7 | ||
|
|
a2dfda3471 | ||
|
|
2d28b4b05c | ||
|
|
87f9bcc6a3 | ||
|
|
48aca996ff | ||
|
|
c8c7e9b304 | ||
|
|
97ff023995 | ||
|
|
e273a336f8 | ||
|
|
34f0c3b90c | ||
|
|
7c2902d2b8 | ||
|
|
8e41afdffc | ||
|
|
7268886294 | ||
|
|
cbae900866 | ||
|
|
ffff138a6f | ||
|
|
88c95db8d0 | ||
|
|
56e657a0bb | ||
|
|
bc36b79105 | ||
|
|
5694bc0230 | ||
|
|
36130031f9 | ||
|
|
b8f1095f53 | ||
|
|
442fa09533 | ||
|
|
42ef2efbc8 | ||
|
|
ead3080b2b | ||
|
|
c6ea31c296 | ||
|
|
21eae29bb7 | ||
|
|
406740b524 | ||
|
|
9d30bc4062 | ||
|
|
fad91b64ab | ||
|
|
2132e71a81 | ||
|
|
bd8a451879 | ||
|
|
24dafa7359 | ||
|
|
3b5df793fb | ||
|
|
da835b6138 | ||
|
|
7e650d86a5 | ||
|
|
308e28cecc | ||
|
|
9a3c74fb64 | ||
|
|
f571f0688a | ||
|
|
1e9c32a102 | ||
|
|
8c69199689 | ||
|
|
3efb3e8a35 | ||
|
|
cfcb278406 | ||
|
|
9e195ea63b | ||
|
|
dc0d34c281 | ||
|
|
72076c218f | ||
|
|
151fd3b950 | ||
|
|
2d484fcb30 | ||
|
|
6e0407f404 | ||
|
|
8670aaba1e | ||
|
|
f27de7df35 | ||
|
|
63fa4dc8ec | ||
|
|
a191e32f71 | ||
|
|
9a38e8a4a0 | ||
|
|
6194222289 | ||
|
|
0d077eaeb7 | ||
|
|
b2c7a9a005 | ||
|
|
be01f1869e | ||
|
|
9f2b6390b0 | ||
|
|
e196f86e30 | ||
|
|
ec41d45234 | ||
|
|
567d1ba18b | ||
|
|
df8706983b | ||
|
|
8697498b32 | ||
|
|
af917c538a | ||
|
|
034e97dfa6 | ||
|
|
5e1e5f68e1 | ||
|
|
fb76f765cc | ||
|
|
7a3f57261d | ||
|
|
a1a460625d | ||
|
|
3f42ea2c61 | ||
|
|
940c594066 | ||
|
|
5e47fc45ab | ||
|
|
b471d56a86 | ||
|
|
61f8029205 | ||
|
|
e2f047d035 | ||
|
|
1aff4eda67 | ||
|
|
a6c5c44ed8 | ||
|
|
3f389d685a | ||
|
|
5d5351f0bc | ||
|
|
1224802ac6 | ||
|
|
e919f89caf | ||
|
|
bb8e7a68ea | ||
|
|
48f95e0ea4 | ||
|
|
931e9bcf0d | ||
|
|
67a3351c4c | ||
|
|
dfe5eeed7b | ||
|
|
3464573f17 | ||
|
|
9cf49c9c75 | ||
|
|
4e837cb90c | ||
|
|
e4fb58496b | ||
|
|
15a254c0cd | ||
|
|
d62746fc8c | ||
|
|
4b8b6fe407 | ||
|
|
6754834eb3 | ||
|
|
be98db561d | ||
|
|
574d0afc72 | ||
|
|
31c8ad611c | ||
|
|
b23730388d | ||
|
|
1b853aa893 | ||
|
|
36cb0a12ad | ||
|
|
5439eacf2d | ||
|
|
2687c3b80e | ||
|
|
fa009327ad | ||
|
|
838bd46e83 | ||
|
|
ccc2009aa8 | ||
|
|
d9aba92314 | ||
|
|
696b0475a8 | ||
|
|
e7370489e8 | ||
|
|
f1503b2238 | ||
|
|
cd4661e878 | ||
|
|
364e01ec7a | ||
|
|
ffb7b0ba38 | ||
|
|
22151eb49b | ||
|
|
d0354345f6 | ||
|
|
b1e61eb1e4 | ||
|
|
36e0ed15b6 | ||
|
|
095dfc2879 | ||
|
|
17dea9433e | ||
|
|
c285444e2f | ||
|
|
8ba402d080 | ||
|
|
88ab86734d | ||
|
|
504d87b0b0 | ||
|
|
b0d5818351 | ||
|
|
8826a01d32 | ||
|
|
cfb7a40841 | ||
|
|
8267761890 | ||
|
|
a651ae6ed4 | ||
|
|
a01911ba5f | ||
|
|
ee50b25d06 | ||
|
|
a67be85858 | ||
|
|
59c5a3973a | ||
|
|
d76d7343ff | ||
|
|
2b9638e7d3 | ||
|
|
3459a73705 | ||
|
|
bd480a466b | ||
|
|
4c34cb55b6 | ||
|
|
7347f9104c | ||
|
|
e137e4a38a | ||
|
|
b5989bbc25 | ||
|
|
c31ff7ceef | ||
|
|
9206c7642a | ||
|
|
d1b4f2b6c2 | ||
|
|
75066f2827 | ||
|
|
303f3aefef | ||
|
|
44fb5e0fd5 | ||
|
|
17a695120a | ||
|
|
6dc716eaf8 | ||
|
|
194be086d4 | ||
|
|
cca3900678 | ||
|
|
4fe32b7dbc | ||
|
|
c49603c25b | ||
|
|
8de85a4041 | ||
|
|
58a2135fa4 | ||
|
|
ab9a97db22 | ||
|
|
d291c241d5 | ||
|
|
24d4cb9b94 | ||
|
|
5b9adb799f | ||
|
|
38b41df36b | ||
|
|
34a9befe5c | ||
|
|
67fd579074 | ||
|
|
e2714b942d | ||
|
|
6b2556f870 | ||
|
|
43e6e9d201 | ||
|
|
131e0cc4c7 | ||
|
|
537be81b8f | ||
|
|
765168db7f | ||
|
|
1e16b06a24 | ||
|
|
42b59a644d | ||
|
|
d9fa9039bb | ||
|
|
cd4c93a5cb | ||
|
|
808961243d | ||
|
|
4d80e119f7 | ||
|
|
10c87edae1 | ||
|
|
0eb335d112 | ||
|
|
b8b26ccfe5 | ||
|
|
e89c23da4d | ||
|
|
f3da8956d9 | ||
|
|
b1147d77af | ||
|
|
66bc2fb41f | ||
|
|
4e538a6df8 | ||
|
|
ced087f8ae | ||
|
|
0f1eed0b1e | ||
|
|
95f15b77a3 | ||
|
|
f9ccfd5ca0 | ||
|
|
7207d7c847 | ||
|
|
00c4a524b7 | ||
|
|
9c3e0b5541 | ||
|
|
33bfe33eb3 | ||
|
|
3127c382a4 | ||
|
|
1748a390ec | ||
|
|
a7c0837049 | ||
|
|
44bf1eeae2 | ||
|
|
762b7a8ef1 | ||
|
|
102712a16e | ||
|
|
40810c59d7 | ||
|
|
35a10e86b5 | ||
|
|
c0c985494d | ||
|
|
8984ba7aef | ||
|
|
179869d481 | ||
|
|
5f29956f2b | ||
|
|
7e56c09620 | ||
|
|
dbc4ba84c2 | ||
|
|
9e4a527675 | ||
|
|
2e7f6afe3f | ||
|
|
45833542a7 | ||
|
|
1be6de30d7 | ||
|
|
981d78c8ba | ||
|
|
fbc7bedb6c | ||
|
|
9a4b1f0937 | ||
|
|
4786b0c5d4 | ||
|
|
17bed26096 | ||
|
|
511e16f1d3 | ||
|
|
18204bc1f7 | ||
|
|
e5e914903c | ||
|
|
7ba443afa5 | ||
|
|
b58d97fad3 | ||
|
|
d2a67a53b5 | ||
|
|
c0b556000c | ||
|
|
462c3b0696 | ||
|
|
d34ad73439 | ||
|
|
2c21712d58 | ||
|
|
2862db3534 | ||
|
|
bf3e30dac0 | ||
|
|
ce01e588c9 | ||
|
|
2a23082203 | ||
|
|
d373f924f6 | ||
|
|
eaf46ee006 | ||
|
|
d51355a0ad | ||
|
|
1e481a311a | ||
|
|
375660f232 | ||
|
|
46abb23ee8 | ||
|
|
8555bb697c | ||
|
|
f821893653 | ||
|
|
f6031baee4 | ||
|
|
75b3ea1f05 | ||
|
|
c818ba7bc7 | ||
|
|
74f0018962 | ||
|
|
3a0f07d36f | ||
|
|
8fb9e779a6 | ||
|
|
c5a794f1b5 | ||
|
|
3aa2cdd754 | ||
|
|
d93d52cf10 | ||
|
|
2abbd5a7fb | ||
|
|
2a10e9f7ee | ||
|
|
166d05afe9 | ||
|
|
2eff8d1962 | ||
|
|
93c9e76c4b | ||
|
|
021cb09b82 | ||
|
|
28e6939884 | ||
|
|
8847039d76 | ||
|
|
a047cf2e91 | ||
|
|
a8ae16e321 | ||
|
|
2694576a32 | ||
|
|
e4f10670f6 | ||
|
|
1324ba3a49 | ||
|
|
73c7810310 | ||
|
|
d160076267 | ||
|
|
a53be31765 | ||
|
|
ed8c1c7c19 | ||
|
|
159c8d1ff9 | ||
|
|
8932d455d8 | ||
|
|
3af183f6c3 | ||
|
|
4475be51cc | ||
|
|
c3ea3b751b | ||
|
|
e2c67d0c5b | ||
|
|
87731090ca | ||
|
|
80ca247435 | ||
|
|
a5b8d3afa5 | ||
|
|
1f615a06ad | ||
|
|
4123560a98 | ||
|
|
5267bd60a5 | ||
|
|
f76bffb482 | ||
|
|
51185c83c9 | ||
|
|
f1f887faae | ||
|
|
d53cbe7868 | ||
|
|
722746c78b | ||
|
|
46f0f3cee9 | ||
|
|
e1f5607836 | ||
|
|
ebc41b2eec | ||
|
|
7cd0d78424 | ||
|
|
d740559749 | ||
|
|
399357f752 | ||
|
|
3b4b474ce8 | ||
|
|
4534e46811 | ||
|
|
7bfa7b3f02 | ||
|
|
1cc34d8e62 | ||
|
|
2eff6b2e9d | ||
|
|
b046411302 | ||
|
|
6ab65b3626 | ||
|
|
cf321f9b09 | ||
|
|
8228d38859 | ||
|
|
c2e3110fa2 | ||
|
|
85681db7b7 | ||
|
|
1fc04c37d3 | ||
|
|
0fd8a122fb | ||
|
|
e3b6ede992 | ||
|
|
3601737869 | ||
|
|
9de6b4f151 | ||
|
|
4f4f55d67f | ||
|
|
714c624dc6 | ||
|
|
94cced8323 | ||
|
|
9b8ed16e37 | ||
|
|
a5e44cd229 | ||
|
|
eccc208229 | ||
|
|
79cfabb45d | ||
|
|
af6e1e2b99 | ||
|
|
4ad51c1b24 | ||
|
|
1919580759 | ||
|
|
b27ffe57e6 | ||
|
|
c115bcde54 | ||
|
|
c44712167f | ||
|
|
1aabaff1f2 | ||
|
|
21c0383efb | ||
|
|
313f19eba4 | ||
|
|
c6bcf53fea | ||
|
|
86812b34d1 | ||
|
|
15f9c49418 | ||
|
|
6e18c92a13 | ||
|
|
7870c6c33f | ||
|
|
ebe018347b | ||
|
|
86fe6fe5ab | ||
|
|
9e828b1750 | ||
|
|
45adb9627a | ||
|
|
940d3d4567 | ||
|
|
6bd7b2b8bb | ||
|
|
f2d6fd7b08 | ||
|
|
7219274d94 | ||
|
|
b84c82880c | ||
|
|
fcc418b4a0 | ||
|
|
15c0bb4c9e | ||
|
|
8db4f914d8 | ||
|
|
f3f9211c9c | ||
|
|
51680b7077 | ||
|
|
a2a69840f7 | ||
|
|
3a4a7590c2 | ||
|
|
bcc8b7ce3c | ||
|
|
1c7fe6d134 | ||
|
|
c4039f52bd | ||
|
|
bd851d5e86 | ||
|
|
00e448c5d6 | ||
|
|
4aeec8afbf | ||
|
|
f10432bf3f | ||
|
|
f0efed8aa1 | ||
|
|
4a4931bee2 | ||
|
|
afcf12ebc9 | ||
|
|
8f86d3417d | ||
|
|
92dfc54c4c | ||
|
|
c93bcb8678 | ||
|
|
98b2da9123 | ||
|
|
cd5f1a1b28 | ||
|
|
0e2e495d09 | ||
|
|
84c6c7e2a6 | ||
|
|
c8ebf9c75a | ||
|
|
29852ff0a5 | ||
|
|
f06ca62589 | ||
|
|
3f39a2be12 | ||
|
|
575190a96d | ||
|
|
78559d98eb | ||
|
|
398964c747 | ||
|
|
a634565296 | ||
|
|
a5ecbec9a6 | ||
|
|
fe79978f88 | ||
|
|
978ec8bc75 | ||
|
|
6e77f5b068 | ||
|
|
c9dbb64269 | ||
|
|
546d32e3eb | ||
|
|
616f6401b4 | ||
|
|
d047190453 | ||
|
|
17504b1b9c | ||
|
|
5a0d3df689 | ||
|
|
871304c89b | ||
|
|
8155150e45 | ||
|
|
d9fb8edaa9 | ||
|
|
dda61679bd | ||
|
|
6ac10a8297 | ||
|
|
0695c11739 | ||
|
|
7a4297c4f1 | ||
|
|
2c9e5df27d | ||
|
|
6db37d35ed | ||
|
|
ceee4fe5cf | ||
|
|
130b4a57de | ||
|
|
1cee27e830 | ||
|
|
ba2ff053f9 | ||
|
|
227665439f | ||
|
|
1a2e043ec2 | ||
|
|
89500df0ac | ||
|
|
cb4e80f1bc |
3
.gitignore
vendored
3
.gitignore
vendored
@@ -21,6 +21,7 @@ examples/
|
||||
|
||||
# Temporary outputs
|
||||
.DS_Store
|
||||
.hypothesis/
|
||||
time.log
|
||||
celerybeat-schedule.db
|
||||
search_results.json
|
||||
@@ -35,3 +36,5 @@ nltk_data/
|
||||
tika-server*.jar*
|
||||
cl100k_base.tiktoken
|
||||
libssl*.deb
|
||||
|
||||
sandbox/lib/seccomp_redbear/target
|
||||
|
||||
28618
api/General_purpose_entity.ttl
Normal file
28618
api/General_purpose_entity.ttl
Normal file
File diff suppressed because it is too large
Load Diff
7
api/app/cache/__init__.py
vendored
7
api/app/cache/__init__.py
vendored
@@ -2,10 +2,7 @@
|
||||
Cache 缓存模块
|
||||
|
||||
提供各种缓存功能的统一入口
|
||||
注意:隐性记忆和情绪建议已迁移到数据库存储,不再使用Redis缓存
|
||||
"""
|
||||
from .memory import EmotionMemoryCache, ImplicitMemoryCache
|
||||
|
||||
__all__ = [
|
||||
"EmotionMemoryCache",
|
||||
"ImplicitMemoryCache",
|
||||
]
|
||||
__all__ = []
|
||||
|
||||
8
api/app/cache/memory/__init__.py
vendored
8
api/app/cache/memory/__init__.py
vendored
@@ -2,11 +2,7 @@
|
||||
Memory 缓存模块
|
||||
|
||||
提供记忆系统相关的缓存功能
|
||||
注意:隐性记忆和情绪建议已迁移到数据库存储,不再使用Redis缓存
|
||||
"""
|
||||
from .emotion_memory import EmotionMemoryCache
|
||||
from .implicit_memory import ImplicitMemoryCache
|
||||
|
||||
__all__ = [
|
||||
"EmotionMemoryCache",
|
||||
"ImplicitMemoryCache",
|
||||
]
|
||||
__all__ = []
|
||||
|
||||
134
api/app/cache/memory/emotion_memory.py
vendored
134
api/app/cache/memory/emotion_memory.py
vendored
@@ -1,134 +0,0 @@
|
||||
"""
|
||||
Emotion Suggestions Cache
|
||||
|
||||
情绪个性化建议缓存模块
|
||||
用于缓存用户的情绪个性化建议数据
|
||||
"""
|
||||
import json
|
||||
import logging
|
||||
from typing import Optional, Dict, Any
|
||||
from datetime import datetime
|
||||
|
||||
from app.aioRedis import aio_redis
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EmotionMemoryCache:
|
||||
"""情绪建议缓存类"""
|
||||
|
||||
# Key 前缀
|
||||
PREFIX = "cache:memory:emotion_memory"
|
||||
|
||||
@classmethod
|
||||
def _get_key(cls, *parts: str) -> str:
|
||||
"""生成 Redis key
|
||||
|
||||
Args:
|
||||
*parts: key 的各个部分
|
||||
|
||||
Returns:
|
||||
完整的 Redis key
|
||||
"""
|
||||
return ":".join([cls.PREFIX] + list(parts))
|
||||
|
||||
@classmethod
|
||||
async def set_emotion_suggestions(
|
||||
cls,
|
||||
user_id: str,
|
||||
suggestions_data: Dict[str, Any],
|
||||
expire: int = 86400
|
||||
) -> bool:
|
||||
"""设置用户情绪建议缓存
|
||||
|
||||
Args:
|
||||
user_id: 用户ID(end_user_id)
|
||||
suggestions_data: 建议数据字典,包含:
|
||||
- health_summary: 健康状态摘要
|
||||
- suggestions: 建议列表
|
||||
- generated_at: 生成时间(可选)
|
||||
expire: 过期时间(秒),默认24小时(86400秒)
|
||||
|
||||
Returns:
|
||||
是否设置成功
|
||||
"""
|
||||
try:
|
||||
key = cls._get_key("suggestions", user_id)
|
||||
|
||||
# 添加生成时间戳
|
||||
if "generated_at" not in suggestions_data:
|
||||
suggestions_data["generated_at"] = datetime.now().isoformat()
|
||||
|
||||
# 添加缓存标记
|
||||
suggestions_data["cached"] = True
|
||||
|
||||
value = json.dumps(suggestions_data, ensure_ascii=False)
|
||||
await aio_redis.set(key, value, ex=expire)
|
||||
logger.info(f"设置情绪建议缓存成功: {key}, 过期时间: {expire}秒")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"设置情绪建议缓存失败: {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
async def get_emotion_suggestions(cls, user_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""获取用户情绪建议缓存
|
||||
|
||||
Args:
|
||||
user_id: 用户ID(end_user_id)
|
||||
|
||||
Returns:
|
||||
建议数据字典,如果不存在或已过期返回 None
|
||||
"""
|
||||
try:
|
||||
key = cls._get_key("suggestions", user_id)
|
||||
value = await aio_redis.get(key)
|
||||
|
||||
if value:
|
||||
data = json.loads(value)
|
||||
logger.info(f"成功获取情绪建议缓存: {key}")
|
||||
return data
|
||||
|
||||
logger.info(f"情绪建议缓存不存在或已过期: {key}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"获取情绪建议缓存失败: {e}", exc_info=True)
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
async def delete_emotion_suggestions(cls, user_id: str) -> bool:
|
||||
"""删除用户情绪建议缓存
|
||||
|
||||
Args:
|
||||
user_id: 用户ID(end_user_id)
|
||||
|
||||
Returns:
|
||||
是否删除成功
|
||||
"""
|
||||
try:
|
||||
key = cls._get_key("suggestions", user_id)
|
||||
result = await aio_redis.delete(key)
|
||||
logger.info(f"删除情绪建议缓存: {key}, 结果: {result}")
|
||||
return result > 0
|
||||
except Exception as e:
|
||||
logger.error(f"删除情绪建议缓存失败: {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
async def get_suggestions_ttl(cls, user_id: str) -> int:
|
||||
"""获取情绪建议缓存的剩余过期时间
|
||||
|
||||
Args:
|
||||
user_id: 用户ID(end_user_id)
|
||||
|
||||
Returns:
|
||||
剩余秒数,-1表示永不过期,-2表示key不存在
|
||||
"""
|
||||
try:
|
||||
key = cls._get_key("suggestions", user_id)
|
||||
ttl = await aio_redis.ttl(key)
|
||||
logger.debug(f"情绪建议缓存TTL: {key} = {ttl}秒")
|
||||
return ttl
|
||||
except Exception as e:
|
||||
logger.error(f"获取情绪建议缓存TTL失败: {e}")
|
||||
return -2
|
||||
136
api/app/cache/memory/implicit_memory.py
vendored
136
api/app/cache/memory/implicit_memory.py
vendored
@@ -1,136 +0,0 @@
|
||||
"""
|
||||
Implicit Memory Profile Cache
|
||||
|
||||
隐式记忆用户画像缓存模块
|
||||
用于缓存用户的完整画像数据(偏好标签、四维画像、兴趣领域、行为习惯)
|
||||
"""
|
||||
import json
|
||||
import logging
|
||||
from typing import Optional, Dict, Any
|
||||
from datetime import datetime
|
||||
|
||||
from app.aioRedis import aio_redis
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ImplicitMemoryCache:
|
||||
"""隐式记忆用户画像缓存类"""
|
||||
|
||||
# Key 前缀
|
||||
PREFIX = "cache:memory:implicit_memory"
|
||||
|
||||
@classmethod
|
||||
def _get_key(cls, *parts: str) -> str:
|
||||
"""生成 Redis key
|
||||
|
||||
Args:
|
||||
*parts: key 的各个部分
|
||||
|
||||
Returns:
|
||||
完整的 Redis key
|
||||
"""
|
||||
return ":".join([cls.PREFIX] + list(parts))
|
||||
|
||||
@classmethod
|
||||
async def set_user_profile(
|
||||
cls,
|
||||
user_id: str,
|
||||
profile_data: Dict[str, Any],
|
||||
expire: int = 86400
|
||||
) -> bool:
|
||||
"""设置用户完整画像缓存
|
||||
|
||||
Args:
|
||||
user_id: 用户ID(end_user_id)
|
||||
profile_data: 画像数据字典,包含:
|
||||
- preferences: 偏好标签列表
|
||||
- portrait: 四维画像对象
|
||||
- interest_areas: 兴趣领域分布对象
|
||||
- habits: 行为习惯列表
|
||||
- generated_at: 生成时间(可选)
|
||||
expire: 过期时间(秒),默认24小时(86400秒)
|
||||
|
||||
Returns:
|
||||
是否设置成功
|
||||
"""
|
||||
try:
|
||||
key = cls._get_key("profile", user_id)
|
||||
|
||||
# 添加生成时间戳
|
||||
if "generated_at" not in profile_data:
|
||||
profile_data["generated_at"] = datetime.now().isoformat()
|
||||
|
||||
# 添加缓存标记
|
||||
profile_data["cached"] = True
|
||||
|
||||
value = json.dumps(profile_data, ensure_ascii=False)
|
||||
await aio_redis.set(key, value, ex=expire)
|
||||
logger.info(f"设置用户画像缓存成功: {key}, 过期时间: {expire}秒")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"设置用户画像缓存失败: {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
async def get_user_profile(cls, user_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""获取用户完整画像缓存
|
||||
|
||||
Args:
|
||||
user_id: 用户ID(end_user_id)
|
||||
|
||||
Returns:
|
||||
画像数据字典,如果不存在或已过期返回 None
|
||||
"""
|
||||
try:
|
||||
key = cls._get_key("profile", user_id)
|
||||
value = await aio_redis.get(key)
|
||||
|
||||
if value:
|
||||
data = json.loads(value)
|
||||
logger.info(f"成功获取用户画像缓存: {key}")
|
||||
return data
|
||||
|
||||
logger.info(f"用户画像缓存不存在或已过期: {key}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"获取用户画像缓存失败: {e}", exc_info=True)
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
async def delete_user_profile(cls, user_id: str) -> bool:
|
||||
"""删除用户完整画像缓存
|
||||
|
||||
Args:
|
||||
user_id: 用户ID(end_user_id)
|
||||
|
||||
Returns:
|
||||
是否删除成功
|
||||
"""
|
||||
try:
|
||||
key = cls._get_key("profile", user_id)
|
||||
result = await aio_redis.delete(key)
|
||||
logger.info(f"删除用户画像缓存: {key}, 结果: {result}")
|
||||
return result > 0
|
||||
except Exception as e:
|
||||
logger.error(f"删除用户画像缓存失败: {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
async def get_profile_ttl(cls, user_id: str) -> int:
|
||||
"""获取用户画像缓存的剩余过期时间
|
||||
|
||||
Args:
|
||||
user_id: 用户ID(end_user_id)
|
||||
|
||||
Returns:
|
||||
剩余秒数,-1表示永不过期,-2表示key不存在
|
||||
"""
|
||||
try:
|
||||
key = cls._get_key("profile", user_id)
|
||||
ttl = await aio_redis.ttl(key)
|
||||
logger.debug(f"用户画像缓存TTL: {key} = {ttl}秒")
|
||||
return ttl
|
||||
except Exception as e:
|
||||
logger.error(f"获取用户画像缓存TTL失败: {e}")
|
||||
return -2
|
||||
@@ -3,8 +3,14 @@ import platform
|
||||
from datetime import timedelta
|
||||
from urllib.parse import quote
|
||||
|
||||
from app.core.config import settings
|
||||
from celery import Celery
|
||||
from celery.schedules import crontab
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
# macOS fork() safety - must be set before any Celery initialization
|
||||
if platform.system() == 'Darwin':
|
||||
os.environ.setdefault('OBJC_DISABLE_INITIALIZE_FORK_SAFETY', 'YES')
|
||||
|
||||
# 创建 Celery 应用实例
|
||||
# broker: 任务队列(使用 Redis DB 0)
|
||||
@@ -63,15 +69,22 @@ celery_app.conf.update(
|
||||
'app.core.memory.agent.read_message': {'queue': 'memory_tasks'},
|
||||
'app.core.memory.agent.write_message': {'queue': 'memory_tasks'},
|
||||
|
||||
# Long-term storage tasks → memory_tasks queue (batched write strategies)
|
||||
'app.core.memory.agent.long_term_storage.window': {'queue': 'memory_tasks'},
|
||||
'app.core.memory.agent.long_term_storage.time': {'queue': 'memory_tasks'},
|
||||
'app.core.memory.agent.long_term_storage.aggregate': {'queue': 'memory_tasks'},
|
||||
|
||||
# Document tasks → document_tasks queue (prefork worker)
|
||||
'app.core.rag.tasks.parse_document': {'queue': 'document_tasks'},
|
||||
'app.core.rag.tasks.build_graphrag_for_kb': {'queue': 'document_tasks'},
|
||||
'app.core.rag.tasks.sync_knowledge_for_kb': {'queue': 'document_tasks'},
|
||||
|
||||
# Beat/periodic tasks → document_tasks queue (prefork worker)
|
||||
'app.tasks.workspace_reflection_task': {'queue': 'document_tasks'},
|
||||
'app.tasks.regenerate_memory_cache': {'queue': 'document_tasks'},
|
||||
'app.tasks.run_forgetting_cycle_task': {'queue': 'document_tasks'},
|
||||
'app.controllers.memory_storage_controller.search_all': {'queue': 'document_tasks'},
|
||||
# Beat/periodic tasks → periodic_tasks queue (dedicated periodic worker)
|
||||
'app.tasks.workspace_reflection_task': {'queue': 'periodic_tasks'},
|
||||
'app.tasks.regenerate_memory_cache': {'queue': 'periodic_tasks'},
|
||||
'app.tasks.run_forgetting_cycle_task': {'queue': 'periodic_tasks'},
|
||||
'app.tasks.write_all_workspaces_memory_task': {'queue': 'periodic_tasks'},
|
||||
'app.tasks.update_implicit_emotions_storage': {'queue': 'periodic_tasks'},
|
||||
},
|
||||
)
|
||||
|
||||
@@ -81,10 +94,14 @@ celery_app.autodiscover_tasks(['app'])
|
||||
# Celery Beat schedule for periodic tasks
|
||||
memory_increment_schedule = timedelta(hours=settings.MEMORY_INCREMENT_INTERVAL_HOURS)
|
||||
memory_cache_regeneration_schedule = timedelta(hours=settings.MEMORY_CACHE_REGENERATION_HOURS)
|
||||
workspace_reflection_schedule = timedelta(seconds=30) # 每30秒运行一次settings.REFLECTION_INTERVAL_TIME
|
||||
forgetting_cycle_schedule = timedelta(hours=24) # 每24小时运行一次遗忘周期
|
||||
workspace_reflection_schedule = timedelta(seconds=settings.WORKSPACE_REFLECTION_INTERVAL_SECONDS)
|
||||
forgetting_cycle_schedule = timedelta(hours=settings.FORGETTING_CYCLE_INTERVAL_HOURS)
|
||||
implicit_emotions_update_schedule = crontab(
|
||||
hour=settings.IMPLICIT_EMOTIONS_UPDATE_HOUR,
|
||||
minute=settings.IMPLICIT_EMOTIONS_UPDATE_MINUTE,
|
||||
)
|
||||
|
||||
# 构建定时任务配置
|
||||
#构建定时任务配置
|
||||
beat_schedule_config = {
|
||||
"run-workspace-reflection": {
|
||||
"task": "app.tasks.workspace_reflection_task",
|
||||
@@ -103,16 +120,16 @@ beat_schedule_config = {
|
||||
"config_id": None, # 使用默认配置,可以通过环境变量配置
|
||||
},
|
||||
},
|
||||
"write-all-workspaces-memory": {
|
||||
"task": "app.tasks.write_all_workspaces_memory_task",
|
||||
"schedule": memory_increment_schedule,
|
||||
"args": (),
|
||||
},
|
||||
"update-implicit-emotions-storage": {
|
||||
"task": "app.tasks.update_implicit_emotions_storage",
|
||||
"schedule": implicit_emotions_update_schedule,
|
||||
"args": (),
|
||||
},
|
||||
}
|
||||
|
||||
# 如果配置了默认工作空间ID,则添加记忆总量统计任务
|
||||
if settings.DEFAULT_WORKSPACE_ID:
|
||||
beat_schedule_config["write-total-memory"] = {
|
||||
"task": "app.controllers.memory_storage_controller.search_all",
|
||||
"schedule": memory_increment_schedule,
|
||||
"kwargs": {
|
||||
"workspace_id": settings.DEFAULT_WORKSPACE_ID,
|
||||
},
|
||||
}
|
||||
|
||||
celery_app.conf.beat_schedule = beat_schedule_config
|
||||
|
||||
@@ -19,14 +19,18 @@ from . import (
|
||||
implicit_memory_controller,
|
||||
knowledge_controller,
|
||||
knowledgeshare_controller,
|
||||
mcp_market_controller,
|
||||
mcp_market_config_controller,
|
||||
memory_agent_controller,
|
||||
memory_dashboard_controller,
|
||||
memory_episodic_controller,
|
||||
memory_explicit_controller,
|
||||
memory_forget_controller,
|
||||
memory_perceptual_controller,
|
||||
memory_reflection_controller,
|
||||
memory_short_term_controller,
|
||||
memory_storage_controller,
|
||||
memory_working_controller,
|
||||
model_controller,
|
||||
multi_agent_controller,
|
||||
prompt_optimizer_controller,
|
||||
@@ -39,12 +43,9 @@ from . import (
|
||||
upload_controller,
|
||||
user_controller,
|
||||
user_memory_controllers,
|
||||
workflow_controller,
|
||||
workspace_controller,
|
||||
memory_forget_controller,
|
||||
home_page_controller,
|
||||
memory_perceptual_controller,
|
||||
memory_working_controller,
|
||||
ontology_controller,
|
||||
skill_controller
|
||||
)
|
||||
|
||||
# 创建管理端 API 路由器
|
||||
@@ -61,6 +62,8 @@ manager_router.include_router(model_controller.router)
|
||||
manager_router.include_router(file_controller.router)
|
||||
manager_router.include_router(document_controller.router)
|
||||
manager_router.include_router(knowledge_controller.router)
|
||||
manager_router.include_router(mcp_market_controller.router)
|
||||
manager_router.include_router(mcp_market_config_controller.router)
|
||||
manager_router.include_router(chunk_controller.router)
|
||||
manager_router.include_router(test_controller.router)
|
||||
manager_router.include_router(knowledgeshare_controller.router)
|
||||
@@ -77,7 +80,6 @@ manager_router.include_router(release_share_controller.router)
|
||||
manager_router.include_router(public_share_controller.router) # 公开路由(无需认证)
|
||||
manager_router.include_router(memory_dashboard_controller.router)
|
||||
manager_router.include_router(multi_agent_controller.router)
|
||||
manager_router.include_router(workflow_controller.router)
|
||||
manager_router.include_router(emotion_controller.router)
|
||||
manager_router.include_router(emotion_config_controller.router)
|
||||
manager_router.include_router(prompt_optimizer_controller.router)
|
||||
@@ -90,5 +92,7 @@ manager_router.include_router(implicit_memory_controller.router)
|
||||
manager_router.include_router(memory_perceptual_controller.router)
|
||||
manager_router.include_router(memory_working_controller.router)
|
||||
manager_router.include_router(file_storage_controller.router)
|
||||
manager_router.include_router(ontology_controller.router)
|
||||
manager_router.include_router(skill_controller.router)
|
||||
|
||||
__all__ = ["manager_router"]
|
||||
|
||||
@@ -22,6 +22,7 @@ from app.services import app_service, workspace_service
|
||||
from app.services.agent_config_helper import enrich_agent_config
|
||||
from app.services.app_service import AppService
|
||||
from app.services.workflow_service import WorkflowService, get_workflow_service
|
||||
from app.services.app_statistics_service import AppStatisticsService
|
||||
|
||||
router = APIRouter(prefix="/apps", tags=["Apps"])
|
||||
logger = get_business_logger()
|
||||
@@ -454,7 +455,8 @@ async def draft_run(
|
||||
user_id=payload.user_id or str(current_user.id),
|
||||
variables=payload.variables,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
files=payload.files # 传递多模态文件
|
||||
):
|
||||
yield event
|
||||
|
||||
@@ -475,7 +477,8 @@ async def draft_run(
|
||||
"app_id": str(app_id),
|
||||
"message_length": len(payload.message),
|
||||
"has_conversation_id": bool(payload.conversation_id),
|
||||
"has_variables": bool(payload.variables)
|
||||
"has_variables": bool(payload.variables),
|
||||
"has_files": bool(payload.files)
|
||||
}
|
||||
)
|
||||
|
||||
@@ -490,7 +493,8 @@ async def draft_run(
|
||||
user_id=payload.user_id or str(current_user.id),
|
||||
variables=payload.variables,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
files=payload.files # 传递多模态文件
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
@@ -798,7 +802,8 @@ async def draft_run_compare(
|
||||
web_search=True,
|
||||
memory=True,
|
||||
parallel=payload.parallel,
|
||||
timeout=payload.timeout or 60
|
||||
timeout=payload.timeout or 60,
|
||||
files=payload.files
|
||||
):
|
||||
yield event
|
||||
|
||||
@@ -872,3 +877,75 @@ async def update_workflow_config(
|
||||
workspace_id = current_user.current_workspace_id
|
||||
cfg = app_service.update_workflow_config(db, app_id=app_id, data=payload, workspace_id=workspace_id)
|
||||
return success(data=WorkflowConfigSchema.model_validate(cfg))
|
||||
|
||||
|
||||
@router.get("/{app_id}/statistics", summary="应用统计数据")
|
||||
@cur_workspace_access_guard()
|
||||
def get_app_statistics(
|
||||
app_id: uuid.UUID,
|
||||
start_date: int,
|
||||
end_date: int,
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
):
|
||||
"""获取应用统计数据
|
||||
|
||||
Args:
|
||||
app_id: 应用ID
|
||||
start_date: 开始时间戳(毫秒)
|
||||
end_date: 结束时间戳(毫秒)
|
||||
|
||||
Returns:
|
||||
- daily_conversations: 每日会话数统计
|
||||
- total_conversations: 总会话数
|
||||
- daily_new_users: 每日新增用户数
|
||||
- total_new_users: 总新增用户数
|
||||
- daily_api_calls: 每日API调用次数
|
||||
- total_api_calls: 总API调用次数
|
||||
- daily_tokens: 每日token消耗
|
||||
- total_tokens: 总token消耗
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
stats_service = AppStatisticsService(db)
|
||||
|
||||
result = stats_service.get_app_statistics(
|
||||
app_id=app_id,
|
||||
workspace_id=workspace_id,
|
||||
start_date=start_date,
|
||||
end_date=end_date
|
||||
)
|
||||
|
||||
return success(data=result)
|
||||
|
||||
|
||||
@router.get("/workspace/api-statistics", summary="工作空间API调用统计")
|
||||
@cur_workspace_access_guard()
|
||||
def get_workspace_api_statistics(
|
||||
start_date: int,
|
||||
end_date: int,
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
):
|
||||
"""获取工作空间API调用统计
|
||||
|
||||
Args:
|
||||
start_date: 开始时间戳(毫秒)
|
||||
end_date: 结束时间戳(毫秒)
|
||||
|
||||
Returns:
|
||||
每日统计数据列表,每项包含:
|
||||
- date: 日期
|
||||
- total_calls: 当日总调用次数
|
||||
- app_calls: 当日应用调用次数
|
||||
- service_calls: 当日服务调用次数
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
stats_service = AppStatisticsService(db)
|
||||
|
||||
result = stats_service.get_workspace_api_statistics(
|
||||
workspace_id=workspace_id,
|
||||
start_date=start_date,
|
||||
end_date=end_date
|
||||
)
|
||||
|
||||
return success(data=result)
|
||||
|
||||
@@ -61,6 +61,7 @@ async def login_for_access_token(
|
||||
user = auth_service.register_user_with_invite(
|
||||
db=db,
|
||||
email=form_data.email,
|
||||
username=form_data.username,
|
||||
password=form_data.password,
|
||||
invite_token=form_data.invite,
|
||||
workspace_id=invite_info.workspace_id
|
||||
|
||||
@@ -7,11 +7,13 @@ Routes:
|
||||
GET /memory/config/emotion - 获取情绪引擎配置
|
||||
POST /memory/config/emotion - 更新情绪引擎配置
|
||||
"""
|
||||
import uuid
|
||||
|
||||
from fastapi import APIRouter, Depends, Query, HTTPException, status
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Optional
|
||||
from typing import Optional, Union
|
||||
from sqlalchemy.orm import Session
|
||||
from uuid import UUID
|
||||
|
||||
from app.core.response_utils import success
|
||||
from app.dependencies import get_current_user
|
||||
@@ -20,6 +22,7 @@ from app.schemas.response_schema import ApiResponse
|
||||
from app.services.emotion_config_service import EmotionConfigService
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.db import get_db
|
||||
from app.utils.config_utils import resolve_config_id
|
||||
|
||||
# 获取API专用日志器
|
||||
api_logger = get_api_logger()
|
||||
@@ -32,11 +35,11 @@ router = APIRouter(
|
||||
|
||||
class EmotionConfigQuery(BaseModel):
|
||||
"""情绪配置查询请求模型"""
|
||||
config_id: int = Field(..., description="配置ID")
|
||||
config_id: UUID = Field(..., description="配置ID")
|
||||
|
||||
class EmotionConfigUpdate(BaseModel):
|
||||
"""情绪配置更新请求模型"""
|
||||
config_id: int = Field(..., description="配置ID")
|
||||
config_id: Union[uuid.UUID, int, str]= Field(..., description="配置ID")
|
||||
emotion_enabled: bool = Field(..., description="是否启用情绪提取")
|
||||
emotion_model_id: Optional[str] = Field(None, description="情绪分析专用模型ID")
|
||||
emotion_extract_keywords: bool = Field(..., description="是否提取情绪关键词")
|
||||
@@ -45,7 +48,7 @@ class EmotionConfigUpdate(BaseModel):
|
||||
|
||||
@router.get("/read_config", response_model=ApiResponse)
|
||||
def get_emotion_config(
|
||||
config_id: int = Query(..., description="配置ID"),
|
||||
config_id: UUID|int = Query(..., description="配置ID"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
@@ -78,7 +81,7 @@ def get_emotion_config(
|
||||
f"用户 {current_user.username} 请求获取情绪配置",
|
||||
extra={"config_id": config_id}
|
||||
)
|
||||
|
||||
config_id=resolve_config_id(config_id, db)
|
||||
# 初始化服务
|
||||
config_service = EmotionConfigService(db)
|
||||
|
||||
@@ -157,6 +160,7 @@ def update_emotion_config(
|
||||
}
|
||||
}
|
||||
"""
|
||||
config.config_id=resolve_config_id(config.config_id, db)
|
||||
try:
|
||||
api_logger.info(
|
||||
f"用户 {current_user.username} 请求更新情绪配置",
|
||||
|
||||
@@ -11,6 +11,7 @@ Routes:
|
||||
"""
|
||||
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.language_utils import get_language_from_header
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.core.response_utils import fail, success
|
||||
from app.dependencies import get_current_user, get_db
|
||||
@@ -45,35 +46,40 @@ emotion_service = EmotionAnalyticsService()
|
||||
@router.post("/tags", response_model=ApiResponse)
|
||||
async def get_emotion_tags(
|
||||
request: EmotionTagsRequest,
|
||||
language_type: str = Header(default="zh", alias="X-Language-Type"),
|
||||
language_type: str = Header(default=None, alias="X-Language-Type"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
|
||||
try:
|
||||
# 使用集中化的语言校验
|
||||
language = get_language_from_header(language_type)
|
||||
|
||||
api_logger.info(
|
||||
f"用户 {current_user.username} 请求获取情绪标签统计",
|
||||
extra={
|
||||
"group_id": request.group_id,
|
||||
"end_user_id": request.end_user_id,
|
||||
"emotion_type": request.emotion_type,
|
||||
"start_date": request.start_date,
|
||||
"end_date": request.end_date,
|
||||
"limit": request.limit
|
||||
"limit": request.limit,
|
||||
"language_type": language
|
||||
}
|
||||
)
|
||||
|
||||
# 调用服务层
|
||||
data = await emotion_service.get_emotion_tags(
|
||||
end_user_id=request.group_id,
|
||||
end_user_id=request.end_user_id,
|
||||
emotion_type=request.emotion_type,
|
||||
start_date=request.start_date,
|
||||
end_date=request.end_date,
|
||||
limit=request.limit
|
||||
limit=request.limit,
|
||||
language=language
|
||||
)
|
||||
|
||||
api_logger.info(
|
||||
"情绪标签统计获取成功",
|
||||
extra={
|
||||
"group_id": request.group_id,
|
||||
"end_user_id": request.end_user_id,
|
||||
"total_count": data.get("total_count", 0),
|
||||
"tags_count": len(data.get("tags", []))
|
||||
}
|
||||
@@ -84,7 +90,7 @@ async def get_emotion_tags(
|
||||
except Exception as e:
|
||||
api_logger.error(
|
||||
f"获取情绪标签统计失败: {str(e)}",
|
||||
extra={"group_id": request.group_id},
|
||||
extra={"end_user_id": request.end_user_id},
|
||||
exc_info=True
|
||||
)
|
||||
raise HTTPException(
|
||||
@@ -97,15 +103,18 @@ async def get_emotion_tags(
|
||||
@router.post("/wordcloud", response_model=ApiResponse)
|
||||
async def get_emotion_wordcloud(
|
||||
request: EmotionWordcloudRequest,
|
||||
language_type: str = Header(default="zh", alias="X-Language-Type"),
|
||||
language_type: str = Header(default=None, alias="X-Language-Type"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
|
||||
try:
|
||||
# 使用集中化的语言校验
|
||||
language = get_language_from_header(language_type)
|
||||
|
||||
api_logger.info(
|
||||
f"用户 {current_user.username} 请求获取情绪词云数据",
|
||||
extra={
|
||||
"group_id": request.group_id,
|
||||
"end_user_id": request.end_user_id,
|
||||
"emotion_type": request.emotion_type,
|
||||
"limit": request.limit
|
||||
}
|
||||
@@ -113,7 +122,7 @@ async def get_emotion_wordcloud(
|
||||
|
||||
# 调用服务层
|
||||
data = await emotion_service.get_emotion_wordcloud(
|
||||
end_user_id=request.group_id,
|
||||
end_user_id=request.end_user_id,
|
||||
emotion_type=request.emotion_type,
|
||||
limit=request.limit
|
||||
)
|
||||
@@ -121,7 +130,7 @@ async def get_emotion_wordcloud(
|
||||
api_logger.info(
|
||||
"情绪词云数据获取成功",
|
||||
extra={
|
||||
"group_id": request.group_id,
|
||||
"end_user_id": request.end_user_id,
|
||||
"total_keywords": data.get("total_keywords", 0)
|
||||
}
|
||||
)
|
||||
@@ -131,7 +140,7 @@ async def get_emotion_wordcloud(
|
||||
except Exception as e:
|
||||
api_logger.error(
|
||||
f"获取情绪词云数据失败: {str(e)}",
|
||||
extra={"group_id": request.group_id},
|
||||
extra={"end_user_id": request.end_user_id},
|
||||
exc_info=True
|
||||
)
|
||||
raise HTTPException(
|
||||
@@ -144,11 +153,14 @@ async def get_emotion_wordcloud(
|
||||
@router.post("/health", response_model=ApiResponse)
|
||||
async def get_emotion_health(
|
||||
request: EmotionHealthRequest,
|
||||
language_type: str = Header(default="zh", alias="X-Language-Type"),
|
||||
language_type: str = Header(default=None, alias="X-Language-Type"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
|
||||
try:
|
||||
# 使用集中化的语言校验
|
||||
language = get_language_from_header(language_type)
|
||||
|
||||
# 验证时间范围参数
|
||||
if request.time_range not in ["7d", "30d", "90d"]:
|
||||
raise HTTPException(
|
||||
@@ -159,22 +171,22 @@ async def get_emotion_health(
|
||||
api_logger.info(
|
||||
f"用户 {current_user.username} 请求获取情绪健康指数",
|
||||
extra={
|
||||
"group_id": request.group_id,
|
||||
"end_user_id": request.end_user_id,
|
||||
"time_range": request.time_range
|
||||
}
|
||||
)
|
||||
|
||||
# 调用服务层
|
||||
data = await emotion_service.calculate_emotion_health_index(
|
||||
end_user_id=request.group_id,
|
||||
end_user_id=request.end_user_id,
|
||||
time_range=request.time_range
|
||||
)
|
||||
|
||||
api_logger.info(
|
||||
"情绪健康指数获取成功",
|
||||
extra={
|
||||
"group_id": request.group_id,
|
||||
"health_score": data.get("health_score", 0),
|
||||
"end_user_id": request.end_user_id,
|
||||
"health_score": data.get("health_score") or 0,
|
||||
"level": data.get("level", "未知")
|
||||
}
|
||||
)
|
||||
@@ -186,7 +198,7 @@ async def get_emotion_health(
|
||||
except Exception as e:
|
||||
api_logger.error(
|
||||
f"获取情绪健康指数失败: {str(e)}",
|
||||
extra={"group_id": request.group_id},
|
||||
extra={"end_user_id": request.end_user_id},
|
||||
exc_info=True
|
||||
)
|
||||
raise HTTPException(
|
||||
@@ -196,64 +208,112 @@ async def get_emotion_health(
|
||||
|
||||
|
||||
|
||||
# @router.post("/check-data", response_model=ApiResponse)
|
||||
# async def check_emotion_data_exists(
|
||||
# request: EmotionSuggestionsRequest,
|
||||
# db: Session = Depends(get_db),
|
||||
# current_user: User = Depends(get_current_user),
|
||||
# ):
|
||||
# """检查用户情绪建议数据是否存在
|
||||
|
||||
# Args:
|
||||
# request: 包含 end_user_id
|
||||
# db: 数据库会话
|
||||
# current_user: 当前用户
|
||||
|
||||
# Returns:
|
||||
# 数据存在状态
|
||||
# """
|
||||
# try:
|
||||
# api_logger.info(
|
||||
# f"检查用户情绪建议数据是否存在: {request.end_user_id}",
|
||||
# extra={"end_user_id": request.end_user_id}
|
||||
# )
|
||||
|
||||
# # 从数据库获取建议
|
||||
# data = await emotion_service.get_cached_suggestions(
|
||||
# end_user_id=request.end_user_id,
|
||||
# db=db
|
||||
# )
|
||||
|
||||
# if data is None:
|
||||
# api_logger.info(f"用户 {request.end_user_id} 的情绪建议数据不存在")
|
||||
# return fail(
|
||||
# BizCode.NOT_FOUND,
|
||||
# "情绪建议数据不存在,请点击右上角刷新进行初始化",
|
||||
# {"exists": False}
|
||||
# )
|
||||
|
||||
# api_logger.info(f"用户 {request.end_user_id} 的情绪建议数据存在")
|
||||
# return success(data={"exists": True}, msg="情绪建议数据已存在")
|
||||
|
||||
# except Exception as e:
|
||||
# api_logger.error(
|
||||
# f"检查情绪建议数据失败: {str(e)}",
|
||||
# extra={"end_user_id": request.end_user_id},
|
||||
# exc_info=True
|
||||
# )
|
||||
# raise HTTPException(
|
||||
# status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
# detail=f"检查情绪建议数据失败: {str(e)}"
|
||||
# )
|
||||
|
||||
|
||||
@router.post("/suggestions", response_model=ApiResponse)
|
||||
async def get_emotion_suggestions(
|
||||
request: EmotionSuggestionsRequest,
|
||||
language_type: str = Header(default="zh", alias="X-Language-Type"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""获取个性化情绪建议(从缓存读取)
|
||||
"""获取个性化情绪建议(从数据库读取)
|
||||
|
||||
Args:
|
||||
request: 包含 group_id 和可选的 config_id
|
||||
request: 包含 end_user_id 和可选的 config_id
|
||||
db: 数据库会话
|
||||
current_user: 当前用户
|
||||
|
||||
Returns:
|
||||
缓存的个性化情绪建议响应
|
||||
存储的个性化情绪建议响应
|
||||
"""
|
||||
try:
|
||||
api_logger.info(
|
||||
f"用户 {current_user.username} 请求获取个性化情绪建议(缓存)",
|
||||
f"用户 {current_user.username} 请求获取个性化情绪建议",
|
||||
extra={
|
||||
"group_id": request.group_id,
|
||||
"end_user_id": request.end_user_id,
|
||||
"config_id": request.config_id
|
||||
}
|
||||
)
|
||||
|
||||
# 从缓存获取建议
|
||||
# 从数据库获取建议
|
||||
data = await emotion_service.get_cached_suggestions(
|
||||
end_user_id=request.group_id,
|
||||
end_user_id=request.end_user_id,
|
||||
db=db
|
||||
)
|
||||
|
||||
if data is None:
|
||||
# 缓存不存在或已过期
|
||||
api_logger.info(
|
||||
f"用户 {request.group_id} 的建议缓存不存在或已过期",
|
||||
extra={"group_id": request.group_id}
|
||||
f"用户 {request.end_user_id} 的建议数据不存在",
|
||||
extra={"end_user_id": request.end_user_id}
|
||||
)
|
||||
return fail(
|
||||
BizCode.NOT_FOUND,
|
||||
"建议缓存不存在或已过期,请右上角刷新生成新建议",
|
||||
""
|
||||
return success(
|
||||
data={"exists": False},
|
||||
msg="情绪建议数据不存在,请点击右上角刷新进行初始化"
|
||||
)
|
||||
|
||||
api_logger.info(
|
||||
"个性化建议获取成功(缓存)",
|
||||
"个性化建议获取成功",
|
||||
extra={
|
||||
"group_id": request.group_id,
|
||||
"end_user_id": request.end_user_id,
|
||||
"suggestions_count": len(data.get("suggestions", []))
|
||||
}
|
||||
)
|
||||
|
||||
return success(data=data, msg="个性化建议获取成功(缓存)")
|
||||
return success(data=data, msg="个性化建议获取成功")
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(
|
||||
f"获取个性化建议失败: {str(e)}",
|
||||
extra={"group_id": request.group_id},
|
||||
extra={"end_user_id": request.end_user_id},
|
||||
exc_info=True
|
||||
)
|
||||
raise HTTPException(
|
||||
@@ -265,11 +325,11 @@ async def get_emotion_suggestions(
|
||||
@router.post("/generate_suggestions", response_model=ApiResponse)
|
||||
async def generate_emotion_suggestions(
|
||||
request: EmotionGenerateSuggestionsRequest,
|
||||
language_type: str = Header(default="zh", alias="X-Language-Type"),
|
||||
language_type: str = Header(default=None, alias="X-Language-Type"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""生成个性化情绪建议(调用LLM并缓存)
|
||||
"""生成个性化情绪建议(调用LLM并保存到数据库)
|
||||
|
||||
Args:
|
||||
request: 包含 end_user_id
|
||||
@@ -280,6 +340,9 @@ async def generate_emotion_suggestions(
|
||||
新生成的个性化情绪建议响应
|
||||
"""
|
||||
try:
|
||||
# 使用集中化的语言校验
|
||||
language = get_language_from_header(language_type)
|
||||
|
||||
api_logger.info(
|
||||
f"用户 {current_user.username} 请求生成个性化情绪建议",
|
||||
extra={
|
||||
@@ -290,15 +353,15 @@ async def generate_emotion_suggestions(
|
||||
# 调用服务层生成建议
|
||||
data = await emotion_service.generate_emotion_suggestions(
|
||||
end_user_id=request.end_user_id,
|
||||
db=db
|
||||
db=db,
|
||||
language=language
|
||||
)
|
||||
|
||||
# 保存到缓存
|
||||
# 保存到数据库
|
||||
await emotion_service.save_suggestions_cache(
|
||||
end_user_id=request.end_user_id,
|
||||
suggestions_data=data,
|
||||
db=db,
|
||||
expires_hours=24
|
||||
db=db
|
||||
)
|
||||
|
||||
api_logger.info(
|
||||
@@ -320,4 +383,4 @@ async def generate_emotion_suggestions(
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"生成个性化建议失败: {str(e)}"
|
||||
)
|
||||
)
|
||||
@@ -29,7 +29,7 @@ from app.core.storage_exceptions import (
|
||||
StorageUploadError,
|
||||
)
|
||||
from app.db import get_db
|
||||
from app.dependencies import get_current_user
|
||||
from app.dependencies import get_current_user, get_share_user_id, ShareTokenData
|
||||
from app.models.file_metadata_model import FileMetadata
|
||||
from app.models.user_model import User
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
@@ -143,6 +143,141 @@ async def upload_file(
|
||||
)
|
||||
|
||||
|
||||
@router.post("/share/files", response_model=ApiResponse)
|
||||
async def upload_file_with_share_token(
|
||||
file: UploadFile = File(...),
|
||||
db: Session = Depends(get_db),
|
||||
share_data: ShareTokenData = Depends(get_share_user_id),
|
||||
storage_service: FileStorageService = Depends(get_file_storage_service),
|
||||
):
|
||||
"""
|
||||
Upload a file to the configured storage backend using share_token authentication.
|
||||
"""
|
||||
from app.services.release_share_service import ReleaseShareService
|
||||
from app.models.app_model import App
|
||||
from app.models.workspace_model import Workspace
|
||||
|
||||
# Get share and release info from share_token
|
||||
service = ReleaseShareService(db)
|
||||
share_info = service.get_shared_release_info(share_token=share_data.share_token)
|
||||
|
||||
# Get share object to access app_id
|
||||
share = service.repo.get_by_share_token(share_data.share_token)
|
||||
if not share:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Shared app not found"
|
||||
)
|
||||
|
||||
# Get app to access workspace_id
|
||||
app = db.query(App).filter(
|
||||
App.id == share.app_id,
|
||||
App.is_active.is_(True)
|
||||
).first()
|
||||
|
||||
if not app:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="App not found"
|
||||
)
|
||||
|
||||
# Get workspace to access tenant_id
|
||||
workspace = db.query(Workspace).filter(
|
||||
Workspace.id == app.workspace_id
|
||||
).first()
|
||||
|
||||
if not workspace:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Workspace not found"
|
||||
)
|
||||
|
||||
tenant_id = workspace.tenant_id
|
||||
workspace_id = app.workspace_id
|
||||
|
||||
api_logger.info(
|
||||
f"Storage upload request (share): tenant_id={tenant_id}, workspace_id={workspace_id}, "
|
||||
f"filename={file.filename}, share_token={share_data.share_token}"
|
||||
)
|
||||
|
||||
# Read file contents
|
||||
contents = await file.read()
|
||||
file_size = len(contents)
|
||||
|
||||
# Validate file size
|
||||
if file_size == 0:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="The file is empty."
|
||||
)
|
||||
|
||||
if file_size > settings.MAX_FILE_SIZE:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"The file size exceeds the {settings.MAX_FILE_SIZE} byte limit"
|
||||
)
|
||||
|
||||
# Extract file extension
|
||||
_, file_extension = os.path.splitext(file.filename)
|
||||
file_ext = file_extension.lower()
|
||||
|
||||
# Generate file_id and file_key
|
||||
file_id = uuid.uuid4()
|
||||
file_key = generate_file_key(
|
||||
tenant_id=tenant_id,
|
||||
workspace_id=workspace_id,
|
||||
file_id=file_id,
|
||||
file_ext=file_ext,
|
||||
)
|
||||
|
||||
# Create file metadata record with pending status
|
||||
file_metadata = FileMetadata(
|
||||
id=file_id,
|
||||
tenant_id=tenant_id,
|
||||
workspace_id=workspace_id,
|
||||
file_key=file_key,
|
||||
file_name=file.filename,
|
||||
file_ext=file_ext,
|
||||
file_size=file_size,
|
||||
content_type=file.content_type,
|
||||
status="pending",
|
||||
)
|
||||
db.add(file_metadata)
|
||||
db.commit()
|
||||
db.refresh(file_metadata)
|
||||
|
||||
# Upload file to storage backend
|
||||
try:
|
||||
await storage_service.upload_file(
|
||||
tenant_id=tenant_id,
|
||||
workspace_id=workspace_id,
|
||||
file_id=file_id,
|
||||
file_ext=file_ext,
|
||||
content=contents,
|
||||
content_type=file.content_type,
|
||||
)
|
||||
# Update status to completed
|
||||
file_metadata.status = "completed"
|
||||
db.commit()
|
||||
api_logger.info(f"File uploaded to storage (share): file_key={file_key}")
|
||||
except StorageUploadError as e:
|
||||
# Update status to failed
|
||||
file_metadata.status = "failed"
|
||||
db.commit()
|
||||
api_logger.error(f"Storage upload failed (share): {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"File storage failed: {str(e)}"
|
||||
)
|
||||
|
||||
api_logger.info(f"File upload successful (share): {file.filename} (file_id: {file_id})")
|
||||
|
||||
return success(
|
||||
data={"file_id": str(file_id), "file_key": file_key},
|
||||
msg="File upload successful"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/files/{file_id}", response_model=Any)
|
||||
async def download_file(
|
||||
file_id: uuid.UUID,
|
||||
@@ -310,7 +445,7 @@ async def get_file_url(
|
||||
try:
|
||||
if permanent:
|
||||
# Generate permanent URL (no expiration check)
|
||||
server_url = f"http://{settings.SERVER_IP}:8000/api"
|
||||
server_url = settings.FILE_LOCAL_SERVER_URL
|
||||
url = f"{server_url}/storage/permanent/{file_id}"
|
||||
return success(
|
||||
data={
|
||||
|
||||
@@ -122,10 +122,52 @@ def validate_confidence_threshold(threshold: float) -> None:
|
||||
raise ValueError("confidence_threshold must be between 0.0 and 1.0")
|
||||
|
||||
|
||||
@router.get("/preferences/{user_id}", response_model=ApiResponse)
|
||||
@router.get("/check-data/{end_user_id}", response_model=ApiResponse)
|
||||
@cur_workspace_access_guard()
|
||||
async def check_user_data_exists(
|
||||
end_user_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
) -> ApiResponse:
|
||||
"""
|
||||
检查用户画像数据是否存在
|
||||
|
||||
Args:
|
||||
end_user_id: 目标用户ID
|
||||
|
||||
Returns:
|
||||
数据存在状态
|
||||
"""
|
||||
api_logger.info(f"检查用户画像数据是否存在: {end_user_id}")
|
||||
|
||||
try:
|
||||
# Validate inputs
|
||||
validate_user_id(end_user_id)
|
||||
|
||||
# Create service with user-specific config
|
||||
service = ImplicitMemoryService(db=db, end_user_id=end_user_id)
|
||||
|
||||
# Get cached profile
|
||||
cached_profile = await service.get_cached_profile(end_user_id=end_user_id, db=db)
|
||||
|
||||
if cached_profile is None:
|
||||
api_logger.info(f"用户 {end_user_id} 的画像数据不存在")
|
||||
return success(
|
||||
data={"exists": False},
|
||||
msg="画像数据不存在,请点击右上角刷新进行初始化"
|
||||
)
|
||||
|
||||
api_logger.info(f"用户 {end_user_id} 的画像数据存在")
|
||||
return success(data={"exists": True}, msg="画像数据已存在")
|
||||
|
||||
except Exception as e:
|
||||
return handle_implicit_memory_error(e, "检查画像数据", end_user_id)
|
||||
|
||||
|
||||
@router.get("/preferences/{end_user_id}", response_model=ApiResponse)
|
||||
@cur_workspace_access_guard()
|
||||
async def get_preference_tags(
|
||||
user_id: str,
|
||||
end_user_id: str,
|
||||
confidence_threshold: float = Query(0.5, ge=0.0, le=1.0, description="Minimum confidence threshold"),
|
||||
tag_category: Optional[str] = Query(None, description="Filter by tag category"),
|
||||
start_date: Optional[datetime] = Query(None, description="Filter start date"),
|
||||
@@ -137,7 +179,7 @@ async def get_preference_tags(
|
||||
Get user preference tags from cache.
|
||||
|
||||
Args:
|
||||
user_id: Target user ID
|
||||
end_user_id: Target end user ID
|
||||
confidence_threshold: Minimum confidence score (0.0-1.0)
|
||||
tag_category: Optional category filter
|
||||
start_date: Optional start date filter
|
||||
@@ -146,25 +188,21 @@ async def get_preference_tags(
|
||||
Returns:
|
||||
List of preference tags from cache
|
||||
"""
|
||||
api_logger.info(f"Preference tags requested for user: {user_id} (from cache)")
|
||||
api_logger.info(f"Preference tags requested for user: {end_user_id} (from cache)")
|
||||
|
||||
try:
|
||||
# Validate inputs
|
||||
validate_user_id(user_id)
|
||||
validate_user_id(end_user_id)
|
||||
|
||||
# Create service with user-specific config
|
||||
service = ImplicitMemoryService(db=db, end_user_id=user_id)
|
||||
service = ImplicitMemoryService(db=db, end_user_id=end_user_id)
|
||||
|
||||
# Get cached profile
|
||||
cached_profile = await service.get_cached_profile(end_user_id=user_id, db=db)
|
||||
cached_profile = await service.get_cached_profile(end_user_id=end_user_id, db=db)
|
||||
|
||||
if cached_profile is None:
|
||||
api_logger.info(f"用户 {user_id} 的画像缓存不存在或已过期")
|
||||
return fail(
|
||||
BizCode.NOT_FOUND,
|
||||
"画像缓存不存在或已过期,请右上角刷新生成新画像",
|
||||
""
|
||||
)
|
||||
api_logger.info(f"用户 {end_user_id} 的画像数据不存在")
|
||||
return fail(BizCode.NOT_FOUND, "", "")
|
||||
|
||||
# Extract preferences from cache
|
||||
preferences = cached_profile.get("preferences", [])
|
||||
@@ -192,17 +230,17 @@ async def get_preference_tags(
|
||||
|
||||
filtered_preferences.append(pref)
|
||||
|
||||
api_logger.info(f"Retrieved {len(filtered_preferences)} preference tags for user: {user_id} (from cache)")
|
||||
api_logger.info(f"Retrieved {len(filtered_preferences)} preference tags for user: {end_user_id} (from cache)")
|
||||
return success(data=filtered_preferences, msg="偏好标签获取成功(缓存)")
|
||||
|
||||
except Exception as e:
|
||||
return handle_implicit_memory_error(e, "偏好标签获取", user_id)
|
||||
return handle_implicit_memory_error(e, "偏好标签获取", end_user_id)
|
||||
|
||||
|
||||
@router.get("/portrait/{user_id}", response_model=ApiResponse)
|
||||
@router.get("/portrait/{end_user_id}", response_model=ApiResponse)
|
||||
@cur_workspace_access_guard()
|
||||
async def get_dimension_portrait(
|
||||
user_id: str,
|
||||
end_user_id: str,
|
||||
include_history: bool = Query(False, description="Include historical trends"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
@@ -211,46 +249,42 @@ async def get_dimension_portrait(
|
||||
Get user's four-dimension personality portrait from cache.
|
||||
|
||||
Args:
|
||||
user_id: Target user ID
|
||||
end_user_id: Target end user ID
|
||||
include_history: Whether to include historical trend data (ignored for cached data)
|
||||
|
||||
Returns:
|
||||
Four-dimension personality portrait from cache
|
||||
"""
|
||||
api_logger.info(f"Dimension portrait requested for user: {user_id} (from cache)")
|
||||
api_logger.info(f"Dimension portrait requested for user: {end_user_id} (from cache)")
|
||||
|
||||
try:
|
||||
# Validate inputs
|
||||
validate_user_id(user_id)
|
||||
validate_user_id(end_user_id)
|
||||
|
||||
# Create service with user-specific config
|
||||
service = ImplicitMemoryService(db=db, end_user_id=user_id)
|
||||
service = ImplicitMemoryService(db=db, end_user_id=end_user_id)
|
||||
|
||||
# Get cached profile
|
||||
cached_profile = await service.get_cached_profile(end_user_id=user_id, db=db)
|
||||
cached_profile = await service.get_cached_profile(end_user_id=end_user_id, db=db)
|
||||
|
||||
if cached_profile is None:
|
||||
api_logger.info(f"用户 {user_id} 的画像缓存不存在或已过期")
|
||||
return fail(
|
||||
BizCode.NOT_FOUND,
|
||||
"画像缓存不存在或已过期,请右上角刷新生成新画像",
|
||||
""
|
||||
)
|
||||
api_logger.info(f"用户 {end_user_id} 的画像数据不存在")
|
||||
return fail(BizCode.NOT_FOUND, "", "")
|
||||
|
||||
# Extract portrait from cache
|
||||
portrait = cached_profile.get("portrait", {})
|
||||
|
||||
api_logger.info(f"Dimension portrait retrieved for user: {user_id} (from cache)")
|
||||
api_logger.info(f"Dimension portrait retrieved for user: {end_user_id} (from cache)")
|
||||
return success(data=portrait, msg="四维画像获取成功(缓存)")
|
||||
|
||||
except Exception as e:
|
||||
return handle_implicit_memory_error(e, "四维画像获取", user_id)
|
||||
return handle_implicit_memory_error(e, "四维画像获取", end_user_id)
|
||||
|
||||
|
||||
@router.get("/interest-areas/{user_id}", response_model=ApiResponse)
|
||||
@router.get("/interest-areas/{end_user_id}", response_model=ApiResponse)
|
||||
@cur_workspace_access_guard()
|
||||
async def get_interest_area_distribution(
|
||||
user_id: str,
|
||||
end_user_id: str,
|
||||
include_trends: bool = Query(False, description="Include trend analysis"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
@@ -259,46 +293,42 @@ async def get_interest_area_distribution(
|
||||
Get user's interest area distribution from cache.
|
||||
|
||||
Args:
|
||||
user_id: Target user ID
|
||||
end_user_id: Target end user ID
|
||||
include_trends: Whether to include trend analysis data (ignored for cached data)
|
||||
|
||||
Returns:
|
||||
Interest area distribution from cache
|
||||
"""
|
||||
api_logger.info(f"Interest area distribution requested for user: {user_id} (from cache)")
|
||||
api_logger.info(f"Interest area distribution requested for user: {end_user_id} (from cache)")
|
||||
|
||||
try:
|
||||
# Validate inputs
|
||||
validate_user_id(user_id)
|
||||
validate_user_id(end_user_id)
|
||||
|
||||
# Create service with user-specific config
|
||||
service = ImplicitMemoryService(db=db, end_user_id=user_id)
|
||||
service = ImplicitMemoryService(db=db, end_user_id=end_user_id)
|
||||
|
||||
# Get cached profile
|
||||
cached_profile = await service.get_cached_profile(end_user_id=user_id, db=db)
|
||||
cached_profile = await service.get_cached_profile(end_user_id=end_user_id, db=db)
|
||||
|
||||
if cached_profile is None:
|
||||
api_logger.info(f"用户 {user_id} 的画像缓存不存在或已过期")
|
||||
return fail(
|
||||
BizCode.NOT_FOUND,
|
||||
"画像缓存不存在或已过期,请右上角刷新生成新画像",
|
||||
""
|
||||
)
|
||||
api_logger.info(f"用户 {end_user_id} 的画像数据不存在")
|
||||
return fail(BizCode.NOT_FOUND, "", "")
|
||||
|
||||
# Extract interest areas from cache
|
||||
interest_areas = cached_profile.get("interest_areas", {})
|
||||
|
||||
api_logger.info(f"Interest area distribution retrieved for user: {user_id} (from cache)")
|
||||
api_logger.info(f"Interest area distribution retrieved for user: {end_user_id} (from cache)")
|
||||
return success(data=interest_areas, msg="兴趣领域分布获取成功(缓存)")
|
||||
|
||||
except Exception as e:
|
||||
return handle_implicit_memory_error(e, "兴趣领域分布获取", user_id)
|
||||
return handle_implicit_memory_error(e, "兴趣领域分布获取", end_user_id)
|
||||
|
||||
|
||||
@router.get("/habits/{user_id}", response_model=ApiResponse)
|
||||
@router.get("/habits/{end_user_id}", response_model=ApiResponse)
|
||||
@cur_workspace_access_guard()
|
||||
async def get_behavior_habits(
|
||||
user_id: str,
|
||||
end_user_id: str,
|
||||
confidence_level: Optional[str] = Query(None, regex="^(high|medium|low)$", description="Filter by confidence level"),
|
||||
frequency_pattern: Optional[str] = Query(None, regex="^(daily|weekly|monthly|seasonal|occasional|event_triggered)$", description="Filter by frequency pattern"),
|
||||
time_period: Optional[str] = Query(None, regex="^(current|past)$", description="Filter by time period"),
|
||||
@@ -309,7 +339,7 @@ async def get_behavior_habits(
|
||||
Get user's behavioral habits from cache.
|
||||
|
||||
Args:
|
||||
user_id: Target user ID
|
||||
end_user_id: Target end user ID
|
||||
confidence_level: Filter by confidence level (high, medium, low)
|
||||
frequency_pattern: Filter by frequency pattern (daily, weekly, monthly, seasonal, occasional, event_triggered)
|
||||
time_period: Filter by time period (current, past)
|
||||
@@ -317,25 +347,21 @@ async def get_behavior_habits(
|
||||
Returns:
|
||||
List of behavioral habits from cache
|
||||
"""
|
||||
api_logger.info(f"Behavior habits requested for user: {user_id} (from cache)")
|
||||
api_logger.info(f"Behavior habits requested for user: {end_user_id} (from cache)")
|
||||
|
||||
try:
|
||||
# Validate inputs
|
||||
validate_user_id(user_id)
|
||||
validate_user_id(end_user_id)
|
||||
|
||||
# Create service with user-specific config
|
||||
service = ImplicitMemoryService(db=db, end_user_id=user_id)
|
||||
service = ImplicitMemoryService(db=db, end_user_id=end_user_id)
|
||||
|
||||
# Get cached profile
|
||||
cached_profile = await service.get_cached_profile(end_user_id=user_id, db=db)
|
||||
cached_profile = await service.get_cached_profile(end_user_id=end_user_id, db=db)
|
||||
|
||||
if cached_profile is None:
|
||||
api_logger.info(f"用户 {user_id} 的画像缓存不存在或已过期")
|
||||
return fail(
|
||||
BizCode.NOT_FOUND,
|
||||
"画像缓存不存在或已过期,请右上角刷新生成新画像",
|
||||
""
|
||||
)
|
||||
api_logger.info(f"用户 {end_user_id} 的画像数据不存在")
|
||||
return fail(BizCode.NOT_FOUND, "", "")
|
||||
|
||||
# Extract habits from cache
|
||||
habits = cached_profile.get("habits", [])
|
||||
@@ -368,11 +394,11 @@ async def get_behavior_habits(
|
||||
|
||||
filtered_habits.append(habit)
|
||||
|
||||
api_logger.info(f"Retrieved {len(filtered_habits)} behavior habits for user: {user_id} (from cache)")
|
||||
api_logger.info(f"Retrieved {len(filtered_habits)} behavior habits for user: {end_user_id} (from cache)")
|
||||
return success(data=filtered_habits, msg="行为习惯获取成功(缓存)")
|
||||
|
||||
except Exception as e:
|
||||
return handle_implicit_memory_error(e, "行为习惯获取", user_id)
|
||||
return handle_implicit_memory_error(e, "行为习惯获取", end_user_id)
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -9,13 +9,16 @@ from sqlalchemy import or_
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.celery_app import celery_app
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.core.rag.common import settings
|
||||
from app.core.rag.integrations.feishu.client import FeishuAPIClient
|
||||
from app.core.rag.integrations.yuque.client import YuqueAPIClient
|
||||
from app.core.rag.llm.chat_model import Base
|
||||
from app.core.rag.nlp import rag_tokenizer, search
|
||||
from app.core.rag.prompts.generator import graph_entity_types
|
||||
from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory
|
||||
from app.core.response_utils import success
|
||||
from app.core.response_utils import success, fail
|
||||
from app.db import get_db
|
||||
from app.dependencies import get_current_user
|
||||
from app.models import knowledge_model
|
||||
@@ -484,3 +487,99 @@ async def rebuild_knowledge_graph(
|
||||
except Exception as e:
|
||||
api_logger.error(f"Failed to rebuild knowledge graph: knowledge_id={knowledge_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
@router.get("/check/yuque/auth", response_model=ApiResponse)
|
||||
async def check_yuque_auth(
|
||||
yuque_user_id: str,
|
||||
yuque_token: str,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
check yuque auth info
|
||||
"""
|
||||
api_logger.info(f"check yuque auth info, username: {current_user.username}")
|
||||
|
||||
try:
|
||||
api_client = YuqueAPIClient(
|
||||
user_id=yuque_user_id,
|
||||
token=yuque_token
|
||||
)
|
||||
async with api_client as client:
|
||||
repos = await client.get_user_repos()
|
||||
if repos:
|
||||
return success(msg="Successfully auth yuque info")
|
||||
return fail(BizCode.UNAUTHORIZED, msg="auth yuque info failed", error="user_id or token is incorrect")
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
api_logger.error(f"auth yuque info failed: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
@router.get("/check/feishu/auth", response_model=ApiResponse)
|
||||
async def check_feishu_auth(
|
||||
feishu_app_id: str,
|
||||
feishu_app_secret: str,
|
||||
feishu_folder_token: str,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
check feishu auth info
|
||||
"""
|
||||
api_logger.info(f"check feishu auth info, username: {current_user.username}")
|
||||
|
||||
try:
|
||||
api_client = FeishuAPIClient(
|
||||
app_id=feishu_app_id,
|
||||
app_secret=feishu_app_secret
|
||||
)
|
||||
async with api_client as client:
|
||||
files = await client.list_all_folder_files(feishu_folder_token, recursive=True)
|
||||
if files:
|
||||
return success(msg="Successfully auth feishu info")
|
||||
return fail(BizCode.UNAUTHORIZED, msg="auth feishu info failed", error="app_id or app_secret or feishu_folder_token is incorrect")
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
api_logger.error(f"auth feishu info failed: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
@router.post("/{knowledge_id}/sync", response_model=ApiResponse)
|
||||
async def sync_knowledge(
|
||||
knowledge_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
sync knowledge base information based on knowledge_id
|
||||
"""
|
||||
api_logger.info(f"Obtain details of the knowledge base: knowledge_id={knowledge_id}, username: {current_user.username}")
|
||||
|
||||
try:
|
||||
# 1. Query knowledge base information from the database
|
||||
api_logger.debug(f"Query knowledge base: {knowledge_id}")
|
||||
db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=knowledge_id, current_user=current_user)
|
||||
if not db_knowledge:
|
||||
api_logger.warning(f"The knowledge base does not exist or access is denied: knowledge_id={knowledge_id}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="The knowledge base does not exist or access is denied"
|
||||
)
|
||||
|
||||
# 2. sync knowledge
|
||||
# from app.tasks import sync_knowledge_for_kb
|
||||
# sync_knowledge_for_kb(kb_id)
|
||||
task = celery_app.send_task("app.core.rag.tasks.sync_knowledge_for_kb", args=[knowledge_id])
|
||||
result = {
|
||||
"task_id": task.id
|
||||
}
|
||||
return success(data=result, msg="Task accepted. sync knowledge is being processed in the background.")
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
api_logger.error(f"Failed to sync knowledge: knowledge_id={knowledge_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
336
api/app/controllers/mcp_market_config_controller.py
Normal file
336
api/app/controllers/mcp_market_config_controller.py
Normal file
@@ -0,0 +1,336 @@
|
||||
import datetime
|
||||
import json
|
||||
from typing import Optional
|
||||
import uuid
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Query
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
import requests
|
||||
from sqlalchemy import or_
|
||||
from sqlalchemy.orm import Session
|
||||
from modelscope.hub.errors import raise_for_http_status
|
||||
from modelscope.hub.mcp_api import MCPApi
|
||||
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.core.response_utils import success, fail
|
||||
from app.db import get_db
|
||||
from app.dependencies import get_current_user
|
||||
from app.models import mcp_market_config_model
|
||||
from app.models.user_model import User
|
||||
from app.schemas import mcp_market_config_schema
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.services import mcp_market_config_service
|
||||
|
||||
# Obtain a dedicated API logger
|
||||
api_logger = get_api_logger()
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/mcp_market_configs",
|
||||
tags=["mcp_market_configs"],
|
||||
dependencies=[Depends(get_current_user)] # Apply auth to all routes in this controller
|
||||
)
|
||||
|
||||
|
||||
@router.get("/mcp_servers", response_model=ApiResponse)
|
||||
async def get_mcp_servers(
|
||||
mcp_market_config_id: uuid.UUID,
|
||||
page: int = Query(1, gt=0), # Default: 1, which must be greater than 0
|
||||
pagesize: int = Query(20, gt=0, le=100), # Default: 20 items per page, maximum: 100 items
|
||||
keywords: Optional[str] = Query(None, description="Search keywords (Optional search query string,e.g. Chinese service name, English service name, author/owner username)"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Query the mcp servers list in pages
|
||||
- Support keyword search for name,author,owner
|
||||
- Return paging metadata + mcp server list
|
||||
"""
|
||||
api_logger.info(
|
||||
f"Query mcp server list: tenant_id={current_user.tenant_id}, page={page}, pagesize={pagesize}, keywords={keywords}, username: {current_user.username}")
|
||||
|
||||
# 1. parameter validation
|
||||
if page < 1 or pagesize < 1:
|
||||
api_logger.warning(f"Error in paging parameters: page={page}, pagesize={pagesize}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="The paging parameter must be greater than 0"
|
||||
)
|
||||
|
||||
# 2. Query mcp market config information from the database
|
||||
api_logger.debug(f"Query mcp market config: {mcp_market_config_id}")
|
||||
db_mcp_market_config = mcp_market_config_service.get_mcp_market_config_by_id(db,
|
||||
mcp_market_config_id=mcp_market_config_id,
|
||||
current_user=current_user)
|
||||
if not db_mcp_market_config:
|
||||
api_logger.warning(
|
||||
f"The mcp market config does not exist or access is denied: mcp_market_config_id={mcp_market_config_id}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="The mcp market config does not exist or access is denied"
|
||||
)
|
||||
|
||||
# 3. Execute paged query
|
||||
api = MCPApi()
|
||||
token = db_mcp_market_config.token
|
||||
api.login(token)
|
||||
|
||||
body = {
|
||||
'filter': {},
|
||||
'page_number': page,
|
||||
'page_size': pagesize,
|
||||
'search': keywords
|
||||
}
|
||||
|
||||
try:
|
||||
cookies = api.get_cookies(token)
|
||||
r = api.session.put(
|
||||
url=api.mcp_base_url,
|
||||
headers=api.builder_headers(api.headers),
|
||||
json=body,
|
||||
cookies=cookies)
|
||||
raise_for_http_status(r)
|
||||
except requests.exceptions.RequestException as e:
|
||||
api_logger.error(f"mFailed to get MCP servers: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to get MCP servers: {str(e)}"
|
||||
)
|
||||
|
||||
data = api._handle_response(r)
|
||||
total = data.get('total_count', 0)
|
||||
mcp_server_list = data.get('mcp_server_list', [])
|
||||
# items = [{
|
||||
# 'name': item.get('name', ''),
|
||||
# 'id': item.get('id', ''),
|
||||
# 'description': item.get('description', '')
|
||||
# } for item in mcp_server_list]
|
||||
|
||||
# 4. Return structured response
|
||||
result = {
|
||||
"items": mcp_server_list,
|
||||
"page": {
|
||||
"page": page,
|
||||
"pagesize": pagesize,
|
||||
"total": total,
|
||||
"has_next": True if page * pagesize < total else False
|
||||
}
|
||||
}
|
||||
return success(data=result, msg="Query of mcp servers list successful")
|
||||
|
||||
|
||||
@router.get("/mcp_server", response_model=ApiResponse)
|
||||
async def get_mcp_server(
|
||||
mcp_market_config_id: uuid.UUID,
|
||||
server_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Get detailed information for a specific MCP Server
|
||||
"""
|
||||
api_logger.info(
|
||||
f"Query mcp server: tenant_id={current_user.tenant_id}, mcp_market_config_id={mcp_market_config_id}, server_id={server_id}, username: {current_user.username}")
|
||||
|
||||
# 1. Query mcp market config information from the database
|
||||
api_logger.debug(f"Query mcp market config: {mcp_market_config_id}")
|
||||
db_mcp_market_config = mcp_market_config_service.get_mcp_market_config_by_id(db,
|
||||
mcp_market_config_id=mcp_market_config_id,
|
||||
current_user=current_user)
|
||||
if not db_mcp_market_config:
|
||||
api_logger.warning(
|
||||
f"The mcp market config does not exist or access is denied: mcp_market_config_id={mcp_market_config_id}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="The mcp market config does not exist or access is denied"
|
||||
)
|
||||
|
||||
# 2. Get detailed information for a specific MCP Server
|
||||
api = MCPApi()
|
||||
token = db_mcp_market_config.token
|
||||
api.login(token)
|
||||
|
||||
result = api.get_mcp_server(server_id=server_id)
|
||||
return success(data=result, msg="Query of mcp servers list successful")
|
||||
|
||||
|
||||
@router.post("/mcp_market_config", response_model=ApiResponse)
|
||||
async def create_mcp_market_config(
|
||||
create_data: mcp_market_config_schema.McpMarketConfigCreate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
create mcp market config
|
||||
"""
|
||||
api_logger.info(
|
||||
f"Request to create a mcp market config: mcp_market_id={create_data.mcp_market_id}, tenant_id={current_user.tenant_id}, username: {current_user.username}")
|
||||
|
||||
try:
|
||||
api_logger.debug(f"Start creating the mcp market config: {create_data.mcp_market_id}")
|
||||
# 1. Check if the mcp market name already exists
|
||||
db_mcp_market_config_exist = mcp_market_config_service.get_mcp_market_config_by_mcp_market_id(db, mcp_market_id=create_data.mcp_market_id, current_user=current_user)
|
||||
if db_mcp_market_config_exist:
|
||||
api_logger.warning(f"The mcp market id already exists: {create_data.mcp_market_id}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"The mcp market id already exists: {create_data.mcp_market_id}"
|
||||
)
|
||||
db_mcp_market_config = mcp_market_config_service.create_mcp_market_config(db=db, mcp_market_config=create_data, current_user=current_user)
|
||||
api_logger.info(
|
||||
f"The mcp market config has been successfully created: (ID: {db_mcp_market_config.id})")
|
||||
return success(data=jsonable_encoder(mcp_market_config_schema.McpMarketConfig.model_validate(db_mcp_market_config)),
|
||||
msg="The mcp market config has been successfully created")
|
||||
except Exception as e:
|
||||
api_logger.error(f"The creation of the mcp market config failed: {create_data.mcp_market_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
@router.get("/{mcp_market_config_id}", response_model=ApiResponse)
|
||||
async def get_mcp_market_config(
|
||||
mcp_market_config_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Retrieve mcp market config information based on mcp_market_config_id
|
||||
"""
|
||||
api_logger.info(
|
||||
f"Obtain details of the mcp market config: mcp_market_config_id={mcp_market_config_id}, username: {current_user.username}")
|
||||
|
||||
try:
|
||||
# 1. Query mcp market config information from the database
|
||||
api_logger.debug(f"Query mcp market config: {mcp_market_config_id}")
|
||||
db_mcp_market_config = mcp_market_config_service.get_mcp_market_config_by_id(db, mcp_market_config_id=mcp_market_config_id, current_user=current_user)
|
||||
if not db_mcp_market_config:
|
||||
api_logger.warning(f"The mcp market config does not exist or access is denied: mcp_market_config_id={mcp_market_config_id}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="The mcp market config does not exist or access is denied"
|
||||
)
|
||||
|
||||
api_logger.info(f"mcp market config query successful: (ID: {db_mcp_market_config.id})")
|
||||
return success(data=jsonable_encoder(mcp_market_config_schema.McpMarketConfig.model_validate(db_mcp_market_config)),
|
||||
msg="Successfully obtained mcp market config information")
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
api_logger.error(f"mcp market config query failed: mcp_market_config_id={mcp_market_config_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
@router.get("/mcp_market_id/{mcp_market_id}", response_model=ApiResponse)
|
||||
async def get_mcp_market_config_by_mcp_market_id(
|
||||
mcp_market_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Retrieve mcp market config information based on mcp_market_id
|
||||
"""
|
||||
api_logger.info(
|
||||
f"Request to create a mcp market config: mcp_market_id={mcp_market_id}, tenant_id={current_user.tenant_id}, username: {current_user.username}")
|
||||
|
||||
try:
|
||||
# 1. Query mcp market config information from the database
|
||||
api_logger.debug(f"Query mcp market config: mcp_market_id={mcp_market_id}")
|
||||
db_mcp_market_config = mcp_market_config_service.get_mcp_market_config_by_mcp_market_id(db, mcp_market_id=mcp_market_id, current_user=current_user)
|
||||
if not db_mcp_market_config:
|
||||
api_logger.warning(f"The mcp market config does not exist or access is denied: mcp_market_id={mcp_market_id}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="The mcp market config does not exist or access is denied"
|
||||
)
|
||||
|
||||
api_logger.info(f"mcp market config query successful: (ID: {db_mcp_market_config.id})")
|
||||
return success(data=jsonable_encoder(mcp_market_config_schema.McpMarketConfig.model_validate(db_mcp_market_config)),
|
||||
msg="Successfully obtained mcp market config information")
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
api_logger.error(f"mcp market config query failed: mcp_market_id={mcp_market_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
@router.put("/{mcp_market_config_id}", response_model=ApiResponse)
|
||||
async def update_mcp_market_config(
|
||||
mcp_market_config_id: uuid.UUID,
|
||||
update_data: mcp_market_config_schema.McpMarketConfigUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
# 1. Check if the mcp market config exists
|
||||
api_logger.debug(f"Query the mcp market config to be updated: {mcp_market_config_id}")
|
||||
db_mcp_market_config = mcp_market_config_service.get_mcp_market_config_by_id(db, mcp_market_config_id=mcp_market_config_id, current_user=current_user)
|
||||
|
||||
if not db_mcp_market_config:
|
||||
api_logger.warning(
|
||||
f"The mcp market config does not exist or you do not have permission to access it: mcp_market_config_id={mcp_market_config_id}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="The mcp market config does not exist or you do not have permission to access it"
|
||||
)
|
||||
|
||||
# 2. Update fields (only update non-null fields)
|
||||
api_logger.debug(f"Start updating the mcp market config fields: {mcp_market_config_id}")
|
||||
update_dict = update_data.dict(exclude_unset=True)
|
||||
updated_fields = []
|
||||
for field, value in update_dict.items():
|
||||
if hasattr(db_mcp_market_config, field):
|
||||
old_value = getattr(db_mcp_market_config, field)
|
||||
if old_value != value:
|
||||
# update value
|
||||
setattr(db_mcp_market_config, field, value)
|
||||
updated_fields.append(f"{field}: {old_value} -> {value}")
|
||||
|
||||
if updated_fields:
|
||||
api_logger.debug(f"updated fields: {', '.join(updated_fields)}")
|
||||
|
||||
# 3. Save to database
|
||||
try:
|
||||
db.commit()
|
||||
db.refresh(db_mcp_market_config)
|
||||
api_logger.info(f"The mcp market config has been successfully updated: (ID: {db_mcp_market_config.id})")
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
api_logger.error(f"The mcp market config update failed: mcp_market_config_id={mcp_market_config_id} - {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"The mcp market config update failed: {str(e)}"
|
||||
)
|
||||
|
||||
# 4. Return the updated mcp market config
|
||||
return success(data=jsonable_encoder(mcp_market_config_schema.McpMarketConfig.model_validate(db_mcp_market_config)),
|
||||
msg="The mcp market config information updated successfully")
|
||||
|
||||
|
||||
@router.delete("/{mcp_market_config_id}", response_model=ApiResponse)
|
||||
async def delete_mcp_market_config(
|
||||
mcp_market_config_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
delete mcp market config
|
||||
"""
|
||||
api_logger.info(f"Request to delete mcp market config: mcp_market_config_id={mcp_market_config_id}, username: {current_user.username}")
|
||||
|
||||
try:
|
||||
# 1. Check whether the mcp market config exists
|
||||
api_logger.debug(f"Check whether the mcp market config exists: {mcp_market_config_id}")
|
||||
db_mcp_market_config = mcp_market_config_service.get_mcp_market_config_by_id(db, mcp_market_config_id=mcp_market_config_id, current_user=current_user)
|
||||
|
||||
if not db_mcp_market_config:
|
||||
api_logger.warning(
|
||||
f"The mcp market config does not exist or you do not have permission to access it: mcp_market_config_id={mcp_market_config_id}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="The mcp market config does not exist or you do not have permission to access it"
|
||||
)
|
||||
|
||||
# 2. Deleting mcp market config
|
||||
mcp_market_config_service.delete_mcp_market_config_by_id(db, mcp_market_config_id=mcp_market_config_id, current_user=current_user)
|
||||
api_logger.info(f"The mcp market config has been successfully deleted: (ID: {mcp_market_config_id})")
|
||||
return success(msg="The mcp market config has been successfully deleted")
|
||||
except Exception as e:
|
||||
api_logger.error(f"Failed to delete from the mcp market config: mcp_market_config_id={mcp_market_config_id} - {str(e)}")
|
||||
raise
|
||||
262
api/app/controllers/mcp_market_controller.py
Normal file
262
api/app/controllers/mcp_market_controller.py
Normal file
@@ -0,0 +1,262 @@
|
||||
import datetime
|
||||
import json
|
||||
from typing import Optional
|
||||
import uuid
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Query
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from sqlalchemy import or_
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.core.response_utils import success, fail
|
||||
from app.db import get_db
|
||||
from app.dependencies import get_current_user
|
||||
from app.models import mcp_market_model
|
||||
from app.models.user_model import User
|
||||
from app.schemas import mcp_market_schema
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.services import mcp_market_service
|
||||
|
||||
# Obtain a dedicated API logger
|
||||
api_logger = get_api_logger()
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/mcp_markets",
|
||||
tags=["mcp_markets"],
|
||||
dependencies=[Depends(get_current_user)] # Apply auth to all routes in this controller
|
||||
)
|
||||
|
||||
|
||||
@router.get("/mcp_markets", response_model=ApiResponse)
|
||||
async def get_mcp_markets(
|
||||
page: int = Query(1, gt=0), # Default: 1, which must be greater than 0
|
||||
pagesize: int = Query(20, gt=0, le=100), # Default: 20 items per page, maximum: 100 items
|
||||
orderby: Optional[str] = Query(None, description="Sort fields, such as: category, created_at"),
|
||||
desc: Optional[bool] = Query(False, description="Is it descending order"),
|
||||
keywords: Optional[str] = Query(None, description="Search keywords (mcp_market base name)"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Query the mcp markets list in pages
|
||||
- Support keyword search for name,description
|
||||
- Support dynamic sorting
|
||||
- Return paging metadata + mcp_market list
|
||||
"""
|
||||
api_logger.info(
|
||||
f"Query mcp market list: tenant_id={current_user.tenant_id}, page={page}, pagesize={pagesize}, keywords={keywords}, username: {current_user.username}")
|
||||
|
||||
# 1. parameter validation
|
||||
if page < 1 or pagesize < 1:
|
||||
api_logger.warning(f"Error in paging parameters: page={page}, pagesize={pagesize}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="The paging parameter must be greater than 0"
|
||||
)
|
||||
|
||||
# 2. Construct query conditions
|
||||
filters = []
|
||||
|
||||
# Keyword search (fuzzy matching of mcp market name,description)
|
||||
if keywords:
|
||||
api_logger.debug(f"Add keyword search criteria: {keywords}")
|
||||
filters.append(
|
||||
or_(
|
||||
mcp_market_model.McpMarket.name.ilike(f"%{keywords}%"),
|
||||
mcp_market_model.McpMarket.description.ilike(f"%{keywords}%")
|
||||
)
|
||||
)
|
||||
# 3. Execute paged query
|
||||
try:
|
||||
api_logger.debug("Start executing mcp market paging query")
|
||||
total, items = mcp_market_service.get_mcp_markets_paginated(
|
||||
db=db,
|
||||
filters=filters,
|
||||
page=page,
|
||||
pagesize=pagesize,
|
||||
orderby=orderby,
|
||||
desc=desc,
|
||||
current_user=current_user
|
||||
)
|
||||
api_logger.info(f"mcp market query successful: total={total}, returned={len(items)} records")
|
||||
except Exception as e:
|
||||
api_logger.error(f"mcp market query failed: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Query failed: {str(e)}"
|
||||
)
|
||||
|
||||
# 4. Return structured response
|
||||
result = {
|
||||
"items": items,
|
||||
"page": {
|
||||
"page": page,
|
||||
"pagesize": pagesize,
|
||||
"total": total,
|
||||
"has_next": True if page * pagesize < total else False
|
||||
}
|
||||
}
|
||||
return success(data=jsonable_encoder(result), msg="Query of mcp market list successful")
|
||||
|
||||
|
||||
@router.post("/mcp_market", response_model=ApiResponse)
|
||||
async def create_mcp_market(
|
||||
create_data: mcp_market_schema.McpMarketCreate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
create mcp market
|
||||
"""
|
||||
api_logger.info(
|
||||
f"Request to create a mcp market: name={create_data.name}, tenant_id={current_user.tenant_id}, username: {current_user.username}")
|
||||
|
||||
try:
|
||||
api_logger.debug(f"Start creating the mcp market: {create_data.name}")
|
||||
# 1. Check if the mcp market name already exists
|
||||
db_mcp_market_exist = mcp_market_service.get_mcp_market_by_name(db, name=create_data.name, current_user=current_user)
|
||||
if db_mcp_market_exist:
|
||||
api_logger.warning(f"The mcp market name already exists: {create_data.name}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"The mcp market name already exists: {create_data.name}"
|
||||
)
|
||||
db_mcp_market = mcp_market_service.create_mcp_market(db=db, mcp_market=create_data, current_user=current_user)
|
||||
api_logger.info(
|
||||
f"The mcp market has been successfully created: {db_mcp_market.name} (ID: {db_mcp_market.id})")
|
||||
return success(data=jsonable_encoder(mcp_market_schema.McpMarket.model_validate(db_mcp_market)),
|
||||
msg="The mcp market has been successfully created")
|
||||
except Exception as e:
|
||||
api_logger.error(f"The creation of the mcp market failed: {create_data.name} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
@router.get("/{mcp_market_id}", response_model=ApiResponse)
|
||||
async def get_mcp_market(
|
||||
mcp_market_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Retrieve mcp market information based on mcp_market_id
|
||||
"""
|
||||
api_logger.info(
|
||||
f"Obtain details of the mcp market: mcp_market_id={mcp_market_id}, username: {current_user.username}")
|
||||
|
||||
try:
|
||||
# 1. Query mcp market information from the database
|
||||
api_logger.debug(f"Query mcp market: {mcp_market_id}")
|
||||
db_mcp_market = mcp_market_service.get_mcp_market_by_id(db, mcp_market_id=mcp_market_id, current_user=current_user)
|
||||
if not db_mcp_market:
|
||||
api_logger.warning(f"The mcp market does not exist or access is denied: mcp_market_id={mcp_market_id}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="The mcp market does not exist or access is denied"
|
||||
)
|
||||
|
||||
api_logger.info(f"mcp market query successful: {db_mcp_market.name} (ID: {db_mcp_market.id})")
|
||||
return success(data=jsonable_encoder(mcp_market_schema.McpMarket.model_validate(db_mcp_market)),
|
||||
msg="Successfully obtained mcp market information")
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
api_logger.error(f"mcp market query failed: mcp_market_id={mcp_market_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
@router.put("/{mcp_market_id}", response_model=ApiResponse)
|
||||
async def update_mcp_market(
|
||||
mcp_market_id: uuid.UUID,
|
||||
update_data: mcp_market_schema.McpMarketUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
# 1. Check if the mcp market exists
|
||||
api_logger.debug(f"Query the mcp market to be updated: {mcp_market_id}")
|
||||
db_mcp_market = mcp_market_service.get_mcp_market_by_id(db, mcp_market_id=mcp_market_id, current_user=current_user)
|
||||
|
||||
if not db_mcp_market:
|
||||
api_logger.warning(
|
||||
f"The mcp market does not exist or you do not have permission to access it: mcp_market_id={mcp_market_id}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="The mcp market does not exist or you do not have permission to access it"
|
||||
)
|
||||
|
||||
# 2. not updating the name (name already exists)
|
||||
update_dict = update_data.dict(exclude_unset=True)
|
||||
if "name" in update_dict:
|
||||
name = update_dict["name"]
|
||||
if name != db_mcp_market.name:
|
||||
# Check if the mcp market name already exists
|
||||
db_mcp_market_exist = mcp_market_service.get_mcp_market_by_name(db, name=name, current_user=current_user)
|
||||
if db_mcp_market_exist:
|
||||
api_logger.warning(f"The mcp market name already exists: {name}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"The mcp market name already exists: {name}"
|
||||
)
|
||||
# 3. Update fields (only update non-null fields)
|
||||
api_logger.debug(f"Start updating the mcp market fields: {mcp_market_id}")
|
||||
updated_fields = []
|
||||
for field, value in update_dict.items():
|
||||
if hasattr(db_mcp_market, field):
|
||||
old_value = getattr(db_mcp_market, field)
|
||||
if old_value != value:
|
||||
# update value
|
||||
setattr(db_mcp_market, field, value)
|
||||
updated_fields.append(f"{field}: {old_value} -> {value}")
|
||||
|
||||
if updated_fields:
|
||||
api_logger.debug(f"updated fields: {', '.join(updated_fields)}")
|
||||
|
||||
# 4. Save to database
|
||||
try:
|
||||
db.commit()
|
||||
db.refresh(db_mcp_market)
|
||||
api_logger.info(f"The mcp market has been successfully updated: {db_mcp_market.name} (ID: {db_mcp_market.id})")
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
api_logger.error(f"The mcp market update failed: mcp_market_id={mcp_market_id} - {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"The mcp market update failed: {str(e)}"
|
||||
)
|
||||
|
||||
# 5. Return the updated mcp market
|
||||
return success(data=jsonable_encoder(mcp_market_schema.McpMarket.model_validate(db_mcp_market)),
|
||||
msg="The mcp market information updated successfully")
|
||||
|
||||
|
||||
@router.delete("/{mcp_market_id}", response_model=ApiResponse)
|
||||
async def delete_mcp_market(
|
||||
mcp_market_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
delete mcp market
|
||||
"""
|
||||
api_logger.info(f"Request to delete mcp market: mcp_market_id={mcp_market_id}, username: {current_user.username}")
|
||||
|
||||
try:
|
||||
# 1. Check whether the mcp market exists
|
||||
api_logger.debug(f"Check whether the mcp market exists: {mcp_market_id}")
|
||||
db_mcp_market = mcp_market_service.get_mcp_market_by_id(db, mcp_market_id=mcp_market_id, current_user=current_user)
|
||||
|
||||
if not db_mcp_market:
|
||||
api_logger.warning(
|
||||
f"The mcp market does not exist or you do not have permission to access it: mcp_market_id={mcp_market_id}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="The mcp market does not exist or you do not have permission to access it"
|
||||
)
|
||||
|
||||
# 2. Deleting mcp market
|
||||
mcp_market_service.delete_mcp_market_by_id(db, mcp_market_id=mcp_market_id, current_user=current_user)
|
||||
api_logger.info(f"The mcp market has been successfully deleted: (ID: {mcp_market_id})")
|
||||
return success(msg="The mcp market has been successfully deleted")
|
||||
except Exception as e:
|
||||
api_logger.error(f"Failed to delete from the mcp market: mcp_market_id={mcp_market_id} - {str(e)}")
|
||||
raise
|
||||
@@ -2,6 +2,7 @@ from typing import List, Optional
|
||||
|
||||
from app.celery_app import celery_app
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.language_utils import get_language_from_header
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.core.rag.llm.cv_model import QWenCV
|
||||
from app.core.response_utils import fail, success
|
||||
@@ -118,6 +119,7 @@ async def download_log(
|
||||
@cur_workspace_access_guard()
|
||||
async def write_server(
|
||||
user_input: Write_UserInput,
|
||||
language_type: str = Header(default=None, alias="X-Language-Type"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
@@ -125,14 +127,18 @@ async def write_server(
|
||||
Write service endpoint - processes write operations synchronously
|
||||
|
||||
Args:
|
||||
user_input: Write request containing message and group_id
|
||||
user_input: Write request containing message and end_user_id
|
||||
language_type: 语言类型 ("zh" 中文, "en" 英文),通过 X-Language-Type Header 传递
|
||||
|
||||
Returns:
|
||||
Response with write operation status
|
||||
"""
|
||||
# 使用集中化的语言校验
|
||||
language = get_language_from_header(language_type)
|
||||
|
||||
config_id = user_input.config_id
|
||||
workspace_id = current_user.current_workspace_id
|
||||
api_logger.info(f"Write service: workspace_id={workspace_id}, config_id={config_id}")
|
||||
api_logger.info(f"Write service: workspace_id={workspace_id}, config_id={config_id}, language_type={language}")
|
||||
|
||||
# 获取 storage_type,如果为 None 则使用默认值
|
||||
storage_type = workspace_service.get_workspace_storage_type(
|
||||
@@ -160,19 +166,19 @@ async def write_server(
|
||||
api_logger.warning("workspace_id 为空,无法使用 rag 存储,将使用 neo4j 存储")
|
||||
storage_type = 'neo4j'
|
||||
|
||||
api_logger.info(f"Write service requested for group {user_input.group_id}, storage_type: {storage_type}, user_rag_memory_id: {user_rag_memory_id}")
|
||||
api_logger.info(f"Write service requested for group {user_input.end_user_id}, storage_type: {storage_type}, user_rag_memory_id: {user_rag_memory_id}")
|
||||
try:
|
||||
# 获取标准化的消息列表
|
||||
messages_list = memory_agent_service.get_messages_list(user_input)
|
||||
|
||||
result = await memory_agent_service.write_memory(
|
||||
user_input.group_id,
|
||||
messages_list, # 传递结构化消息列表
|
||||
user_input.end_user_id,
|
||||
messages_list,
|
||||
config_id,
|
||||
db,
|
||||
storage_type,
|
||||
user_rag_memory_id
|
||||
user_rag_memory_id,
|
||||
language
|
||||
)
|
||||
|
||||
return success(data=result, msg="写入成功")
|
||||
except BaseException as e:
|
||||
# Handle ExceptionGroup from TaskGroup (Python 3.11+) or BaseExceptionGroup
|
||||
@@ -189,6 +195,7 @@ async def write_server(
|
||||
@cur_workspace_access_guard()
|
||||
async def write_server_async(
|
||||
user_input: Write_UserInput,
|
||||
language_type: str = Header(default=None, alias="X-Language-Type"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
@@ -196,15 +203,19 @@ async def write_server_async(
|
||||
Async write service endpoint - enqueues write processing to Celery
|
||||
|
||||
Args:
|
||||
user_input: Write request containing message and group_id
|
||||
user_input: Write request containing message and end_user_id
|
||||
language_type: 语言类型 ("zh" 中文, "en" 英文),通过 X-Language-Type Header 传递
|
||||
|
||||
Returns:
|
||||
Task ID for tracking async operation
|
||||
Use GET /memory/write_result/{task_id} to check task status and get result
|
||||
"""
|
||||
# 使用集中化的语言校验
|
||||
language = get_language_from_header(language_type)
|
||||
|
||||
config_id = user_input.config_id
|
||||
workspace_id = current_user.current_workspace_id
|
||||
api_logger.info(f"Async write service: workspace_id={workspace_id}, config_id={config_id}")
|
||||
api_logger.info(f"Async write service: workspace_id={workspace_id}, config_id={config_id}, language_type={language}")
|
||||
|
||||
# 获取 storage_type,如果为 None 则使用默认值
|
||||
storage_type = workspace_service.get_workspace_storage_type(
|
||||
@@ -226,10 +237,10 @@ async def write_server_async(
|
||||
try:
|
||||
# 获取标准化的消息列表
|
||||
messages_list = memory_agent_service.get_messages_list(user_input)
|
||||
|
||||
|
||||
task = celery_app.send_task(
|
||||
"app.core.memory.agent.write_message",
|
||||
args=[user_input.group_id, messages_list, config_id, storage_type, user_rag_memory_id]
|
||||
args=[user_input.end_user_id, messages_list, config_id, storage_type, user_rag_memory_id, language]
|
||||
)
|
||||
api_logger.info(f"Write task queued: {task.id}")
|
||||
|
||||
@@ -255,16 +266,14 @@ async def read_server(
|
||||
- "2": Direct answer based on context
|
||||
|
||||
Args:
|
||||
user_input: Read request with message, history, search_switch, and group_id
|
||||
user_input: Read request with message, history, search_switch, and end_user_id
|
||||
|
||||
Returns:
|
||||
Response with query answer
|
||||
"""
|
||||
config_id = user_input.config_id
|
||||
workspace_id = current_user.current_workspace_id
|
||||
api_logger.info(f"Read service: workspace_id={workspace_id}, config_id={config_id}")
|
||||
|
||||
# 获取 storage_type,如果为 None 则使用默认值
|
||||
storage_type = workspace_service.get_workspace_storage_type(
|
||||
db=db,
|
||||
workspace_id=workspace_id,
|
||||
@@ -279,12 +288,13 @@ async def read_server(
|
||||
name="USER_RAG_MERORY",
|
||||
workspace_id=workspace_id
|
||||
)
|
||||
if knowledge: user_rag_memory_id = str(knowledge.id)
|
||||
if knowledge:
|
||||
user_rag_memory_id = str(knowledge.id)
|
||||
|
||||
api_logger.info(f"Read service: group={user_input.group_id}, storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}, workspace_id={workspace_id}")
|
||||
api_logger.info(f"Read service: group={user_input.end_user_id}, storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}, workspace_id={workspace_id}")
|
||||
try:
|
||||
result = await memory_agent_service.read_memory(
|
||||
user_input.group_id,
|
||||
user_input.end_user_id,
|
||||
user_input.message,
|
||||
user_input.history,
|
||||
user_input.search_switch,
|
||||
@@ -295,17 +305,20 @@ async def read_server(
|
||||
)
|
||||
if str(user_input.search_switch) == "2":
|
||||
retrieve_info = result['answer']
|
||||
history = await SessionService(store).get_history(user_input.group_id, user_input.group_id, user_input.group_id)
|
||||
history = await SessionService(store).get_history(user_input.end_user_id, user_input.end_user_id, user_input.end_user_id)
|
||||
query = user_input.message
|
||||
|
||||
|
||||
# 调用 memory_agent_service 的方法生成最终答案
|
||||
result['answer'] = await memory_agent_service.generate_summary_from_retrieve(
|
||||
end_user_id=user_input.end_user_id,
|
||||
retrieve_info=retrieve_info,
|
||||
history=history,
|
||||
query=query,
|
||||
config_id=config_id,
|
||||
db=db
|
||||
)
|
||||
if "信息不足,无法回答" in result['answer']:
|
||||
result['answer']=retrieve_info
|
||||
return success(data=result, msg="回复对话消息成功")
|
||||
except BaseException as e:
|
||||
# Handle ExceptionGroup from TaskGroup (Python 3.11+) or BaseExceptionGroup
|
||||
@@ -403,7 +416,7 @@ async def read_server_async(
|
||||
try:
|
||||
task = celery_app.send_task(
|
||||
"app.core.memory.agent.read_message",
|
||||
args=[user_input.group_id, user_input.message, user_input.history, user_input.search_switch,
|
||||
args=[user_input.end_user_id, user_input.message, user_input.history, user_input.search_switch,
|
||||
config_id, storage_type, user_rag_memory_id]
|
||||
)
|
||||
api_logger.info(f"Read task queued: {task.id}")
|
||||
@@ -447,7 +460,7 @@ async def get_read_task_result(
|
||||
return success(
|
||||
data={
|
||||
"result": task_result.get("result"),
|
||||
"group_id": task_result.get("group_id"),
|
||||
"end_user_id": task_result.get("end_user_id"),
|
||||
"elapsed_time": task_result.get("elapsed_time"),
|
||||
"task_id": task_id
|
||||
},
|
||||
@@ -524,7 +537,7 @@ async def get_write_task_result(
|
||||
return success(
|
||||
data={
|
||||
"result": task_result.get("result"),
|
||||
"group_id": task_result.get("group_id"),
|
||||
"end_user_id": task_result.get("end_user_id"),
|
||||
"elapsed_time": task_result.get("elapsed_time"),
|
||||
"task_id": task_id
|
||||
},
|
||||
@@ -578,16 +591,16 @@ async def status_type(
|
||||
Determine the type of user message (read or write)
|
||||
|
||||
Args:
|
||||
user_input: Request containing user message and group_id
|
||||
user_input: Request containing user message and end_user_id
|
||||
|
||||
Returns:
|
||||
Type classification result
|
||||
"""
|
||||
api_logger.info(f"Status type check requested for group {user_input.group_id}")
|
||||
api_logger.info(f"Status type check requested for group {user_input.end_user_id}")
|
||||
try:
|
||||
# 获取标准化的消息列表
|
||||
messages_list = memory_agent_service.get_messages_list(user_input)
|
||||
|
||||
|
||||
# 将消息列表转换为字符串用于分类
|
||||
# 只取最后一条用户消息进行分类
|
||||
last_user_message = ""
|
||||
@@ -595,11 +608,11 @@ async def status_type(
|
||||
if msg.get('role') == 'user':
|
||||
last_user_message = msg.get('content', '')
|
||||
break
|
||||
|
||||
|
||||
if not last_user_message:
|
||||
# 如果没有用户消息,使用所有消息的内容
|
||||
last_user_message = " ".join([msg.get('content', '') for msg in messages_list])
|
||||
|
||||
|
||||
result = await memory_agent_service.classify_message_type(
|
||||
last_user_message,
|
||||
user_input.config_id,
|
||||
@@ -620,12 +633,11 @@ async def get_knowledge_type_stats_api(
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
统计当前空间下各知识库类型的数量,包含 General | Web | Third-party | Folder | memory。
|
||||
统计当前空间下各知识库类型的数量,包含 General | Web | Third-party | Folder。
|
||||
会对缺失类型补 0,返回字典形式。
|
||||
可选按状态过滤。
|
||||
- 知识库类型根据当前用户的 current_workspace_id 过滤
|
||||
- memory 是 Neo4j 中 Chunk 的数量,根据 end_user_id (group_id) 过滤
|
||||
- 如果用户没有当前工作空间或未提供 end_user_id,对应的统计返回 0
|
||||
- 如果用户没有当前工作空间,对应的统计返回 0
|
||||
"""
|
||||
api_logger.info(f"Knowledge type stats requested for workspace_id: {current_user.current_workspace_id}, end_user_id: {end_user_id}")
|
||||
try:
|
||||
@@ -652,7 +664,6 @@ async def get_knowledge_type_stats_api(
|
||||
@router.get("/analytics/hot_memory_tags/by_user", response_model=ApiResponse)
|
||||
async def get_hot_memory_tags_by_user_api(
|
||||
end_user_id: Optional[str] = Query(None, description="用户ID(可选)"),
|
||||
language_type: str = Header(default="zh", alias="X-Language-Type"),
|
||||
limit: int = Query(20, description="返回标签数量限制"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session=Depends(get_db),
|
||||
@@ -660,28 +671,18 @@ async def get_hot_memory_tags_by_user_api(
|
||||
"""
|
||||
获取指定用户的热门记忆标签
|
||||
|
||||
注意:标签语言由写入时的 X-Language-Type 决定,查询时不进行翻译
|
||||
|
||||
返回格式:
|
||||
[
|
||||
{"name": "标签名", "frequency": 频次},
|
||||
...
|
||||
]
|
||||
"""
|
||||
|
||||
workspace_id=current_user.current_workspace_id
|
||||
workspace_repo = WorkspaceRepository(db)
|
||||
workspace_models = workspace_repo.get_workspace_models_configs(workspace_id)
|
||||
|
||||
if workspace_models:
|
||||
model_id = workspace_models.get("llm", None)
|
||||
else:
|
||||
model_id = None
|
||||
|
||||
api_logger.info(f"Hot memory tags by user requested: end_user_id={end_user_id}")
|
||||
try:
|
||||
result = await memory_agent_service.get_hot_memory_tags_by_user(
|
||||
end_user_id=end_user_id,
|
||||
language_type=language_type,
|
||||
model_id=model_id,
|
||||
limit=limit
|
||||
)
|
||||
return success(data=result, msg="获取热门记忆标签成功")
|
||||
@@ -697,7 +698,7 @@ async def get_user_profile_api(
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
获取工作空间下Popular Memory Tags,包含:
|
||||
获取用户详情,包含:
|
||||
- name: 用户名字(直接使用 end_user_id)
|
||||
- tags: 3个用户特征标签(从语句和实体中LLM总结)
|
||||
- hot_tags: 4个热门记忆标签
|
||||
|
||||
@@ -9,6 +9,7 @@ from app.schemas.response_schema import ApiResponse
|
||||
|
||||
from app.services import memory_dashboard_service, memory_storage_service, workspace_service
|
||||
from app.services.memory_agent_service import get_end_users_connected_configs_batch
|
||||
from app.services.app_statistics_service import AppStatisticsService
|
||||
from app.core.logging_config import get_api_logger
|
||||
|
||||
# 获取API专用日志器
|
||||
@@ -49,63 +50,134 @@ async def get_workspace_end_users(
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
获取工作空间的宿主列表
|
||||
获取工作空间的宿主列表(高性能优化版本 v2)
|
||||
|
||||
返回格式与原 memory_list 接口中的 end_users 字段相同,
|
||||
并包含每个用户的记忆配置信息(memory_config_id 和 memory_config_name)
|
||||
优化策略:
|
||||
1. 批量查询 end_users(一次查询而非循环)
|
||||
2. 并发查询所有用户的记忆数量(Neo4j)
|
||||
3. RAG 模式使用批量查询(一次 SQL)
|
||||
4. 只返回必要字段减少数据传输
|
||||
5. 添加短期缓存减少重复查询
|
||||
6. 并发执行配置查询和记忆数量查询
|
||||
|
||||
返回格式:
|
||||
{
|
||||
"end_user": {"id": "uuid", "other_name": "名称"},
|
||||
"memory_num": {"total": 数量},
|
||||
"memory_config": {"memory_config_id": "id", "memory_config_name": "名称"}
|
||||
}
|
||||
"""
|
||||
import asyncio
|
||||
import json
|
||||
from app.aioRedis import aio_redis_get, aio_redis_set
|
||||
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
# 尝试从缓存获取(30秒缓存)
|
||||
cache_key = f"end_users:workspace:{workspace_id}"
|
||||
try:
|
||||
cached_data = await aio_redis_get(cache_key)
|
||||
if cached_data:
|
||||
api_logger.info(f"从缓存获取宿主列表: workspace_id={workspace_id}")
|
||||
return success(data=json.loads(cached_data), msg="宿主列表获取成功")
|
||||
except Exception as e:
|
||||
api_logger.warning(f"Redis 缓存读取失败: {str(e)}")
|
||||
|
||||
# 获取当前空间类型
|
||||
current_workspace_type = memory_dashboard_service.get_current_workspace_type(db, workspace_id, current_user)
|
||||
api_logger.info(f"用户 {current_user.username} 请求获取工作空间 {workspace_id} 的宿主列表")
|
||||
|
||||
# 获取 end_users(已优化为批量查询)
|
||||
end_users = memory_dashboard_service.get_workspace_end_users(
|
||||
db=db,
|
||||
workspace_id=workspace_id,
|
||||
current_user=current_user
|
||||
)
|
||||
|
||||
# 批量获取所有用户的记忆配置信息(优化:一次查询而非 N 次)
|
||||
end_user_ids = [str(user.id) for user in end_users]
|
||||
memory_configs_map = {}
|
||||
if end_user_ids:
|
||||
if not end_users:
|
||||
api_logger.info("工作空间下没有宿主")
|
||||
# 缓存空结果,避免重复查询
|
||||
try:
|
||||
memory_configs_map = get_end_users_connected_configs_batch(end_user_ids, db)
|
||||
await aio_redis_set(cache_key, json.dumps([]), expire=30)
|
||||
except Exception as e:
|
||||
api_logger.warning(f"Redis 缓存写入失败: {str(e)}")
|
||||
return success(data=[], msg="宿主列表获取成功")
|
||||
|
||||
end_user_ids = [str(user.id) for user in end_users]
|
||||
|
||||
# 并发执行两个独立的查询任务
|
||||
async def get_memory_configs():
|
||||
"""获取记忆配置(在线程池中执行同步查询)"""
|
||||
try:
|
||||
return await asyncio.to_thread(
|
||||
get_end_users_connected_configs_batch,
|
||||
end_user_ids, db
|
||||
)
|
||||
except Exception as e:
|
||||
api_logger.error(f"批量获取记忆配置失败: {str(e)}")
|
||||
# 失败时使用空字典,不影响其他数据返回
|
||||
return {}
|
||||
|
||||
async def get_memory_nums():
|
||||
"""获取记忆数量"""
|
||||
if current_workspace_type == "rag":
|
||||
# RAG 模式:批量查询
|
||||
try:
|
||||
chunk_map = await asyncio.to_thread(
|
||||
memory_dashboard_service.get_users_total_chunk_batch,
|
||||
end_user_ids, db, current_user
|
||||
)
|
||||
return {uid: {"total": count} for uid, count in chunk_map.items()}
|
||||
except Exception as e:
|
||||
api_logger.error(f"批量获取 RAG chunk 数量失败: {str(e)}")
|
||||
return {uid: {"total": 0} for uid in end_user_ids}
|
||||
|
||||
elif current_workspace_type == "neo4j":
|
||||
# Neo4j 模式:并发查询(带并发限制)
|
||||
# 使用信号量限制并发数,避免大量用户时压垮 Neo4j
|
||||
MAX_CONCURRENT_QUERIES = 10
|
||||
semaphore = asyncio.Semaphore(MAX_CONCURRENT_QUERIES)
|
||||
|
||||
async def get_neo4j_memory_num(end_user_id: str):
|
||||
async with semaphore:
|
||||
try:
|
||||
return await memory_storage_service.search_all(end_user_id)
|
||||
except Exception as e:
|
||||
api_logger.error(f"获取用户 {end_user_id} Neo4j 记忆数量失败: {str(e)}")
|
||||
return {"total": 0}
|
||||
|
||||
memory_nums_list = await asyncio.gather(*[get_neo4j_memory_num(uid) for uid in end_user_ids])
|
||||
return {end_user_ids[i]: memory_nums_list[i] for i in range(len(end_user_ids))}
|
||||
|
||||
return {uid: {"total": 0} for uid in end_user_ids}
|
||||
|
||||
# 并发执行配置查询和记忆数量查询
|
||||
memory_configs_map, memory_nums_map = await asyncio.gather(
|
||||
get_memory_configs(),
|
||||
get_memory_nums()
|
||||
)
|
||||
|
||||
# 构建结果(优化:使用列表推导式)
|
||||
result = []
|
||||
for end_user in end_users:
|
||||
memory_num = {}
|
||||
if current_workspace_type == "neo4j":
|
||||
# EndUser 是 Pydantic 模型,直接访问属性而不是使用 .get()
|
||||
memory_num = await memory_storage_service.search_all(str(end_user.id))
|
||||
elif current_workspace_type == "rag":
|
||||
memory_num = {
|
||||
"total":memory_dashboard_service.get_current_user_total_chunk(str(end_user.id), db, current_user)
|
||||
}
|
||||
|
||||
# 从批量查询结果中获取配置信息
|
||||
user_id = str(end_user.id)
|
||||
memory_config_info = memory_configs_map.get(user_id, {
|
||||
"memory_config_id": None,
|
||||
"memory_config_name": None
|
||||
})
|
||||
|
||||
# 只保留需要的字段,移除 error 字段(如果有)
|
||||
memory_config = {
|
||||
"memory_config_id": memory_config_info.get("memory_config_id"),
|
||||
"memory_config_name": memory_config_info.get("memory_config_name")
|
||||
}
|
||||
|
||||
result.append(
|
||||
{
|
||||
'end_user': end_user,
|
||||
'memory_num': memory_num,
|
||||
'memory_config': memory_config
|
||||
config_info = memory_configs_map.get(user_id, {})
|
||||
result.append({
|
||||
'end_user': {
|
||||
'id': user_id,
|
||||
'other_name': end_user.other_name
|
||||
},
|
||||
'memory_num': memory_nums_map.get(user_id, {"total": 0}),
|
||||
'memory_config': {
|
||||
"memory_config_id": config_info.get("memory_config_id"),
|
||||
"memory_config_name": config_info.get("memory_config_name")
|
||||
}
|
||||
)
|
||||
|
||||
})
|
||||
|
||||
# 写入缓存(30秒过期)
|
||||
try:
|
||||
await aio_redis_set(cache_key, json.dumps(result), expire=30)
|
||||
except Exception as e:
|
||||
api_logger.warning(f"Redis 缓存写入失败: {str(e)}")
|
||||
|
||||
api_logger.info(f"成功获取 {len(end_users)} 个宿主记录")
|
||||
return success(data=result, msg="宿主列表获取成功")
|
||||
|
||||
@@ -398,6 +470,8 @@ async def get_chunk_insight(
|
||||
@router.get("/dashboard_data", response_model=ApiResponse)
|
||||
async def dashboard_data(
|
||||
end_user_id: Optional[str] = Query(None, description="可选的用户ID"),
|
||||
start_date: Optional[int] = Query(None, description="开始时间戳(毫秒)"),
|
||||
end_date: Optional[int] = Query(None, description="结束时间戳(毫秒)"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
@@ -432,6 +506,15 @@ async def dashboard_data(
|
||||
workspace_id = current_user.current_workspace_id
|
||||
api_logger.info(f"用户 {current_user.username} 请求获取工作空间 {workspace_id} 的dashboard整合数据")
|
||||
|
||||
# 如果没有提供时间范围,默认使用最近30天
|
||||
if start_date is None or end_date is None:
|
||||
from datetime import datetime, timedelta
|
||||
end_dt = datetime.now()
|
||||
start_dt = end_dt - timedelta(days=30)
|
||||
end_date = int(end_dt.timestamp() * 1000)
|
||||
start_date = int(start_dt.timestamp() * 1000)
|
||||
api_logger.info(f"使用默认时间范围: {start_dt} 到 {end_dt}")
|
||||
|
||||
# 获取 storage_type,如果为 None 则使用默认值
|
||||
storage_type = workspace_service.get_workspace_storage_type(
|
||||
db=db,
|
||||
@@ -492,17 +575,22 @@ async def dashboard_data(
|
||||
except Exception as e:
|
||||
api_logger.warning(f"获取知识库类型统计失败: {str(e)}")
|
||||
|
||||
# 3. 获取API调用增量(total_api_call,转换为整数)
|
||||
# 3. 获取API调用统计(total_api_call)
|
||||
try:
|
||||
api_increment = memory_dashboard_service.get_workspace_api_increment(
|
||||
db=db,
|
||||
# 使用 AppStatisticsService 获取真实的API调用统计
|
||||
app_stats_service = AppStatisticsService(db)
|
||||
api_stats = app_stats_service.get_workspace_api_statistics(
|
||||
workspace_id=workspace_id,
|
||||
current_user=current_user
|
||||
start_date=start_date,
|
||||
end_date=end_date
|
||||
)
|
||||
neo4j_data["total_api_call"] = api_increment
|
||||
api_logger.info(f"成功获取API调用增量: {neo4j_data['total_api_call']}")
|
||||
# 计算总调用次数
|
||||
total_api_calls = sum(item.get("total_calls", 0) for item in api_stats)
|
||||
neo4j_data["total_api_call"] = total_api_calls
|
||||
api_logger.info(f"成功获取API调用统计: {neo4j_data['total_api_call']}")
|
||||
except Exception as e:
|
||||
api_logger.warning(f"获取API调用增量失败: {str(e)}")
|
||||
api_logger.error(f"获取API调用统计失败: {str(e)}")
|
||||
neo4j_data["total_api_call"] = 0
|
||||
|
||||
result["neo4j_data"] = neo4j_data
|
||||
api_logger.info("成功获取neo4j_data")
|
||||
@@ -531,10 +619,23 @@ async def dashboard_data(
|
||||
total_kb = memory_dashboard_service.get_rag_total_kb(db, current_user)
|
||||
rag_data["total_knowledge"] = total_kb
|
||||
|
||||
# total_api_call: 固定值
|
||||
rag_data["total_api_call"] = 1024
|
||||
# total_api_call: 使用 AppStatisticsService 获取真实的API调用统计
|
||||
try:
|
||||
app_stats_service = AppStatisticsService(db)
|
||||
api_stats = app_stats_service.get_workspace_api_statistics(
|
||||
workspace_id=workspace_id,
|
||||
start_date=start_date,
|
||||
end_date=end_date
|
||||
)
|
||||
# 计算总调用次数
|
||||
total_api_calls = sum(item.get("total_calls", 0) for item in api_stats)
|
||||
rag_data["total_api_call"] = total_api_calls
|
||||
api_logger.info(f"成功获取RAG模式API调用统计: {rag_data['total_api_call']}")
|
||||
except Exception as e:
|
||||
api_logger.warning(f"获取RAG模式API调用统计失败,使用默认值: {str(e)}")
|
||||
rag_data["total_api_call"] = 0
|
||||
|
||||
api_logger.info(f"成功获取RAG相关数据: memory={total_chunk}, app={len(apps_orm)}, knowledge={total_kb}")
|
||||
api_logger.info(f"成功获取RAG相关数据: memory={total_chunk}, app={len(apps_orm)}, knowledge={total_kb}, api_calls={rag_data['total_api_call']}")
|
||||
except Exception as e:
|
||||
api_logger.warning(f"获取RAG相关数据失败: {str(e)}")
|
||||
|
||||
|
||||
@@ -3,9 +3,10 @@
|
||||
包含情景记忆总览和详情查询接口
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
from fastapi import APIRouter, Depends, Header
|
||||
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.language_utils import get_language_from_header
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.core.response_utils import fail, success
|
||||
from app.dependencies import get_current_user
|
||||
@@ -14,6 +15,7 @@ from app.schemas.response_schema import ApiResponse
|
||||
from app.schemas.memory_episodic_schema import (
|
||||
EpisodicMemoryOverviewRequest,
|
||||
EpisodicMemoryDetailsRequest,
|
||||
translate_episodic_type,
|
||||
)
|
||||
from app.services.memory_episodic_service import memory_episodic_service
|
||||
|
||||
@@ -84,6 +86,7 @@ async def get_episodic_memory_overview_api(
|
||||
@router.post("/details", response_model=ApiResponse)
|
||||
async def get_episodic_memory_details_api(
|
||||
request: EpisodicMemoryDetailsRequest,
|
||||
language_type: str = Header(default=None, alias="X-Language-Type"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> dict:
|
||||
"""
|
||||
@@ -111,6 +114,11 @@ async def get_episodic_memory_details_api(
|
||||
summary_id=request.summary_id
|
||||
)
|
||||
|
||||
# 根据语言参数翻译 episodic_type
|
||||
language = get_language_from_header(language_type)
|
||||
if "episodic_type" in result:
|
||||
result["episodic_type"] = translate_episodic_type(result["episodic_type"], language)
|
||||
|
||||
api_logger.info(
|
||||
f"成功获取情景记忆详情: end_user_id={request.end_user_id}, summary_id={request.summary_id}"
|
||||
)
|
||||
|
||||
@@ -11,6 +11,7 @@
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
from sqlalchemy.orm import Session
|
||||
@@ -33,7 +34,7 @@ from app.schemas.memory_storage_schema import (
|
||||
)
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.services.memory_forget_service import MemoryForgetService
|
||||
|
||||
from app.utils.config_utils import resolve_config_id
|
||||
|
||||
# 获取API专用日志器
|
||||
api_logger = get_api_logger()
|
||||
@@ -83,7 +84,8 @@ async def trigger_forgetting_cycle(
|
||||
|
||||
connected_config = get_end_user_connected_config(end_user_id, db)
|
||||
config_id = connected_config.get("memory_config_id")
|
||||
|
||||
config_id = resolve_config_id((config_id), db)
|
||||
|
||||
if config_id is None:
|
||||
api_logger.warning(f"终端用户 {end_user_id} 未关联记忆配置")
|
||||
return fail(BizCode.INVALID_PARAMETER, f"终端用户 {end_user_id} 未关联记忆配置", "memory_config_id is None")
|
||||
@@ -106,7 +108,7 @@ async def trigger_forgetting_cycle(
|
||||
# 调用服务层执行遗忘周期
|
||||
report = await forget_service.trigger_forgetting_cycle(
|
||||
db=db,
|
||||
group_id=end_user_id, # 服务层方法的参数名是 group_id
|
||||
end_user_id=end_user_id, # 服务层方法的参数名是 end_user_id
|
||||
max_merge_batch_size=payload.max_merge_batch_size,
|
||||
min_days_since_access=payload.min_days_since_access,
|
||||
config_id=config_id
|
||||
@@ -128,7 +130,7 @@ async def trigger_forgetting_cycle(
|
||||
|
||||
@router.get("/read_config", response_model=ApiResponse)
|
||||
async def read_forgetting_config(
|
||||
config_id: int,
|
||||
config_id: UUID|int,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
@@ -157,6 +159,7 @@ async def read_forgetting_config(
|
||||
)
|
||||
|
||||
try:
|
||||
config_id=resolve_config_id(config_id, db)
|
||||
# 调用服务层读取配置
|
||||
config = forget_service.read_forgetting_config(db=db, config_id=config_id)
|
||||
|
||||
@@ -194,6 +197,8 @@ async def update_forgetting_config(
|
||||
ApiResponse: 包含更新结果的响应
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
payload.config_id=resolve_config_id((payload.config_id), db)
|
||||
|
||||
|
||||
# 检查用户是否已选择工作空间
|
||||
if workspace_id is None:
|
||||
@@ -236,7 +241,7 @@ async def update_forgetting_config(
|
||||
|
||||
@router.get("/stats", response_model=ApiResponse)
|
||||
async def get_forgetting_stats(
|
||||
group_id: Optional[str] = None,
|
||||
end_user_id: Optional[str] = None,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
@@ -246,7 +251,7 @@ async def get_forgetting_stats(
|
||||
返回知识层节点统计、激活值分布等信息。
|
||||
|
||||
Args:
|
||||
group_id: 组ID(即 end_user_id,可选)
|
||||
end_user_id: 组ID(即 end_user_id,可选)
|
||||
current_user: 当前用户
|
||||
db: 数据库会话
|
||||
|
||||
@@ -254,26 +259,25 @@ async def get_forgetting_stats(
|
||||
ApiResponse: 包含统计信息的响应
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
# 检查用户是否已选择工作空间
|
||||
if workspace_id is None:
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试获取遗忘引擎统计但未选择工作空间")
|
||||
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
||||
|
||||
# 如果提供了 group_id,通过它获取 config_id
|
||||
# 如果提供了 end_user_id,通过它获取 config_id
|
||||
config_id = None
|
||||
if group_id:
|
||||
if end_user_id:
|
||||
try:
|
||||
from app.services.memory_agent_service import get_end_user_connected_config
|
||||
|
||||
connected_config = get_end_user_connected_config(group_id, db)
|
||||
connected_config = get_end_user_connected_config(end_user_id, db)
|
||||
config_id = connected_config.get("memory_config_id")
|
||||
config_id = resolve_config_id(config_id, db)
|
||||
|
||||
if config_id is None:
|
||||
api_logger.warning(f"终端用户 {group_id} 未关联记忆配置")
|
||||
return fail(BizCode.INVALID_PARAMETER, f"终端用户 {group_id} 未关联记忆配置", "memory_config_id is None")
|
||||
api_logger.warning(f"终端用户 {end_user_id} 未关联记忆配置")
|
||||
return fail(BizCode.INVALID_PARAMETER, f"终端用户 {end_user_id} 未关联记忆配置", "memory_config_id is None")
|
||||
|
||||
api_logger.debug(f"通过 group_id={group_id} 获取到 config_id={config_id}")
|
||||
api_logger.debug(f"通过 end_user_id={end_user_id} 获取到 config_id={config_id}")
|
||||
except ValueError as e:
|
||||
api_logger.warning(f"获取终端用户配置失败: {str(e)}")
|
||||
return fail(BizCode.INVALID_PARAMETER, str(e), "ValueError")
|
||||
@@ -283,14 +287,14 @@ async def get_forgetting_stats(
|
||||
|
||||
api_logger.info(
|
||||
f"用户 {current_user.username} 在工作空间 {workspace_id} 请求获取遗忘引擎统计: "
|
||||
f"group_id={group_id}, config_id={config_id}"
|
||||
f"end_user_id={end_user_id}, config_id={config_id}"
|
||||
)
|
||||
|
||||
try:
|
||||
# 调用服务层获取统计信息
|
||||
stats = await forget_service.get_forgetting_stats(
|
||||
db=db,
|
||||
group_id=group_id,
|
||||
end_user_id=end_user_id,
|
||||
config_id=config_id
|
||||
)
|
||||
|
||||
@@ -324,7 +328,7 @@ async def get_forgetting_curve(
|
||||
ApiResponse: 包含遗忘曲线数据的响应
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
request.config_id = resolve_config_id((request.config_id), db)
|
||||
# 检查用户是否已选择工作空间
|
||||
if workspace_id is None:
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试获取遗忘曲线但未选择工作空间")
|
||||
|
||||
@@ -27,27 +27,27 @@ router = APIRouter(
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{group_id}/count", response_model=ApiResponse)
|
||||
@router.get("/{end_user_id}/count", response_model=ApiResponse)
|
||||
def get_memory_count(
|
||||
group_id: uuid.UUID,
|
||||
end_user_id: uuid.UUID,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""Retrieve perceptual memory statistics for a user group.
|
||||
|
||||
Args:
|
||||
group_id: ID of the user group (usually end_user_id in this context)
|
||||
end_user_id: ID of the user group (usually end_user_id in this context)
|
||||
current_user: Current authenticated user
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
ApiResponse: Response containing memory count statistics
|
||||
"""
|
||||
api_logger.info(f"Fetching perceptual memory statistics: user={current_user.username}, group_id={group_id}")
|
||||
api_logger.info(f"Fetching perceptual memory statistics: user={current_user.username}, end_user_id={end_user_id}")
|
||||
|
||||
try:
|
||||
service = MemoryPerceptualService(db)
|
||||
count_stats = service.get_memory_count(group_id)
|
||||
count_stats = service.get_memory_count(end_user_id)
|
||||
|
||||
api_logger.info(f"Memory statistics fetched successfully: total={count_stats.get('total', 0)}")
|
||||
|
||||
@@ -57,37 +57,37 @@ def get_memory_count(
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"Failed to fetch memory statistics: group_id={group_id}, error={str(e)}")
|
||||
api_logger.error(f"Failed to fetch memory statistics: end_user_id={end_user_id}, error={str(e)}")
|
||||
return fail(
|
||||
code=BizCode.INTERNAL_ERROR,
|
||||
msg="Failed to fetch memory statistics",
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{group_id}/last_visual", response_model=ApiResponse)
|
||||
@router.get("/{end_user_id}/last_visual", response_model=ApiResponse)
|
||||
def get_last_visual_memory(
|
||||
group_id: uuid.UUID,
|
||||
end_user_id: uuid.UUID,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""Retrieve the most recent VISION-type memory for a user.
|
||||
|
||||
Args:
|
||||
group_id: ID of the user group
|
||||
end_user_id: ID of the user group
|
||||
current_user: Current authenticated user
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
ApiResponse: Metadata of the latest visual memory
|
||||
"""
|
||||
api_logger.info(f"Fetching latest visual memory: user={current_user.username}, group_id={group_id}")
|
||||
api_logger.info(f"Fetching latest visual memory: user={current_user.username}, end_user_id={end_user_id}")
|
||||
|
||||
try:
|
||||
service = MemoryPerceptualService(db)
|
||||
visual_memory = service.get_latest_visual_memory(group_id)
|
||||
visual_memory = service.get_latest_visual_memory(end_user_id)
|
||||
|
||||
if visual_memory is None:
|
||||
api_logger.info(f"No visual memory found: group_id={group_id}")
|
||||
api_logger.info(f"No visual memory found: end_user_id={end_user_id}")
|
||||
return success(
|
||||
data=None,
|
||||
msg="No visual memory available"
|
||||
@@ -101,37 +101,37 @@ def get_last_visual_memory(
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"Failed to fetch latest visual memory: group_id={group_id}, error={str(e)}")
|
||||
api_logger.error(f"Failed to fetch latest visual memory: end_user_id={end_user_id}, error={str(e)}")
|
||||
return fail(
|
||||
code=BizCode.INTERNAL_ERROR,
|
||||
msg="Failed to fetch latest visual memory",
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{group_id}/last_listen", response_model=ApiResponse)
|
||||
@router.get("/{end_user_id}/last_listen", response_model=ApiResponse)
|
||||
def get_last_memory_listen(
|
||||
group_id: uuid.UUID,
|
||||
end_user_id: uuid.UUID,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""Retrieve the most recent AUDIO-type memory for a user.
|
||||
|
||||
Args:
|
||||
group_id: ID of the user group
|
||||
end_user_id: ID of the user group
|
||||
current_user: Current authenticated user
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
ApiResponse: Metadata of the latest audio memory
|
||||
"""
|
||||
api_logger.info(f"Fetching latest audio memory: user={current_user.username}, group_id={group_id}")
|
||||
api_logger.info(f"Fetching latest audio memory: user={current_user.username}, end_user_id={end_user_id}")
|
||||
|
||||
try:
|
||||
service = MemoryPerceptualService(db)
|
||||
audio_memory = service.get_latest_audio_memory(group_id)
|
||||
audio_memory = service.get_latest_audio_memory(end_user_id)
|
||||
|
||||
if audio_memory is None:
|
||||
api_logger.info(f"No audio memory found: group_id={group_id}")
|
||||
api_logger.info(f"No audio memory found: end_user_id={end_user_id}")
|
||||
return success(
|
||||
data=None,
|
||||
msg="No audio memory available"
|
||||
@@ -145,38 +145,38 @@ def get_last_memory_listen(
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"Failed to fetch latest audio memory: group_id={group_id}, error={str(e)}")
|
||||
api_logger.error(f"Failed to fetch latest audio memory: end_user_id={end_user_id}, error={str(e)}")
|
||||
return fail(
|
||||
code=BizCode.INTERNAL_ERROR,
|
||||
msg="Failed to fetch latest audio memory",
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{group_id}/last_text", response_model=ApiResponse)
|
||||
@router.get("/{end_user_id}/last_text", response_model=ApiResponse)
|
||||
def get_last_text_memory(
|
||||
group_id: uuid.UUID,
|
||||
end_user_id: uuid.UUID,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""Retrieve the most recent TEXT-type memory for a user.
|
||||
|
||||
Args:
|
||||
group_id: ID of the user group
|
||||
end_user_id: ID of the user group
|
||||
current_user: Current authenticated user
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
ApiResponse: Metadata of the latest text memory
|
||||
"""
|
||||
api_logger.info(f"Fetching latest text memory: user={current_user.username}, group_id={group_id}")
|
||||
api_logger.info(f"Fetching latest text memory: user={current_user.username}, end_user_id={end_user_id}")
|
||||
|
||||
try:
|
||||
# 调用服务层获取最近的文本记忆
|
||||
service = MemoryPerceptualService(db)
|
||||
text_memory = service.get_latest_text_memory(group_id)
|
||||
text_memory = service.get_latest_text_memory(end_user_id)
|
||||
|
||||
if text_memory is None:
|
||||
api_logger.info(f"No text memory found: group_id={group_id}")
|
||||
api_logger.info(f"No text memory found: end_user_id={end_user_id}")
|
||||
return success(
|
||||
data=None,
|
||||
msg="No text memory available"
|
||||
@@ -190,16 +190,16 @@ def get_last_text_memory(
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"Failed to fetch latest text memory: group_id={group_id}, error={str(e)}")
|
||||
api_logger.error(f"Failed to fetch latest text memory: end_user_id={end_user_id}, error={str(e)}")
|
||||
return fail(
|
||||
code=BizCode.INTERNAL_ERROR,
|
||||
msg="Failed to fetch latest text memory",
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{group_id}/timeline", response_model=ApiResponse)
|
||||
@router.get("/{end_user_id}/timeline", response_model=ApiResponse)
|
||||
def get_memory_time_line(
|
||||
group_id: uuid.UUID,
|
||||
end_user_id: uuid.UUID,
|
||||
perceptual_type: Optional[PerceptualType] = Query(None, description="感知类型过滤"),
|
||||
page: int = Query(1, ge=1, description="页码"),
|
||||
page_size: int = Query(10, ge=1, le=100, description="每页大小"),
|
||||
@@ -209,7 +209,7 @@ def get_memory_time_line(
|
||||
"""Retrieve a timeline of perceptual memories for a user group.
|
||||
|
||||
Args:
|
||||
group_id: ID of the user group
|
||||
end_user_id: ID of the user group
|
||||
perceptual_type: Optional filter for perceptual type
|
||||
page: Page number for pagination
|
||||
page_size: Number of items per page
|
||||
@@ -221,7 +221,7 @@ def get_memory_time_line(
|
||||
"""
|
||||
api_logger.info(
|
||||
f"Fetching perceptual memory timeline: user={current_user.username}, "
|
||||
f"group_id={group_id}, type={perceptual_type}, page={page}"
|
||||
f"end_user_id={end_user_id}, type={perceptual_type}, page={page}"
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -232,7 +232,7 @@ def get_memory_time_line(
|
||||
)
|
||||
|
||||
service = MemoryPerceptualService(db)
|
||||
timeline_data = service.get_time_line(group_id, query)
|
||||
timeline_data = service.get_time_line(end_user_id, query)
|
||||
|
||||
api_logger.info(
|
||||
f"Perceptual memory timeline retrieved successfully: total={timeline_data.total}, "
|
||||
@@ -246,7 +246,7 @@ def get_memory_time_line(
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(
|
||||
f"Failed to fetch perceptual memory timeline: group_id={group_id}, "
|
||||
f"Failed to fetch perceptual memory timeline: end_user_id={end_user_id}, "
|
||||
f"error={str(e)}"
|
||||
)
|
||||
return fail(
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
import asyncio
|
||||
import time
|
||||
import uuid
|
||||
from uuid import UUID
|
||||
|
||||
from app.core.language_utils import get_language_from_header
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.core.memory.storage_services.reflection_engine.self_reflexion import (
|
||||
ReflectionConfig,
|
||||
@@ -11,7 +13,7 @@ from app.core.response_utils import success
|
||||
from app.db import get_db
|
||||
from app.dependencies import get_current_user
|
||||
from app.models.user_model import User
|
||||
from app.repositories.data_config_repository import DataConfigRepository
|
||||
from app.repositories.memory_config_repository import MemoryConfigRepository
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.schemas.memory_reflection_schemas import Memory_Reflection
|
||||
from app.services.memory_reflection_service import (
|
||||
@@ -24,6 +26,8 @@ from fastapi import APIRouter, Depends, HTTPException, status,Header
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.utils.config_utils import resolve_config_id
|
||||
|
||||
load_dotenv()
|
||||
api_logger = get_api_logger()
|
||||
|
||||
@@ -42,15 +46,15 @@ async def save_reflection_config(
|
||||
"""Save reflection configuration to data_comfig table"""
|
||||
try:
|
||||
config_id = request.config_id
|
||||
config_id = resolve_config_id(config_id, db)
|
||||
if not config_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="缺少必需参数: config_id"
|
||||
)
|
||||
|
||||
api_logger.info(f"用户 {current_user.username} 保存反思配置,config_id: {config_id}")
|
||||
|
||||
data_config = DataConfigRepository.update_reflection_config(
|
||||
memory_config = MemoryConfigRepository.update_reflection_config(
|
||||
db,
|
||||
config_id=config_id,
|
||||
enable_self_reflexion=request.reflection_enabled,
|
||||
@@ -63,17 +67,17 @@ async def save_reflection_config(
|
||||
)
|
||||
|
||||
db.commit()
|
||||
db.refresh(data_config)
|
||||
db.refresh(memory_config)
|
||||
|
||||
reflection_result={
|
||||
"config_id": data_config.config_id,
|
||||
"enable_self_reflexion": data_config.enable_self_reflexion,
|
||||
"iteration_period": data_config.iteration_period,
|
||||
"reflexion_range": data_config.reflexion_range,
|
||||
"baseline": data_config.baseline,
|
||||
"reflection_model_id": data_config.reflection_model_id,
|
||||
"memory_verify": data_config.memory_verify,
|
||||
"quality_assessment": data_config.quality_assessment}
|
||||
"config_id": memory_config.config_id,
|
||||
"enable_self_reflexion": memory_config.enable_self_reflexion,
|
||||
"iteration_period": memory_config.iteration_period,
|
||||
"reflexion_range": memory_config.reflexion_range,
|
||||
"baseline": memory_config.baseline,
|
||||
"reflection_model_id": memory_config.reflection_model_id,
|
||||
"memory_verify": memory_config.memory_verify,
|
||||
"quality_assessment": memory_config.quality_assessment}
|
||||
|
||||
return success(data=reflection_result, msg="反思配置成功")
|
||||
|
||||
@@ -98,51 +102,71 @@ async def start_workspace_reflection(
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
"""Activate the reflection function for all matching applications in the workspace"""
|
||||
"""启动工作空间中所有匹配应用的反思功能"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
reflection_service = MemoryReflectionService(db)
|
||||
|
||||
try:
|
||||
api_logger.info(f"用户 {current_user.username} 启动workspace反思,workspace_id: {workspace_id}")
|
||||
|
||||
service = WorkspaceAppService(db)
|
||||
result = service.get_workspace_apps_detailed(workspace_id)
|
||||
# 使用独立的数据库会话来获取工作空间应用详情,避免事务失败
|
||||
from app.db import get_db_context
|
||||
with get_db_context() as query_db:
|
||||
service = WorkspaceAppService(query_db)
|
||||
result = service.get_workspace_apps_detailed(workspace_id)
|
||||
|
||||
reflection_results = []
|
||||
|
||||
for data in result['apps_detailed_info']:
|
||||
if data['data_configs'] == []:
|
||||
# 跳过没有配置的应用
|
||||
if not data['memory_configs']:
|
||||
api_logger.debug(f"应用 {data['id']} 没有memory_configs,跳过")
|
||||
continue
|
||||
|
||||
|
||||
releases = data['releases']
|
||||
data_configs = data['data_configs']
|
||||
memory_configs = data['memory_configs']
|
||||
end_users = data['end_users']
|
||||
|
||||
for base, config, user in zip(releases, data_configs, end_users):
|
||||
# 安全地转换为整数,处理空字符串和None的情况
|
||||
print(base['config'])
|
||||
try:
|
||||
base_config = int(base['config']) if base['config'] else 0
|
||||
config_id = int(config['config_id']) if config['config_id'] else 0
|
||||
except (ValueError, TypeError):
|
||||
api_logger.warning(f"无效的配置ID: base['config']={base.get('config')}, config['config_id']={config.get('config_id')}")
|
||||
|
||||
# 为每个配置和用户组合执行反思
|
||||
for config in memory_configs:
|
||||
config_id_str = str(config['config_id'])
|
||||
|
||||
# 找到匹配此配置的所有release
|
||||
matching_releases = [r for r in releases if str(r['config']) == config_id_str]
|
||||
|
||||
if not matching_releases:
|
||||
api_logger.debug(f"配置 {config_id_str} 没有匹配的release")
|
||||
continue
|
||||
|
||||
if base_config == config_id and base['app_id'] == user['app_id']:
|
||||
# 调用反思服务
|
||||
api_logger.info(f"为用户 {user['id']} 启动反思,config_id: {config['config_id']}")
|
||||
|
||||
reflection_result = await reflection_service.start_text_reflection(
|
||||
config_data=config,
|
||||
end_user_id=user['id']
|
||||
)
|
||||
|
||||
reflection_results.append({
|
||||
"app_id": base['app_id'],
|
||||
"config_id": config['config_id'],
|
||||
"end_user_id": user['id'],
|
||||
"reflection_result": reflection_result
|
||||
})
|
||||
|
||||
# 为每个用户执行反思 - 使用独立的数据库会话
|
||||
for user in end_users:
|
||||
api_logger.info(f"为用户 {user['id']} 启动反思,config_id: {config_id_str}")
|
||||
|
||||
# 为每个用户创建独立的数据库会话,避免事务失败影响其他用户
|
||||
with get_db_context() as user_db:
|
||||
try:
|
||||
reflection_service = MemoryReflectionService(user_db)
|
||||
reflection_result = await reflection_service.start_text_reflection(
|
||||
config_data=config,
|
||||
end_user_id=user['id']
|
||||
)
|
||||
|
||||
reflection_results.append({
|
||||
"app_id": data['id'],
|
||||
"config_id": config_id_str,
|
||||
"end_user_id": user['id'],
|
||||
"reflection_result": reflection_result
|
||||
})
|
||||
except Exception as e:
|
||||
api_logger.error(f"用户 {user['id']} 反思失败: {str(e)}")
|
||||
reflection_results.append({
|
||||
"app_id": data['id'],
|
||||
"config_id": config_id_str,
|
||||
"end_user_id": user['id'],
|
||||
"reflection_result": {
|
||||
"status": "错误",
|
||||
"message": f"反思失败: {str(e)}"
|
||||
}
|
||||
})
|
||||
|
||||
return success(data=reflection_results, msg="反思配置成功")
|
||||
|
||||
@@ -156,17 +180,20 @@ async def start_workspace_reflection(
|
||||
|
||||
@router.get("/reflection/configs")
|
||||
async def start_reflection_configs(
|
||||
config_id: int,
|
||||
config_id: uuid.UUID|int,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
"""通过config_id查询data_config表中的反思配置信息"""
|
||||
"""通过config_id查询memory_config表中的反思配置信息"""
|
||||
config_id = resolve_config_id(config_id, db)
|
||||
try:
|
||||
config_id=resolve_config_id(config_id,db)
|
||||
api_logger.info(f"用户 {current_user.username} 查询反思配置,config_id: {config_id}")
|
||||
result = DataConfigRepository.query_reflection_config_by_id(db, config_id)
|
||||
result = MemoryConfigRepository.query_reflection_config_by_id(db, config_id)
|
||||
memory_config_id = resolve_config_id(result.config_id, db)
|
||||
# 构建返回数据
|
||||
reflection_config = {
|
||||
"config_id": result.config_id,
|
||||
"config_id": memory_config_id,
|
||||
"reflection_enabled": result.enable_self_reflexion,
|
||||
"reflection_period_in_hours": result.iteration_period,
|
||||
"reflexion_range": result.reflexion_range,
|
||||
@@ -191,17 +218,19 @@ async def start_reflection_configs(
|
||||
|
||||
@router.get("/reflection/run")
|
||||
async def reflection_run(
|
||||
config_id: int,
|
||||
language_type: str = Header(default="zh", alias="X-Language-Type"),
|
||||
config_id: UUID|int,
|
||||
language_type: str = Header(default=None, alias="X-Language-Type"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
"""Activate the reflection function for all matching applications in the workspace"""
|
||||
# 使用集中化的语言校验
|
||||
language = get_language_from_header(language_type)
|
||||
|
||||
api_logger.info(f"用户 {current_user.username} 查询反思配置,config_id: {config_id}")
|
||||
|
||||
# 使用DataConfigRepository查询反思配置
|
||||
result = DataConfigRepository.query_reflection_config_by_id(db, config_id)
|
||||
config_id = resolve_config_id(config_id, db)
|
||||
# 使用MemoryConfigRepository查询反思配置
|
||||
result = MemoryConfigRepository.query_reflection_config_by_id(db, config_id)
|
||||
if not result:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException, status,Header
|
||||
from app.core.language_utils import get_language_from_header
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.core.response_utils import success
|
||||
from app.db import get_db
|
||||
@@ -20,10 +21,13 @@ router = APIRouter(
|
||||
@router.get("/short_term")
|
||||
async def short_term_configs(
|
||||
end_user_id: str,
|
||||
language_type:str = Header(default="zh", alias="X-Language-Type"),
|
||||
language_type:str = Header(default=None, alias="X-Language-Type"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
# 使用集中化的语言校验
|
||||
language = get_language_from_header(language_type)
|
||||
|
||||
# 获取短期记忆数据
|
||||
short_term=ShortService(end_user_id)
|
||||
short_result=short_term.get_short_databasets()
|
||||
|
||||
@@ -1,7 +1,12 @@
|
||||
import os
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.language_utils import get_language_from_header
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.core.response_utils import fail, success
|
||||
from app.db import get_db
|
||||
@@ -10,7 +15,6 @@ from app.models.user_model import User
|
||||
from app.schemas.memory_storage_schema import (
|
||||
ConfigKey,
|
||||
ConfigParamsCreate,
|
||||
ConfigParamsDelete,
|
||||
ConfigPilotRun,
|
||||
ConfigUpdate,
|
||||
ConfigUpdateExtracted,
|
||||
@@ -30,10 +34,12 @@ from app.services.memory_storage_service import (
|
||||
search_entity,
|
||||
search_statement,
|
||||
)
|
||||
from fastapi import APIRouter, Depends
|
||||
from fastapi import APIRouter, Depends, Header
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.utils.config_utils import resolve_config_id
|
||||
|
||||
# Get API logger
|
||||
api_logger = get_api_logger()
|
||||
|
||||
@@ -69,68 +75,9 @@ async def get_storage_info(
|
||||
return fail(BizCode.INTERNAL_ERROR, "存储信息获取失败", str(e))
|
||||
|
||||
|
||||
# --- DB connection dependency ---
|
||||
_CONN: Optional[object] = None
|
||||
|
||||
|
||||
"""PostgreSQL 连接生成与管理(使用 psycopg2)。"""
|
||||
# 这个可以转移,可能是已经有的
|
||||
# PostgreSQL 数据库连接
|
||||
def _make_pgsql_conn() -> Optional[object]: # 创建 PostgreSQL 数据库连接
|
||||
host = os.getenv("DB_HOST")
|
||||
user = os.getenv("DB_USER")
|
||||
password = os.getenv("DB_PASSWORD")
|
||||
database = os.getenv("DB_NAME")
|
||||
port_str = os.getenv("DB_PORT")
|
||||
try:
|
||||
import psycopg2 # type: ignore
|
||||
port = int(port_str) if port_str else 5432
|
||||
conn = psycopg2.connect(
|
||||
host=host or "localhost",
|
||||
port=port,
|
||||
user=user,
|
||||
password=password,
|
||||
dbname=database,
|
||||
)
|
||||
# 设置自动提交,避免显式事务管理
|
||||
conn.autocommit = True
|
||||
# 设置会话时区为中国标准时间(Asia/Shanghai),便于直接以本地时区展示
|
||||
try:
|
||||
cur = conn.cursor()
|
||||
cur.execute("SET TIME ZONE 'Asia/Shanghai'")
|
||||
cur.close()
|
||||
except Exception:
|
||||
# 时区设置失败不影响连接,仅记录但不抛出
|
||||
pass
|
||||
return conn
|
||||
except Exception as e:
|
||||
try:
|
||||
print(f"[PostgreSQL] 连接失败: {e}")
|
||||
except Exception:
|
||||
pass
|
||||
return None
|
||||
|
||||
def get_db_conn() -> Optional[object]: # 获取 PostgreSQL 数据库连接
|
||||
global _CONN
|
||||
if _CONN is None:
|
||||
_CONN = _make_pgsql_conn()
|
||||
return _CONN
|
||||
|
||||
|
||||
def reset_db_conn() -> bool: # 重置 PostgreSQL 数据库连接
|
||||
"""Close and recreate the global DB connection."""
|
||||
global _CONN
|
||||
try:
|
||||
if _CONN:
|
||||
try:
|
||||
_CONN.close()
|
||||
except Exception:
|
||||
pass
|
||||
_CONN = _make_pgsql_conn()
|
||||
return _CONN is not None
|
||||
except Exception:
|
||||
_CONN = None
|
||||
return False
|
||||
|
||||
|
||||
@router.post("/create_config", response_model=ApiResponse) # 创建配置文件,其他参数默认
|
||||
@@ -138,9 +85,8 @@ def create_config(
|
||||
payload: ConfigParamsCreate,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
) -> dict:
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
# 检查用户是否已选择工作空间
|
||||
if workspace_id is None:
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试创建配置但未选择工作空间")
|
||||
@@ -160,39 +106,96 @@ def create_config(
|
||||
|
||||
@router.delete("/delete_config", response_model=ApiResponse) # 删除数据库中的内容(按配置名称)
|
||||
def delete_config(
|
||||
config_id: str,
|
||||
config_id: UUID|int,
|
||||
force: bool = Query(False, description="是否强制删除(即使有终端用户正在使用)"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
workspace_id = current_user.current_workspace_id
|
||||
) -> dict:
|
||||
"""删除记忆配置(带终端用户保护)
|
||||
|
||||
- 检查是否为默认配置,默认配置不允许删除
|
||||
- 检查是否有终端用户连接到该配置
|
||||
- 如果有连接且 force=False,返回警告
|
||||
- 如果 force=True,清除终端用户引用后删除配置
|
||||
|
||||
Query Parameters:
|
||||
force: 设置为 true 可强制删除(即使有终端用户正在使用)
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
config_id=resolve_config_id(config_id, db)
|
||||
# 检查用户是否已选择工作空间
|
||||
if workspace_id is None:
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试删除配置但未选择工作空间")
|
||||
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
||||
|
||||
api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求删除配置: {config_id}")
|
||||
api_logger.info(
|
||||
f"用户 {current_user.username} 在工作空间 {workspace_id} 请求删除配置: "
|
||||
f"config_id={config_id}, force={force}"
|
||||
)
|
||||
|
||||
try:
|
||||
svc = DataConfigService(db)
|
||||
result = svc.delete(ConfigParamsDelete(config_id=config_id))
|
||||
return success(data=result, msg="删除成功")
|
||||
# 使用带保护的删除服务
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
|
||||
config_service = MemoryConfigService(db)
|
||||
result = config_service.delete_config(config_id=config_id, force=force)
|
||||
|
||||
if result["status"] == "error":
|
||||
api_logger.warning(
|
||||
f"记忆配置删除被拒绝: config_id={config_id}, reason={result['message']}"
|
||||
)
|
||||
return fail(
|
||||
code=BizCode.FORBIDDEN,
|
||||
msg=result["message"],
|
||||
data={"config_id": str(config_id), "is_default": result.get("is_default", False)}
|
||||
)
|
||||
|
||||
if result["status"] == "warning":
|
||||
api_logger.warning(
|
||||
f"记忆配置正在使用,无法删除: config_id={config_id}, "
|
||||
f"connected_count={result['connected_count']}"
|
||||
)
|
||||
return fail(
|
||||
code=BizCode.RESOURCE_IN_USE,
|
||||
msg=result["message"],
|
||||
data={
|
||||
"connected_count": result["connected_count"],
|
||||
"force_required": result["force_required"]
|
||||
}
|
||||
)
|
||||
|
||||
api_logger.info(
|
||||
f"记忆配置删除成功: config_id={config_id}, "
|
||||
f"affected_users={result['affected_users']}"
|
||||
)
|
||||
return success(
|
||||
msg=result["message"],
|
||||
data={"affected_users": result["affected_users"]}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"Delete config failed: {str(e)}")
|
||||
api_logger.error(f"Delete config failed: {str(e)}", exc_info=True)
|
||||
return fail(BizCode.INTERNAL_ERROR, "删除配置失败", str(e))
|
||||
|
||||
|
||||
@router.post("/update_config", response_model=ApiResponse) # 更新配置文件中name和desc
|
||||
def update_config(
|
||||
payload: ConfigUpdate,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
) -> dict:
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
payload.config_id = resolve_config_id(payload.config_id, db)
|
||||
# 检查用户是否已选择工作空间
|
||||
if workspace_id is None:
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试更新配置但未选择工作空间")
|
||||
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
||||
|
||||
# 校验至少有一个字段需要更新
|
||||
if payload.config_name is None and payload.config_desc is None and payload.scene_id is None:
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试更新配置但未提供任何更新字段")
|
||||
return fail(BizCode.INVALID_PARAMETER, "请至少提供一个需要更新的字段", "config_name, config_desc, scene_id 均为空")
|
||||
|
||||
api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求更新配置: {payload.config_id}")
|
||||
try:
|
||||
svc = DataConfigService(db)
|
||||
@@ -208,9 +211,9 @@ def update_config_extracted(
|
||||
payload: ConfigUpdateExtracted,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
) -> dict:
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
payload.config_id = resolve_config_id(payload.config_id, db)
|
||||
# 检查用户是否已选择工作空间
|
||||
if workspace_id is None:
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试更新提取配置但未选择工作空间")
|
||||
@@ -232,12 +235,12 @@ def update_config_extracted(
|
||||
|
||||
@router.get("/read_config_extracted", response_model=ApiResponse) # 通过查询参数读取某条配置(固定路径) 没有意义的话就删除
|
||||
def read_config_extracted(
|
||||
config_id: str,
|
||||
config_id: UUID | int,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
) -> dict:
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
config_id = resolve_config_id(config_id, db)
|
||||
# 检查用户是否已选择工作空间
|
||||
if workspace_id is None:
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试读取提取配置但未选择工作空间")
|
||||
@@ -256,7 +259,7 @@ def read_config_extracted(
|
||||
def read_all_config(
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
) -> dict:
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
# 检查用户是否已选择工作空间
|
||||
@@ -278,16 +281,22 @@ def read_all_config(
|
||||
@router.post("/pilot_run", response_model=None)
|
||||
async def pilot_run(
|
||||
payload: ConfigPilotRun,
|
||||
language_type: str = Header(default=None, alias="X-Language-Type"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> StreamingResponse:
|
||||
# 使用集中化的语言校验
|
||||
language = get_language_from_header(language_type)
|
||||
|
||||
api_logger.info(
|
||||
f"Pilot run requested: config_id={payload.config_id}, "
|
||||
f"dialogue_text_length={len(payload.dialogue_text)}"
|
||||
f"dialogue_text_length={len(payload.dialogue_text)}, "
|
||||
f"custom_text_length={len(payload.custom_text) if payload.custom_text else 0}"
|
||||
)
|
||||
payload.config_id = resolve_config_id(payload.config_id, db)
|
||||
svc = DataConfigService(db)
|
||||
return StreamingResponse(
|
||||
svc.pilot_run_stream(payload),
|
||||
svc.pilot_run_stream(payload, language=language),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
@@ -296,9 +305,8 @@ async def pilot_run(
|
||||
},
|
||||
)
|
||||
|
||||
"""
|
||||
以下为搜索与分析接口,直接挂载到同一 router,统一响应为 ApiResponse。
|
||||
"""
|
||||
|
||||
# ==================== Search & Analytics ====================
|
||||
|
||||
@router.get("/search/kb_type_distribution", response_model=ApiResponse)
|
||||
async def get_kb_type_distribution(
|
||||
@@ -420,15 +428,96 @@ async def get_hot_memory_tags_api(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> dict:
|
||||
api_logger.info(f"Hot memory tags requested for current_user: {current_user.id}")
|
||||
"""
|
||||
获取热门记忆标签(带Redis缓存)
|
||||
|
||||
缓存策略:
|
||||
- 缓存键:workspace_id + limit
|
||||
- 过期时间:5分钟(300秒)
|
||||
- 缓存命中:~50ms
|
||||
- 缓存未命中:~600-800ms(取决于LLM速度)
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
# 构建缓存键
|
||||
cache_key = f"hot_memory_tags:{workspace_id}:{limit}"
|
||||
|
||||
api_logger.info(f"Hot memory tags requested for workspace: {workspace_id}, limit: {limit}")
|
||||
|
||||
try:
|
||||
# 尝试从Redis缓存获取
|
||||
import json
|
||||
|
||||
from app.aioRedis import aio_redis_get, aio_redis_set
|
||||
|
||||
cached_result = await aio_redis_get(cache_key)
|
||||
if cached_result:
|
||||
api_logger.info(f"Cache hit for key: {cache_key}")
|
||||
try:
|
||||
data = json.loads(cached_result)
|
||||
return success(data=data, msg="查询成功(缓存)")
|
||||
except json.JSONDecodeError:
|
||||
api_logger.warning(f"Failed to parse cached data, will refresh")
|
||||
|
||||
# 缓存未命中,执行查询
|
||||
api_logger.info(f"Cache miss for key: {cache_key}, executing query")
|
||||
result = await analytics_hot_memory_tags(db, current_user, limit)
|
||||
|
||||
# 写入缓存(过期时间:5分钟)
|
||||
# 注意:result是列表,需要转换为JSON字符串
|
||||
try:
|
||||
cache_data = json.dumps(result, ensure_ascii=False)
|
||||
await aio_redis_set(cache_key, cache_data, expire=300)
|
||||
api_logger.info(f"Cached result for key: {cache_key}")
|
||||
except Exception as cache_error:
|
||||
# 缓存写入失败不影响主流程
|
||||
api_logger.warning(f"Failed to cache result: {str(cache_error)}")
|
||||
|
||||
return success(data=result, msg="查询成功")
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"Hot memory tags failed: {str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "热门标签查询失败", str(e))
|
||||
|
||||
|
||||
@router.delete("/analytics/hot_memory_tags/cache", response_model=ApiResponse)
|
||||
async def clear_hot_memory_tags_cache(
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> dict:
|
||||
"""
|
||||
清除热门标签缓存
|
||||
|
||||
用于:
|
||||
- 手动刷新数据
|
||||
- 调试和测试
|
||||
- 数据更新后立即生效
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
api_logger.info(f"Clear hot memory tags cache requested for workspace: {workspace_id}")
|
||||
|
||||
try:
|
||||
from app.aioRedis import aio_redis_delete
|
||||
|
||||
# 清除所有limit的缓存(常见的limit值)
|
||||
cleared_count = 0
|
||||
for limit in [5, 10, 15, 20, 30, 50]:
|
||||
cache_key = f"hot_memory_tags:{workspace_id}:{limit}"
|
||||
result = await aio_redis_delete(cache_key)
|
||||
if result:
|
||||
cleared_count += 1
|
||||
api_logger.info(f"Cleared cache for key: {cache_key}")
|
||||
|
||||
return success(
|
||||
data={"cleared_count": cleared_count},
|
||||
msg=f"成功清除 {cleared_count} 个缓存"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"Clear cache failed: {str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "清除缓存失败", str(e))
|
||||
|
||||
|
||||
@router.get("/analytics/recent_activity_stats", response_model=ApiResponse)
|
||||
async def get_recent_activity_stats_api(
|
||||
current_user: User = Depends(get_current_user),
|
||||
|
||||
@@ -20,18 +20,18 @@ router = APIRouter(
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{group_id}/count", response_model=ApiResponse)
|
||||
@router.get("/{end_user_id}/count", response_model=ApiResponse)
|
||||
def get_memory_count(
|
||||
group_id: uuid.UUID,
|
||||
end_user_id: uuid.UUID,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
@router.get("/{group_id}/conversations", response_model=ApiResponse)
|
||||
@router.get("/{end_user_id}/conversations", response_model=ApiResponse)
|
||||
def get_conversations(
|
||||
group_id: uuid.UUID,
|
||||
end_user_id: uuid.UUID,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
@@ -39,7 +39,7 @@ def get_conversations(
|
||||
Retrieve all conversations for the current user in a specific group.
|
||||
|
||||
Args:
|
||||
group_id (UUID): The group identifier.
|
||||
end_user_id (UUID): The group identifier.
|
||||
current_user (User, optional): The authenticated user.
|
||||
db (Session, optional): SQLAlchemy session.
|
||||
|
||||
@@ -53,7 +53,7 @@ def get_conversations(
|
||||
"""
|
||||
conversation_service = ConversationService(db)
|
||||
conversations = conversation_service.get_user_conversations(
|
||||
group_id
|
||||
end_user_id
|
||||
)
|
||||
return success(data=[
|
||||
{
|
||||
@@ -63,7 +63,7 @@ def get_conversations(
|
||||
], msg="get conversations success")
|
||||
|
||||
|
||||
@router.get("/{group_id}/messages", response_model=ApiResponse)
|
||||
@router.get("/{end_user_id}/messages", response_model=ApiResponse)
|
||||
def get_messages(
|
||||
conversation_id: uuid.UUID,
|
||||
current_user: User = Depends(get_current_user),
|
||||
@@ -100,7 +100,7 @@ def get_messages(
|
||||
return success(data=messages, msg="get conversation history success")
|
||||
|
||||
|
||||
@router.get("/{group_id}/detail", response_model=ApiResponse)
|
||||
@router.get("/{end_user_id}/detail", response_model=ApiResponse)
|
||||
async def get_conversation_detail(
|
||||
conversation_id: uuid.UUID,
|
||||
current_user: User = Depends(get_current_user),
|
||||
|
||||
@@ -3,15 +3,17 @@ from sqlalchemy.orm import Session
|
||||
from typing import Optional
|
||||
import uuid
|
||||
|
||||
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.db import get_db
|
||||
from app.dependencies import get_current_user
|
||||
from app.models.models_model import ModelProvider, ModelType
|
||||
from app.models.models_model import ModelProvider, ModelType, LoadBalanceStrategy
|
||||
from app.models.user_model import User
|
||||
from app.repositories.model_repository import ModelConfigRepository
|
||||
from app.schemas import model_schema
|
||||
from app.core.response_utils import success
|
||||
from app.schemas.response_schema import ApiResponse, PageData
|
||||
from app.services.model_service import ModelConfigService, ModelApiKeyService
|
||||
from app.services.model_service import ModelConfigService, ModelApiKeyService, ModelBaseService
|
||||
from app.core.logging_config import get_api_logger
|
||||
|
||||
# 获取API专用日志器
|
||||
@@ -24,24 +26,83 @@ router = APIRouter(
|
||||
|
||||
@router.get("/type", response_model=ApiResponse)
|
||||
def get_model_types():
|
||||
|
||||
return success(msg="获取模型类型成功", data=list(ModelType))
|
||||
|
||||
|
||||
@router.get("/provider", response_model=ApiResponse)
|
||||
def get_model_providers():
|
||||
return success(msg="获取模型提供商成功", data=list(ModelProvider))
|
||||
providers = [p for p in ModelProvider if p != ModelProvider.COMPOSITE]
|
||||
return success(msg="获取模型提供商成功", data=providers)
|
||||
|
||||
@router.get("/strategy", response_model=ApiResponse)
|
||||
def get_model_strategies():
|
||||
return success(msg="获取模型策略成功", data=list(LoadBalanceStrategy))
|
||||
|
||||
|
||||
@router.get("", response_model=ApiResponse)
|
||||
def get_model_list(
|
||||
type: Optional[str] = Query(None, description="模型类型筛选(支持多个,如 ?type=LLM 或 ?type=LLM,EMBEDDING)"),
|
||||
provider: Optional[model_schema.ModelProvider] = Query(None, description="提供商筛选(基于API Key)"),
|
||||
type: Optional[list[str]] = Query(None, description="模型类型筛选(支持多个,如 ?type=LLM 或 ?type=LLM,EMBEDDING)"),
|
||||
provider: Optional[model_schema.ModelProvider] = Query(None, description="提供商筛选(基于API Key)"),
|
||||
is_active: Optional[bool] = Query(None, description="激活状态筛选"),
|
||||
is_public: Optional[bool] = Query(None, description="公开状态筛选"),
|
||||
search: Optional[str] = Query(None, description="搜索关键词"),
|
||||
page: int = Query(1, ge=1, description="页码"),
|
||||
pagesize: int = Query(10, ge=1, le=100, description="每页数量"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
获取模型配置列表
|
||||
|
||||
支持多个 type 参数:
|
||||
- 单个:?type=LLM
|
||||
- 多个(逗号分隔):?type=LLM,EMBEDDING
|
||||
- 多个(重复参数):?type=LLM&type=EMBEDDING
|
||||
"""
|
||||
api_logger.info(
|
||||
f"获取模型配置列表请求: type={type}, provider={provider}, page={page}, pagesize={pagesize}, tenant_id={current_user.tenant_id}")
|
||||
|
||||
try:
|
||||
# 解析 type 参数(支持逗号分隔)
|
||||
type_list = []
|
||||
if type is not None:
|
||||
flat_type = []
|
||||
for item in type:
|
||||
split_items = [t.strip() for t in item.split(',') if t.strip()]
|
||||
flat_type.extend(split_items)
|
||||
|
||||
unique_flat_type = list(dict.fromkeys(flat_type))
|
||||
type_list = [ModelType(t.lower()) for t in unique_flat_type]
|
||||
|
||||
api_logger.error(f"获取模型type_list: {type_list}")
|
||||
query = model_schema.ModelConfigQuery(
|
||||
type=type_list,
|
||||
provider=provider,
|
||||
is_active=is_active,
|
||||
is_public=is_public,
|
||||
search=search,
|
||||
page=page,
|
||||
pagesize=pagesize
|
||||
)
|
||||
|
||||
api_logger.debug(f"开始获取模型配置列表: {query.dict()}")
|
||||
result_orm = ModelConfigService.get_model_list(db=db, query=query, tenant_id=current_user.tenant_id)
|
||||
result = PageData.model_validate(result_orm)
|
||||
api_logger.info(f"模型配置列表获取成功: 总数={result.page.total}, 当前页={len(result.items)}")
|
||||
return success(data=result, msg="模型配置列表获取成功")
|
||||
except Exception as e:
|
||||
api_logger.error(f"获取模型配置列表失败: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
@router.get("/new", response_model=ApiResponse)
|
||||
def get_model_list_new(
|
||||
type: Optional[list[str]] = Query(None, description="模型类型筛选(支持多个,如 ?type=LLM 或 ?type=LLM,EMBEDDING)"),
|
||||
provider: Optional[model_schema.ModelProvider] = Query(None, description="提供商筛选(基于ModelConfig)"),
|
||||
is_active: Optional[bool] = Query(None, description="激活状态筛选"),
|
||||
is_public: Optional[bool] = Query(None, description="公开状态筛选"),
|
||||
search: Optional[str] = Query(None, description="搜索关键词"),
|
||||
page: int = Query(1, ge=1, description="页码"),
|
||||
pagesize: int = Query(10, ge=1, le=100, description="每页数量"),
|
||||
is_composite: Optional[bool] = Query(None, description="组合模型筛选"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
@@ -53,36 +114,127 @@ def get_model_list(
|
||||
- 多个(逗号分隔):?type=LLM,EMBEDDING
|
||||
- 多个(重复参数):?type=LLM&type=EMBEDDING
|
||||
"""
|
||||
api_logger.info(f"获取模型配置列表请求: type={type}, provider={provider}, page={page}, pagesize={pagesize}, tenant_id={current_user.tenant_id}")
|
||||
api_logger.info(f"获取模型配置列表请求: type={type}, provider={provider}, tenant_id={current_user.tenant_id}")
|
||||
|
||||
try:
|
||||
# 解析 type 参数(支持逗号分隔)
|
||||
type_list = None
|
||||
if type:
|
||||
type_values = [t.strip() for t in type.split(',')]
|
||||
type_list = [model_schema.ModelType(t.lower()) for t in type_values if t]
|
||||
type_list = []
|
||||
if type is not None:
|
||||
flat_type = []
|
||||
for item in type:
|
||||
split_items = [t.strip() for t in item.split(',') if t.strip()]
|
||||
flat_type.extend(split_items)
|
||||
|
||||
unique_flat_type = list(dict.fromkeys(flat_type))
|
||||
type_list = [ModelType(t.lower()) for t in unique_flat_type]
|
||||
|
||||
api_logger.error(f"获取模型type_list: {type_list}")
|
||||
query = model_schema.ModelConfigQuery(
|
||||
api_logger.info(f"获取模型type_list: {type_list}")
|
||||
query = model_schema.ModelConfigQueryNew(
|
||||
type=type_list,
|
||||
provider=provider,
|
||||
is_active=is_active,
|
||||
is_public=is_public,
|
||||
search=search,
|
||||
page=page,
|
||||
pagesize=pagesize
|
||||
is_composite=is_composite,
|
||||
search=search
|
||||
)
|
||||
|
||||
api_logger.debug(f"开始获取模型配置列表: {query.dict()}")
|
||||
result_orm = ModelConfigService.get_model_list(db=db, query=query, tenant_id=current_user.tenant_id)
|
||||
result = PageData.model_validate(result_orm)
|
||||
api_logger.info(f"模型配置列表获取成功: 总数={result.page.total}, 当前页={len(result.items)}")
|
||||
api_logger.debug(f"开始获取模型配置列表: {query.model_dump()}")
|
||||
result = ModelConfigService.get_model_list_new(db=db, query=query, tenant_id=current_user.tenant_id)
|
||||
api_logger.info(f"模型配置列表获取成功: 分组数={len(result)}, 总模型数={sum(len(item['models']) for item in result)}")
|
||||
return success(data=result, msg="模型配置列表获取成功")
|
||||
except Exception as e:
|
||||
api_logger.error(f"获取模型配置列表失败: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
@router.get("/model_plaza", response_model=ApiResponse)
|
||||
def get_model_plaza_list(
|
||||
type: Optional[ModelType] = Query(None, description="模型类型"),
|
||||
provider: Optional[ModelProvider] = Query(None, description="供应商"),
|
||||
is_official: Optional[bool] = Query(None, description="是否官方模型"),
|
||||
is_deprecated: Optional[bool] = Query(None, description="是否弃用"),
|
||||
search: Optional[str] = Query(None, description="搜索关键词"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""模型广场查询接口(按供应商分组)"""
|
||||
|
||||
query = model_schema.ModelBaseQuery(
|
||||
type=type,
|
||||
provider=provider,
|
||||
is_official=is_official,
|
||||
is_deprecated=is_deprecated,
|
||||
search=search
|
||||
)
|
||||
result = ModelBaseService.get_model_base_list(db=db, query=query, tenant_id=current_user.tenant_id)
|
||||
return success(data=result, msg="模型广场列表获取成功")
|
||||
|
||||
|
||||
@router.get("/model_plaza/{model_base_id}", response_model=ApiResponse)
|
||||
def get_model_base_by_id(
|
||||
model_base_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""获取基础模型详情"""
|
||||
|
||||
result = ModelBaseService.get_model_base_by_id(db=db, model_base_id=model_base_id)
|
||||
return success(data=model_schema.ModelBase.model_validate(result), msg="基础模型获取成功")
|
||||
|
||||
|
||||
@router.post("/model_plaza", response_model=ApiResponse)
|
||||
def create_model_base(
|
||||
data: model_schema.ModelBaseCreate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""创建基础模型"""
|
||||
|
||||
result = ModelBaseService.create_model_base(db=db, data=data)
|
||||
return success(data=model_schema.ModelBase.model_validate(result), msg="基础模型创建成功")
|
||||
|
||||
|
||||
@router.put("/model_plaza/{model_base_id}", response_model=ApiResponse)
|
||||
def update_model_base(
|
||||
model_base_id: uuid.UUID,
|
||||
data: model_schema.ModelBaseUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""更新基础模型"""
|
||||
|
||||
# 不允许更改type类型
|
||||
if data.type is not None or data.provider is not None:
|
||||
raise BusinessException("不允许更改模型类型和供应商", BizCode.INVALID_PARAMETER)
|
||||
|
||||
result = ModelBaseService.update_model_base(db=db, model_base_id=model_base_id, data=data)
|
||||
return success(data=model_schema.ModelBase.model_validate(result), msg="基础模型更新成功")
|
||||
|
||||
|
||||
@router.delete("/model_plaza/{model_base_id}", response_model=ApiResponse)
|
||||
def delete_model_base(
|
||||
model_base_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""删除基础模型"""
|
||||
|
||||
ModelBaseService.delete_model_base(db=db, model_base_id=model_base_id)
|
||||
return success(msg="基础模型删除成功")
|
||||
|
||||
|
||||
@router.post("/model_plaza/{model_base_id}/add", response_model=ApiResponse)
|
||||
def add_model_from_plaza(
|
||||
model_base_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""从模型广场添加模型到模型列表"""
|
||||
|
||||
result = ModelBaseService.add_model_from_plaza(db=db, model_base_id=model_base_id, tenant_id=current_user.tenant_id)
|
||||
return success(data=model_schema.ModelConfig.model_validate(result), msg="模型添加成功")
|
||||
|
||||
|
||||
@router.get("/{model_id}", response_model=ApiResponse)
|
||||
def get_model_by_id(
|
||||
model_id: uuid.UUID,
|
||||
@@ -138,6 +290,73 @@ async def create_model(
|
||||
raise
|
||||
|
||||
|
||||
@router.post("/composite", response_model=ApiResponse)
|
||||
async def create_composite_model(
|
||||
model_data: model_schema.CompositeModelCreate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
创建组合模型
|
||||
|
||||
- 绑定一个或多个现有的 API Key
|
||||
- 所有 API Key 必须来自非组合模型
|
||||
- 所有 API Key 关联的模型类型必须与组合模型类型一致
|
||||
"""
|
||||
api_logger.info(f"创建组合模型请求: {model_data.name}, 用户: {current_user.username}, tenant_id={current_user.tenant_id}")
|
||||
|
||||
try:
|
||||
result_orm = await ModelConfigService.create_composite_model(db=db, model_data=model_data, tenant_id=current_user.tenant_id)
|
||||
api_logger.info(f"组合模型创建成功: {result_orm.name} (ID: {result_orm.id})")
|
||||
|
||||
result = model_schema.ModelConfig.model_validate(result_orm)
|
||||
return success(data=result, msg="组合模型创建成功")
|
||||
except Exception as e:
|
||||
api_logger.error(f"创建组合模型失败: {model_data.name} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
@router.put("/composite/{model_id}", response_model=ApiResponse)
|
||||
async def update_composite_model(
|
||||
model_id: uuid.UUID,
|
||||
model_data: model_schema.CompositeModelCreate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""更新组合模型"""
|
||||
api_logger.info(f"更新组合模型请求: model_id={model_id}, 用户: {current_user.username}")
|
||||
|
||||
try:
|
||||
if model_data.type is not None:
|
||||
raise BusinessException("不允许更改模型类型", BizCode.INVALID_PARAMETER)
|
||||
result_orm = await ModelConfigService.update_composite_model(db=db, model_id=model_id, model_data=model_data, tenant_id=current_user.tenant_id)
|
||||
api_logger.info(f"组合模型更新成功: {result_orm.name} (ID: {model_id})")
|
||||
|
||||
result = model_schema.ModelConfig.model_validate(result_orm)
|
||||
return success(data=result, msg="组合模型更新成功")
|
||||
except Exception as e:
|
||||
api_logger.error(f"更新组合模型失败: model_id={model_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
@router.delete("/composite/{model_id}", response_model=ApiResponse)
|
||||
def delete_composite_model(
|
||||
model_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""删除组合模型"""
|
||||
api_logger.info(f"删除组合模型请求: model_id={model_id}, 用户: {current_user.username}")
|
||||
|
||||
try:
|
||||
ModelConfigService.delete_model(db=db, model_id=model_id, tenant_id=current_user.tenant_id)
|
||||
api_logger.info(f"组合模型删除成功: model_id={model_id}")
|
||||
return success(msg="组合模型删除成功")
|
||||
except Exception as e:
|
||||
api_logger.error(f"删除组合模型失败: model_id={model_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
@router.put("/{model_id}", response_model=ApiResponse)
|
||||
def update_model(
|
||||
model_id: uuid.UUID,
|
||||
@@ -149,6 +368,9 @@ def update_model(
|
||||
更新模型配置
|
||||
"""
|
||||
api_logger.info(f"更新模型配置请求: model_id={model_id}, 用户: {current_user.username}, tenant_id={current_user.tenant_id}")
|
||||
|
||||
if model_data.type is not None or model_data.provider is not None:
|
||||
raise BusinessException("不允许更改模型类型和供应商", BizCode.INVALID_PARAMETER)
|
||||
|
||||
try:
|
||||
api_logger.debug(f"开始更新模型配置: model_id={model_id}")
|
||||
@@ -214,6 +436,53 @@ def get_model_api_keys(
|
||||
raise
|
||||
|
||||
|
||||
@router.post("/provider/apikeys", response_model=ApiResponse)
|
||||
async def create_model_api_key_by_provider(
|
||||
api_key_data: model_schema.ModelApiKeyCreateByProvider,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
根据供应商为所有匹配的模型创建API Key
|
||||
"""
|
||||
api_logger.info(f"创建API Key请求: provider={api_key_data.provider}, 用户: {current_user.username}")
|
||||
|
||||
try:
|
||||
# 根据tenant_id和provider筛选model_config_id列表
|
||||
model_config_ids = api_key_data.model_config_ids
|
||||
if not model_config_ids:
|
||||
model_config_ids = ModelConfigRepository.get_model_config_ids_by_provider(
|
||||
db=db,
|
||||
tenant_id=current_user.tenant_id,
|
||||
provider=api_key_data.provider
|
||||
)
|
||||
|
||||
if not model_config_ids:
|
||||
raise BusinessException(f"未找到供应商 {api_key_data.provider} 的模型配置", BizCode.MODEL_NOT_FOUND)
|
||||
|
||||
# 构造schema并调用service
|
||||
create_data = model_schema.ModelApiKeyCreateByProvider(
|
||||
provider=api_key_data.provider,
|
||||
api_key=api_key_data.api_key,
|
||||
api_base=api_key_data.api_base,
|
||||
description=api_key_data.description,
|
||||
config=api_key_data.config,
|
||||
is_active=api_key_data.is_active,
|
||||
priority=api_key_data.priority,
|
||||
model_config_ids=model_config_ids
|
||||
)
|
||||
created_keys, failed_models = await ModelApiKeyService.create_api_key_by_provider(db=db, data=create_data)
|
||||
|
||||
api_logger.info(f"API Key创建成功: 关联{len(created_keys)}个模型")
|
||||
# result_list = [model_schema.ModelApiKey.model_validate(key) for key in created_keys]
|
||||
result = "API Key已存在" if len(created_keys) == 0 and len(failed_models) == 0 else \
|
||||
f"成功为 {len(created_keys)} 个模型创建API Key, 失败模型列表{failed_models}"
|
||||
return success(data=result, msg=f"成功为 {len(created_keys)} 个模型创建API Key")
|
||||
except Exception as e:
|
||||
api_logger.error(f"创建API Key失败: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
@router.post("/{model_id}/apikeys", response_model=ApiResponse, status_code=status.HTTP_201_CREATED)
|
||||
async def create_model_api_key(
|
||||
model_id: uuid.UUID,
|
||||
@@ -228,11 +497,12 @@ async def create_model_api_key(
|
||||
|
||||
try:
|
||||
# 设置模型配置ID
|
||||
api_key_data.model_config_id = model_id
|
||||
api_key_data.model_config_ids = [model_id]
|
||||
|
||||
api_logger.debug(f"开始创建模型API Key: {api_key_data.model_name}")
|
||||
result = await ModelApiKeyService.create_api_key(db=db, api_key_data=api_key_data)
|
||||
api_logger.info(f"模型API Key创建成功: {result.model_name} (ID: {result.id})")
|
||||
result_orm = await ModelApiKeyService.create_api_key(db=db, api_key_data=api_key_data)
|
||||
api_logger.info(f"模型API Key创建成功: {result_orm.model_name} (ID: {result_orm.id})")
|
||||
result = model_schema.ModelApiKey.model_validate(result_orm)
|
||||
return success(data=result, msg="模型API Key创建成功")
|
||||
except Exception as e:
|
||||
api_logger.error(f"创建模型API Key失败: {api_key_data.model_name} - {str(e)}")
|
||||
@@ -334,5 +604,3 @@ async def validate_model_config(
|
||||
return success(data=model_schema.ModelValidateResponse(**result), msg="验证完成")
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
1132
api/app/controllers/ontology_controller.py
Normal file
1132
api/app/controllers/ontology_controller.py
Normal file
File diff suppressed because it is too large
Load Diff
611
api/app/controllers/ontology_secondary_routes.py
Normal file
611
api/app/controllers/ontology_secondary_routes.py
Normal file
@@ -0,0 +1,611 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""本体场景和类型路由(续)
|
||||
|
||||
由于主Controller文件较大,将剩余路由放在此文件中。
|
||||
"""
|
||||
|
||||
from uuid import UUID
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import Depends
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.core.response_utils import fail, success
|
||||
from app.db import get_db
|
||||
from app.dependencies import get_current_user
|
||||
from app.models.user_model import User
|
||||
from app.schemas.ontology_schemas import (
|
||||
SceneResponse,
|
||||
SceneListResponse,
|
||||
PaginationInfo,
|
||||
ClassCreateRequest,
|
||||
ClassUpdateRequest,
|
||||
ClassResponse,
|
||||
ClassListResponse,
|
||||
ClassBatchCreateResponse,
|
||||
)
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.services.ontology_service import OntologyService
|
||||
from app.core.memory.llm_tools.openai_client import OpenAIClient
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
|
||||
|
||||
api_logger = get_api_logger()
|
||||
|
||||
|
||||
def _get_dummy_ontology_service(db: Session) -> OntologyService:
|
||||
"""获取OntologyService实例(不需要LLM)
|
||||
|
||||
场景和类型管理不需要LLM,创建一个dummy配置。
|
||||
"""
|
||||
dummy_config = RedBearModelConfig(
|
||||
model_name="dummy",
|
||||
provider="openai",
|
||||
api_key="dummy",
|
||||
base_url="https://api.openai.com/v1"
|
||||
)
|
||||
llm_client = OpenAIClient(model_config=dummy_config)
|
||||
return OntologyService(llm_client=llm_client, db=db)
|
||||
|
||||
|
||||
# 这些函数将被导入到主Controller中
|
||||
|
||||
async def scenes_handler(
|
||||
workspace_id: Optional[str] = None,
|
||||
scene_name: Optional[str] = None,
|
||||
page: Optional[int] = None,
|
||||
page_size: Optional[int] = None,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""获取场景列表(支持模糊搜索和全量查询,全量查询支持分页)
|
||||
|
||||
当提供 scene_name 参数时,进行模糊搜索(不分页);
|
||||
当不提供 scene_name 参数时,返回所有场景(支持分页)。
|
||||
|
||||
Args:
|
||||
workspace_id: 工作空间ID(可选,默认当前用户工作空间)
|
||||
scene_name: 场景名称关键词(可选,支持模糊匹配)
|
||||
page: 页码(可选,从1开始,仅在全量查询时有效)
|
||||
page_size: 每页数量(可选,仅在全量查询时有效)
|
||||
db: 数据库会话
|
||||
current_user: 当前用户
|
||||
"""
|
||||
operation = "search" if scene_name else "list"
|
||||
api_logger.info(
|
||||
f"Scene {operation} requested by user {current_user.id}, "
|
||||
f"workspace_id={workspace_id}, keyword={scene_name}, page={page}, page_size={page_size}"
|
||||
)
|
||||
|
||||
try:
|
||||
# 确定工作空间ID
|
||||
if workspace_id:
|
||||
try:
|
||||
ws_uuid = UUID(workspace_id)
|
||||
except ValueError:
|
||||
api_logger.warning(f"Invalid workspace_id format: {workspace_id}")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "无效的工作空间ID格式")
|
||||
else:
|
||||
ws_uuid = current_user.current_workspace_id
|
||||
if not ws_uuid:
|
||||
api_logger.warning(f"User {current_user.id} has no current workspace")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "当前用户没有工作空间")
|
||||
|
||||
# 创建Service
|
||||
service = _get_dummy_ontology_service(db)
|
||||
|
||||
# 根据是否提供 scene_name 决定查询方式
|
||||
if scene_name and scene_name.strip():
|
||||
# 验证分页参数(模糊搜索也支持分页)
|
||||
if page is not None and page < 1:
|
||||
api_logger.warning(f"Invalid page number: {page}")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "页码必须大于0")
|
||||
|
||||
if page_size is not None and page_size < 1:
|
||||
api_logger.warning(f"Invalid page_size: {page_size}")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "每页数量必须大于0")
|
||||
|
||||
# 如果只提供了page或page_size中的一个,返回错误
|
||||
if (page is not None and page_size is None) or (page is None and page_size is not None):
|
||||
api_logger.warning(f"Incomplete pagination params: page={page}, page_size={page_size}")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "分页参数page和pagesize必须同时提供")
|
||||
|
||||
# 模糊搜索场景(支持分页)
|
||||
scenes = service.search_scenes_by_name(scene_name.strip(), ws_uuid)
|
||||
total = len(scenes)
|
||||
|
||||
# 如果提供了分页参数,进行分页处理
|
||||
if page is not None and page_size is not None:
|
||||
start_idx = (page - 1) * page_size
|
||||
end_idx = start_idx + page_size
|
||||
scenes = scenes[start_idx:end_idx]
|
||||
|
||||
# 构建响应
|
||||
items = []
|
||||
for scene in scenes:
|
||||
# 获取前3个class_name作为entity_type
|
||||
entity_type = [cls.class_name for cls in scene.classes[:3]] if scene.classes else None
|
||||
# 动态计算 type_num
|
||||
type_num = len(scene.classes) if scene.classes else 0
|
||||
|
||||
items.append(SceneResponse(
|
||||
scene_id=scene.scene_id,
|
||||
scene_name=scene.scene_name,
|
||||
scene_description=scene.scene_description,
|
||||
type_num=type_num,
|
||||
entity_type=entity_type,
|
||||
workspace_id=scene.workspace_id,
|
||||
created_at=scene.created_at,
|
||||
updated_at=scene.updated_at,
|
||||
classes_count=type_num
|
||||
))
|
||||
|
||||
# 构建响应(包含分页信息)
|
||||
if page is not None and page_size is not None:
|
||||
# 计算是否有下一页
|
||||
hasnext = (page * page_size) < total
|
||||
|
||||
pagination_info = PaginationInfo(
|
||||
page=page,
|
||||
pagesize=page_size,
|
||||
total=total,
|
||||
hasnext=hasnext
|
||||
)
|
||||
response = SceneListResponse(items=items, page=pagination_info)
|
||||
else:
|
||||
response = SceneListResponse(items=items)
|
||||
|
||||
api_logger.info(
|
||||
f"Scene search completed: found {len(items)} scenes matching '{scene_name}' "
|
||||
f"in workspace {ws_uuid}, total={total}"
|
||||
)
|
||||
else:
|
||||
# 获取所有场景(支持分页)
|
||||
# 验证分页参数
|
||||
if page is not None and page < 1:
|
||||
api_logger.warning(f"Invalid page number: {page}")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "页码必须大于0")
|
||||
|
||||
if page_size is not None and page_size < 1:
|
||||
api_logger.warning(f"Invalid page_size: {page_size}")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "每页数量必须大于0")
|
||||
|
||||
# 如果只提供了page或page_size中的一个,返回错误
|
||||
if (page is not None and page_size is None) or (page is None and page_size is not None):
|
||||
api_logger.warning(f"Incomplete pagination params: page={page}, page_size={page_size}")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "分页参数page和pagesize必须同时提供")
|
||||
|
||||
scenes, total = service.list_scenes(ws_uuid, page, page_size)
|
||||
|
||||
# 构建响应
|
||||
items = []
|
||||
for scene in scenes:
|
||||
# 获取前3个class_name作为entity_type
|
||||
entity_type = [cls.class_name for cls in scene.classes[:3]] if scene.classes else None
|
||||
# 动态计算 type_num
|
||||
type_num = len(scene.classes) if scene.classes else 0
|
||||
|
||||
items.append(SceneResponse(
|
||||
scene_id=scene.scene_id,
|
||||
scene_name=scene.scene_name,
|
||||
scene_description=scene.scene_description,
|
||||
type_num=type_num,
|
||||
entity_type=entity_type,
|
||||
workspace_id=scene.workspace_id,
|
||||
created_at=scene.created_at,
|
||||
updated_at=scene.updated_at,
|
||||
classes_count=type_num
|
||||
))
|
||||
|
||||
# 构建响应(包含分页信息)
|
||||
if page is not None and page_size is not None:
|
||||
# 计算是否有下一页
|
||||
hasnext = (page * page_size) < total
|
||||
|
||||
pagination_info = PaginationInfo(
|
||||
page=page,
|
||||
pagesize=page_size,
|
||||
total=total,
|
||||
hasnext=hasnext
|
||||
)
|
||||
response = SceneListResponse(items=items, page=pagination_info)
|
||||
else:
|
||||
response = SceneListResponse(items=items)
|
||||
|
||||
api_logger.info(f"Scene list retrieved successfully, count={len(items)}, total={total}")
|
||||
|
||||
return success(data=response.model_dump(mode='json'), msg="查询成功")
|
||||
|
||||
except ValueError as e:
|
||||
api_logger.warning(f"Validation error in scene {operation}: {str(e)}")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", str(e))
|
||||
|
||||
except RuntimeError as e:
|
||||
api_logger.error(f"Runtime error in scene {operation}: {str(e)}", exc_info=True)
|
||||
return fail(BizCode.INTERNAL_ERROR, "查询失败", str(e))
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"Unexpected error in scene {operation}: {str(e)}", exc_info=True)
|
||||
return fail(BizCode.INTERNAL_ERROR, "查询失败", str(e))
|
||||
|
||||
|
||||
# ==================== 本体类型管理接口 ====================
|
||||
|
||||
async def create_class_handler(
|
||||
request: ClassCreateRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""创建本体类型(统一使用列表形式,支持单个或批量)"""
|
||||
|
||||
# 根据列表长度判断是单个还是批量
|
||||
count = len(request.classes)
|
||||
mode = "single" if count == 1 else "batch"
|
||||
|
||||
api_logger.info(
|
||||
f"Class creation ({mode}) requested by user {current_user.id}, "
|
||||
f"scene_id={request.scene_id}, count={count}"
|
||||
)
|
||||
|
||||
try:
|
||||
# 获取当前工作空间ID
|
||||
workspace_id = current_user.current_workspace_id
|
||||
if not workspace_id:
|
||||
api_logger.warning(f"User {current_user.id} has no current workspace")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "当前用户没有工作空间")
|
||||
|
||||
# 创建Service
|
||||
service = _get_dummy_ontology_service(db)
|
||||
|
||||
# 准备类型数据
|
||||
classes_data = [
|
||||
{
|
||||
"class_name": item.class_name,
|
||||
"class_description": item.class_description
|
||||
}
|
||||
for item in request.classes
|
||||
]
|
||||
|
||||
if count == 1:
|
||||
# 单个创建
|
||||
class_data = classes_data[0]
|
||||
ontology_class = service.create_class(
|
||||
scene_id=request.scene_id,
|
||||
class_name=class_data["class_name"],
|
||||
class_description=class_data["class_description"],
|
||||
workspace_id=workspace_id
|
||||
)
|
||||
|
||||
# 构建单个响应
|
||||
response = ClassResponse(
|
||||
class_id=ontology_class.class_id,
|
||||
class_name=ontology_class.class_name,
|
||||
class_description=ontology_class.class_description,
|
||||
scene_id=ontology_class.scene_id,
|
||||
created_at=ontology_class.created_at,
|
||||
updated_at=ontology_class.updated_at
|
||||
)
|
||||
|
||||
api_logger.info(f"Class created successfully: {ontology_class.class_id}")
|
||||
|
||||
return success(data=response.model_dump(mode='json'), msg="类型创建成功")
|
||||
|
||||
else:
|
||||
# 批量创建
|
||||
created_classes, errors = service.create_classes_batch(
|
||||
scene_id=request.scene_id,
|
||||
classes=classes_data,
|
||||
workspace_id=workspace_id
|
||||
)
|
||||
|
||||
# 构建批量响应
|
||||
items = []
|
||||
for ontology_class in created_classes:
|
||||
items.append(ClassResponse(
|
||||
class_id=ontology_class.class_id,
|
||||
class_name=ontology_class.class_name,
|
||||
class_description=ontology_class.class_description,
|
||||
scene_id=ontology_class.scene_id,
|
||||
created_at=ontology_class.created_at,
|
||||
updated_at=ontology_class.updated_at
|
||||
))
|
||||
|
||||
response = ClassBatchCreateResponse(
|
||||
total=len(classes_data),
|
||||
success_count=len(created_classes),
|
||||
failed_count=len(errors),
|
||||
items=items,
|
||||
errors=errors if errors else None
|
||||
)
|
||||
|
||||
api_logger.info(
|
||||
f"Batch class creation completed: "
|
||||
f"success={len(created_classes)}, failed={len(errors)}"
|
||||
)
|
||||
|
||||
return success(data=response.model_dump(mode='json'), msg="批量创建完成")
|
||||
|
||||
except ValueError as e:
|
||||
api_logger.warning(f"Validation error in class creation: {str(e)}")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", str(e))
|
||||
|
||||
except RuntimeError as e:
|
||||
api_logger.error(f"Runtime error in class creation: {str(e)}", exc_info=True)
|
||||
return fail(BizCode.INTERNAL_ERROR, "类型创建失败", str(e))
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"Unexpected error in class creation: {str(e)}", exc_info=True)
|
||||
return fail(BizCode.INTERNAL_ERROR, "类型创建失败", str(e))
|
||||
|
||||
|
||||
async def update_class_handler(
|
||||
class_id: str,
|
||||
request: ClassUpdateRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""更新本体类型"""
|
||||
api_logger.info(
|
||||
f"Class update requested by user {current_user.id}, "
|
||||
f"class_id={class_id}"
|
||||
)
|
||||
|
||||
try:
|
||||
# 验证UUID格式
|
||||
try:
|
||||
class_uuid = UUID(class_id)
|
||||
except ValueError:
|
||||
api_logger.warning(f"Invalid class_id format: {class_id}")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "无效的类型ID格式")
|
||||
|
||||
# 获取当前工作空间ID
|
||||
workspace_id = current_user.current_workspace_id
|
||||
if not workspace_id:
|
||||
api_logger.warning(f"User {current_user.id} has no current workspace")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "当前用户没有工作空间")
|
||||
|
||||
# 创建Service
|
||||
service = _get_dummy_ontology_service(db)
|
||||
|
||||
# 更新类型
|
||||
ontology_class = service.update_class(
|
||||
class_id=class_uuid,
|
||||
class_name=request.class_name,
|
||||
class_description=request.class_description,
|
||||
workspace_id=workspace_id
|
||||
)
|
||||
|
||||
# 构建响应
|
||||
response = ClassResponse(
|
||||
class_id=ontology_class.class_id,
|
||||
class_name=ontology_class.class_name,
|
||||
class_description=ontology_class.class_description,
|
||||
scene_id=ontology_class.scene_id,
|
||||
created_at=ontology_class.created_at,
|
||||
updated_at=ontology_class.updated_at
|
||||
)
|
||||
|
||||
api_logger.info(f"Class updated successfully: {class_id}")
|
||||
|
||||
return success(data=response.model_dump(mode='json'), msg="类型更新成功")
|
||||
|
||||
except ValueError as e:
|
||||
api_logger.warning(f"Validation error in class update: {str(e)}")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", str(e))
|
||||
|
||||
except RuntimeError as e:
|
||||
api_logger.error(f"Runtime error in class update: {str(e)}", exc_info=True)
|
||||
return fail(BizCode.INTERNAL_ERROR, "类型更新失败", str(e))
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"Unexpected error in class update: {str(e)}", exc_info=True)
|
||||
return fail(BizCode.INTERNAL_ERROR, "类型更新失败", str(e))
|
||||
|
||||
|
||||
async def delete_class_handler(
|
||||
class_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""删除本体类型"""
|
||||
api_logger.info(
|
||||
f"Class deletion requested by user {current_user.id}, "
|
||||
f"class_id={class_id}"
|
||||
)
|
||||
|
||||
try:
|
||||
# 验证UUID格式
|
||||
try:
|
||||
class_uuid = UUID(class_id)
|
||||
except ValueError:
|
||||
api_logger.warning(f"Invalid class_id format: {class_id}")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "无效的类型ID格式")
|
||||
|
||||
# 获取当前工作空间ID
|
||||
workspace_id = current_user.current_workspace_id
|
||||
if not workspace_id:
|
||||
api_logger.warning(f"User {current_user.id} has no current workspace")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "当前用户没有工作空间")
|
||||
|
||||
# 创建Service
|
||||
service = _get_dummy_ontology_service(db)
|
||||
|
||||
# 删除类型
|
||||
success_flag = service.delete_class(
|
||||
class_id=class_uuid,
|
||||
workspace_id=workspace_id
|
||||
)
|
||||
|
||||
api_logger.info(f"Class deleted successfully: {class_id}")
|
||||
|
||||
return success(data={"deleted": success_flag}, msg="类型删除成功")
|
||||
|
||||
except ValueError as e:
|
||||
api_logger.warning(f"Validation error in class deletion: {str(e)}")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", str(e))
|
||||
|
||||
except RuntimeError as e:
|
||||
api_logger.error(f"Runtime error in class deletion: {str(e)}", exc_info=True)
|
||||
return fail(BizCode.INTERNAL_ERROR, "类型删除失败", str(e))
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"Unexpected error in class deletion: {str(e)}", exc_info=True)
|
||||
return fail(BizCode.INTERNAL_ERROR, "类型删除失败", str(e))
|
||||
|
||||
|
||||
async def get_class_handler(
|
||||
class_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""获取单个本体类型"""
|
||||
api_logger.info(
|
||||
f"Get class requested by user {current_user.id}, "
|
||||
f"class_id={class_id}"
|
||||
)
|
||||
|
||||
try:
|
||||
# 验证UUID格式
|
||||
try:
|
||||
class_uuid = UUID(class_id)
|
||||
except ValueError:
|
||||
api_logger.warning(f"Invalid class_id format: {class_id}")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "无效的类型ID格式")
|
||||
|
||||
# 获取当前工作空间ID
|
||||
workspace_id = current_user.current_workspace_id
|
||||
if not workspace_id:
|
||||
api_logger.warning(f"User {current_user.id} has no current workspace")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "当前用户没有工作空间")
|
||||
|
||||
# 创建Service
|
||||
service = _get_dummy_ontology_service(db)
|
||||
|
||||
# 获取类型(会抛出ValueError如果不存在)
|
||||
ontology_class = service.get_class_by_id(class_uuid, workspace_id)
|
||||
|
||||
# 构建响应
|
||||
response = ClassResponse(
|
||||
class_id=ontology_class.class_id,
|
||||
class_name=ontology_class.class_name,
|
||||
class_description=ontology_class.class_description,
|
||||
scene_id=ontology_class.scene_id,
|
||||
created_at=ontology_class.created_at,
|
||||
updated_at=ontology_class.updated_at
|
||||
)
|
||||
|
||||
api_logger.info(f"Class retrieved successfully: {class_id}")
|
||||
|
||||
return success(data=response.model_dump(mode='json'), msg="查询成功")
|
||||
|
||||
except ValueError as e:
|
||||
# 类型不存在或无权限访问
|
||||
api_logger.warning(f"Validation error in get class: {str(e)}")
|
||||
return fail(BizCode.NOT_FOUND, "请求参数无效", str(e))
|
||||
|
||||
except RuntimeError as e:
|
||||
api_logger.error(f"Runtime error in get class: {str(e)}", exc_info=True)
|
||||
return fail(BizCode.INTERNAL_ERROR, "查询失败", str(e))
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"Unexpected error in get class: {str(e)}", exc_info=True)
|
||||
return fail(BizCode.INTERNAL_ERROR, "查询失败", str(e))
|
||||
|
||||
|
||||
async def classes_handler(
|
||||
scene_id: str,
|
||||
class_name: Optional[str] = None,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""获取类型列表(支持模糊搜索和全量查询)
|
||||
|
||||
当提供 class_name 参数时,进行模糊搜索;
|
||||
当不提供 class_name 参数时,返回场景下的所有类型。
|
||||
|
||||
Args:
|
||||
scene_id: 场景ID(必填)
|
||||
class_name: 类型名称关键词(可选,支持模糊匹配)
|
||||
db: 数据库会话
|
||||
current_user: 当前用户
|
||||
"""
|
||||
operation = "search" if class_name else "list"
|
||||
api_logger.info(
|
||||
f"Class {operation} requested by user {current_user.id}, "
|
||||
f"keyword={class_name}, scene_id={scene_id}"
|
||||
)
|
||||
|
||||
try:
|
||||
# 验证UUID格式
|
||||
try:
|
||||
scene_uuid = UUID(scene_id)
|
||||
except ValueError:
|
||||
api_logger.warning(f"Invalid scene_id format: {scene_id}")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "无效的场景ID格式")
|
||||
|
||||
# 获取当前工作空间ID
|
||||
workspace_id = current_user.current_workspace_id
|
||||
if not workspace_id:
|
||||
api_logger.warning(f"User {current_user.id} has no current workspace")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "当前用户没有工作空间")
|
||||
|
||||
# 创建Service
|
||||
service = _get_dummy_ontology_service(db)
|
||||
|
||||
# 获取场景信息
|
||||
scene = service.get_scene_by_id(scene_uuid, workspace_id)
|
||||
if not scene:
|
||||
api_logger.warning(f"Scene not found: {scene_id}")
|
||||
return fail(BizCode.NOT_FOUND, "场景不存在", f"未找到ID为 {scene_id} 的场景")
|
||||
|
||||
# 根据是否提供 class_name 决定查询方式
|
||||
if class_name and class_name.strip():
|
||||
# 模糊搜索类型
|
||||
classes = service.search_classes_by_name(class_name.strip(), scene_uuid, workspace_id)
|
||||
else:
|
||||
# 获取所有类型
|
||||
classes = service.list_classes_by_scene(scene_uuid, workspace_id)
|
||||
|
||||
# 构建响应
|
||||
items = []
|
||||
for ontology_class in classes:
|
||||
items.append(ClassResponse(
|
||||
class_id=ontology_class.class_id,
|
||||
class_name=ontology_class.class_name,
|
||||
class_description=ontology_class.class_description,
|
||||
scene_id=ontology_class.scene_id,
|
||||
created_at=ontology_class.created_at,
|
||||
updated_at=ontology_class.updated_at
|
||||
))
|
||||
|
||||
response = ClassListResponse(
|
||||
total=len(items),
|
||||
scene_id=scene_uuid,
|
||||
scene_name=scene.scene_name,
|
||||
scene_description=scene.scene_description,
|
||||
items=items
|
||||
)
|
||||
|
||||
if class_name:
|
||||
api_logger.info(
|
||||
f"Class search completed: found {len(items)} classes matching '{class_name}' "
|
||||
f"in scene {scene_id}"
|
||||
)
|
||||
else:
|
||||
api_logger.info(f"Class list retrieved successfully, count={len(items)}")
|
||||
|
||||
return success(data=response.model_dump(mode='json'), msg="查询成功")
|
||||
|
||||
except ValueError as e:
|
||||
api_logger.warning(f"Validation error in class {operation}: {str(e)}")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", str(e))
|
||||
|
||||
except RuntimeError as e:
|
||||
api_logger.error(f"Runtime error in class {operation}: {str(e)}", exc_info=True)
|
||||
return fail(BizCode.INTERNAL_ERROR, "查询失败", str(e))
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"Unexpected error in class {operation}: {str(e)}", exc_info=True)
|
||||
return fail(BizCode.INTERNAL_ERROR, "查询失败", str(e))
|
||||
@@ -1,5 +1,5 @@
|
||||
import uuid
|
||||
import json
|
||||
import uuid
|
||||
|
||||
from fastapi import APIRouter, Depends, Path
|
||||
from sqlalchemy.orm import Session
|
||||
@@ -8,9 +8,13 @@ from starlette.responses import StreamingResponse
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.core.response_utils import success
|
||||
from app.dependencies import get_current_user, get_db
|
||||
from app.models.prompt_optimizer_model import RoleType
|
||||
from app.schemas.prompt_optimizer_schema import PromptOptMessage, PromptOptModelSet, CreateSessionResponse, \
|
||||
OptimizePromptResponse, SessionHistoryResponse, SessionMessage
|
||||
from app.schemas.prompt_optimizer_schema import (
|
||||
PromptOptMessage,
|
||||
CreateSessionResponse,
|
||||
SessionHistoryResponse,
|
||||
SessionMessage,
|
||||
PromptSaveRequest
|
||||
)
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.services.prompt_optimizer_service import PromptOptimizerService
|
||||
|
||||
@@ -116,7 +120,8 @@ async def get_prompt_opt(
|
||||
session_id=session_id,
|
||||
user_id=current_user.id,
|
||||
current_prompt=data.current_prompt,
|
||||
user_require=data.message
|
||||
user_require=data.message,
|
||||
skill=data.skill
|
||||
):
|
||||
# chunk 是 prompt 的增量内容
|
||||
yield f"event:message\ndata: {json.dumps(chunk)}\n\n"
|
||||
@@ -135,3 +140,109 @@ async def get_prompt_opt(
|
||||
"X-Accel-Buffering": "no"
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/releases",
|
||||
summary="Get prompt optimization",
|
||||
response_model=ApiResponse
|
||||
)
|
||||
def save_prompt(
|
||||
data: PromptSaveRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
Save a prompt release for the current tenant.
|
||||
|
||||
Args:
|
||||
data (PromptSaveRequest): Request body containing session_id, title, and prompt.
|
||||
db (Session): SQLAlchemy database session, injected via dependency.
|
||||
current_user: Currently authenticated user object, injected via dependency.
|
||||
|
||||
Returns:
|
||||
ApiResponse: Standard API response containing the saved prompt release info:
|
||||
- id: UUID of the prompt release
|
||||
- session_id: associated session
|
||||
- title: prompt title
|
||||
- prompt: prompt content
|
||||
- created_at: timestamp of creation
|
||||
|
||||
Raises:
|
||||
Any database or service exceptions are propagated to the global exception handler.
|
||||
"""
|
||||
service = PromptOptimizerService(db)
|
||||
prompt_info = service.save_prompt(
|
||||
tenant_id=current_user.tenant_id,
|
||||
session_id=data.session_id,
|
||||
title=data.title,
|
||||
prompt=data.prompt
|
||||
)
|
||||
return success(data=prompt_info)
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/releases/{prompt_id}",
|
||||
summary="Delete prompt (soft delete)",
|
||||
response_model=ApiResponse
|
||||
)
|
||||
def delete_prompt(
|
||||
prompt_id: uuid.UUID = Path(..., description="Prompt ID"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
Soft delete a prompt release.
|
||||
|
||||
Args:
|
||||
prompt_id
|
||||
db (Session): Database session
|
||||
current_user: Current logged-in user
|
||||
|
||||
Returns:
|
||||
ApiResponse: Success message confirming deletion
|
||||
"""
|
||||
service = PromptOptimizerService(db)
|
||||
service.delete_prompt(
|
||||
tenant_id=current_user.tenant_id,
|
||||
prompt_id=prompt_id
|
||||
)
|
||||
return success(msg="Prompt deleted successfully")
|
||||
|
||||
|
||||
@router.get(
|
||||
"/releases/list",
|
||||
summary="Get paginated list of released prompts with optional filter",
|
||||
response_model=ApiResponse
|
||||
)
|
||||
def get_release_list(
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
keyword: str | None = None,
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
Retrieve paginated list of released prompts for the current tenant.
|
||||
Optionally filter by keyword in title.
|
||||
|
||||
Args:
|
||||
page (int): Page number (starting from 1)
|
||||
page_size (int): Number of items per page (max 100)
|
||||
keyword (str | None): Optional keyword to filter prompt titles
|
||||
db (Session): Database session
|
||||
current_user: Current logged-in user
|
||||
|
||||
Returns:
|
||||
ApiResponse: Contains paginated list of prompt releases with metadata
|
||||
"""
|
||||
service = PromptOptimizerService(db)
|
||||
result = service.get_release_list(
|
||||
tenant_id=current_user.tenant_id,
|
||||
page=max(1, page),
|
||||
page_size=min(max(1, page_size), 100),
|
||||
filter_keyword=keyword
|
||||
)
|
||||
return success(data=result)
|
||||
|
||||
|
||||
|
||||
@@ -317,9 +317,12 @@ async def chat(
|
||||
appid = share.app_id
|
||||
"""获取存储类型和工作空间的ID"""
|
||||
|
||||
# 直接通过 SQLAlchemy 查询 app
|
||||
# 直接通过 SQLAlchemy 查询 app(仅查询未删除的应用)
|
||||
from app.models.app_model import App
|
||||
app = db.query(App).filter(App.id == appid).first()
|
||||
app = db.query(App).filter(
|
||||
App.id == appid,
|
||||
App.is_active.is_(True)
|
||||
).first()
|
||||
if not app:
|
||||
raise BusinessException("应用不存在", BizCode.APP_NOT_FOUND)
|
||||
|
||||
@@ -435,7 +438,8 @@ async def chat(
|
||||
memory=payload.memory,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
workspace_id=workspace_id
|
||||
workspace_id=workspace_id,
|
||||
files=payload.files # 传递多模态文件
|
||||
):
|
||||
yield event
|
||||
|
||||
@@ -472,7 +476,8 @@ async def chat(
|
||||
memory=payload.memory,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
workspace_id=workspace_id
|
||||
workspace_id=workspace_id,
|
||||
files=payload.files # 传递多模态文件
|
||||
)
|
||||
return success(data=conversation_schema.ChatResponse(**result).model_dump(mode="json"))
|
||||
elif app_type == AppType.MULTI_AGENT:
|
||||
@@ -575,6 +580,7 @@ async def chat(
|
||||
conversation_id=conversation.id, # 使用已创建的会话 ID
|
||||
user_id=end_user_id, # 转换为字符串
|
||||
variables=payload.variables,
|
||||
files=payload.files,
|
||||
config=config,
|
||||
web_search=payload.web_search,
|
||||
memory=payload.memory,
|
||||
@@ -582,7 +588,8 @@ async def chat(
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
app_id=release.app_id,
|
||||
workspace_id=workspace_id,
|
||||
release_id=release.id
|
||||
release_id=release.id,
|
||||
public=True
|
||||
):
|
||||
event_type = event.get("event", "message")
|
||||
event_data = event.get("data", {})
|
||||
|
||||
@@ -12,7 +12,6 @@ from app.core.exceptions import BusinessException
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.core.response_utils import success
|
||||
from app.db import get_db
|
||||
from app.dependencies import get_app_or_workspace
|
||||
from app.models.app_model import App
|
||||
from app.models.app_model import AppType
|
||||
from app.repositories import knowledge_repository
|
||||
@@ -21,9 +20,10 @@ from app.schemas import AppChatRequest, conversation_schema
|
||||
from app.schemas.api_key_schema import ApiKeyAuth
|
||||
from app.services import workspace_service
|
||||
from app.services.app_chat_service import AppChatService, get_app_chat_service
|
||||
from app.services.conversation_service import ConversationService, get_conversation_service
|
||||
from app.utils.app_config_utils import dict_to_multi_agent_config, workflow_config_4_app_release, agent_config_4_app_release, multi_agent_config_4_app_release
|
||||
from app.services.app_service import get_app_service, AppService
|
||||
from app.services.conversation_service import ConversationService, get_conversation_service
|
||||
from app.utils.app_config_utils import workflow_config_4_app_release, \
|
||||
agent_config_4_app_release, multi_agent_config_4_app_release
|
||||
|
||||
router = APIRouter(prefix="/app", tags=["V1 - App API"])
|
||||
logger = get_business_logger()
|
||||
@@ -34,6 +34,7 @@ async def list_apps():
|
||||
"""列出可访问的应用(占位)"""
|
||||
return success(data=[], msg="App API - Coming Soon")
|
||||
|
||||
|
||||
# /v1/app/chat
|
||||
|
||||
# @router.post("/chat")
|
||||
@@ -73,16 +74,17 @@ def _checkAppConfig(app: App):
|
||||
else:
|
||||
raise BusinessException("不支持的应用类型", BizCode.AGENT_CONFIG_MISSING)
|
||||
|
||||
|
||||
@router.post("/chat")
|
||||
@require_api_key(scopes=["app"])
|
||||
async def chat(
|
||||
request:Request,
|
||||
api_key_auth: ApiKeyAuth = None,
|
||||
db: Session = Depends(get_db),
|
||||
conversation_service: Annotated[ConversationService, Depends(get_conversation_service)] = None,
|
||||
app_chat_service: Annotated[AppChatService, Depends(get_app_chat_service)] = None,
|
||||
app_service: Annotated[AppService, Depends(get_app_service)] = None,
|
||||
message: str = Body(..., description="聊天消息内容"),
|
||||
request: Request,
|
||||
api_key_auth: ApiKeyAuth = None,
|
||||
db: Session = Depends(get_db),
|
||||
conversation_service: Annotated[ConversationService, Depends(get_conversation_service)] = None,
|
||||
app_chat_service: Annotated[AppChatService, Depends(get_app_chat_service)] = None,
|
||||
app_service: Annotated[AppService, Depends(get_app_service)] = None,
|
||||
message: str = Body(..., description="聊天消息内容"),
|
||||
):
|
||||
body = await request.json()
|
||||
payload = AppChatRequest(**body)
|
||||
@@ -98,8 +100,8 @@ async def chat(
|
||||
original_user_id=other_id # Save original user_id to other_id
|
||||
)
|
||||
end_user_id = str(new_end_user.id)
|
||||
web_search=True
|
||||
memory=True
|
||||
web_search = True
|
||||
memory = True
|
||||
# 提前验证和准备(在流式响应开始前完成)
|
||||
storage_type = workspace_service.get_workspace_storage_type_without_auth(
|
||||
db=db,
|
||||
@@ -146,16 +148,17 @@ async def chat(
|
||||
if payload.stream:
|
||||
async def event_generator():
|
||||
async for event in app_chat_service.agnet_chat_stream(
|
||||
message=payload.message,
|
||||
conversation_id=conversation.id, # 使用已创建的会话 ID
|
||||
user_id= end_user_id, # 转换为字符串
|
||||
variables=payload.variables,
|
||||
web_search=web_search,
|
||||
config=agent_config,
|
||||
memory=memory,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
workspace_id=workspace_id
|
||||
message=payload.message,
|
||||
conversation_id=conversation.id, # 使用已创建的会话 ID
|
||||
user_id=end_user_id, # 转换为字符串
|
||||
variables=payload.variables,
|
||||
web_search=web_search,
|
||||
config=agent_config,
|
||||
memory=memory,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
workspace_id=workspace_id,
|
||||
files=payload.files # 传递多模态文件
|
||||
):
|
||||
yield event
|
||||
|
||||
@@ -175,12 +178,13 @@ async def chat(
|
||||
conversation_id=conversation.id, # 使用已创建的会话 ID
|
||||
user_id=end_user_id, # 转换为字符串
|
||||
variables=payload.variables,
|
||||
config= agent_config,
|
||||
config=agent_config,
|
||||
web_search=web_search,
|
||||
memory=memory,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
workspace_id=workspace_id
|
||||
workspace_id=workspace_id,
|
||||
files=payload.files # 传递多模态文件
|
||||
)
|
||||
return success(data=conversation_schema.ChatResponse(**result).model_dump(mode="json"))
|
||||
elif app_type == AppType.MULTI_AGENT:
|
||||
@@ -190,15 +194,15 @@ async def chat(
|
||||
async def event_generator():
|
||||
async for event in app_chat_service.multi_agent_chat_stream(
|
||||
|
||||
message=payload.message,
|
||||
conversation_id=conversation.id, # 使用已创建的会话 ID
|
||||
user_id=end_user_id, # 转换为字符串
|
||||
variables=payload.variables,
|
||||
config=config,
|
||||
web_search=web_search,
|
||||
memory=memory,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id
|
||||
message=payload.message,
|
||||
conversation_id=conversation.id, # 使用已创建的会话 ID
|
||||
user_id=end_user_id, # 转换为字符串
|
||||
variables=payload.variables,
|
||||
config=config,
|
||||
web_search=web_search,
|
||||
memory=memory,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id
|
||||
):
|
||||
yield event
|
||||
|
||||
@@ -232,19 +236,19 @@ async def chat(
|
||||
if payload.stream:
|
||||
async def event_generator():
|
||||
async for event in app_chat_service.workflow_chat_stream(
|
||||
|
||||
message=payload.message,
|
||||
conversation_id=conversation.id, # 使用已创建的会话 ID
|
||||
user_id=new_end_user.id, # 转换为字符串
|
||||
variables=payload.variables,
|
||||
config=config,
|
||||
web_search=payload.web_search,
|
||||
memory=payload.memory,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
app_id=app.id,
|
||||
workspace_id=workspace_id,
|
||||
release_id=app.current_release.id,
|
||||
message=payload.message,
|
||||
conversation_id=conversation.id, # 使用已创建的会话 ID
|
||||
user_id=end_user_id, # 转换为字符串
|
||||
variables=payload.variables,
|
||||
files=payload.files,
|
||||
config=config,
|
||||
web_search=web_search,
|
||||
memory=memory,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
app_id=app.id,
|
||||
workspace_id=workspace_id,
|
||||
release_id=app.current_release.id,
|
||||
):
|
||||
event_type = event.get("event", "message")
|
||||
event_data = event.get("data", {})
|
||||
@@ -268,11 +272,11 @@ async def chat(
|
||||
|
||||
message=payload.message,
|
||||
conversation_id=conversation.id, # 使用已创建的会话 ID
|
||||
user_id=new_end_user.id, # 转换为字符串
|
||||
user_id=end_user_id, # 转换为字符串
|
||||
variables=payload.variables,
|
||||
config=config,
|
||||
web_search=payload.web_search,
|
||||
memory=payload.memory,
|
||||
web_search=web_search,
|
||||
memory=memory,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
app_id=app.id,
|
||||
@@ -294,4 +298,3 @@ async def chat(
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.error_codes import BizCode
|
||||
raise BusinessException(f"不支持的应用类型: {app_type}", BizCode.APP_TYPE_NOT_SUPPORTED)
|
||||
|
||||
|
||||
@@ -39,7 +39,7 @@ async def write_memory_api_service(
|
||||
|
||||
Stores memory content for the specified end user using the Memory API Service.
|
||||
"""
|
||||
logger.info(f"Memory write request - end_user_id: {payload.end_user_id}")
|
||||
logger.info(f"Memory write request - end_user_id: {payload.end_user_id}, tenant_id: {api_key_auth.tenant_id}")
|
||||
|
||||
memory_api_service = MemoryAPIService(db)
|
||||
|
||||
|
||||
@@ -246,3 +246,73 @@ async def rebuild_knowledge_graph(
|
||||
db=db,
|
||||
current_user=current_user)
|
||||
|
||||
|
||||
@router.get("/check/yuque/auth", response_model=ApiResponse)
|
||||
@require_api_key(scopes=["rag"])
|
||||
async def check_yuque_auth(
|
||||
yuque_user_id: str,
|
||||
yuque_token: str,
|
||||
request: Request,
|
||||
api_key_auth: ApiKeyAuth = None,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
check yuque auth info
|
||||
"""
|
||||
api_key = api_key_service.ApiKeyService.get_api_key(db, api_key_auth.api_key_id, api_key_auth.workspace_id)
|
||||
current_user = api_key.creator
|
||||
current_user.current_workspace_id = api_key_auth.workspace_id
|
||||
|
||||
api_logger.info(f"check yuque auth info, username: {current_user.username}")
|
||||
|
||||
return await knowledge_controller.check_yuque_auth(yuque_user_id=yuque_user_id,
|
||||
yuque_token=yuque_token,
|
||||
db=db,
|
||||
current_user=current_user)
|
||||
|
||||
|
||||
@router.get("/check/feishu/auth", response_model=ApiResponse)
|
||||
@require_api_key(scopes=["rag"])
|
||||
async def check_feishu_auth(
|
||||
feishu_app_id: str,
|
||||
feishu_app_secret: str,
|
||||
feishu_folder_token: str,
|
||||
request: Request,
|
||||
api_key_auth: ApiKeyAuth = None,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
check feishu auth info
|
||||
"""
|
||||
api_key = api_key_service.ApiKeyService.get_api_key(db, api_key_auth.api_key_id, api_key_auth.workspace_id)
|
||||
current_user = api_key.creator
|
||||
current_user.current_workspace_id = api_key_auth.workspace_id
|
||||
|
||||
api_logger.info(f"check feishu auth info, username: {current_user.username}")
|
||||
|
||||
return await knowledge_controller.check_feishu_auth(feishu_app_id=feishu_app_id,
|
||||
feishu_app_secret=feishu_app_secret,
|
||||
feishu_folder_token=feishu_folder_token,
|
||||
db=db,
|
||||
current_user=current_user)
|
||||
|
||||
|
||||
@router.post("/{knowledge_id}/sync", response_model=ApiResponse)
|
||||
@require_api_key(scopes=["rag"])
|
||||
async def sync_knowledge(
|
||||
knowledge_id: uuid.UUID,
|
||||
request: Request,
|
||||
api_key_auth: ApiKeyAuth = None,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
sync knowledge base information based on knowledge_id
|
||||
"""
|
||||
api_key = api_key_service.ApiKeyService.get_api_key(db, api_key_auth.api_key_id, api_key_auth.workspace_id)
|
||||
current_user = api_key.creator
|
||||
current_user.current_workspace_id = api_key_auth.workspace_id
|
||||
|
||||
return await knowledge_controller.sync_knowledge(knowledge_id=knowledge_id,
|
||||
db=db,
|
||||
current_user=current_user)
|
||||
|
||||
|
||||
85
api/app/controllers/skill_controller.py
Normal file
85
api/app/controllers/skill_controller.py
Normal file
@@ -0,0 +1,85 @@
|
||||
"""Skill Controller - 技能市场管理"""
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import Optional
|
||||
import uuid
|
||||
|
||||
from app.db import get_db
|
||||
from app.dependencies import get_current_user
|
||||
from app.models import User
|
||||
from app.schemas import skill_schema
|
||||
from app.schemas.response_schema import PageData, PageMeta
|
||||
from app.services.skill_service import SkillService
|
||||
from app.core.response_utils import success
|
||||
|
||||
router = APIRouter(prefix="/skills", tags=["Skills"])
|
||||
|
||||
|
||||
@router.post("", summary="创建技能")
|
||||
def create_skill(
|
||||
data: skill_schema.SkillCreate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""创建技能 - 可以关联现有工具(内置、MCP、自定义)"""
|
||||
tenant_id = current_user.tenant_id
|
||||
skill = SkillService.create_skill(db, data, tenant_id)
|
||||
return success(data=skill_schema.Skill.model_validate(skill), msg="技能创建成功")
|
||||
|
||||
|
||||
@router.get("", summary="技能列表")
|
||||
def list_skills(
|
||||
search: Optional[str] = Query(None, description="搜索关键词"),
|
||||
is_active: Optional[bool] = Query(None, description="是否激活"),
|
||||
is_public: Optional[bool] = Query(None, description="是否公开"),
|
||||
page: int = Query(1, ge=1, description="页码"),
|
||||
pagesize: int = Query(10, ge=1, le=100, description="每页数量"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""技能市场列表 - 包含本工作空间和公开的技能"""
|
||||
tenant_id = current_user.tenant_id
|
||||
skills, total = SkillService.list_skills(
|
||||
db, tenant_id, search, is_active, is_public, page, pagesize
|
||||
)
|
||||
|
||||
items = [skill_schema.Skill.model_validate(s) for s in skills]
|
||||
meta = PageMeta(page=page, pagesize=pagesize, total=total, hasnext=(page * pagesize) < total)
|
||||
return success(data=PageData(page=meta, items=items), msg="技能市场列表获取成功")
|
||||
|
||||
|
||||
@router.get("/{skill_id}", summary="获取技能详情")
|
||||
def get_skill(
|
||||
skill_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""获取技能详情"""
|
||||
tenant_id = current_user.tenant_id
|
||||
skill = SkillService.get_skill(db, skill_id, tenant_id)
|
||||
return success(data=skill_schema.Skill.model_validate(skill), msg="获取技能详情成功")
|
||||
|
||||
|
||||
@router.put("/{skill_id}", summary="更新技能")
|
||||
def update_skill(
|
||||
skill_id: uuid.UUID,
|
||||
data: skill_schema.SkillUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""更新技能"""
|
||||
tenant_id = current_user.tenant_id
|
||||
skill = SkillService.update_skill(db, skill_id, data, tenant_id)
|
||||
return success(data=skill_schema.Skill.model_validate(skill), msg="技能更新成功")
|
||||
|
||||
|
||||
@router.delete("/{skill_id}", summary="删除技能")
|
||||
def delete_skill(
|
||||
skill_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""删除技能"""
|
||||
tenant_id = current_user.tenant_id
|
||||
SkillService.delete_skill(db, skill_id, tenant_id)
|
||||
return success(msg="技能删除成功")
|
||||
@@ -2,15 +2,23 @@ from fastapi import APIRouter, Depends
|
||||
from sqlalchemy.orm import Session
|
||||
import uuid
|
||||
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.db import get_db
|
||||
from app.dependencies import get_current_user, get_current_superuser
|
||||
from app.models.user_model import User
|
||||
from app.schemas import user_schema
|
||||
from app.schemas.user_schema import ChangePasswordRequest, AdminChangePasswordRequest
|
||||
from app.schemas.user_schema import (
|
||||
ChangePasswordRequest,
|
||||
AdminChangePasswordRequest,
|
||||
SendEmailCodeRequest,
|
||||
VerifyEmailCodeRequest,
|
||||
VerifyPasswordRequest)
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.services import user_service
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.core.response_utils import success
|
||||
from app.core.security import verify_password
|
||||
|
||||
# 获取API专用日志器
|
||||
api_logger = get_api_logger()
|
||||
@@ -92,7 +100,7 @@ def get_current_user_info(
|
||||
result_schema.current_workspace_name = current_workspace.name
|
||||
|
||||
for ws in result.workspaces:
|
||||
if ws.workspace_id == current_user.current_workspace_id:
|
||||
if ws.workspace_id == current_user.current_workspace_id and ws.is_active:
|
||||
result_schema.role = ws.role
|
||||
break
|
||||
|
||||
@@ -120,6 +128,7 @@ def get_tenant_superusers(
|
||||
return success(data=superusers_schema, msg="租户超管列表获取成功")
|
||||
|
||||
|
||||
|
||||
@router.get("/{user_id}", response_model=ApiResponse)
|
||||
def get_user_info_by_id(
|
||||
user_id: uuid.UUID,
|
||||
@@ -180,4 +189,54 @@ async def admin_change_password(
|
||||
return success(msg="密码修改成功")
|
||||
else:
|
||||
api_logger.info(f"管理员密码重置成功: 用户 {request.user_id}, 随机密码已生成")
|
||||
return success(data=generated_password, msg="密码重置成功")
|
||||
return success(data=generated_password, msg="密码重置成功")
|
||||
|
||||
|
||||
@router.post("/verify_pwd", response_model=ApiResponse)
|
||||
def verify_pwd(
|
||||
request: VerifyPasswordRequest,
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""验证当前用户密码"""
|
||||
api_logger.info(f"用户验证密码请求: {current_user.username}")
|
||||
|
||||
is_valid = verify_password(request.password, current_user.hashed_password)
|
||||
api_logger.info(f"用户密码验证结果: {current_user.username}, valid={is_valid}")
|
||||
if not is_valid:
|
||||
raise BusinessException("密码验证失败", code=BizCode.VALIDATION_FAILED)
|
||||
return success(data={"valid": is_valid}, msg="验证完成")
|
||||
|
||||
|
||||
@router.post("/send-email-code", response_model=ApiResponse)
|
||||
async def send_email_code(
|
||||
request: SendEmailCodeRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""发送邮箱验证码"""
|
||||
api_logger.info(f"用户请求发送邮箱验证码: {current_user.username}, email={request.email}")
|
||||
|
||||
await user_service.send_email_code_method(db=db, email=request.email, user_id=current_user.id)
|
||||
|
||||
api_logger.info(f"邮箱验证码已发送: {current_user.username}")
|
||||
return success(msg="验证码已发送到您的邮箱,请查收")
|
||||
|
||||
|
||||
@router.put("/change-email", response_model=ApiResponse)
|
||||
async def change_email(
|
||||
request: VerifyEmailCodeRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""验证验证码并修改邮箱"""
|
||||
api_logger.info(f"用户修改邮箱: {current_user.username}, new_email={request.new_email}")
|
||||
|
||||
await user_service.verify_and_change_email(
|
||||
db=db,
|
||||
user_id=current_user.id,
|
||||
new_email=request.new_email,
|
||||
code=request.code
|
||||
)
|
||||
|
||||
api_logger.info(f"用户邮箱修改成功: {current_user.username}")
|
||||
return success(msg="邮箱修改成功")
|
||||
|
||||
@@ -8,11 +8,11 @@ from sqlalchemy.orm import Session
|
||||
from fastapi import APIRouter, Depends,Header
|
||||
|
||||
from app.db import get_db
|
||||
from app.core.language_utils import get_language_from_header
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.core.response_utils import success, fail
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.api_key_utils import timestamp_to_datetime
|
||||
from app.services.memory_base_service import Translation_English
|
||||
from app.services.user_memory_service import (
|
||||
UserMemoryService,
|
||||
analytics_memory_types,
|
||||
@@ -45,7 +45,6 @@ router = APIRouter(
|
||||
@router.get("/analytics/memory_insight/report", response_model=ApiResponse)
|
||||
async def get_memory_insight_report_api(
|
||||
end_user_id: str,
|
||||
language_type: str = Header(default="zh", alias="X-Language-Type"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
@@ -55,18 +54,10 @@ async def get_memory_insight_report_api(
|
||||
此接口仅查询数据库中已缓存的记忆洞察数据,不执行生成操作。
|
||||
如需生成新的洞察报告,请使用专门的生成接口。
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
workspace_repo = WorkspaceRepository(db)
|
||||
workspace_models = workspace_repo.get_workspace_models_configs(workspace_id)
|
||||
|
||||
if workspace_models:
|
||||
model_id = workspace_models.get("llm", None)
|
||||
else:
|
||||
model_id = None
|
||||
api_logger.info(f"记忆洞察报告查询请求: end_user_id={end_user_id}, user={current_user.username}")
|
||||
try:
|
||||
# 调用服务层获取缓存数据
|
||||
result = await user_memory_service.get_cached_memory_insight(db, end_user_id,model_id,language_type)
|
||||
result = await user_memory_service.get_cached_memory_insight(db, end_user_id)
|
||||
|
||||
if result["is_cached"]:
|
||||
api_logger.info(f"成功返回缓存的记忆洞察报告: end_user_id={end_user_id}")
|
||||
@@ -82,7 +73,7 @@ async def get_memory_insight_report_api(
|
||||
@router.get("/analytics/user_summary", response_model=ApiResponse)
|
||||
async def get_user_summary_api(
|
||||
end_user_id: str,
|
||||
language_type: str = Header(default="zh", alias="X-Language-Type"),
|
||||
language_type: str = Header(default=None, alias="X-Language-Type"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
@@ -91,7 +82,14 @@ async def get_user_summary_api(
|
||||
|
||||
此接口仅查询数据库中已缓存的用户摘要数据,不执行生成操作。
|
||||
如需生成新的用户摘要,请使用专门的生成接口。
|
||||
|
||||
语言控制:
|
||||
- 使用 X-Language-Type Header 指定语言
|
||||
- 如果未传 Header,默认使用中文 (zh)
|
||||
"""
|
||||
# 使用集中化的语言校验
|
||||
language = get_language_from_header(language_type)
|
||||
|
||||
workspace_id = current_user.current_workspace_id
|
||||
workspace_repo = WorkspaceRepository(db)
|
||||
workspace_models = workspace_repo.get_workspace_models_configs(workspace_id)
|
||||
@@ -103,7 +101,7 @@ async def get_user_summary_api(
|
||||
api_logger.info(f"用户摘要查询请求: end_user_id={end_user_id}, user={current_user.username}")
|
||||
try:
|
||||
# 调用服务层获取缓存数据
|
||||
result = await user_memory_service.get_cached_user_summary(db, end_user_id,model_id,language_type)
|
||||
result = await user_memory_service.get_cached_user_summary(db, end_user_id,model_id,language)
|
||||
|
||||
if result["is_cached"]:
|
||||
api_logger.info(f"成功返回缓存的用户摘要: end_user_id={end_user_id}")
|
||||
@@ -119,6 +117,7 @@ async def get_user_summary_api(
|
||||
@router.post("/analytics/generate_cache", response_model=ApiResponse)
|
||||
async def generate_cache_api(
|
||||
request: GenerateCacheRequest,
|
||||
language_type: str = Header(default=None, alias="X-Language-Type"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
@@ -127,7 +126,14 @@ async def generate_cache_api(
|
||||
|
||||
- 如果提供 end_user_id,只为该用户生成
|
||||
- 如果不提供,为当前工作空间的所有用户生成
|
||||
|
||||
语言控制:
|
||||
- 使用 X-Language-Type Header 指定语言 ("zh" 中文, "en" 英文)
|
||||
- 如果未传 Header,默认使用中文 (zh)
|
||||
"""
|
||||
# 使用集中化的语言校验
|
||||
language = get_language_from_header(language_type)
|
||||
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
# 检查用户是否已选择工作空间
|
||||
@@ -135,27 +141,27 @@ async def generate_cache_api(
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试生成缓存但未选择工作空间")
|
||||
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
||||
|
||||
group_id = request.end_user_id
|
||||
end_user_id = request.end_user_id
|
||||
|
||||
api_logger.info(
|
||||
f"缓存生成请求: user={current_user.username}, workspace={workspace_id}, "
|
||||
f"end_user_id={group_id if group_id else '全部用户'}"
|
||||
f"end_user_id={end_user_id if end_user_id else '全部用户'}, language={language}"
|
||||
)
|
||||
|
||||
try:
|
||||
if group_id:
|
||||
if end_user_id:
|
||||
# 为单个用户生成
|
||||
api_logger.info(f"开始为单个用户生成缓存: end_user_id={group_id}")
|
||||
api_logger.info(f"开始为单个用户生成缓存: end_user_id={end_user_id}")
|
||||
|
||||
# 生成记忆洞察
|
||||
insight_result = await user_memory_service.generate_and_cache_insight(db, group_id, workspace_id)
|
||||
insight_result = await user_memory_service.generate_and_cache_insight(db, end_user_id, workspace_id, language=language)
|
||||
|
||||
# 生成用户摘要
|
||||
summary_result = await user_memory_service.generate_and_cache_summary(db, group_id, workspace_id)
|
||||
summary_result = await user_memory_service.generate_and_cache_summary(db, end_user_id, workspace_id, language=language)
|
||||
|
||||
# 构建响应
|
||||
result = {
|
||||
"end_user_id": group_id,
|
||||
"end_user_id": end_user_id,
|
||||
"insight_success": insight_result["success"],
|
||||
"summary_success": summary_result["success"],
|
||||
"errors": []
|
||||
@@ -175,9 +181,9 @@ async def generate_cache_api(
|
||||
|
||||
# 记录结果
|
||||
if result["insight_success"] and result["summary_success"]:
|
||||
api_logger.info(f"成功为用户 {group_id} 生成缓存")
|
||||
api_logger.info(f"成功为用户 {end_user_id} 生成缓存")
|
||||
else:
|
||||
api_logger.warning(f"用户 {group_id} 的缓存生成部分失败: {result['errors']}")
|
||||
api_logger.warning(f"用户 {end_user_id} 的缓存生成部分失败: {result['errors']}")
|
||||
|
||||
return success(data=result, msg="生成完成")
|
||||
|
||||
@@ -185,7 +191,7 @@ async def generate_cache_api(
|
||||
# 为整个工作空间生成
|
||||
api_logger.info(f"开始为工作空间 {workspace_id} 批量生成缓存")
|
||||
|
||||
result = await user_memory_service.generate_cache_for_workspace(db, workspace_id)
|
||||
result = await user_memory_service.generate_cache_for_workspace(db, workspace_id, language=language)
|
||||
|
||||
# 记录统计信息
|
||||
api_logger.info(
|
||||
@@ -385,10 +391,13 @@ async def update_end_user_profile(
|
||||
return fail(BizCode.INTERNAL_ERROR, "用户信息更新失败", error_msg)
|
||||
|
||||
@router.get("/memory_space/timeline_memories", response_model=ApiResponse)
|
||||
async def memory_space_timeline_of_shared_memories(id: str, label: str,language_type: str = Header(default="zh", alias="X-Language-Type"),
|
||||
async def memory_space_timeline_of_shared_memories(id: str, label: str,language_type: str = Header(default=None, alias="X-Language-Type"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
# 使用集中化的语言校验
|
||||
language = get_language_from_header(language_type)
|
||||
|
||||
workspace_id=current_user.current_workspace_id
|
||||
workspace_repo = WorkspaceRepository(db)
|
||||
workspace_models = workspace_repo.get_workspace_models_configs(workspace_id)
|
||||
@@ -398,7 +407,7 @@ async def memory_space_timeline_of_shared_memories(id: str, label: str,language_
|
||||
else:
|
||||
model_id = None
|
||||
MemoryEntity = MemoryEntityService(id, label)
|
||||
timeline_memories_result = await MemoryEntity.get_timeline_memories_server(model_id, language_type)
|
||||
timeline_memories_result = await MemoryEntity.get_timeline_memories_server(model_id, language)
|
||||
|
||||
return success(data=timeline_memories_result, msg="共同记忆时间线")
|
||||
@router.get("/memory_space/relationship_evolution", response_model=ApiResponse)
|
||||
|
||||
@@ -1,610 +0,0 @@
|
||||
"""
|
||||
工作流 API 控制器
|
||||
"""
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends, Path, Query
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.db import get_db
|
||||
from app.dependencies import get_current_user, cur_workspace_access_guard
|
||||
|
||||
from app.models.user_model import User
|
||||
from app.models.app_model import App
|
||||
from app.services.workflow_service import WorkflowService, get_workflow_service
|
||||
from app.schemas.workflow_schema import (
|
||||
WorkflowConfigCreate,
|
||||
WorkflowConfigUpdate,
|
||||
WorkflowConfig,
|
||||
WorkflowValidationResponse,
|
||||
WorkflowExecution,
|
||||
WorkflowNodeExecution,
|
||||
WorkflowExecutionRequest,
|
||||
WorkflowExecutionResponse
|
||||
)
|
||||
from app.core.response_utils import success, fail
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.error_codes import BizCode
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/apps", tags=["workflow"])
|
||||
|
||||
|
||||
# ==================== 工作流配置管理 ====================
|
||||
|
||||
@router.post("/{app_id}/workflow")
|
||||
@cur_workspace_access_guard()
|
||||
async def create_workflow_config(
|
||||
app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
|
||||
config: WorkflowConfigCreate,
|
||||
db: Annotated[Session, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_user)],
|
||||
service: Annotated[WorkflowService, Depends(get_workflow_service)]
|
||||
):
|
||||
"""创建工作流配置
|
||||
|
||||
创建或更新应用的工作流配置。配置会进行基础验证,但允许保存不完整的配置(草稿)。
|
||||
"""
|
||||
try:
|
||||
# 验证应用是否存在且属于当前工作空间
|
||||
app = db.query(App).filter(
|
||||
App.id == app_id,
|
||||
App.workspace_id == current_user.current_workspace_id,
|
||||
App.is_active == True
|
||||
).first()
|
||||
|
||||
if not app:
|
||||
return fail(
|
||||
code=BizCode.NOT_FOUND,
|
||||
msg="应用不存在或无权访问"
|
||||
)
|
||||
|
||||
# 验证应用类型
|
||||
if app.type != "workflow":
|
||||
return fail(
|
||||
code=BizCode.INVALID_PARAMETER,
|
||||
msg=f"应用类型必须为 workflow,当前为 {app.type}"
|
||||
)
|
||||
|
||||
# 创建工作流配置
|
||||
workflow_config = service.create_workflow_config(
|
||||
app_id=app_id,
|
||||
nodes=[node.model_dump() for node in config.nodes],
|
||||
edges=[edge.model_dump() for edge in config.edges],
|
||||
variables=[var.model_dump() for var in config.variables],
|
||||
execution_config=config.execution_config.model_dump(),
|
||||
triggers=[trigger.model_dump() for trigger in config.triggers],
|
||||
validate=True # 进行基础验证
|
||||
)
|
||||
|
||||
return success(
|
||||
data=WorkflowConfig.model_validate(workflow_config),
|
||||
msg="工作流配置创建成功"
|
||||
)
|
||||
|
||||
except BusinessException as e:
|
||||
logger.warning(f"创建工作流配置失败: {e.message}")
|
||||
return fail(code=e.error_code, msg=e.message)
|
||||
except Exception as e:
|
||||
logger.error(f"创建工作流配置异常: {e}", exc_info=True)
|
||||
return fail(
|
||||
code=BizCode.INTERNAL_ERROR,
|
||||
msg=f"创建工作流配置失败: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
#
|
||||
# @router.get("/{app_id}/workflow")
|
||||
# async def get_workflow_config(
|
||||
# app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
|
||||
# db: Annotated[Session, Depends(get_db)],
|
||||
# current_user: Annotated[User, Depends(get_current_user)]
|
||||
#
|
||||
# ):
|
||||
# """获取工作流配置
|
||||
#
|
||||
# 获取应用的工作流配置详情。
|
||||
# """
|
||||
# try:
|
||||
# # 验证应用是否存在且属于当前工作空间
|
||||
# app = db.query(App).filter(
|
||||
# App.id == app_id,
|
||||
# App.workspace_id == current_user.current_workspace_id,
|
||||
# App.is_active == True
|
||||
# ).first()
|
||||
#
|
||||
# if not app:
|
||||
# return fail(
|
||||
# code=BizCode.NOT_FOUND,
|
||||
# msg="应用不存在或无权访问"
|
||||
# )
|
||||
#
|
||||
# # 获取工作流配置
|
||||
# service = WorkflowService(db)
|
||||
# workflow_config = service.get_workflow_config(app_id)
|
||||
#
|
||||
# if not workflow_config:
|
||||
# return fail(
|
||||
# code=BizCode.NOT_FOUND,
|
||||
# msg="工作流配置不存在"
|
||||
# )
|
||||
#
|
||||
# return success(
|
||||
# data=WorkflowConfig.model_validate(workflow_config)
|
||||
# )
|
||||
#
|
||||
# except Exception as e:
|
||||
# logger.error(f"获取工作流配置异常: {e}", exc_info=True)
|
||||
# return fail(
|
||||
# code=BizCode.INTERNAL_ERROR,
|
||||
# msg=f"获取工作流配置失败: {str(e)}"
|
||||
# )
|
||||
|
||||
|
||||
# @router.put("/{app_id}/workflow")
|
||||
# async def update_workflow_config(
|
||||
# app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
|
||||
# config: WorkflowConfigUpdate,
|
||||
# db: Annotated[Session, Depends(get_db)],
|
||||
# current_user: Annotated[User, Depends(get_current_user)],
|
||||
# service: Annotated[WorkflowService, Depends(get_workflow_service)]
|
||||
# ):
|
||||
# """更新工作流配置
|
||||
|
||||
# 更新应用的工作流配置。可以部分更新,未提供的字段保持不变。
|
||||
# """
|
||||
# try:
|
||||
# # 验证应用是否存在且属于当前工作空间
|
||||
# app = db.query(App).filter(
|
||||
# App.id == app_id,
|
||||
# App.workspace_id == current_user.current_workspace_id,
|
||||
# App.is_active == True
|
||||
# ).first()
|
||||
|
||||
# if not app:
|
||||
# return fail(
|
||||
# code=BizCode.NOT_FOUND,
|
||||
# msg="应用不存在或无权访问"
|
||||
# )
|
||||
|
||||
# # 更新工作流配置
|
||||
# workflow_config = service.update_workflow_config(
|
||||
# app_id=app_id,
|
||||
# nodes=[node.model_dump() for node in config.nodes] if config.nodes else None,
|
||||
# edges=[edge.model_dump() for edge in config.edges] if config.edges else None,
|
||||
# variables=[var.model_dump() for var in config.variables] if config.variables else None,
|
||||
# execution_config=config.execution_config.model_dump() if config.execution_config else None,
|
||||
# triggers=[trigger.model_dump() for trigger in config.triggers] if config.triggers else None,
|
||||
# validate=True
|
||||
# )
|
||||
|
||||
# return success(
|
||||
# data=WorkflowConfig.model_validate(workflow_config),
|
||||
# msg="工作流配置更新成功"
|
||||
# )
|
||||
|
||||
# except BusinessException as e:
|
||||
# logger.warning(f"更新工作流配置失败: {e.message}")
|
||||
# return fail(code=e.error_code, msg=e.message)
|
||||
# except Exception as e:
|
||||
# logger.error(f"更新工作流配置异常: {e}", exc_info=True)
|
||||
# return fail(
|
||||
# code=BizCode.INTERNAL_ERROR,
|
||||
# msg=f"更新工作流配置失败: {str(e)}"
|
||||
# )
|
||||
|
||||
|
||||
@router.delete("/{app_id}/workflow")
|
||||
async def delete_workflow_config(
|
||||
app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
|
||||
db: Annotated[Session, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_user)],
|
||||
service: Annotated[WorkflowService, Depends(get_workflow_service)]
|
||||
):
|
||||
"""删除工作流配置
|
||||
|
||||
删除应用的工作流配置。
|
||||
"""
|
||||
try:
|
||||
# 验证应用是否存在且属于当前工作空间
|
||||
app = db.query(App).filter(
|
||||
App.id == app_id,
|
||||
App.workspace_id == current_user.current_workspace_id,
|
||||
App.is_active == True
|
||||
).first()
|
||||
|
||||
if not app:
|
||||
return fail(
|
||||
code=BizCode.NOT_FOUND,
|
||||
msg="应用不存在或无权访问"
|
||||
)
|
||||
|
||||
# 删除工作流配置
|
||||
deleted = service.delete_workflow_config(app_id)
|
||||
|
||||
if not deleted:
|
||||
return fail(
|
||||
code=BizCode.NOT_FOUND,
|
||||
msg="工作流配置不存在"
|
||||
)
|
||||
|
||||
return success(msg="工作流配置删除成功")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"删除工作流配置异常: {e}", exc_info=True)
|
||||
return fail(
|
||||
code=BizCode.INTERNAL_ERROR,
|
||||
msg=f"删除工作流配置失败: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/{app_id}/workflow/validate")
|
||||
async def validate_workflow_config(
|
||||
app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
|
||||
db: Annotated[Session, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_user)],
|
||||
service: Annotated[WorkflowService, Depends(get_workflow_service)],
|
||||
for_publish: Annotated[bool, Query(description="是否为发布验证")] = False
|
||||
):
|
||||
"""验证工作流配置
|
||||
|
||||
验证工作流配置是否有效。可以选择是否进行发布级别的严格验证。
|
||||
"""
|
||||
try:
|
||||
# 验证应用是否存在且属于当前工作空间
|
||||
app = db.query(App).filter(
|
||||
App.id == app_id,
|
||||
App.workspace_id == current_user.current_workspace_id,
|
||||
App.is_active == True
|
||||
).first()
|
||||
|
||||
if not app:
|
||||
return fail(
|
||||
code=BizCode.NOT_FOUND,
|
||||
msg="应用不存在或无权访问"
|
||||
)
|
||||
|
||||
# 验证工作流配置
|
||||
|
||||
if for_publish:
|
||||
is_valid, errors = service.validate_workflow_config_for_publish(app_id)
|
||||
else:
|
||||
workflow_config = service.get_workflow_config(app_id)
|
||||
if not workflow_config:
|
||||
return fail(
|
||||
code=BizCode.NOT_FOUND,
|
||||
msg="工作流配置不存在"
|
||||
)
|
||||
|
||||
from app.core.workflow.validator import validate_workflow_config as validate_config
|
||||
config_dict = {
|
||||
"nodes": workflow_config.nodes,
|
||||
"edges": workflow_config.edges,
|
||||
"variables": workflow_config.variables,
|
||||
"execution_config": workflow_config.execution_config,
|
||||
"triggers": workflow_config.triggers
|
||||
}
|
||||
is_valid, errors = validate_config(config_dict, for_publish=False)
|
||||
|
||||
return success(
|
||||
data=WorkflowValidationResponse(
|
||||
is_valid=is_valid,
|
||||
errors=errors,
|
||||
warnings=[]
|
||||
)
|
||||
)
|
||||
|
||||
except BusinessException as e:
|
||||
logger.warning(f"验证工作流配置失败: {e.message}")
|
||||
return fail(code=e.error_code, msg=e.message)
|
||||
except Exception as e:
|
||||
logger.error(f"验证工作流配置异常: {e}", exc_info=True)
|
||||
return fail(
|
||||
code=BizCode.INTERNAL_ERROR,
|
||||
msg=f"验证工作流配置失败: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
# ==================== 工作流执行管理 ====================
|
||||
|
||||
@router.get("/{app_id}/workflow/executions")
|
||||
async def get_workflow_executions(
|
||||
app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
|
||||
db: Annotated[Session, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_user)],
|
||||
service: Annotated[WorkflowService, Depends(get_workflow_service)],
|
||||
limit: Annotated[int, Query(ge=1, le=100)] = 50,
|
||||
offset: Annotated[int, Query(ge=0)] = 0
|
||||
):
|
||||
"""获取工作流执行记录列表
|
||||
|
||||
获取应用的工作流执行历史记录。
|
||||
"""
|
||||
try:
|
||||
# 验证应用是否存在且属于当前工作空间
|
||||
app = db.query(App).filter(
|
||||
App.id == app_id,
|
||||
App.workspace_id == current_user.current_workspace_id,
|
||||
App.is_active == True
|
||||
).first()
|
||||
|
||||
if not app:
|
||||
return fail(
|
||||
code=BizCode.NOT_FOUND,
|
||||
msg="应用不存在或无权访问"
|
||||
)
|
||||
|
||||
# 获取执行记录
|
||||
executions = service.get_executions_by_app(app_id, limit, offset)
|
||||
|
||||
# 获取统计信息
|
||||
statistics = service.get_execution_statistics(app_id)
|
||||
|
||||
return success(
|
||||
data={
|
||||
"executions": [WorkflowExecution.model_validate(e) for e in executions],
|
||||
"statistics": statistics,
|
||||
"pagination": {
|
||||
"limit": limit,
|
||||
"offset": offset,
|
||||
"total": statistics["total"]
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取工作流执行记录异常: {e}", exc_info=True)
|
||||
return fail(
|
||||
code=BizCode.INTERNAL_ERROR,
|
||||
msg=f"获取工作流执行记录失败: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/workflow/executions/{execution_id}")
|
||||
async def get_workflow_execution(
|
||||
execution_id: Annotated[str, Path(description="执行 ID")],
|
||||
db: Annotated[Session, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_user)],
|
||||
service: Annotated[WorkflowService, Depends(get_workflow_service)]
|
||||
):
|
||||
"""获取工作流执行详情
|
||||
|
||||
获取单个工作流执行的详细信息,包括所有节点的执行记录。
|
||||
"""
|
||||
try:
|
||||
# 获取执行记录
|
||||
execution = service.get_execution(execution_id)
|
||||
|
||||
if not execution:
|
||||
return fail(
|
||||
code=BizCode.NOT_FOUND,
|
||||
msg="执行记录不存在"
|
||||
)
|
||||
|
||||
# 验证应用是否属于当前工作空间
|
||||
app = db.query(App).filter(
|
||||
App.id == execution.app_id,
|
||||
App.workspace_id == current_user.current_workspace_id,
|
||||
App.is_active == True
|
||||
).first()
|
||||
|
||||
if not app:
|
||||
return fail(
|
||||
code=BizCode.NOT_FOUND,
|
||||
msg="无权访问该执行记录"
|
||||
)
|
||||
|
||||
# 获取节点执行记录
|
||||
node_executions = service.node_execution_repo.get_by_execution_id(execution.id)
|
||||
|
||||
return success(
|
||||
data={
|
||||
"execution": WorkflowExecution.model_validate(execution),
|
||||
"node_executions": [
|
||||
WorkflowNodeExecution.model_validate(ne) for ne in node_executions
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取工作流执行详情异常: {e}", exc_info=True)
|
||||
return fail(
|
||||
code=BizCode.INTERNAL_ERROR,
|
||||
msg=f"获取工作流执行详情失败: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
# ==================== 工作流执行 ====================
|
||||
@router.post("/{app_id}/workflow/run")
|
||||
async def run_workflow(
|
||||
app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
|
||||
request: WorkflowExecutionRequest,
|
||||
db: Annotated[Session, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_user)],
|
||||
service: Annotated[WorkflowService, Depends(get_workflow_service)]
|
||||
):
|
||||
"""执行工作流
|
||||
|
||||
执行工作流并返回结果。支持流式和非流式两种模式。
|
||||
|
||||
**非流式模式**:等待工作流执行完成后返回完整结果。
|
||||
|
||||
**流式模式**:实时返回执行过程中的事件(节点开始、节点完成、工作流完成等)。
|
||||
"""
|
||||
try:
|
||||
# 验证应用是否存在且属于当前工作空间
|
||||
app = db.query(App).filter(
|
||||
App.id == app_id,
|
||||
App.workspace_id == current_user.current_workspace_id,
|
||||
App.is_active == True
|
||||
).first()
|
||||
|
||||
if not app:
|
||||
return fail(
|
||||
code=BizCode.NOT_FOUND,
|
||||
msg="应用不存在或无权访问"
|
||||
)
|
||||
|
||||
# 验证应用类型
|
||||
if app.type != "workflow":
|
||||
return fail(
|
||||
code=BizCode.INVALID_PARAMETER,
|
||||
msg=f"应用类型必须为 workflow,当前为 {app.type}"
|
||||
)
|
||||
|
||||
# 准备输入数据
|
||||
input_data = {
|
||||
"message": request.message or "",
|
||||
"variables": request.variables
|
||||
}
|
||||
|
||||
# 执行工作流
|
||||
|
||||
if request.stream:
|
||||
# 流式执行
|
||||
from fastapi.responses import StreamingResponse
|
||||
import json
|
||||
|
||||
async def event_generator():
|
||||
"""生成 SSE 事件
|
||||
|
||||
SSE 格式:
|
||||
event: <event_type>
|
||||
data: <json_data>
|
||||
|
||||
支持的事件类型:
|
||||
- workflow_start: 工作流开始
|
||||
- workflow_end: 工作流结束
|
||||
- node_start: 节点开始执行
|
||||
- node_end: 节点执行完成
|
||||
- node_chunk: 中间节点的流式输出
|
||||
- message: 最终消息的流式输出(End 节点及其相邻节点)
|
||||
"""
|
||||
try:
|
||||
async for event in await service.run_workflow(
|
||||
app_id=app_id,
|
||||
input_data=input_data,
|
||||
triggered_by=current_user.id,
|
||||
conversation_id=uuid.UUID(request.conversation_id) if request.conversation_id else None,
|
||||
stream=True
|
||||
):
|
||||
# 提取事件类型和数据
|
||||
event_type = event.get("event", "message")
|
||||
event_data = event.get("data", {})
|
||||
|
||||
# 转换为标准 SSE 格式(字符串)
|
||||
# event: <type>
|
||||
# data: <json>
|
||||
sse_message = f"event: {event_type}\ndata: {json.dumps(event_data)}\n\n"
|
||||
yield sse_message
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"流式执行异常: {e}", exc_info=True)
|
||||
# 发送错误事件
|
||||
sse_error = f"event: error\ndata: {json.dumps({'error': str(e)})}\n\n"
|
||||
yield sse_error
|
||||
|
||||
return StreamingResponse(
|
||||
event_generator(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no" # 禁用 nginx 缓冲
|
||||
}
|
||||
)
|
||||
else:
|
||||
# 非流式执行
|
||||
result = await service.run_workflow(
|
||||
app_id=app_id,
|
||||
input_data=input_data,
|
||||
triggered_by=current_user.id,
|
||||
conversation_id=uuid.UUID(request.conversation_id) if request.conversation_id else None,
|
||||
stream=False
|
||||
)
|
||||
|
||||
return success(
|
||||
data=WorkflowExecutionResponse(
|
||||
execution_id=result["execution_id"],
|
||||
status=result["status"],
|
||||
output=result.get("output"),
|
||||
output_data=result.get("output_data"),
|
||||
error_message=result.get("error_message"),
|
||||
elapsed_time=result.get("elapsed_time"),
|
||||
token_usage=result.get("token_usage")
|
||||
),
|
||||
msg="工作流执行完成"
|
||||
)
|
||||
|
||||
except BusinessException as e:
|
||||
logger.warning(f"执行工作流失败: {e.message}")
|
||||
return fail(code=e.error_code, msg=e.message)
|
||||
except Exception as e:
|
||||
logger.error(f"执行工作流异常: {e}", exc_info=True)
|
||||
return fail(
|
||||
code=BizCode.INTERNAL_ERROR,
|
||||
msg=f"执行工作流失败: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/workflow/executions/{execution_id}/cancel")
|
||||
async def cancel_workflow_execution(
|
||||
execution_id: Annotated[str, Path(description="执行 ID")],
|
||||
db: Annotated[Session, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_user)],
|
||||
service: Annotated[WorkflowService, Depends(get_workflow_service)]
|
||||
):
|
||||
"""取消工作流执行
|
||||
|
||||
取消正在运行的工作流执行。
|
||||
|
||||
**注意**:当前版本仅更新状态为 cancelled,实际的执行取消功能待实现。
|
||||
"""
|
||||
try:
|
||||
# 获取执行记录
|
||||
execution = service.get_execution(execution_id)
|
||||
|
||||
if not execution:
|
||||
return fail(
|
||||
code=BizCode.NOT_FOUND,
|
||||
msg="执行记录不存在"
|
||||
)
|
||||
|
||||
# 验证应用是否属于当前工作空间
|
||||
app = db.query(App).filter(
|
||||
App.id == execution.app_id,
|
||||
App.workspace_id == current_user.current_workspace_id,
|
||||
App.is_active == True
|
||||
).first()
|
||||
|
||||
if not app:
|
||||
return fail(
|
||||
code=BizCode.NOT_FOUND,
|
||||
msg="无权访问该执行记录"
|
||||
)
|
||||
|
||||
# 检查执行状态
|
||||
if execution.status not in ["pending", "running"]:
|
||||
return fail(
|
||||
code=BizCode.INVALID_PARAMETER,
|
||||
msg=f"无法取消状态为 {execution.status} 的执行"
|
||||
)
|
||||
|
||||
# 更新状态为 cancelled
|
||||
service.update_execution_status(execution_id, "cancelled")
|
||||
|
||||
return success(msg="工作流执行已取消")
|
||||
|
||||
except BusinessException as e:
|
||||
logger.warning(f"取消工作流执行失败: {e.message}")
|
||||
return fail(code=e.code, msg=e.message)
|
||||
except Exception as e:
|
||||
logger.error(f"取消工作流执行异常: {e}", exc_info=True)
|
||||
return fail(
|
||||
code=BizCode.INTERNAL_ERROR,
|
||||
msg=f"取消工作流执行失败: {str(e)}"
|
||||
)
|
||||
4
api/app/core/__init__.py
Normal file
4
api/app/core/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
# -*- coding: UTF-8 -*-
|
||||
# Author: Eternity
|
||||
# @Email: 1533512157@qq.com
|
||||
# @Time : 2026/2/9 16:24
|
||||
162
api/app/core/agent/agent_middleware.py
Normal file
162
api/app/core/agent/agent_middleware.py
Normal file
@@ -0,0 +1,162 @@
|
||||
"""Agent Middleware - 动态技能过滤"""
|
||||
import uuid
|
||||
from typing import List, Dict, Any, Optional
|
||||
from langchain_core.runnables import RunnablePassthrough
|
||||
|
||||
from app.services.skill_service import SkillService
|
||||
from app.repositories.skill_repository import SkillRepository
|
||||
|
||||
|
||||
class AgentMiddleware:
|
||||
"""Agent 中间件 - 用于动态过滤和加载技能"""
|
||||
|
||||
def __init__(self, skills: Optional[dict] = None):
|
||||
"""
|
||||
初始化中间件
|
||||
|
||||
Args:
|
||||
skills: 技能配置字典 {"enabled": bool, "all_skills": bool, "skill_ids": [...]}
|
||||
"""
|
||||
self.skills = skills or {}
|
||||
self.enabled = self.skills.get('enabled', False)
|
||||
self.all_skills = self.skills.get('all_skills', False)
|
||||
self.skill_ids = self.skills.get('skill_ids', [])
|
||||
|
||||
@staticmethod
|
||||
def filter_tools(
|
||||
tools: List,
|
||||
message: str = "",
|
||||
skill_configs: Dict[str, Any] = None,
|
||||
tool_to_skill_map: Dict[str, str] = None
|
||||
) -> tuple[List, List[str]]:
|
||||
"""
|
||||
根据消息内容和技能配置动态过滤工具
|
||||
|
||||
Args:
|
||||
tools: 所有可用工具列表
|
||||
message: 用户消息(可用于智能过滤)
|
||||
skill_configs: 技能配置字典 {skill_id: {"keywords": [...], "enabled": True, "prompt": "..."}}
|
||||
tool_to_skill_map: 工具到技能的映射 {tool_name: skill_id}
|
||||
|
||||
Returns:
|
||||
(过滤后的工具列表, 激活的技能ID列表)
|
||||
"""
|
||||
if not tools:
|
||||
return [], []
|
||||
|
||||
# 如果没有技能配置,返回所有工具
|
||||
if not skill_configs:
|
||||
return tools, []
|
||||
|
||||
# 基于关键词匹配激活技能
|
||||
activated_skill_ids = []
|
||||
message_lower = message.lower()
|
||||
|
||||
for skill_id, config in skill_configs.items():
|
||||
if not config.get('enabled', True):
|
||||
continue
|
||||
|
||||
keywords = config.get('keywords', [])
|
||||
# 如果没有关键词限制,或消息包含关键词,则激活该技能
|
||||
if not keywords or any(kw.lower() in message_lower for kw in keywords):
|
||||
activated_skill_ids.append(skill_id)
|
||||
|
||||
# 如果没有工具映射关系,返回所有工具
|
||||
if not tool_to_skill_map:
|
||||
return tools, activated_skill_ids
|
||||
|
||||
# 根据激活的技能过滤工具
|
||||
filtered_tools = []
|
||||
for tool in tools:
|
||||
tool_name = getattr(tool, 'name', str(id(tool)))
|
||||
# 如果工具不属于任何skill(base_tools),或者工具所属的skill被激活,则保留
|
||||
if tool_name not in tool_to_skill_map or tool_to_skill_map[tool_name] in activated_skill_ids:
|
||||
filtered_tools.append(tool)
|
||||
|
||||
return filtered_tools, activated_skill_ids
|
||||
|
||||
def load_skill_tools(self, db, tenant_id: uuid.UUID, base_tools: List = None) -> tuple[List, Dict[str, Any], Dict[str, str]]:
|
||||
"""
|
||||
加载技能关联的工具
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
tenant_id: 租户id
|
||||
base_tools: 基础工具列表
|
||||
|
||||
Returns:
|
||||
(工具列表, 技能配置字典, 工具到技能的映射 {tool_name: skill_id})
|
||||
"""
|
||||
|
||||
tools_dict = {}
|
||||
tool_to_skill_map = {} # 工具名称到技能ID的映射
|
||||
|
||||
if base_tools:
|
||||
for tool in base_tools:
|
||||
tool_name = getattr(tool, 'name', str(id(tool)))
|
||||
tools_dict[tool_name] = tool
|
||||
# base_tools 不属于任何 skill,不加入映射
|
||||
|
||||
skill_configs = {}
|
||||
skill_ids_to_load = []
|
||||
|
||||
# 如果启用技能且 all_skills 为 True,加载租户下所有激活的技能
|
||||
if self.enabled and self.all_skills:
|
||||
skills, _ = SkillRepository.list_skills(db, tenant_id, is_active=True, page=1, pagesize=1000)
|
||||
skill_ids_to_load = [str(skill.id) for skill in skills]
|
||||
elif self.enabled and self.skill_ids:
|
||||
skill_ids_to_load = self.skill_ids
|
||||
|
||||
if skill_ids_to_load:
|
||||
for skill_id in skill_ids_to_load:
|
||||
try:
|
||||
skill = SkillRepository.get_by_id(db, uuid.UUID(skill_id), tenant_id)
|
||||
if skill and skill.is_active:
|
||||
# 保存技能配置(包含prompt)
|
||||
config = skill.config or {}
|
||||
config['prompt'] = skill.prompt
|
||||
config['name'] = skill.name
|
||||
skill_configs[skill_id] = config
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
# 加载技能工具并获取映射关系
|
||||
skill_tools, skill_tool_map = SkillService.load_skill_tools(db, skill_ids_to_load, tenant_id)
|
||||
|
||||
# 只添加不冲突的 skill_tools
|
||||
for tool in skill_tools:
|
||||
tool_name = getattr(tool, 'name', str(id(tool)))
|
||||
if tool_name not in tools_dict:
|
||||
tools_dict[tool_name] = tool
|
||||
# 复制映射关系
|
||||
if tool_name in skill_tool_map:
|
||||
tool_to_skill_map[tool_name] = skill_tool_map[tool_name]
|
||||
|
||||
return list(tools_dict.values()), skill_configs, tool_to_skill_map
|
||||
|
||||
@staticmethod
|
||||
def get_active_prompts(activated_skill_ids: List[str], skill_configs: Dict[str, Any]) -> str:
|
||||
"""
|
||||
根据激活的技能ID获取对应的提示词
|
||||
|
||||
Args:
|
||||
activated_skill_ids: 被激活的技能ID列表
|
||||
skill_configs: 技能配置字典
|
||||
|
||||
Returns:
|
||||
合并后的提示词
|
||||
"""
|
||||
prompts = []
|
||||
for skill_id in activated_skill_ids:
|
||||
config = skill_configs.get(skill_id, {})
|
||||
prompt = config.get('prompt')
|
||||
name = config.get('name', 'Skill')
|
||||
if prompt:
|
||||
prompts.append(f"# {name}\n{prompt}")
|
||||
|
||||
return "\n\n".join(prompts) if prompts else ""
|
||||
|
||||
@staticmethod
|
||||
def create_runnable():
|
||||
"""创建可运行的中间件"""
|
||||
return RunnablePassthrough()
|
||||
@@ -7,27 +7,21 @@ LangChain Agent 封装
|
||||
- 支持流式输出
|
||||
- 使用 RedBearLLM 支持多提供商
|
||||
"""
|
||||
import os
|
||||
|
||||
import time
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional, Sequence
|
||||
|
||||
|
||||
from app.core.memory.agent.langgraph_graph.write_graph import write_long_term
|
||||
from app.db import get_db
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.core.memory.agent.utils.redis_tool import store
|
||||
from app.core.models import RedBearLLM, RedBearModelConfig
|
||||
from app.models.models_model import ModelType
|
||||
from app.repositories.memory_short_repository import LongTermMemoryRepository
|
||||
from app.services.memory_agent_service import (
|
||||
get_end_user_connected_config,
|
||||
)
|
||||
from app.services.memory_konwledges_server import write_rag
|
||||
from app.services.task_service import get_task_memory_write_result
|
||||
from app.tasks import write_message_task
|
||||
from langchain.agents import create_agent
|
||||
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
|
||||
from langchain_core.tools import BaseTool
|
||||
|
||||
logger = get_business_logger()
|
||||
|
||||
|
||||
@@ -43,7 +37,9 @@ class LangChainAgent:
|
||||
max_tokens: int = 2000,
|
||||
system_prompt: Optional[str] = None,
|
||||
tools: Optional[Sequence[BaseTool]] = None,
|
||||
streaming: bool = False
|
||||
streaming: bool = False,
|
||||
max_iterations: Optional[int] = None, # 最大迭代次数(None 表示自动计算)
|
||||
max_tool_consecutive_calls: int = 3 # 单个工具最大连续调用次数
|
||||
):
|
||||
"""初始化 LangChain Agent
|
||||
|
||||
@@ -56,13 +52,36 @@ class LangChainAgent:
|
||||
max_tokens: 最大 token 数
|
||||
system_prompt: 系统提示词
|
||||
tools: 工具列表(可选,框架自动走 ReAct 循环)
|
||||
streaming: 是否启用流式输出(默认 True)
|
||||
streaming: 是否启用流式输出
|
||||
max_iterations: 最大迭代次数(None 表示自动计算:基础 5 次 + 每个工具 2 次)
|
||||
max_tool_consecutive_calls: 单个工具最大连续调用次数(默认 3 次)
|
||||
"""
|
||||
self.model_name = model_name
|
||||
self.provider = provider
|
||||
self.system_prompt = system_prompt or "你是一个专业的AI助手"
|
||||
self.tools = tools or []
|
||||
self.streaming = streaming
|
||||
self.max_tool_consecutive_calls = max_tool_consecutive_calls
|
||||
|
||||
# 工具调用计数器:记录每个工具的连续调用次数
|
||||
self.tool_call_counter: Dict[str, int] = {}
|
||||
self.last_tool_called: Optional[str] = None
|
||||
|
||||
# 根据工具数量动态调整最大迭代次数
|
||||
# 基础值 + 每个工具额外的调用机会
|
||||
if max_iterations is None:
|
||||
# 自动计算:基础 5 次 + 每个工具 2 次额外机会
|
||||
self.max_iterations = 5 + len(self.tools) * 2
|
||||
else:
|
||||
self.max_iterations = max_iterations
|
||||
|
||||
self.system_prompt = system_prompt or "你是一个专业的AI助手"
|
||||
|
||||
logger.debug(
|
||||
f"Agent 迭代次数配置: max_iterations={self.max_iterations}, "
|
||||
f"tool_count={len(self.tools)}, "
|
||||
f"max_tool_consecutive_calls={self.max_tool_consecutive_calls}, "
|
||||
f"auto_calculated={max_iterations is None}"
|
||||
)
|
||||
|
||||
# 创建 RedBearLLM(支持多提供商)
|
||||
model_config = RedBearModelConfig(
|
||||
@@ -86,11 +105,14 @@ class LangChainAgent:
|
||||
if streaming and hasattr(self._underlying_llm, 'streaming'):
|
||||
self._underlying_llm.streaming = True
|
||||
|
||||
# 包装工具以跟踪连续调用次数
|
||||
wrapped_tools = self._wrap_tools_with_tracking(self.tools) if self.tools else None
|
||||
|
||||
# 使用 create_agent 创建 agent graph(LangChain 1.x 标准方式)
|
||||
# 无论是否有工具,都使用 agent 统一处理
|
||||
self.agent = create_agent(
|
||||
model=self.llm,
|
||||
tools=self.tools if self.tools else None,
|
||||
tools=wrapped_tools,
|
||||
system_prompt=self.system_prompt
|
||||
)
|
||||
|
||||
@@ -102,17 +124,91 @@ class LangChainAgent:
|
||||
"has_api_base": bool(api_base),
|
||||
"temperature": temperature,
|
||||
"streaming": streaming,
|
||||
"max_iterations": self.max_iterations,
|
||||
"max_tool_consecutive_calls": self.max_tool_consecutive_calls,
|
||||
"tool_count": len(self.tools),
|
||||
"tool_names": [tool.name for tool in self.tools] if self.tools else [],
|
||||
"tool_count": len(self.tools)
|
||||
# "tool_count": len(self.tools)
|
||||
}
|
||||
)
|
||||
|
||||
def _wrap_tools_with_tracking(self, tools: Sequence[BaseTool]) -> List[BaseTool]:
|
||||
"""包装工具以跟踪连续调用次数
|
||||
|
||||
Args:
|
||||
tools: 原始工具列表
|
||||
|
||||
Returns:
|
||||
List[BaseTool]: 包装后的工具列表
|
||||
"""
|
||||
from langchain_core.tools import StructuredTool
|
||||
from functools import wraps
|
||||
|
||||
wrapped_tools = []
|
||||
|
||||
for original_tool in tools:
|
||||
tool_name = original_tool.name
|
||||
original_func = original_tool.func if hasattr(original_tool, 'func') else None
|
||||
|
||||
if not original_func:
|
||||
# 如果无法获取原始函数,直接使用原工具
|
||||
wrapped_tools.append(original_tool)
|
||||
continue
|
||||
|
||||
# 创建包装函数
|
||||
def make_wrapped_func(tool_name, original_func):
|
||||
"""创建包装函数的工厂函数,避免闭包问题"""
|
||||
@wraps(original_func)
|
||||
def wrapped_func(*args, **kwargs):
|
||||
"""包装后的工具函数,跟踪连续调用次数"""
|
||||
# 检查是否是连续调用同一个工具
|
||||
if self.last_tool_called == tool_name:
|
||||
self.tool_call_counter[tool_name] = self.tool_call_counter.get(tool_name, 0) + 1
|
||||
else:
|
||||
# 切换到新工具,重置计数器
|
||||
self.tool_call_counter[tool_name] = 1
|
||||
self.last_tool_called = tool_name
|
||||
|
||||
current_count = self.tool_call_counter[tool_name]
|
||||
|
||||
logger.debug(
|
||||
f"工具调用: {tool_name}, 连续调用次数: {current_count}/{self.max_tool_consecutive_calls}"
|
||||
)
|
||||
|
||||
# 检查是否超过最大连续调用次数
|
||||
if current_count > self.max_tool_consecutive_calls:
|
||||
logger.warning(
|
||||
f"工具 '{tool_name}' 连续调用次数已达上限 ({self.max_tool_consecutive_calls}),"
|
||||
f"返回提示信息"
|
||||
)
|
||||
return (
|
||||
f"工具 '{tool_name}' 已连续调用 {self.max_tool_consecutive_calls} 次,"
|
||||
f"未找到有效结果。请尝试其他方法或直接回答用户的问题。"
|
||||
)
|
||||
|
||||
# 调用原始工具函数
|
||||
return original_func(*args, **kwargs)
|
||||
|
||||
return wrapped_func
|
||||
|
||||
# 使用 StructuredTool 创建新工具
|
||||
wrapped_tool = StructuredTool(
|
||||
name=original_tool.name,
|
||||
description=original_tool.description,
|
||||
func=make_wrapped_func(tool_name, original_func),
|
||||
args_schema=original_tool.args_schema if hasattr(original_tool, 'args_schema') else None
|
||||
)
|
||||
|
||||
wrapped_tools.append(wrapped_tool)
|
||||
|
||||
return wrapped_tools
|
||||
|
||||
def _prepare_messages(
|
||||
self,
|
||||
message: str,
|
||||
history: Optional[List[Dict[str, str]]] = None,
|
||||
context: Optional[str] = None
|
||||
context: Optional[str] = None,
|
||||
files: Optional[List[Dict[str, Any]]] = None
|
||||
) -> List[BaseMessage]:
|
||||
"""准备消息列表
|
||||
|
||||
@@ -120,6 +216,7 @@ class LangChainAgent:
|
||||
message: 用户消息
|
||||
history: 历史消息列表
|
||||
context: 上下文信息
|
||||
files: 多模态文件内容列表(已处理)
|
||||
|
||||
Returns:
|
||||
List[BaseMessage]: 消息列表
|
||||
@@ -142,101 +239,46 @@ class LangChainAgent:
|
||||
if context:
|
||||
user_content = f"参考信息:\n{context}\n\n用户问题:\n{user_content}"
|
||||
|
||||
messages.append(HumanMessage(content=user_content))
|
||||
# 构建用户消息(支持多模态)
|
||||
if files and len(files) > 0:
|
||||
content_parts = self._build_multimodal_content(user_content, files)
|
||||
messages.append(HumanMessage(content=content_parts))
|
||||
else:
|
||||
# 纯文本消息
|
||||
messages.append(HumanMessage(content=user_content))
|
||||
|
||||
return messages
|
||||
# TODO 乐力齐 - 累积多组对话批量写入功能已禁用
|
||||
# async def term_memory_save(self,messages,end_user_end,aimessages):
|
||||
# '''短长期存储redis,为不影响正常使用6句一段话,存储用户名加一个前缀,当数据存够6条返回给neo4j'''
|
||||
# end_user_end=f"Term_{end_user_end}"
|
||||
# print(messages)
|
||||
# print(aimessages)
|
||||
# session_id = store.save_session(
|
||||
# userid=end_user_end,
|
||||
# messages=messages,
|
||||
# apply_id=end_user_end,
|
||||
# group_id=end_user_end,
|
||||
# aimessages=aimessages
|
||||
# )
|
||||
# store.delete_duplicate_sessions()
|
||||
# # logger.info(f'Redis_Agent:{end_user_end};{session_id}')
|
||||
# return session_id
|
||||
|
||||
# TODO 乐力齐 - 累积多组对话批量写入功能已禁用
|
||||
# async def term_memory_redis_read(self,end_user_end):
|
||||
# end_user_end = f"Term_{end_user_end}"
|
||||
# history = store.find_user_apply_group(end_user_end, end_user_end, end_user_end)
|
||||
# # logger.info(f'Redis_Agent:{end_user_end};{history}')
|
||||
# messagss_list=[]
|
||||
# retrieved_content=[]
|
||||
# for messages in history:
|
||||
# query = messages.get("Query")
|
||||
# aimessages = messages.get("Answer")
|
||||
# messagss_list.append(f'用户:{query}。AI回复:{aimessages}')
|
||||
# retrieved_content.append({query: aimessages})
|
||||
# return messagss_list,retrieved_content
|
||||
|
||||
async def write(self, storage_type, end_user_id, user_message, ai_message, user_rag_memory_id, actual_end_user_id, actual_config_id):
|
||||
def _build_multimodal_content(self, text: str, files: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
写入记忆(支持结构化消息)
|
||||
构建多模态消息内容
|
||||
|
||||
Args:
|
||||
storage_type: 存储类型 (neo4j/rag)
|
||||
end_user_id: 终端用户ID
|
||||
user_message: 用户消息内容
|
||||
ai_message: AI 回复内容
|
||||
user_rag_memory_id: RAG 记忆ID
|
||||
actual_end_user_id: 实际用户ID
|
||||
actual_config_id: 配置ID
|
||||
text: 文本内容
|
||||
files: 文件列表(已由 MultimodalService 处理为对应 provider 的格式)
|
||||
|
||||
逻辑说明:
|
||||
- RAG 模式:组合 user_message 和 ai_message 为字符串格式,保持原有逻辑不变
|
||||
- Neo4j 模式:使用结构化消息列表
|
||||
1. 如果 user_message 和 ai_message 都不为空:创建配对消息 [user, assistant]
|
||||
2. 如果只有 user_message:创建单条用户消息 [user](用于历史记忆场景)
|
||||
3. 每条消息会被转换为独立的 Chunk,保留 speaker 字段
|
||||
Returns:
|
||||
List[Dict]: 消息内容列表
|
||||
"""
|
||||
if storage_type == "rag":
|
||||
# RAG 模式:组合消息为字符串格式(保持原有逻辑)
|
||||
combined_message = f"user: {user_message}\nassistant: {ai_message}"
|
||||
await write_rag(end_user_id, combined_message, user_rag_memory_id)
|
||||
logger.info(f'RAG_Agent:{end_user_id};{user_rag_memory_id}')
|
||||
# 根据 provider 使用不同的文本格式
|
||||
if self.provider.lower() in ["bedrock", "anthropic"]:
|
||||
# Anthropic/Bedrock: {"type": "text", "text": "..."}
|
||||
content_parts = [{"type": "text", "text": text}]
|
||||
else:
|
||||
# Neo4j 模式:使用结构化消息列表
|
||||
structured_messages = []
|
||||
|
||||
# 始终添加用户消息(如果不为空)
|
||||
if user_message:
|
||||
structured_messages.append({"role": "user", "content": user_message})
|
||||
|
||||
# 只有当 AI 回复不为空时才添加 assistant 消息
|
||||
if ai_message:
|
||||
structured_messages.append({"role": "assistant", "content": ai_message})
|
||||
|
||||
# 如果没有消息,直接返回
|
||||
if not structured_messages:
|
||||
logger.warning(f"No messages to write for user {actual_end_user_id}")
|
||||
return
|
||||
|
||||
# 调用 Celery 任务,传递结构化消息列表
|
||||
# 数据流:
|
||||
# 1. structured_messages 传递给 write_message_task
|
||||
# 2. write_message_task 调用 memory_agent_service.write_memory
|
||||
# 3. write_memory 调用 write_tools.write,传递 messages 参数
|
||||
# 4. write_tools.write 调用 get_chunked_dialogs,传递 messages 参数
|
||||
# 5. get_chunked_dialogs 为每条消息创建独立的 Chunk,设置 speaker 字段
|
||||
# 6. 每个 Chunk 保存到 Neo4j,包含 speaker 字段
|
||||
logger.info(f"[WRITE] Submitting Celery task - user={actual_end_user_id}, messages={len(structured_messages)}, config={actual_config_id}")
|
||||
write_id = write_message_task.delay(
|
||||
actual_end_user_id, # group_id: 用户ID
|
||||
structured_messages, # message: 结构化消息列表 [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]
|
||||
actual_config_id, # config_id: 配置ID
|
||||
storage_type, # storage_type: "neo4j"
|
||||
user_rag_memory_id # user_rag_memory_id: RAG记忆ID(Neo4j模式下不使用)
|
||||
)
|
||||
logger.info(f"[WRITE] Celery task submitted - task_id={write_id}")
|
||||
write_status = get_task_memory_write_result(str(write_id))
|
||||
logger.info(f'[WRITE] Task result - user={actual_end_user_id}, status={write_status}')
|
||||
# 通义千问等: {"text": "..."}
|
||||
content_parts = [{"text": text}]
|
||||
|
||||
# 添加文件内容
|
||||
# MultimodalService 已经根据 provider 返回了正确格式,直接使用
|
||||
content_parts.extend(files)
|
||||
|
||||
logger.debug(
|
||||
f"构建多模态消息: provider={self.provider}, "
|
||||
f"parts={len(content_parts)}, "
|
||||
f"files={len(files)}"
|
||||
)
|
||||
|
||||
return content_parts
|
||||
|
||||
async def chat(
|
||||
self,
|
||||
@@ -247,7 +289,8 @@ class LangChainAgent:
|
||||
config_id: Optional[str] = None, # 添加这个参数
|
||||
storage_type: Optional[str] = None,
|
||||
user_rag_memory_id: Optional[str] = None,
|
||||
memory_flag: Optional[bool] = True
|
||||
memory_flag: Optional[bool] = True,
|
||||
files: Optional[List[Dict[str, Any]]] = None # 新增:多模态文件
|
||||
) -> Dict[str, Any]:
|
||||
"""执行对话
|
||||
|
||||
@@ -281,33 +324,9 @@ class LangChainAgent:
|
||||
actual_end_user_id = end_user_id if end_user_id is not None else "unknown"
|
||||
logger.info(f'写入类型{storage_type,str(end_user_id), message, str(user_rag_memory_id)}')
|
||||
print(f'写入类型{storage_type,str(end_user_id), message, str(user_rag_memory_id)}')
|
||||
# # TODO 乐力齐,在长短期记忆存储的时候再使用此代码
|
||||
# history_term_memory_result = await self.term_memory_redis_read(end_user_id)
|
||||
# history_term_memory = history_term_memory_result[0]
|
||||
# db_for_memory = next(get_db())
|
||||
# if memory_flag:
|
||||
# if len(history_term_memory)>=4 and storage_type != "rag":
|
||||
# history_term_memory = ';'.join(history_term_memory)
|
||||
# retrieved_content = history_term_memory_result[1]
|
||||
# print(retrieved_content)
|
||||
# # 为长期记忆操作获取新的数据库连接
|
||||
# try:
|
||||
# repo = LongTermMemoryRepository(db_for_memory)
|
||||
# repo.upsert(end_user_id, retrieved_content)
|
||||
# logger.info(
|
||||
# f'写入短长期:{storage_type, str(end_user_id), history_term_memory, str(user_rag_memory_id)}')
|
||||
# except Exception as e:
|
||||
# logger.error(f"Failed to write to LongTermMemory: {e}")
|
||||
# raise
|
||||
# finally:
|
||||
# db_for_memory.close()
|
||||
|
||||
# # 长期记忆写入(
|
||||
# await self.write(storage_type, actual_end_user_id, history_term_memory, "", user_rag_memory_id, actual_end_user_id, actual_config_id)
|
||||
# # 注意:不在这里写入用户消息,等 AI 回复后一起写入
|
||||
try:
|
||||
# 准备消息列表
|
||||
messages = self._prepare_messages(message, history, context)
|
||||
# 准备消息列表(支持多模态)
|
||||
messages = self._prepare_messages(message, history, context, files)
|
||||
|
||||
logger.debug(
|
||||
"准备调用 LangChain Agent",
|
||||
@@ -315,27 +334,85 @@ class LangChainAgent:
|
||||
"has_context": bool(context),
|
||||
"has_history": bool(history),
|
||||
"has_tools": bool(self.tools),
|
||||
"message_count": len(messages)
|
||||
"has_files": bool(files),
|
||||
"message_count": len(messages),
|
||||
"max_iterations": self.max_iterations
|
||||
}
|
||||
)
|
||||
|
||||
# 统一使用 agent.invoke 调用
|
||||
result = await self.agent.ainvoke({"messages": messages})
|
||||
# 通过 recursion_limit 限制最大迭代次数,防止工具调用死循环
|
||||
try:
|
||||
result = await self.agent.ainvoke(
|
||||
{"messages": messages},
|
||||
config={"recursion_limit": self.max_iterations}
|
||||
)
|
||||
except RecursionError as e:
|
||||
logger.warning(
|
||||
f"Agent 达到最大迭代次数限制 ({self.max_iterations}),可能存在工具调用循环",
|
||||
extra={"error": str(e)}
|
||||
)
|
||||
# 返回一个友好的错误提示
|
||||
return {
|
||||
"content": f"抱歉,我在处理您的请求时遇到了问题。已达到最大处理步骤限制({self.max_iterations}次)。请尝试简化您的问题或稍后再试。",
|
||||
"model": self.model_name,
|
||||
"elapsed_time": time.time() - start_time,
|
||||
"usage": {
|
||||
"prompt_tokens": 0,
|
||||
"completion_tokens": 0,
|
||||
"total_tokens": 0
|
||||
}
|
||||
}
|
||||
|
||||
# 获取最后的 AI 消息
|
||||
output_messages = result.get("messages", [])
|
||||
content = ""
|
||||
|
||||
logger.debug(f"输出消息数量: {len(output_messages)}")
|
||||
total_tokens = 0
|
||||
for msg in reversed(output_messages):
|
||||
if isinstance(msg, AIMessage):
|
||||
content = msg.content
|
||||
logger.debug(f"找到 AI 消息,content 类型: {type(msg.content)}")
|
||||
logger.debug(f"AI 消息内容: {msg.content}")
|
||||
|
||||
# 处理多模态响应:content 可能是字符串或列表
|
||||
if isinstance(msg.content, str):
|
||||
content = msg.content
|
||||
logger.debug(f"提取字符串内容,长度: {len(content)}")
|
||||
elif isinstance(msg.content, list):
|
||||
# 多模态响应:提取文本部分
|
||||
logger.debug(f"多模态响应,列表长度: {len(msg.content)}")
|
||||
text_parts = []
|
||||
for item in msg.content:
|
||||
logger.debug(f"处理项: {item}")
|
||||
if isinstance(item, dict):
|
||||
# 通义千问格式: {"text": "..."}
|
||||
if "text" in item:
|
||||
text = item.get("text", "")
|
||||
text_parts.append(text)
|
||||
logger.debug(f"提取文本: {text[:100]}...")
|
||||
# OpenAI 格式: {"type": "text", "text": "..."}
|
||||
elif item.get("type") == "text":
|
||||
text = item.get("text", "")
|
||||
text_parts.append(text)
|
||||
logger.debug(f"提取文本: {text[:100]}...")
|
||||
elif isinstance(item, str):
|
||||
text_parts.append(item)
|
||||
logger.debug(f"提取字符串: {item[:100]}...")
|
||||
content = "".join(text_parts)
|
||||
logger.debug(f"合并后内容长度: {len(content)}")
|
||||
else:
|
||||
content = str(msg.content)
|
||||
logger.debug(f"转换为字符串: {content[:100]}...")
|
||||
response_meta = msg.response_metadata if hasattr(msg, 'response_metadata') else None
|
||||
total_tokens = response_meta.get("token_usage", {}).get("total_tokens", 0) if response_meta else 0
|
||||
break
|
||||
|
||||
logger.info(f"最终提取的内容长度: {len(content)}")
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
if memory_flag:
|
||||
# AI 回复写入(用户消息和 AI 回复配对,一次性写入完整对话)
|
||||
await self.write(storage_type, actual_end_user_id, message_chat, content, user_rag_memory_id, actual_end_user_id, actual_config_id)
|
||||
# TODO 乐力齐 - 累积多组对话批量写入功能已禁用
|
||||
# await self.term_memory_save(message_chat, end_user_id, content)
|
||||
await write_long_term(storage_type, end_user_id, message_chat, content, user_rag_memory_id, actual_config_id)
|
||||
response = {
|
||||
"content": content,
|
||||
"model": self.model_name,
|
||||
@@ -343,7 +420,7 @@ class LangChainAgent:
|
||||
"usage": {
|
||||
"prompt_tokens": 0,
|
||||
"completion_tokens": 0,
|
||||
"total_tokens": 0
|
||||
"total_tokens": total_tokens
|
||||
}
|
||||
}
|
||||
|
||||
@@ -370,7 +447,8 @@ class LangChainAgent:
|
||||
config_id: Optional[str] = None,
|
||||
storage_type:Optional[str] = None,
|
||||
user_rag_memory_id:Optional[str] = None,
|
||||
memory_flag: Optional[bool] = True
|
||||
memory_flag: Optional[bool] = True,
|
||||
files: Optional[List[Dict[str, Any]]] = None # 新增:多模态文件
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""执行流式对话
|
||||
|
||||
@@ -403,33 +481,15 @@ class LangChainAgent:
|
||||
db.close()
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get db session: {e}")
|
||||
# # TODO 乐力齐
|
||||
# history_term_memory_result = await self.term_memory_redis_read(end_user_id)
|
||||
# history_term_memory = history_term_memory_result[0]
|
||||
# if memory_flag:
|
||||
# if len(history_term_memory) >= 4 and storage_type != "rag":
|
||||
# history_term_memory = ';'.join(history_term_memory)
|
||||
# retrieved_content = history_term_memory_result[1]
|
||||
# db_for_memory = next(get_db())
|
||||
# try:
|
||||
# repo = LongTermMemoryRepository(db_for_memory)
|
||||
# repo.upsert(end_user_id, retrieved_content)
|
||||
# logger.info(
|
||||
# f'写入短长期:{storage_type, str(end_user_id), history_term_memory, str(user_rag_memory_id)}')
|
||||
# # 长期记忆写入
|
||||
# await self.write(storage_type, end_user_id, history_term_memory, "", user_rag_memory_id, end_user_id, actual_config_id)
|
||||
# except Exception as e:
|
||||
# logger.error(f"Failed to write to long term memory: {e}")
|
||||
# finally:
|
||||
# db_for_memory.close()
|
||||
|
||||
|
||||
# 注意:不在这里写入用户消息,等 AI 回复后一起写入
|
||||
try:
|
||||
# 准备消息列表
|
||||
messages = self._prepare_messages(message, history, context)
|
||||
# 准备消息列表(支持多模态)
|
||||
messages = self._prepare_messages(message, history, context, files)
|
||||
|
||||
logger.debug(
|
||||
f"准备流式调用,has_tools={bool(self.tools)}, message_count={len(messages)}"
|
||||
f"准备流式调用,has_tools={bool(self.tools)}, has_files={bool(files)}, message_count={len(messages)}"
|
||||
)
|
||||
|
||||
chunk_count = 0
|
||||
@@ -437,11 +497,12 @@ class LangChainAgent:
|
||||
|
||||
# 统一使用 agent 的 astream_events 实现流式输出
|
||||
logger.debug("使用 Agent astream_events 实现流式输出")
|
||||
full_content=''
|
||||
full_content = ''
|
||||
try:
|
||||
async for event in self.agent.astream_events(
|
||||
{"messages": messages},
|
||||
version="v2"
|
||||
version="v2",
|
||||
config={"recursion_limit": self.max_iterations}
|
||||
):
|
||||
chunk_count += 1
|
||||
kind = event.get("event")
|
||||
@@ -450,20 +511,70 @@ class LangChainAgent:
|
||||
if kind == "on_chat_model_stream":
|
||||
# LLM 流式输出
|
||||
chunk = event.get("data", {}).get("chunk")
|
||||
full_content+=chunk.content
|
||||
if chunk and hasattr(chunk, "content") and chunk.content:
|
||||
yield chunk.content
|
||||
yielded_content = True
|
||||
if chunk and hasattr(chunk, "content"):
|
||||
# 处理多模态响应:content 可能是字符串或列表
|
||||
chunk_content = chunk.content
|
||||
if isinstance(chunk_content, str) and chunk_content:
|
||||
full_content += chunk_content
|
||||
yield chunk_content
|
||||
yielded_content = True
|
||||
elif isinstance(chunk_content, list):
|
||||
# 多模态响应:提取文本部分
|
||||
for item in chunk_content:
|
||||
if isinstance(item, dict):
|
||||
# 通义千问格式: {"text": "..."}
|
||||
if "text" in item:
|
||||
text = item.get("text", "")
|
||||
if text:
|
||||
full_content += text
|
||||
yield text
|
||||
yielded_content = True
|
||||
# OpenAI 格式: {"type": "text", "text": "..."}
|
||||
elif item.get("type") == "text":
|
||||
text = item.get("text", "")
|
||||
if text:
|
||||
full_content += text
|
||||
yield text
|
||||
yielded_content = True
|
||||
elif isinstance(item, str):
|
||||
full_content += item
|
||||
yield item
|
||||
yielded_content = True
|
||||
|
||||
elif kind == "on_llm_stream":
|
||||
# 另一种 LLM 流式事件
|
||||
chunk = event.get("data", {}).get("chunk")
|
||||
if chunk:
|
||||
if hasattr(chunk, "content") and chunk.content:
|
||||
full_content+=chunk.content
|
||||
yield chunk.content
|
||||
yielded_content = True
|
||||
if hasattr(chunk, "content"):
|
||||
chunk_content = chunk.content
|
||||
if isinstance(chunk_content, str) and chunk_content:
|
||||
full_content += chunk_content
|
||||
yield chunk_content
|
||||
yielded_content = True
|
||||
elif isinstance(chunk_content, list):
|
||||
# 多模态响应:提取文本部分
|
||||
for item in chunk_content:
|
||||
if isinstance(item, dict):
|
||||
# 通义千问格式: {"text": "..."}
|
||||
if "text" in item:
|
||||
text = item.get("text", "")
|
||||
if text:
|
||||
full_content += text
|
||||
yield text
|
||||
yielded_content = True
|
||||
# OpenAI 格式: {"type": "text", "text": "..."}
|
||||
elif item.get("type") == "text":
|
||||
text = item.get("text", "")
|
||||
if text:
|
||||
full_content += text
|
||||
yield text
|
||||
yielded_content = True
|
||||
elif isinstance(item, str):
|
||||
full_content += item
|
||||
yield item
|
||||
yielded_content = True
|
||||
elif isinstance(chunk, str):
|
||||
full_content += chunk
|
||||
yield chunk
|
||||
yielded_content = True
|
||||
|
||||
@@ -474,12 +585,17 @@ class LangChainAgent:
|
||||
logger.debug(f"工具调用结束: {event.get('name')}")
|
||||
|
||||
logger.debug(f"Agent 流式完成,共 {chunk_count} 个事件")
|
||||
# 统计token消耗
|
||||
output_messages = event.get("data", {}).get("output", {}).get("messages", [])
|
||||
for msg in reversed(output_messages):
|
||||
if isinstance(msg, AIMessage):
|
||||
response_meta = msg.response_metadata if hasattr(msg, 'response_metadata') else None
|
||||
total_tokens = response_meta.get("token_usage", {}).get("total_tokens",
|
||||
0) if response_meta else 0
|
||||
yield total_tokens
|
||||
break
|
||||
if memory_flag:
|
||||
# AI 回复写入(用户消息和 AI 回复配对,一次性写入完整对话)
|
||||
await self.write(storage_type, end_user_id, message_chat, full_content, user_rag_memory_id, end_user_id, actual_config_id)
|
||||
# TODO 乐力齐 - 累积多组对话批量写入功能已禁用
|
||||
# await self.term_memory_save(message_chat, end_user_id, full_content)
|
||||
|
||||
await write_long_term(storage_type, end_user_id, message_chat, full_content, user_rag_memory_id, actual_config_id)
|
||||
except Exception as e:
|
||||
logger.error(f"Agent astream_events 失败: {str(e)}", exc_info=True)
|
||||
raise
|
||||
|
||||
@@ -9,6 +9,25 @@ load_dotenv()
|
||||
|
||||
|
||||
class Settings:
|
||||
# ========================================================================
|
||||
# Deployment Mode Configuration
|
||||
# ========================================================================
|
||||
# community: 社区版(开源,功能受限)
|
||||
# cloud: SaaS 云服务版(全功能,按量计费)
|
||||
# enterprise: 企业私有化版(License 控制)
|
||||
DEPLOYMENT_MODE: str = os.getenv("DEPLOYMENT_MODE", "community")
|
||||
|
||||
# License 配置(企业版)
|
||||
LICENSE_FILE: str = os.getenv("LICENSE_FILE", "/etc/app/license.json")
|
||||
LICENSE_SERVER_URL: str = os.getenv("LICENSE_SERVER_URL", "https://license.yourcompany.com")
|
||||
|
||||
# 计费服务配置(SaaS 版)
|
||||
BILLING_SERVICE_URL: str = os.getenv("BILLING_SERVICE_URL", "")
|
||||
|
||||
# 基础 URL(用于 SSO 回调等)
|
||||
BASE_URL: str = os.getenv("BASE_URL", "http://localhost:8000")
|
||||
FRONTEND_URL: str = os.getenv("FRONTEND_URL", "http://localhost:3000")
|
||||
|
||||
ENABLE_SINGLE_WORKSPACE: bool = os.getenv("ENABLE_SINGLE_WORKSPACE", "true").lower() == "true"
|
||||
# API Keys Configuration
|
||||
OPENAI_API_KEY: str = os.getenv("OPENAI_API_KEY", "")
|
||||
@@ -72,6 +91,10 @@ class Settings:
|
||||
|
||||
# Single Sign-On configuration
|
||||
ENABLE_SINGLE_SESSION: bool = os.getenv("ENABLE_SINGLE_SESSION", "false").lower() == "true"
|
||||
|
||||
# SSO 免登配置
|
||||
SSO_TOKEN_EXPIRE_SECONDS: int = int(os.getenv("SSO_TOKEN_EXPIRE_SECONDS", "300"))
|
||||
SSO_TRUSTED_SOURCES_CONFIG: str = os.getenv("SSO_TRUSTED_SOURCES_CONFIG", "{}")
|
||||
|
||||
# File Upload
|
||||
MAX_FILE_SIZE: int = int(os.getenv("MAX_FILE_SIZE", "52428800"))
|
||||
@@ -107,6 +130,7 @@ class Settings:
|
||||
|
||||
# Server Configuration
|
||||
SERVER_IP: str = os.getenv("SERVER_IP", "127.0.0.1")
|
||||
FILE_LOCAL_SERVER_URL : str = os.getenv("FILE_LOCAL_SERVER_URL", "http://localhost:8000/api")
|
||||
|
||||
# ========================================================================
|
||||
# Internal Configuration (not in .env, used by application code)
|
||||
@@ -133,6 +157,11 @@ class Settings:
|
||||
if origin.strip()
|
||||
]
|
||||
|
||||
# Language Configuration
|
||||
# Supported values: "zh" (Chinese), "en" (English)
|
||||
# This controls the language used for memory summary titles and other generated content
|
||||
DEFAULT_LANGUAGE: str = os.getenv("DEFAULT_LANGUAGE", "zh")
|
||||
|
||||
# Logging settings
|
||||
LOG_LEVEL: str = os.getenv("LOG_LEVEL", "INFO")
|
||||
LOG_FORMAT: str = os.getenv("LOG_FORMAT", "%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
||||
@@ -164,16 +193,29 @@ class Settings:
|
||||
CELERY_BROKER: int = int(os.getenv("CELERY_BROKER", "1"))
|
||||
CELERY_BACKEND: int = int(os.getenv("CELERY_BACKEND", "2"))
|
||||
|
||||
# SMTP Email Configuration
|
||||
SMTP_SERVER: str = os.getenv("SMTP_SERVER", "smtp.gmail.com")
|
||||
SMTP_PORT: int = int(os.getenv("SMTP_PORT", "587"))
|
||||
SMTP_USER: str = os.getenv("SMTP_USER", "")
|
||||
SMTP_PASSWORD: str = os.getenv("SMTP_PASSWORD", "")
|
||||
|
||||
REFLECTION_INTERVAL_SECONDS: float = float(os.getenv("REFLECTION_INTERVAL_SECONDS", "300"))
|
||||
HEALTH_CHECK_SECONDS: float = float(os.getenv("HEALTH_CHECK_SECONDS", "600"))
|
||||
MEMORY_INCREMENT_INTERVAL_HOURS: float = float(os.getenv("MEMORY_INCREMENT_INTERVAL_HOURS", "24"))
|
||||
DEFAULT_WORKSPACE_ID: Optional[str] = os.getenv("DEFAULT_WORKSPACE_ID", None)
|
||||
REFLECTION_INTERVAL_TIME: Optional[str] = int(os.getenv("REFLECTION_INTERVAL_TIME", 30))
|
||||
|
||||
# Memory Cache Regeneration Configuration
|
||||
MEMORY_CACHE_REGENERATION_HOURS: int = int(os.getenv("MEMORY_CACHE_REGENERATION_HOURS", "24"))
|
||||
|
||||
# Memory Module Configuration (internal)
|
||||
# Periodic Task Schedule Configuration
|
||||
# workspace_reflection: 每隔多少秒执行一次
|
||||
WORKSPACE_REFLECTION_INTERVAL_SECONDS: int = int(os.getenv("WORKSPACE_REFLECTION_INTERVAL_SECONDS", "30"))
|
||||
# forgetting_cycle: 每隔多少小时执行一次
|
||||
FORGETTING_CYCLE_INTERVAL_HOURS: int = int(os.getenv("FORGETTING_CYCLE_INTERVAL_HOURS", "24"))
|
||||
# implicit_emotions_update: 每天几点执行(小时,0-23)
|
||||
IMPLICIT_EMOTIONS_UPDATE_HOUR: int = int(os.getenv("IMPLICIT_EMOTIONS_UPDATE_HOUR", "2"))
|
||||
# implicit_emotions_update: 每天几分执行(分钟,0-59)
|
||||
IMPLICIT_EMOTIONS_UPDATE_MINUTE: int = int(os.getenv("IMPLICIT_EMOTIONS_UPDATE_MINUTE", "0")) # Memory Module Configuration (internal)
|
||||
MEMORY_OUTPUT_DIR: str = os.getenv("MEMORY_OUTPUT_DIR", "logs/memory-output")
|
||||
MEMORY_CONFIG_DIR: str = os.getenv("MEMORY_CONFIG_DIR", "app/core/memory")
|
||||
|
||||
@@ -184,11 +226,36 @@ class Settings:
|
||||
ENABLE_TOOL_MANAGEMENT: bool = os.getenv("ENABLE_TOOL_MANAGEMENT", "true").lower() == "true"
|
||||
|
||||
# official environment system version
|
||||
SYSTEM_VERSION: str = os.getenv("SYSTEM_VERSION", "v0.2.0")
|
||||
SYSTEM_VERSION: str = os.getenv("SYSTEM_VERSION", "v0.2.1")
|
||||
|
||||
# model square loading
|
||||
LOAD_MODEL: bool = os.getenv("LOAD_MODEL", "false").lower() == "true"
|
||||
|
||||
# workflow config
|
||||
WORKFLOW_NODE_TIMEOUT: int = int(os.getenv("WORKFLOW_NODE_TIMEOUT", 600))
|
||||
|
||||
# ========================================================================
|
||||
# General Ontology Type Configuration
|
||||
# ========================================================================
|
||||
# 通用本体文件路径列表(逗号分隔)
|
||||
GENERAL_ONTOLOGY_FILES: str = os.getenv("GENERAL_ONTOLOGY_FILES", "General_purpose_entity.ttl")
|
||||
|
||||
# 是否启用通用本体类型功能
|
||||
ENABLE_GENERAL_ONTOLOGY_TYPES: bool = os.getenv("ENABLE_GENERAL_ONTOLOGY_TYPES", "true").lower() == "true"
|
||||
|
||||
# Prompt 中最大类型数量
|
||||
MAX_ONTOLOGY_TYPES_IN_PROMPT: int = int(os.getenv("MAX_ONTOLOGY_TYPES_IN_PROMPT", "50"))
|
||||
|
||||
# 核心通用类型列表(逗号分隔)
|
||||
CORE_GENERAL_TYPES: str = os.getenv(
|
||||
"CORE_GENERAL_TYPES",
|
||||
"Person,Organization,Company,GovernmentAgency,Place,Location,City,Country,Building,"
|
||||
"Event,SportsEvent,SocialEvent,Work,Book,Film,Software,Concept,TopicalConcept,AcademicSubject"
|
||||
)
|
||||
|
||||
# 实验模式开关(允许通过 API 动态切换本体配置)
|
||||
ONTOLOGY_EXPERIMENT_MODE: bool = os.getenv("ONTOLOGY_EXPERIMENT_MODE", "true").lower() == "true"
|
||||
|
||||
def get_memory_output_path(self, filename: str = "") -> str:
|
||||
"""
|
||||
Get the full path for memory module output files.
|
||||
|
||||
@@ -46,6 +46,7 @@ class BizCode(IntEnum):
|
||||
RESOURCE_ALREADY_EXISTS = 5002
|
||||
VERSION_ALREADY_EXISTS = 5003
|
||||
STATE_CONFLICT = 5004
|
||||
RESOURCE_IN_USE = 5005
|
||||
|
||||
# 应用发布(6xxx)
|
||||
PUBLISH_FAILED = 6001
|
||||
@@ -125,6 +126,7 @@ HTTP_MAPPING = {
|
||||
BizCode.RESOURCE_ALREADY_EXISTS: 409,
|
||||
BizCode.VERSION_ALREADY_EXISTS: 409,
|
||||
BizCode.STATE_CONFLICT: 409,
|
||||
BizCode.RESOURCE_IN_USE: 409,
|
||||
BizCode.PUBLISH_FAILED: 500,
|
||||
BizCode.NO_DRAFT_TO_PUBLISH: 400,
|
||||
BizCode.ROLLBACK_TARGET_NOT_FOUND: 400,
|
||||
|
||||
82
api/app/core/language_utils.py
Normal file
82
api/app/core/language_utils.py
Normal file
@@ -0,0 +1,82 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""语言处理工具模块
|
||||
|
||||
本模块提供集中化的语言校验和处理功能,确保整个应用中语言参数的一致性。
|
||||
|
||||
Functions:
|
||||
validate_language: 校验语言参数,确保其为有效值
|
||||
get_language_from_header: 从请求头获取并校验语言参数
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from app.core.logging_config import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# 支持的语言列表
|
||||
SUPPORTED_LANGUAGES = {"zh", "en"}
|
||||
|
||||
# 默认回退语言
|
||||
DEFAULT_LANGUAGE = "zh"
|
||||
|
||||
|
||||
def validate_language(language: Optional[str]) -> str:
|
||||
"""
|
||||
校验语言参数,确保其为有效值。
|
||||
|
||||
Args:
|
||||
language: 待校验的语言代码,可以是 None、"zh"、"en" 或其他值
|
||||
|
||||
Returns:
|
||||
有效的语言代码("zh" 或 "en")
|
||||
|
||||
Examples:
|
||||
>>> validate_language("zh")
|
||||
'zh'
|
||||
>>> validate_language("en")
|
||||
'en'
|
||||
>>> validate_language("EN") # 大小写不敏感
|
||||
'en'
|
||||
>>> validate_language(None) # None 回退到默认值
|
||||
'zh'
|
||||
>>> validate_language("fr") # 不支持的语言回退到默认值
|
||||
'zh'
|
||||
"""
|
||||
if language is None:
|
||||
return DEFAULT_LANGUAGE
|
||||
|
||||
# 标准化:转小写并去除空白
|
||||
lang = str(language).lower().strip()
|
||||
|
||||
if lang in SUPPORTED_LANGUAGES:
|
||||
return lang
|
||||
|
||||
logger.warning(
|
||||
f"无效的语言参数 '{language}',已回退到默认值 '{DEFAULT_LANGUAGE}'。"
|
||||
f"支持的语言: {SUPPORTED_LANGUAGES}"
|
||||
)
|
||||
return DEFAULT_LANGUAGE
|
||||
|
||||
|
||||
def get_language_from_header(language_type: Optional[str]) -> str:
|
||||
"""
|
||||
从请求头获取并校验语言参数。
|
||||
|
||||
这是一个便捷函数,用于在 controller 层统一处理 X-Language-Type Header。
|
||||
|
||||
Args:
|
||||
language_type: 从 X-Language-Type Header 获取的语言值
|
||||
|
||||
Returns:
|
||||
有效的语言代码("zh" 或 "en")
|
||||
|
||||
Examples:
|
||||
>>> get_language_from_header(None) # Header 未传递
|
||||
'zh'
|
||||
>>> get_language_from_header("en")
|
||||
'en'
|
||||
>>> get_language_from_header("invalid") # 无效值回退
|
||||
'zh'
|
||||
"""
|
||||
return validate_language(language_type)
|
||||
@@ -38,6 +38,56 @@ class SensitiveDataLoggingFilter(logging.Filter):
|
||||
return True
|
||||
|
||||
|
||||
class Neo4jSuccessNotificationFilter(logging.Filter):
|
||||
"""Neo4j 日志过滤器:过滤成功/信息性状态的通知,保留真正的警告和错误
|
||||
|
||||
Neo4j 驱动会以 WARNING 级别记录所有数据库通知,包括成功的操作。
|
||||
这个过滤器会过滤掉以下 GQL 状态码的通知,只保留真正的警告和错误:
|
||||
- 00000: 成功完成 (successful completion)
|
||||
- 00N00: 无数据 (no data)
|
||||
- 00NA0: 无数据,信息性通知 (no data, informational notification)
|
||||
|
||||
使用正则表达式进行更严格的匹配,避免误过滤无关的警告。
|
||||
"""
|
||||
|
||||
import re
|
||||
|
||||
# 编译正则表达式以提高性能
|
||||
# 匹配所有"成功/信息性"的 GQL 状态码:
|
||||
# 00000 = 成功完成, 00N00 = 无数据, 00NA0 = 无数据信息性通知
|
||||
GQL_STATUS_PATTERN = re.compile(r"gql_status=['\"](00000|00N00|00NA0)['\"]")
|
||||
|
||||
# 匹配 status_description 中的成功完成或信息性通知消息
|
||||
SUCCESS_DESC_PATTERN = re.compile(r"status_description=['\"]note:\s*(successful\s+completion|no\s+data)['\"]", re.IGNORECASE)
|
||||
|
||||
def filter(self, record: logging.LogRecord) -> bool:
|
||||
"""
|
||||
过滤 Neo4j 成功通知
|
||||
|
||||
Args:
|
||||
record: 日志记录
|
||||
|
||||
Returns:
|
||||
True表示允许记录,False表示拒绝(过滤掉)
|
||||
"""
|
||||
# 只处理 INFO 和 WARNING 级别的日志
|
||||
# Neo4j 驱动对 severity='INFORMATION' 的通知使用 INFO 级别,
|
||||
# 对 severity='WARNING' 的通知使用 WARNING 级别
|
||||
if record.levelno not in (logging.INFO, logging.WARNING):
|
||||
return True
|
||||
|
||||
# 检查是否是 Neo4j 的成功通知
|
||||
message = str(record.msg)
|
||||
|
||||
# 使用正则表达式进行更严格的匹配
|
||||
# 这样可以避免误过滤包含这些子字符串但不是 Neo4j 通知的日志
|
||||
if self.GQL_STATUS_PATTERN.search(message) or self.SUCCESS_DESC_PATTERN.search(message):
|
||||
return False # 过滤掉这条日志
|
||||
|
||||
# 保留其他所有日志(包括真正的警告和错误)
|
||||
return True
|
||||
|
||||
|
||||
class LoggingConfig:
|
||||
"""全局日志配置类"""
|
||||
|
||||
@@ -65,6 +115,22 @@ class LoggingConfig:
|
||||
# 清除现有处理器
|
||||
root_logger.handlers.clear()
|
||||
|
||||
# Neo4j 通知过滤器 - 挂在 handler 上确保所有传播上来的日志都能被过滤
|
||||
neo4j_filter = Neo4jSuccessNotificationFilter()
|
||||
|
||||
# 抑制 Neo4j 通知日志
|
||||
# Neo4j 驱动内部会给 neo4j.notifications logger 配置自己的 handler,
|
||||
# 导致日志绕过根 logger 的 filter 直接输出。
|
||||
# 多管齐下确保过滤生效:
|
||||
# 1. 设置 neo4j.notifications 级别为 WARNING(过滤 INFO 级别的 00NA0 通知)
|
||||
# 2. 在所有 neo4j logger 上添加 filter(过滤 WARNING 级别的成功通知)
|
||||
# 3. 在根 handler 上也添加 filter(兜底)
|
||||
neo4j_notifications_logger = logging.getLogger("neo4j.notifications")
|
||||
neo4j_notifications_logger.setLevel(logging.WARNING)
|
||||
for neo4j_logger_name in ["neo4j", "neo4j.io", "neo4j.pool", "neo4j.notifications"]:
|
||||
neo4j_logger = logging.getLogger(neo4j_logger_name)
|
||||
neo4j_logger.addFilter(neo4j_filter)
|
||||
|
||||
# 创建格式化器
|
||||
formatter = logging.Formatter(
|
||||
fmt=settings.LOG_FORMAT,
|
||||
@@ -80,6 +146,7 @@ class LoggingConfig:
|
||||
console_handler.setFormatter(formatter)
|
||||
console_handler.setLevel(getattr(logging, settings.LOG_LEVEL.upper()))
|
||||
console_handler.addFilter(sensitive_filter)
|
||||
console_handler.addFilter(neo4j_filter)
|
||||
root_logger.addHandler(console_handler)
|
||||
|
||||
# 文件处理器(带轮转)
|
||||
@@ -93,6 +160,7 @@ class LoggingConfig:
|
||||
file_handler.setFormatter(formatter)
|
||||
file_handler.setLevel(getattr(logging, settings.LOG_LEVEL.upper()))
|
||||
file_handler.addFilter(sensitive_filter)
|
||||
file_handler.addFilter(neo4j_filter)
|
||||
root_logger.addHandler(file_handler)
|
||||
|
||||
cls._initialized = True
|
||||
|
||||
@@ -14,7 +14,7 @@ from app.core.memory.agent.utils.session_tools import SessionService
|
||||
from app.core.memory.agent.utils.template_tools import TemplateService
|
||||
from app.core.memory.agent.services.optimized_llm_service import LLMServiceMixin
|
||||
|
||||
template_root = os.path.join(PROJECT_ROOT_, 'agent', 'utils', 'prompt')
|
||||
template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt')
|
||||
db_session = next(get_db())
|
||||
logger = get_agent_logger(__name__)
|
||||
|
||||
@@ -35,10 +35,10 @@ async def Split_The_Problem(state: ReadState) -> ReadState:
|
||||
"""问题分解节点"""
|
||||
# 从状态中获取数据
|
||||
content = state.get('data', '')
|
||||
group_id = state.get('group_id', '')
|
||||
end_user_id = state.get('end_user_id', '')
|
||||
memory_config = state.get('memory_config', None)
|
||||
|
||||
history = await SessionService(store).get_history(group_id, group_id, group_id)
|
||||
history = await SessionService(store).get_history(end_user_id, end_user_id, end_user_id)
|
||||
|
||||
# 生成 JSON schema 以指导 LLM 输出正确格式
|
||||
json_schema = ProblemExtensionResponse.model_json_schema()
|
||||
@@ -140,7 +140,7 @@ async def Problem_Extension(state: ReadState) -> ReadState:
|
||||
start = time.time()
|
||||
content = state.get('data', '')
|
||||
data = state.get('spit_data', '')['context']
|
||||
group_id = state.get('group_id', '')
|
||||
end_user_id = state.get('end_user_id', '')
|
||||
storage_type = state.get('storage_type', '')
|
||||
user_rag_memory_id = state.get('user_rag_memory_id', '')
|
||||
memory_config = state.get('memory_config', None)
|
||||
@@ -156,7 +156,7 @@ async def Problem_Extension(state: ReadState) -> ReadState:
|
||||
databasets = {}
|
||||
data = []
|
||||
|
||||
history = await SessionService(store).get_history(group_id, group_id, group_id)
|
||||
history = await SessionService(store).get_history(end_user_id, end_user_id, end_user_id)
|
||||
|
||||
# 生成 JSON schema 以指导 LLM 输出正确格式
|
||||
json_schema = ProblemExtensionResponse.model_json_schema()
|
||||
|
||||
@@ -52,9 +52,9 @@ async def rag_config(state):
|
||||
return kb_config
|
||||
async def rag_knowledge(state,question):
|
||||
kb_config = await rag_config(state)
|
||||
group_id = state.get('group_id', '')
|
||||
end_user_id = state.get('end_user_id', '')
|
||||
user_rag_memory_id=state.get("user_rag_memory_id",'')
|
||||
retrieve_chunks_result = knowledge_retrieval(question, kb_config, [str(group_id)])
|
||||
retrieve_chunks_result = knowledge_retrieval(question, kb_config, [str(end_user_id)])
|
||||
try:
|
||||
retrieval_knowledge = [i.page_content for i in retrieve_chunks_result]
|
||||
clean_content = '\n\n'.join(retrieval_knowledge)
|
||||
@@ -159,7 +159,7 @@ async def retrieve_nodes(state: ReadState) -> ReadState:
|
||||
problem_extension=state.get('problem_extension', '')['context']
|
||||
storage_type=state.get('storage_type', '')
|
||||
user_rag_memory_id=state.get('user_rag_memory_id', '')
|
||||
group_id=state.get('group_id', '')
|
||||
end_user_id=state.get('end_user_id', '')
|
||||
memory_config = state.get('memory_config', None)
|
||||
original=state.get('data', '')
|
||||
problem_list=[]
|
||||
@@ -172,7 +172,7 @@ async def retrieve_nodes(state: ReadState) -> ReadState:
|
||||
try:
|
||||
# Prepare search parameters based on storage type
|
||||
search_params = {
|
||||
"group_id": group_id,
|
||||
"end_user_id": end_user_id,
|
||||
"question": question,
|
||||
"return_raw_results": True
|
||||
}
|
||||
@@ -263,13 +263,13 @@ async def retrieve_nodes(state: ReadState) -> ReadState:
|
||||
|
||||
|
||||
async def retrieve(state: ReadState) -> ReadState:
|
||||
# 从state中获取group_id
|
||||
# 从state中获取end_user_id
|
||||
import time
|
||||
start=time.time()
|
||||
problem_extension = state.get('problem_extension', '')['context']
|
||||
storage_type = state.get('storage_type', '')
|
||||
user_rag_memory_id = state.get('user_rag_memory_id', '')
|
||||
group_id = state.get('group_id', '')
|
||||
end_user_id = state.get('end_user_id', '')
|
||||
memory_config = state.get('memory_config', None)
|
||||
original = state.get('data', '')
|
||||
problem_list = []
|
||||
@@ -295,13 +295,13 @@ async def retrieve(state: ReadState) -> ReadState:
|
||||
temperature=0.2,
|
||||
)
|
||||
|
||||
time_retrieval_tool = create_time_retrieval_tool(group_id)
|
||||
search_params = { "group_id": group_id, "return_raw_results": True }
|
||||
time_retrieval_tool = create_time_retrieval_tool(end_user_id)
|
||||
search_params = { "end_user_id": end_user_id, "return_raw_results": True }
|
||||
hybrid_retrieval=create_hybrid_retrieval_tool_sync(memory_config, **search_params)
|
||||
agent = create_agent(
|
||||
llm,
|
||||
tools=[time_retrieval_tool,hybrid_retrieval],
|
||||
system_prompt=f"我是检索专家,可以根据适合的工具进行检索。当前使用的group_id是: {group_id}"
|
||||
system_prompt=f"我是检索专家,可以根据适合的工具进行检索。当前使用的end_user_id是: {end_user_id}"
|
||||
)
|
||||
|
||||
# 创建异步任务处理单个问题
|
||||
|
||||
@@ -19,7 +19,7 @@ from app.core.memory.agent.utils.session_tools import SessionService
|
||||
from app.core.memory.agent.utils.template_tools import TemplateService
|
||||
from app.db import get_db
|
||||
|
||||
template_root = os.path.join(PROJECT_ROOT_, 'agent', 'utils', 'prompt')
|
||||
template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt')
|
||||
logger = get_agent_logger(__name__)
|
||||
db_session = next(get_db())
|
||||
|
||||
@@ -34,8 +34,8 @@ class SummaryNodeService(LLMServiceMixin):
|
||||
summary_service = SummaryNodeService()
|
||||
|
||||
async def summary_history(state: ReadState) -> ReadState:
|
||||
group_id = state.get("group_id", '')
|
||||
history = await SessionService(store).get_history(group_id, group_id, group_id)
|
||||
end_user_id = state.get("end_user_id", '')
|
||||
history = await SessionService(store).get_history(end_user_id, end_user_id, end_user_id)
|
||||
return history
|
||||
|
||||
async def summary_llm(state: ReadState, history, retrieve_info, template_name, operation_name, response_model,search_mode) -> str:
|
||||
@@ -122,12 +122,12 @@ async def summary_llm(state: ReadState, history, retrieve_info, template_name, o
|
||||
|
||||
async def summary_redis_save(state: ReadState,aimessages) -> ReadState:
|
||||
data = state.get("data", '')
|
||||
group_id = state.get("group_id", '')
|
||||
end_user_id = state.get("end_user_id", '')
|
||||
await SessionService(store).save_session(
|
||||
user_id=group_id,
|
||||
user_id=end_user_id,
|
||||
query=data,
|
||||
apply_id=group_id,
|
||||
group_id=group_id,
|
||||
apply_id=end_user_id,
|
||||
end_user_id=end_user_id,
|
||||
ai_response=aimessages
|
||||
)
|
||||
await SessionService(store).cleanup_duplicates()
|
||||
@@ -175,11 +175,11 @@ async def Input_Summary(state: ReadState) -> ReadState:
|
||||
memory_config = state.get('memory_config', None)
|
||||
user_rag_memory_id=state.get("user_rag_memory_id",'')
|
||||
data=state.get("data", '')
|
||||
group_id=state.get("group_id", '')
|
||||
end_user_id=state.get("end_user_id", '')
|
||||
logger.info(f"Input_Summary: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}")
|
||||
history = await summary_history( state)
|
||||
search_params = {
|
||||
"group_id": group_id,
|
||||
"end_user_id": end_user_id,
|
||||
"question": data,
|
||||
"return_raw_results": True,
|
||||
"include": ["summaries"] # Only search summary nodes for faster performance
|
||||
@@ -236,7 +236,7 @@ async def Retrieve_Summary(state: ReadState)-> ReadState:
|
||||
retrieve_info_str='\n'.join(retrieve_info_str)
|
||||
|
||||
aimessages=await summary_llm(state,history,retrieve_info_str,
|
||||
'Retrieve_Summary_prompt.jinja2','retrieve_summary',RetrieveSummaryResponse,"1")
|
||||
'direct_summary_prompt.jinja2','retrieve_summary',RetrieveSummaryResponse,"1")
|
||||
if '信息不足,无法回答' not in str(aimessages) or str(aimessages) != "":
|
||||
await summary_redis_save(state, aimessages)
|
||||
if aimessages == '':
|
||||
@@ -276,7 +276,6 @@ async def Summary(state: ReadState)-> ReadState:
|
||||
aimessages=await summary_llm(state,history,data,
|
||||
'summary_prompt.jinja2','summary',SummaryResponse,0)
|
||||
|
||||
|
||||
if '信息不足,无法回答' not in str(aimessages) or str(aimessages) != "":
|
||||
await summary_redis_save(state, aimessages)
|
||||
if aimessages == '':
|
||||
@@ -295,9 +294,26 @@ async def Summary(state: ReadState)-> ReadState:
|
||||
async def Summary_fails(state: ReadState)-> ReadState:
|
||||
storage_type=state.get("storage_type", '')
|
||||
user_rag_memory_id=state.get("user_rag_memory_id", '')
|
||||
history = await summary_history(state)
|
||||
query = state.get("data", '')
|
||||
verify = state.get("verify", '')
|
||||
verify_expansion_issue = verify.get("verified_data", '')
|
||||
retrieve_info_str = ''
|
||||
for data in verify_expansion_issue:
|
||||
for key, value in data.items():
|
||||
if key == 'answer_small':
|
||||
for i in value:
|
||||
retrieve_info_str += i + '\n'
|
||||
data = {
|
||||
"query": query,
|
||||
"history": history,
|
||||
"retrieve_info": retrieve_info_str
|
||||
}
|
||||
aimessages = await summary_llm(state, history, data,
|
||||
'fail_summary_prompt.jinja2', 'summary', SummaryResponse, 0)
|
||||
result= {
|
||||
"status": "success",
|
||||
"summary_result": "没有相关数据",
|
||||
"summary_result": aimessages,
|
||||
"storage_type": storage_type,
|
||||
"user_rag_memory_id": user_rag_memory_id
|
||||
}
|
||||
|
||||
@@ -12,7 +12,7 @@ from app.core.memory.agent.utils.session_tools import SessionService
|
||||
from app.core.memory.agent.utils.template_tools import TemplateService
|
||||
from app.core.memory.agent.services.optimized_llm_service import LLMServiceMixin
|
||||
|
||||
template_root = os.path.join(PROJECT_ROOT_, 'agent', 'utils', 'prompt')
|
||||
template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt')
|
||||
db_session = next(get_db())
|
||||
logger = get_agent_logger(__name__)
|
||||
|
||||
@@ -62,12 +62,12 @@ async def Verify(state: ReadState):
|
||||
logger.info("=== Verify 节点开始执行 ===")
|
||||
try:
|
||||
content = state.get('data', '')
|
||||
group_id = state.get('group_id', '')
|
||||
end_user_id = state.get('end_user_id', '')
|
||||
memory_config = state.get('memory_config', None)
|
||||
|
||||
logger.info(f"Verify: content={content[:50] if content else 'empty'}..., group_id={group_id}")
|
||||
logger.info(f"Verify: content={content[:50] if content else 'empty'}..., end_user_id={end_user_id}")
|
||||
|
||||
history = await SessionService(store).get_history(group_id, group_id, group_id)
|
||||
history = await SessionService(store).get_history(end_user_id, end_user_id, end_user_id)
|
||||
logger.info(f"Verify: 获取历史记录完成,history length={len(history)}")
|
||||
|
||||
retrieve = state.get("retrieve", {})
|
||||
|
||||
@@ -1,23 +1,25 @@
|
||||
|
||||
from app.core.memory.agent.utils.llm_tools import WriteState
|
||||
from app.core.memory.agent.utils.llm_tools import WriteState
|
||||
from app.core.memory.agent.utils.write_tools import write
|
||||
from app.core.logging_config import get_agent_logger
|
||||
|
||||
logger = get_agent_logger(__name__)
|
||||
|
||||
|
||||
async def write_node(state: WriteState) -> WriteState:
|
||||
"""
|
||||
Write data to the database/file system.
|
||||
|
||||
Args:
|
||||
state: WriteState containing messages, group_id, and memory_config
|
||||
state: WriteState containing messages, end_user_id, memory_config, and language
|
||||
|
||||
Returns:
|
||||
dict: Contains 'write_result' with status and data fields
|
||||
"""
|
||||
messages = state.get('messages', [])
|
||||
group_id = state.get('group_id', '')
|
||||
end_user_id = state.get('end_user_id', '')
|
||||
memory_config = state.get('memory_config', '')
|
||||
|
||||
language = state.get('language', 'zh') # 默认中文
|
||||
|
||||
# Convert LangChain messages to structured format expected by write()
|
||||
structured_messages = []
|
||||
for msg in messages:
|
||||
@@ -28,14 +30,13 @@ async def write_node(state: WriteState) -> WriteState:
|
||||
"role": role,
|
||||
"content": msg.content # content is now guaranteed to be a string
|
||||
})
|
||||
|
||||
|
||||
try:
|
||||
result = await write(
|
||||
messages=structured_messages,
|
||||
user_id=group_id,
|
||||
apply_id=group_id,
|
||||
group_id=group_id,
|
||||
end_user_id=end_user_id,
|
||||
memory_config=memory_config,
|
||||
language=language,
|
||||
)
|
||||
logger.info(f"Write completed successfully! Config: {memory_config.config_name}")
|
||||
|
||||
|
||||
@@ -79,7 +79,7 @@ async def make_read_graph():
|
||||
async def main():
|
||||
"""主函数 - 运行工作流"""
|
||||
message = "昨天有什么好看的电影"
|
||||
group_id = '88a459f5_text09' # 组ID
|
||||
end_user_id = '88a459f5_text09' # 组ID
|
||||
storage_type = 'neo4j' # 存储类型
|
||||
search_switch = '1' # 搜索开关
|
||||
user_rag_memory_id = 'wwwwwwww' # 用户RAG记忆ID
|
||||
@@ -95,9 +95,9 @@ async def main():
|
||||
start=time.time()
|
||||
try:
|
||||
async with make_read_graph() as graph:
|
||||
config = {"configurable": {"thread_id": group_id}}
|
||||
config = {"configurable": {"thread_id": end_user_id}}
|
||||
# 初始状态 - 包含所有必要字段
|
||||
initial_state = {"messages": [HumanMessage(content=message)] ,"search_switch":search_switch,"group_id":group_id
|
||||
initial_state = {"messages": [HumanMessage(content=message)] ,"search_switch":search_switch,"end_user_id":end_user_id
|
||||
,"storage_type":storage_type,"user_rag_memory_id":user_rag_memory_id,"memory_config":memory_config}
|
||||
# 获取节点更新信息
|
||||
_intermediate_outputs = []
|
||||
|
||||
@@ -0,0 +1,238 @@
|
||||
import json
|
||||
import os
|
||||
|
||||
from app.core.logging_config import get_agent_logger
|
||||
from app.core.memory.agent.langgraph_graph.tools.write_tool import format_parsing, messages_parse
|
||||
from app.core.memory.agent.langgraph_graph.write_graph import make_write_graph, long_term_storage
|
||||
|
||||
from app.core.memory.agent.models.write_aggregate_model import WriteAggregateModel
|
||||
from app.core.memory.agent.utils.llm_tools import PROJECT_ROOT_
|
||||
from app.core.memory.agent.utils.redis_tool import write_store
|
||||
from app.core.memory.agent.utils.redis_tool import count_store
|
||||
from app.core.memory.agent.utils.template_tools import TemplateService
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.db import get_db_context, get_db
|
||||
from app.repositories.memory_short_repository import LongTermMemoryRepository
|
||||
from app.schemas.memory_agent_schema import AgentMemory_Long_Term
|
||||
from app.services.memory_konwledges_server import write_rag
|
||||
from app.services.task_service import get_task_memory_write_result
|
||||
from app.tasks import write_message_task
|
||||
from app.utils.config_utils import resolve_config_id
|
||||
logger = get_agent_logger(__name__)
|
||||
template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt')
|
||||
|
||||
async def write_rag_agent(end_user_id, user_message, ai_message, user_rag_memory_id):
|
||||
# RAG 模式:组合消息为字符串格式(保持原有逻辑)
|
||||
combined_message = f"user: {user_message}\nassistant: {ai_message}"
|
||||
await write_rag(end_user_id, combined_message, user_rag_memory_id)
|
||||
logger.info(f'RAG_Agent:{end_user_id};{user_rag_memory_id}')
|
||||
async def write(storage_type, end_user_id, user_message, ai_message, user_rag_memory_id, actual_end_user_id,
|
||||
actual_config_id, long_term_messages=[]):
|
||||
"""
|
||||
写入记忆(支持结构化消息)
|
||||
|
||||
Args:
|
||||
storage_type: 存储类型 (neo4j/rag)
|
||||
end_user_id: 终端用户ID
|
||||
user_message: 用户消息内容
|
||||
ai_message: AI 回复内容
|
||||
user_rag_memory_id: RAG 记忆ID
|
||||
actual_end_user_id: 实际用户ID
|
||||
actual_config_id: 配置ID
|
||||
|
||||
逻辑说明:
|
||||
- RAG 模式:组合 user_message 和 ai_message 为字符串格式,保持原有逻辑不变
|
||||
- Neo4j 模式:使用结构化消息列表
|
||||
1. 如果 user_message 和 ai_message 都不为空:创建配对消息 [user, assistant]
|
||||
2. 如果只有 user_message:创建单条用户消息 [user](用于历史记忆场景)
|
||||
3. 每条消息会被转换为独立的 Chunk,保留 speaker 字段
|
||||
"""
|
||||
|
||||
db = next(get_db())
|
||||
try:
|
||||
actual_config_id = resolve_config_id(actual_config_id, db)
|
||||
# Neo4j 模式:使用结构化消息列表
|
||||
structured_messages = []
|
||||
|
||||
# 始终添加用户消息(如果不为空)
|
||||
if isinstance(user_message, str) and user_message.strip() != "":
|
||||
structured_messages.append({"role": "user", "content": user_message})
|
||||
|
||||
# 只有当 AI 回复不为空时才添加 assistant 消息
|
||||
if isinstance(ai_message, str) and ai_message.strip() != "":
|
||||
structured_messages.append({"role": "assistant", "content": ai_message})
|
||||
|
||||
# 如果提供了 long_term_messages,使用它替代 structured_messages
|
||||
if long_term_messages and isinstance(long_term_messages, list):
|
||||
structured_messages = long_term_messages
|
||||
elif long_term_messages and isinstance(long_term_messages, str):
|
||||
# 如果是 JSON 字符串,先解析
|
||||
try:
|
||||
structured_messages = json.loads(long_term_messages)
|
||||
except json.JSONDecodeError:
|
||||
logger.error(f"Failed to parse long_term_messages as JSON: {long_term_messages}")
|
||||
|
||||
# 如果没有消息,直接返回
|
||||
if not structured_messages:
|
||||
logger.warning(f"No messages to write for user {actual_end_user_id}")
|
||||
return
|
||||
|
||||
logger.info(
|
||||
f"[WRITE] Submitting Celery task - user={actual_end_user_id}, messages={len(structured_messages)}, config={actual_config_id}")
|
||||
write_id = write_message_task.delay(
|
||||
actual_end_user_id, # end_user_id: 用户ID
|
||||
structured_messages, # message: JSON 字符串格式的消息列表
|
||||
str(actual_config_id), # config_id: 配置ID字符串
|
||||
storage_type, # storage_type: "neo4j"
|
||||
user_rag_memory_id or "" # user_rag_memory_id: RAG记忆ID(Neo4j模式下不使用)
|
||||
)
|
||||
logger.info(f"[WRITE] Celery task submitted - task_id={write_id}")
|
||||
write_status = get_task_memory_write_result(str(write_id))
|
||||
logger.info(f'[WRITE] Task result - user={actual_end_user_id}, status={write_status}')
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
async def term_memory_save(long_term_messages,actual_config_id,end_user_id,type,scope):
|
||||
with get_db_context() as db_session:
|
||||
repo = LongTermMemoryRepository(db_session)
|
||||
|
||||
|
||||
from app.core.memory.agent.utils.redis_tool import write_store
|
||||
result = write_store.get_session_by_userid(end_user_id)
|
||||
if type==AgentMemory_Long_Term.STRATEGY_CHUNK or AgentMemory_Long_Term.STRATEGY_AGGREGATE:
|
||||
data = await format_parsing(result, "dict")
|
||||
chunk_data = data[:scope]
|
||||
if len(chunk_data)==scope:
|
||||
repo.upsert(end_user_id, chunk_data)
|
||||
logger.info(f'---------写入短长期-----------')
|
||||
else:
|
||||
long_time_data = write_store.find_user_recent_sessions(end_user_id, 5)
|
||||
long_messages = await messages_parse(long_time_data)
|
||||
repo.upsert(end_user_id, long_messages)
|
||||
logger.info(f'写入短长期:')
|
||||
|
||||
|
||||
|
||||
'''根据窗口'''
|
||||
async def window_dialogue(end_user_id,langchain_messages,memory_config,scope):
|
||||
'''
|
||||
根据窗口获取redis数据,写入neo4j:
|
||||
Args:
|
||||
end_user_id: 终端用户ID
|
||||
memory_config: 内存配置对象
|
||||
langchain_messages:原始数据LIST
|
||||
scope:窗口大小
|
||||
'''
|
||||
scope=scope
|
||||
is_end_user_id = count_store.get_sessions_count(end_user_id)
|
||||
if is_end_user_id is not False:
|
||||
is_end_user_id = count_store.get_sessions_count(end_user_id)[0]
|
||||
redis_messages = count_store.get_sessions_count(end_user_id)[1]
|
||||
if is_end_user_id and int(is_end_user_id) != int(scope):
|
||||
is_end_user_id += 1
|
||||
langchain_messages += redis_messages
|
||||
count_store.update_sessions_count(end_user_id, is_end_user_id, langchain_messages)
|
||||
elif int(is_end_user_id) == int(scope):
|
||||
logger.info('写入长期记忆NEO4J')
|
||||
formatted_messages = (redis_messages)
|
||||
# 获取 config_id(如果 memory_config 是对象,提取 config_id;否则直接使用)
|
||||
if hasattr(memory_config, 'config_id'):
|
||||
config_id = memory_config.config_id
|
||||
else:
|
||||
config_id = memory_config
|
||||
|
||||
await write(AgentMemory_Long_Term.STORAGE_NEO4J, end_user_id, "", "", None, end_user_id,
|
||||
config_id, formatted_messages)
|
||||
count_store.update_sessions_count(end_user_id, 1, langchain_messages)
|
||||
else:
|
||||
count_store.save_sessions_count(end_user_id, 1, langchain_messages)
|
||||
|
||||
|
||||
"""根据时间"""
|
||||
async def memory_long_term_storage(end_user_id,memory_config,time):
|
||||
'''
|
||||
根据时间获取redis数据,写入neo4j:
|
||||
Args:
|
||||
end_user_id: 终端用户ID
|
||||
memory_config: 内存配置对象
|
||||
'''
|
||||
long_time_data = write_store.find_user_recent_sessions(end_user_id, time)
|
||||
format_messages = (long_time_data)
|
||||
messages=[]
|
||||
memory_config=memory_config.config_id
|
||||
for i in format_messages:
|
||||
message=json.loads(i['Query'])
|
||||
messages+= message
|
||||
if format_messages!=[]:
|
||||
await write(AgentMemory_Long_Term.STORAGE_NEO4J, end_user_id, "", "", None, end_user_id,
|
||||
memory_config, messages)
|
||||
'''聚合判断'''
|
||||
async def aggregate_judgment(end_user_id: str, ori_messages: list, memory_config) -> dict:
|
||||
"""
|
||||
聚合判断函数:判断输入句子和历史消息是否描述同一事件
|
||||
|
||||
Args:
|
||||
end_user_id: 终端用户ID
|
||||
ori_messages: 原始消息列表,格式如 [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]
|
||||
memory_config: 内存配置对象
|
||||
"""
|
||||
|
||||
try:
|
||||
# 1. 获取历史会话数据(使用新方法)
|
||||
result = write_store.get_all_sessions_by_end_user_id(end_user_id)
|
||||
history = await format_parsing(result)
|
||||
if not result:
|
||||
history = []
|
||||
else:
|
||||
history = await format_parsing(result)
|
||||
json_schema = WriteAggregateModel.model_json_schema()
|
||||
template_service = TemplateService(template_root)
|
||||
system_prompt = await template_service.render_template(
|
||||
template_name='write_aggregate_judgment.jinja2',
|
||||
operation_name='aggregate_judgment',
|
||||
history=history,
|
||||
sentence=ori_messages,
|
||||
json_schema=json_schema
|
||||
)
|
||||
with get_db_context() as db_session:
|
||||
factory = MemoryClientFactory(db_session)
|
||||
llm_client = factory.get_llm_client(memory_config.llm_model_id)
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": system_prompt
|
||||
}
|
||||
]
|
||||
structured = await llm_client.response_structured(
|
||||
messages=messages,
|
||||
response_model=WriteAggregateModel
|
||||
)
|
||||
output_value = structured.output
|
||||
if isinstance(output_value, list):
|
||||
output_value = [
|
||||
{"role": msg.role, "content": msg.content}
|
||||
for msg in output_value
|
||||
]
|
||||
|
||||
result_dict = {
|
||||
"is_same_event": structured.is_same_event,
|
||||
"output": output_value
|
||||
}
|
||||
if not structured.is_same_event:
|
||||
logger.info(result_dict)
|
||||
await write("neo4j", end_user_id, "", "", None, end_user_id,
|
||||
memory_config.config_id, output_value)
|
||||
return result_dict
|
||||
|
||||
except Exception as e:
|
||||
print(f"[aggregate_judgment] 发生错误: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
return {
|
||||
"is_same_event": False,
|
||||
"output": ori_messages,
|
||||
"messages": ori_messages,
|
||||
"history": history if 'history' in locals() else [],
|
||||
"error": str(e)
|
||||
}
|
||||
@@ -48,11 +48,11 @@ def extract_tool_message_content(response):
|
||||
class TimeRetrievalInput(BaseModel):
|
||||
"""时间检索工具的输入模式"""
|
||||
context: str = Field(description="用户输入的查询内容")
|
||||
group_id: str = Field(default="88a459f5_text09", description="组ID,用于过滤搜索结果")
|
||||
end_user_id: str = Field(default="88a459f5_text09", description="组ID,用于过滤搜索结果")
|
||||
|
||||
def create_time_retrieval_tool(group_id: str):
|
||||
def create_time_retrieval_tool(end_user_id: str):
|
||||
"""
|
||||
创建一个带有特定group_id的TimeRetrieval工具(同步版本),用于按时间范围搜索语句(Statements)
|
||||
创建一个带有特定end_user_id的TimeRetrieval工具(同步版本),用于按时间范围搜索语句(Statements)
|
||||
"""
|
||||
|
||||
def clean_temporal_result_fields(data):
|
||||
@@ -93,26 +93,26 @@ def create_time_retrieval_tool(group_id: str):
|
||||
return data
|
||||
|
||||
@tool
|
||||
def TimeRetrievalWithGroupId(context: str, start_date: str = None, end_date: str = None, group_id_param: str = None, clean_output: bool = True) -> str:
|
||||
def TimeRetrievalWithGroupId(context: str, start_date: str = None, end_date: str = None, end_user_id_param: str = None, clean_output: bool = True) -> str:
|
||||
"""
|
||||
优化的时间检索工具,只结合时间范围搜索(同步版本),自动过滤不需要的元数据字段
|
||||
显式接收参数:
|
||||
- context: 查询上下文内容
|
||||
- start_date: 开始时间(可选,格式:YYYY-MM-DD)
|
||||
- end_date: 结束时间(可选,格式:YYYY-MM-DD)
|
||||
- group_id_param: 组ID(可选,用于覆盖默认组ID)
|
||||
- end_user_id_param: 组ID(可选,用于覆盖默认组ID)
|
||||
- clean_output: 是否清理输出中的元数据字段
|
||||
-end_date 需要根据用户的描述获取结束的时间,输出格式用strftime("%Y-%m-%d")
|
||||
"""
|
||||
async def _async_search():
|
||||
# 使用传入的参数或默认值
|
||||
actual_group_id = group_id_param or group_id
|
||||
actual_end_user_id = end_user_id_param or end_user_id
|
||||
actual_end_date = end_date or datetime.now().strftime("%Y-%m-%d")
|
||||
actual_start_date = start_date or (datetime.now() - timedelta(days=7)).strftime("%Y-%m-%d")
|
||||
|
||||
# 基本时间搜索
|
||||
results = await search_by_temporal(
|
||||
group_id=actual_group_id,
|
||||
end_user_id=actual_end_user_id,
|
||||
start_date=actual_start_date,
|
||||
end_date=actual_end_date,
|
||||
limit=10
|
||||
@@ -147,7 +147,7 @@ def create_time_retrieval_tool(group_id: str):
|
||||
# 关键词时间搜索
|
||||
results = await search_by_keyword_temporal(
|
||||
query_text=context,
|
||||
group_id=group_id,
|
||||
end_user_id=end_user_id,
|
||||
start_date=actual_start_date,
|
||||
end_date=actual_end_date,
|
||||
limit=15
|
||||
@@ -172,7 +172,7 @@ def create_hybrid_retrieval_tool_async(memory_config, **search_params):
|
||||
|
||||
Args:
|
||||
memory_config: 内存配置对象
|
||||
**search_params: 搜索参数,包含group_id, limit, include等
|
||||
**search_params: 搜索参数,包含end_user_id, limit, include等
|
||||
"""
|
||||
|
||||
def clean_result_fields(data):
|
||||
@@ -186,10 +186,11 @@ def create_hybrid_retrieval_tool_async(memory_config, **search_params):
|
||||
清理后的数据
|
||||
"""
|
||||
# 需要过滤的字段列表
|
||||
# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
|
||||
fields_to_remove = {
|
||||
'invalid_at', 'valid_at', 'chunk_id_from_rel', 'entity_ids',
|
||||
'expired_at', 'created_at', 'chunk_id', 'id', 'apply_id',
|
||||
'user_id', 'statement_ids', 'updated_at',"chunk_ids","fact_summary"
|
||||
'user_id', 'statement_ids', 'updated_at',"chunk_ids" ,"fact_summary"
|
||||
}
|
||||
|
||||
if isinstance(data, dict):
|
||||
@@ -211,7 +212,7 @@ def create_hybrid_retrieval_tool_async(memory_config, **search_params):
|
||||
context: str,
|
||||
search_type: str = "hybrid",
|
||||
limit: int = 10,
|
||||
group_id: str = None,
|
||||
end_user_id: str = None,
|
||||
rerank_alpha: float = 0.6,
|
||||
use_forgetting_rerank: bool = False,
|
||||
use_llm_rerank: bool = False,
|
||||
@@ -224,7 +225,7 @@ def create_hybrid_retrieval_tool_async(memory_config, **search_params):
|
||||
context: 查询内容
|
||||
search_type: 搜索类型 ('keyword', 'embedding', 'hybrid')
|
||||
limit: 结果数量限制
|
||||
group_id: 组ID,用于过滤搜索结果
|
||||
end_user_id: 组ID,用于过滤搜索结果
|
||||
rerank_alpha: 重排序权重参数
|
||||
use_forgetting_rerank: 是否使用遗忘重排序
|
||||
use_llm_rerank: 是否使用LLM重排序
|
||||
@@ -238,7 +239,7 @@ def create_hybrid_retrieval_tool_async(memory_config, **search_params):
|
||||
final_params = {
|
||||
"query_text": context,
|
||||
"search_type": search_type,
|
||||
"group_id": group_id or search_params.get("group_id"),
|
||||
"end_user_id": end_user_id or search_params.get("end_user_id"),
|
||||
"limit": limit or search_params.get("limit", 10),
|
||||
"include": search_params.get("include", ["summaries", "statements", "chunks", "entities"]),
|
||||
"output_path": None, # 不保存到文件
|
||||
@@ -291,7 +292,7 @@ def create_hybrid_retrieval_tool_sync(memory_config, **search_params):
|
||||
context: str,
|
||||
search_type: str = "hybrid",
|
||||
limit: int = 10,
|
||||
group_id: str = None,
|
||||
end_user_id: str = None,
|
||||
clean_output: bool = True
|
||||
) -> str:
|
||||
"""
|
||||
@@ -301,7 +302,7 @@ def create_hybrid_retrieval_tool_sync(memory_config, **search_params):
|
||||
context: 查询内容
|
||||
search_type: 搜索类型 ('keyword', 'embedding', 'hybrid')
|
||||
limit: 结果数量限制
|
||||
group_id: 组ID,用于过滤搜索结果
|
||||
end_user_id: 组ID,用于过滤搜索结果
|
||||
clean_output: 是否清理输出中的元数据字段
|
||||
"""
|
||||
async def _async_search():
|
||||
@@ -311,7 +312,7 @@ def create_hybrid_retrieval_tool_sync(memory_config, **search_params):
|
||||
"context": context,
|
||||
"search_type": search_type,
|
||||
"limit": limit,
|
||||
"group_id": group_id,
|
||||
"end_user_id": end_user_id,
|
||||
"clean_output": clean_output
|
||||
})
|
||||
|
||||
|
||||
@@ -0,0 +1,72 @@
|
||||
import json
|
||||
|
||||
from langchain_core.messages import HumanMessage, AIMessage
|
||||
async def format_parsing(messages: list,type:str='string'):
|
||||
"""
|
||||
格式化解析消息列表
|
||||
|
||||
Args:
|
||||
messages: 消息列表
|
||||
type: 返回类型 ('string' 或 'dict')
|
||||
|
||||
Returns:
|
||||
格式化后的消息列表
|
||||
"""
|
||||
result = []
|
||||
user=[]
|
||||
ai=[]
|
||||
|
||||
for message in messages:
|
||||
hstory_messages = message['messages']
|
||||
for history_messag in hstory_messages.strip().splitlines():
|
||||
history_messag = json.loads(history_messag)
|
||||
for content in history_messag:
|
||||
role = content['role']
|
||||
content = content['content']
|
||||
if type == "string":
|
||||
if role == 'human' or role=="user":
|
||||
content = '用户:' + content
|
||||
else:
|
||||
content = 'AI:' + content
|
||||
result.append(content)
|
||||
if type == "dict" :
|
||||
if role == 'human' or role=="user":
|
||||
user.append( content)
|
||||
else:
|
||||
ai.append(content)
|
||||
if type == "dict":
|
||||
for key,values in zip(user,ai):
|
||||
result.append({key:values})
|
||||
return result
|
||||
|
||||
async def messages_parse(messages: list | dict):
|
||||
user=[]
|
||||
ai=[]
|
||||
database=[]
|
||||
for message in messages:
|
||||
Query = message['Query']
|
||||
Query = json.loads(Query)
|
||||
for data in Query:
|
||||
role = data['role']
|
||||
if role == "human":
|
||||
user.append(data['content'])
|
||||
if role == "ai":
|
||||
ai.append(data['content'])
|
||||
for key, values in zip(user, ai):
|
||||
database.append({key, values})
|
||||
return database
|
||||
|
||||
|
||||
async def agent_chat_messages(user_content,ai_content):
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"{user_content}"
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": f"{ai_content}"
|
||||
}
|
||||
|
||||
]
|
||||
return messages
|
||||
@@ -1,33 +1,37 @@
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import sys
|
||||
import warnings
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langgraph.constants import END, START
|
||||
from langgraph.graph import StateGraph
|
||||
|
||||
|
||||
from app.db import get_db
|
||||
from app.db import get_db, get_db_context
|
||||
from app.core.logging_config import get_agent_logger
|
||||
from app.core.memory.agent.utils.llm_tools import WriteState
|
||||
from app.core.memory.agent.langgraph_graph.nodes.write_nodes import write_node
|
||||
from app.schemas.memory_agent_schema import AgentMemory_Long_Term
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
|
||||
|
||||
warnings.filterwarnings("ignore", category=RuntimeWarning)
|
||||
logger = get_agent_logger(__name__)
|
||||
|
||||
if sys.platform.startswith("win"):
|
||||
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def make_write_graph():
|
||||
"""
|
||||
Create a write graph workflow for memory operations.
|
||||
|
||||
The workflow directly processes messages from the initial state
|
||||
and saves them to Neo4j storage.
|
||||
Args:
|
||||
user_id: User identifier
|
||||
tools: MCP tools loaded from session
|
||||
apply_id: Application identifier
|
||||
end_user_id: Group identifier
|
||||
memory_config: MemoryConfig object containing all configuration
|
||||
"""
|
||||
workflow = StateGraph(WriteState)
|
||||
workflow.add_node("save_neo4j", write_node)
|
||||
@@ -38,43 +42,63 @@ async def make_write_graph():
|
||||
|
||||
yield graph
|
||||
|
||||
|
||||
async def main():
|
||||
"""主函数 - 运行工作流"""
|
||||
message = "今天周一"
|
||||
group_id = 'new_2025test1103' # 组ID
|
||||
|
||||
|
||||
async def long_term_storage(long_term_type:str="chunk",langchain_messages:list=[],memory_config:str='',end_user_id:str='',scope:int=6):
|
||||
from app.core.memory.agent.langgraph_graph.routing.write_router import memory_long_term_storage, window_dialogue,aggregate_judgment
|
||||
from app.core.memory.agent.utils.redis_tool import write_store
|
||||
write_store.save_session_write(end_user_id, (langchain_messages))
|
||||
# 获取数据库会话
|
||||
db_session = next(get_db())
|
||||
config_service = MemoryConfigService(db_session)
|
||||
memory_config = config_service.load_memory_config(
|
||||
config_id=17, # 改为整数
|
||||
service_name="MemoryAgentService"
|
||||
)
|
||||
try:
|
||||
async with make_write_graph() as graph:
|
||||
config = {"configurable": {"thread_id": group_id}}
|
||||
# 初始状态 - 包含所有必要字段
|
||||
initial_state = {"messages": [HumanMessage(content=message)], "group_id": group_id, "memory_config": memory_config}
|
||||
|
||||
# 获取节点更新信息
|
||||
async for update_event in graph.astream(
|
||||
initial_state,
|
||||
stream_mode="updates",
|
||||
config=config
|
||||
):
|
||||
for node_name, node_data in update_event.items():
|
||||
if 'save_neo4j'==node_name:
|
||||
massages=node_data
|
||||
massages=massages.get('write_result')['status']
|
||||
print(massages) # | 更新数据: {node_data}
|
||||
|
||||
except Exception as e:
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
with get_db_context() as db_session:
|
||||
config_service = MemoryConfigService(db_session)
|
||||
memory_config = config_service.load_memory_config(
|
||||
config_id=memory_config, # 改为整数
|
||||
service_name="MemoryAgentService"
|
||||
)
|
||||
if long_term_type=='chunk':
|
||||
'''方案一:对话窗口6轮对话'''
|
||||
await window_dialogue(end_user_id,langchain_messages,memory_config,scope)
|
||||
if long_term_type=='time':
|
||||
"""时间"""
|
||||
await memory_long_term_storage(end_user_id, memory_config,5)
|
||||
if long_term_type=='aggregate':
|
||||
"""方案三:聚合判断"""
|
||||
await aggregate_judgment(end_user_id, langchain_messages, memory_config)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
asyncio.run(main())
|
||||
|
||||
async def write_long_term(storage_type,end_user_id,message_chat,aimessages,user_rag_memory_id,actual_config_id):
|
||||
from app.core.memory.agent.langgraph_graph.routing.write_router import write_rag_agent
|
||||
from app.core.memory.agent.langgraph_graph.routing.write_router import term_memory_save
|
||||
from app.core.memory.agent.langgraph_graph.tools.write_tool import agent_chat_messages
|
||||
if storage_type == AgentMemory_Long_Term.STORAGE_RAG:
|
||||
await write_rag_agent(end_user_id, message_chat, aimessages, user_rag_memory_id)
|
||||
else:
|
||||
# AI 回复写入(用户消息和 AI 回复配对,一次性写入完整对话)
|
||||
CHUNK = AgentMemory_Long_Term.STRATEGY_CHUNK
|
||||
SCOPE = AgentMemory_Long_Term.DEFAULT_SCOPE
|
||||
long_term_messages = await agent_chat_messages(message_chat, aimessages)
|
||||
await long_term_storage(long_term_type=CHUNK, langchain_messages=long_term_messages,
|
||||
memory_config=actual_config_id, end_user_id=end_user_id, scope=SCOPE)
|
||||
await term_memory_save(long_term_messages, actual_config_id, end_user_id, CHUNK, scope=SCOPE)
|
||||
|
||||
# async def main():
|
||||
# """主函数 - 运行工作流"""
|
||||
# langchain_messages = [
|
||||
# {
|
||||
# "role": "user",
|
||||
# "content": "今天周五去爬山"
|
||||
# },
|
||||
# {
|
||||
# "role": "assistant",
|
||||
# "content": "好耶"
|
||||
# }
|
||||
#
|
||||
# ]
|
||||
# end_user_id = '837fee1b-04a2-48ee-94d7-211488908940' # 组ID
|
||||
# memory_config="08ed205c-0f05-49c3-8e0c-a580d28f5fd4"
|
||||
# await long_term_storage(long_term_type="chunk",langchain_messages=langchain_messages,memory_config=memory_config,end_user_id=end_user_id,scope=2)
|
||||
#
|
||||
#
|
||||
#
|
||||
# if __name__ == "__main__":
|
||||
# import asyncio
|
||||
# asyncio.run(main())
|
||||
28
api/app/core/memory/agent/models/write_aggregate_model.py
Normal file
28
api/app/core/memory/agent/models/write_aggregate_model.py
Normal file
@@ -0,0 +1,28 @@
|
||||
"""Pydantic models for write aggregate judgment operations."""
|
||||
|
||||
from typing import List, Union
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class MessageItem(BaseModel):
|
||||
"""Individual message item in conversation."""
|
||||
|
||||
role: str = Field(..., description="角色:user 或 assistant")
|
||||
content: str = Field(..., description="消息内容")
|
||||
|
||||
|
||||
class WriteAggregateResponse(BaseModel):
|
||||
"""Response model for aggregate judgment containing judgment result and output."""
|
||||
|
||||
is_same_event: bool = Field(
|
||||
...,
|
||||
description="是否是同一事件。True表示是同一事件,False表示不同事件"
|
||||
)
|
||||
output: Union[List[MessageItem], bool] = Field(
|
||||
...,
|
||||
description="如果is_same_event为True,返回False;如果is_same_event为False,返回消息列表"
|
||||
)
|
||||
|
||||
|
||||
# 为了保持向后兼容,保留旧的类名作为别名
|
||||
WriteAggregateModel = WriteAggregateResponse
|
||||
@@ -24,7 +24,7 @@ class ParameterBuilder:
|
||||
tool_call_id: str,
|
||||
search_switch: str,
|
||||
apply_id: str,
|
||||
group_id: str,
|
||||
end_user_id: str,
|
||||
storage_type: Optional[str] = None,
|
||||
user_rag_memory_id: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
@@ -44,7 +44,7 @@ class ParameterBuilder:
|
||||
tool_call_id: Extracted tool call identifier
|
||||
search_switch: Search routing parameter
|
||||
apply_id: Application identifier
|
||||
group_id: Group identifier
|
||||
end_user_id: Group identifier
|
||||
storage_type: Storage type for the workspace (optional)
|
||||
user_rag_memory_id: User RAG memory ID for knowledge base retrieval (optional)
|
||||
|
||||
@@ -55,7 +55,7 @@ class ParameterBuilder:
|
||||
base_args = {
|
||||
"usermessages": tool_call_id,
|
||||
"apply_id": apply_id,
|
||||
"group_id": group_id
|
||||
"end_user_id": end_user_id
|
||||
}
|
||||
|
||||
# Always add storage_type and user_rag_memory_id (with defaults if None)
|
||||
|
||||
@@ -91,7 +91,7 @@ class SearchService:
|
||||
|
||||
async def execute_hybrid_search(
|
||||
self,
|
||||
group_id: str,
|
||||
end_user_id: str,
|
||||
question: str,
|
||||
limit: int = 5,
|
||||
search_type: str = "hybrid",
|
||||
@@ -105,7 +105,7 @@ class SearchService:
|
||||
Execute hybrid search and return clean content.
|
||||
|
||||
Args:
|
||||
group_id: Group identifier for filtering results
|
||||
end_user_id: Group identifier for filtering results
|
||||
question: Search query text
|
||||
limit: Maximum number of results to return (default: 5)
|
||||
search_type: Type of search - "hybrid", "keyword", or "embedding" (default: "hybrid")
|
||||
@@ -130,7 +130,7 @@ class SearchService:
|
||||
answer = await run_hybrid_search(
|
||||
query_text=cleaned_query,
|
||||
search_type=search_type,
|
||||
group_id=group_id,
|
||||
end_user_id=end_user_id,
|
||||
limit=limit,
|
||||
include=include,
|
||||
output_path=output_path,
|
||||
@@ -186,7 +186,7 @@ class SearchService:
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Search failed for query '{question}' in group '{group_id}': {e}",
|
||||
f"Search failed for query '{question}' in group '{end_user_id}': {e}",
|
||||
exc_info=True
|
||||
)
|
||||
# Return empty results on failure
|
||||
|
||||
@@ -59,7 +59,7 @@ class SessionService:
|
||||
self,
|
||||
user_id: str,
|
||||
apply_id: str,
|
||||
group_id: str
|
||||
end_user_id: str
|
||||
) -> List[dict]:
|
||||
"""
|
||||
Retrieve conversation history from Redis.
|
||||
@@ -67,20 +67,20 @@ class SessionService:
|
||||
Args:
|
||||
user_id: User identifier
|
||||
apply_id: Application identifier
|
||||
group_id: Group identifier
|
||||
end_user_id: Group identifier
|
||||
|
||||
Returns:
|
||||
List of conversation history items with Query and Answer keys
|
||||
Returns empty list if no history found or on error
|
||||
"""
|
||||
try:
|
||||
history = self.store.find_user_apply_group(user_id, apply_id, group_id)
|
||||
history = self.store.find_user_apply_group(user_id, apply_id, end_user_id)
|
||||
|
||||
# Validate history structure
|
||||
if not isinstance(history, list):
|
||||
logger.warning(
|
||||
f"Invalid history format for user {user_id}, "
|
||||
f"apply {apply_id}, group {group_id}: expected list, got {type(history)}"
|
||||
f"apply {apply_id}, group {end_user_id}: expected list, got {type(history)}"
|
||||
)
|
||||
return []
|
||||
|
||||
@@ -89,7 +89,7 @@ class SessionService:
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to retrieve history for user {user_id}, "
|
||||
f"apply {apply_id}, group {group_id}: {e}",
|
||||
f"apply {apply_id}, group {end_user_id}: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
# Return empty list on error to allow execution to continue
|
||||
@@ -100,7 +100,7 @@ class SessionService:
|
||||
user_id: str,
|
||||
query: str,
|
||||
apply_id: str,
|
||||
group_id: str,
|
||||
end_user_id: str,
|
||||
ai_response: str
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
@@ -110,7 +110,7 @@ class SessionService:
|
||||
user_id: User identifier
|
||||
query: User query/message
|
||||
apply_id: Application identifier
|
||||
group_id: Group identifier
|
||||
end_user_id: Group identifier
|
||||
ai_response: AI response/answer
|
||||
|
||||
Returns:
|
||||
@@ -131,7 +131,7 @@ class SessionService:
|
||||
userid=user_id,
|
||||
messages=query,
|
||||
apply_id=apply_id,
|
||||
group_id=group_id,
|
||||
end_user_id=end_user_id,
|
||||
aimessages=ai_response
|
||||
)
|
||||
|
||||
@@ -152,7 +152,7 @@ class SessionService:
|
||||
Duplicates are identified by matching:
|
||||
- sessionid
|
||||
- user_id (id field)
|
||||
- group_id
|
||||
- end_user_id
|
||||
- messages
|
||||
- aimessages
|
||||
|
||||
|
||||
@@ -9,9 +9,7 @@ from app.core.memory.models.message_models import DialogData, ConversationContex
|
||||
|
||||
async def get_chunked_dialogs(
|
||||
chunker_strategy: str = "RecursiveChunker",
|
||||
group_id: str = "group_1",
|
||||
user_id: str = "user1",
|
||||
apply_id: str = "applyid",
|
||||
end_user_id: str = "group_1",
|
||||
messages: list = None,
|
||||
ref_id: str = "wyl_20251027",
|
||||
config_id: str = None
|
||||
@@ -20,9 +18,7 @@ async def get_chunked_dialogs(
|
||||
|
||||
Args:
|
||||
chunker_strategy: The chunking strategy to use (default: RecursiveChunker)
|
||||
group_id: Group identifier
|
||||
user_id: User identifier
|
||||
apply_id: Application identifier
|
||||
end_user_id: Group identifier
|
||||
messages: Structured message list [{"role": "user", "content": "..."}, ...]
|
||||
ref_id: Reference identifier
|
||||
config_id: Configuration ID for processing
|
||||
@@ -32,42 +28,40 @@ async def get_chunked_dialogs(
|
||||
"""
|
||||
from app.core.logging_config import get_agent_logger
|
||||
logger = get_agent_logger(__name__)
|
||||
|
||||
|
||||
if not messages or not isinstance(messages, list) or len(messages) == 0:
|
||||
raise ValueError("messages parameter must be a non-empty list")
|
||||
|
||||
|
||||
conversation_messages = []
|
||||
|
||||
|
||||
for idx, msg in enumerate(messages):
|
||||
if not isinstance(msg, dict) or 'role' not in msg or 'content' not in msg:
|
||||
raise ValueError(f"Message {idx} format error: must contain 'role' and 'content' fields")
|
||||
|
||||
|
||||
role = msg['role']
|
||||
content = msg['content']
|
||||
|
||||
|
||||
if role not in ['user', 'assistant']:
|
||||
raise ValueError(f"Message {idx} role must be 'user' or 'assistant', got: {role}")
|
||||
|
||||
|
||||
if content.strip():
|
||||
conversation_messages.append(ConversationMessage(role=role, msg=content.strip()))
|
||||
|
||||
|
||||
if not conversation_messages:
|
||||
raise ValueError("Message list cannot be empty after filtering")
|
||||
|
||||
|
||||
conversation_context = ConversationContext(msgs=conversation_messages)
|
||||
dialog_data = DialogData(
|
||||
context=conversation_context,
|
||||
ref_id=ref_id,
|
||||
group_id=group_id,
|
||||
user_id=user_id,
|
||||
apply_id=apply_id,
|
||||
end_user_id=end_user_id,
|
||||
config_id=config_id
|
||||
)
|
||||
|
||||
|
||||
chunker = DialogueChunker(chunker_strategy)
|
||||
extracted_chunks = await chunker.process_dialogue(dialog_data)
|
||||
dialog_data.chunks = extracted_chunks
|
||||
|
||||
|
||||
logger.info(f"DialogData created with {len(extracted_chunks)} chunks")
|
||||
|
||||
return [dialog_data]
|
||||
|
||||
@@ -1,24 +1,24 @@
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Annotated, TypedDict
|
||||
|
||||
from langchain_core.messages import AnyMessage
|
||||
from langgraph.graph import add_messages
|
||||
|
||||
PROJECT_ROOT_ = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
PROJECT_ROOT_ = str(Path(__file__).resolve().parents[3])
|
||||
|
||||
class WriteState(TypedDict):
|
||||
'''
|
||||
Langgrapg Writing TypedDict
|
||||
'''
|
||||
messages: Annotated[list[AnyMessage], add_messages]
|
||||
user_id:str
|
||||
apply_id:str
|
||||
group_id:str
|
||||
end_user_id: str
|
||||
errors: list[dict] # Track errors: [{"tool": "tool_name", "error": "message"}]
|
||||
memory_config: object
|
||||
write_result: dict
|
||||
data:str
|
||||
data: str
|
||||
language: str # 语言类型 ("zh" 中文, "en" 英文)
|
||||
|
||||
class ReadState(TypedDict):
|
||||
"""
|
||||
@@ -28,7 +28,7 @@ class ReadState(TypedDict):
|
||||
messages: 消息列表,支持自动追加
|
||||
loop_count: 遍历次数
|
||||
search_switch: 搜索类型开关
|
||||
group_id: 组标识
|
||||
end_user_id: 组标识
|
||||
config_id: 配置ID,用于过滤结果
|
||||
data: 从content_input_node传递的内容数据
|
||||
spit_data: 从Split_The_Problem传递的分解结果
|
||||
@@ -39,7 +39,7 @@ class ReadState(TypedDict):
|
||||
messages: Annotated[list[AnyMessage], add_messages] # 消息追加模式
|
||||
loop_count: int
|
||||
search_switch: str
|
||||
group_id: str
|
||||
end_user_id: str
|
||||
config_id: str
|
||||
data: str # 新增字段用于传递内容
|
||||
spit_data: dict # 新增字段用于传递问题分解结果
|
||||
|
||||
@@ -0,0 +1,61 @@
|
||||
# 角色
|
||||
你是一个智能问答助手,基于检索信息和历史对话回答用户问题。
|
||||
# 任务
|
||||
根据提供的上下文信息回答用户的问题。
|
||||
# 输入信息
|
||||
- 历史对话:{{history}}
|
||||
- 检索信息:{{retrieve_info}}
|
||||
# 用户问题
|
||||
{{query}}
|
||||
# 回答指南
|
||||
## 1. 仔细阅读检索信息
|
||||
- 答案可能直接或间接地出现在检索信息中
|
||||
- 如果检索信息中提到"小曼会使用Python",说明用户名是"小曼"
|
||||
- 第三人称描述的偏好、行为通常指用户本人
|
||||
|
||||
## 2. 判断信息相关性
|
||||
**情况A:信息匹配问题**
|
||||
- 直接回答,像自然对话一样
|
||||
- 例:检索到"小曼会使用Python" → 问"我叫什么" → 答"你叫小曼"
|
||||
|
||||
**情况B:信息部分相关**
|
||||
- 先回答已知部分,再自然地询问更多信息
|
||||
- 例:检索到"用户去过上海的面包店" → 问"我吃过哪家面包" → 答"我记得你去过上海的面包店,但具体是哪家我不太清楚,是哪家呢?"
|
||||
|
||||
**情况C:信息完全不相关**
|
||||
- 自然地表达不知道,但可以提及检索到的相关信息,让对话更连贯
|
||||
- 使用友好的表达:
|
||||
- "你好像没和我说过...,但是我知道你[检索到的相关信息]"
|
||||
- "关于这个我不太清楚,不过我记得你[检索到的相关信息],能告诉我更多吗?"
|
||||
- "我不记得你提到过...,但你[检索到的相关信息]"
|
||||
- 即使检索信息不直接回答问题,也可以自然地融入对话中
|
||||
- 避免僵硬的"信息不足,无法回答"
|
||||
## 3. 回答要求
|
||||
- 像人类对话一样自然流畅
|
||||
- 不要提及"检索信息"、"搜索结果"、"根据资料"等技术术语
|
||||
- 不要解释推理过程或引用信息来源
|
||||
- 保持友好、乐于助人的语气
|
||||
- 使用与问题相同的语言回答
|
||||
# 关键示例
|
||||
**示例1 - 直接匹配:**
|
||||
- 检索信息:"小曼会使用Python..."
|
||||
- 问题:"我叫什么"
|
||||
- ✓ 正确:"你叫小曼"
|
||||
- ✗ 错误:"你没有告诉我你的名字"
|
||||
**示例2 - 间接匹配:**
|
||||
- 检索信息:"用户很喜欢吃星巴克的甜品"
|
||||
- 问题:"我喜欢什么"
|
||||
- ✓ 正确:"你很喜欢吃星巴克的甜品"
|
||||
- ✗ 错误:"信息不足"
|
||||
**示例3 - 信息不匹配(推荐做法):**
|
||||
- 检索信息:"用户只喝拿铁咖啡,认为美式咖啡太苦"
|
||||
- 问题:"我吃过哪家面包"
|
||||
- ✓ 最佳:"你好像没和我说过吃过哪家面包,但是我知道你喜欢喝拿铁,能跟我分享一下吗?"
|
||||
- ✓ 可以:"你好像没和我说过吃过哪家面包,能跟我分享一下吗?"
|
||||
- ✗ 错误:"用户只喝拿铁咖啡,认为美式咖啡太苦。"(答非所问)
|
||||
- ✗ 错误:"信息不足,无法回答。"(太僵硬)
|
||||
# 重要提醒
|
||||
- 检索信息中描述用户行为/偏好时提到的名字,就是用户的名字
|
||||
- 信息不匹配时,不要强行回答无关内容,但可以自然地提及检索到的信息,让对话更有温度
|
||||
- 用对话式语言表达"不知道",而非机械模板
|
||||
- 检索信息代表你对用户的了解,即使不直接回答问题,也能体现你对用户的记忆
|
||||
@@ -0,0 +1,43 @@
|
||||
{# 角色定义 #}
|
||||
你是专业的问题解答专家+引导学者
|
||||
|
||||
{# 输入数据展示 #}
|
||||
{% if data %}
|
||||
## 输入数据
|
||||
上下文信息:
|
||||
{% for item in data.history %}
|
||||
- {{ item }}
|
||||
{% endfor %}
|
||||
检索到的所有信息:
|
||||
{% for item in data.retrieve_info %}
|
||||
- {{ item }}
|
||||
{% endfor %}
|
||||
{% endif %}
|
||||
|
||||
## User Query
|
||||
{{ query }}
|
||||
|
||||
{# 问题回答标准 #}
|
||||
## 问题回答核心标准
|
||||
根据上下文信息(history)和检索到的所有信息(retrieve_info)准确回答用户的问题(query)。
|
||||
注意,仔细阅读检索信息,答案可能直接或间接地出现在检索信息中或者历史上下文消息中,同时需要 判断信息相关性
|
||||
**情况A:信息匹配问题**
|
||||
- 直接回答,像自然对话一样
|
||||
- 例:检索到"小曼会使用Python" → 问"我叫什么" → 答"你叫小曼"
|
||||
|
||||
**情况B:信息部分相关**
|
||||
- 先回答已知部分,再自然地询问更多信息
|
||||
- 例:检索到"用户去过上海的面包店" → 问"我吃过哪家面包" → 答"我记得你去过上海的面包店,但具体是哪家我不太清楚,是哪家呢?"
|
||||
|
||||
**情况C:信息完全不相关**
|
||||
- 自然地表达不知道,但可以提及检索到的相关信息,让对话更连贯
|
||||
- 使用友好的表达:
|
||||
- "你好像没和我说过...,但是我知道你[检索到的相关信息]"
|
||||
- "关于这个我不太清楚,不过我记得你[检索到的相关信息],能告诉我更多吗?"
|
||||
- "我不记得你提到过...,但你[检索到的相关信息]"
|
||||
- 即使检索信息不直接回答问题,也可以自然地融入对话中
|
||||
- 避免僵硬的"信息不足,无法回答"
|
||||
|
||||
{# 重要提醒 #}
|
||||
当检索以及上下文的历史信息都无法回答的时候,可引导对方进行提问/回答,或者进行其他引导
|
||||
当检索或者上下文中出现了,相似的问题,可以委婉,提醒对方,我记得刚刚提过这个问题,但是我自己不记得了,能在描述一次吗~以此为例
|
||||
@@ -0,0 +1,57 @@
|
||||
输入句子:{{sentence}}
|
||||
历史消息:{{history}}
|
||||
|
||||
# 你的角色
|
||||
你是一个擅长事件聚合与语义判断的专家。
|
||||
|
||||
# 你的任务
|
||||
结合历史消息和输入句子,判断它们是否在描述**同一件事件或同一事件链**。
|
||||
|
||||
以下情况视为"同一事件"(需要返回 is_same_event=True, output=False):
|
||||
- 描述的是同一个具体事件或事实
|
||||
- 存在明显的因果关系、前后发展关系
|
||||
- 是对同一事件的补充、解释、追问或延展
|
||||
- 逻辑上属于同一语境下的连续讨论
|
||||
|
||||
以下情况视为"不同事件"(需要返回 is_same_event=False, output=消息列表):
|
||||
- 话题不同,事件主体不同
|
||||
- 时间、地点、对象明显不同
|
||||
- 只是语义相似,但并非同一具体事件
|
||||
- 无直接事件、因果或逻辑关联
|
||||
|
||||
# 输出规则(非常重要)
|
||||
你必须按照以下JSON格式输出:
|
||||
|
||||
**如果是同一事件:**
|
||||
```json
|
||||
{
|
||||
"is_same_event": true,
|
||||
"output": false
|
||||
}
|
||||
```
|
||||
|
||||
**如果不是同一事件:**
|
||||
```json
|
||||
{
|
||||
"is_same_event": false,
|
||||
"output": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "输入句子的内容"
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "对应的回复内容"
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
# JSON Schema
|
||||
{{json_schema}}
|
||||
|
||||
# 注意事项
|
||||
- 必须严格按照上述格式输出
|
||||
- output 字段:如果是同一事件返回 false,如果不是同一事件返回完整的消息列表
|
||||
- 消息列表必须包含 role 和 content 字段
|
||||
- 不要输出任何解释、分析或多余内容
|
||||
186
api/app/core/memory/agent/utils/redis_base.py
Normal file
186
api/app/core/memory/agent/utils/redis_base.py
Normal file
@@ -0,0 +1,186 @@
|
||||
import json
|
||||
from typing import Any, List, Dict, Optional
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
|
||||
def serialize_messages(messages: Any) -> str:
|
||||
"""
|
||||
将消息序列化为 JSON 字符串,支持 LangChain 消息对象
|
||||
|
||||
Args:
|
||||
messages: 可以是 list、dict、string 或 LangChain 消息对象列表
|
||||
|
||||
Returns:
|
||||
str: JSON 字符串
|
||||
"""
|
||||
if isinstance(messages, str):
|
||||
return messages
|
||||
|
||||
if isinstance(messages, (list, tuple)):
|
||||
# 检查是否是 LangChain 消息对象列表
|
||||
serialized_list = []
|
||||
for msg in messages:
|
||||
if hasattr(msg, 'type') and hasattr(msg, 'content'):
|
||||
# LangChain 消息对象
|
||||
serialized_list.append({
|
||||
'type': msg.type,
|
||||
'content': msg.content,
|
||||
'role': getattr(msg, 'role', msg.type)
|
||||
})
|
||||
elif isinstance(msg, dict):
|
||||
serialized_list.append(msg)
|
||||
else:
|
||||
serialized_list.append(str(msg))
|
||||
return json.dumps(serialized_list, ensure_ascii=False)
|
||||
|
||||
if isinstance(messages, dict):
|
||||
return json.dumps(messages, ensure_ascii=False)
|
||||
|
||||
# 其他类型转为字符串
|
||||
return str(messages)
|
||||
|
||||
|
||||
def deserialize_messages(messages_str: str) -> Any:
|
||||
"""
|
||||
将 JSON 字符串反序列化为原始格式
|
||||
|
||||
Args:
|
||||
messages_str: JSON 字符串
|
||||
|
||||
Returns:
|
||||
反序列化后的对象(list、dict 或 string)
|
||||
"""
|
||||
if not messages_str:
|
||||
return []
|
||||
|
||||
try:
|
||||
return json.loads(messages_str)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
return messages_str
|
||||
|
||||
|
||||
def fix_encoding(text: str) -> str:
|
||||
"""
|
||||
修复错误编码的文本
|
||||
|
||||
Args:
|
||||
text: 需要修复的文本
|
||||
|
||||
Returns:
|
||||
str: 修复后的文本
|
||||
"""
|
||||
if not text or not isinstance(text, str):
|
||||
return text
|
||||
try:
|
||||
# 尝试修复 Latin-1 误编码为 UTF-8 的情况
|
||||
return text.encode('latin-1').decode('utf-8')
|
||||
except (UnicodeDecodeError, UnicodeEncodeError):
|
||||
# 如果修复失败,返回原文本
|
||||
return text
|
||||
|
||||
|
||||
def format_session_data(data: Dict[str, Any], include_time: bool = False) -> Dict[str, Any]:
|
||||
"""
|
||||
格式化会话数据为统一的输出格式
|
||||
|
||||
Args:
|
||||
data: 原始会话数据
|
||||
include_time: 是否包含时间字段
|
||||
|
||||
Returns:
|
||||
Dict: 格式化后的数据 {"Query": "...", "Answer": "...", "starttime": "..."}
|
||||
"""
|
||||
result = {
|
||||
"Query": fix_encoding(data.get('messages', '')),
|
||||
"Answer": fix_encoding(data.get('aimessages', ''))
|
||||
}
|
||||
|
||||
if include_time:
|
||||
result["starttime"] = data.get('starttime', '')
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def filter_by_time_range(items: List[Dict], minutes: int) -> List[Dict]:
|
||||
"""
|
||||
根据时间范围过滤数据
|
||||
|
||||
Args:
|
||||
items: 包含 starttime 字段的数据列表
|
||||
minutes: 时间范围(分钟)
|
||||
|
||||
Returns:
|
||||
List[Dict]: 过滤后的数据列表
|
||||
"""
|
||||
time_threshold = datetime.now() - timedelta(minutes=minutes)
|
||||
time_threshold_str = time_threshold.strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
||||
filtered_items = []
|
||||
for item in items:
|
||||
starttime = item.get('starttime', '')
|
||||
if starttime and starttime >= time_threshold_str:
|
||||
filtered_items.append(item)
|
||||
|
||||
return filtered_items
|
||||
|
||||
|
||||
def sort_and_limit_results(items: List[Dict], limit: int = 6,
|
||||
remove_time: bool = True) -> List[Dict]:
|
||||
"""
|
||||
对结果进行排序、限制数量并移除时间字段
|
||||
|
||||
Args:
|
||||
items: 数据列表
|
||||
limit: 最大返回数量
|
||||
remove_time: 是否移除 starttime 字段
|
||||
|
||||
Returns:
|
||||
List[Dict]: 处理后的数据列表
|
||||
"""
|
||||
# 按时间降序排序(最新的在前)
|
||||
items.sort(key=lambda x: x.get('starttime', ''), reverse=True)
|
||||
|
||||
# 限制数量
|
||||
result_items = items[:limit]
|
||||
|
||||
# 移除 starttime 字段
|
||||
if remove_time:
|
||||
for item in result_items:
|
||||
item.pop('starttime', None)
|
||||
|
||||
# 如果结果少于1条,返回空列表
|
||||
if len(result_items) < 1:
|
||||
return []
|
||||
|
||||
return result_items
|
||||
|
||||
|
||||
def generate_session_key(session_id: str, key_type: str = "session") -> str:
|
||||
"""
|
||||
生成 Redis key
|
||||
|
||||
Args:
|
||||
session_id: 会话ID
|
||||
key_type: key 类型 ("session", "read", "write", "count")
|
||||
|
||||
Returns:
|
||||
str: Redis key
|
||||
"""
|
||||
if key_type == "count":
|
||||
return f"session:count:{session_id}"
|
||||
elif key_type == "write":
|
||||
return f"session:write:{session_id}"
|
||||
elif key_type == "session" or key_type == "read":
|
||||
return f"session:{session_id}"
|
||||
else:
|
||||
return f"session:{session_id}"
|
||||
|
||||
|
||||
def get_current_timestamp() -> str:
|
||||
"""
|
||||
获取当前时间戳字符串
|
||||
|
||||
Returns:
|
||||
str: 格式化的时间字符串 "YYYY-MM-DD HH:MM:SS"
|
||||
"""
|
||||
return datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
@@ -1,11 +1,36 @@
|
||||
import redis
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from app.core.config import settings
|
||||
from typing import List, Dict, Any, Optional, Union
|
||||
|
||||
from app.core.memory.agent.utils.redis_base import (
|
||||
serialize_messages,
|
||||
deserialize_messages,
|
||||
fix_encoding,
|
||||
format_session_data,
|
||||
filter_by_time_range,
|
||||
sort_and_limit_results,
|
||||
generate_session_key,
|
||||
get_current_timestamp
|
||||
)
|
||||
|
||||
|
||||
class RedisSessionStore:
|
||||
|
||||
|
||||
class RedisWriteStore:
|
||||
"""Redis Write 类型存储类,用于管理 save_session_write 相关的数据"""
|
||||
|
||||
def __init__(self, host='localhost', port=6379, db=0, password=None, session_id=''):
|
||||
"""
|
||||
初始化 Redis 连接
|
||||
|
||||
Args:
|
||||
host: Redis 主机地址
|
||||
port: Redis 端口
|
||||
db: Redis 数据库编号
|
||||
password: Redis 密码
|
||||
session_id: 会话ID
|
||||
"""
|
||||
self.r = redis.Redis(
|
||||
host=host,
|
||||
port=port,
|
||||
@@ -16,210 +41,633 @@ class RedisSessionStore:
|
||||
)
|
||||
self.uudi = session_id
|
||||
|
||||
def _fix_encoding(self, text):
|
||||
"""修复错误编码的文本"""
|
||||
if not text or not isinstance(text, str):
|
||||
return text
|
||||
try:
|
||||
# 尝试修复 Latin-1 误编码为 UTF-8 的情况
|
||||
return text.encode('latin-1').decode('utf-8')
|
||||
except (UnicodeDecodeError, UnicodeEncodeError):
|
||||
# 如果修复失败,返回原文本
|
||||
return text
|
||||
|
||||
# 修改后的 save_session 方法
|
||||
def save_session(self, userid, messages, aimessages, apply_id, group_id):
|
||||
def save_session_write(self, userid: str, messages: str) -> str:
|
||||
"""
|
||||
写入一条会话数据,返回 session_id
|
||||
优化版本:确保写入时间不超过1秒
|
||||
|
||||
Args:
|
||||
userid: 用户ID
|
||||
messages: 用户消息
|
||||
|
||||
Returns:
|
||||
str: 新生成的 session_id
|
||||
"""
|
||||
try:
|
||||
session_id = str(uuid.uuid4()) # 为每次会话生成新的 ID
|
||||
starttime = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
key = f"session:{session_id}" # 使用新生成的 session_id 作为 key
|
||||
messages = serialize_messages(messages)
|
||||
session_id = str(uuid.uuid4())
|
||||
key = generate_session_key(session_id, key_type="write")
|
||||
|
||||
# 使用 pipeline 批量写入,减少网络往返
|
||||
pipe = self.r.pipeline()
|
||||
|
||||
# 直接写入数据,decode_responses=True 已经处理了编码
|
||||
pipe.hset(key, mapping={
|
||||
"id": self.uudi,
|
||||
"sessionid": userid,
|
||||
"apply_id": apply_id,
|
||||
"group_id": group_id,
|
||||
"messages": messages,
|
||||
"aimessages": aimessages,
|
||||
"starttime": starttime
|
||||
"starttime": get_current_timestamp()
|
||||
})
|
||||
|
||||
# 可选:设置过期时间(例如30天),避免数据无限增长
|
||||
# pipe.expire(key, 30 * 24 * 60 * 60)
|
||||
|
||||
# 执行批量操作
|
||||
result = pipe.execute()
|
||||
|
||||
print(f"保存结果: {result[0]}, session_id: {session_id}")
|
||||
return session_id # 返回新生成的 session_id
|
||||
print(f"[save_session_write] 保存结果: {result[0]}, session_id: {session_id}")
|
||||
return session_id
|
||||
except Exception as e:
|
||||
print(f"保存会话失败: {e}")
|
||||
print(f"[save_session_write] 保存会话失败: {e}")
|
||||
raise e
|
||||
|
||||
def save_sessions_batch(self, sessions_data):
|
||||
def get_session_by_userid(self, userid: str) -> Union[List[Dict[str, str]], bool]:
|
||||
"""
|
||||
批量写入多条会话数据,返回 session_id 列表
|
||||
sessions_data: list of dict, 每个 dict 包含 userid, messages, aimessages, apply_id, group_id
|
||||
优化版本:批量操作,大幅提升性能
|
||||
通过 save_session_write 的 userid 获取 sessionid 和 messages
|
||||
|
||||
Args:
|
||||
userid: 用户ID (对应 sessionid 字段)
|
||||
|
||||
Returns:
|
||||
List[Dict] 或 False: 如果找到数据返回 [{"sessionid": "...", "messages": "..."}, ...],否则返回 False
|
||||
"""
|
||||
try:
|
||||
session_ids = []
|
||||
# 只查询 write 类型的 key
|
||||
keys = self.r.keys('session:write:*')
|
||||
if not keys:
|
||||
return False
|
||||
|
||||
# 批量获取数据
|
||||
pipe = self.r.pipeline()
|
||||
for key in keys:
|
||||
pipe.hgetall(key)
|
||||
all_data = pipe.execute()
|
||||
|
||||
for session in sessions_data:
|
||||
session_id = str(uuid.uuid4())
|
||||
starttime = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
key = f"session:{session_id}"
|
||||
|
||||
pipe.hset(key, mapping={
|
||||
"id": self.uudi,
|
||||
"sessionid": session.get('userid'),
|
||||
"apply_id": session.get('apply_id'),
|
||||
"group_id": session.get('group_id'),
|
||||
"messages": session.get('messages'),
|
||||
"aimessages": session.get('aimessages'),
|
||||
"starttime": starttime
|
||||
})
|
||||
|
||||
session_ids.append(session_id)
|
||||
|
||||
# 一次性执行所有写入操作
|
||||
results = pipe.execute()
|
||||
print(f"批量保存完成: {len(session_ids)} 条记录")
|
||||
return session_ids
|
||||
# 筛选符合 userid 的数据
|
||||
results = []
|
||||
for key, data in zip(keys, all_data):
|
||||
if not data:
|
||||
continue
|
||||
|
||||
# 从 write 类型读取,匹配 sessionid 字段
|
||||
if data.get('sessionid') == userid:
|
||||
# 从 key 中提取 session_id: session:write:{session_id}
|
||||
session_id = key.split(':')[-1]
|
||||
results.append({
|
||||
"sessionid": session_id,
|
||||
"messages": fix_encoding(data.get('messages', ''))
|
||||
})
|
||||
|
||||
if not results:
|
||||
return False
|
||||
|
||||
print(f"[get_session_by_userid] userid={userid}, 找到 {len(results)} 条数据")
|
||||
return results
|
||||
except Exception as e:
|
||||
print(f"批量保存会话失败: {e}")
|
||||
raise e
|
||||
print(f"[get_session_by_userid] 查询失败: {e}")
|
||||
return False
|
||||
|
||||
def get_all_sessions_by_end_user_id(self, end_user_id: str) -> Union[List[Dict[str, Any]], bool]:
|
||||
"""
|
||||
通过 end_user_id 获取所有 write 类型的会话数据
|
||||
|
||||
Args:
|
||||
end_user_id: 终端用户ID (对应 sessionid 字段)
|
||||
|
||||
Returns:
|
||||
List[Dict] 或 False: 如果找到数据返回完整的会话信息列表,否则返回 False
|
||||
|
||||
返回格式:
|
||||
[
|
||||
{
|
||||
"session_id": "uuid",
|
||||
"id": "...",
|
||||
"sessionid": "end_user_id",
|
||||
"messages": "...",
|
||||
"starttime": "timestamp"
|
||||
},
|
||||
...
|
||||
]
|
||||
"""
|
||||
try:
|
||||
# 只查询 write 类型的 key
|
||||
keys = self.r.keys('session:write:*')
|
||||
if not keys:
|
||||
print(f"[get_all_sessions_by_end_user_id] 没有找到任何 write 类型的会话")
|
||||
return False
|
||||
|
||||
# ---------------- 读取 ----------------
|
||||
def get_session(self, session_id):
|
||||
"""
|
||||
读取一条会话数据
|
||||
"""
|
||||
key = f"session:{session_id}"
|
||||
data = self.r.hgetall(key)
|
||||
return data if data else None
|
||||
# 批量获取数据
|
||||
pipe = self.r.pipeline()
|
||||
for key in keys:
|
||||
pipe.hgetall(key)
|
||||
all_data = pipe.execute()
|
||||
|
||||
def get_session_apply_group(self, sessionid, apply_id, group_id):
|
||||
"""
|
||||
根据 sessionid、apply_id 和 group_id 三个条件查询会话数据
|
||||
"""
|
||||
result_items = []
|
||||
# 筛选符合 end_user_id 的数据
|
||||
results = []
|
||||
for key, data in zip(keys, all_data):
|
||||
if not data:
|
||||
continue
|
||||
|
||||
# 从 write 类型读取,匹配 sessionid 字段
|
||||
if data.get('sessionid') == end_user_id:
|
||||
# 从 key 中提取 session_id: session:write:{session_id}
|
||||
session_id = key.split(':')[-1]
|
||||
|
||||
# 构建完整的会话信息
|
||||
session_info = {
|
||||
"session_id": session_id,
|
||||
"id": data.get('id', ''),
|
||||
"sessionid": data.get('sessionid', ''),
|
||||
"messages": fix_encoding(data.get('messages', '')),
|
||||
"starttime": data.get('starttime', '')
|
||||
}
|
||||
results.append(session_info)
|
||||
|
||||
if not results:
|
||||
print(f"[get_all_sessions_by_end_user_id] end_user_id={end_user_id}, 没有找到数据")
|
||||
return False
|
||||
|
||||
# 按时间排序(最新的在前)
|
||||
results.sort(key=lambda x: x.get('starttime', ''), reverse=True)
|
||||
|
||||
print(f"[get_all_sessions_by_end_user_id] end_user_id={end_user_id}, 找到 {len(results)} 条数据")
|
||||
return results
|
||||
except Exception as e:
|
||||
print(f"[get_all_sessions_by_end_user_id] 查询失败: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
# 遍历所有会话数据
|
||||
for key in self.r.keys('session:*'):
|
||||
data = self.r.hgetall(key)
|
||||
|
||||
if not data:
|
||||
continue
|
||||
|
||||
# 检查三个条件是否都匹配
|
||||
if (data.get('sessionid') == sessionid and
|
||||
data.get('apply_id') == apply_id and
|
||||
data.get('group_id') == group_id):
|
||||
result_items.append(data)
|
||||
|
||||
return result_items
|
||||
|
||||
def get_all_sessions(self):
|
||||
def find_user_recent_sessions(self, userid: str,
|
||||
minutes: int = 5) -> List[Dict[str, str]]:
|
||||
"""
|
||||
获取所有会话数据
|
||||
"""
|
||||
sessions = {}
|
||||
for key in self.r.keys('session:*'):
|
||||
sid = key.split(':')[1]
|
||||
sessions[sid] = self.get_session(sid)
|
||||
return sessions
|
||||
|
||||
# ---------------- 更新 ----------------
|
||||
def update_session(self, session_id, field, value):
|
||||
"""
|
||||
更新单个字段
|
||||
优化版本:使用 pipeline 减少网络往返
|
||||
"""
|
||||
key = f"session:{session_id}"
|
||||
pipe = self.r.pipeline()
|
||||
pipe.exists(key)
|
||||
pipe.hset(key, field, value)
|
||||
results = pipe.execute()
|
||||
return bool(results[0]) # 返回 key 是否存在
|
||||
|
||||
# ---------------- 删除 ----------------
|
||||
def delete_session(self, session_id):
|
||||
"""
|
||||
删除单条会话
|
||||
"""
|
||||
key = f"session:{session_id}"
|
||||
return self.r.delete(key)
|
||||
|
||||
def delete_all_sessions(self):
|
||||
"""
|
||||
删除所有会话
|
||||
"""
|
||||
keys = self.r.keys('session:*')
|
||||
if keys:
|
||||
return self.r.delete(*keys)
|
||||
return 0
|
||||
|
||||
def delete_duplicate_sessions(self):
|
||||
"""
|
||||
删除重复会话数据,条件:
|
||||
"sessionid"、"user_id"、"group_id"、"messages"、"aimessages" 五个字段都相同的只保留一个,其他删除
|
||||
优化版本:使用 pipeline 批量操作,确保在1秒内完成
|
||||
根据 userid 从 save_session_write 写入的数据中查询最近 N 分钟内的会话数据
|
||||
|
||||
Args:
|
||||
userid: 用户ID (对应 sessionid 字段)
|
||||
minutes: 查询最近几分钟的数据,默认5分钟
|
||||
|
||||
Returns:
|
||||
List[Dict]: 会话列表 [{"Query": "...", "Answer": "..."}, ...]
|
||||
"""
|
||||
import time
|
||||
start_time = time.time()
|
||||
|
||||
# 第一步:使用 pipeline 批量获取所有 key
|
||||
keys = self.r.keys('session:*')
|
||||
|
||||
|
||||
# 只查询 write 类型的 key
|
||||
keys = self.r.keys('session:write:*')
|
||||
if not keys:
|
||||
print("[delete_duplicate_sessions] 没有会话数据")
|
||||
return 0
|
||||
print(f"[find_user_recent_sessions] 查询耗时: {time.time() - start_time:.3f}秒, 结果数: 0")
|
||||
return []
|
||||
|
||||
# 第二步:使用 pipeline 批量获取所有数据
|
||||
# 批量获取数据
|
||||
pipe = self.r.pipeline()
|
||||
for key in keys:
|
||||
pipe.hgetall(key)
|
||||
all_data = pipe.execute()
|
||||
|
||||
# 第三步:在内存中识别重复数据
|
||||
seen = {} # 用字典记录:identifier -> key(保留第一个出现的 key)
|
||||
keys_to_delete = [] # 需要删除的 key 列表
|
||||
# 筛选符合 userid 的数据
|
||||
matched_items = []
|
||||
for data in all_data:
|
||||
if not data:
|
||||
continue
|
||||
|
||||
# 从 write 类型读取,匹配 sessionid 字段
|
||||
if data.get('sessionid') == userid and data.get('starttime'):
|
||||
# write 类型没有 aimessages,所以 Answer 为空
|
||||
matched_items.append({
|
||||
"Query": fix_encoding(data.get('messages', '')),
|
||||
"Answer": "",
|
||||
"starttime": data.get('starttime', '')
|
||||
})
|
||||
|
||||
# 根据时间范围过滤
|
||||
filtered_items = filter_by_time_range(matched_items, minutes)
|
||||
# 排序并移除时间字段
|
||||
result_items = sort_and_limit_results(filtered_items, limit=None)
|
||||
print(result_items)
|
||||
|
||||
for key, data in zip(keys, all_data, strict=False):
|
||||
elapsed_time = time.time() - start_time
|
||||
print(f"[find_user_recent_sessions] userid={userid}, minutes={minutes}, "
|
||||
f"查询耗时: {elapsed_time:.3f}秒, 结果数: {len(result_items)}")
|
||||
|
||||
return result_items
|
||||
|
||||
def delete_all_write_sessions(self) -> int:
|
||||
"""
|
||||
删除所有 write 类型的会话
|
||||
|
||||
Returns:
|
||||
int: 删除的数量
|
||||
"""
|
||||
keys = self.r.keys('session:write:*')
|
||||
if keys:
|
||||
return self.r.delete(*keys)
|
||||
return 0
|
||||
|
||||
|
||||
class RedisCountStore:
|
||||
"""Redis Count 类型存储类,用于管理访问次数统计相关的数据"""
|
||||
|
||||
def __init__(self, host='localhost', port=6379, db=0, password=None, session_id=''):
|
||||
"""
|
||||
初始化 Redis 连接
|
||||
|
||||
Args:
|
||||
host: Redis 主机地址
|
||||
port: Redis 端口
|
||||
db: Redis 数据库编号
|
||||
password: Redis 密码
|
||||
session_id: 会话ID
|
||||
"""
|
||||
self.r = redis.Redis(
|
||||
host=host,
|
||||
port=port,
|
||||
db=db,
|
||||
password=password,
|
||||
decode_responses=True,
|
||||
encoding='utf-8'
|
||||
)
|
||||
self.uudi = session_id
|
||||
|
||||
def save_sessions_count(self, end_user_id: str, count: int, messages: Any) -> str:
|
||||
"""
|
||||
保存用户访问次数统计
|
||||
|
||||
Args:
|
||||
end_user_id: 终端用户ID
|
||||
count: 访问次数
|
||||
messages: 消息内容
|
||||
|
||||
Returns:
|
||||
str: 新生成的 session_id
|
||||
"""
|
||||
session_id = str(uuid.uuid4())
|
||||
key = generate_session_key(session_id, key_type="count")
|
||||
index_key = f'session:count:index:{end_user_id}' # 索引键
|
||||
|
||||
pipe = self.r.pipeline()
|
||||
pipe.hset(key, mapping={
|
||||
"id": self.uudi,
|
||||
"end_user_id": end_user_id,
|
||||
"count": int(count),
|
||||
"messages": serialize_messages(messages),
|
||||
"starttime": get_current_timestamp()
|
||||
})
|
||||
pipe.expire(key, 30 * 24 * 60 * 60) # 30天过期
|
||||
|
||||
# 创建索引:end_user_id -> session_id 映射
|
||||
pipe.set(index_key, session_id, ex=30 * 24 * 60 * 60)
|
||||
|
||||
result = pipe.execute()
|
||||
|
||||
print(f"[save_sessions_count] 保存结果: {result}, session_id: {session_id}")
|
||||
return session_id
|
||||
|
||||
def get_sessions_count(self, end_user_id: str) -> Union[List[Any], bool]:
|
||||
"""
|
||||
通过 end_user_id 查询访问次数统计
|
||||
|
||||
Args:
|
||||
end_user_id: 终端用户ID
|
||||
|
||||
Returns:
|
||||
list 或 False: 如果找到返回 [count, messages],否则返回 False
|
||||
"""
|
||||
try:
|
||||
# 使用索引键快速查找
|
||||
index_key = f'session:count:index:{end_user_id}'
|
||||
|
||||
# 检查索引键类型,避免 WRONGTYPE 错误
|
||||
try:
|
||||
key_type = self.r.type(index_key)
|
||||
if key_type != 'string' and key_type != 'none':
|
||||
self.r.delete(index_key)
|
||||
return False
|
||||
except Exception as type_error:
|
||||
print(f"[get_sessions_count] 检查键类型失败: {type_error}")
|
||||
|
||||
session_id = self.r.get(index_key)
|
||||
|
||||
if not session_id:
|
||||
return False
|
||||
|
||||
# 直接获取数据
|
||||
key = generate_session_key(session_id, key_type="count")
|
||||
data = self.r.hgetall(key)
|
||||
|
||||
if not data:
|
||||
# 索引存在但数据不存在,清理索引
|
||||
self.r.delete(index_key)
|
||||
return False
|
||||
|
||||
count = data.get('count')
|
||||
messages_str = data.get('messages')
|
||||
|
||||
if count is not None:
|
||||
messages = deserialize_messages(messages_str)
|
||||
return [int(count), messages]
|
||||
|
||||
return False
|
||||
except Exception as e:
|
||||
print(f"[get_sessions_count] 查询失败: {e}")
|
||||
return False
|
||||
def update_sessions_count(self, end_user_id: str, new_count: int,
|
||||
messages: Any) -> bool:
|
||||
"""
|
||||
通过 end_user_id 修改访问次数统计(优化版:使用索引)
|
||||
|
||||
Args:
|
||||
end_user_id: 终端用户ID
|
||||
new_count: 新的 count 值
|
||||
messages: 消息内容
|
||||
|
||||
Returns:
|
||||
bool: 更新成功返回 True,未找到记录返回 False
|
||||
"""
|
||||
try:
|
||||
# 使用索引键快速查找
|
||||
index_key = f'session:count:index:{end_user_id}'
|
||||
|
||||
# 检查索引键类型,避免 WRONGTYPE 错误
|
||||
try:
|
||||
key_type = self.r.type(index_key)
|
||||
if key_type != 'string' and key_type != 'none':
|
||||
# 索引键类型错误,删除并返回 False
|
||||
print(f"[update_sessions_count] 索引键类型错误: {key_type},删除索引")
|
||||
self.r.delete(index_key)
|
||||
print(f"[update_sessions_count] 未找到记录: end_user_id={end_user_id}")
|
||||
return False
|
||||
except Exception as type_error:
|
||||
print(f"[update_sessions_count] 检查键类型失败: {type_error}")
|
||||
|
||||
session_id = self.r.get(index_key)
|
||||
|
||||
if not session_id:
|
||||
print(f"[update_sessions_count] 未找到记录: end_user_id={end_user_id}")
|
||||
return False
|
||||
|
||||
# 直接更新数据
|
||||
key = generate_session_key(session_id, key_type="count")
|
||||
messages_str = serialize_messages(messages)
|
||||
|
||||
pipe = self.r.pipeline()
|
||||
pipe.hset(key, 'count', int(new_count))
|
||||
pipe.hset(key, 'messages', messages_str)
|
||||
result = pipe.execute()
|
||||
|
||||
print(f"[update_sessions_count] 更新成功: end_user_id={end_user_id}, new_count={new_count}, key={key}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"[update_sessions_count] 更新失败: {e}")
|
||||
return False
|
||||
|
||||
def delete_all_count_sessions(self) -> int:
|
||||
"""
|
||||
删除所有 count 类型的会话
|
||||
|
||||
Returns:
|
||||
int: 删除的数量
|
||||
"""
|
||||
keys = self.r.keys('session:count:*')
|
||||
if keys:
|
||||
return self.r.delete(*keys)
|
||||
return 0
|
||||
|
||||
|
||||
class RedisSessionStore:
|
||||
"""Redis 会话存储类,用于管理会话数据"""
|
||||
|
||||
def __init__(self, host='localhost', port=6379, db=0, password=None, session_id=''):
|
||||
"""
|
||||
初始化 Redis 连接
|
||||
|
||||
Args:
|
||||
host: Redis 主机地址
|
||||
port: Redis 端口
|
||||
db: Redis 数据库编号
|
||||
password: Redis 密码
|
||||
session_id: 会话ID
|
||||
"""
|
||||
self.r = redis.Redis(
|
||||
host=host,
|
||||
port=port,
|
||||
db=db,
|
||||
password=password,
|
||||
decode_responses=True,
|
||||
encoding='utf-8'
|
||||
)
|
||||
self.uudi = session_id
|
||||
|
||||
# ==================== 写入操作 ====================
|
||||
|
||||
def save_session(self, userid: str, messages: str, aimessages: str,
|
||||
apply_id: str, end_user_id: str) -> str:
|
||||
"""
|
||||
写入一条会话数据,返回 session_id
|
||||
|
||||
Args:
|
||||
userid: 用户ID
|
||||
messages: 用户消息
|
||||
aimessages: AI回复消息
|
||||
apply_id: 应用ID
|
||||
end_user_id: 终端用户ID
|
||||
|
||||
Returns:
|
||||
str: 新生成的 session_id
|
||||
"""
|
||||
try:
|
||||
session_id = str(uuid.uuid4())
|
||||
key = generate_session_key(session_id, key_type="read")
|
||||
|
||||
pipe = self.r.pipeline()
|
||||
pipe.hset(key, mapping={
|
||||
"id": self.uudi,
|
||||
"sessionid": userid,
|
||||
"apply_id": apply_id,
|
||||
"end_user_id": end_user_id,
|
||||
"messages": messages,
|
||||
"aimessages": aimessages,
|
||||
"starttime": get_current_timestamp()
|
||||
})
|
||||
result = pipe.execute()
|
||||
|
||||
print(f"[save_session] 保存结果: {result[0]}, session_id: {session_id}")
|
||||
return session_id
|
||||
except Exception as e:
|
||||
print(f"[save_session] 保存会话失败: {e}")
|
||||
raise e
|
||||
|
||||
# ==================== 读取操作 ====================
|
||||
|
||||
def get_session(self, session_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
读取一条会话数据
|
||||
|
||||
Args:
|
||||
session_id: 会话ID
|
||||
|
||||
Returns:
|
||||
Dict 或 None: 会话数据
|
||||
"""
|
||||
key = generate_session_key(session_id)
|
||||
data = self.r.hgetall(key)
|
||||
return data if data else None
|
||||
|
||||
def get_all_sessions(self) -> Dict[str, Dict[str, Any]]:
|
||||
"""
|
||||
获取所有会话数据(不包括 count 和 write 类型)
|
||||
|
||||
Returns:
|
||||
Dict: 所有会话数据,key 为 session_id
|
||||
"""
|
||||
sessions = {}
|
||||
for key in self.r.keys('session:*'):
|
||||
# 排除 count 和 write 类型的 key
|
||||
if ':count:' not in key and ':write:' not in key:
|
||||
sid = key.split(':')[1]
|
||||
sessions[sid] = self.get_session(sid)
|
||||
return sessions
|
||||
|
||||
def find_user_apply_group(self, sessionid: str, apply_id: str,
|
||||
end_user_id: str) -> List[Dict[str, str]]:
|
||||
"""
|
||||
根据 sessionid、apply_id 和 end_user_id 查询会话数据,返回最新的6条
|
||||
|
||||
Args:
|
||||
sessionid: 会话ID(支持模糊匹配)
|
||||
apply_id: 应用ID
|
||||
end_user_id: 终端用户ID
|
||||
|
||||
Returns:
|
||||
List[Dict]: 会话列表 [{"Query": "...", "Answer": "..."}, ...]
|
||||
"""
|
||||
import time
|
||||
start_time = time.time()
|
||||
|
||||
keys = self.r.keys('session:*')
|
||||
if not keys:
|
||||
print(f"[find_user_apply_group] 查询耗时: {time.time() - start_time:.3f}秒, 结果数: 0")
|
||||
return []
|
||||
|
||||
# 批量获取数据
|
||||
pipe = self.r.pipeline()
|
||||
for key in keys:
|
||||
# 排除 count 和 write 类型
|
||||
if ':count:' not in key and ':write:' not in key:
|
||||
pipe.hgetall(key)
|
||||
all_data = pipe.execute()
|
||||
|
||||
# 筛选符合条件的数据
|
||||
matched_items = []
|
||||
for data in all_data:
|
||||
if not data:
|
||||
continue
|
||||
|
||||
# 获取五个字段的值
|
||||
sessionid = data.get('sessionid', '')
|
||||
user_id = data.get('id', '')
|
||||
group_id = data.get('group_id', '')
|
||||
messages = data.get('messages', '')
|
||||
aimessages = data.get('aimessages', '')
|
||||
if (data.get('apply_id') == apply_id and
|
||||
data.get('end_user_id') == end_user_id):
|
||||
# 支持模糊匹配或完全匹配 sessionid
|
||||
if sessionid in data.get('sessionid', '') or data.get('sessionid') == sessionid:
|
||||
matched_items.append(format_session_data(data, include_time=True))
|
||||
|
||||
# 排序、限制数量并移除时间字段
|
||||
result_items = sort_and_limit_results(matched_items, limit=6)
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
print(f"[find_user_apply_group] 查询耗时: {elapsed_time:.3f}秒, 结果数: {len(result_items)}")
|
||||
|
||||
return result_items
|
||||
|
||||
# ==================== 更新操作 ====================
|
||||
|
||||
def update_session(self, session_id: str, field: str, value: Any) -> bool:
|
||||
"""
|
||||
更新单个字段
|
||||
|
||||
Args:
|
||||
session_id: 会话ID
|
||||
field: 字段名
|
||||
value: 字段值
|
||||
|
||||
Returns:
|
||||
bool: 是否更新成功
|
||||
"""
|
||||
key = generate_session_key(session_id)
|
||||
pipe = self.r.pipeline()
|
||||
pipe.exists(key)
|
||||
pipe.hset(key, field, value)
|
||||
results = pipe.execute()
|
||||
return bool(results[0])
|
||||
|
||||
# ==================== 删除操作 ====================
|
||||
|
||||
def delete_session(self, session_id: str) -> int:
|
||||
"""
|
||||
删除单条会话
|
||||
|
||||
Args:
|
||||
session_id: 会话ID
|
||||
|
||||
Returns:
|
||||
int: 删除的数量
|
||||
"""
|
||||
key = generate_session_key(session_id)
|
||||
return self.r.delete(key)
|
||||
|
||||
def delete_all_sessions(self) -> int:
|
||||
"""
|
||||
删除所有会话(不包括 count 和 write 类型)
|
||||
|
||||
Returns:
|
||||
int: 删除的数量
|
||||
"""
|
||||
keys = self.r.keys('session:*')
|
||||
# 过滤掉 count 和 write 类型
|
||||
keys_to_delete = [k for k in keys if ':count:' not in k and ':write:' not in k]
|
||||
if keys_to_delete:
|
||||
return self.r.delete(*keys_to_delete)
|
||||
return 0
|
||||
|
||||
def delete_duplicate_sessions(self) -> int:
|
||||
"""
|
||||
删除重复会话数据(不包括 count 和 write 类型)
|
||||
条件:sessionid、user_id、end_user_id、messages、aimessages 五个字段都相同的只保留一个
|
||||
|
||||
Returns:
|
||||
int: 删除的数量
|
||||
"""
|
||||
import time
|
||||
start_time = time.time()
|
||||
|
||||
keys = self.r.keys('session:*')
|
||||
if not keys:
|
||||
print("[delete_duplicate_sessions] 没有会话数据")
|
||||
return 0
|
||||
|
||||
# 批量获取所有数据
|
||||
pipe = self.r.pipeline()
|
||||
for key in keys:
|
||||
# 排除 count 和 write 类型
|
||||
if ':count:' not in key and ':write:' not in key:
|
||||
pipe.hgetall(key)
|
||||
all_data = pipe.execute()
|
||||
|
||||
# 识别重复数据
|
||||
seen = {}
|
||||
keys_to_delete = []
|
||||
|
||||
for key, data in zip([k for k in keys if ':count:' not in k and ':write:' not in k], all_data, strict=False):
|
||||
if not data:
|
||||
continue
|
||||
|
||||
# 用五元组作为唯一标识
|
||||
identifier = (sessionid, user_id, group_id, messages, aimessages)
|
||||
identifier = (
|
||||
data.get('sessionid', ''),
|
||||
data.get('id', ''),
|
||||
data.get('end_user_id', ''),
|
||||
data.get('messages', ''),
|
||||
data.get('aimessages', '')
|
||||
)
|
||||
|
||||
if identifier in seen:
|
||||
# 重复,标记为待删除
|
||||
keys_to_delete.append(key)
|
||||
else:
|
||||
# 第一次出现,记录
|
||||
seen[identifier] = key
|
||||
|
||||
# 第四步:使用 pipeline 批量删除重复的 key
|
||||
# 批量删除重复的 key
|
||||
deleted_count = 0
|
||||
if keys_to_delete:
|
||||
# 分批删除,避免单次操作过大
|
||||
batch_size = 1000
|
||||
for i in range(0, len(keys_to_delete), batch_size):
|
||||
batch = keys_to_delete[i:i + batch_size]
|
||||
@@ -233,79 +681,28 @@ class RedisSessionStore:
|
||||
print(f"[delete_duplicate_sessions] 删除重复会话数量: {deleted_count}, 耗时: {elapsed_time:.3f}秒")
|
||||
return deleted_count
|
||||
|
||||
def find_user_session(self, sessionid):
|
||||
user_id = sessionid
|
||||
|
||||
result_items = []
|
||||
for key, values in store.get_all_sessions().items():
|
||||
history = {}
|
||||
if user_id == str(values['sessionid']):
|
||||
history["Query"] = values['messages']
|
||||
history["Answer"] = values['aimessages']
|
||||
result_items.append(history)
|
||||
|
||||
if len(result_items) <= 1:
|
||||
result_items = []
|
||||
return (result_items)
|
||||
|
||||
def find_user_apply_group(self, sessionid, apply_id, group_id):
|
||||
"""
|
||||
根据 sessionid、apply_id 和 group_id 三个条件查询会话数据,返回最新的6条
|
||||
"""
|
||||
import time
|
||||
start_time = time.time()
|
||||
# 使用 pipeline 批量获取数据,提高性能
|
||||
keys = self.r.keys('session:*')
|
||||
|
||||
if not keys:
|
||||
print(f"查询耗时: {time.time() - start_time:.3f}秒, 结果数: 0")
|
||||
return []
|
||||
|
||||
# 使用 pipeline 批量获取所有 hash 数据
|
||||
pipe = self.r.pipeline()
|
||||
for key in keys:
|
||||
pipe.hgetall(key)
|
||||
all_data = pipe.execute()
|
||||
|
||||
# 解析并筛选符合条件的数据
|
||||
matched_items = []
|
||||
for data in all_data:
|
||||
if not data:
|
||||
continue
|
||||
|
||||
# 检查是否符合三个条件
|
||||
|
||||
if (data.get('apply_id') == apply_id and
|
||||
data.get('group_id') == group_id):
|
||||
# 支持模糊匹配 sessionid 或者完全匹配
|
||||
if sessionid in data.get('sessionid', '') or data.get('sessionid') == sessionid:
|
||||
matched_items.append({
|
||||
"Query": self._fix_encoding(data.get('messages')),
|
||||
"Answer": self._fix_encoding(data.get('aimessages')),
|
||||
"starttime": data.get('starttime', '')
|
||||
})
|
||||
# 按时间降序排序(最新的在前)
|
||||
matched_items.sort(key=lambda x: x.get('starttime', ''), reverse=True)
|
||||
# 只保留最新的6条
|
||||
result_items = matched_items[:6]
|
||||
# # 移除 starttime 字段
|
||||
for item in result_items:
|
||||
item.pop('starttime', None)
|
||||
|
||||
# 如果结果少于等于1条,返回空列表
|
||||
if len(result_items) <= 1:
|
||||
result_items = []
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
print(f"查询耗时: {elapsed_time:.3f}秒, 结果数: {len(result_items)}")
|
||||
|
||||
return result_items
|
||||
|
||||
|
||||
# 全局实例
|
||||
store = RedisSessionStore(
|
||||
host=settings.REDIS_HOST,
|
||||
port=settings.REDIS_PORT,
|
||||
db=settings.REDIS_DB,
|
||||
password=settings.REDIS_PASSWORD if settings.REDIS_PASSWORD else None,
|
||||
session_id=str(uuid.uuid4())
|
||||
)
|
||||
)
|
||||
|
||||
write_store = RedisWriteStore(
|
||||
host=settings.REDIS_HOST,
|
||||
port=settings.REDIS_PORT,
|
||||
db=settings.REDIS_DB,
|
||||
password=settings.REDIS_PASSWORD if settings.REDIS_PASSWORD else None,
|
||||
session_id=str(uuid.uuid4())
|
||||
)
|
||||
|
||||
count_store = RedisCountStore(
|
||||
host=settings.REDIS_HOST,
|
||||
port=settings.REDIS_PORT,
|
||||
db=settings.REDIS_DB,
|
||||
password=settings.REDIS_PASSWORD if settings.REDIS_PASSWORD else None,
|
||||
session_id=str(uuid.uuid4())
|
||||
)
|
||||
|
||||
@@ -59,7 +59,7 @@ class SessionService:
|
||||
self,
|
||||
user_id: str,
|
||||
apply_id: str,
|
||||
group_id: str
|
||||
end_user_id: str
|
||||
) -> List[dict]:
|
||||
"""
|
||||
Retrieve conversation history from Redis.
|
||||
@@ -67,20 +67,20 @@ class SessionService:
|
||||
Args:
|
||||
user_id: User identifier
|
||||
apply_id: Application identifier
|
||||
group_id: Group identifier
|
||||
end_user_id: Group identifier
|
||||
|
||||
Returns:
|
||||
List of conversation history items with Query and Answer keys
|
||||
Returns empty list if no history found or on error
|
||||
"""
|
||||
try:
|
||||
history = self.store.find_user_apply_group(user_id, apply_id, group_id)
|
||||
history = self.store.find_user_apply_group(user_id, apply_id, end_user_id)
|
||||
|
||||
# Validate history structure
|
||||
if not isinstance(history, list):
|
||||
logger.warning(
|
||||
f"Invalid history format for user {user_id}, "
|
||||
f"apply {apply_id}, group {group_id}: expected list, got {type(history)}"
|
||||
f"apply {apply_id}, group {end_user_id}: expected list, got {type(history)}"
|
||||
)
|
||||
return []
|
||||
|
||||
@@ -89,7 +89,7 @@ class SessionService:
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to retrieve history for user {user_id}, "
|
||||
f"apply {apply_id}, group {group_id}: {e}",
|
||||
f"apply {apply_id}, group {end_user_id}: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
# Return empty list on error to allow execution to continue
|
||||
@@ -100,7 +100,7 @@ class SessionService:
|
||||
user_id: str,
|
||||
query: str,
|
||||
apply_id: str,
|
||||
group_id: str,
|
||||
end_user_id: str,
|
||||
ai_response: str
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
@@ -110,7 +110,7 @@ class SessionService:
|
||||
user_id: User identifier
|
||||
query: User query/message
|
||||
apply_id: Application identifier
|
||||
group_id: Group identifier
|
||||
end_user_id: Group identifier
|
||||
ai_response: AI response/answer
|
||||
|
||||
Returns:
|
||||
@@ -131,7 +131,7 @@ class SessionService:
|
||||
userid=user_id,
|
||||
messages=query,
|
||||
apply_id=apply_id,
|
||||
group_id=group_id,
|
||||
end_user_id=end_user_id,
|
||||
aimessages=ai_response
|
||||
)
|
||||
|
||||
@@ -152,7 +152,7 @@ class SessionService:
|
||||
Duplicates are identified by matching:
|
||||
- sessionid
|
||||
- user_id (id field)
|
||||
- group_id
|
||||
- end_user_id
|
||||
- messages
|
||||
- aimessages
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ Write Tools for Memory Knowledge Extraction Pipeline
|
||||
This module provides the main write function for executing the knowledge extraction
|
||||
pipeline. Only MemoryConfig is needed - clients are constructed internally.
|
||||
"""
|
||||
import asyncio
|
||||
import time
|
||||
from datetime import datetime
|
||||
|
||||
@@ -29,36 +30,34 @@ logger = get_agent_logger(__name__)
|
||||
|
||||
|
||||
async def write(
|
||||
user_id: str,
|
||||
apply_id: str,
|
||||
group_id: str,
|
||||
end_user_id: str,
|
||||
memory_config: MemoryConfig,
|
||||
messages: list,
|
||||
ref_id: str = "wyl20251027",
|
||||
language: str = "zh",
|
||||
) -> None:
|
||||
"""
|
||||
Execute the complete knowledge extraction pipeline.
|
||||
|
||||
|
||||
Args:
|
||||
user_id: User identifier
|
||||
apply_id: Application identifier
|
||||
group_id: Group identifier
|
||||
end_user_id: Group identifier
|
||||
memory_config: MemoryConfig object containing all configuration
|
||||
messages: Structured message list [{"role": "user", "content": "..."}, ...]
|
||||
ref_id: Reference ID, defaults to "wyl20251027"
|
||||
language: 语言类型 ("zh" 中文, "en" 英文),默认中文
|
||||
"""
|
||||
# Extract config values
|
||||
embedding_model_id = str(memory_config.embedding_model_id)
|
||||
chunker_strategy = memory_config.chunker_strategy
|
||||
config_id = str(memory_config.config_id)
|
||||
|
||||
|
||||
logger.info("=== MemSci Knowledge Extraction Pipeline ===")
|
||||
logger.info(f"Config: {memory_config.config_name} (ID: {config_id})")
|
||||
logger.info(f"Workspace: {memory_config.workspace_name}")
|
||||
logger.info(f"LLM model: {memory_config.llm_model_name}")
|
||||
logger.info(f"Embedding model: {memory_config.embedding_model_name}")
|
||||
logger.info(f"Chunker strategy: {chunker_strategy}")
|
||||
logger.info(f"Group ID: {group_id}")
|
||||
logger.info(f"end_user_id ID: {end_user_id}")
|
||||
|
||||
# Construct clients from memory_config using factory pattern with db session
|
||||
with get_db_context() as db:
|
||||
@@ -83,9 +82,7 @@ async def write(
|
||||
step_start = time.time()
|
||||
chunked_dialogs = await get_chunked_dialogs(
|
||||
chunker_strategy=chunker_strategy,
|
||||
group_id=group_id,
|
||||
user_id=user_id,
|
||||
apply_id=apply_id,
|
||||
end_user_id=end_user_id,
|
||||
messages=messages,
|
||||
ref_id=ref_id,
|
||||
config_id=config_id,
|
||||
@@ -97,12 +94,39 @@ async def write(
|
||||
from app.core.memory.utils.config.config_utils import get_pipeline_config
|
||||
pipeline_config = get_pipeline_config(memory_config)
|
||||
|
||||
# Fetch ontology types if scene_id is configured
|
||||
ontology_types = None
|
||||
if memory_config.scene_id:
|
||||
try:
|
||||
from app.core.memory.ontology_services.ontology_type_loader import load_ontology_types_for_scene
|
||||
|
||||
with get_db_context() as db:
|
||||
ontology_types = load_ontology_types_for_scene(
|
||||
scene_id=memory_config.scene_id,
|
||||
workspace_id=memory_config.workspace_id,
|
||||
db=db
|
||||
)
|
||||
|
||||
if ontology_types:
|
||||
logger.info(
|
||||
f"Loaded {len(ontology_types.types)} ontology types for scene_id: {memory_config.scene_id}"
|
||||
)
|
||||
else:
|
||||
logger.info(f"No ontology classes found for scene_id: {memory_config.scene_id}")
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to fetch ontology types for scene_id {memory_config.scene_id}: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
|
||||
orchestrator = ExtractionOrchestrator(
|
||||
llm_client=llm_client,
|
||||
embedder_client=embedder_client,
|
||||
connector=neo4j_connector,
|
||||
config=pipeline_config,
|
||||
embedding_id=embedding_model_id,
|
||||
language=language,
|
||||
ontology_types=ontology_types,
|
||||
)
|
||||
|
||||
# Run the complete extraction pipeline
|
||||
@@ -127,23 +151,48 @@ async def write(
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating indexes: {e}", exc_info=True)
|
||||
|
||||
# 添加死锁重试机制
|
||||
max_retries = 3
|
||||
retry_delay = 1 # 秒
|
||||
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
success = await save_dialog_and_statements_to_neo4j(
|
||||
dialogue_nodes=all_dialogue_nodes,
|
||||
chunk_nodes=all_chunk_nodes,
|
||||
statement_nodes=all_statement_nodes,
|
||||
entity_nodes=all_entity_nodes,
|
||||
statement_chunk_edges=all_statement_chunk_edges,
|
||||
statement_entity_edges=all_statement_entity_edges,
|
||||
entity_edges=all_entity_entity_edges,
|
||||
connector=neo4j_connector
|
||||
)
|
||||
if success:
|
||||
logger.info("Successfully saved all data to Neo4j")
|
||||
break
|
||||
else:
|
||||
logger.warning("Failed to save some data to Neo4j")
|
||||
if attempt < max_retries - 1:
|
||||
logger.info(f"Retrying... (attempt {attempt + 2}/{max_retries})")
|
||||
await asyncio.sleep(retry_delay * (attempt + 1)) # 指数退避
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
# 检查是否是死锁错误
|
||||
if "DeadlockDetected" in error_msg or "deadlock" in error_msg.lower():
|
||||
if attempt < max_retries - 1:
|
||||
logger.warning(f"Deadlock detected, retrying... (attempt {attempt + 2}/{max_retries})")
|
||||
await asyncio.sleep(retry_delay * (attempt + 1)) # 指数退避
|
||||
else:
|
||||
logger.error(f"Failed after {max_retries} attempts due to deadlock: {e}")
|
||||
raise
|
||||
else:
|
||||
# 非死锁错误,直接抛出
|
||||
raise
|
||||
|
||||
try:
|
||||
success = await save_dialog_and_statements_to_neo4j(
|
||||
dialogue_nodes=all_dialogue_nodes,
|
||||
chunk_nodes=all_chunk_nodes,
|
||||
statement_nodes=all_statement_nodes,
|
||||
entity_nodes=all_entity_nodes,
|
||||
statement_chunk_edges=all_statement_chunk_edges,
|
||||
statement_entity_edges=all_statement_entity_edges,
|
||||
entity_edges=all_entity_entity_edges,
|
||||
connector=neo4j_connector
|
||||
)
|
||||
if success:
|
||||
logger.info("Successfully saved all data to Neo4j")
|
||||
else:
|
||||
logger.warning("Failed to save some data to Neo4j")
|
||||
finally:
|
||||
await neo4j_connector.close()
|
||||
except Exception as e:
|
||||
logger.error(f"Error closing Neo4j connector: {e}")
|
||||
|
||||
log_time("Neo4j Database Save", time.time() - step_start, log_file)
|
||||
|
||||
@@ -151,7 +200,7 @@ async def write(
|
||||
step_start = time.time()
|
||||
try:
|
||||
summaries = await memory_summary_generation(
|
||||
chunked_dialogs, llm_client=llm_client, embedder_client=embedder_client
|
||||
chunked_dialogs, llm_client=llm_client, embedder_client=embedder_client, language=language
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -177,4 +226,4 @@ async def write(
|
||||
f.write(f"=== Pipeline Run Completed: {timestamp} ===\n\n")
|
||||
|
||||
logger.info("=== Pipeline Complete ===")
|
||||
logger.info(f"Total execution time: {total_time:.2f} seconds")
|
||||
logger.info(f"Total execution time: {total_time:.2f} seconds")
|
||||
@@ -139,7 +139,8 @@ def parse_api_docs(file_path: str) -> Dict[str, Any]:
|
||||
|
||||
|
||||
def get_default_docs_path() -> str:
|
||||
project_root = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
|
||||
from pathlib import Path
|
||||
project_root = str(Path(__file__).resolve().parents[2])
|
||||
return os.path.join(project_root, "src", "analytics", "API接口.md")
|
||||
|
||||
|
||||
|
||||
@@ -16,13 +16,13 @@ class FilteredTags(BaseModel):
|
||||
"""用于接收LLM筛选后的核心标签列表的模型。"""
|
||||
meaningful_tags: List[str] = Field(..., description="从原始列表中筛选出的具有核心代表意义的名词列表。")
|
||||
|
||||
async def filter_tags_with_llm(tags: List[str], group_id: str) -> List[str]:
|
||||
async def filter_tags_with_llm(tags: List[str], end_user_id: str) -> List[str]:
|
||||
"""
|
||||
使用LLM筛选标签列表,仅保留具有代表性的核心名词。
|
||||
|
||||
Args:
|
||||
tags: 原始标签列表
|
||||
group_id: 用户组ID,用于获取配置
|
||||
end_user_id: 用户组ID,用于获取配置
|
||||
|
||||
Returns:
|
||||
筛选后的标签列表
|
||||
@@ -37,18 +37,22 @@ async def filter_tags_with_llm(tags: List[str], group_id: str) -> List[str]:
|
||||
get_end_user_connected_config,
|
||||
)
|
||||
|
||||
connected_config = get_end_user_connected_config(group_id, db)
|
||||
connected_config = get_end_user_connected_config(end_user_id, db)
|
||||
config_id = connected_config.get("memory_config_id")
|
||||
workspace_id = connected_config.get("workspace_id")
|
||||
|
||||
if not config_id:
|
||||
if not config_id and not workspace_id:
|
||||
raise ValueError(
|
||||
f"No memory_config_id found for group_id: {group_id}. "
|
||||
f"No memory_config_id found for end_user_id: {end_user_id}. "
|
||||
"Please ensure the user has a valid memory configuration."
|
||||
)
|
||||
|
||||
# Use the config_id to get the proper LLM client
|
||||
# Use the config_id to get the proper LLM client with workspace fallback
|
||||
config_service = MemoryConfigService(db)
|
||||
memory_config = config_service.load_memory_config(config_id)
|
||||
memory_config = config_service.load_memory_config(
|
||||
config_id=config_id,
|
||||
workspace_id=workspace_id
|
||||
)
|
||||
|
||||
if not memory_config.llm_model_id:
|
||||
raise ValueError(
|
||||
@@ -87,7 +91,7 @@ async def filter_tags_with_llm(tags: List[str], group_id: str) -> List[str]:
|
||||
|
||||
async def get_raw_tags_from_db(
|
||||
connector: Neo4jConnector,
|
||||
group_id: str,
|
||||
end_user_id: str,
|
||||
limit: int,
|
||||
by_user: bool = False
|
||||
) -> List[Tuple[str, int]]:
|
||||
@@ -99,9 +103,9 @@ async def get_raw_tags_from_db(
|
||||
|
||||
Args:
|
||||
connector: Neo4j连接器实例
|
||||
group_id: 如果by_user=False,则为group_id;如果by_user=True,则为user_id
|
||||
end_user_id: 如果by_user=False,则为end_user_id;如果by_user=True,则为user_id
|
||||
limit: 返回的标签数量限制
|
||||
by_user: 是否按user_id查询(默认False,按group_id查询)
|
||||
by_user: 是否按user_id查询(默认False,按end_user_id查询)
|
||||
|
||||
Returns:
|
||||
List[Tuple[str, int]]: 标签名称和频率的元组列表
|
||||
@@ -119,7 +123,7 @@ async def get_raw_tags_from_db(
|
||||
else:
|
||||
query = (
|
||||
"MATCH (e:ExtractedEntity) "
|
||||
"WHERE e.group_id = $id AND e.entity_type <> '人物' AND e.name IS NOT NULL AND NOT e.name IN $names_to_exclude "
|
||||
"WHERE e.end_user_id = $id AND e.entity_type <> '人物' AND e.name IS NOT NULL AND NOT e.name IN $names_to_exclude "
|
||||
"RETURN e.name AS name, count(e) AS frequency "
|
||||
"ORDER BY frequency DESC "
|
||||
"LIMIT $limit"
|
||||
@@ -128,44 +132,44 @@ async def get_raw_tags_from_db(
|
||||
# 使用项目的Neo4jConnector执行查询
|
||||
results = await connector.execute_query(
|
||||
query,
|
||||
id=group_id,
|
||||
id=end_user_id,
|
||||
limit=limit,
|
||||
names_to_exclude=names_to_exclude
|
||||
)
|
||||
|
||||
return [(record["name"], record["frequency"]) for record in results]
|
||||
|
||||
async def get_hot_memory_tags(group_id: str, limit: int = 40, by_user: bool = False) -> List[Tuple[str, int]]:
|
||||
async def get_hot_memory_tags(end_user_id: str, limit: int = 40, by_user: bool = False) -> List[Tuple[str, int]]:
|
||||
"""
|
||||
获取原始标签,然后使用LLM进行筛选,返回最终的热门标签列表。
|
||||
查询更多的标签(limit=40)给LLM提供更丰富的上下文进行筛选。
|
||||
|
||||
Args:
|
||||
group_id: 必需参数。如果by_user=False,则为group_id;如果by_user=True,则为user_id
|
||||
end_user_id: 必需参数。如果by_user=False,则为end_user_id;如果by_user=True,则为user_id
|
||||
limit: 返回的标签数量限制
|
||||
by_user: 是否按user_id查询(默认False,按group_id查询)
|
||||
by_user: 是否按user_id查询(默认False,按end_user_id查询)
|
||||
|
||||
Raises:
|
||||
ValueError: 如果group_id未提供或为空
|
||||
ValueError: 如果end_user_id未提供或为空
|
||||
"""
|
||||
# 验证group_id必须提供且不为空
|
||||
if not group_id or not group_id.strip():
|
||||
# 验证end_user_id必须提供且不为空
|
||||
if not end_user_id or not end_user_id.strip():
|
||||
raise ValueError(
|
||||
"group_id is required. Please provide a valid group_id or user_id."
|
||||
"end_user_id is required. Please provide a valid end_user_id or user_id."
|
||||
)
|
||||
|
||||
# 使用项目的Neo4jConnector
|
||||
connector = Neo4jConnector()
|
||||
try:
|
||||
# 1. 从数据库获取原始排名靠前的标签
|
||||
raw_tags_with_freq = await get_raw_tags_from_db(connector, group_id, limit, by_user=by_user)
|
||||
raw_tags_with_freq = await get_raw_tags_from_db(connector, end_user_id, limit, by_user=by_user)
|
||||
if not raw_tags_with_freq:
|
||||
return []
|
||||
|
||||
raw_tag_names = [tag for tag, freq in raw_tags_with_freq]
|
||||
|
||||
# 2. 初始化LLM客户端并使用LLM筛选出有意义的标签
|
||||
meaningful_tag_names = await filter_tags_with_llm(raw_tag_names, group_id)
|
||||
meaningful_tag_names = await filter_tags_with_llm(raw_tag_names, end_user_id)
|
||||
|
||||
# 3. 根据LLM的筛选结果,构建最终的标签列表(保留原始频率和顺序)
|
||||
final_tags = []
|
||||
|
||||
@@ -108,7 +108,6 @@ class DimensionAnalyzer:
|
||||
|
||||
# Create dimension portrait
|
||||
portrait = DimensionPortrait(
|
||||
user_id=user_id,
|
||||
creativity=dimension_scores["creativity"],
|
||||
aesthetic=dimension_scores["aesthetic"],
|
||||
technology=dimension_scores["technology"],
|
||||
@@ -220,7 +219,7 @@ class DimensionAnalyzer:
|
||||
"""Create an empty dimension portrait when no data is available.
|
||||
|
||||
Args:
|
||||
user_id: Target user ID
|
||||
user_id: Target user ID (used for logging only)
|
||||
|
||||
Returns:
|
||||
Empty DimensionPortrait
|
||||
@@ -228,7 +227,6 @@ class DimensionAnalyzer:
|
||||
current_time = datetime.now()
|
||||
|
||||
return DimensionPortrait(
|
||||
user_id=user_id,
|
||||
creativity=self._create_default_dimension_score("creativity"),
|
||||
aesthetic=self._create_default_dimension_score("aesthetic"),
|
||||
technology=self._create_default_dimension_score("technology"),
|
||||
|
||||
@@ -7,7 +7,7 @@ providing percentage distribution that totals 100%.
|
||||
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from app.core.memory.analytics.implicit_memory.llm_client import ImplicitMemoryLLMClient
|
||||
from app.core.memory.llm_tools.llm_client import LLMClientException
|
||||
@@ -133,7 +133,6 @@ class InterestAnalyzer:
|
||||
|
||||
# Create interest area distribution
|
||||
distribution = InterestAreaDistribution(
|
||||
user_id=user_id,
|
||||
tech=interest_categories["tech"],
|
||||
lifestyle=interest_categories["lifestyle"],
|
||||
music=interest_categories["music"],
|
||||
@@ -251,7 +250,7 @@ class InterestAnalyzer:
|
||||
"""Create an empty interest distribution when no data is available.
|
||||
|
||||
Args:
|
||||
user_id: Target user ID
|
||||
user_id: Target user ID (used for logging only)
|
||||
|
||||
Returns:
|
||||
Empty InterestAreaDistribution with equal percentages
|
||||
@@ -259,15 +258,15 @@ class InterestAnalyzer:
|
||||
current_time = datetime.now()
|
||||
equal_percentage = 25.0 # 100% / 4 categories
|
||||
|
||||
default_category = lambda name: InterestCategory(
|
||||
category_name=name,
|
||||
percentage=equal_percentage,
|
||||
evidence=["Insufficient data for analysis"],
|
||||
trending_direction=None
|
||||
)
|
||||
def default_category(name: str) -> InterestCategory:
|
||||
return InterestCategory(
|
||||
category_name=name,
|
||||
percentage=equal_percentage,
|
||||
evidence=["Insufficient data for analysis"],
|
||||
trending_direction=None
|
||||
)
|
||||
|
||||
return InterestAreaDistribution(
|
||||
user_id=user_id,
|
||||
tech=default_category("tech"),
|
||||
lifestyle=default_category("lifestyle"),
|
||||
music=default_category("music"),
|
||||
|
||||
@@ -75,8 +75,8 @@ class MemoryDataSource:
|
||||
start_date = time_range.start_date if time_range else None
|
||||
end_date = time_range.end_date if time_range else None
|
||||
|
||||
summary_dicts = await self.memory_summary_repo.find_by_group_id(
|
||||
group_id=user_id,
|
||||
summary_dicts = await self.memory_summary_repo.find_by_end_user_id(
|
||||
end_user_id=user_id,
|
||||
limit=limit,
|
||||
start_date=start_date,
|
||||
end_date=end_date
|
||||
|
||||
@@ -16,6 +16,7 @@ Summary {{ loop.index }}:
|
||||
3. DO NOT use long phrases - use short nouns or noun phrases
|
||||
4. Only include preferences with confidence_score >= 0.3
|
||||
5. **IMPORTANT: Output language MUST match the input language. If summaries are in Chinese, output in Chinese. If in English, output in English.**
|
||||
6. **CRITICAL: supporting_evidence must be DIRECT QUOTES or paraphrases from the user's actual statements. DO NOT reference summary numbers (e.g., "Summary 1", "摘要1"). DO NOT describe what the summary contains. Extract the actual user behavior or statement as evidence.**
|
||||
|
||||
## Output Format
|
||||
{
|
||||
@@ -38,6 +39,16 @@ Summary {{ loop.index }}:
|
||||
]
|
||||
}
|
||||
|
||||
## BAD supporting_evidence examples (DO NOT do this):
|
||||
- "Summary 1:西湖为核心景区" ❌
|
||||
- "摘要2中提到喜欢咖啡" ❌
|
||||
- "Based on Summary 3" ❌
|
||||
|
||||
## GOOD supporting_evidence examples:
|
||||
- "去过西湖断桥、苏堤" ✓
|
||||
- "每天早上喝咖啡" ✓
|
||||
- "mentioned visiting the lake twice" ✓
|
||||
|
||||
## Example (English input → English output)
|
||||
{
|
||||
"preferences": [
|
||||
|
||||
@@ -2,13 +2,16 @@ import os
|
||||
import re
|
||||
import glob
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Tuple
|
||||
|
||||
try:
|
||||
from app.core.memory.utils.config.definitions import PROJECT_ROOT
|
||||
except Exception:
|
||||
# Fallback: derive project root from this file location
|
||||
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
# 当前文件在 api/app/core/memory/analytics/recent_activity_stats.py
|
||||
# 需要向上 5 级到达 api/ 目录
|
||||
PROJECT_ROOT = str(Path(__file__).resolve().parents[4])
|
||||
|
||||
|
||||
def _get_latest_prompt_log_path() -> str | None:
|
||||
@@ -67,44 +70,43 @@ def parse_stats_from_log(log_path: str) -> dict:
|
||||
triplet_relations_count = 0
|
||||
temporal_count = 0
|
||||
|
||||
# Patterns
|
||||
# 正则表达式模式 - 匹配当前日志格式
|
||||
pat_chunk_render = re.compile(r"===\s*RENDERED\s*STATEMENT\s*EXTRACTION\s*PROMPT\s*===")
|
||||
pat_triplet_start = re.compile(r"\[Triplet\].*statements_to_process\s*=\s*(\d+)")
|
||||
pat_triplet_done = re.compile(
|
||||
r"\[Triplet\].*completed,\s*total_triplets\s*=\s*(\d+),\s*total_entities\s*=\s*(\d+)"
|
||||
pat_triplet_started = re.compile(r"\[Triplet\]\s+Started\s+-\s+statement_id=")
|
||||
pat_triplet_completed = re.compile(
|
||||
r"\[Triplet\]\s+Completed\s+-\s+statement_id=[^,]+,\s+triplets=(\d+),\s+entities=(\d+)"
|
||||
)
|
||||
pat_temporal_done = re.compile(
|
||||
r"\[Temporal\].*completed,\s*extracted_valid_ranges\s*=\s*(\d+)"
|
||||
pat_temporal_completed = re.compile(
|
||||
r"\[Temporal\]\s+Completed\s+-\s+statement_id=[^,]+,\s+valid_ranges=(\d+)"
|
||||
)
|
||||
|
||||
with open(log_path, "r", encoding="utf-8", errors="ignore") as f:
|
||||
for line in f:
|
||||
# Chunk prompts count (each chunk triggers one statement-extraction prompt render)
|
||||
# 文本块数量(每个块触发一次陈述提取提示)
|
||||
if pat_chunk_render.search(line):
|
||||
chunk_count += 1
|
||||
continue
|
||||
|
||||
m1 = pat_triplet_start.search(line)
|
||||
if m1:
|
||||
# 陈述数量(每个 Triplet Started 代表一个陈述被处理)
|
||||
if pat_triplet_started.search(line):
|
||||
statements_count += 1
|
||||
continue
|
||||
|
||||
# 三元组完成:[Triplet] Completed - statement_id=xxx, triplets=X, entities=Y
|
||||
m_triplet = pat_triplet_completed.search(line)
|
||||
if m_triplet:
|
||||
try:
|
||||
statements_count += int(m1.group(1))
|
||||
triplet_relations_count += int(m_triplet.group(1))
|
||||
triplet_entities_count += int(m_triplet.group(2))
|
||||
except Exception:
|
||||
pass
|
||||
continue
|
||||
|
||||
m2 = pat_triplet_done.search(line)
|
||||
if m2:
|
||||
# 时间信息完成:[Temporal] Completed - statement_id=xxx, valid_ranges=X
|
||||
m_temporal = pat_temporal_completed.search(line)
|
||||
if m_temporal:
|
||||
try:
|
||||
triplet_relations_count += int(m2.group(1))
|
||||
triplet_entities_count += int(m2.group(2))
|
||||
except Exception:
|
||||
pass
|
||||
continue
|
||||
|
||||
m3 = pat_temporal_done.search(line)
|
||||
if m3:
|
||||
try:
|
||||
temporal_count += int(m3.group(1))
|
||||
temporal_count += int(m_temporal.group(1))
|
||||
except Exception:
|
||||
pass
|
||||
continue
|
||||
@@ -120,15 +122,20 @@ def parse_stats_from_log(log_path: str) -> dict:
|
||||
|
||||
|
||||
def get_recent_activity_stats() -> Tuple[dict, str]:
|
||||
"""Get aggregated stats from all prompt logs in logs/.
|
||||
"""Get stats from the latest prompt log file only.
|
||||
|
||||
Returns (stats_dict, message).
|
||||
"""
|
||||
all_logs = _get_all_prompt_logs()
|
||||
# Fallback to recursive search if none found in logs/
|
||||
if not all_logs:
|
||||
# 获取最新的日志文件
|
||||
latest_log = _get_latest_prompt_log_path()
|
||||
|
||||
# 如果没有找到,尝试递归搜索
|
||||
if not latest_log:
|
||||
all_logs = _get_any_logs_recursive()
|
||||
if not all_logs:
|
||||
if all_logs:
|
||||
latest_log = all_logs[-1] # 取最新的
|
||||
|
||||
if not latest_log:
|
||||
return (
|
||||
{
|
||||
"chunk_count": 0,
|
||||
@@ -141,24 +148,13 @@ def get_recent_activity_stats() -> Tuple[dict, str]:
|
||||
"未找到日志文件,请确认已运行过提取流程。",
|
||||
)
|
||||
|
||||
agg = {
|
||||
"chunk_count": 0,
|
||||
"statements_count": 0,
|
||||
"triplet_entities_count": 0,
|
||||
"triplet_relations_count": 0,
|
||||
"temporal_count": 0,
|
||||
}
|
||||
for path in all_logs:
|
||||
s = parse_stats_from_log(path)
|
||||
agg["chunk_count"] += s.get("chunk_count", 0)
|
||||
agg["statements_count"] += s.get("statements_count", 0)
|
||||
agg["triplet_entities_count"] += s.get("triplet_entities_count", 0)
|
||||
agg["triplet_relations_count"] += s.get("triplet_relations_count", 0)
|
||||
agg["temporal_count"] += s.get("temporal_count", 0)
|
||||
|
||||
# Attach a summary of files combined
|
||||
agg["log_path"] = f"{len(all_logs)} 个日志文件,最新:{all_logs[-1]}"
|
||||
return agg, "成功汇总 logs 目录中所有提示日志。"
|
||||
# 只解析最新的日志文件
|
||||
stats = parse_stats_from_log(latest_log)
|
||||
|
||||
# 添加日志文件路径信息
|
||||
stats["log_path"] = f"最新:{latest_log}"
|
||||
|
||||
return stats, "成功读取最近一次记忆活动统计。"
|
||||
|
||||
|
||||
def _format_summary(stats: dict) -> str:
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
"""Evaluation package with dataset-specific pipelines and a unified runner."""
|
||||
@@ -1,30 +0,0 @@
|
||||
⏬数据集下载地址:
|
||||
Locomo10.json:https://github.com/snap-research/locomo/tree/main/data
|
||||
LongMemEval_oracle.json:https://huggingface.co/datasets/xiaowu0162/longmemeval-cleaned
|
||||
msc_self_instruct.jsonl:https://huggingface.co/datasets/MemGPT/MSC-Self-Instruct
|
||||
上方数据集下载好后全部放入app/core/memory/data文件夹中
|
||||
|
||||
全流程基准测试运行:
|
||||
locomo:
|
||||
python -m app.core.memory.evaluation.run_eval --dataset locomo --sample-size 1 --reset-group --group-id yyw1 --search-type hybrid --search-limit 8 --context-char-budget 12000 --llm-max-tokens 32
|
||||
LongMemEval:
|
||||
python -m app.core.memory.evaluation.run_eval --dataset longmemeval --sample-size 10 --start-index 0 --group-id longmemeval_zh_bak_2 --search-limit 8 --context-char-budget 4000 --search-type hybrid --max-contexts-per-item 2 --reset-group
|
||||
memsciqa:
|
||||
python -m app.core.memory.evaluation.run_eval --dataset memsciqa --sample-size 10 --reset-group --group-id group_memsci
|
||||
|
||||
单独检索评估运行命令:
|
||||
python -m app.core.memory.evaluation.locomo.locomo_test
|
||||
python -m app.core.memory.evaluation.longmemeval.test_eval
|
||||
python -m app.core.memory.evaluation.memsciqa.memsciqa-test
|
||||
需要先在项目中修改需要检测评估的group_id。
|
||||
|
||||
参数及解释:
|
||||
● --dataset longmemeval - 指定数据集
|
||||
● --sample-size 10 - 评估10个样本
|
||||
● --start-index 0 - 从第0个样本开始
|
||||
● --group-id longmemeval_zh_bak_2 - 使用指定的组ID
|
||||
● --search-limit 8 - 检索限制8条
|
||||
● --context-char-budget 4000 - 上下文字符预算4000
|
||||
● --search-type hybrid - 使用混合检索
|
||||
● --max-contexts-per-item 2 - 每个样本最多摄入2个上下文
|
||||
● --reset-group - 运行前清空组数据
|
||||
@@ -1,100 +0,0 @@
|
||||
import math
|
||||
import re
|
||||
from typing import List, Dict
|
||||
|
||||
|
||||
def _normalize(text: str) -> List[str]:
|
||||
"""Lowercase, strip punctuation, and split into tokens."""
|
||||
text = text.lower().strip()
|
||||
# Python's re doesn't support \p classes; use a simple non-word filter
|
||||
text = re.sub(r"[^\w\s]", " ", text)
|
||||
tokens = [t for t in text.split() if t]
|
||||
return tokens
|
||||
|
||||
|
||||
def exact_match(pred: str, ref: str) -> float:
|
||||
return float(_normalize(pred) == _normalize(ref))
|
||||
|
||||
|
||||
def jaccard(pred: str, ref: str) -> float:
|
||||
p = set(_normalize(pred))
|
||||
r = set(_normalize(ref))
|
||||
if not p and not r:
|
||||
return 1.0
|
||||
if not p or not r:
|
||||
return 0.0
|
||||
return len(p & r) / len(p | r)
|
||||
|
||||
|
||||
def f1_score(pred: str, ref: str) -> float:
|
||||
p_tokens = _normalize(pred)
|
||||
r_tokens = _normalize(ref)
|
||||
if not p_tokens and not r_tokens:
|
||||
return 1.0
|
||||
if not p_tokens or not r_tokens:
|
||||
return 0.0
|
||||
p_set = set(p_tokens)
|
||||
r_set = set(r_tokens)
|
||||
tp = len(p_set & r_set)
|
||||
precision = tp / len(p_set) if p_set else 0.0
|
||||
recall = tp / len(r_set) if r_set else 0.0
|
||||
if precision + recall == 0:
|
||||
return 0.0
|
||||
return 2 * precision * recall / (precision + recall)
|
||||
|
||||
|
||||
def bleu1(pred: str, ref: str) -> float:
|
||||
"""Unigram BLEU (BLEU-1) with clipping and brevity penalty."""
|
||||
p_tokens = _normalize(pred)
|
||||
r_tokens = _normalize(ref)
|
||||
if not p_tokens:
|
||||
return 0.0
|
||||
# Clipped count
|
||||
r_counts: Dict[str, int] = {}
|
||||
for t in r_tokens:
|
||||
r_counts[t] = r_counts.get(t, 0) + 1
|
||||
clipped = 0
|
||||
p_counts: Dict[str, int] = {}
|
||||
for t in p_tokens:
|
||||
p_counts[t] = p_counts.get(t, 0) + 1
|
||||
for t, c in p_counts.items():
|
||||
clipped += min(c, r_counts.get(t, 0))
|
||||
precision = clipped / max(len(p_tokens), 1)
|
||||
# Brevity penalty
|
||||
ref_len = len(r_tokens)
|
||||
pred_len = len(p_tokens)
|
||||
if pred_len > ref_len or pred_len == 0:
|
||||
bp = 1.0
|
||||
else:
|
||||
bp = math.exp(1 - ref_len / max(pred_len, 1))
|
||||
return bp * precision
|
||||
|
||||
|
||||
def percentile(values: List[float], p: float) -> float:
|
||||
if not values:
|
||||
return 0.0
|
||||
vals = sorted(values)
|
||||
k = (len(vals) - 1) * p
|
||||
f = math.floor(k)
|
||||
c = math.ceil(k)
|
||||
if f == c:
|
||||
return vals[int(k)]
|
||||
return vals[f] + (k - f) * (vals[c] - vals[f])
|
||||
|
||||
|
||||
def latency_stats(latencies_ms: List[float]) -> Dict[str, float]:
|
||||
"""Return basic latency stats: mean, p50, p95, iqr (p75-p25)."""
|
||||
if not latencies_ms:
|
||||
return {"mean": 0.0, "p50": 0.0, "p95": 0.0, "iqr": 0.0}
|
||||
p25 = percentile(latencies_ms, 0.25)
|
||||
p50 = percentile(latencies_ms, 0.50)
|
||||
p75 = percentile(latencies_ms, 0.75)
|
||||
p95 = percentile(latencies_ms, 0.95)
|
||||
mean = sum(latencies_ms) / max(len(latencies_ms), 1)
|
||||
return {"mean": mean, "p50": p50, "p95": p95, "iqr": p75 - p25}
|
||||
|
||||
|
||||
def avg_context_tokens(contexts: List[str]) -> float:
|
||||
if not contexts:
|
||||
return 0.0
|
||||
return sum(len(_normalize(c)) for c in contexts) / len(contexts)
|
||||
@@ -1,60 +0,0 @@
|
||||
"""
|
||||
Dialogue search queries for evaluation purposes.
|
||||
This file contains Cypher queries for searching dialogues, entities, and chunks.
|
||||
Placed in evaluation directory to avoid circular imports with src modules.
|
||||
"""
|
||||
|
||||
# Entity search queries
|
||||
SEARCH_ENTITIES_BY_NAME = """
|
||||
MATCH (e:Entity)
|
||||
WHERE e.name = $name
|
||||
RETURN e
|
||||
"""
|
||||
|
||||
SEARCH_ENTITIES_BY_NAME_FALLBACK = """
|
||||
MATCH (e:Entity)
|
||||
WHERE e.name CONTAINS $name
|
||||
RETURN e
|
||||
"""
|
||||
|
||||
# Chunk search queries
|
||||
SEARCH_CHUNKS_BY_CONTENT = """
|
||||
MATCH (c:Chunk)
|
||||
WHERE c.content CONTAINS $content
|
||||
RETURN c
|
||||
"""
|
||||
|
||||
# Dialogue search queries
|
||||
SEARCH_DIALOGUE_BY_DIALOG_ID = """
|
||||
MATCH (d:Dialogue)
|
||||
WHERE d.dialog_id = $dialog_id
|
||||
RETURN d
|
||||
"""
|
||||
|
||||
SEARCH_DIALOGUES_BY_CONTENT = """
|
||||
MATCH (d:Dialogue)
|
||||
WHERE d.content CONTAINS $q
|
||||
RETURN d
|
||||
"""
|
||||
|
||||
DIALOGUE_EMBEDDING_SEARCH = """
|
||||
WITH $embedding AS q
|
||||
MATCH (d:Dialogue)
|
||||
WHERE d.dialog_embedding IS NOT NULL
|
||||
AND ($group_id IS NULL OR d.group_id = $group_id)
|
||||
WITH d, q, d.dialog_embedding AS v
|
||||
WITH d,
|
||||
reduce(dot = 0.0, i IN range(0, size(q)-1) | dot + toFloat(q[i]) * toFloat(v[i])) AS dot,
|
||||
sqrt(reduce(qs = 0.0, i IN range(0, size(q)-1) | qs + toFloat(q[i]) * toFloat(q[i]))) AS qnorm,
|
||||
sqrt(reduce(vs = 0.0, i IN range(0, size(v)-1) | vs + toFloat(v[i]) * toFloat(v[i]))) AS vnorm
|
||||
WITH d, CASE WHEN qnorm = 0 OR vnorm = 0 THEN 0.0 ELSE dot / (qnorm * vnorm) END AS score
|
||||
WHERE score > $threshold
|
||||
RETURN d.id AS dialog_id,
|
||||
d.group_id AS group_id,
|
||||
d.content AS content,
|
||||
d.created_at AS created_at,
|
||||
d.expired_at AS expired_at,
|
||||
score
|
||||
ORDER BY score DESC
|
||||
LIMIT $limit
|
||||
"""
|
||||
@@ -1,341 +0,0 @@
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from app.core.memory.llm_tools.openai_client import LLMClient
|
||||
from app.core.memory.models.message_models import (
|
||||
ConversationContext,
|
||||
ConversationMessage,
|
||||
DialogData,
|
||||
)
|
||||
|
||||
# 使用新的模块化架构
|
||||
from app.core.memory.storage_services.extraction_engine.extraction_orchestrator import (
|
||||
ExtractionOrchestrator,
|
||||
)
|
||||
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.chunk_extraction import (
|
||||
DialogueChunker,
|
||||
)
|
||||
from app.core.memory.utils.config.definitions import (
|
||||
SELECTED_CHUNKER_STRATEGY,
|
||||
SELECTED_EMBEDDING_ID,
|
||||
)
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.db import get_db_context
|
||||
|
||||
# Import from database module
|
||||
from app.repositories.neo4j.graph_saver import save_dialog_and_statements_to_neo4j
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
|
||||
# Cypher queries for evaluation
|
||||
# Note: Entity, chunk, and dialogue search queries have been moved to evaluation/dialogue_queries.py
|
||||
|
||||
|
||||
async def ingest_contexts_via_full_pipeline(
|
||||
contexts: List[str],
|
||||
group_id: str,
|
||||
chunker_strategy: str | None = None,
|
||||
embedding_name: str | None = None,
|
||||
save_chunk_output: bool = False,
|
||||
save_chunk_output_path: str | None = None,
|
||||
) -> bool:
|
||||
"""DEPRECATED: 此函数使用旧的流水线架构,建议使用新的 ExtractionOrchestrator
|
||||
|
||||
Run the full extraction pipeline on provided dialogue contexts and save to Neo4j.
|
||||
This function mirrors the steps in main(), but starts from raw text contexts.
|
||||
Args:
|
||||
contexts: List of dialogue texts, each containing lines like "role: message".
|
||||
group_id: Group ID to assign to generated DialogData and graph nodes.
|
||||
chunker_strategy: Optional chunker strategy; defaults to SELECTED_CHUNKER_STRATEGY.
|
||||
embedding_name: Optional embedding model ID; defaults to SELECTED_EMBEDDING_ID.
|
||||
save_chunk_output: If True, write chunked DialogData list to a JSON file for debugging.
|
||||
save_chunk_output_path: Optional output path; defaults to src/chunker_test_output.txt.
|
||||
Returns:
|
||||
True if data saved successfully, False otherwise.
|
||||
"""
|
||||
chunker_strategy = chunker_strategy or SELECTED_CHUNKER_STRATEGY
|
||||
embedding_name = embedding_name or SELECTED_EMBEDDING_ID
|
||||
|
||||
# Initialize llm client with graceful fallback
|
||||
llm_client = None
|
||||
llm_available = True
|
||||
try:
|
||||
from app.core.memory.utils.config import definitions as config_defs
|
||||
with get_db_context() as db:
|
||||
factory = MemoryClientFactory(db)
|
||||
llm_client = factory.get_llm_client(config_defs.SELECTED_LLM_ID)
|
||||
except Exception as e:
|
||||
print(f"[Ingestion] LLM client unavailable, will skip LLM-dependent steps: {e}")
|
||||
llm_available = False
|
||||
|
||||
# Step A: Build DialogData list from contexts with robust parsing
|
||||
chunker = DialogueChunker(chunker_strategy)
|
||||
dialog_data_list: List[DialogData] = []
|
||||
|
||||
for idx, ctx in enumerate(contexts):
|
||||
messages: List[ConversationMessage] = []
|
||||
|
||||
# Improved parsing: capture multi-line message blocks, normalize roles
|
||||
pattern = r"^\s*(用户|AI|assistant|user)\s*[::]\s*(.+?)(?=\n\s*(?:用户|AI|assistant|user)\s*[::]|\Z)"
|
||||
matches = list(re.finditer(pattern, ctx, flags=re.MULTILINE | re.DOTALL))
|
||||
|
||||
if matches:
|
||||
for m in matches:
|
||||
raw_role = m.group(1).strip()
|
||||
content = m.group(2).strip()
|
||||
norm_role = "AI" if raw_role.lower() in ("ai", "assistant") else "用户"
|
||||
messages.append(ConversationMessage(role=norm_role, msg=content))
|
||||
else:
|
||||
# Fallback: line-by-line parsing
|
||||
for raw in ctx.split("\n"):
|
||||
line = raw.strip()
|
||||
if not line:
|
||||
continue
|
||||
m = re.match(r'^\s*([^::]+)\s*[::]\s*(.+)$', line)
|
||||
if m:
|
||||
role = m.group(1).strip()
|
||||
msg = m.group(2).strip()
|
||||
norm_role = "AI" if role.lower() in ("ai", "assistant") else "用户"
|
||||
messages.append(ConversationMessage(role=norm_role, msg=msg))
|
||||
else:
|
||||
# Final fallback: treat as user message
|
||||
default_role = "AI" if re.match(r'^\s*(assistant|AI)\b', line, flags=re.IGNORECASE) else "用户"
|
||||
messages.append(ConversationMessage(role=default_role, msg=line))
|
||||
|
||||
context_model = ConversationContext(msgs=messages)
|
||||
dialog = DialogData(
|
||||
context=context_model,
|
||||
ref_id=f"pipeline_item_{idx}",
|
||||
group_id=group_id,
|
||||
user_id="default_user",
|
||||
apply_id="default_application",
|
||||
)
|
||||
# Generate chunks
|
||||
dialog.chunks = await chunker.process_dialogue(dialog)
|
||||
dialog_data_list.append(dialog)
|
||||
|
||||
if not dialog_data_list:
|
||||
print("No dialogs to process for ingestion.")
|
||||
return False
|
||||
|
||||
# Optionally save chunking outputs for debugging
|
||||
if save_chunk_output:
|
||||
try:
|
||||
def _serialize_datetime(obj):
|
||||
if isinstance(obj, datetime):
|
||||
return obj.isoformat()
|
||||
raise TypeError(f"Object of type {obj.__class__.__name__} is not JSON serializable")
|
||||
|
||||
from app.core.config import settings
|
||||
settings.ensure_memory_output_dir()
|
||||
default_path = settings.get_memory_output_path("chunker_test_output.txt")
|
||||
out_path = save_chunk_output_path or default_path
|
||||
|
||||
combined_output = [dd.model_dump() for dd in dialog_data_list]
|
||||
with open(out_path, "w", encoding="utf-8") as f:
|
||||
json.dump(combined_output, f, ensure_ascii=False, indent=4, default=_serialize_datetime)
|
||||
print(f"Saved chunking results to: {out_path}")
|
||||
except Exception as e:
|
||||
print(f"Failed to save chunking results: {e}")
|
||||
|
||||
# Step B-G: 使用新的 ExtractionOrchestrator 执行完整的提取流水线
|
||||
if not llm_available:
|
||||
print("[Ingestion] Skipping extraction pipeline (no LLM).")
|
||||
return False
|
||||
|
||||
# 初始化 embedder 客户端
|
||||
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
|
||||
try:
|
||||
with get_db_context() as db:
|
||||
embedder_config_dict = MemoryConfigService(db).get_embedder_config(embedding_name or SELECTED_EMBEDDING_ID)
|
||||
embedder_config = RedBearModelConfig(**embedder_config_dict)
|
||||
embedder_client = OpenAIEmbedderClient(embedder_config)
|
||||
except Exception as e:
|
||||
print(f"[Ingestion] Failed to initialize embedder client: {e}")
|
||||
print("[Ingestion] Skipping extraction pipeline (embedder initialization failed).")
|
||||
return False
|
||||
|
||||
connector = Neo4jConnector()
|
||||
|
||||
# 初始化并运行 ExtractionOrchestrator
|
||||
from app.core.memory.utils.config.config_utils import get_pipeline_config
|
||||
config = get_pipeline_config()
|
||||
|
||||
orchestrator = ExtractionOrchestrator(
|
||||
llm_client=llm_client,
|
||||
embedder_client=embedder_client,
|
||||
connector=connector,
|
||||
config=config,
|
||||
)
|
||||
|
||||
# 创建一个包装的 orchestrator 来修复时间提取器的输出
|
||||
# 保存原始的 _assign_extracted_data 方法
|
||||
original_assign = orchestrator._assign_extracted_data
|
||||
|
||||
def clean_temporal_value(value):
|
||||
"""清理 temporal_validity 字段的值,将无效值转换为 None"""
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, str):
|
||||
# 处理字符串形式的 'null', 'None', 空字符串等
|
||||
if value.lower() in ('null', 'none', '') or value.strip() == '':
|
||||
return None
|
||||
return value
|
||||
|
||||
async def patched_assign_extracted_data(*args, **kwargs):
|
||||
"""包装方法:在赋值后清理 temporal_validity 中的无效字符串"""
|
||||
result = await original_assign(*args, **kwargs)
|
||||
|
||||
# 清理返回的 dialog_data_list 中的 temporal_validity
|
||||
for dialog in result:
|
||||
if hasattr(dialog, 'chunks') and dialog.chunks:
|
||||
for chunk in dialog.chunks:
|
||||
if hasattr(chunk, 'statements') and chunk.statements:
|
||||
for statement in chunk.statements:
|
||||
if hasattr(statement, 'temporal_validity') and statement.temporal_validity:
|
||||
tv = statement.temporal_validity
|
||||
# 清理 valid_at 和 invalid_at
|
||||
if hasattr(tv, 'valid_at'):
|
||||
tv.valid_at = clean_temporal_value(tv.valid_at)
|
||||
if hasattr(tv, 'invalid_at'):
|
||||
tv.invalid_at = clean_temporal_value(tv.invalid_at)
|
||||
return result
|
||||
|
||||
# 替换方法
|
||||
orchestrator._assign_extracted_data = patched_assign_extracted_data
|
||||
|
||||
# 同时包装 _create_nodes_and_edges 方法,在创建节点前再次清理
|
||||
original_create = orchestrator._create_nodes_and_edges
|
||||
|
||||
async def patched_create_nodes_and_edges(dialog_data_list_arg):
|
||||
"""包装方法:在创建节点前再次清理 temporal_validity"""
|
||||
# 最后一次清理,确保万无一失
|
||||
for dialog in dialog_data_list_arg:
|
||||
if hasattr(dialog, 'chunks') and dialog.chunks:
|
||||
for chunk in dialog.chunks:
|
||||
if hasattr(chunk, 'statements') and chunk.statements:
|
||||
for statement in chunk.statements:
|
||||
if hasattr(statement, 'temporal_validity') and statement.temporal_validity:
|
||||
tv = statement.temporal_validity
|
||||
if hasattr(tv, 'valid_at'):
|
||||
tv.valid_at = clean_temporal_value(tv.valid_at)
|
||||
if hasattr(tv, 'invalid_at'):
|
||||
tv.invalid_at = clean_temporal_value(tv.invalid_at)
|
||||
|
||||
return await original_create(dialog_data_list_arg)
|
||||
|
||||
orchestrator._create_nodes_and_edges = patched_create_nodes_and_edges
|
||||
|
||||
# 运行完整的提取流水线
|
||||
# orchestrator.run 返回 7 个元素的元组
|
||||
result = await orchestrator.run(dialog_data_list, is_pilot_run=False)
|
||||
(
|
||||
dialogue_nodes,
|
||||
chunk_nodes,
|
||||
statement_nodes,
|
||||
entity_nodes,
|
||||
statement_chunk_edges,
|
||||
statement_entity_edges,
|
||||
entity_entity_edges,
|
||||
) = result
|
||||
|
||||
# statement_chunk_edges 已经由 orchestrator 创建,无需重复创建
|
||||
|
||||
# Step G: 生成记忆摘要
|
||||
print("[Ingestion] Generating memory summaries...")
|
||||
try:
|
||||
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.memory_summary import (
|
||||
memory_summary_generation,
|
||||
)
|
||||
from app.repositories.neo4j.add_edges import add_memory_summary_statement_edges
|
||||
from app.repositories.neo4j.add_nodes import add_memory_summary_nodes
|
||||
|
||||
summaries = await memory_summary_generation(
|
||||
chunked_dialogs=dialog_data_list,
|
||||
llm_client=llm_client,
|
||||
embedder_client=embedder_client
|
||||
)
|
||||
print(f"[Ingestion] Generated {len(summaries)} memory summaries")
|
||||
except Exception as e:
|
||||
print(f"[Ingestion] Warning: Failed to generate memory summaries: {e}")
|
||||
summaries = []
|
||||
|
||||
# Step H: Save to Neo4j
|
||||
try:
|
||||
success = await save_dialog_and_statements_to_neo4j(
|
||||
dialogue_nodes=dialogue_nodes,
|
||||
chunk_nodes=chunk_nodes,
|
||||
statement_nodes=statement_nodes,
|
||||
entity_nodes=entity_nodes,
|
||||
entity_edges=entity_entity_edges,
|
||||
statement_chunk_edges=statement_chunk_edges,
|
||||
statement_entity_edges=statement_entity_edges,
|
||||
connector=connector
|
||||
)
|
||||
|
||||
# Save memory summaries separately
|
||||
if summaries:
|
||||
try:
|
||||
await add_memory_summary_nodes(summaries, connector)
|
||||
await add_memory_summary_statement_edges(summaries, connector)
|
||||
print(f"Successfully saved {len(summaries)} memory summary nodes to Neo4j")
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to save summary nodes: {e}")
|
||||
|
||||
await connector.close()
|
||||
if success:
|
||||
print("Successfully saved extracted data to Neo4j!")
|
||||
else:
|
||||
print("Failed to save data to Neo4j")
|
||||
return success
|
||||
except Exception as e:
|
||||
print(f"Failed to save data to Neo4j: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def handle_context_processing(args):
|
||||
"""Handle context-based processing from command line arguments."""
|
||||
contexts = []
|
||||
|
||||
if args.contexts:
|
||||
contexts.extend(args.contexts)
|
||||
|
||||
if args.context_file:
|
||||
try:
|
||||
with open(args.context_file, 'r', encoding='utf-8') as f:
|
||||
contexts.extend(line.strip() for line in f if line.strip())
|
||||
except Exception as e:
|
||||
print(f"Error reading context file: {e}")
|
||||
return False
|
||||
|
||||
if not contexts:
|
||||
print("No contexts provided for processing.")
|
||||
return False
|
||||
|
||||
return await main_from_contexts(contexts, args.context_group_id)
|
||||
|
||||
|
||||
async def main_from_contexts(contexts: List[str], group_id: str):
|
||||
"""Run the pipeline from provided dialogue contexts instead of test data."""
|
||||
print("=== Running pipeline from provided contexts ===")
|
||||
|
||||
success = await ingest_contexts_via_full_pipeline(
|
||||
contexts=contexts,
|
||||
group_id=group_id,
|
||||
chunker_strategy=SELECTED_CHUNKER_STRATEGY,
|
||||
embedding_name=SELECTED_EMBEDDING_ID,
|
||||
save_chunk_output=True
|
||||
)
|
||||
|
||||
if success:
|
||||
print("Successfully processed and saved contexts to Neo4j!")
|
||||
else:
|
||||
print("Failed to process contexts.")
|
||||
|
||||
return success
|
||||
@@ -1,575 +0,0 @@
|
||||
"""
|
||||
LoCoMo Benchmark Script
|
||||
|
||||
This module provides the main entry point for running LoCoMo benchmark evaluations.
|
||||
It orchestrates data loading, ingestion, retrieval, LLM inference, and metric calculation
|
||||
in a clean, maintainable way.
|
||||
|
||||
Usage:
|
||||
python locomo_benchmark.py --sample_size 20 --search_type hybrid
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
try:
|
||||
from dotenv import load_dotenv
|
||||
except ImportError:
|
||||
def load_dotenv():
|
||||
pass
|
||||
|
||||
from app.core.memory.evaluation.common.metrics import (
|
||||
avg_context_tokens,
|
||||
bleu1,
|
||||
f1_score,
|
||||
jaccard,
|
||||
latency_stats,
|
||||
)
|
||||
from app.core.memory.evaluation.locomo.locomo_metrics import (
|
||||
get_category_name,
|
||||
locomo_f1_score,
|
||||
locomo_multi_f1,
|
||||
)
|
||||
from app.core.memory.evaluation.locomo.locomo_utils import (
|
||||
extract_conversations,
|
||||
ingest_conversations_if_needed,
|
||||
load_locomo_data,
|
||||
resolve_temporal_references,
|
||||
retrieve_relevant_information,
|
||||
select_and_format_information,
|
||||
)
|
||||
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
|
||||
from app.core.memory.utils.definitions import (
|
||||
PROJECT_ROOT,
|
||||
SELECTED_EMBEDDING_ID,
|
||||
SELECTED_GROUP_ID,
|
||||
SELECTED_LLM_ID,
|
||||
)
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
from app.db import get_db_context
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
|
||||
|
||||
async def run_locomo_benchmark(
|
||||
sample_size: int = 20,
|
||||
group_id: Optional[str] = None,
|
||||
search_type: str = "hybrid",
|
||||
search_limit: int = 12,
|
||||
context_char_budget: int = 8000,
|
||||
reset_group: bool = False,
|
||||
skip_ingest: bool = False,
|
||||
output_dir: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Run LoCoMo benchmark evaluation.
|
||||
|
||||
This function orchestrates the complete evaluation pipeline:
|
||||
1. Load LoCoMo dataset (only QA pairs from first conversation)
|
||||
2. Check/ingest conversations into database (only first conversation, unless skip_ingest=True)
|
||||
3. For each question:
|
||||
- Retrieve relevant information
|
||||
- Generate answer using LLM
|
||||
- Calculate metrics
|
||||
4. Aggregate results and save to file
|
||||
|
||||
Note: By default, only the first conversation is ingested into the database,
|
||||
and only QA pairs from that conversation are evaluated. This ensures that
|
||||
all questions have corresponding memory in the database for retrieval.
|
||||
|
||||
Args:
|
||||
sample_size: Number of QA pairs to evaluate (from first conversation)
|
||||
group_id: Database group ID for retrieval (uses default if None)
|
||||
search_type: "keyword", "embedding", or "hybrid"
|
||||
search_limit: Max documents to retrieve per query
|
||||
context_char_budget: Max characters for context
|
||||
reset_group: Whether to clear and re-ingest data (not implemented)
|
||||
skip_ingest: If True, skip data ingestion and use existing data in Neo4j
|
||||
output_dir: Directory to save results (uses default if None)
|
||||
|
||||
Returns:
|
||||
Dictionary with evaluation results including metrics, timing, and samples
|
||||
"""
|
||||
# Use default group_id if not provided
|
||||
group_id = group_id or SELECTED_GROUP_ID
|
||||
|
||||
# Determine data path
|
||||
data_path = os.path.join(PROJECT_ROOT, "data", "locomo10.json")
|
||||
if not os.path.exists(data_path):
|
||||
# Fallback to current directory
|
||||
data_path = os.path.join(os.getcwd(), "data", "locomo10.json")
|
||||
|
||||
print(f"\n{'='*60}")
|
||||
print("🚀 Starting LoCoMo Benchmark Evaluation")
|
||||
print(f"{'='*60}")
|
||||
print("📊 Configuration:")
|
||||
print(f" Sample size: {sample_size}")
|
||||
print(f" Group ID: {group_id}")
|
||||
print(f" Search type: {search_type}")
|
||||
print(f" Search limit: {search_limit}")
|
||||
print(f" Context budget: {context_char_budget} chars")
|
||||
print(f" Data path: {data_path}")
|
||||
print(f"{'='*60}\n")
|
||||
|
||||
# Step 1: Load LoCoMo data
|
||||
print("📂 Loading LoCoMo dataset...")
|
||||
try:
|
||||
# Only load QA pairs from the first conversation (index 0)
|
||||
# since we only ingest the first conversation into the database
|
||||
qa_items = load_locomo_data(data_path, sample_size, conversation_index=0)
|
||||
print(f"✅ Loaded {len(qa_items)} QA pairs from conversation 0\n")
|
||||
except Exception as e:
|
||||
print(f"❌ Failed to load data: {e}")
|
||||
return {
|
||||
"error": f"Data loading failed: {e}",
|
||||
"timestamp": datetime.now().isoformat()
|
||||
}
|
||||
|
||||
# Step 2: Extract conversations and ingest if needed
|
||||
if skip_ingest:
|
||||
print("⏭️ Skipping data ingestion (using existing data in Neo4j)")
|
||||
print(f" Group ID: {group_id}\n")
|
||||
else:
|
||||
print("💾 Checking database ingestion...")
|
||||
try:
|
||||
conversations = extract_conversations(data_path, max_dialogues=1)
|
||||
print(f"📝 Extracted {len(conversations)} conversations")
|
||||
|
||||
# Always ingest for now (ingestion check not implemented)
|
||||
print(f"🔄 Ingesting conversations into group '{group_id}'...")
|
||||
success = await ingest_conversations_if_needed(
|
||||
conversations=conversations,
|
||||
group_id=group_id,
|
||||
reset=reset_group
|
||||
)
|
||||
|
||||
if success:
|
||||
print("✅ Ingestion completed successfully\n")
|
||||
else:
|
||||
print("⚠️ Ingestion may have failed, continuing anyway\n")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Ingestion failed: {e}")
|
||||
print("⚠️ Continuing with evaluation (database may be empty)\n")
|
||||
|
||||
# Step 3: Initialize clients
|
||||
print("🔧 Initializing clients...")
|
||||
connector = Neo4jConnector()
|
||||
|
||||
# Initialize LLM client with database context
|
||||
with get_db_context() as db:
|
||||
factory = MemoryClientFactory(db)
|
||||
llm_client = factory.get_llm_client(SELECTED_LLM_ID)
|
||||
|
||||
# Initialize embedder
|
||||
with get_db_context() as db:
|
||||
config_service = MemoryConfigService(db)
|
||||
cfg_dict = config_service.get_embedder_config(SELECTED_EMBEDDING_ID)
|
||||
embedder = OpenAIEmbedderClient(
|
||||
model_config=RedBearModelConfig.model_validate(cfg_dict)
|
||||
)
|
||||
print("✅ Clients initialized\n")
|
||||
|
||||
# Step 4: Process questions
|
||||
print(f"🔍 Processing {len(qa_items)} questions...")
|
||||
print(f"{'='*60}\n")
|
||||
|
||||
# Tracking variables
|
||||
latencies_search: List[float] = []
|
||||
latencies_llm: List[float] = []
|
||||
context_counts: List[int] = []
|
||||
context_chars: List[int] = []
|
||||
context_tokens: List[int] = []
|
||||
|
||||
# Metric lists
|
||||
f1_scores: List[float] = []
|
||||
bleu1_scores: List[float] = []
|
||||
jaccard_scores: List[float] = []
|
||||
locomo_f1_scores: List[float] = []
|
||||
|
||||
# Per-category tracking
|
||||
category_counts: Dict[str, int] = {}
|
||||
category_f1: Dict[str, List[float]] = {}
|
||||
category_bleu1: Dict[str, List[float]] = {}
|
||||
category_jaccard: Dict[str, List[float]] = {}
|
||||
category_locomo_f1: Dict[str, List[float]] = {}
|
||||
|
||||
# Detailed samples
|
||||
samples: List[Dict[str, Any]] = []
|
||||
|
||||
# Fixed anchor date for temporal resolution
|
||||
anchor_date = datetime(2023, 5, 8)
|
||||
|
||||
try:
|
||||
for idx, item in enumerate(qa_items, 1):
|
||||
question = item.get("question", "")
|
||||
ground_truth = item.get("answer", "")
|
||||
category = get_category_name(item)
|
||||
|
||||
# Ensure ground truth is a string
|
||||
ground_truth_str = str(ground_truth) if ground_truth is not None else ""
|
||||
|
||||
print(f"[{idx}/{len(qa_items)}] Category: {category}")
|
||||
print(f"❓ Question: {question}")
|
||||
print(f"✅ Ground Truth: {ground_truth_str}")
|
||||
|
||||
# Step 4a: Retrieve relevant information
|
||||
t_search_start = time.time()
|
||||
try:
|
||||
retrieved_info = await retrieve_relevant_information(
|
||||
question=question,
|
||||
group_id=group_id,
|
||||
search_type=search_type,
|
||||
search_limit=search_limit,
|
||||
connector=connector,
|
||||
embedder=embedder
|
||||
)
|
||||
t_search_end = time.time()
|
||||
search_latency = (t_search_end - t_search_start) * 1000
|
||||
latencies_search.append(search_latency)
|
||||
|
||||
print(f"🔍 Retrieved {len(retrieved_info)} documents ({search_latency:.1f}ms)")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Retrieval failed: {e}")
|
||||
retrieved_info = []
|
||||
search_latency = 0.0
|
||||
latencies_search.append(search_latency)
|
||||
|
||||
# Step 4b: Select and format context
|
||||
context_text = select_and_format_information(
|
||||
retrieved_info=retrieved_info,
|
||||
question=question,
|
||||
max_chars=context_char_budget
|
||||
)
|
||||
|
||||
# Resolve temporal references
|
||||
context_text = resolve_temporal_references(context_text, anchor_date)
|
||||
|
||||
# Add reference date to context
|
||||
if context_text:
|
||||
context_text = f"Reference date: {anchor_date.date().isoformat()}\n\n{context_text}"
|
||||
else:
|
||||
context_text = "No relevant context found."
|
||||
|
||||
# Track context statistics
|
||||
context_counts.append(len(retrieved_info))
|
||||
context_chars.append(len(context_text))
|
||||
context_tokens.append(len(context_text.split()))
|
||||
|
||||
print(f"📝 Context: {len(context_text)} chars, {len(retrieved_info)} docs")
|
||||
|
||||
# Step 4c: Generate answer with LLM
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": (
|
||||
"You are a precise QA assistant. Answer following these rules:\n"
|
||||
"1) Extract the EXACT information mentioned in the context\n"
|
||||
"2) For time questions: calculate actual dates from relative times\n"
|
||||
"3) Return ONLY the answer text in simplest form\n"
|
||||
"4) For dates, use format 'DD Month YYYY' (e.g., '7 May 2023')\n"
|
||||
"5) If no clear answer found, respond with 'Unknown'"
|
||||
)
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"Question: {question}\n\nContext:\n{context_text}"
|
||||
}
|
||||
]
|
||||
|
||||
t_llm_start = time.time()
|
||||
try:
|
||||
response = await llm_client.chat(messages=messages)
|
||||
t_llm_end = time.time()
|
||||
llm_latency = (t_llm_end - t_llm_start) * 1000
|
||||
latencies_llm.append(llm_latency)
|
||||
|
||||
# Extract prediction from response
|
||||
if hasattr(response, 'content'):
|
||||
prediction = response.content.strip()
|
||||
elif isinstance(response, dict):
|
||||
prediction = response["choices"][0]["message"]["content"].strip()
|
||||
else:
|
||||
prediction = "Unknown"
|
||||
|
||||
print(f"🤖 Prediction: {prediction} ({llm_latency:.1f}ms)")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ LLM failed: {e}")
|
||||
prediction = "Unknown"
|
||||
llm_latency = 0.0
|
||||
latencies_llm.append(llm_latency)
|
||||
|
||||
# Step 4d: Calculate metrics
|
||||
f1_val = f1_score(prediction, ground_truth_str)
|
||||
bleu1_val = bleu1(prediction, ground_truth_str)
|
||||
jaccard_val = jaccard(prediction, ground_truth_str)
|
||||
|
||||
# LoCoMo-specific F1: use multi-answer for category 1 (Multi-Hop)
|
||||
if item.get("category") == 1:
|
||||
locomo_f1_val = locomo_multi_f1(prediction, ground_truth_str)
|
||||
else:
|
||||
locomo_f1_val = locomo_f1_score(prediction, ground_truth_str)
|
||||
|
||||
# Accumulate metrics
|
||||
f1_scores.append(f1_val)
|
||||
bleu1_scores.append(bleu1_val)
|
||||
jaccard_scores.append(jaccard_val)
|
||||
locomo_f1_scores.append(locomo_f1_val)
|
||||
|
||||
# Track by category
|
||||
category_counts[category] = category_counts.get(category, 0) + 1
|
||||
category_f1.setdefault(category, []).append(f1_val)
|
||||
category_bleu1.setdefault(category, []).append(bleu1_val)
|
||||
category_jaccard.setdefault(category, []).append(jaccard_val)
|
||||
category_locomo_f1.setdefault(category, []).append(locomo_f1_val)
|
||||
|
||||
print(f"📊 Metrics - F1: {f1_val:.3f}, BLEU-1: {bleu1_val:.3f}, "
|
||||
f"Jaccard: {jaccard_val:.3f}, LoCoMo F1: {locomo_f1_val:.3f}")
|
||||
print()
|
||||
|
||||
# Save sample details
|
||||
samples.append({
|
||||
"question": question,
|
||||
"ground_truth": ground_truth_str,
|
||||
"prediction": prediction,
|
||||
"category": category,
|
||||
"metrics": {
|
||||
"f1": f1_val,
|
||||
"bleu1": bleu1_val,
|
||||
"jaccard": jaccard_val,
|
||||
"locomo_f1": locomo_f1_val
|
||||
},
|
||||
"retrieval": {
|
||||
"num_docs": len(retrieved_info),
|
||||
"context_length": len(context_text)
|
||||
},
|
||||
"timing": {
|
||||
"search_ms": search_latency,
|
||||
"llm_ms": llm_latency
|
||||
}
|
||||
})
|
||||
|
||||
finally:
|
||||
# Close connector
|
||||
await connector.close()
|
||||
|
||||
# Step 5: Aggregate results
|
||||
print(f"\n{'='*60}")
|
||||
print("📊 Aggregating Results")
|
||||
print(f"{'='*60}\n")
|
||||
|
||||
# Overall metrics
|
||||
overall_metrics = {
|
||||
"f1": sum(f1_scores) / max(len(f1_scores), 1) if f1_scores else 0.0,
|
||||
"bleu1": sum(bleu1_scores) / max(len(bleu1_scores), 1) if bleu1_scores else 0.0,
|
||||
"jaccard": sum(jaccard_scores) / max(len(jaccard_scores), 1) if jaccard_scores else 0.0,
|
||||
"locomo_f1": sum(locomo_f1_scores) / max(len(locomo_f1_scores), 1) if locomo_f1_scores else 0.0
|
||||
}
|
||||
|
||||
# Per-category metrics
|
||||
by_category: Dict[str, Dict[str, Any]] = {}
|
||||
for cat in category_counts:
|
||||
f1_list = category_f1.get(cat, [])
|
||||
b1_list = category_bleu1.get(cat, [])
|
||||
j_list = category_jaccard.get(cat, [])
|
||||
lf_list = category_locomo_f1.get(cat, [])
|
||||
|
||||
by_category[cat] = {
|
||||
"count": category_counts[cat],
|
||||
"f1": sum(f1_list) / max(len(f1_list), 1) if f1_list else 0.0,
|
||||
"bleu1": sum(b1_list) / max(len(b1_list), 1) if b1_list else 0.0,
|
||||
"jaccard": sum(j_list) / max(len(j_list), 1) if j_list else 0.0,
|
||||
"locomo_f1": sum(lf_list) / max(len(lf_list), 1) if lf_list else 0.0
|
||||
}
|
||||
|
||||
# Latency statistics
|
||||
latency = {
|
||||
"search": latency_stats(latencies_search),
|
||||
"llm": latency_stats(latencies_llm)
|
||||
}
|
||||
|
||||
# Context statistics
|
||||
context_stats = {
|
||||
"avg_retrieved_docs": sum(context_counts) / max(len(context_counts), 1) if context_counts else 0.0,
|
||||
"avg_context_chars": sum(context_chars) / max(len(context_chars), 1) if context_chars else 0.0,
|
||||
"avg_context_tokens": sum(context_tokens) / max(len(context_tokens), 1) if context_tokens else 0.0
|
||||
}
|
||||
|
||||
# Build result dictionary
|
||||
result = {
|
||||
"dataset": "locomo",
|
||||
"sample_size": len(qa_items),
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"params": {
|
||||
"group_id": group_id,
|
||||
"search_type": search_type,
|
||||
"search_limit": search_limit,
|
||||
"context_char_budget": context_char_budget,
|
||||
"llm_id": SELECTED_LLM_ID,
|
||||
"embedding_id": SELECTED_EMBEDDING_ID
|
||||
},
|
||||
"overall_metrics": overall_metrics,
|
||||
"by_category": by_category,
|
||||
"latency": latency,
|
||||
"context_stats": context_stats,
|
||||
"samples": samples
|
||||
}
|
||||
|
||||
# Step 6: Save results
|
||||
if output_dir is None:
|
||||
output_dir = os.path.join(
|
||||
os.path.dirname(__file__),
|
||||
"results"
|
||||
)
|
||||
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
# Generate timestamped filename
|
||||
timestamp_str = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
output_path = os.path.join(output_dir, f"locomo_{timestamp_str}.json")
|
||||
|
||||
try:
|
||||
with open(output_path, "w", encoding="utf-8") as f:
|
||||
json.dump(result, f, ensure_ascii=False, indent=2)
|
||||
print(f"✅ Results saved to: {output_path}\n")
|
||||
except Exception as e:
|
||||
print(f"❌ Failed to save results: {e}")
|
||||
print("📊 Printing results to console instead:\n")
|
||||
print(json.dumps(result, ensure_ascii=False, indent=2))
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def main():
|
||||
"""
|
||||
Parse command-line arguments and run benchmark.
|
||||
|
||||
This function provides a CLI interface for running LoCoMo benchmarks
|
||||
with configurable parameters.
|
||||
"""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Run LoCoMo benchmark evaluation",
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--sample_size",
|
||||
type=int,
|
||||
default=20,
|
||||
help="Number of QA pairs to evaluate"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--group_id",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Database group ID for retrieval (uses default if not specified)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--search_type",
|
||||
type=str,
|
||||
default="hybrid",
|
||||
choices=["keyword", "embedding", "hybrid"],
|
||||
help="Search strategy to use"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--search_limit",
|
||||
type=int,
|
||||
default=12,
|
||||
help="Maximum number of documents to retrieve per query"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--context_char_budget",
|
||||
type=int,
|
||||
default=8000,
|
||||
help="Maximum characters for context"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--reset_group",
|
||||
action="store_true",
|
||||
help="Clear and re-ingest data (not implemented)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--skip_ingest",
|
||||
action="store_true",
|
||||
help="Skip data ingestion and use existing data in Neo4j"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Directory to save results (uses default if not specified)"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Load environment variables
|
||||
load_dotenv()
|
||||
|
||||
# Run benchmark
|
||||
result = asyncio.run(run_locomo_benchmark(
|
||||
sample_size=args.sample_size,
|
||||
group_id=args.group_id,
|
||||
search_type=args.search_type,
|
||||
search_limit=args.search_limit,
|
||||
context_char_budget=args.context_char_budget,
|
||||
reset_group=args.reset_group,
|
||||
skip_ingest=args.skip_ingest,
|
||||
output_dir=args.output_dir
|
||||
))
|
||||
|
||||
# Print summary
|
||||
print(f"\n{'='*60}")
|
||||
|
||||
# Check if there was an error
|
||||
if 'error' in result:
|
||||
print("❌ Benchmark Failed!")
|
||||
print(f"{'='*60}")
|
||||
print(f"Error: {result['error']}")
|
||||
return
|
||||
|
||||
print("🎉 Benchmark Complete!")
|
||||
print(f"{'='*60}")
|
||||
print("📊 Final Results:")
|
||||
print(f" Sample size: {result.get('sample_size', 0)}")
|
||||
print(f" F1: {result['overall_metrics']['f1']:.3f}")
|
||||
print(f" BLEU-1: {result['overall_metrics']['bleu1']:.3f}")
|
||||
print(f" Jaccard: {result['overall_metrics']['jaccard']:.3f}")
|
||||
print(f" LoCoMo F1: {result['overall_metrics']['locomo_f1']:.3f}")
|
||||
|
||||
if result.get('context_stats'):
|
||||
print("\n📈 Context Statistics:")
|
||||
print(f" Avg retrieved docs: {result['context_stats']['avg_retrieved_docs']:.1f}")
|
||||
print(f" Avg context chars: {result['context_stats']['avg_context_chars']:.0f}")
|
||||
print(f" Avg context tokens: {result['context_stats']['avg_context_tokens']:.0f}")
|
||||
|
||||
if result.get('latency'):
|
||||
print("\n⏱️ Latency Statistics:")
|
||||
print(f" Search - Mean: {result['latency']['search']['mean']:.1f}ms, "
|
||||
f"P50: {result['latency']['search']['p50']:.1f}ms, "
|
||||
f"P95: {result['latency']['search']['p95']:.1f}ms")
|
||||
print(f" LLM - Mean: {result['latency']['llm']['mean']:.1f}ms, "
|
||||
f"P50: {result['latency']['llm']['p50']:.1f}ms, "
|
||||
f"P95: {result['latency']['llm']['p95']:.1f}ms")
|
||||
|
||||
if result.get('by_category'):
|
||||
print("\n📂 Results by Category:")
|
||||
for cat, metrics in result['by_category'].items():
|
||||
print(f" {cat}:")
|
||||
print(f" Count: {metrics['count']}")
|
||||
print(f" F1: {metrics['f1']:.3f}")
|
||||
print(f" LoCoMo F1: {metrics['locomo_f1']:.3f}")
|
||||
print(f" Jaccard: {metrics['jaccard']:.3f}")
|
||||
|
||||
print(f"\n{'='*60}\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,225 +0,0 @@
|
||||
"""
|
||||
LoCoMo-specific metric calculations.
|
||||
|
||||
This module provides clean, simplified implementations of metrics used for
|
||||
LoCoMo benchmark evaluation, including text normalization and F1 score variants.
|
||||
"""
|
||||
|
||||
import re
|
||||
from typing import Dict, Any
|
||||
|
||||
|
||||
def normalize_text(text: str) -> str:
|
||||
"""
|
||||
Normalize text for LoCoMo evaluation.
|
||||
|
||||
Normalization steps:
|
||||
- Convert to lowercase
|
||||
- Remove commas
|
||||
- Remove stop words (a, an, the, and)
|
||||
- Remove punctuation
|
||||
- Normalize whitespace
|
||||
|
||||
Args:
|
||||
text: Input text to normalize
|
||||
|
||||
Returns:
|
||||
Normalized text string with consistent formatting
|
||||
|
||||
Examples:
|
||||
>>> normalize_text("The cat, and the dog")
|
||||
'cat dog'
|
||||
>>> normalize_text("Hello, World!")
|
||||
'hello world'
|
||||
"""
|
||||
# Ensure input is a string
|
||||
text = str(text) if text is not None else ""
|
||||
|
||||
# Convert to lowercase
|
||||
text = text.lower()
|
||||
|
||||
# Remove commas
|
||||
text = re.sub(r"[\,]", " ", text)
|
||||
|
||||
# Remove stop words
|
||||
text = re.sub(r"\b(a|an|the|and)\b", " ", text)
|
||||
|
||||
# Remove punctuation (keep only word characters and whitespace)
|
||||
text = re.sub(r"[^\w\s]", " ", text)
|
||||
|
||||
# Normalize whitespace (collapse multiple spaces to single space)
|
||||
text = " ".join(text.split())
|
||||
|
||||
return text
|
||||
|
||||
|
||||
def locomo_f1_score(prediction: str, ground_truth: str) -> float:
|
||||
"""
|
||||
Calculate LoCoMo F1 score for single-answer questions.
|
||||
|
||||
Uses token-level precision and recall based on normalized text.
|
||||
Treats tokens as sets (no duplicate counting).
|
||||
|
||||
Args:
|
||||
prediction: Model's predicted answer
|
||||
ground_truth: Correct answer
|
||||
|
||||
Returns:
|
||||
F1 score between 0.0 and 1.0
|
||||
|
||||
Examples:
|
||||
>>> locomo_f1_score("Paris", "Paris")
|
||||
1.0
|
||||
>>> locomo_f1_score("The cat", "cat")
|
||||
1.0
|
||||
>>> locomo_f1_score("dog", "cat")
|
||||
0.0
|
||||
"""
|
||||
# Ensure inputs are strings
|
||||
pred_str = str(prediction) if prediction is not None else ""
|
||||
truth_str = str(ground_truth) if ground_truth is not None else ""
|
||||
|
||||
# Normalize and tokenize
|
||||
pred_tokens = normalize_text(pred_str).split()
|
||||
truth_tokens = normalize_text(truth_str).split()
|
||||
|
||||
# Handle empty cases
|
||||
if not pred_tokens or not truth_tokens:
|
||||
return 0.0
|
||||
|
||||
# Convert to sets for comparison
|
||||
pred_set = set(pred_tokens)
|
||||
truth_set = set(truth_tokens)
|
||||
|
||||
# Calculate true positives (intersection)
|
||||
true_positives = len(pred_set & truth_set)
|
||||
|
||||
# Calculate precision and recall
|
||||
precision = true_positives / len(pred_set) if pred_set else 0.0
|
||||
recall = true_positives / len(truth_set) if truth_set else 0.0
|
||||
|
||||
# Calculate F1 score
|
||||
if precision + recall == 0:
|
||||
return 0.0
|
||||
|
||||
f1 = 2 * precision * recall / (precision + recall)
|
||||
return f1
|
||||
|
||||
|
||||
def locomo_multi_f1(prediction: str, ground_truth: str) -> float:
|
||||
"""
|
||||
Calculate LoCoMo F1 score for multi-answer questions.
|
||||
|
||||
Handles comma-separated answers by:
|
||||
1. Splitting both prediction and ground truth by commas
|
||||
2. For each ground truth answer, finding the best matching prediction
|
||||
3. Averaging the F1 scores across all ground truth answers
|
||||
|
||||
Args:
|
||||
prediction: Model's predicted answer (may contain multiple comma-separated answers)
|
||||
ground_truth: Correct answer (may contain multiple comma-separated answers)
|
||||
|
||||
Returns:
|
||||
Average F1 score across all ground truth answers (0.0 to 1.0)
|
||||
|
||||
Examples:
|
||||
>>> locomo_multi_f1("Paris, London", "Paris, London")
|
||||
1.0
|
||||
>>> locomo_multi_f1("Paris", "Paris, London")
|
||||
0.5
|
||||
>>> locomo_multi_f1("Paris, Berlin", "Paris, London")
|
||||
0.5
|
||||
"""
|
||||
# Ensure inputs are strings
|
||||
pred_str = str(prediction) if prediction is not None else ""
|
||||
truth_str = str(ground_truth) if ground_truth is not None else ""
|
||||
|
||||
# Split by commas and strip whitespace
|
||||
predictions = [p.strip() for p in pred_str.split(',') if p.strip()]
|
||||
ground_truths = [g.strip() for g in truth_str.split(',') if g.strip()]
|
||||
|
||||
# Handle empty cases
|
||||
if not predictions or not ground_truths:
|
||||
return 0.0
|
||||
|
||||
# For each ground truth, find the best matching prediction
|
||||
f1_scores = []
|
||||
for gt in ground_truths:
|
||||
# Calculate F1 with each prediction and take the maximum
|
||||
best_f1 = max(locomo_f1_score(pred, gt) for pred in predictions)
|
||||
f1_scores.append(best_f1)
|
||||
|
||||
# Return average F1 across all ground truths
|
||||
return sum(f1_scores) / len(f1_scores)
|
||||
|
||||
|
||||
def get_category_name(item: Dict[str, Any]) -> str:
|
||||
"""
|
||||
Extract and normalize category name from QA item.
|
||||
|
||||
Handles both numeric categories (1-4) and string categories with various formats.
|
||||
Supports multiple field names: "cat", "category", "type".
|
||||
|
||||
Category mapping:
|
||||
- 1 or "multi-hop" -> "Multi-Hop"
|
||||
- 2 or "temporal" -> "Temporal"
|
||||
- 3 or "open domain" -> "Open Domain"
|
||||
- 4 or "single-hop" -> "Single-Hop"
|
||||
|
||||
Args:
|
||||
item: QA item dictionary containing category information
|
||||
|
||||
Returns:
|
||||
Standardized category name or "unknown" if not found
|
||||
|
||||
Examples:
|
||||
>>> get_category_name({"category": 1})
|
||||
'Multi-Hop'
|
||||
>>> get_category_name({"cat": "temporal"})
|
||||
'Temporal'
|
||||
>>> get_category_name({"type": "Single-Hop"})
|
||||
'Single-Hop'
|
||||
"""
|
||||
# Numeric category mapping
|
||||
CATEGORY_MAP = {
|
||||
1: "Multi-Hop",
|
||||
2: "Temporal",
|
||||
3: "Open Domain",
|
||||
4: "Single-Hop",
|
||||
}
|
||||
|
||||
# String category aliases (case-insensitive)
|
||||
TYPE_ALIASES = {
|
||||
"single-hop": "Single-Hop",
|
||||
"singlehop": "Single-Hop",
|
||||
"single hop": "Single-Hop",
|
||||
"multi-hop": "Multi-Hop",
|
||||
"multihop": "Multi-Hop",
|
||||
"multi hop": "Multi-Hop",
|
||||
"open domain": "Open Domain",
|
||||
"opendomain": "Open Domain",
|
||||
"temporal": "Temporal",
|
||||
}
|
||||
|
||||
# Try "cat" field first (string category)
|
||||
cat = item.get("cat")
|
||||
if isinstance(cat, str) and cat.strip():
|
||||
name = cat.strip()
|
||||
lower = name.lower()
|
||||
return TYPE_ALIASES.get(lower, name)
|
||||
|
||||
# Try "category" field (can be int or string)
|
||||
cat_num = item.get("category")
|
||||
if isinstance(cat_num, int):
|
||||
return CATEGORY_MAP.get(cat_num, "unknown")
|
||||
elif isinstance(cat_num, str) and cat_num.strip():
|
||||
lower = cat_num.strip().lower()
|
||||
return TYPE_ALIASES.get(lower, cat_num.strip())
|
||||
|
||||
# Try "type" field as fallback
|
||||
cat_type = item.get("type")
|
||||
if isinstance(cat_type, str) and cat_type.strip():
|
||||
lower = cat_type.strip().lower()
|
||||
return TYPE_ALIASES.get(lower, cat_type.strip())
|
||||
|
||||
return "unknown"
|
||||
@@ -1,810 +0,0 @@
|
||||
# file name: check_neo4j_connection_fixed.py
|
||||
import asyncio
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import time
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# 1
|
||||
# 添加项目根目录到路径
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
project_root = os.path.dirname(current_dir)
|
||||
if project_root not in sys.path:
|
||||
sys.path.insert(0, project_root)
|
||||
# 关键:将 src 目录置于最前,确保从当前仓库加载模块
|
||||
src_dir = os.path.join(project_root, "src")
|
||||
if src_dir not in sys.path:
|
||||
sys.path.insert(0, src_dir)
|
||||
|
||||
load_dotenv()
|
||||
|
||||
# 首先定义 _loc_normalize 函数,因为其他函数依赖它
|
||||
def _loc_normalize(text: str) -> str:
|
||||
text = str(text) if text is not None else ""
|
||||
text = text.lower()
|
||||
text = re.sub(r"[\,]", " ", text)
|
||||
text = re.sub(r"\b(a|an|the|and)\b", " ", text)
|
||||
text = re.sub(r"[^\w\s]", " ", text)
|
||||
text = " ".join(text.split())
|
||||
return text
|
||||
|
||||
# 尝试从 metrics.py 导入基础指标
|
||||
try:
|
||||
from common.metrics import bleu1, f1_score, jaccard
|
||||
print("✅ 从 metrics.py 导入基础指标成功")
|
||||
except ImportError as e:
|
||||
print(f"❌ 从 metrics.py 导入失败: {e}")
|
||||
# 回退到本地实现
|
||||
def f1_score(pred: str, ref: str) -> float:
|
||||
pred_str = str(pred) if pred is not None else ""
|
||||
ref_str = str(ref) if ref is not None else ""
|
||||
|
||||
p_tokens = _loc_normalize(pred_str).split()
|
||||
r_tokens = _loc_normalize(ref_str).split()
|
||||
if not p_tokens and not r_tokens:
|
||||
return 1.0
|
||||
if not p_tokens or not r_tokens:
|
||||
return 0.0
|
||||
p_set = set(p_tokens)
|
||||
r_set = set(r_tokens)
|
||||
tp = len(p_set & r_set)
|
||||
precision = tp / len(p_set) if p_set else 0.0
|
||||
recall = tp / len(r_set) if r_set else 0.0
|
||||
if precision + recall == 0:
|
||||
return 0.0
|
||||
return 2 * precision * recall / (precision + recall)
|
||||
|
||||
def bleu1(pred: str, ref: str) -> float:
|
||||
pred_str = str(pred) if pred is not None else ""
|
||||
ref_str = str(ref) if ref is not None else ""
|
||||
|
||||
p_tokens = _loc_normalize(pred_str).split()
|
||||
r_tokens = _loc_normalize(ref_str).split()
|
||||
if not p_tokens:
|
||||
return 0.0
|
||||
|
||||
r_counts = {}
|
||||
for t in r_tokens:
|
||||
r_counts[t] = r_counts.get(t, 0) + 1
|
||||
|
||||
clipped = 0
|
||||
p_counts = {}
|
||||
for t in p_tokens:
|
||||
p_counts[t] = p_counts.get(t, 0) + 1
|
||||
|
||||
for t, c in p_counts.items():
|
||||
clipped += min(c, r_counts.get(t, 0))
|
||||
|
||||
precision = clipped / max(len(p_tokens), 1)
|
||||
ref_len = len(r_tokens)
|
||||
pred_len = len(p_tokens)
|
||||
|
||||
if pred_len > ref_len or pred_len == 0:
|
||||
bp = 1.0
|
||||
else:
|
||||
bp = math.exp(1 - ref_len / max(pred_len, 1))
|
||||
|
||||
return bp * precision
|
||||
|
||||
def jaccard(pred: str, ref: str) -> float:
|
||||
pred_str = str(pred) if pred is not None else ""
|
||||
ref_str = str(ref) if ref is not None else ""
|
||||
|
||||
p = set(_loc_normalize(pred_str).split())
|
||||
r = set(_loc_normalize(ref_str).split())
|
||||
if not p and not r:
|
||||
return 1.0
|
||||
if not p or not r:
|
||||
return 0.0
|
||||
return len(p & r) / len(p | r)
|
||||
|
||||
# 尝试从 qwen_search_eval.py 导入 LoCoMo 特定指标
|
||||
try:
|
||||
# 添加 evaluation 目录路径
|
||||
evaluation_dir = os.path.join(project_root, "evaluation")
|
||||
if evaluation_dir not in sys.path:
|
||||
sys.path.insert(0, evaluation_dir)
|
||||
|
||||
# 尝试从不同位置导入
|
||||
try:
|
||||
from locomo.qwen_search_eval import (
|
||||
_resolve_relative_times,
|
||||
loc_f1_score,
|
||||
loc_multi_f1,
|
||||
)
|
||||
print("✅ 从 locomo.qwen_search_eval 导入 LoCoMo 特定指标成功")
|
||||
except ImportError:
|
||||
from qwen_search_eval import _resolve_relative_times, loc_f1_score, loc_multi_f1
|
||||
print("✅ 从 qwen_search_eval 导入 LoCoMo 特定指标成功")
|
||||
|
||||
except ImportError as e:
|
||||
print(f"❌ 从 qwen_search_eval.py 导入失败: {e}")
|
||||
# 回退到本地实现 LoCoMo 特定函数
|
||||
def _resolve_relative_times(text: str, anchor: datetime) -> str:
|
||||
t = str(text) if text is not None else ""
|
||||
t = re.sub(r"\btoday\b", anchor.date().isoformat(), t, flags=re.IGNORECASE)
|
||||
t = re.sub(r"\byesterday\b", (anchor - timedelta(days=1)).date().isoformat(), t, flags=re.IGNORECASE)
|
||||
t = re.sub(r"\btomorrow\b", (anchor + timedelta(days=1)).date().isoformat(), t, flags=re.IGNORECASE)
|
||||
|
||||
def _ago_repl(m: re.Match[str]) -> str:
|
||||
n = int(m.group(1))
|
||||
return (anchor - timedelta(days=n)).date().isoformat()
|
||||
def _in_repl(m: re.Match[str]) -> str:
|
||||
n = int(m.group(1))
|
||||
return (anchor + timedelta(days=n)).date().isoformat()
|
||||
|
||||
t = re.sub(r"\b(\d+)\s+days\s+ago\b", _ago_repl, t, flags=re.IGNORECASE)
|
||||
t = re.sub(r"\bin\s+(\d+)\s+days\b", _in_repl, t, flags=re.IGNORECASE)
|
||||
t = re.sub(r"\blast\s+week\b", (anchor - timedelta(days=7)).date().isoformat(), t, flags=re.IGNORECASE)
|
||||
t = re.sub(r"\bnext\s+week\b", (anchor + timedelta(days=7)).date().isoformat(), t, flags=re.IGNORECASE)
|
||||
return t
|
||||
|
||||
def loc_f1_score(prediction: str, ground_truth: str) -> float:
|
||||
p_tokens = _loc_normalize(prediction).split()
|
||||
g_tokens = _loc_normalize(ground_truth).split()
|
||||
if not p_tokens or not g_tokens:
|
||||
return 0.0
|
||||
p = set(p_tokens)
|
||||
g = set(g_tokens)
|
||||
tp = len(p & g)
|
||||
precision = tp / len(p) if p else 0.0
|
||||
recall = tp / len(g) if g else 0.0
|
||||
return (2 * precision * recall / (precision + recall)) if (precision + recall) > 0 else 0.0
|
||||
|
||||
def loc_multi_f1(prediction: str, ground_truth: str) -> float:
|
||||
predictions = [p.strip() for p in str(prediction).split(',') if p.strip()]
|
||||
ground_truths = [g.strip() for g in str(ground_truth).split(',') if g.strip()]
|
||||
if not predictions or not ground_truths:
|
||||
return 0.0
|
||||
def _f1(a: str, b: str) -> float:
|
||||
return loc_f1_score(a, b)
|
||||
vals = []
|
||||
for gt in ground_truths:
|
||||
vals.append(max(_f1(pred, gt) for pred in predictions))
|
||||
return sum(vals) / len(vals)
|
||||
|
||||
|
||||
def smart_context_selection(contexts: List[str], question: str, max_chars: int = 8000) -> str:
|
||||
"""基于问题关键词智能选择上下文"""
|
||||
if not contexts:
|
||||
return ""
|
||||
|
||||
# 提取问题关键词(只保留有意义的词)
|
||||
question_lower = question.lower()
|
||||
stop_words = {'what', 'when', 'where', 'who', 'why', 'how', 'did', 'do', 'does', 'is', 'are', 'was', 'were', 'the', 'a', 'an', 'and', 'or', 'but'}
|
||||
question_words = set(re.findall(r'\b\w+\b', question_lower))
|
||||
question_words = {word for word in question_words if word not in stop_words and len(word) > 2}
|
||||
|
||||
print(f"🔍 问题关键词: {question_words}")
|
||||
|
||||
# 给每个上下文打分
|
||||
scored_contexts = []
|
||||
for i, context in enumerate(contexts):
|
||||
context_lower = context.lower()
|
||||
score = 0
|
||||
|
||||
# 关键词匹配得分
|
||||
keyword_matches = 0
|
||||
for word in question_words:
|
||||
if word in context_lower:
|
||||
keyword_matches += 1
|
||||
# 关键词出现次数越多,得分越高
|
||||
score += context_lower.count(word) * 2
|
||||
|
||||
# 上下文长度得分(适中的长度更好)
|
||||
context_len = len(context)
|
||||
if 100 < context_len < 2000: # 理想长度范围
|
||||
score += 5
|
||||
elif context_len >= 2000: # 太长可能包含无关信息
|
||||
score += 2
|
||||
|
||||
# 如果是前几个上下文,给予额外分数(通常相关性更高)
|
||||
if i < 3:
|
||||
score += 3
|
||||
|
||||
scored_contexts.append((score, context, keyword_matches))
|
||||
|
||||
# 按得分排序
|
||||
scored_contexts.sort(key=lambda x: x[0], reverse=True)
|
||||
|
||||
# 选择高得分的上下文,直到达到字符限制
|
||||
selected = []
|
||||
total_chars = 0
|
||||
selected_count = 0
|
||||
|
||||
print("📊 上下文相关性分析:")
|
||||
for score, context, matches in scored_contexts[:5]: # 只显示前5个
|
||||
print(f" - 得分: {score}, 关键词匹配: {matches}, 长度: {len(context)}")
|
||||
|
||||
for score, context, matches in scored_contexts:
|
||||
if total_chars + len(context) <= max_chars:
|
||||
selected.append(context)
|
||||
total_chars += len(context)
|
||||
selected_count += 1
|
||||
else:
|
||||
# 如果这个上下文得分很高但放不下,尝试截取
|
||||
if score > 10 and total_chars < max_chars - 500:
|
||||
remaining = max_chars - total_chars
|
||||
# 找到包含关键词的部分
|
||||
lines = context.split('\n')
|
||||
relevant_lines = []
|
||||
current_chars = 0
|
||||
|
||||
for line in lines:
|
||||
line_lower = line.lower()
|
||||
line_relevance = any(word in line_lower for word in question_words)
|
||||
|
||||
if line_relevance and current_chars < remaining - 100:
|
||||
relevant_lines.append(line)
|
||||
current_chars += len(line)
|
||||
|
||||
if relevant_lines:
|
||||
truncated = '\n'.join(relevant_lines)
|
||||
if len(truncated) > 100: # 确保有足够内容
|
||||
selected.append(truncated + "\n[相关内容截断...]")
|
||||
total_chars += len(truncated)
|
||||
selected_count += 1
|
||||
break # 不再尝试添加更多上下文
|
||||
|
||||
result = "\n\n".join(selected)
|
||||
print(f"✅ 智能选择: {selected_count}个上下文, 总长度: {total_chars}字符")
|
||||
return result
|
||||
|
||||
|
||||
def get_dynamic_search_params(question: str, question_index: int, total_questions: int):
|
||||
"""根据问题复杂度和进度动态调整检索参数"""
|
||||
|
||||
# 分析问题复杂度
|
||||
word_count = len(question.split())
|
||||
has_temporal = any(word in question.lower() for word in ['when', 'date', 'time', 'ago'])
|
||||
has_multi_hop = any(word in question.lower() for word in ['and', 'both', 'also', 'while'])
|
||||
|
||||
# 根据进度调整 - 后期问题可能需要更精确的检索
|
||||
progress_factor = question_index / total_questions
|
||||
|
||||
base_limit = 12
|
||||
if has_temporal and has_multi_hop:
|
||||
base_limit = 20
|
||||
elif word_count > 8:
|
||||
base_limit = 16
|
||||
|
||||
# 随着测试进行,逐渐收紧检索范围
|
||||
adjusted_limit = max(8, int(base_limit * (1 - progress_factor * 0.3)))
|
||||
|
||||
# 动态调整最大字符数
|
||||
max_chars = 8000 + 4000 * (1 - progress_factor)
|
||||
|
||||
return {
|
||||
"limit": adjusted_limit,
|
||||
"max_chars": int(max_chars)
|
||||
}
|
||||
|
||||
|
||||
class EnhancedEvaluationMonitor:
|
||||
def __init__(self, reset_interval=5, performance_threshold=0.6):
|
||||
self.question_count = 0
|
||||
self.reset_interval = reset_interval
|
||||
self.performance_threshold = performance_threshold
|
||||
self.consecutive_low_scores = 0
|
||||
self.performance_history = []
|
||||
self.recent_f1_scores = []
|
||||
|
||||
def should_reset_connections(self, current_f1=None):
|
||||
"""基于计数和性能双重判断"""
|
||||
# 定期重置
|
||||
if self.question_count % self.reset_interval == 0:
|
||||
return True
|
||||
|
||||
# 性能驱动的重置
|
||||
if current_f1 is not None and current_f1 < self.performance_threshold:
|
||||
self.consecutive_low_scores += 1
|
||||
if self.consecutive_low_scores >= 2: # 连续2个低分就重置
|
||||
print("🚨 连续低分,触发紧急重置")
|
||||
self.consecutive_low_scores = 0
|
||||
return True
|
||||
else:
|
||||
self.consecutive_low_scores = 0
|
||||
|
||||
return False
|
||||
|
||||
def record_performance(self, question_index, metrics, context_length, retrieved_docs):
|
||||
"""记录性能指标,检测衰减"""
|
||||
self.performance_history.append({
|
||||
'index': question_index,
|
||||
'metrics': metrics,
|
||||
'context_length': context_length,
|
||||
'retrieved_docs': retrieved_docs,
|
||||
'timestamp': time.time()
|
||||
})
|
||||
|
||||
# 记录最近的F1分数
|
||||
self.recent_f1_scores.append(metrics['f1'])
|
||||
if len(self.recent_f1_scores) > 5:
|
||||
self.recent_f1_scores.pop(0)
|
||||
|
||||
def get_recent_performance(self):
|
||||
"""获取近期平均性能"""
|
||||
if not self.recent_f1_scores:
|
||||
return 0.5
|
||||
return sum(self.recent_f1_scores) / len(self.recent_f1_scores)
|
||||
|
||||
def get_performance_trend(self):
|
||||
"""分析性能趋势"""
|
||||
if len(self.performance_history) < 2:
|
||||
return "stable"
|
||||
|
||||
recent_metrics = [item['metrics']['f1'] for item in self.performance_history[-5:]]
|
||||
earlier_metrics = [item['metrics']['f1'] for item in self.performance_history[-10:-5]]
|
||||
|
||||
if len(recent_metrics) < 2 or len(earlier_metrics) < 2:
|
||||
return "stable"
|
||||
|
||||
recent_avg = sum(recent_metrics) / len(recent_metrics)
|
||||
earlier_avg = sum(earlier_metrics) / len(earlier_metrics)
|
||||
|
||||
if recent_avg < earlier_avg * 0.8:
|
||||
return "degrading"
|
||||
elif recent_avg > earlier_avg * 1.1:
|
||||
return "improving"
|
||||
else:
|
||||
return "stable"
|
||||
|
||||
|
||||
def get_enhanced_search_params(question: str, question_index: int, total_questions: int, recent_performance: float):
|
||||
"""基于问题复杂度和近期性能动态调整检索参数"""
|
||||
|
||||
# 基础参数
|
||||
base_params = get_dynamic_search_params(question, question_index, total_questions)
|
||||
|
||||
# 性能自适应调整
|
||||
if recent_performance < 0.5: # 近期表现差
|
||||
# 增加检索范围,尝试获取更多上下文
|
||||
base_params["limit"] = min(base_params["limit"] + 5, 25)
|
||||
base_params["max_chars"] = min(base_params["max_chars"] + 2000, 12000)
|
||||
print(f"📈 性能自适应:增加检索范围 (limit={base_params['limit']}, max_chars={base_params['max_chars']})")
|
||||
|
||||
elif recent_performance > 0.8: # 近期表现好
|
||||
# 收紧检索,提高精度
|
||||
base_params["limit"] = max(base_params["limit"] - 2, 8)
|
||||
base_params["max_chars"] = max(base_params["max_chars"] - 1000, 6000)
|
||||
print(f"🎯 性能自适应:提高检索精度 (limit={base_params['limit']}, max_chars={base_params['max_chars']})")
|
||||
|
||||
# 中间阶段特殊处理
|
||||
mid_sequence_factor = abs(question_index / total_questions - 0.5)
|
||||
if mid_sequence_factor < 0.2: # 在中间30%的问题
|
||||
print("🎯 中间阶段:使用更精确的检索策略")
|
||||
base_params["limit"] = max(base_params["limit"] - 2, 10) # 减少数量,提高质量
|
||||
base_params["max_chars"] = max(base_params["max_chars"] - 1000, 7000)
|
||||
|
||||
return base_params
|
||||
|
||||
|
||||
def enhanced_context_selection(contexts: List[str], question: str, question_index: int, total_questions: int, max_chars: int = 8000) -> str:
|
||||
"""考虑问题序列位置的智能选择"""
|
||||
|
||||
if not contexts:
|
||||
return ""
|
||||
|
||||
# 在序列中间阶段使用更严格的筛选
|
||||
mid_sequence_factor = abs(question_index / total_questions - 0.5) # 距离中心的距离
|
||||
|
||||
if mid_sequence_factor < 0.2: # 在中间30%的问题
|
||||
print("🎯 中间阶段:使用严格上下文筛选")
|
||||
|
||||
# 提取问题关键词
|
||||
question_lower = question.lower()
|
||||
stop_words = {'what', 'when', 'where', 'who', 'why', 'how', 'did', 'do', 'does', 'is', 'are', 'was', 'were', 'the', 'a', 'an', 'and', 'or', 'but'}
|
||||
question_words = set(re.findall(r'\b\w+\b', question_lower))
|
||||
question_words = {word for word in question_words if word not in stop_words and len(word) > 2}
|
||||
|
||||
# 只保留高度相关的上下文
|
||||
filtered_contexts = []
|
||||
for context in contexts:
|
||||
context_lower = context.lower()
|
||||
relevance_score = sum(3 if word in context_lower else 0 for word in question_words)
|
||||
|
||||
# 额外加分给包含数字、日期的上下文(对事实性问题更重要)
|
||||
if any(char.isdigit() for char in context):
|
||||
relevance_score += 2
|
||||
|
||||
# 提高阈值:只有得分>=3的上下文才保留
|
||||
if relevance_score >= 3:
|
||||
filtered_contexts.append(context)
|
||||
else:
|
||||
print(f" - 过滤低分上下文: 得分={relevance_score}")
|
||||
|
||||
contexts = filtered_contexts
|
||||
print(f"🔍 严格筛选后保留 {len(contexts)} 个上下文")
|
||||
|
||||
# 使用原有的智能选择逻辑
|
||||
return smart_context_selection(contexts, question, max_chars)
|
||||
|
||||
|
||||
async def run_enhanced_evaluation():
|
||||
"""使用增强方法进行完整评估 - 解决中间性能衰减问题"""
|
||||
try:
|
||||
from dotenv import load_dotenv
|
||||
except Exception:
|
||||
def load_dotenv():
|
||||
return None
|
||||
|
||||
# 修正导入路径:使用 app.core.memory.src 前缀
|
||||
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
|
||||
from app.core.memory.utils.config.definitions import (
|
||||
SELECTED_EMBEDDING_ID,
|
||||
SELECTED_LLM_ID,
|
||||
)
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
from app.db import get_db_context
|
||||
from app.repositories.neo4j.graph_search import search_graph_by_embedding
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
|
||||
# 加载数据
|
||||
# 获取项目根目录
|
||||
current_file = os.path.abspath(__file__)
|
||||
evaluation_dir = os.path.dirname(os.path.dirname(current_file)) # evaluation目录
|
||||
memory_dir = os.path.dirname(evaluation_dir) # memory目录
|
||||
data_path = os.path.join(memory_dir, "data", "locomo10.json")
|
||||
with open(data_path, "r", encoding="utf-8") as f:
|
||||
raw = json.load(f)
|
||||
|
||||
qa_items = []
|
||||
if isinstance(raw, list):
|
||||
for entry in raw:
|
||||
qa_items.extend(entry.get("qa", []))
|
||||
else:
|
||||
qa_items.extend(raw.get("qa", []))
|
||||
|
||||
items = qa_items[:20] # 测试多少个问题
|
||||
|
||||
# 初始化增强监控器
|
||||
monitor = EnhancedEvaluationMonitor(reset_interval=5, performance_threshold=0.6)
|
||||
|
||||
with get_db_context() as db:
|
||||
factory = MemoryClientFactory(db)
|
||||
llm = factory.get_llm_client(SELECTED_LLM_ID)
|
||||
|
||||
# 初始化embedder
|
||||
with get_db_context() as db:
|
||||
config_service = MemoryConfigService(db)
|
||||
cfg_dict = config_service.get_embedder_config(SELECTED_EMBEDDING_ID)
|
||||
embedder = OpenAIEmbedderClient(
|
||||
model_config=RedBearModelConfig.model_validate(cfg_dict)
|
||||
)
|
||||
|
||||
# 初始化连接器
|
||||
connector = Neo4jConnector()
|
||||
|
||||
# 初始化结果字典
|
||||
results = {
|
||||
"questions": [],
|
||||
"overall_metrics": {"f1": 0.0, "b1": 0.0, "j": 0.0, "loc_f1": 0.0},
|
||||
"category_metrics": {},
|
||||
"retrieval_stats": {"total_questions": len(items), "avg_context_length": 0, "avg_retrieved_docs": 0},
|
||||
"performance_trend": "stable",
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"enhanced_strategy": True
|
||||
}
|
||||
|
||||
total_f1 = 0.0
|
||||
total_bleu1 = 0.0
|
||||
total_jaccard = 0.0
|
||||
total_loc_f1 = 0.0
|
||||
total_context_length = 0
|
||||
total_retrieved_docs = 0
|
||||
category_stats = {}
|
||||
|
||||
try:
|
||||
for i, item in enumerate(items):
|
||||
monitor.question_count += 1
|
||||
|
||||
# 获取近期性能用于重置判断
|
||||
recent_performance = monitor.get_recent_performance()
|
||||
|
||||
# 增强的重置判断
|
||||
should_reset = monitor.should_reset_connections(current_f1=recent_performance)
|
||||
if should_reset and i > 0:
|
||||
print(f"🔄 重置Neo4j连接 (问题 {i+1}/{len(items)}, 近期性能: {recent_performance:.3f})...")
|
||||
await connector.close()
|
||||
connector = Neo4jConnector() # 创建新连接
|
||||
print("✅ 连接重置完成")
|
||||
|
||||
q = item.get("question", "")
|
||||
ref = item.get("answer", "")
|
||||
ref_str = str(ref) if ref is not None else ""
|
||||
|
||||
print(f"\n🔍 [{i+1}/{len(items)}] 问题: {q}")
|
||||
print(f"✅ 真实答案: {ref_str}")
|
||||
|
||||
# 分类别统计
|
||||
category = "Unknown"
|
||||
if item.get("category") == 1:
|
||||
category = "Multi-Hop"
|
||||
elif item.get("category") == 2:
|
||||
category = "Temporal"
|
||||
elif item.get("category") == 3:
|
||||
category = "Open Domain"
|
||||
elif item.get("category") == 4:
|
||||
category = "Single-Hop"
|
||||
|
||||
# 增强的检索参数
|
||||
search_params = get_enhanced_search_params(q, i, len(items), recent_performance)
|
||||
search_limit = search_params["limit"]
|
||||
max_chars = search_params["max_chars"]
|
||||
|
||||
print(f"🏷️ 类别: {category}, 检索参数: limit={search_limit}, max_chars={max_chars}")
|
||||
|
||||
# 使用项目标准的混合检索方法
|
||||
t0 = time.time()
|
||||
contexts_all = []
|
||||
|
||||
try:
|
||||
# 使用统一的搜索服务
|
||||
from app.core.memory.storage_services.search import run_hybrid_search
|
||||
|
||||
print("🔀 使用混合搜索服务...")
|
||||
|
||||
search_results = await run_hybrid_search(
|
||||
query_text=q,
|
||||
search_type="hybrid",
|
||||
group_id="locomo_sk",
|
||||
limit=20,
|
||||
include=["statements", "chunks", "entities", "summaries"],
|
||||
alpha=0.6, # BM25权重
|
||||
embedding_id=SELECTED_EMBEDDING_ID
|
||||
)
|
||||
|
||||
# 处理搜索结果 - 新的搜索服务返回统一的结构
|
||||
chunks = search_results.get("chunks", [])
|
||||
statements = search_results.get("statements", [])
|
||||
entities = search_results.get("entities", [])
|
||||
summaries = search_results.get("summaries", [])
|
||||
|
||||
print(f"✅ 混合检索成功: {len(chunks)} chunks, {len(statements)} 条陈述, {len(entities)} 个实体, {len(summaries)} 个摘要")
|
||||
|
||||
# 构建上下文:优先使用 chunks、statements 和 summaries
|
||||
for c in chunks:
|
||||
content = str(c.get("content", "")).strip()
|
||||
if content:
|
||||
contexts_all.append(content)
|
||||
|
||||
for s in statements:
|
||||
stmt_text = str(s.get("statement", "")).strip()
|
||||
if stmt_text:
|
||||
contexts_all.append(stmt_text)
|
||||
|
||||
for sm in summaries:
|
||||
summary_text = str(sm.get("summary", "")).strip()
|
||||
if summary_text:
|
||||
contexts_all.append(summary_text)
|
||||
|
||||
# 实体摘要:最多加入前3个高分实体,避免噪声
|
||||
scored = [e for e in entities if e.get("score") is not None]
|
||||
top_entities = sorted(scored, key=lambda x: x.get("score", 0), reverse=True)[:3] if scored else entities[:3]
|
||||
if top_entities:
|
||||
summary_lines = []
|
||||
for e in top_entities:
|
||||
name = str(e.get("name", "")).strip()
|
||||
etype = str(e.get("entity_type", "")).strip()
|
||||
score = e.get("score")
|
||||
if name:
|
||||
meta = []
|
||||
if etype:
|
||||
meta.append(f"type={etype}")
|
||||
if isinstance(score, (int, float)):
|
||||
meta.append(f"score={score:.3f}")
|
||||
summary_lines.append(f"EntitySummary: {name}{(' [' + ' '.join(meta) + ']') if meta else ''}")
|
||||
if summary_lines:
|
||||
contexts_all.append("\n".join(summary_lines))
|
||||
|
||||
print(f"📊 有效上下文数量: {len(contexts_all)}")
|
||||
except Exception as e:
|
||||
print(f"❌ 检索失败: {e}")
|
||||
contexts_all = []
|
||||
|
||||
t1 = time.time()
|
||||
search_time = (t1 - t0) * 1000
|
||||
|
||||
# 增强的上下文选择
|
||||
context_text = ""
|
||||
if contexts_all:
|
||||
# 使用增强的上下文选择
|
||||
context_text = enhanced_context_selection(contexts_all, q, i, len(items), max_chars=max_chars)
|
||||
|
||||
# 如果智能选择后仍然过长,进行最终保护性截断
|
||||
if len(context_text) > max_chars:
|
||||
print(f"⚠️ 智能选择后仍然过长 ({len(context_text)}字符),进行最终截断")
|
||||
context_text = context_text[:max_chars] + "\n\n[最终截断...]"
|
||||
|
||||
# 时间解析
|
||||
anchor_date = datetime(2023, 5, 8) # 使用固定日期确保一致性
|
||||
context_text = _resolve_relative_times(context_text, anchor_date)
|
||||
|
||||
context_text = f"Reference date: {anchor_date.date().isoformat()}\n\n" + context_text
|
||||
|
||||
print(f"📝 最终上下文长度: {len(context_text)} 字符")
|
||||
|
||||
# 显示不同上下文的预览(不只是第一条)
|
||||
print("🔍 上下文预览:")
|
||||
for j, context in enumerate(contexts_all[:3]): # 显示前3个上下文
|
||||
preview = context[:150].replace('\n', ' ')
|
||||
print(f" 上下文{j+1}: {preview}...")
|
||||
|
||||
# 🔍 调试:检查答案是否在上下文中
|
||||
if ref_str and ref_str.strip():
|
||||
answer_found = any(ref_str.lower() in ctx.lower() for ctx in contexts_all)
|
||||
print(f"🔍 调试:答案 '{ref_str}' 是否在检索到的上下文中? {'✅ 是' if answer_found else '❌ 否'}")
|
||||
|
||||
else:
|
||||
print("❌ 没有检索到有效上下文")
|
||||
context_text = "No relevant context found."
|
||||
|
||||
# LLM 回答
|
||||
messages = [
|
||||
{"role": "system", "content": (
|
||||
"You are a precise QA assistant. Answer following these rules:\n"
|
||||
"1) Extract the EXACT information mentioned in the context\n"
|
||||
"2) For time questions: calculate actual dates from relative times\n"
|
||||
"3) Return ONLY the answer text in simplest form\n"
|
||||
"4) For dates, use format 'DD Month YYYY' (e.g., '7 May 2023')\n"
|
||||
"5) If no clear answer found, respond with 'Unknown'"
|
||||
)},
|
||||
{"role": "user", "content": f"Question: {q}\n\nContext:\n{context_text}"},
|
||||
]
|
||||
|
||||
t2 = time.time()
|
||||
try:
|
||||
# 使用异步调用
|
||||
resp = await llm.chat(messages=messages)
|
||||
# 兼容不同的响应格式
|
||||
pred = resp.content.strip() if hasattr(resp, 'content') else (resp["choices"][0]["message"]["content"].strip() if isinstance(resp, dict) else "Unknown")
|
||||
except Exception as e:
|
||||
print(f"❌ LLM 生成失败: {e}")
|
||||
pred = "Unknown"
|
||||
t3 = time.time()
|
||||
llm_time = (t3 - t2) * 1000
|
||||
|
||||
# 计算指标 - 使用导入的指标函数
|
||||
f1_val = f1_score(pred, ref_str)
|
||||
bleu1_val = bleu1(pred, ref_str)
|
||||
jaccard_val = jaccard(pred, ref_str)
|
||||
loc_f1_val = loc_f1_score(pred, ref_str)
|
||||
|
||||
print(f"🤖 LLM 回答: {pred}")
|
||||
print(f"📈 指标 - F1: {f1_val:.3f}, BLEU-1: {bleu1_val:.3f}, Jaccard: {jaccard_val:.3f}, LoCoMo F1: {loc_f1_val:.3f}")
|
||||
print(f"⏱️ 时间 - 检索: {search_time:.1f}ms, LLM: {llm_time:.1f}ms")
|
||||
|
||||
# 更新统计
|
||||
total_f1 += f1_val
|
||||
total_bleu1 += bleu1_val
|
||||
total_jaccard += jaccard_val
|
||||
total_loc_f1 += loc_f1_val
|
||||
total_context_length += len(context_text)
|
||||
total_retrieved_docs += len(contexts_all)
|
||||
|
||||
if category not in category_stats:
|
||||
category_stats[category] = {"count": 0, "f1_sum": 0.0, "b1_sum": 0.0, "j_sum": 0.0, "loc_f1_sum": 0.0}
|
||||
|
||||
category_stats[category]["count"] += 1
|
||||
category_stats[category]["f1_sum"] += f1_val
|
||||
category_stats[category]["b1_sum"] += bleu1_val
|
||||
category_stats[category]["j_sum"] += jaccard_val
|
||||
category_stats[category]["loc_f1_sum"] += loc_f1_val
|
||||
|
||||
# 记录性能指标
|
||||
metrics = {"f1": f1_val, "bleu1": bleu1_val, "jaccard": jaccard_val, "loc_f1": loc_f1_val}
|
||||
monitor.record_performance(i, metrics, len(context_text), len(contexts_all))
|
||||
|
||||
# 保存结果
|
||||
question_result = {
|
||||
"question": q,
|
||||
"ground_truth": ref_str,
|
||||
"prediction": pred,
|
||||
"category": category,
|
||||
"metrics": metrics,
|
||||
"retrieval": {
|
||||
"retrieved_documents": len(contexts_all),
|
||||
"context_length": len(context_text),
|
||||
"search_limit": search_limit,
|
||||
"max_chars": max_chars,
|
||||
"recent_performance": recent_performance
|
||||
},
|
||||
"timing": {
|
||||
"search_ms": search_time,
|
||||
"llm_ms": llm_time
|
||||
}
|
||||
}
|
||||
|
||||
results["questions"].append(question_result)
|
||||
|
||||
print("="*60)
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ 评估过程中发生错误: {e}")
|
||||
# 即使出错,也返回已有的结果
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
finally:
|
||||
await connector.close()
|
||||
|
||||
# 计算总体指标
|
||||
n = len(items)
|
||||
if n > 0:
|
||||
results["overall_metrics"] = {
|
||||
"f1": total_f1 / n,
|
||||
"b1": total_bleu1 / n,
|
||||
"j": total_jaccard / n,
|
||||
"loc_f1": total_loc_f1 / n
|
||||
}
|
||||
|
||||
for category, stats in category_stats.items():
|
||||
count = stats["count"]
|
||||
results["category_metrics"][category] = {
|
||||
"count": count,
|
||||
"f1": stats["f1_sum"] / count,
|
||||
"bleu1": stats["b1_sum"] / count,
|
||||
"jaccard": stats["j_sum"] / count,
|
||||
"loc_f1": stats["loc_f1_sum"] / count
|
||||
}
|
||||
|
||||
results["retrieval_stats"]["avg_context_length"] = total_context_length / n
|
||||
results["retrieval_stats"]["avg_retrieved_docs"] = total_retrieved_docs / n
|
||||
|
||||
# 分析性能趋势
|
||||
results["performance_trend"] = monitor.get_performance_trend()
|
||||
results["reset_interval"] = monitor.reset_interval
|
||||
results["total_questions_processed"] = monitor.question_count
|
||||
|
||||
return results
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("🚀 运行增强版完整评估(解决中间性能衰减问题)...")
|
||||
print("📋 增强特性:")
|
||||
print(" - 双重重置策略:定期重置 + 性能驱动重置")
|
||||
print(" - 动态检索参数:基于近期性能自适应调整")
|
||||
print(" - 中间阶段严格筛选:提高上下文质量要求")
|
||||
print(" - 连续性能监控:实时检测性能衰减")
|
||||
|
||||
result = asyncio.run(run_enhanced_evaluation())
|
||||
|
||||
print("\n📊 最终评估结果:")
|
||||
print("总体指标:")
|
||||
print(f" F1: {result['overall_metrics']['f1']:.4f}")
|
||||
print(f" BLEU-1: {result['overall_metrics']['b1']:.4f}")
|
||||
print(f" Jaccard: {result['overall_metrics']['j']:.4f}")
|
||||
print(f" LoCoMo F1: {result['overall_metrics']['loc_f1']:.4f}")
|
||||
|
||||
print("\n分类别指标:")
|
||||
for category, metrics in result['category_metrics'].items():
|
||||
print(f" {category}: F1={metrics['f1']:.4f}, BLEU-1={metrics['bleu1']:.4f}, Jaccard={metrics['jaccard']:.4f}, LoCoMo F1={metrics['loc_f1']:.4f} (样本数: {metrics['count']})")
|
||||
|
||||
print("\n检索统计:")
|
||||
stats = result['retrieval_stats']
|
||||
print(f" 平均上下文长度: {stats['avg_context_length']:.0f} 字符")
|
||||
print(f" 平均检索文档数: {stats['avg_retrieved_docs']:.1f}")
|
||||
|
||||
print(f"\n性能趋势: {result['performance_trend']}")
|
||||
print(f"重置间隔: 每{result['reset_interval']}个问题")
|
||||
print(f"处理问题总数: {result['total_questions_processed']}")
|
||||
print(f"增强策略: {'启用' if result.get('enhanced_strategy', False) else '未启用'}")
|
||||
|
||||
|
||||
# 保存结果到指定目录
|
||||
# 使用代码文件所在目录的绝对路径
|
||||
current_file_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
output_dir = os.path.join(current_file_dir, "results")
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
output_file = os.path.join(output_dir, "enhanced_evaluation_results.json")
|
||||
with open(output_file, "w", encoding="utf-8") as f:
|
||||
json.dump(result, f, ensure_ascii=False, indent=2)
|
||||
print(f"\n详细结果已保存到: {output_file}")
|
||||
@@ -1,626 +0,0 @@
|
||||
"""
|
||||
LoCoMo Utilities Module
|
||||
|
||||
This module provides helper functions for the LoCoMo benchmark evaluation:
|
||||
- Data loading from JSON files
|
||||
- Conversation extraction for ingestion
|
||||
- Temporal reference resolution
|
||||
- Context selection and formatting
|
||||
- Retrieval wrapper functions
|
||||
- Ingestion wrapper functions
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
import re
|
||||
from datetime import datetime, timedelta
|
||||
from typing import List, Dict, Any, Optional
|
||||
|
||||
from app.core.memory.utils.definitions import PROJECT_ROOT
|
||||
from app.core.memory.evaluation.extraction_utils import ingest_contexts_via_full_pipeline
|
||||
|
||||
|
||||
def load_locomo_data(
|
||||
data_path: str,
|
||||
sample_size: int,
|
||||
conversation_index: int = 0
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Load LoCoMo dataset from JSON file.
|
||||
|
||||
The LoCoMo dataset structure is a list of conversation objects, where each
|
||||
object contains a "qa" list of question-answer pairs.
|
||||
|
||||
Args:
|
||||
data_path: Path to locomo10.json file
|
||||
sample_size: Number of QA pairs to load (limits total QA items returned)
|
||||
conversation_index: Which conversation to load QA pairs from (default: 0 for first)
|
||||
|
||||
Returns:
|
||||
List of QA item dictionaries, each containing:
|
||||
- question: str
|
||||
- answer: str
|
||||
- category: int (1-4)
|
||||
- evidence: List[str]
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If data_path does not exist
|
||||
json.JSONDecodeError: If file is not valid JSON
|
||||
IndexError: If conversation_index is out of range
|
||||
"""
|
||||
if not os.path.exists(data_path):
|
||||
raise FileNotFoundError(f"LoCoMo data file not found: {data_path}")
|
||||
|
||||
with open(data_path, "r", encoding="utf-8") as f:
|
||||
raw = json.load(f)
|
||||
|
||||
# LoCoMo data structure: list of objects, each with a "qa" list
|
||||
qa_items: List[Dict[str, Any]] = []
|
||||
|
||||
if isinstance(raw, list):
|
||||
# Only load QA pairs from the specified conversation
|
||||
if conversation_index < len(raw):
|
||||
entry = raw[conversation_index]
|
||||
if isinstance(entry, dict) and "qa" in entry:
|
||||
qa_items.extend(entry.get("qa", []))
|
||||
else:
|
||||
raise IndexError(
|
||||
f"Conversation index {conversation_index} out of range. "
|
||||
f"Dataset has {len(raw)} conversations."
|
||||
)
|
||||
else:
|
||||
# Fallback: single object with qa list
|
||||
if conversation_index == 0:
|
||||
qa_items.extend(raw.get("qa", []))
|
||||
else:
|
||||
raise IndexError(
|
||||
f"Conversation index {conversation_index} out of range. "
|
||||
f"Dataset has only 1 conversation."
|
||||
)
|
||||
|
||||
# Return only the requested sample size
|
||||
return qa_items[:sample_size]
|
||||
|
||||
|
||||
def extract_conversations(data_path: str, max_dialogues: int = 1) -> List[str]:
|
||||
"""
|
||||
Extract conversation texts from LoCoMo data for ingestion.
|
||||
|
||||
This function extracts the raw conversation dialogues from the LoCoMo dataset
|
||||
so they can be ingested into the memory system. Each conversation is formatted
|
||||
as a multi-line string with "role: message" format.
|
||||
|
||||
Args:
|
||||
data_path: Path to locomo10.json file
|
||||
max_dialogues: Maximum number of dialogues to extract (default: 1)
|
||||
|
||||
Returns:
|
||||
List of conversation strings formatted for ingestion.
|
||||
Each string contains multiple lines in format "role: message"
|
||||
|
||||
Example output:
|
||||
[
|
||||
"User: I went to the store yesterday.\\nAI: What did you buy?\\n...",
|
||||
"User: I love hiking.\\nAI: Where do you like to hike?\\n..."
|
||||
]
|
||||
"""
|
||||
if not os.path.exists(data_path):
|
||||
raise FileNotFoundError(f"LoCoMo data file not found: {data_path}")
|
||||
|
||||
with open(data_path, "r", encoding="utf-8") as f:
|
||||
raw = json.load(f)
|
||||
|
||||
# Ensure we have a list of entries
|
||||
entries = raw if isinstance(raw, list) else [raw]
|
||||
|
||||
contents: List[str] = []
|
||||
|
||||
for i, entry in enumerate(entries[:max_dialogues]):
|
||||
if not isinstance(entry, dict):
|
||||
continue
|
||||
|
||||
conv = entry.get("conversation", {})
|
||||
|
||||
if not isinstance(conv, dict):
|
||||
continue
|
||||
|
||||
lines: List[str] = []
|
||||
|
||||
# Collect all session_* messages
|
||||
for key, val in sorted(conv.items()):
|
||||
if isinstance(val, list) and key.startswith("session_"):
|
||||
for msg in val:
|
||||
if not isinstance(msg, dict):
|
||||
continue
|
||||
|
||||
role = msg.get("speaker") or "User"
|
||||
text = msg.get("text") or ""
|
||||
text = str(text).strip()
|
||||
|
||||
if not text:
|
||||
continue
|
||||
|
||||
lines.append(f"{role}: {text}")
|
||||
|
||||
if lines:
|
||||
contents.append("\n".join(lines))
|
||||
|
||||
return contents
|
||||
|
||||
|
||||
def resolve_temporal_references(text: str, anchor_date: datetime) -> str:
|
||||
"""
|
||||
Resolve relative temporal references to absolute dates.
|
||||
|
||||
This function converts relative time expressions (like "today", "yesterday",
|
||||
"3 days ago") into absolute ISO date strings based on an anchor date.
|
||||
|
||||
Supported patterns:
|
||||
- today, yesterday, tomorrow
|
||||
- X days ago, in X days
|
||||
- last week, next week
|
||||
|
||||
Args:
|
||||
text: Text containing temporal references
|
||||
anchor_date: Reference date for resolution (datetime object)
|
||||
|
||||
Returns:
|
||||
Text with temporal references replaced by ISO dates (YYYY-MM-DD format)
|
||||
|
||||
Example:
|
||||
>>> anchor = datetime(2023, 5, 8)
|
||||
>>> resolve_temporal_references("I saw him yesterday", anchor)
|
||||
"I saw him 2023-05-07"
|
||||
"""
|
||||
# Ensure input is a string
|
||||
t = str(text) if text is not None else ""
|
||||
|
||||
# today / yesterday / tomorrow
|
||||
t = re.sub(
|
||||
r"\btoday\b",
|
||||
anchor_date.date().isoformat(),
|
||||
t,
|
||||
flags=re.IGNORECASE
|
||||
)
|
||||
t = re.sub(
|
||||
r"\byesterday\b",
|
||||
(anchor_date - timedelta(days=1)).date().isoformat(),
|
||||
t,
|
||||
flags=re.IGNORECASE
|
||||
)
|
||||
t = re.sub(
|
||||
r"\btomorrow\b",
|
||||
(anchor_date + timedelta(days=1)).date().isoformat(),
|
||||
t,
|
||||
flags=re.IGNORECASE
|
||||
)
|
||||
|
||||
# X days ago
|
||||
def _ago_repl(m: re.Match[str]) -> str:
|
||||
n = int(m.group(1))
|
||||
return (anchor_date - timedelta(days=n)).date().isoformat()
|
||||
|
||||
# in X days
|
||||
def _in_repl(m: re.Match[str]) -> str:
|
||||
n = int(m.group(1))
|
||||
return (anchor_date + timedelta(days=n)).date().isoformat()
|
||||
|
||||
t = re.sub(
|
||||
r"\b(\d+)\s+days?\s+ago\b",
|
||||
_ago_repl,
|
||||
t,
|
||||
flags=re.IGNORECASE
|
||||
)
|
||||
t = re.sub(
|
||||
r"\bin\s+(\d+)\s+days?\b",
|
||||
_in_repl,
|
||||
t,
|
||||
flags=re.IGNORECASE
|
||||
)
|
||||
|
||||
# last week / next week (approximate as 7 days)
|
||||
t = re.sub(
|
||||
r"\blast\s+week\b",
|
||||
(anchor_date - timedelta(days=7)).date().isoformat(),
|
||||
t,
|
||||
flags=re.IGNORECASE
|
||||
)
|
||||
t = re.sub(
|
||||
r"\bnext\s+week\b",
|
||||
(anchor_date + timedelta(days=7)).date().isoformat(),
|
||||
t,
|
||||
flags=re.IGNORECASE
|
||||
)
|
||||
|
||||
return t
|
||||
|
||||
|
||||
def select_and_format_information(
|
||||
retrieved_info: List[str],
|
||||
question: str,
|
||||
max_chars: int = 8000
|
||||
) -> str:
|
||||
"""
|
||||
Intelligently select and format most relevant retrieved information for LLM prompt.
|
||||
|
||||
This function scores each piece of retrieved information based on keyword matching
|
||||
with the question, then selects the highest-scoring pieces up to the character limit.
|
||||
|
||||
Scoring criteria:
|
||||
- Keyword matches (higher weight for multiple occurrences)
|
||||
- Context length (moderate length preferred)
|
||||
- Position (earlier contexts get bonus points)
|
||||
|
||||
Args:
|
||||
retrieved_info: List of retrieved information strings (chunks, statements, entities)
|
||||
question: Question being answered
|
||||
max_chars: Maximum total characters to include in final prompt
|
||||
|
||||
Returns:
|
||||
Formatted string combining the most relevant information for LLM prompt.
|
||||
Contexts are separated by double newlines.
|
||||
|
||||
Example:
|
||||
>>> contexts = ["Alice went to Paris", "Bob likes pizza", "Alice visited the Eiffel Tower"]
|
||||
>>> question = "Where did Alice go?"
|
||||
>>> select_and_format_information(contexts, question, max_chars=100)
|
||||
"Alice went to Paris\\n\\nAlice visited the Eiffel Tower"
|
||||
"""
|
||||
if not retrieved_info:
|
||||
return ""
|
||||
|
||||
# Extract question keywords (filter out stop words and short words)
|
||||
question_lower = question.lower()
|
||||
stop_words = {
|
||||
'what', 'when', 'where', 'who', 'why', 'how',
|
||||
'did', 'do', 'does', 'is', 'are', 'was', 'were',
|
||||
'the', 'a', 'an', 'and', 'or', 'but', 'in', 'on', 'at'
|
||||
}
|
||||
question_words = set(re.findall(r'\b\w+\b', question_lower))
|
||||
question_words = {
|
||||
word for word in question_words
|
||||
if word not in stop_words and len(word) > 2
|
||||
}
|
||||
|
||||
# Score each context
|
||||
scored_contexts = []
|
||||
for i, context in enumerate(retrieved_info):
|
||||
context_lower = context.lower()
|
||||
score = 0
|
||||
|
||||
# Keyword matching score
|
||||
keyword_matches = 0
|
||||
for word in question_words:
|
||||
if word in context_lower:
|
||||
keyword_matches += 1
|
||||
# Multiple occurrences increase score
|
||||
score += context_lower.count(word) * 2
|
||||
|
||||
# Length score (prefer moderate length)
|
||||
context_len = len(context)
|
||||
if 100 < context_len < 2000:
|
||||
score += 5
|
||||
elif context_len >= 2000:
|
||||
score += 2
|
||||
|
||||
# Position bonus (earlier contexts often more relevant)
|
||||
if i < 3:
|
||||
score += 3
|
||||
|
||||
scored_contexts.append((score, context, keyword_matches))
|
||||
|
||||
# Sort by score (descending)
|
||||
scored_contexts.sort(key=lambda x: x[0], reverse=True)
|
||||
|
||||
# Select contexts up to character limit
|
||||
selected = []
|
||||
total_chars = 0
|
||||
|
||||
for score, context, matches in scored_contexts:
|
||||
if total_chars + len(context) <= max_chars:
|
||||
selected.append(context)
|
||||
total_chars += len(context)
|
||||
else:
|
||||
# Try to include high-scoring context by truncating
|
||||
if score > 10 and total_chars < max_chars - 500:
|
||||
remaining = max_chars - total_chars
|
||||
# Find lines with keywords
|
||||
lines = context.split('\n')
|
||||
relevant_lines = []
|
||||
current_chars = 0
|
||||
|
||||
for line in lines:
|
||||
line_lower = line.lower()
|
||||
line_relevance = any(word in line_lower for word in question_words)
|
||||
|
||||
if line_relevance and current_chars < remaining - 100:
|
||||
relevant_lines.append(line)
|
||||
current_chars += len(line)
|
||||
|
||||
if relevant_lines and len('\n'.join(relevant_lines)) > 100:
|
||||
truncated = '\n'.join(relevant_lines)
|
||||
selected.append(truncated + "\n[Content truncated...]")
|
||||
total_chars += len(truncated)
|
||||
break
|
||||
|
||||
return "\n\n".join(selected)
|
||||
|
||||
|
||||
async def retrieve_relevant_information(
|
||||
question: str,
|
||||
group_id: str,
|
||||
search_type: str,
|
||||
search_limit: int,
|
||||
connector: Any,
|
||||
embedder: Any
|
||||
) -> List[str]:
|
||||
"""
|
||||
Retrieve relevant information from memory graph for a question.
|
||||
|
||||
This function searches the Neo4j memory graph (populated during ingestion) and
|
||||
returns relevant chunks, statements, and entity information that might help
|
||||
answer the question.
|
||||
|
||||
The function supports three search types:
|
||||
- "keyword": Full-text search using Cypher queries
|
||||
- "embedding": Vector similarity search using embeddings
|
||||
- "hybrid": Combination of keyword and embedding search with reranking
|
||||
|
||||
Args:
|
||||
question: Question to search for
|
||||
group_id: Database group ID (identifies which conversation memory to search)
|
||||
search_type: "keyword", "embedding", or "hybrid"
|
||||
search_limit: Max memory pieces to retrieve
|
||||
connector: Neo4j connector instance
|
||||
embedder: Embedder client instance
|
||||
|
||||
Returns:
|
||||
List of text strings (chunks, statements, entity summaries) from memory graph.
|
||||
Each string represents a piece of retrieved information.
|
||||
|
||||
Raises:
|
||||
Exception: If search fails (caught and returns empty list)
|
||||
"""
|
||||
from app.repositories.neo4j.graph_search import (
|
||||
search_graph,
|
||||
search_graph_by_embedding
|
||||
)
|
||||
from app.core.memory.storage_services.search import run_hybrid_search
|
||||
|
||||
contexts_all: List[str] = []
|
||||
|
||||
try:
|
||||
if search_type == "embedding":
|
||||
# Embedding-based search
|
||||
search_results = await search_graph_by_embedding(
|
||||
connector=connector,
|
||||
embedder_client=embedder,
|
||||
query_text=question,
|
||||
group_id=group_id,
|
||||
limit=search_limit,
|
||||
include=["chunks", "statements", "entities", "summaries"],
|
||||
)
|
||||
|
||||
chunks = search_results.get("chunks", [])
|
||||
statements = search_results.get("statements", [])
|
||||
entities = search_results.get("entities", [])
|
||||
summaries = search_results.get("summaries", [])
|
||||
|
||||
# Build context from chunks
|
||||
for c in chunks:
|
||||
content = str(c.get("content", "")).strip()
|
||||
if content:
|
||||
contexts_all.append(content)
|
||||
|
||||
# Add statements
|
||||
for s in statements:
|
||||
stmt_text = str(s.get("statement", "")).strip()
|
||||
if stmt_text:
|
||||
contexts_all.append(stmt_text)
|
||||
|
||||
# Add summaries
|
||||
for sm in summaries:
|
||||
summary_text = str(sm.get("summary", "")).strip()
|
||||
if summary_text:
|
||||
contexts_all.append(summary_text)
|
||||
|
||||
# Add top entities (limit to 3 to avoid noise)
|
||||
if entities:
|
||||
scored = [e for e in entities if e.get("score") is not None]
|
||||
top_entities = (
|
||||
sorted(scored, key=lambda x: x.get("score", 0), reverse=True)[:3]
|
||||
if scored else entities[:3]
|
||||
)
|
||||
if top_entities:
|
||||
summary_lines = []
|
||||
for e in top_entities:
|
||||
name = str(e.get("name", "")).strip()
|
||||
etype = str(e.get("entity_type", "")).strip()
|
||||
score = e.get("score")
|
||||
if name:
|
||||
meta = []
|
||||
if etype:
|
||||
meta.append(f"type={etype}")
|
||||
if isinstance(score, (int, float)):
|
||||
meta.append(f"score={score:.3f}")
|
||||
summary_lines.append(
|
||||
f"EntitySummary: {name}"
|
||||
f"{(' [' + '; '.join(meta) + ']') if meta else ''}"
|
||||
)
|
||||
if summary_lines:
|
||||
contexts_all.append("\n".join(summary_lines))
|
||||
|
||||
elif search_type == "keyword":
|
||||
# Keyword-based search
|
||||
search_results = await search_graph(
|
||||
connector=connector,
|
||||
q=question,
|
||||
group_id=group_id,
|
||||
limit=search_limit
|
||||
)
|
||||
|
||||
dialogs = search_results.get("dialogues", [])
|
||||
statements = search_results.get("statements", [])
|
||||
entities = search_results.get("entities", [])
|
||||
|
||||
# Build context from dialogues
|
||||
for d in dialogs:
|
||||
content = str(d.get("content", "")).strip()
|
||||
if content:
|
||||
contexts_all.append(content)
|
||||
|
||||
# Add statements
|
||||
for s in statements:
|
||||
stmt_text = str(s.get("statement", "")).strip()
|
||||
if stmt_text:
|
||||
contexts_all.append(stmt_text)
|
||||
|
||||
# Add entity names
|
||||
if entities:
|
||||
entity_names = [
|
||||
str(e.get("name", "")).strip()
|
||||
for e in entities[:5]
|
||||
if e.get("name")
|
||||
]
|
||||
if entity_names:
|
||||
contexts_all.append(f"EntitySummary: {', '.join(entity_names)}")
|
||||
|
||||
else: # hybrid
|
||||
# Hybrid search with fallback to embedding
|
||||
try:
|
||||
search_results = await run_hybrid_search(
|
||||
query_text=question,
|
||||
search_type=search_type,
|
||||
group_id=group_id,
|
||||
limit=search_limit,
|
||||
include=["chunks", "statements", "entities", "summaries"],
|
||||
output_path=None,
|
||||
)
|
||||
|
||||
# Handle flat structure (new API format)
|
||||
if search_results and isinstance(search_results, dict):
|
||||
chunks = search_results.get("chunks", [])
|
||||
statements = search_results.get("statements", [])
|
||||
entities = search_results.get("entities", [])
|
||||
summaries = search_results.get("summaries", [])
|
||||
|
||||
# Check if we got results
|
||||
if not (chunks or statements or entities or summaries):
|
||||
# Try nested structure (backward compatibility)
|
||||
reranked = search_results.get("reranked_results", {})
|
||||
if reranked and isinstance(reranked, dict):
|
||||
chunks = reranked.get("chunks", [])
|
||||
statements = reranked.get("statements", [])
|
||||
entities = reranked.get("entities", [])
|
||||
summaries = reranked.get("summaries", [])
|
||||
else:
|
||||
raise ValueError("Hybrid search returned empty results")
|
||||
else:
|
||||
raise ValueError("Hybrid search returned empty results")
|
||||
|
||||
except Exception as e:
|
||||
# Fallback to embedding search
|
||||
search_results = await search_graph_by_embedding(
|
||||
connector=connector,
|
||||
embedder_client=embedder,
|
||||
query_text=question,
|
||||
group_id=group_id,
|
||||
limit=search_limit,
|
||||
include=["chunks", "statements", "entities", "summaries"],
|
||||
)
|
||||
chunks = search_results.get("chunks", [])
|
||||
statements = search_results.get("statements", [])
|
||||
entities = search_results.get("entities", [])
|
||||
summaries = search_results.get("summaries", [])
|
||||
|
||||
# Build context (same for both hybrid and fallback)
|
||||
for c in chunks:
|
||||
content = str(c.get("content", "")).strip()
|
||||
if content:
|
||||
contexts_all.append(content)
|
||||
|
||||
for s in statements:
|
||||
stmt_text = str(s.get("statement", "")).strip()
|
||||
if stmt_text:
|
||||
contexts_all.append(stmt_text)
|
||||
|
||||
for sm in summaries:
|
||||
summary_text = str(sm.get("summary", "")).strip()
|
||||
if summary_text:
|
||||
contexts_all.append(summary_text)
|
||||
|
||||
# Add top entities
|
||||
if entities:
|
||||
scored = [e for e in entities if e.get("score") is not None]
|
||||
top_entities = (
|
||||
sorted(scored, key=lambda x: x.get("score", 0), reverse=True)[:3]
|
||||
if scored else entities[:3]
|
||||
)
|
||||
if top_entities:
|
||||
summary_lines = []
|
||||
for e in top_entities:
|
||||
name = str(e.get("name", "")).strip()
|
||||
etype = str(e.get("entity_type", "")).strip()
|
||||
score = e.get("score")
|
||||
if name:
|
||||
meta = []
|
||||
if etype:
|
||||
meta.append(f"type={etype}")
|
||||
if isinstance(score, (int, float)):
|
||||
meta.append(f"score={score:.3f}")
|
||||
summary_lines.append(
|
||||
f"EntitySummary: {name}"
|
||||
f"{(' [' + '; '.join(meta) + ']') if meta else ''}"
|
||||
)
|
||||
if summary_lines:
|
||||
contexts_all.append("\n".join(summary_lines))
|
||||
|
||||
except Exception as e:
|
||||
# Return empty list on error
|
||||
contexts_all = []
|
||||
|
||||
return contexts_all
|
||||
|
||||
|
||||
async def ingest_conversations_if_needed(
|
||||
conversations: List[str],
|
||||
group_id: str,
|
||||
reset: bool = False
|
||||
) -> bool:
|
||||
"""
|
||||
Wrapper for conversation ingestion using external extraction pipeline.
|
||||
|
||||
This function populates the Neo4j database with processed conversation data
|
||||
(chunks, statements, entities) so that the retrieval system has memory to search.
|
||||
|
||||
The ingestion process:
|
||||
1. Parses conversation text into dialogue messages
|
||||
2. Chunks the dialogues into semantic units
|
||||
3. Extracts statements and entities using LLM
|
||||
4. Generates embeddings for all content
|
||||
5. Stores everything in Neo4j graph database
|
||||
|
||||
Args:
|
||||
conversations: List of raw conversation texts from LoCoMo dataset
|
||||
Example: ["User: I went to Paris. AI: When was that?", ...]
|
||||
group_id: Target group ID for database storage
|
||||
reset: Whether to clear existing data first (not implemented in wrapper)
|
||||
|
||||
Returns:
|
||||
True if successful, False otherwise
|
||||
|
||||
Note:
|
||||
The external function uses "contexts" to mean "conversation texts".
|
||||
This runs the full extraction pipeline: chunking → entity extraction →
|
||||
statement extraction → embedding → Neo4j storage.
|
||||
"""
|
||||
try:
|
||||
success = await ingest_contexts_via_full_pipeline(
|
||||
contexts=conversations,
|
||||
group_id=group_id,
|
||||
save_chunk_output=True
|
||||
)
|
||||
return success
|
||||
except Exception as e:
|
||||
print(f"[Ingestion] Failed to ingest conversations: {e}")
|
||||
return False
|
||||
@@ -1,878 +0,0 @@
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import statistics
|
||||
import time
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any, Dict, List
|
||||
|
||||
try:
|
||||
from dotenv import load_dotenv
|
||||
except Exception:
|
||||
def load_dotenv():
|
||||
return None
|
||||
|
||||
import re
|
||||
|
||||
from app.core.memory.evaluation.common.metrics import (
|
||||
avg_context_tokens,
|
||||
bleu1,
|
||||
jaccard,
|
||||
latency_stats,
|
||||
)
|
||||
from app.core.memory.evaluation.common.metrics import f1_score as common_f1
|
||||
from app.core.memory.evaluation.extraction_utils import (
|
||||
ingest_contexts_via_full_pipeline,
|
||||
)
|
||||
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
|
||||
from app.core.memory.storage_services.search import run_hybrid_search
|
||||
from app.core.memory.utils.config.definitions import (
|
||||
PROJECT_ROOT,
|
||||
SELECTED_EMBEDDING_ID,
|
||||
SELECTED_GROUP_ID,
|
||||
SELECTED_LLM_ID,
|
||||
)
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
from app.db import get_db_context
|
||||
from app.repositories.neo4j.graph_search import search_graph, search_graph_by_embedding
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
|
||||
|
||||
# 参考 evaluation/locomo/evaluation.py 的 F1 计算逻辑(移除外部依赖,内联实现)
|
||||
def _loc_normalize(text: str) -> str:
|
||||
import re
|
||||
# 确保输入是字符串
|
||||
text = str(text) if text is not None else ""
|
||||
text = text.lower()
|
||||
text = re.sub(r"[\,]", " ", text) # 去掉逗号
|
||||
text = re.sub(r"\b(a|an|the|and)\b", " ", text)
|
||||
text = re.sub(r"[^\w\s]", " ", text)
|
||||
text = " ".join(text.split())
|
||||
return text
|
||||
|
||||
# 追加:相对时间归一化为绝对日期(有限支持:today/yesterday/tomorrow/X days ago/in X days/last week/next week)
|
||||
def _resolve_relative_times(text: str, anchor: datetime) -> str:
|
||||
import re
|
||||
# 确保输入是字符串
|
||||
t = str(text) if text is not None else ""
|
||||
# today / yesterday / tomorrow
|
||||
t = re.sub(r"\btoday\b", anchor.date().isoformat(), t, flags=re.IGNORECASE)
|
||||
t = re.sub(r"\byesterday\b", (anchor - timedelta(days=1)).date().isoformat(), t, flags=re.IGNORECASE)
|
||||
t = re.sub(r"\btomorrow\b", (anchor + timedelta(days=1)).date().isoformat(), t, flags=re.IGNORECASE)
|
||||
# X days ago / in X days
|
||||
def _ago_repl(m: re.Match[str]) -> str:
|
||||
n = int(m.group(1))
|
||||
return (anchor - timedelta(days=n)).date().isoformat()
|
||||
def _in_repl(m: re.Match[str]) -> str:
|
||||
n = int(m.group(1))
|
||||
return (anchor + timedelta(days=n)).date().isoformat()
|
||||
t = re.sub(r"\b(\d+)\s+days\s+ago\b", _ago_repl, t, flags=re.IGNORECASE)
|
||||
t = re.sub(r"\bin\s+(\d+)\s+days\b", _in_repl, t, flags=re.IGNORECASE)
|
||||
# last week / next week(以7天近似)
|
||||
t = re.sub(r"\blast\s+week\b", (anchor - timedelta(days=7)).date().isoformat(), t, flags=re.IGNORECASE)
|
||||
t = re.sub(r"\bnext\s+week\b", (anchor + timedelta(days=7)).date().isoformat(), t, flags=re.IGNORECASE)
|
||||
return t
|
||||
|
||||
def loc_f1_score(prediction: str, ground_truth: str) -> float:
|
||||
# 单答案 F1:按词集合计算(近似原始实现,去除词干依赖)
|
||||
# 确保输入是字符串
|
||||
pred_str = str(prediction) if prediction is not None else ""
|
||||
truth_str = str(ground_truth) if ground_truth is not None else ""
|
||||
|
||||
p_tokens = _loc_normalize(pred_str).split()
|
||||
g_tokens = _loc_normalize(truth_str).split()
|
||||
if not p_tokens or not g_tokens:
|
||||
return 0.0
|
||||
p = set(p_tokens)
|
||||
g = set(g_tokens)
|
||||
tp = len(p & g)
|
||||
precision = tp / len(p) if p else 0.0
|
||||
recall = tp / len(g) if g else 0.0
|
||||
return (2 * precision * recall / (precision + recall)) if (precision + recall) > 0 else 0.0
|
||||
|
||||
def loc_multi_f1(prediction: str, ground_truth: str) -> float:
|
||||
# 多答案 F1:prediction 与 ground_truth 以逗号分隔,逐一匹配取最大,再对多个 GT 取平均
|
||||
# 确保输入是字符串
|
||||
pred_str = str(prediction) if prediction is not None else ""
|
||||
truth_str = str(ground_truth) if ground_truth is not None else ""
|
||||
|
||||
predictions = [p.strip() for p in str(pred_str).split(',') if p.strip()]
|
||||
ground_truths = [g.strip() for g in str(truth_str).split(',') if g.strip()]
|
||||
if not predictions or not ground_truths:
|
||||
return 0.0
|
||||
def _f1(a: str, b: str) -> float:
|
||||
return loc_f1_score(a, b)
|
||||
vals = []
|
||||
for gt in ground_truths:
|
||||
vals.append(max(_f1(pred, gt) for pred in predictions))
|
||||
return sum(vals) / len(vals)
|
||||
|
||||
# 标准化 LoCoMo 类别名:支持数字 category 与字符串 cat/type
|
||||
CATEGORY_MAP_NUM_TO_NAME = {
|
||||
4: "Single-Hop",
|
||||
1: "Multi-Hop",
|
||||
3: "Open Domain",
|
||||
2: "Temporal",
|
||||
}
|
||||
|
||||
_TYPE_ALIASES = {
|
||||
"single-hop": "Single-Hop",
|
||||
"singlehop": "Single-Hop",
|
||||
"single hop": "Single-Hop",
|
||||
"multi-hop": "Multi-Hop",
|
||||
"multihop": "Multi-Hop",
|
||||
"multi hop": "Multi-Hop",
|
||||
"open domain": "Open Domain",
|
||||
"opendomain": "Open Domain",
|
||||
"temporal": "Temporal",
|
||||
}
|
||||
|
||||
def get_category_label(item: Dict[str, Any]) -> str:
|
||||
# 1) 直接用字符串 cat
|
||||
cat = item.get("cat")
|
||||
if isinstance(cat, str) and cat.strip():
|
||||
name = cat.strip()
|
||||
lower = name.lower()
|
||||
return _TYPE_ALIASES.get(lower, name)
|
||||
# 2) 数字 category 转名称
|
||||
cat_num = item.get("category")
|
||||
if isinstance(cat_num, int):
|
||||
return CATEGORY_MAP_NUM_TO_NAME.get(cat_num, "unknown")
|
||||
# 3) 备用 type 字段
|
||||
t = item.get("type")
|
||||
if isinstance(t, str) and t.strip():
|
||||
lower = t.strip().lower()
|
||||
return _TYPE_ALIASES.get(lower, t.strip())
|
||||
return "unknown"
|
||||
|
||||
|
||||
def smart_context_selection(contexts: List[str], question: str, max_chars: int = 12000) -> str:
|
||||
"""基于问题关键词智能选择上下文"""
|
||||
if not contexts:
|
||||
return ""
|
||||
|
||||
# 提取问题关键词(只保留有意义的词)
|
||||
question_lower = question.lower()
|
||||
stop_words = {'what', 'when', 'where', 'who', 'why', 'how', 'did', 'do', 'does', 'is', 'are', 'was', 'were', 'the', 'a', 'an', 'and', 'or', 'but'}
|
||||
question_words = set(re.findall(r'\b\w+\b', question_lower))
|
||||
question_words = {word for word in question_words if word not in stop_words and len(word) > 2}
|
||||
|
||||
print(f"🔍 问题关键词: {question_words}")
|
||||
|
||||
# 给每个上下文打分
|
||||
scored_contexts = []
|
||||
for i, context in enumerate(contexts):
|
||||
context_lower = context.lower()
|
||||
score = 0
|
||||
|
||||
# 关键词匹配得分
|
||||
keyword_matches = 0
|
||||
for word in question_words:
|
||||
if word in context_lower:
|
||||
keyword_matches += 1
|
||||
# 关键词出现次数越多,得分越高
|
||||
score += context_lower.count(word) * 2
|
||||
|
||||
# 上下文长度得分(适中的长度更好)
|
||||
context_len = len(context)
|
||||
if 100 < context_len < 2000: # 理想长度范围
|
||||
score += 5
|
||||
elif context_len >= 2000: # 太长可能包含无关信息
|
||||
score += 2
|
||||
|
||||
# 如果是前几个上下文,给予额外分数(通常相关性更高)
|
||||
if i < 3:
|
||||
score += 3
|
||||
|
||||
scored_contexts.append((score, context, keyword_matches))
|
||||
|
||||
# 按得分排序
|
||||
scored_contexts.sort(key=lambda x: x[0], reverse=True)
|
||||
|
||||
# 选择高得分的上下文,直到达到字符限制
|
||||
selected = []
|
||||
total_chars = 0
|
||||
selected_count = 0
|
||||
|
||||
print("📊 上下文相关性分析:")
|
||||
for score, context, matches in scored_contexts[:5]: # 只显示前5个
|
||||
print(f" - 得分: {score}, 关键词匹配: {matches}, 长度: {len(context)}")
|
||||
|
||||
for score, context, matches in scored_contexts:
|
||||
if total_chars + len(context) <= max_chars:
|
||||
selected.append(context)
|
||||
total_chars += len(context)
|
||||
selected_count += 1
|
||||
else:
|
||||
# 如果这个上下文得分很高但放不下,尝试截取
|
||||
if score > 10 and total_chars < max_chars - 500:
|
||||
remaining = max_chars - total_chars
|
||||
# 找到包含关键词的部分
|
||||
lines = context.split('\n')
|
||||
relevant_lines = []
|
||||
current_chars = 0
|
||||
|
||||
for line in lines:
|
||||
line_lower = line.lower()
|
||||
line_relevance = any(word in line_lower for word in question_words)
|
||||
|
||||
if line_relevance and current_chars < remaining - 100:
|
||||
relevant_lines.append(line)
|
||||
current_chars += len(line)
|
||||
|
||||
if relevant_lines:
|
||||
truncated = '\n'.join(relevant_lines)
|
||||
if len(truncated) > 100: # 确保有足够内容
|
||||
selected.append(truncated + "\n[相关内容截断...]")
|
||||
total_chars += len(truncated)
|
||||
selected_count += 1
|
||||
break # 不再尝试添加更多上下文
|
||||
|
||||
result = "\n\n".join(selected)
|
||||
print(f"✅ 智能选择: {selected_count}个上下文, 总长度: {total_chars}字符")
|
||||
return result
|
||||
|
||||
|
||||
def get_search_params_by_category(category: str):
|
||||
"""根据问题类别调整检索参数"""
|
||||
params_map = {
|
||||
"Multi-Hop": {"limit": 20, "max_chars": 15000},
|
||||
"Temporal": {"limit": 16, "max_chars": 10000},
|
||||
"Open Domain": {"limit": 24, "max_chars": 18000},
|
||||
"Single-Hop": {"limit": 12, "max_chars": 8000},
|
||||
}
|
||||
return params_map.get(category, {"limit": 16, "max_chars": 12000})
|
||||
|
||||
|
||||
async def run_locomo_eval(
|
||||
sample_size: int = 1,
|
||||
group_id: str | None = None,
|
||||
search_limit: int = 8,
|
||||
context_char_budget: int = 4000, # 保持默认值不变
|
||||
llm_temperature: float = 0.0,
|
||||
llm_max_tokens: int = 32,
|
||||
search_type: str = "hybrid", # 保持默认值不变
|
||||
output_path: str | None = None,
|
||||
skip_ingest_if_exists: bool = True,
|
||||
llm_timeout: float = 10.0,
|
||||
llm_max_retries: int = 1
|
||||
) -> Dict[str, Any]:
|
||||
|
||||
# 函数内部使用三路检索逻辑,但保持参数签名不变
|
||||
group_id = group_id or SELECTED_GROUP_ID
|
||||
data_path = os.path.join(PROJECT_ROOT, "data", "locomo10.json")
|
||||
if not os.path.exists(data_path):
|
||||
data_path = os.path.join(os.getcwd(), "data", "locomo10.json")
|
||||
with open(data_path, "r", encoding="utf-8") as f:
|
||||
raw = json.load(f)
|
||||
# LoCoMo 数据结构:顶层为若干对象,每个对象下有 qa 列表
|
||||
qa_items: List[Dict[str, Any]] = []
|
||||
if isinstance(raw, list):
|
||||
for entry in raw:
|
||||
qa_items.extend(entry.get("qa", []))
|
||||
else:
|
||||
qa_items.extend(raw.get("qa", []))
|
||||
items: List[Dict[str, Any]] = qa_items[:sample_size]
|
||||
|
||||
# === 保持原来的数据摄入逻辑 ===
|
||||
entries = raw if isinstance(raw, list) else [raw]
|
||||
|
||||
# 只摄入前1条对话(保持原样)
|
||||
max_dialogues_to_ingest = 1
|
||||
contents: List[str] = []
|
||||
print(f"📊 找到 {len(entries)} 个对话对象,只摄入前 {max_dialogues_to_ingest} 条")
|
||||
|
||||
for i, entry in enumerate(entries[:max_dialogues_to_ingest]):
|
||||
if not isinstance(entry, dict):
|
||||
continue
|
||||
|
||||
conv = entry.get("conversation", {})
|
||||
sample_id = entry.get("sample_id", f"unknown_{i}")
|
||||
|
||||
print(f"🔍 处理对话 {i+1}: {sample_id}")
|
||||
|
||||
lines: List[str] = []
|
||||
if isinstance(conv, dict):
|
||||
# 收集所有 session_* 的消息
|
||||
session_count = 0
|
||||
for key, val in conv.items():
|
||||
if isinstance(val, list) and key.startswith("session_"):
|
||||
session_count += 1
|
||||
for msg in val:
|
||||
role = msg.get("speaker") or "用户"
|
||||
text = msg.get("text") or ""
|
||||
text = str(text).strip()
|
||||
if not text:
|
||||
continue
|
||||
lines.append(f"{role}: {text}")
|
||||
|
||||
print(f" - 包含 {session_count} 个session, {len(lines)} 条消息")
|
||||
|
||||
if not lines:
|
||||
print(f"⚠️ 警告: 对话 {sample_id} 没有对话内容,跳过摄入")
|
||||
continue
|
||||
|
||||
contents.append("\n".join(lines))
|
||||
|
||||
print(f"📥 总共摄入 {len(contents)} 个对话的conversation内容")
|
||||
|
||||
# 选择要评测的QA对(从所有对话中选取)
|
||||
indexed_items: List[tuple[int, Dict[str, Any]]] = []
|
||||
if isinstance(raw, list):
|
||||
for e_idx, entry in enumerate(raw):
|
||||
for qa in entry.get("qa", []):
|
||||
indexed_items.append((e_idx, qa))
|
||||
else:
|
||||
for qa in raw.get("qa", []):
|
||||
indexed_items.append((0, qa))
|
||||
|
||||
# 这里使用sample_size来限制评测的QA数量
|
||||
selected = indexed_items[:sample_size]
|
||||
items: List[Dict[str, Any]] = [qa for _, qa in selected]
|
||||
|
||||
print(f"🎯 将评测 {len(items)} 个QA对,数据库中只包含 {len(contents)} 个对话")
|
||||
# === 修改结束 ===
|
||||
|
||||
connector = Neo4jConnector()
|
||||
|
||||
# 关键修复:强制重新摄入纯净的对话数据
|
||||
print("🔄 强制重新摄入纯净的对话数据...")
|
||||
await ingest_contexts_via_full_pipeline(contents, group_id, save_chunk_output=True)
|
||||
|
||||
# 使用异步LLM客户端
|
||||
with get_db_context() as db:
|
||||
factory = MemoryClientFactory(db)
|
||||
llm_client = factory.get_llm_client(SELECTED_LLM_ID)
|
||||
# 初始化embedder用于直接调用
|
||||
with get_db_context() as db:
|
||||
config_service = MemoryConfigService(db)
|
||||
cfg_dict = config_service.get_embedder_config(SELECTED_EMBEDDING_ID)
|
||||
embedder = OpenAIEmbedderClient(
|
||||
model_config=RedBearModelConfig.model_validate(cfg_dict)
|
||||
)
|
||||
|
||||
# connector initialized above
|
||||
latencies_llm: List[float] = []
|
||||
latencies_search: List[float] = []
|
||||
# 上下文诊断收集
|
||||
per_query_context_counts: List[int] = []
|
||||
per_query_context_avg_tokens: List[float] = []
|
||||
per_query_context_chars: List[int] = []
|
||||
per_query_context_tokens_total: List[int] = []
|
||||
# 详细样本调试信息
|
||||
samples: List[Dict[str, Any]] = []
|
||||
# 通用指标
|
||||
f1s: List[float] = []
|
||||
b1s: List[float] = []
|
||||
jss: List[float] = []
|
||||
# 参考 LoCoMo 评测的类别专用 F1(multi-hop 使用多答案 F1)
|
||||
loc_f1s: List[float] = []
|
||||
# Per-category aggregation
|
||||
cat_counts: Dict[str, int] = {}
|
||||
cat_f1s: Dict[str, List[float]] = {}
|
||||
cat_b1s: Dict[str, List[float]] = {}
|
||||
cat_jss: Dict[str, List[float]] = {}
|
||||
cat_loc_f1s: Dict[str, List[float]] = {}
|
||||
try:
|
||||
for item in items:
|
||||
q = item.get("question", "")
|
||||
ref = item.get("answer", "")
|
||||
# 确保答案是字符串
|
||||
ref_str = str(ref) if ref is not None else ""
|
||||
cat = get_category_label(item)
|
||||
|
||||
print(f"\n=== 处理问题: {q} ===")
|
||||
|
||||
# 根据类别调整检索参数
|
||||
search_params = get_search_params_by_category(cat)
|
||||
adjusted_limit = search_params["limit"]
|
||||
max_chars = search_params["max_chars"]
|
||||
|
||||
print(f"🏷️ 类别: {cat}, 检索参数: limit={adjusted_limit}, max_chars={max_chars}")
|
||||
|
||||
# 改进的检索逻辑:使用三路检索(statements, dialogues, entities)
|
||||
t0 = time.time()
|
||||
contexts_all: List[str] = []
|
||||
search_results = None # 保存完整的检索结果
|
||||
|
||||
try:
|
||||
if search_type == "embedding":
|
||||
# 直接调用嵌入检索,包含三路数据
|
||||
search_results = await search_graph_by_embedding(
|
||||
connector=connector,
|
||||
embedder_client=embedder,
|
||||
query_text=q,
|
||||
group_id=group_id,
|
||||
limit=adjusted_limit,
|
||||
include=["chunks", "statements", "entities", "summaries"], # 修复:使用正确的类型
|
||||
)
|
||||
chunks = search_results.get("chunks", [])
|
||||
statements = search_results.get("statements", [])
|
||||
entities = search_results.get("entities", [])
|
||||
summaries = search_results.get("summaries", [])
|
||||
|
||||
print(f"✅ 嵌入检索成功: {len(chunks)} chunks, {len(statements)} 条陈述, {len(entities)} 个实体, {len(summaries)} 个摘要")
|
||||
|
||||
# 构建上下文:优先使用 chunks、statements 和 summaries
|
||||
for c in chunks:
|
||||
content = str(c.get("content", "")).strip()
|
||||
if content:
|
||||
contexts_all.append(content)
|
||||
|
||||
for s in statements:
|
||||
stmt_text = str(s.get("statement", "")).strip()
|
||||
if stmt_text:
|
||||
contexts_all.append(stmt_text)
|
||||
|
||||
for sm in summaries:
|
||||
summary_text = str(sm.get("summary", "")).strip()
|
||||
if summary_text:
|
||||
contexts_all.append(summary_text)
|
||||
|
||||
# 实体摘要:最多加入前3个高分实体,避免噪声
|
||||
scored = [e for e in entities if e.get("score") is not None]
|
||||
top_entities = sorted(scored, key=lambda x: x.get("score", 0), reverse=True)[:3] if scored else entities[:3]
|
||||
if top_entities:
|
||||
summary_lines = []
|
||||
for e in top_entities:
|
||||
name = str(e.get("name", "")).strip()
|
||||
etype = str(e.get("entity_type", "")).strip()
|
||||
score = e.get("score")
|
||||
if name:
|
||||
meta = []
|
||||
if etype:
|
||||
meta.append(f"type={etype}")
|
||||
if isinstance(score, (int, float)):
|
||||
meta.append(f"score={score:.3f}")
|
||||
summary_lines.append(f"EntitySummary: {name}{(' [' + '; '.join(meta) + ']') if meta else ''}")
|
||||
if summary_lines:
|
||||
contexts_all.append("\n".join(summary_lines))
|
||||
|
||||
elif search_type == "keyword":
|
||||
# 直接调用关键词检索
|
||||
search_results = await search_graph(
|
||||
connector=connector,
|
||||
q=q,
|
||||
group_id=group_id,
|
||||
limit=adjusted_limit
|
||||
)
|
||||
dialogs = search_results.get("dialogues", [])
|
||||
statements = search_results.get("statements", [])
|
||||
entities = search_results.get("entities", [])
|
||||
print(f"🔤 关键词检索找到 {len(dialogs)} 条对话, {len(statements)} 条陈述, {len(entities)} 个实体")
|
||||
|
||||
# 构建上下文
|
||||
for d in dialogs:
|
||||
content = str(d.get("content", "")).strip()
|
||||
if content:
|
||||
contexts_all.append(content)
|
||||
for s in statements:
|
||||
stmt_text = str(s.get("statement", "")).strip()
|
||||
if stmt_text:
|
||||
contexts_all.append(stmt_text)
|
||||
# 实体处理(关键词检索的实体可能没有分数)
|
||||
if entities:
|
||||
entity_names = [str(e.get("name", "")).strip() for e in entities[:5] if e.get("name")]
|
||||
if entity_names:
|
||||
contexts_all.append(f"EntitySummary: {', '.join(entity_names)}")
|
||||
|
||||
else: # hybrid
|
||||
# 🎯 关键修复:混合检索使用更严格的回退机制
|
||||
print("🔀 使用混合检索(带回退机制)...")
|
||||
try:
|
||||
search_results = await run_hybrid_search(
|
||||
query_text=q,
|
||||
search_type=search_type,
|
||||
group_id=group_id,
|
||||
limit=adjusted_limit,
|
||||
include=["chunks", "statements", "entities", "summaries"],
|
||||
output_path=None,
|
||||
)
|
||||
|
||||
# 🎯 关键修复:正确处理混合检索的扁平结构
|
||||
# 新的API返回扁平结构,直接从顶层获取结果
|
||||
if search_results and isinstance(search_results, dict):
|
||||
# 新API返回扁平结构:直接从顶层获取
|
||||
chunks = search_results.get("chunks", [])
|
||||
statements = search_results.get("statements", [])
|
||||
entities = search_results.get("entities", [])
|
||||
summaries = search_results.get("summaries", [])
|
||||
|
||||
# 检查是否有有效结果
|
||||
if chunks or statements or entities or summaries:
|
||||
print(f"✅ 混合检索成功: {len(chunks)} chunks, {len(statements)} 陈述, {len(entities)} 实体, {len(summaries)} 摘要")
|
||||
else:
|
||||
# 如果顶层没有结果,尝试旧的嵌套结构(向后兼容)
|
||||
reranked = search_results.get("reranked_results", {})
|
||||
if reranked and isinstance(reranked, dict):
|
||||
chunks = reranked.get("chunks", [])
|
||||
statements = reranked.get("statements", [])
|
||||
entities = reranked.get("entities", [])
|
||||
summaries = reranked.get("summaries", [])
|
||||
print(f"✅ 混合检索成功(使用旧格式reranked结果): {len(chunks)} chunks, {len(statements)} 陈述")
|
||||
else:
|
||||
raise ValueError("混合检索返回空结果")
|
||||
else:
|
||||
raise ValueError("混合检索返回空结果")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ 混合检索失败: {e},回退到嵌入检索")
|
||||
search_results = await search_graph_by_embedding(
|
||||
connector=connector,
|
||||
embedder_client=embedder,
|
||||
query_text=q,
|
||||
group_id=group_id,
|
||||
limit=adjusted_limit,
|
||||
include=["chunks", "statements", "entities", "summaries"],
|
||||
)
|
||||
chunks = search_results.get("chunks", [])
|
||||
statements = search_results.get("statements", [])
|
||||
entities = search_results.get("entities", [])
|
||||
summaries = search_results.get("summaries", [])
|
||||
print(f"✅ 回退嵌入检索成功: {len(chunks)} chunks, {len(statements)} 陈述")
|
||||
|
||||
# 🎯 统一处理:构建上下文(所有检索类型共用)
|
||||
for c in chunks:
|
||||
content = str(c.get("content", "")).strip()
|
||||
if content:
|
||||
contexts_all.append(content)
|
||||
|
||||
for s in statements:
|
||||
stmt_text = str(s.get("statement", "")).strip()
|
||||
if stmt_text:
|
||||
contexts_all.append(stmt_text)
|
||||
|
||||
for sm in summaries:
|
||||
summary_text = str(sm.get("summary", "")).strip()
|
||||
if summary_text:
|
||||
contexts_all.append(summary_text)
|
||||
|
||||
# 实体摘要:最多加入前3个高分实体
|
||||
if entities:
|
||||
scored = [e for e in entities if e.get("score") is not None]
|
||||
top_entities = sorted(scored, key=lambda x: x.get("score", 0), reverse=True)[:3] if scored else entities[:3]
|
||||
if top_entities:
|
||||
summary_lines = []
|
||||
for e in top_entities:
|
||||
name = str(e.get("name", "")).strip()
|
||||
etype = str(e.get("entity_type", "")).strip()
|
||||
score = e.get("score")
|
||||
if name:
|
||||
meta = []
|
||||
if etype:
|
||||
meta.append(f"type={etype}")
|
||||
if isinstance(score, (int, float)):
|
||||
meta.append(f"score={score:.3f}")
|
||||
summary_lines.append(f"EntitySummary: {name}{(' [' + '; '.join(meta) + ']') if meta else ''}")
|
||||
if summary_lines:
|
||||
contexts_all.append("\n".join(summary_lines))
|
||||
|
||||
# 关键修复:过滤掉包含当前问题答案的上下文
|
||||
filtered_contexts = []
|
||||
for context in contexts_all:
|
||||
content = str(context)
|
||||
# 排除包含当前问题标准答案的上下文
|
||||
if ref_str and ref_str.strip() and ref_str.strip() in content:
|
||||
print("🚫 过滤掉包含标准答案的上下文")
|
||||
continue
|
||||
filtered_contexts.append(context)
|
||||
|
||||
print(f"📊 过滤后保留 {len(filtered_contexts)} 个上下文 (原 {len(contexts_all)} 个)")
|
||||
contexts_all = filtered_contexts
|
||||
|
||||
# 输出完整的检索结果信息
|
||||
print("🔍 检索结果详情:")
|
||||
if search_results:
|
||||
output_data = {
|
||||
"statements": [
|
||||
{
|
||||
"statement": s.get("statement", "")[:200] + "..." if len(s.get("statement", "")) > 200 else s.get("statement", ""),
|
||||
"score": s.get("score", 0.0)
|
||||
}
|
||||
for s in (statements[:2] if 'statements' in locals() else [])
|
||||
],
|
||||
"dialogues": [
|
||||
{
|
||||
"uuid": d.get("uuid", ""),
|
||||
"group_id": d.get("group_id", ""),
|
||||
"content": d.get("content", "")[:200] + "..." if len(d.get("content", "")) > 200 else d.get("content", ""),
|
||||
"score": d.get("score", 0.0)
|
||||
}
|
||||
for d in (dialogs[:2] if 'dialogs' in locals() else [])
|
||||
],
|
||||
"entities": [
|
||||
{
|
||||
"name": e.get("name", ""),
|
||||
"entity_type": e.get("entity_type", ""),
|
||||
"score": e.get("score", 0.0)
|
||||
}
|
||||
for e in (entities[:2] if 'entities' in locals() else [])
|
||||
]
|
||||
}
|
||||
print(json.dumps(output_data, ensure_ascii=False, indent=2))
|
||||
else:
|
||||
print(" 无检索结果")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ {search_type}检索失败: {e}")
|
||||
contexts_all = []
|
||||
search_results = None
|
||||
|
||||
t1 = time.time()
|
||||
latencies_search.append((t1 - t0) * 1000)
|
||||
|
||||
# 使用智能上下文选择
|
||||
context_text = ""
|
||||
if contexts_all:
|
||||
context_text = smart_context_selection(contexts_all, q, max_chars=max_chars)
|
||||
|
||||
# 如果智能选择后仍然过长,进行最终保护性截断
|
||||
if len(context_text) > max_chars:
|
||||
print(f"⚠️ 智能选择后仍然过长 ({len(context_text)}字符),进行最终截断")
|
||||
context_text = context_text[:max_chars] + "\n\n[最终截断...]"
|
||||
|
||||
# 时间解析
|
||||
anchor_date = datetime(2023, 5, 8) # 使用固定日期确保一致性
|
||||
context_text = _resolve_relative_times(context_text, anchor_date)
|
||||
|
||||
context_text = f"Reference date: {anchor_date.date().isoformat()}\n\n" + context_text
|
||||
|
||||
print(f"📝 最终上下文长度: {len(context_text)} 字符")
|
||||
|
||||
# 显示不同上下文的预览
|
||||
print("🔍 上下文预览:")
|
||||
for j, context in enumerate(contexts_all[:3]): # 显示前3个上下文
|
||||
preview = context[:150].replace('\n', ' ')
|
||||
print(f" 上下文{j+1}: {preview}...")
|
||||
|
||||
else:
|
||||
print("❌ 没有检索到有效上下文")
|
||||
context_text = "No relevant context found."
|
||||
|
||||
# 记录上下文诊断信息
|
||||
per_query_context_counts.append(len(contexts_all))
|
||||
per_query_context_avg_tokens.append(avg_context_tokens([context_text]))
|
||||
per_query_context_chars.append(len(context_text))
|
||||
per_query_context_tokens_total.append(len(_loc_normalize(context_text).split()))
|
||||
|
||||
# LLM 提示词
|
||||
messages = [
|
||||
{"role": "system", "content": (
|
||||
"You are a precise QA assistant. Answer following these rules:\n"
|
||||
"1) Extract the EXACT information mentioned in the context\n"
|
||||
"2) For time questions: calculate actual dates from relative times\n"
|
||||
"3) Return ONLY the answer text in simplest form\n"
|
||||
"4) For dates, use format 'DD Month YYYY' (e.g., '7 May 2023')\n"
|
||||
"5) If no clear answer found, respond with 'Unknown'"
|
||||
)},
|
||||
{"role": "user", "content": f"Question: {q}\n\nContext:\n{context_text}"},
|
||||
]
|
||||
|
||||
t2 = time.time()
|
||||
# 使用异步调用
|
||||
resp = await llm_client.chat(messages=messages)
|
||||
t3 = time.time()
|
||||
latencies_llm.append((t3 - t2) * 1000)
|
||||
|
||||
# 兼容不同的响应格式
|
||||
pred = resp.content.strip() if hasattr(resp, 'content') else (resp["choices"][0]["message"]["content"].strip() if isinstance(resp, dict) else "Unknown")
|
||||
|
||||
# 计算指标(确保使用字符串)
|
||||
f1_val = common_f1(str(pred), ref_str)
|
||||
b1_val = bleu1(str(pred), ref_str)
|
||||
j_val = jaccard(str(pred), ref_str)
|
||||
|
||||
f1s.append(f1_val)
|
||||
b1s.append(b1_val)
|
||||
jss.append(j_val)
|
||||
|
||||
# Accumulate by category
|
||||
cat_counts[cat] = cat_counts.get(cat, 0) + 1
|
||||
cat_f1s.setdefault(cat, []).append(f1_val)
|
||||
cat_b1s.setdefault(cat, []).append(b1_val)
|
||||
cat_jss.setdefault(cat, []).append(j_val)
|
||||
|
||||
# LoCoMo 专用 F1:multi-hop(1) 使用多答案 F1,其它(2/3/4)使用单答案 F1
|
||||
if item.get("category") in [2, 3, 4]:
|
||||
loc_val = loc_f1_score(str(pred), ref_str)
|
||||
elif item.get("category") in [1]:
|
||||
loc_val = loc_multi_f1(str(pred), ref_str)
|
||||
else:
|
||||
loc_val = loc_f1_score(str(pred), ref_str)
|
||||
loc_f1s.append(loc_val)
|
||||
cat_loc_f1s.setdefault(cat, []).append(loc_val)
|
||||
|
||||
# 保存完整的检索结果信息
|
||||
samples.append({
|
||||
"question": q,
|
||||
"answer": ref_str,
|
||||
"category": cat,
|
||||
"prediction": pred,
|
||||
"metrics": {
|
||||
"f1": f1_val,
|
||||
"b1": b1_val,
|
||||
"j": j_val,
|
||||
"loc_f1": loc_val
|
||||
},
|
||||
"retrieval": {
|
||||
"retrieved_documents": len(contexts_all),
|
||||
"context_length": len(context_text),
|
||||
"search_limit": adjusted_limit,
|
||||
"max_chars": max_chars
|
||||
},
|
||||
"timing": {
|
||||
"search_ms": (t1 - t0) * 1000,
|
||||
"llm_ms": (t3 - t2) * 1000
|
||||
}
|
||||
})
|
||||
|
||||
print(f"🤖 LLM 回答: {pred}")
|
||||
print(f"✅ 正确答案: {ref_str}")
|
||||
print(f"📈 当前指标 - F1: {f1_val:.3f}, BLEU-1: {b1_val:.3f}, Jaccard: {j_val:.3f}, LoCoMo F1: {loc_val:.3f}")
|
||||
|
||||
# Compute per-category averages and dispersion (std, iqr)
|
||||
def _percentile(sorted_vals: List[float], p: float) -> float:
|
||||
if not sorted_vals:
|
||||
return 0.0
|
||||
if len(sorted_vals) == 1:
|
||||
return sorted_vals[0]
|
||||
k = (len(sorted_vals) - 1) * p
|
||||
f = int(k)
|
||||
c = f + 1 if f + 1 < len(sorted_vals) else f
|
||||
if f == c:
|
||||
return sorted_vals[f]
|
||||
return sorted_vals[f] + (sorted_vals[c] - sorted_vals[f]) * (k - f)
|
||||
|
||||
by_category: Dict[str, Dict[str, float | int]] = {}
|
||||
for c in cat_counts:
|
||||
f_list = cat_f1s.get(c, [])
|
||||
b_list = cat_b1s.get(c, [])
|
||||
j_list = cat_jss.get(c, [])
|
||||
lf_list = cat_loc_f1s.get(c, [])
|
||||
j_sorted = sorted(j_list)
|
||||
j_std = statistics.stdev(j_list) if len(j_list) > 1 else 0.0
|
||||
j_q75 = _percentile(j_sorted, 0.75)
|
||||
j_q25 = _percentile(j_sorted, 0.25)
|
||||
by_category[c] = {
|
||||
"count": cat_counts[c],
|
||||
"f1": (sum(f_list) / max(len(f_list), 1)) if f_list else 0.0,
|
||||
"b1": (sum(b_list) / max(len(b_list), 1)) if b_list else 0.0,
|
||||
"j": (sum(j_list) / max(len(j_list), 1)) if j_list else 0.0,
|
||||
"j_std": j_std,
|
||||
"j_iqr": (j_q75 - j_q25) if j_list else 0.0,
|
||||
# 参考 LoCoMo 评测的类别专用 F1
|
||||
"loc_f1": (sum(lf_list) / max(len(lf_list), 1)) if lf_list else 0.0,
|
||||
}
|
||||
|
||||
# 累加命中(cum accuracy by category):与 evaluation_stats.py 输出形式相仿
|
||||
cum_accuracy_by_category = {c: sum(cat_loc_f1s.get(c, [])) for c in cat_counts}
|
||||
|
||||
result = {
|
||||
"dataset": "locomo",
|
||||
"items": len(items),
|
||||
"metrics": {
|
||||
"f1": sum(f1s) / max(len(f1s), 1),
|
||||
"b1": sum(b1s) / max(len(b1s), 1),
|
||||
"j": sum(jss) / max(len(jss), 1),
|
||||
# LoCoMo 类别专用 F1 的总体
|
||||
"loc_f1": sum(loc_f1s) / max(len(loc_f1s), 1),
|
||||
},
|
||||
"by_category": by_category,
|
||||
"category_counts": cat_counts,
|
||||
"cum_accuracy_by_category": cum_accuracy_by_category,
|
||||
"context": {
|
||||
"avg_tokens": (sum(per_query_context_avg_tokens) / max(len(per_query_context_avg_tokens), 1)) if per_query_context_avg_tokens else 0.0,
|
||||
"avg_chars": (sum(per_query_context_chars) / max(len(per_query_context_chars), 1)) if per_query_context_chars else 0.0,
|
||||
"count_avg": (sum(per_query_context_counts) / max(len(per_query_context_counts), 1)) if per_query_context_counts else 0.0,
|
||||
"avg_memory_tokens": (sum(per_query_context_tokens_total) / max(len(per_query_context_tokens_total), 1)) if per_query_context_tokens_total else 0.0,
|
||||
},
|
||||
"latency": {
|
||||
"search": latency_stats(latencies_search),
|
||||
"llm": latency_stats(latencies_llm),
|
||||
},
|
||||
"samples": samples,
|
||||
"params": {
|
||||
"group_id": group_id,
|
||||
"search_limit": search_limit,
|
||||
"context_char_budget": context_char_budget,
|
||||
"search_type": search_type,
|
||||
"llm_id": SELECTED_LLM_ID,
|
||||
"retrieval_embedding_id": SELECTED_EMBEDDING_ID,
|
||||
"skip_ingest_if_exists": skip_ingest_if_exists,
|
||||
"llm_timeout": llm_timeout,
|
||||
"llm_max_retries": llm_max_retries,
|
||||
"llm_temperature": llm_temperature,
|
||||
"llm_max_tokens": llm_max_tokens
|
||||
},
|
||||
"timestamp": datetime.now().isoformat()
|
||||
}
|
||||
if output_path:
|
||||
try:
|
||||
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
||||
with open(output_path, "w", encoding="utf-8") as f:
|
||||
json.dump(result, f, ensure_ascii=False, indent=2)
|
||||
print(f"✅ 结果已保存到: {output_path}")
|
||||
except Exception as e:
|
||||
print(f"❌ 保存结果失败: {e}")
|
||||
return result
|
||||
finally:
|
||||
await connector.close()
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Run LoCoMo evaluation with Qwen search")
|
||||
parser.add_argument("--sample_size", type=int, default=1, help="Number of samples to evaluate")
|
||||
parser.add_argument("--group_id", type=str, default=None, help="Group ID for retrieval")
|
||||
parser.add_argument("--search_limit", type=int, default=8, help="Search limit per query")
|
||||
parser.add_argument("--context_char_budget", type=int, default=12000, help="Max characters for context")
|
||||
parser.add_argument("--llm_temperature", type=float, default=0.0, help="LLM temperature")
|
||||
parser.add_argument("--llm_max_tokens", type=int, default=32, help="LLM max tokens")
|
||||
parser.add_argument("--search_type", type=str, default="embedding", choices=["keyword", "embedding", "hybrid"], help="Search type")
|
||||
parser.add_argument("--output_path", type=str, default=None, help="Output path for results")
|
||||
parser.add_argument("--skip_ingest_if_exists", action="store_true", help="Skip ingest if group exists")
|
||||
parser.add_argument("--llm_timeout", type=float, default=10.0, help="LLM timeout in seconds")
|
||||
parser.add_argument("--llm_max_retries", type=int, default=1, help="LLM max retries")
|
||||
args = parser.parse_args()
|
||||
|
||||
load_dotenv()
|
||||
|
||||
result = asyncio.run(run_locomo_eval(
|
||||
sample_size=args.sample_size,
|
||||
group_id=args.group_id,
|
||||
search_limit=args.search_limit,
|
||||
context_char_budget=args.context_char_budget,
|
||||
llm_temperature=args.llm_temperature,
|
||||
llm_max_tokens=args.llm_max_tokens,
|
||||
search_type=args.search_type,
|
||||
output_path=args.output_path,
|
||||
skip_ingest_if_exists=args.skip_ingest_if_exists,
|
||||
llm_timeout=args.llm_timeout,
|
||||
llm_max_retries=args.llm_max_retries
|
||||
))
|
||||
|
||||
print("\n" + "="*50)
|
||||
print("📊 最终评测结果:")
|
||||
print(f" 样本数量: {result['items']}")
|
||||
print(f" F1: {result['metrics']['f1']:.3f}")
|
||||
print(f" BLEU-1: {result['metrics']['b1']:.3f}")
|
||||
print(f" Jaccard: {result['metrics']['j']:.3f}")
|
||||
print(f" LoCoMo F1: {result['metrics']['loc_f1']:.3f}")
|
||||
print(f" 平均上下文长度: {result['context']['avg_chars']:.0f} 字符")
|
||||
print(f" 平均检索延迟: {result['latency']['search']['mean']:.1f}ms")
|
||||
print(f" 平均LLM延迟: {result['latency']['llm']['mean']:.1f}ms")
|
||||
|
||||
if result['by_category']:
|
||||
print("\n📈 按类别细分:")
|
||||
for cat, metrics in result['by_category'].items():
|
||||
print(f" {cat}:")
|
||||
print(f" 样本数: {metrics['count']}")
|
||||
print(f" F1: {metrics['f1']:.3f}")
|
||||
print(f" LoCoMo F1: {metrics['loc_f1']:.3f}")
|
||||
print(f" Jaccard: {metrics['j']:.3f} (±{metrics['j_std']:.3f}, IQR={metrics['j_iqr']:.3f})")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -1,324 +0,0 @@
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Any, Dict, List
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.schemas.memory_config_schema import MemoryConfig
|
||||
|
||||
try:
|
||||
from dotenv import load_dotenv
|
||||
except Exception:
|
||||
def load_dotenv():
|
||||
return None
|
||||
|
||||
from app.core.memory.evaluation.common.metrics import (
|
||||
avg_context_tokens,
|
||||
exact_match,
|
||||
latency_stats,
|
||||
)
|
||||
from app.core.memory.evaluation.extraction_utils import (
|
||||
ingest_contexts_via_full_pipeline,
|
||||
)
|
||||
from app.core.memory.storage_services.search import run_hybrid_search
|
||||
from app.core.memory.utils.config.definitions import (
|
||||
PROJECT_ROOT,
|
||||
SELECTED_EMBEDDING_ID,
|
||||
SELECTED_GROUP_ID,
|
||||
SELECTED_LLM_ID,
|
||||
)
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.db import get_db_context
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
|
||||
|
||||
def smart_context_selection(contexts: List[str], question: str, max_chars: int = 4000) -> str:
|
||||
"""基于问题关键词对上下文进行评分选择,并在预算内拼接文本。"""
|
||||
if not contexts:
|
||||
return ""
|
||||
import re
|
||||
# 提取问题关键词(移除停用词)
|
||||
question_lower = (question or "").lower()
|
||||
stop_words = {
|
||||
'what','when','where','who','why','how','did','do','does','is','are','was','were',
|
||||
'the','a','an','and','or','but'
|
||||
}
|
||||
question_words = set(re.findall(r"\b\w+\b", question_lower))
|
||||
question_words = {w for w in question_words if w not in stop_words and len(w) > 2}
|
||||
|
||||
# 评分
|
||||
scored = []
|
||||
for i, ctx in enumerate(contexts):
|
||||
ctx_lower = (ctx or "").lower()
|
||||
score = 0
|
||||
matches = 0
|
||||
for w in question_words:
|
||||
if w in ctx_lower:
|
||||
matches += 1
|
||||
score += ctx_lower.count(w) * 2
|
||||
length = len(ctx)
|
||||
if 100 < length < 2000:
|
||||
score += 5
|
||||
elif length >= 2000:
|
||||
score += 2
|
||||
if i < 3:
|
||||
score += 3
|
||||
scored.append((score, ctx, matches))
|
||||
|
||||
scored.sort(key=lambda x: x[0], reverse=True)
|
||||
|
||||
# 选择直到达到字符限制,必要时截断包含关键词的段落
|
||||
selected: List[str] = []
|
||||
total = 0
|
||||
for score, ctx, _ in scored:
|
||||
if total + len(ctx) <= max_chars:
|
||||
selected.append(ctx)
|
||||
total += len(ctx)
|
||||
else:
|
||||
if score > 10 and total < max_chars - 200:
|
||||
remaining = max_chars - total
|
||||
lines = ctx.split('\n')
|
||||
rel_lines: List[str] = []
|
||||
cur = 0
|
||||
for line in lines:
|
||||
l = line.lower()
|
||||
if any(w in l for w in question_words) and cur < remaining - 50:
|
||||
rel_lines.append(line)
|
||||
cur += len(line)
|
||||
if rel_lines:
|
||||
truncated = '\n'.join(rel_lines)
|
||||
if len(truncated) > 50:
|
||||
selected.append(truncated + "\n[相关内容截断...]")
|
||||
total += len(truncated)
|
||||
break
|
||||
return "\n\n".join(selected)
|
||||
|
||||
|
||||
def build_context_from_dialog(dialog_obj: Dict[str, Any]) -> str:
|
||||
"""Compose a text context from `dialog` list in msc_self_instruct item."""
|
||||
parts: List[str] = []
|
||||
for turn in dialog_obj.get("dialog", []):
|
||||
speaker = turn.get("speaker", "")
|
||||
text = turn.get("text", "")
|
||||
if text:
|
||||
parts.append(f"{speaker}: {text}")
|
||||
return "\n".join(parts)
|
||||
|
||||
|
||||
def _combine_dialogues_for_hybrid(results: Dict[str, Any]) -> List[Dict[str, Any]]:
|
||||
"""Combine dialogues from embedding and keyword searches (embedding first)."""
|
||||
if results is None:
|
||||
return []
|
||||
emb = []
|
||||
kw = []
|
||||
if isinstance(results.get("embedding_search"), dict):
|
||||
emb = results.get("embedding_search", {}).get("dialogues", []) or []
|
||||
elif isinstance(results.get("dialogues"), list):
|
||||
emb = results.get("dialogues", []) or []
|
||||
if isinstance(results.get("keyword_search"), dict):
|
||||
kw = results.get("keyword_search", {}).get("dialogues", []) or []
|
||||
seen = set()
|
||||
merged: List[Dict[str, Any]] = []
|
||||
for d in emb:
|
||||
k = (str(d.get("uuid", "")), str(d.get("content", "")))
|
||||
if k not in seen:
|
||||
merged.append(d)
|
||||
seen.add(k)
|
||||
for d in kw:
|
||||
k = (str(d.get("uuid", "")), str(d.get("content", "")))
|
||||
if k not in seen:
|
||||
merged.append(d)
|
||||
seen.add(k)
|
||||
return merged
|
||||
|
||||
|
||||
async def run_memsciqa_eval(sample_size: int = 1, group_id: str | None = None, search_limit: int = 8, context_char_budget: int = 4000, llm_temperature: float = 0.0, llm_max_tokens: int = 64, search_type: str = "hybrid", memory_config: "MemoryConfig" = None) -> Dict[str, Any]:
|
||||
group_id = group_id or SELECTED_GROUP_ID
|
||||
# Load data
|
||||
data_path = os.path.join(PROJECT_ROOT, "data", "msc_self_instruct.jsonl")
|
||||
if not os.path.exists(data_path):
|
||||
data_path = os.path.join(os.getcwd(), "data", "msc_self_instruct.jsonl")
|
||||
with open(data_path, "r", encoding="utf-8") as f:
|
||||
lines = f.readlines()
|
||||
items: List[Dict[str, Any]] = [json.loads(l) for l in lines[:sample_size]]
|
||||
# 改为:每条样本仅摄入一个上下文(完整对话转录),避免多上下文摄入
|
||||
# 说明:memsciqa 数据集的每个样本天然只有一个对话,保持按样本一上下文的策略
|
||||
contexts: List[str] = [build_context_from_dialog(item) for item in items]
|
||||
await ingest_contexts_via_full_pipeline(contexts, group_id)
|
||||
|
||||
# LLM client (使用异步调用)
|
||||
with get_db_context() as db:
|
||||
factory = MemoryClientFactory(db)
|
||||
llm_client = factory.get_llm_client(SELECTED_LLM_ID)
|
||||
|
||||
# Evaluate each item
|
||||
connector = Neo4jConnector()
|
||||
latencies_llm: List[float] = []
|
||||
latencies_search: List[float] = []
|
||||
contexts_used: List[str] = []
|
||||
correct_flags: List[float] = []
|
||||
f1s: List[float] = []
|
||||
b1s: List[float] = []
|
||||
jss: List[float] = []
|
||||
try:
|
||||
for item in items:
|
||||
question = item.get("self_instruct", {}).get("B", "") or item.get("question", "")
|
||||
reference = item.get("self_instruct", {}).get("A", "") or item.get("answer", "")
|
||||
# 检索:对齐 locomo 的三路检索(dialogues/statements/entities)
|
||||
t0 = time.time()
|
||||
try:
|
||||
results = await run_hybrid_search(
|
||||
query_text=question,
|
||||
search_type=search_type,
|
||||
group_id=group_id,
|
||||
limit=search_limit,
|
||||
include=["dialogues", "statements", "entities"],
|
||||
output_path=None,
|
||||
memory_config=memory_config,
|
||||
)
|
||||
except Exception:
|
||||
results = None
|
||||
t1 = time.time()
|
||||
latencies_search.append((t1 - t0) * 1000)
|
||||
|
||||
# 构建上下文:包含对话、陈述和实体摘要,并智能选择
|
||||
contexts_all: List[str] = []
|
||||
if results:
|
||||
if search_type == "hybrid":
|
||||
emb = results.get("embedding_search", {}) if isinstance(results.get("embedding_search"), dict) else {}
|
||||
kw = results.get("keyword_search", {}) if isinstance(results.get("keyword_search"), dict) else {}
|
||||
emb_dialogs = emb.get("dialogues", [])
|
||||
emb_statements = emb.get("statements", [])
|
||||
emb_entities = emb.get("entities", [])
|
||||
kw_dialogs = kw.get("dialogues", [])
|
||||
kw_statements = kw.get("statements", [])
|
||||
kw_entities = kw.get("entities", [])
|
||||
all_dialogs = emb_dialogs + kw_dialogs
|
||||
all_statements = emb_statements + kw_statements
|
||||
all_entities = emb_entities + kw_entities
|
||||
|
||||
# 简单去重与限制
|
||||
seen_texts = set()
|
||||
for d in all_dialogs:
|
||||
text = str(d.get("content", "")).strip()
|
||||
if text and text not in seen_texts:
|
||||
contexts_all.append(text)
|
||||
seen_texts.add(text)
|
||||
if len(contexts_all) >= search_limit:
|
||||
break
|
||||
for s in all_statements:
|
||||
text = str(s.get("statement", "")).strip()
|
||||
if text and text not in seen_texts:
|
||||
contexts_all.append(text)
|
||||
seen_texts.add(text)
|
||||
if len(contexts_all) >= search_limit:
|
||||
break
|
||||
# 实体摘要(最多3个)
|
||||
names = []
|
||||
merged_entities = all_entities[:]
|
||||
for e in merged_entities:
|
||||
name = str(e.get("name", "")).strip()
|
||||
if name and name not in names:
|
||||
names.append(name)
|
||||
if len(names) >= 3:
|
||||
break
|
||||
if names:
|
||||
contexts_all.append("EntitySummary: " + ", ".join(names))
|
||||
else:
|
||||
dialogs = results.get("dialogues", [])
|
||||
statements = results.get("statements", [])
|
||||
entities = results.get("entities", [])
|
||||
for d in dialogs:
|
||||
text = str(d.get("content", "")).strip()
|
||||
if text:
|
||||
contexts_all.append(text)
|
||||
for s in statements:
|
||||
text = str(s.get("statement", "")).strip()
|
||||
if text:
|
||||
contexts_all.append(text)
|
||||
names = [str(e.get("name", "")).strip() for e in entities[:3] if e.get("name")]
|
||||
if names:
|
||||
contexts_all.append("EntitySummary: " + ", ".join(names))
|
||||
|
||||
# 智能选择并截断到预算
|
||||
context_text = smart_context_selection(contexts_all, question, max_chars=context_char_budget) if contexts_all else ""
|
||||
if not context_text:
|
||||
context_text = "No relevant context found."
|
||||
contexts_used.append(context_text[:200])
|
||||
|
||||
# Call LLM (使用异步调用)
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a QA assistant. Answer in English. Strictly follow: 1) If the context contains the answer, copy the shortest exact span from the context as the answer; 2) If the answer cannot be determined from the context, respond with 'Unknown'; 3) Return ONLY the answer text, no explanations."},
|
||||
{"role": "user", "content": f"Question: {question}\n\nContext:\n{context_text}"},
|
||||
]
|
||||
t2 = time.time()
|
||||
resp = await llm_client.chat(messages=messages)
|
||||
t3 = time.time()
|
||||
latencies_llm.append((t3 - t2) * 1000)
|
||||
pred = resp.content.strip() if hasattr(resp, 'content') else (resp["choices"][0]["message"]["content"].strip() if isinstance(resp, dict) else str(resp).strip())
|
||||
# Metrics: F1, BLEU-1, Jaccard; keep exact match for reference
|
||||
correct_flags.append(exact_match(pred, reference))
|
||||
from app.core.memory.evaluation.common.metrics import (
|
||||
bleu1,
|
||||
f1_score,
|
||||
jaccard,
|
||||
)
|
||||
f1s.append(f1_score(str(pred), str(reference)))
|
||||
b1s.append(bleu1(str(pred), str(reference)))
|
||||
jss.append(jaccard(str(pred), str(reference)))
|
||||
|
||||
# Aggregate metrics
|
||||
acc = sum(correct_flags) / max(len(correct_flags), 1)
|
||||
ctx_avg_tokens = avg_context_tokens(contexts_used)
|
||||
result = {
|
||||
"dataset": "memsciqa",
|
||||
"items": len(items),
|
||||
"metrics": {
|
||||
"accuracy": acc,
|
||||
# Placeholders for extensibility
|
||||
"f1": (sum(f1s) / max(len(f1s), 1)) if f1s else 0.0,
|
||||
"bleu1": (sum(b1s) / max(len(b1s), 1)) if b1s else 0.0,
|
||||
"jaccard": (sum(jss) / max(len(jss), 1)) if jss else 0.0,
|
||||
},
|
||||
"latency": {
|
||||
"search": latency_stats(latencies_search),
|
||||
"llm": latency_stats(latencies_llm),
|
||||
},
|
||||
"avg_context_tokens": ctx_avg_tokens,
|
||||
}
|
||||
return result
|
||||
finally:
|
||||
await connector.close()
|
||||
|
||||
|
||||
def main():
|
||||
load_dotenv()
|
||||
parser = argparse.ArgumentParser(description="Evaluate DMR (memsciqa) with graph search and Qwen")
|
||||
parser.add_argument("--sample-size", type=int, default=1, help="评测样本数量")
|
||||
parser.add_argument("--group-id", type=str, default=None, help="可选 group_id,默认取 runtime.json")
|
||||
parser.add_argument("--search-limit", type=int, default=8, help="每类检索最大返回数")
|
||||
parser.add_argument("--context-char-budget", type=int, default=4000, help="上下文字符预算")
|
||||
parser.add_argument("--llm-temperature", type=float, default=0.0, help="LLM 温度")
|
||||
parser.add_argument("--llm-max-tokens", type=int, default=64, help="LLM 最大生成长度")
|
||||
parser.add_argument("--search-type", type=str, choices=["keyword","embedding","hybrid"], default="hybrid", help="检索类型")
|
||||
args = parser.parse_args()
|
||||
|
||||
result = asyncio.run(
|
||||
run_memsciqa_eval(
|
||||
sample_size=args.sample_size,
|
||||
group_id=args.group_id,
|
||||
search_limit=args.search_limit,
|
||||
context_char_budget=args.context_char_budget,
|
||||
llm_temperature=args.llm_temperature,
|
||||
llm_max_tokens=args.llm_max_tokens,
|
||||
search_type=args.search_type,
|
||||
)
|
||||
)
|
||||
print(json.dumps(result, ensure_ascii=False, indent=2))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,576 +0,0 @@
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List
|
||||
|
||||
try:
|
||||
from dotenv import load_dotenv
|
||||
except Exception:
|
||||
def load_dotenv():
|
||||
return None
|
||||
|
||||
# 路径与模块导入保持与现有评估脚本一致
|
||||
import sys
|
||||
|
||||
_THIS_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||
_PROJECT_ROOT = os.path.dirname(os.path.dirname(_THIS_DIR))
|
||||
_SRC_DIR = os.path.join(_PROJECT_ROOT, "src")
|
||||
for _p in (_SRC_DIR, _PROJECT_ROOT):
|
||||
if _p not in sys.path:
|
||||
sys.path.insert(0, _p)
|
||||
|
||||
# 对齐 locomo_test 的检索逻辑:直接使用 graph_search 与 Neo4jConnector/Embedder1
|
||||
from app.core.memory.evaluation.common.metrics import (
|
||||
avg_context_tokens,
|
||||
exact_match,
|
||||
latency_stats,
|
||||
)
|
||||
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
|
||||
from app.core.memory.utils.config.definitions import (
|
||||
PROJECT_ROOT,
|
||||
SELECTED_EMBEDDING_ID,
|
||||
SELECTED_GROUP_ID,
|
||||
SELECTED_LLM_ID,
|
||||
)
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
from app.db import get_db_context
|
||||
from app.repositories.neo4j.graph_search import search_graph, search_graph_by_embedding
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
|
||||
try:
|
||||
from app.core.memory.evaluation.common.metrics import bleu1, f1_score, jaccard
|
||||
except Exception:
|
||||
# 兜底:简单实现(必要时)
|
||||
def f1_score(pred: str, ref: str) -> float:
|
||||
ps = pred.lower().split()
|
||||
rs = ref.lower().split()
|
||||
if not ps or not rs:
|
||||
return 0.0
|
||||
tp = len(set(ps) & set(rs))
|
||||
if tp == 0:
|
||||
return 0.0
|
||||
precision = tp / len(ps)
|
||||
recall = tp / len(rs)
|
||||
if precision + recall == 0:
|
||||
return 0.0
|
||||
return 2 * precision * recall / (precision + recall)
|
||||
|
||||
def bleu1(pred: str, ref: str) -> float:
|
||||
ps = pred.lower().split()
|
||||
rs = ref.lower().split()
|
||||
if not ps or not rs:
|
||||
return 0.0
|
||||
overlap = len([w for w in ps if w in rs])
|
||||
return overlap / max(len(ps), 1)
|
||||
|
||||
def jaccard(pred: str, ref: str) -> float:
|
||||
ps = set(pred.lower().split())
|
||||
rs = set(ref.lower().split())
|
||||
union = len(ps | rs)
|
||||
if union == 0:
|
||||
return 0.0
|
||||
return len(ps & rs) / union
|
||||
|
||||
|
||||
def smart_context_selection(contexts: List[str], question: str, max_chars: int = 4000) -> str:
|
||||
"""基于问题关键词对上下文进行评分选择,并在预算内拼接文本。
|
||||
|
||||
参考 evaluation/memsciqa/evaluate_qa.py 的实现,避免路径导入带来的不稳定。
|
||||
"""
|
||||
if not contexts:
|
||||
return ""
|
||||
question_lower = (question or "").lower()
|
||||
stop_words = {
|
||||
'what','when','where','who','why','how','did','do','does','is','are','was','were',
|
||||
'the','a','an','and','or','but'
|
||||
}
|
||||
question_words = set(re.findall(r"\b\w+\b", question_lower))
|
||||
question_words = {w for w in question_words if w not in stop_words and len(w) > 2}
|
||||
|
||||
scored = []
|
||||
for i, ctx in enumerate(contexts):
|
||||
ctx_lower = (ctx or "").lower()
|
||||
score = 0
|
||||
matches = 0
|
||||
for w in question_words:
|
||||
if w in ctx_lower:
|
||||
matches += 1
|
||||
score += ctx_lower.count(w) * 2
|
||||
length = len(ctx)
|
||||
if 100 < length < 2000:
|
||||
score += 5
|
||||
elif length >= 2000:
|
||||
score += 2
|
||||
if i < 3:
|
||||
score += 3
|
||||
scored.append((score, ctx, matches))
|
||||
|
||||
scored.sort(key=lambda x: x[0], reverse=True)
|
||||
|
||||
selected: List[str] = []
|
||||
total = 0
|
||||
for score, ctx, _ in scored:
|
||||
if total + len(ctx) <= max_chars:
|
||||
selected.append(ctx)
|
||||
total += len(ctx)
|
||||
else:
|
||||
if score > 10 and total < max_chars - 200:
|
||||
remaining = max_chars - total
|
||||
lines = ctx.split('\n')
|
||||
rel_lines: List[str] = []
|
||||
cur = 0
|
||||
for line in lines:
|
||||
l = line.lower()
|
||||
if any(w in l for w in question_words) and cur < remaining - 50:
|
||||
rel_lines.append(line)
|
||||
cur += len(line)
|
||||
if rel_lines:
|
||||
truncated = '\n'.join(rel_lines)
|
||||
if len(truncated) > 50:
|
||||
selected.append(truncated + "\n[相关内容截断...]")
|
||||
total += len(truncated)
|
||||
break
|
||||
return "\n\n".join(selected)
|
||||
|
||||
|
||||
def extract_question_keywords(question: str, max_keywords: int = 8) -> List[str]:
|
||||
"""提取问题中的关键词(简单英文分词,去停用词,长度>=3)。"""
|
||||
ql = (question or "").lower()
|
||||
stop_words = {
|
||||
'what','when','where','who','why','how','did','do','does','is','are','was','were',
|
||||
'the','a','an','and','or','but','of','to','in','on','for','with','from','that','this'
|
||||
}
|
||||
words = re.findall(r"\b[\w-]+\b", ql)
|
||||
kws = [w for w in words if w not in stop_words and len(w) >= 3]
|
||||
# 去重保序
|
||||
seen = set()
|
||||
uniq = []
|
||||
for w in kws:
|
||||
if w not in seen:
|
||||
uniq.append(w)
|
||||
seen.add(w)
|
||||
if len(uniq) >= max_keywords:
|
||||
break
|
||||
return uniq
|
||||
|
||||
|
||||
def analyze_contexts_simple(contexts: List[str], keywords: List[str], top_n: int = 5) -> List[Dict[str, int | float]]:
|
||||
"""对上下文进行简单相关性打分,仅用于控制台可视化。
|
||||
|
||||
评分: score = match_count*200 + min(len(text), 100000)/100
|
||||
"""
|
||||
results = []
|
||||
for ctx in contexts:
|
||||
tl = (ctx or "").lower()
|
||||
match_count = sum(1 for k in keywords if k in tl)
|
||||
length = len(ctx)
|
||||
score = match_count * 200 + min(length, 100000) / 100.0
|
||||
results.append({"score": float(f"{score:.0f}"), "match": match_count, "length": length})
|
||||
results.sort(key=lambda x: (x["score"], x["match"], x["length"]), reverse=True)
|
||||
return results[:max(top_n, 0)]
|
||||
|
||||
|
||||
# 纯测试脚本不进行摄入;若需摄入请使用 evaluate_qa.py
|
||||
|
||||
|
||||
def load_dataset_memsciqa(data_path: str) -> List[Dict[str, Any]]:
|
||||
if not os.path.exists(data_path):
|
||||
raise FileNotFoundError(f"未找到数据集: {data_path}")
|
||||
items: List[Dict[str, Any]] = []
|
||||
with open(data_path, "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
items.append(json.loads(line))
|
||||
except Exception:
|
||||
# 跳过坏行但不中断
|
||||
continue
|
||||
return items
|
||||
|
||||
|
||||
async def run_memsciqa_test(
|
||||
sample_size: int = 3,
|
||||
group_id: str | None = None,
|
||||
search_limit: int = 8,
|
||||
context_char_budget: int = 4000,
|
||||
llm_temperature: float = 0.0,
|
||||
llm_max_tokens: int = 64,
|
||||
search_type: str = "embedding",
|
||||
data_path: str | None = None,
|
||||
start_index: int = 0,
|
||||
verbose: bool = True,
|
||||
) -> Dict[str, Any]:
|
||||
"""memsciqa 增强测试脚本:结合 evaluate_qa 的三路检索与智能上下文选择。
|
||||
|
||||
- 支持从指定索引开始与评估全部样本(sample_size<=0)
|
||||
- 支持在摄入前重置组(清空图)与跳过摄入
|
||||
- 支持 keyword / embedding / hybrid 三种检索
|
||||
"""
|
||||
|
||||
# 默认使用指定的 memsci 组 ID
|
||||
group_id = group_id or "group_memsci"
|
||||
|
||||
# 数据路径解析(项目根与当前工作目录兜底)
|
||||
if not data_path:
|
||||
proj_path = os.path.join(PROJECT_ROOT, "data", "msc_self_instruct.jsonl")
|
||||
cwd_path = os.path.join(os.getcwd(), "data", "msc_self_instruct.jsonl")
|
||||
if os.path.exists(proj_path):
|
||||
data_path = proj_path
|
||||
elif os.path.exists(cwd_path):
|
||||
data_path = cwd_path
|
||||
else:
|
||||
raise FileNotFoundError("未找到数据集: data/msc_self_instruct.jsonl,请确保其存在于项目根目录或当前工作目录的 data 目录下。")
|
||||
|
||||
# 加载数据
|
||||
all_items = load_dataset_memsciqa(data_path)
|
||||
if sample_size is None or sample_size <= 0:
|
||||
items = all_items[start_index:]
|
||||
else:
|
||||
items = all_items[start_index:start_index + sample_size]
|
||||
|
||||
# 初始化 LLM(纯测试:不进行摄入)
|
||||
with get_db_context() as db:
|
||||
factory = MemoryClientFactory(db)
|
||||
llm = factory.get_llm_client(SELECTED_LLM_ID)
|
||||
|
||||
# 初始化 Neo4j 连接与向量检索 Embedder(对齐 locomo_test)
|
||||
connector = Neo4jConnector()
|
||||
embedder = None
|
||||
if search_type in ("embedding", "hybrid"):
|
||||
with get_db_context() as db:
|
||||
config_service = MemoryConfigService(db)
|
||||
cfg_dict = config_service.get_embedder_config(SELECTED_EMBEDDING_ID)
|
||||
embedder = OpenAIEmbedderClient(
|
||||
model_config=RedBearModelConfig.model_validate(cfg_dict)
|
||||
)
|
||||
|
||||
# 评估循环
|
||||
latencies_llm: List[float] = []
|
||||
latencies_search: List[float] = []
|
||||
# 存储完整上下文文本用于统计
|
||||
contexts_used: List[str] = []
|
||||
per_query_context_chars: List[int] = []
|
||||
per_query_context_counts: List[int] = []
|
||||
correct_flags: List[float] = []
|
||||
f1s: List[float] = []
|
||||
b1s: List[float] = []
|
||||
jss: List[float] = []
|
||||
samples: List[Dict[str, Any]] = []
|
||||
|
||||
total_items = len(items)
|
||||
for idx, item in enumerate(items):
|
||||
if verbose:
|
||||
print(f"\n🧪 评估样本: {idx+1}/{total_items}")
|
||||
question = item.get("self_instruct", {}).get("B", "") or item.get("question", "")
|
||||
reference = item.get("self_instruct", {}).get("A", "") or item.get("answer", "")
|
||||
|
||||
# 三路检索:chunks/statements/entities/summaries(对齐 qwen_search_eval.py)
|
||||
t0 = time.time()
|
||||
results = None
|
||||
try:
|
||||
if search_type in ("embedding", "hybrid"):
|
||||
# 使用嵌入检索(与 qwen_search_eval 对齐)
|
||||
results = await search_graph_by_embedding(
|
||||
connector=connector,
|
||||
embedder_client=embedder,
|
||||
query_text=question,
|
||||
group_id=group_id,
|
||||
limit=search_limit,
|
||||
include=["chunks", "statements", "entities", "summaries"], # 使用 chunks 而不是 dialogues
|
||||
)
|
||||
elif search_type == "keyword":
|
||||
# 关键词检索(直接调用 graph_search)
|
||||
results = await search_graph(
|
||||
connector=connector,
|
||||
q=question,
|
||||
group_id=group_id,
|
||||
limit=search_limit,
|
||||
include=["chunks", "statements", "entities", "summaries"], # 使用 chunks 而不是 dialogues
|
||||
)
|
||||
except Exception:
|
||||
results = None
|
||||
t1 = time.time()
|
||||
search_ms = (t1 - t0) * 1000
|
||||
latencies_search.append(search_ms)
|
||||
|
||||
# 构建上下文:包含 chunks、陈述、摘要和实体(对齐 qwen_search_eval.py)
|
||||
contexts_all: List[str] = []
|
||||
retrieved_counts: Dict[str, int] = {}
|
||||
if results:
|
||||
chunks = results.get("chunks", [])
|
||||
statements = results.get("statements", [])
|
||||
entities = results.get("entities", [])
|
||||
summaries = results.get("summaries", [])
|
||||
retrieved_counts = {
|
||||
"chunks": len(chunks),
|
||||
"statements": len(statements),
|
||||
"entities": len(entities),
|
||||
"summaries": len(summaries),
|
||||
}
|
||||
# 优先使用 chunks
|
||||
for c in chunks:
|
||||
text = str(c.get("content", "")).strip()
|
||||
if text:
|
||||
contexts_all.append(text)
|
||||
# 然后是 statements
|
||||
for s in statements:
|
||||
text = str(s.get("statement", "")).strip()
|
||||
if text:
|
||||
contexts_all.append(text)
|
||||
# 然后是 summaries
|
||||
for sm in summaries:
|
||||
text = str(sm.get("summary", "")).strip()
|
||||
if text:
|
||||
contexts_all.append(text)
|
||||
# 实体摘要:最多加入前3个高分实体(对齐 qwen_search_eval.py)
|
||||
scored = [e for e in entities if e.get("score") is not None]
|
||||
top_entities = sorted(scored, key=lambda x: x.get("score", 0), reverse=True)[:3] if scored else entities[:3]
|
||||
if top_entities:
|
||||
summary_lines = []
|
||||
for e in top_entities:
|
||||
name = str(e.get("name", "")).strip()
|
||||
etype = str(e.get("entity_type", "")).strip()
|
||||
score = e.get("score")
|
||||
if name:
|
||||
meta = []
|
||||
if etype:
|
||||
meta.append(f"type={etype}")
|
||||
if isinstance(score, (int, float)):
|
||||
meta.append(f"score={score:.3f}")
|
||||
summary_lines.append(f"EntitySummary: {name}{(' [' + '; '.join(meta) + ']') if meta else ''}")
|
||||
if summary_lines:
|
||||
contexts_all.append("\n".join(summary_lines))
|
||||
|
||||
if verbose:
|
||||
if retrieved_counts:
|
||||
print(f"✅ 检索成功: {retrieved_counts.get('chunks',0)} chunks, {retrieved_counts.get('statements',0)} 条陈述, {retrieved_counts.get('entities',0)} 个实体, {retrieved_counts.get('summaries',0)} 个摘要")
|
||||
print(f"📊 有效上下文数量: {len(contexts_all)}")
|
||||
q_keywords = extract_question_keywords(question, max_keywords=8)
|
||||
if q_keywords:
|
||||
print(f"🔍 问题关键词: {set(q_keywords)}")
|
||||
if contexts_all:
|
||||
analysis = analyze_contexts_simple(contexts_all, q_keywords, top_n=5)
|
||||
if analysis:
|
||||
print("📊 上下文相关性分析:")
|
||||
for a in analysis:
|
||||
print(f" - 得分: {int(a['score'])}, 关键词匹配: {a['match']}, 长度: {a['length']}")
|
||||
# 打印检索到的上下文预览,便于定位为何为 Unknown
|
||||
print("🔎 上下文预览(最多前10条,每条截断展示):")
|
||||
for i, ctx in enumerate(contexts_all[:10]):
|
||||
preview = str(ctx).replace("\n", " ")
|
||||
if len(preview) > 300:
|
||||
preview = preview[:300] + "..."
|
||||
print(f" [{i+1}] 长度: {len(ctx)} | 片段: {preview}")
|
||||
# 标注参考答案是否出现在任一上下文中
|
||||
ref_lower = (str(reference) or "").lower()
|
||||
if ref_lower:
|
||||
hits = []
|
||||
for i, ctx in enumerate(contexts_all):
|
||||
if ref_lower in str(ctx).lower():
|
||||
hits.append(i+1)
|
||||
print(f"🔗 参考答案命中上下文条数: {len(hits)}" + (f" | 命中索引: {hits}" if hits else ""))
|
||||
|
||||
context_text = smart_context_selection(contexts_all, question, max_chars=context_char_budget) if contexts_all else ""
|
||||
if not context_text:
|
||||
context_text = "No relevant context found."
|
||||
contexts_used.append(context_text)
|
||||
per_query_context_chars.append(len(context_text))
|
||||
per_query_context_counts.append(len(contexts_all))
|
||||
|
||||
if verbose:
|
||||
selected_count = (context_text.count("\n\n") + 1) if context_text else 0
|
||||
print(f"✅ 智能选择: {selected_count}个上下文, 总长度: {len(context_text)}字符")
|
||||
# 展示拼接后的上下文片段,便于核查是否包含答案
|
||||
concat_preview = context_text.replace("\n", " ")
|
||||
if len(concat_preview) > 600:
|
||||
concat_preview = concat_preview[:600] + "..."
|
||||
print(f"🧵 拼接上下文预览: {concat_preview}")
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": (
|
||||
"You are a QA assistant. Answer in English. Follow these guidelines:\n"
|
||||
"1) If the context contains information to answer the question, provide a concise answer based on the context;\n"
|
||||
"2) If the context does not contain enough information to answer the question, respond with 'Unknown';\n"
|
||||
"3) Keep your answer brief and to the point;\n"
|
||||
"4) Do not add explanations or additional text beyond the answer."
|
||||
),
|
||||
},
|
||||
{"role": "user", "content": f"Question: {question}\n\nContext:\n{context_text}"},
|
||||
]
|
||||
|
||||
t2 = time.time()
|
||||
try:
|
||||
# 使用异步调用
|
||||
resp = await llm.chat(messages=messages)
|
||||
# 更健壮的响应解析,处理不同的LLM响应格式
|
||||
if hasattr(resp, 'content'):
|
||||
pred = resp.content.strip()
|
||||
elif isinstance(resp, dict) and "choices" in resp and len(resp["choices"]) > 0:
|
||||
pred = resp["choices"][0]["message"]["content"].strip()
|
||||
elif isinstance(resp, dict) and "content" in resp:
|
||||
pred = resp["content"].strip()
|
||||
elif isinstance(resp, str):
|
||||
pred = resp.strip()
|
||||
else:
|
||||
pred = "Unknown"
|
||||
print(f"⚠️ LLM响应格式异常: {type(resp)} - {resp}")
|
||||
|
||||
# 检查预测是否为"Unknown"或空,如果是则检查上下文是否真的没有答案
|
||||
if pred.lower() in ["unknown", ""]:
|
||||
# 如果参考答案在上下文中存在,但LLM返回Unknown,可能是提示词问题
|
||||
ref_lower = (str(reference) or "").lower()
|
||||
if ref_lower and any(ref_lower in ctx.lower() for ctx in contexts_all):
|
||||
print("⚠️ 参考答案在上下文中存在但LLM返回Unknown,检查提示词")
|
||||
except Exception as e:
|
||||
# 更详细的错误处理
|
||||
pred = "Unknown"
|
||||
print(f"⚠️ LLM调用异常: {e}")
|
||||
t3 = time.time()
|
||||
llm_ms = (t3 - t2) * 1000
|
||||
latencies_llm.append(llm_ms)
|
||||
|
||||
exact = exact_match(pred, reference)
|
||||
correct_flags.append(exact)
|
||||
f1_val = f1_score(str(pred), str(reference))
|
||||
b1_val = bleu1(str(pred), str(reference))
|
||||
j_val = jaccard(str(pred), str(reference))
|
||||
f1s.append(f1_val)
|
||||
b1s.append(b1_val)
|
||||
jss.append(j_val)
|
||||
|
||||
if verbose:
|
||||
print(f"🤖 LLM 回答: {pred}")
|
||||
print(f"✅ 正确答案: {reference}")
|
||||
print(f"📈 当前指标 - F1: {f1_val:.3f}, BLEU-1: {b1_val:.3f}, Jaccard: {j_val:.3f}")
|
||||
print(f"⏱️ 延迟 - 检索: {search_ms:.0f}ms, LLM: {llm_ms:.0f}ms")
|
||||
|
||||
# 对齐 locomo/qwen_search_eval.py 的样本输出结构
|
||||
samples.append({
|
||||
"question": str(question),
|
||||
"answer": str(reference),
|
||||
"prediction": str(pred),
|
||||
"metrics": {
|
||||
"f1": f1_val,
|
||||
"b1": b1_val,
|
||||
"j": j_val
|
||||
},
|
||||
"retrieval": {
|
||||
"retrieved_documents": len(contexts_all),
|
||||
"context_length": len(context_text),
|
||||
"search_limit": search_limit,
|
||||
"max_chars": context_char_budget
|
||||
},
|
||||
"timing": {
|
||||
"search_ms": search_ms,
|
||||
"llm_ms": llm_ms
|
||||
}
|
||||
})
|
||||
|
||||
# 计算总体指标与聚合
|
||||
acc = sum(correct_flags) / max(len(correct_flags), 1)
|
||||
ctx_avg_tokens = avg_context_tokens(contexts_used)
|
||||
result = {
|
||||
"dataset": "memsciqa",
|
||||
"items": len(items),
|
||||
"metrics": {
|
||||
"f1": (sum(f1s) / max(len(f1s), 1)) if f1s else 0.0,
|
||||
"b1": (sum(b1s) / max(len(b1s), 1)) if b1s else 0.0,
|
||||
"j": (sum(jss) / max(len(jss), 1)) if jss else 0.0,
|
||||
},
|
||||
"context": {
|
||||
"avg_tokens": ctx_avg_tokens,
|
||||
"avg_chars": (sum(per_query_context_chars) / max(len(per_query_context_chars), 1)) if per_query_context_chars else 0.0,
|
||||
"count_avg": (sum(per_query_context_counts) / max(len(per_query_context_counts), 1)) if per_query_context_counts else 0.0,
|
||||
"avg_memory_tokens": 0.0
|
||||
},
|
||||
"latency": {
|
||||
"search": latency_stats(latencies_search),
|
||||
"llm": latency_stats(latencies_llm),
|
||||
},
|
||||
"samples": samples,
|
||||
"params": {
|
||||
"group_id": group_id,
|
||||
"search_limit": search_limit,
|
||||
"context_char_budget": context_char_budget,
|
||||
"llm_temperature": llm_temperature,
|
||||
"llm_max_tokens": llm_max_tokens,
|
||||
"search_type": search_type,
|
||||
"start_index": start_index,
|
||||
"llm_id": SELECTED_LLM_ID,
|
||||
"retrieval_embedding_id": SELECTED_EMBEDDING_ID
|
||||
},
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
}
|
||||
try:
|
||||
await connector.close()
|
||||
except Exception:
|
||||
pass
|
||||
return result
|
||||
|
||||
|
||||
def main():
|
||||
load_dotenv()
|
||||
parser = argparse.ArgumentParser(description="memsciqa 测试脚本(三路检索 + 智能上下文选择)")
|
||||
parser.add_argument("--sample-size", type=int, default=30, help="样本数量(<=0 表示全部)")
|
||||
parser.add_argument("--all", action="store_true", help="评估全部样本(覆盖 --sample-size)")
|
||||
parser.add_argument("--start-index", type=int, default=0, help="起始样本索引")
|
||||
parser.add_argument("--group-id", type=str, default="group_memsci", help="图数据库 Group ID(默认 group_memsci)")
|
||||
parser.add_argument("--search-limit", type=int, default=8, help="检索条数上限")
|
||||
parser.add_argument("--context-char-budget", type=int, default=4000, help="上下文字符预算")
|
||||
parser.add_argument("--llm-temperature", type=float, default=0.0, help="LLM 温度")
|
||||
parser.add_argument("--llm-max-tokens", type=int, default=64, help="LLM 最大输出 token")
|
||||
parser.add_argument("--search-type", type=str, default="embedding", choices=["embedding","keyword","hybrid"], help="检索类型(hybrid 等同于 embedding)")
|
||||
parser.add_argument("--data-path", type=str, default=None, help="数据集路径(默认 data/msc_self_instruct.jsonl)")
|
||||
parser.add_argument("--output", type=str, default=None, help="将评估结果保存到指定文件路径(JSON)")
|
||||
parser.add_argument("--verbose", action="store_true", default=True, help="打印过程日志(默认开启)")
|
||||
parser.add_argument("--quiet", action="store_true", help="关闭过程日志")
|
||||
args = parser.parse_args()
|
||||
|
||||
sample_size = 0 if args.all else args.sample_size
|
||||
|
||||
verbose_flag = False if args.quiet else args.verbose
|
||||
result = asyncio.run(
|
||||
run_memsciqa_test(
|
||||
sample_size=sample_size,
|
||||
group_id=args.group_id,
|
||||
search_limit=args.search_limit,
|
||||
context_char_budget=args.context_char_budget,
|
||||
llm_temperature=args.llm_temperature,
|
||||
llm_max_tokens=args.llm_max_tokens,
|
||||
search_type=args.search_type,
|
||||
data_path=args.data_path,
|
||||
start_index=args.start_index,
|
||||
verbose=verbose_flag,
|
||||
)
|
||||
)
|
||||
|
||||
print(json.dumps(result, ensure_ascii=False, indent=2))
|
||||
|
||||
# 结果保存
|
||||
out_path = args.output
|
||||
if not out_path:
|
||||
eval_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
dataset_results_dir = os.path.join(eval_dir, "results")
|
||||
ts = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
out_path = os.path.join(dataset_results_dir, f"memsciqa_{result['params']['search_type']}_{ts}.json")
|
||||
try:
|
||||
os.makedirs(os.path.dirname(out_path), exist_ok=True)
|
||||
with open(out_path, "w", encoding="utf-8") as f:
|
||||
json.dump(result, f, ensure_ascii=False, indent=2)
|
||||
print(f"\n💾 结果已保存: {out_path}")
|
||||
except Exception as e:
|
||||
print(f"⚠️ 结果保存失败: {e}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,150 +0,0 @@
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from typing import Any, Dict
|
||||
|
||||
# Add src directory to Python path for proper imports when running from evaluation directory
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), 'src'))
|
||||
|
||||
try:
|
||||
from dotenv import load_dotenv
|
||||
except Exception:
|
||||
def load_dotenv():
|
||||
return None
|
||||
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.core.memory.utils.config.definitions import SELECTED_GROUP_ID, PROJECT_ROOT
|
||||
|
||||
from app.core.memory.evaluation.memsciqa.evaluate_qa import run_memsciqa_eval
|
||||
from app.core.memory.evaluation.longmemeval.qwen_search_eval import run_longmemeval_test
|
||||
from app.core.memory.evaluation.locomo.qwen_search_eval import run_locomo_eval
|
||||
|
||||
|
||||
async def run(
|
||||
dataset: str,
|
||||
sample_size: int,
|
||||
reset_group: bool,
|
||||
group_id: str | None,
|
||||
judge_model: str | None = None,
|
||||
search_limit: int | None = None,
|
||||
context_char_budget: int | None = None,
|
||||
llm_temperature: float | None = None,
|
||||
llm_max_tokens: int | None = None,
|
||||
search_type: str | None = None,
|
||||
start_index: int | None = None,
|
||||
max_contexts_per_item: int | None = None,
|
||||
) -> Dict[str, Any]:
|
||||
# 恢复原始风格:统一入口做路由,并沿用各数据集既有默认
|
||||
group_id = group_id or SELECTED_GROUP_ID
|
||||
|
||||
if reset_group:
|
||||
connector = Neo4jConnector()
|
||||
try:
|
||||
await connector.delete_group(group_id)
|
||||
finally:
|
||||
await connector.close()
|
||||
|
||||
if dataset == "locomo":
|
||||
kwargs: Dict[str, Any] = {"sample_size": sample_size, "group_id": group_id}
|
||||
if search_limit is not None:
|
||||
kwargs["search_limit"] = search_limit
|
||||
if context_char_budget is not None:
|
||||
kwargs["context_char_budget"] = context_char_budget
|
||||
if llm_temperature is not None:
|
||||
kwargs["llm_temperature"] = llm_temperature
|
||||
if llm_max_tokens is not None:
|
||||
kwargs["llm_max_tokens"] = llm_max_tokens
|
||||
if search_type is not None:
|
||||
kwargs["search_type"] = search_type
|
||||
return await run_locomo_eval(**kwargs)
|
||||
|
||||
if dataset == "memsciqa":
|
||||
kwargs: Dict[str, Any] = {"sample_size": sample_size, "group_id": group_id}
|
||||
if search_limit is not None:
|
||||
kwargs["search_limit"] = search_limit
|
||||
if context_char_budget is not None:
|
||||
kwargs["context_char_budget"] = context_char_budget
|
||||
if llm_temperature is not None:
|
||||
kwargs["llm_temperature"] = llm_temperature
|
||||
if llm_max_tokens is not None:
|
||||
kwargs["llm_max_tokens"] = llm_max_tokens
|
||||
if search_type is not None:
|
||||
kwargs["search_type"] = search_type
|
||||
return await run_memsciqa_eval(**kwargs)
|
||||
|
||||
if dataset == "longmemeval":
|
||||
kwargs: Dict[str, Any] = {"sample_size": sample_size, "group_id": group_id}
|
||||
if search_limit is not None:
|
||||
kwargs["search_limit"] = search_limit
|
||||
if context_char_budget is not None:
|
||||
kwargs["context_char_budget"] = context_char_budget
|
||||
if llm_temperature is not None:
|
||||
kwargs["llm_temperature"] = llm_temperature
|
||||
if llm_max_tokens is not None:
|
||||
kwargs["llm_max_tokens"] = llm_max_tokens
|
||||
if search_type is not None:
|
||||
kwargs["search_type"] = search_type
|
||||
if start_index is not None:
|
||||
kwargs["start_index"] = start_index
|
||||
if max_contexts_per_item is not None:
|
||||
kwargs["max_contexts_per_item"] = max_contexts_per_item
|
||||
return await run_longmemeval_test(**kwargs)
|
||||
raise ValueError(f"未知数据集: {dataset}")
|
||||
|
||||
|
||||
def main():
|
||||
load_dotenv()
|
||||
parser = argparse.ArgumentParser(description="统一评估入口:memsciqa / longmemeval / locomo")
|
||||
parser.add_argument("--dataset", choices=["memsciqa", "longmemeval", "locomo"], required=True)
|
||||
parser.add_argument("--sample-size", type=int, default=1, help="先用一条数据跑通")
|
||||
parser.add_argument("--reset-group", action="store_true", help="运行前清空当前 group_id 的图数据")
|
||||
parser.add_argument("--group-id", type=str, default=None, help="可选 group_id,默认取 runtime.json")
|
||||
parser.add_argument("--judge-model", type=str, default=None, help="可选:longmemeval 判别式评测模型名")
|
||||
parser.add_argument("--search-limit", type=int, default=None, help="检索返回的对话节点数量上限(不提供则使用各脚本默认)")
|
||||
parser.add_argument("--context-char-budget", type=int, default=None, help="上下文字符预算(不提供则使用各脚本默认)")
|
||||
parser.add_argument("--llm-temperature", type=float, default=None, help="生成温度(不提供则使用各脚本默认)")
|
||||
parser.add_argument("--llm-max-tokens", type=int, default=None, help="最大生成 tokens(不提供则使用各脚本默认)")
|
||||
parser.add_argument("--search-type", type=str, default=None, choices=["keyword", "embedding", "hybrid"], help="检索类型(不提供则使用各脚本默认)")
|
||||
# 仅透传到 longmemeval;其他数据集忽略
|
||||
parser.add_argument("--start-index", type=int, default=None, help="仅 longmemeval:起始样本索引(不提供则用脚本默认)")
|
||||
parser.add_argument("--max-contexts-per-item", type=int, default=None, help="仅 longmemeval:每条样本摄入的上下文数量上限(不提供则用脚本默认)")
|
||||
parser.add_argument("--output", type=str, default=None, help="可选:将评估结果保存到指定文件路径(JSON);不提供时默认保存到 evaluation/<dataset>/results 目录")
|
||||
args = parser.parse_args()
|
||||
|
||||
result = asyncio.run(run(
|
||||
args.dataset,
|
||||
args.sample_size,
|
||||
args.reset_group,
|
||||
args.group_id,
|
||||
args.judge_model,
|
||||
args.search_limit,
|
||||
args.context_char_budget,
|
||||
args.llm_temperature,
|
||||
args.llm_max_tokens,
|
||||
args.search_type,
|
||||
args.start_index,
|
||||
args.max_contexts_per_item,
|
||||
))
|
||||
print(json.dumps(result, ensure_ascii=False, indent=2))
|
||||
|
||||
# 结果输出逻辑保持不变
|
||||
if args.output:
|
||||
out_path = args.output
|
||||
else:
|
||||
eval_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
dataset_results_dir = os.path.join(eval_dir, args.dataset, "results")
|
||||
out_filename = f"{args.dataset}_{args.sample_size}.json"
|
||||
out_path = os.path.join(dataset_results_dir, out_filename)
|
||||
|
||||
out_dir = os.path.dirname(out_path)
|
||||
if out_dir and not os.path.exists(out_dir):
|
||||
os.makedirs(out_dir, exist_ok=True)
|
||||
with open(out_path, "w", encoding="utf-8") as f:
|
||||
json.dump(result, f, ensure_ascii=False, indent=2)
|
||||
print(f"\n结果已保存到: {out_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -187,11 +187,11 @@ class ChunkerClient:
|
||||
async def generate_chunks(self, dialogue: DialogData):
|
||||
"""
|
||||
Generate chunks following 1 Message = 1 Chunk strategy.
|
||||
|
||||
|
||||
Each message creates one chunk, directly inheriting role information.
|
||||
If a message is too long, it will be split into multiple sub-chunks,
|
||||
each maintaining the same speaker.
|
||||
|
||||
|
||||
Raises:
|
||||
ValueError: If dialogue has no messages or chunking fails
|
||||
"""
|
||||
@@ -201,9 +201,9 @@ class ChunkerClient:
|
||||
f"Dialogue {dialogue.ref_id} has no messages. "
|
||||
f"Cannot generate chunks from empty dialogue."
|
||||
)
|
||||
|
||||
|
||||
dialogue.chunks = []
|
||||
|
||||
|
||||
# 按消息分块:每个消息创建一个或多个 chunk,直接继承角色
|
||||
for msg_idx, msg in enumerate(dialogue.context.msgs):
|
||||
# Validate message has required attributes
|
||||
@@ -212,13 +212,13 @@ class ChunkerClient:
|
||||
f"Message {msg_idx} in dialogue {dialogue.ref_id} "
|
||||
f"missing 'role' or 'msg' attribute"
|
||||
)
|
||||
|
||||
|
||||
msg_content = msg.msg.strip()
|
||||
|
||||
|
||||
# Skip empty messages
|
||||
if not msg_content:
|
||||
continue
|
||||
|
||||
|
||||
# 如果消息太长,可以进一步分块
|
||||
if len(msg_content) > self.chunk_size:
|
||||
# 对单个消息的内容进行分块
|
||||
@@ -228,14 +228,14 @@ class ChunkerClient:
|
||||
raise ValueError(
|
||||
f"Failed to chunk long message {msg_idx} in dialogue {dialogue.ref_id}: {e}"
|
||||
)
|
||||
|
||||
|
||||
for idx, sub_chunk in enumerate(sub_chunks):
|
||||
sub_chunk_text = sub_chunk.text if hasattr(sub_chunk, 'text') else str(sub_chunk)
|
||||
sub_chunk_text = sub_chunk_text.strip()
|
||||
|
||||
|
||||
if len(sub_chunk_text) < (self.min_characters_per_chunk or 50):
|
||||
continue
|
||||
|
||||
|
||||
chunk = Chunk(
|
||||
content=f"{msg.role}: {sub_chunk_text}",
|
||||
speaker=msg.role, # 直接继承角色
|
||||
@@ -260,7 +260,7 @@ class ChunkerClient:
|
||||
},
|
||||
)
|
||||
dialogue.chunks.append(chunk)
|
||||
|
||||
|
||||
# Validate we generated at least one chunk
|
||||
if not dialogue.chunks:
|
||||
raise ValueError(
|
||||
@@ -268,7 +268,7 @@ class ChunkerClient:
|
||||
f"All messages were either empty or too short. "
|
||||
f"Messages count: {len(dialogue.context.msgs)}"
|
||||
)
|
||||
|
||||
|
||||
return dialogue
|
||||
|
||||
def evaluate_chunking(self, dialogue: DialogData) -> dict:
|
||||
|
||||
@@ -58,6 +58,25 @@ from app.core.memory.models.triplet_models import (
|
||||
TripletExtractionResponse,
|
||||
)
|
||||
|
||||
# Ontology scenario models (LLM extracted from scenarios)
|
||||
from app.core.memory.models.ontology_scenario_models import (
|
||||
OntologyClass,
|
||||
OntologyExtractionResponse,
|
||||
)
|
||||
|
||||
# Ontology extraction models (for extraction flow)
|
||||
from app.core.memory.models.ontology_extraction_models import (
|
||||
OntologyTypeInfo,
|
||||
OntologyTypeList,
|
||||
)
|
||||
|
||||
# Ontology general models (loaded from external ontology files)
|
||||
from app.core.memory.models.ontology_general_models import (
|
||||
OntologyFileFormat,
|
||||
GeneralOntologyType,
|
||||
GeneralOntologyTypeRegistry,
|
||||
)
|
||||
|
||||
# Variable configuration models
|
||||
from app.core.memory.models.variate_config import (
|
||||
StatementExtractionConfig,
|
||||
@@ -105,6 +124,16 @@ __all__ = [
|
||||
"Entity",
|
||||
"Triplet",
|
||||
"TripletExtractionResponse",
|
||||
# Ontology models
|
||||
"OntologyClass",
|
||||
"OntologyExtractionResponse",
|
||||
# Ontology type models for extraction flow
|
||||
"OntologyTypeInfo",
|
||||
"OntologyTypeList",
|
||||
# General ontology type models
|
||||
"OntologyFileFormat",
|
||||
"GeneralOntologyType",
|
||||
"GeneralOntologyTypeRegistry",
|
||||
# Variable configuration
|
||||
"StatementExtractionConfig",
|
||||
"ForgettingEngineConfig",
|
||||
|
||||
@@ -72,7 +72,7 @@ class TemporalSearchParams(BaseModel):
|
||||
"""Parameters for temporal search queries in the knowledge graph.
|
||||
|
||||
Attributes:
|
||||
group_id: Group ID to filter search results (default: 'test')
|
||||
end_user_id: Group ID to filter search results (default: 'test')
|
||||
apply_id: Application ID to filter search results
|
||||
user_id: User ID to filter search results
|
||||
start_date: Start date for temporal filtering (format: 'YYYY-MM-DD')
|
||||
@@ -81,7 +81,7 @@ class TemporalSearchParams(BaseModel):
|
||||
invalid_date: Date when memory should be invalid (format: 'YYYY-MM-DD')
|
||||
limit: Maximum number of results to return (default: 3)
|
||||
"""
|
||||
group_id: Optional[str] = Field("test", description="The group ID to filter the search.")
|
||||
end_user_id: Optional[str] = Field("test", description="The group ID to filter the search.")
|
||||
apply_id: Optional[str] = Field(None, description="The apply ID to filter the search.")
|
||||
user_id: Optional[str] = Field(None, description="The user ID to filter the search.")
|
||||
start_date: Optional[str] = Field(None, description="The start date for the search.")
|
||||
|
||||
@@ -103,9 +103,7 @@ class Edge(BaseModel):
|
||||
id: Unique identifier for the edge
|
||||
source: ID of the source node
|
||||
target: ID of the target node
|
||||
group_id: Group ID for multi-tenancy
|
||||
user_id: User ID for user-specific data
|
||||
apply_id: Application ID for application-specific data
|
||||
end_user_id: End user ID for multi-tenancy
|
||||
run_id: Unique identifier for the pipeline run that created this edge
|
||||
created_at: Timestamp when the edge was created (system perspective)
|
||||
expired_at: Optional timestamp when the edge expires (system perspective)
|
||||
@@ -113,9 +111,7 @@ class Edge(BaseModel):
|
||||
id: str = Field(default_factory=lambda: uuid4().hex, description="A unique identifier for the edge.")
|
||||
source: str = Field(..., description="The ID of the source node.")
|
||||
target: str = Field(..., description="The ID of the target node.")
|
||||
group_id: str = Field(..., description="The group ID of the edge.")
|
||||
user_id: str = Field(..., description="The user ID of the edge.")
|
||||
apply_id: str = Field(..., description="The apply ID of the edge.")
|
||||
end_user_id: str = Field(..., description="The end user ID of the edge.")
|
||||
run_id: str = Field(default_factory=lambda: uuid4().hex, description="Unique identifier for this pipeline run.")
|
||||
created_at: datetime = Field(..., description="The valid time of the edge from system perspective.")
|
||||
expired_at: Optional[datetime] = Field(None, description="The expired time of the edge from system perspective.")
|
||||
@@ -185,18 +181,14 @@ class Node(BaseModel):
|
||||
Attributes:
|
||||
id: Unique identifier for the node
|
||||
name: Name of the node
|
||||
group_id: Group ID for multi-tenancy
|
||||
user_id: User ID for user-specific data
|
||||
apply_id: Application ID for application-specific data
|
||||
end_user_id: End user ID for multi-tenancy
|
||||
run_id: Unique identifier for the pipeline run that created this node
|
||||
created_at: Timestamp when the node was created (system perspective)
|
||||
expired_at: Optional timestamp when the node expires (system perspective)
|
||||
"""
|
||||
id: str = Field(..., description="The unique identifier for the node.")
|
||||
name: str = Field(..., description="The name of the node.")
|
||||
group_id: str = Field(..., description="The group ID of the node.")
|
||||
user_id: str = Field(..., description="The user ID of the edge.")
|
||||
apply_id: str = Field(..., description="The apply ID of the edge.")
|
||||
end_user_id: str = Field(..., description="The end user ID of the node.")
|
||||
run_id: str = Field(default_factory=lambda: uuid4().hex, description="Unique identifier for this pipeline run.")
|
||||
created_at: datetime = Field(..., description="The valid time of the node from system perspective.")
|
||||
expired_at: Optional[datetime] = Field(None, description="The expired time of the node from system perspective.")
|
||||
@@ -421,7 +413,8 @@ class ExtractedEntityNode(Node):
|
||||
description="Entity aliases - alternative names for this entity"
|
||||
)
|
||||
name_embedding: Optional[List[float]] = Field(default_factory=list, description="Name embedding vector")
|
||||
fact_summary: str = Field(default="", description="Summary of the fact about this entity")
|
||||
# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
|
||||
# fact_summary: str = Field(default="", description="Summary of the fact about this entity")
|
||||
connect_strength: str = Field(..., description="Strong VS Weak about this entity")
|
||||
config_id: Optional[int | str] = Field(None, description="Configuration ID used to process this entity (integer or string)")
|
||||
|
||||
|
||||
@@ -55,7 +55,7 @@ class Statement(BaseModel):
|
||||
Attributes:
|
||||
id: Unique identifier for the statement
|
||||
chunk_id: ID of the parent chunk this statement belongs to
|
||||
group_id: Optional group ID for multi-tenancy
|
||||
end_user_id: Optional group ID for multi-tenancy
|
||||
statement: The actual statement text content
|
||||
speaker: Optional speaker identifier ('用户' for user, 'AI' for AI responses)
|
||||
statement_embedding: Optional embedding vector for the statement
|
||||
@@ -73,7 +73,7 @@ class Statement(BaseModel):
|
||||
"""
|
||||
id: str = Field(default_factory=lambda: uuid4().hex, description="A unique identifier for the statement.")
|
||||
chunk_id: str = Field(..., description="ID of the parent chunk this statement belongs to.")
|
||||
group_id: Optional[str] = Field(None, description="ID of the group this statement belongs to.")
|
||||
end_user_id: Optional[str] = Field(None, description="ID of the group this statement belongs to.")
|
||||
statement: str = Field(..., description="The text content of the statement.")
|
||||
speaker: Optional[str] = Field(None, description="Speaker identifier: 'user' for user messages, 'assistant' for AI responses")
|
||||
statement_embedding: Optional[List[float]] = Field(None, description="The embedding vector of the statement.")
|
||||
@@ -159,9 +159,7 @@ class DialogData(BaseModel):
|
||||
context: Full conversation context
|
||||
dialog_embedding: Optional embedding vector for the entire dialog
|
||||
ref_id: Reference ID linking to external dialog system
|
||||
group_id: Group ID for multi-tenancy
|
||||
user_id: User ID for user-specific data
|
||||
apply_id: Application ID for application-specific data
|
||||
end_user_id: End user ID for multi-tenancy
|
||||
created_at: Timestamp when the dialog was created
|
||||
expired_at: Timestamp when the dialog expires (default: far future)
|
||||
metadata: Additional metadata as key-value pairs
|
||||
@@ -175,9 +173,7 @@ class DialogData(BaseModel):
|
||||
context: ConversationContext = Field(..., description="The full conversation context as a single string.")
|
||||
dialog_embedding: Optional[List[float]] = Field(None, description="The embedding vector of the dialog.")
|
||||
ref_id: str = Field(..., description="Refer to external dialog id. This is used to link to the original dialog.")
|
||||
group_id: str = Field(default=..., description="Group ID of dialogue data")
|
||||
user_id: str = Field(..., description="USER ID of dialogue data")
|
||||
apply_id: str = Field(..., description="APPLY ID of dialogue data")
|
||||
end_user_id: str = Field(default=..., description="End user ID of dialogue data")
|
||||
run_id: str = Field(default_factory=lambda: uuid4().hex, description="Unique identifier for this pipeline run.")
|
||||
created_at: datetime = Field(default_factory=datetime.now, description="The timestamp when the dialog was created.")
|
||||
expired_at: datetime = Field(default_factory=lambda: datetime(9999, 12, 31), description="The timestamp when the dialog expires.")
|
||||
@@ -250,11 +246,11 @@ class DialogData(BaseModel):
|
||||
return []
|
||||
|
||||
def assign_group_id_to_statements(self) -> None:
|
||||
"""Assign this dialog's group_id to all statements in all chunks.
|
||||
"""Assign this dialog's end_user_id to all statements in all chunks.
|
||||
|
||||
This method updates statements that don't have a group_id set.
|
||||
This method updates statements that don't have a end_user_id set.
|
||||
"""
|
||||
for chunk in self.chunks:
|
||||
for statement in chunk.statements:
|
||||
if statement.group_id is None:
|
||||
statement.group_id = self.group_id
|
||||
if statement.end_user_id is None:
|
||||
statement.end_user_id = self.end_user_id
|
||||
|
||||
105
api/app/core/memory/models/ontology_extraction_models.py
Normal file
105
api/app/core/memory/models/ontology_extraction_models.py
Normal file
@@ -0,0 +1,105 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""本体类型数据结构模块
|
||||
|
||||
本模块定义用于在萃取流程中传递本体类型信息的轻量级数据类。
|
||||
|
||||
Classes:
|
||||
OntologyTypeInfo: 单个本体类型信息
|
||||
OntologyTypeList: 本体类型列表
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import List
|
||||
|
||||
|
||||
@dataclass
|
||||
class OntologyTypeInfo:
|
||||
"""本体类型信息,用于萃取流程中传递。
|
||||
|
||||
Attributes:
|
||||
class_name: 类型名称
|
||||
class_description: 类型描述
|
||||
"""
|
||||
class_name: str
|
||||
class_description: str
|
||||
|
||||
def to_prompt_format(self) -> str:
|
||||
"""转换为提示词格式。
|
||||
|
||||
Returns:
|
||||
格式化的字符串,如 "- TypeName: Description"
|
||||
"""
|
||||
return f"- {self.class_name}: {self.class_description}"
|
||||
|
||||
|
||||
@dataclass
|
||||
class OntologyTypeList:
|
||||
"""本体类型列表。
|
||||
|
||||
Attributes:
|
||||
types: 本体类型信息列表
|
||||
"""
|
||||
types: List[OntologyTypeInfo]
|
||||
|
||||
@classmethod
|
||||
def from_db_models(cls, ontology_classes: list) -> "OntologyTypeList":
|
||||
"""从数据库模型转换创建 OntologyTypeList。
|
||||
|
||||
Args:
|
||||
ontology_classes: OntologyClass 数据库模型列表,
|
||||
每个对象应包含 class_name 和 class_description 属性
|
||||
|
||||
Returns:
|
||||
包含转换后类型信息的 OntologyTypeList 实例
|
||||
"""
|
||||
types = [
|
||||
OntologyTypeInfo(
|
||||
class_name=oc.class_name,
|
||||
class_description=oc.class_description or ""
|
||||
)
|
||||
for oc in ontology_classes
|
||||
]
|
||||
return cls(types=types)
|
||||
|
||||
def to_prompt_section(self) -> str:
|
||||
"""转换为提示词中的类型列表部分。
|
||||
|
||||
Returns:
|
||||
格式化的类型列表字符串,每行一个类型;
|
||||
如果列表为空则返回空字符串
|
||||
"""
|
||||
if not self.types:
|
||||
return ""
|
||||
lines = [t.to_prompt_format() for t in self.types]
|
||||
return "\n".join(lines)
|
||||
|
||||
def get_type_names(self) -> List[str]:
|
||||
"""获取所有类型名称列表。
|
||||
|
||||
Returns:
|
||||
类型名称字符串列表
|
||||
"""
|
||||
return [t.class_name for t in self.types]
|
||||
|
||||
def get_type_hierarchy_hints(self) -> List[str]:
|
||||
"""获取类型层次结构提示列表。
|
||||
|
||||
尝试从通用本体注册表中获取每个类型的继承链信息。
|
||||
|
||||
Returns:
|
||||
层次提示字符串列表,格式为 "类型名 → 父类1 → 父类2"
|
||||
"""
|
||||
hints = []
|
||||
try:
|
||||
from app.core.memory.ontology_services.ontology_type_merger import OntologyTypeMerger
|
||||
|
||||
merger = OntologyTypeMerger()
|
||||
for type_info in self.types:
|
||||
hint = merger.get_type_hierarchy_hint(type_info.class_name)
|
||||
if hint:
|
||||
hints.append(hint)
|
||||
except Exception:
|
||||
# 如果无法获取层次信息,返回空列表
|
||||
pass
|
||||
|
||||
return hints
|
||||
223
api/app/core/memory/models/ontology_general_models.py
Normal file
223
api/app/core/memory/models/ontology_general_models.py
Normal file
@@ -0,0 +1,223 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""通用本体类型数据模型模块
|
||||
|
||||
本模块定义用于通用本体类型管理的数据结构,包括:
|
||||
- OntologyFileFormat: 本体文件格式枚举
|
||||
- GeneralOntologyType: 通用本体类型数据类
|
||||
- GeneralOntologyTypeRegistry: 通用本体类型注册表
|
||||
|
||||
Classes:
|
||||
OntologyFileFormat: 本体文件格式枚举,支持 TTL、OWL/XML、RDF/XML、N-Triples、JSON-LD
|
||||
GeneralOntologyType: 通用本体类型,包含类名、URI、标签、描述、父类等信息
|
||||
GeneralOntologyTypeRegistry: 类型注册表,管理类型集合和层次结构
|
||||
"""
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional, Set
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OntologyFileFormat(Enum):
|
||||
"""本体文件格式枚举
|
||||
|
||||
支持的格式:
|
||||
- TURTLE: Turtle 格式 (.ttl 文件)
|
||||
- RDF_XML: RDF/XML 格式 (.owl, .rdf 文件)
|
||||
- N_TRIPLES: N-Triples 格式 (.nt 文件)
|
||||
- JSON_LD: JSON-LD 格式 (.jsonld, .json 文件)
|
||||
"""
|
||||
TURTLE = "turtle" # .ttl 文件
|
||||
RDF_XML = "xml" # .owl, .rdf (RDF/XML 格式)
|
||||
N_TRIPLES = "nt" # .nt 文件
|
||||
JSON_LD = "json-ld" # .jsonld 文件
|
||||
|
||||
@classmethod
|
||||
def from_extension(cls, file_path: str) -> "OntologyFileFormat":
|
||||
"""根据文件扩展名推断格式
|
||||
|
||||
Args:
|
||||
file_path: 文件路径
|
||||
|
||||
Returns:
|
||||
推断出的文件格式,默认返回 RDF_XML
|
||||
"""
|
||||
ext = file_path.lower().split('.')[-1]
|
||||
format_map = {
|
||||
'ttl': cls.TURTLE,
|
||||
'owl': cls.RDF_XML,
|
||||
'rdf': cls.RDF_XML,
|
||||
'nt': cls.N_TRIPLES,
|
||||
'jsonld': cls.JSON_LD,
|
||||
'json': cls.JSON_LD,
|
||||
}
|
||||
return format_map.get(ext, cls.RDF_XML)
|
||||
|
||||
|
||||
@dataclass
|
||||
class GeneralOntologyType:
|
||||
"""通用本体类型
|
||||
|
||||
表示从本体文件中解析出的类型定义,包含类型的基本信息和层次关系。
|
||||
|
||||
Attributes:
|
||||
class_name: 类型名称,如 "Person"
|
||||
class_uri: 完整 URI,如 "http://dbpedia.org/ontology/Person"
|
||||
labels: 多语言标签字典,键为语言代码(如 "en", "zh"),值为标签文本
|
||||
description: 类型描述
|
||||
parent_class: 父类名称,用于构建类型层次
|
||||
source_file: 来源文件路径
|
||||
"""
|
||||
class_name: str # 类型名称,如 "Person"
|
||||
class_uri: str # 完整 URI
|
||||
labels: Dict[str, str] = field(default_factory=dict) # 多语言标签
|
||||
description: Optional[str] = None # 类型描述
|
||||
parent_class: Optional[str] = None # 父类名称
|
||||
source_file: Optional[str] = None # 来源文件
|
||||
|
||||
def get_label(self, lang: str = "en") -> str:
|
||||
"""获取指定语言的标签
|
||||
|
||||
优先返回指定语言的标签,如果不存在则尝试返回英文标签,
|
||||
最后返回类型名称作为默认值。
|
||||
|
||||
Args:
|
||||
lang: 语言代码,默认为 "en"
|
||||
|
||||
Returns:
|
||||
指定语言的标签,或默认值
|
||||
"""
|
||||
return self.labels.get(lang, self.labels.get("en", self.class_name))
|
||||
|
||||
|
||||
@dataclass
|
||||
class GeneralOntologyTypeRegistry:
|
||||
"""通用本体类型注册表
|
||||
|
||||
管理解析后的本体类型集合,提供类型查询、层次遍历、注册表合并等功能。
|
||||
|
||||
Attributes:
|
||||
types: 类型字典,键为类型名称,值为 GeneralOntologyType 实例
|
||||
hierarchy: 层次结构字典,键为父类名称,值为子类名称集合
|
||||
source_files: 已加载的源文件路径列表
|
||||
"""
|
||||
types: Dict[str, GeneralOntologyType] = field(default_factory=dict)
|
||||
hierarchy: Dict[str, Set[str]] = field(default_factory=dict) # 父类 -> 子类集合
|
||||
source_files: List[str] = field(default_factory=list)
|
||||
|
||||
def get_type(self, name: str) -> Optional[GeneralOntologyType]:
|
||||
"""根据名称获取类型
|
||||
|
||||
Args:
|
||||
name: 类型名称
|
||||
|
||||
Returns:
|
||||
对应的 GeneralOntologyType 实例,如果不存在则返回 None
|
||||
"""
|
||||
return self.types.get(name)
|
||||
|
||||
def get_ancestors(self, name: str) -> List[str]:
|
||||
"""获取类型的所有祖先类型(防循环)
|
||||
|
||||
从当前类型开始,沿着父类链向上遍历,返回所有祖先类型名称。
|
||||
使用 visited 集合防止循环引用导致的无限循环。
|
||||
|
||||
Args:
|
||||
name: 类型名称
|
||||
|
||||
Returns:
|
||||
祖先类型名称列表,按从近到远的顺序排列
|
||||
"""
|
||||
ancestors = []
|
||||
current = name
|
||||
visited = set()
|
||||
while current and current not in visited:
|
||||
visited.add(current)
|
||||
type_info = self.types.get(current)
|
||||
if type_info and type_info.parent_class:
|
||||
# 检测循环引用
|
||||
if type_info.parent_class in visited:
|
||||
logger.warning(
|
||||
f"检测到类型层次循环引用: {current} -> {type_info.parent_class},"
|
||||
f"已遍历路径: {' -> '.join([name] + ancestors)}"
|
||||
)
|
||||
break
|
||||
ancestors.append(type_info.parent_class)
|
||||
current = type_info.parent_class
|
||||
else:
|
||||
break
|
||||
return ancestors
|
||||
|
||||
def get_descendants(self, name: str) -> Set[str]:
|
||||
"""获取类型的所有后代类型
|
||||
|
||||
从当前类型开始,沿着子类关系向下遍历,返回所有后代类型名称。
|
||||
使用广度优先搜索,避免重复处理已访问的类型。
|
||||
|
||||
Args:
|
||||
name: 类型名称
|
||||
|
||||
Returns:
|
||||
后代类型名称集合
|
||||
"""
|
||||
descendants: Set[str] = set()
|
||||
to_process = [name]
|
||||
while to_process:
|
||||
current = to_process.pop()
|
||||
children = self.hierarchy.get(current, set())
|
||||
new_children = children - descendants
|
||||
descendants.update(new_children)
|
||||
to_process.extend(new_children)
|
||||
return descendants
|
||||
|
||||
def merge(self, other: "GeneralOntologyTypeRegistry") -> None:
|
||||
"""合并另一个注册表(先加载的优先)
|
||||
|
||||
将另一个注册表的类型和层次结构合并到当前注册表。
|
||||
对于同名类型,保留当前注册表中已存在的定义(先加载优先)。
|
||||
层次结构会合并所有子类关系。
|
||||
|
||||
Args:
|
||||
other: 要合并的另一个注册表
|
||||
"""
|
||||
for name, type_info in other.types.items():
|
||||
if name not in self.types:
|
||||
self.types[name] = type_info
|
||||
for parent, children in other.hierarchy.items():
|
||||
if parent not in self.hierarchy:
|
||||
self.hierarchy[parent] = set()
|
||||
self.hierarchy[parent].update(children)
|
||||
self.source_files.extend(other.source_files)
|
||||
|
||||
def get_statistics(self) -> Dict[str, Any]:
|
||||
"""获取注册表统计信息
|
||||
|
||||
Returns:
|
||||
包含以下键的字典:
|
||||
- total_types: 总类型数
|
||||
- root_types: 根类型数(无父类的类型)
|
||||
- max_depth: 类型层次的最大深度
|
||||
- source_files: 源文件列表
|
||||
"""
|
||||
return {
|
||||
"total_types": len(self.types),
|
||||
"root_types": len([t for t in self.types.values() if not t.parent_class]),
|
||||
"max_depth": self._calculate_max_depth(),
|
||||
"source_files": self.source_files,
|
||||
}
|
||||
|
||||
def _calculate_max_depth(self) -> int:
|
||||
"""计算类型层次的最大深度
|
||||
|
||||
遍历所有类型,计算每个类型到根的深度,返回最大值。
|
||||
|
||||
Returns:
|
||||
类型层次的最大深度
|
||||
"""
|
||||
max_depth = 0
|
||||
for type_name in self.types:
|
||||
depth = len(self.get_ancestors(type_name))
|
||||
max_depth = max(max_depth, depth)
|
||||
return max_depth
|
||||
138
api/app/core/memory/models/ontology_scenario_models.py
Normal file
138
api/app/core/memory/models/ontology_scenario_models.py
Normal file
@@ -0,0 +1,138 @@
|
||||
"""Models for ontology classes and extraction responses.
|
||||
|
||||
This module contains Pydantic models for representing extracted ontology classes
|
||||
from scenario descriptions, following OWL ontology engineering standards.
|
||||
|
||||
Classes:
|
||||
OntologyClass: Represents an extracted ontology class
|
||||
OntologyExtractionResponse: Response model containing extracted ontology classes
|
||||
"""
|
||||
|
||||
from typing import List, Optional
|
||||
from uuid import uuid4
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
|
||||
|
||||
class OntologyClass(BaseModel):
|
||||
"""Represents an extracted ontology class from scenario description.
|
||||
|
||||
An ontology class represents an abstract category or concept in a domain,
|
||||
following OWL ontology engineering standards and naming conventions.
|
||||
|
||||
Attributes:
|
||||
id: Unique string identifier for the ontology class
|
||||
name: Name of the class in PascalCase format (e.g., 'MedicalProcedure')
|
||||
name_chinese: Chinese translation of the class name (e.g., '医疗程序')
|
||||
description: Textual description of the class
|
||||
examples: List of concrete instance examples of this class
|
||||
parent_class: Optional name of the parent class in the hierarchy
|
||||
entity_type: Type/category of the entity (e.g., 'Person', 'Organization', 'Concept')
|
||||
domain: Domain this class belongs to (e.g., 'Healthcare', 'Education')
|
||||
|
||||
Config:
|
||||
extra: Ignore extra fields from LLM output
|
||||
"""
|
||||
model_config = ConfigDict(extra='ignore')
|
||||
|
||||
id: str = Field(
|
||||
default_factory=lambda: uuid4().hex,
|
||||
description="Unique identifier for the ontology class"
|
||||
)
|
||||
name: str = Field(
|
||||
...,
|
||||
description="Name of the class in PascalCase format"
|
||||
)
|
||||
name_chinese: Optional[str] = Field(
|
||||
None,
|
||||
description="Chinese translation of the class name"
|
||||
)
|
||||
description: str = Field(
|
||||
...,
|
||||
description="Description of the class"
|
||||
)
|
||||
examples: List[str] = Field(
|
||||
default_factory=list,
|
||||
description="List of concrete instance examples"
|
||||
)
|
||||
parent_class: Optional[str] = Field(
|
||||
None,
|
||||
description="Name of the parent class in the hierarchy"
|
||||
)
|
||||
entity_type: str = Field(
|
||||
...,
|
||||
description="Type/category of the entity"
|
||||
)
|
||||
domain: str = Field(
|
||||
...,
|
||||
description="Domain this class belongs to"
|
||||
)
|
||||
|
||||
@field_validator('name')
|
||||
@classmethod
|
||||
def validate_pascal_case(cls, v: str) -> str:
|
||||
"""Validate that the class name follows PascalCase convention.
|
||||
|
||||
PascalCase rules:
|
||||
- Must start with an uppercase letter (for English) or any character (for Chinese/Unicode)
|
||||
- Cannot contain spaces
|
||||
- Should not contain special characters except underscores
|
||||
|
||||
Args:
|
||||
v: The class name to validate
|
||||
|
||||
Returns:
|
||||
The validated class name
|
||||
|
||||
Raises:
|
||||
ValueError: If the name doesn't follow PascalCase convention
|
||||
"""
|
||||
if not v:
|
||||
raise ValueError("Class name cannot be empty")
|
||||
|
||||
# For Chinese/Unicode characters, skip the uppercase check
|
||||
# Only check uppercase for ASCII letters
|
||||
first_char = v[0]
|
||||
if first_char.isascii() and first_char.isalpha() and not first_char.isupper():
|
||||
raise ValueError(
|
||||
f"Class name '{v}' must start with an uppercase letter (PascalCase)"
|
||||
)
|
||||
|
||||
if ' ' in v:
|
||||
raise ValueError(
|
||||
f"Class name '{v}' cannot contain spaces (PascalCase)"
|
||||
)
|
||||
|
||||
# Check for invalid characters (allow alphanumeric, underscore, and Unicode characters)
|
||||
if not all(c.isalnum() or c == '_' or ord(c) > 127 for c in v):
|
||||
raise ValueError(
|
||||
f"Class name '{v}' contains invalid characters. "
|
||||
"Only alphanumeric characters, underscores, and Unicode characters are allowed"
|
||||
)
|
||||
|
||||
return v
|
||||
|
||||
|
||||
class OntologyExtractionResponse(BaseModel):
|
||||
"""Response model for ontology extraction from LLM.
|
||||
|
||||
This model represents the structured output from the LLM when
|
||||
extracting ontology classes from scenario descriptions.
|
||||
|
||||
Attributes:
|
||||
classes: List of extracted ontology classes
|
||||
domain: Domain/field the scenario belongs to
|
||||
|
||||
Config:
|
||||
extra: Ignore extra fields from LLM output
|
||||
"""
|
||||
model_config = ConfigDict(extra='ignore')
|
||||
|
||||
classes: List[OntologyClass] = Field(
|
||||
default_factory=list,
|
||||
description="List of extracted ontology classes"
|
||||
)
|
||||
domain: str = Field(
|
||||
...,
|
||||
description="Domain/field the scenario belongs to"
|
||||
)
|
||||
39
api/app/core/memory/ontology_services/__init__.py
Normal file
39
api/app/core/memory/ontology_services/__init__.py
Normal file
@@ -0,0 +1,39 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""本体类型服务模块
|
||||
|
||||
本模块提供本体类型相关的服务,包括:
|
||||
- OntologyTypeMerger: 本体类型合并服务
|
||||
- get_general_ontology_registry: 获取通用本体类型注册表(单例,懒加载)
|
||||
- get_ontology_type_merger: 获取类型合并服务实例
|
||||
- reload_ontology_registry: 重新加载本体注册表(实验模式)
|
||||
- clear_ontology_cache: 清除本体缓存
|
||||
- is_general_ontology_enabled: 检查通用本体类型功能是否启用
|
||||
- load_ontology_types_for_scene: 从数据库加载场景的本体类型
|
||||
- create_empty_ontology_type_list: 创建空的本体类型列表
|
||||
- load_ontology_types_with_fallback: 加载本体类型(带通用类型回退)
|
||||
"""
|
||||
|
||||
from .ontology_type_merger import OntologyTypeMerger, DEFAULT_CORE_GENERAL_TYPES
|
||||
from .ontology_type_loader import (
|
||||
get_general_ontology_registry,
|
||||
get_ontology_type_merger,
|
||||
reload_ontology_registry,
|
||||
clear_ontology_cache,
|
||||
is_general_ontology_enabled,
|
||||
load_ontology_types_for_scene,
|
||||
create_empty_ontology_type_list,
|
||||
load_ontology_types_with_fallback,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"OntologyTypeMerger",
|
||||
"DEFAULT_CORE_GENERAL_TYPES",
|
||||
"get_general_ontology_registry",
|
||||
"get_ontology_type_merger",
|
||||
"reload_ontology_registry",
|
||||
"clear_ontology_cache",
|
||||
"is_general_ontology_enabled",
|
||||
"load_ontology_types_for_scene",
|
||||
"create_empty_ontology_type_list",
|
||||
"load_ontology_types_with_fallback",
|
||||
]
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user