Compare commits
974 Commits
release/v0
...
hotfix/v0.
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
fa4be10e51 | ||
|
|
dcb7b496d3 | ||
|
|
9535545947 | ||
|
|
59f5c7a8bb | ||
|
|
1305a08c86 | ||
|
|
fe29141437 | ||
|
|
17d3c81c02 | ||
|
|
baf02e4faa | ||
|
|
4d6038c3cc | ||
|
|
d4450658a8 | ||
|
|
3ceb2efeaf | ||
|
|
e134b96333 | ||
|
|
4df41966fe | ||
|
|
2d6cde157e | ||
|
|
abc27c8372 | ||
|
|
dbe387f666 | ||
|
|
5e70d436a8 | ||
|
|
b7198f1abd | ||
|
|
5c87a2beeb | ||
|
|
3419bb137a | ||
|
|
a00684c67d | ||
|
|
6e7c641fd4 | ||
|
|
0c677701c0 | ||
|
|
4974f9aa98 | ||
|
|
c90b58bbcd | ||
|
|
d6a243f1be | ||
|
|
418114ef72 | ||
|
|
ceed61167f | ||
|
|
83774d7443 | ||
|
|
052c7c19b3 | ||
|
|
d42db0ca33 | ||
|
|
e15af5a2ba | ||
|
|
8b44b2cd61 | ||
|
|
9d91453200 | ||
|
|
ea8db7cd90 | ||
|
|
d60f16df1b | ||
|
|
8dd24533bf | ||
|
|
91b7f2a980 | ||
|
|
f7e89af9d2 | ||
|
|
afbd8c9b4f | ||
|
|
09b3b01d37 | ||
|
|
e3dcbed5f9 | ||
|
|
c7b51e7ad8 | ||
|
|
c0cd2373c0 | ||
|
|
6e757ae9e2 | ||
|
|
64a73c41d6 | ||
|
|
dae7431075 | ||
|
|
643bbbcf5c | ||
|
|
6702e86536 | ||
|
|
ab2bdfa088 | ||
|
|
8285250096 | ||
|
|
e59a215078 | ||
|
|
c89eccf8fe | ||
|
|
5703fc0cb4 | ||
|
|
3aed5c447a | ||
|
|
13352178ad | ||
|
|
8f216db353 | ||
|
|
9f6026492d | ||
|
|
b699b746a5 | ||
|
|
6095170169 | ||
|
|
173697e86a | ||
|
|
5c11da6a2e | ||
|
|
96214c433f | ||
|
|
167c915631 | ||
|
|
f485398768 | ||
|
|
289b1989e5 | ||
|
|
8224848ce1 | ||
|
|
c43d258455 | ||
|
|
c3e5c8b8bb | ||
|
|
930cadcaa8 | ||
|
|
57b6b34567 | ||
|
|
f878846364 | ||
|
|
7dce63dc0b | ||
|
|
03bc8ee7f5 | ||
|
|
4aefb01b0b | ||
|
|
4e9b5736b1 | ||
|
|
46fa99a8b8 | ||
|
|
17ea92357d | ||
|
|
bd70a8b812 | ||
|
|
ad5dc3c138 | ||
|
|
e37b1b01ca | ||
|
|
e659ca9fa2 | ||
|
|
758be0087f | ||
|
|
200c13b59f | ||
|
|
32f6886000 | ||
|
|
7fbf3e8873 | ||
|
|
3026702000 | ||
|
|
8677db114b | ||
|
|
2597a1f532 | ||
|
|
4298cd7d06 | ||
|
|
8197f9db35 | ||
|
|
3da6331515 | ||
|
|
539999131c | ||
|
|
d0ca5c8b27 | ||
|
|
ee6b8ffa62 | ||
|
|
14838dc064 | ||
|
|
e017870f44 | ||
|
|
9730c5ce0f | ||
|
|
f30260939a | ||
|
|
8ba0a74473 | ||
|
|
4f69224cfd | ||
|
|
6f7fee18c9 | ||
|
|
cc58c7333c | ||
|
|
c936277507 | ||
|
|
701df40270 | ||
|
|
b724dbe53a | ||
|
|
ac7c891ded | ||
|
|
3ed6f49bb0 | ||
|
|
a416a6b2bd | ||
|
|
35be03803f | ||
|
|
6427018ffb | ||
|
|
06b823ff96 | ||
|
|
0fdb489227 | ||
|
|
f6394a791e | ||
|
|
4bfd4944d0 | ||
|
|
7faf291ec3 | ||
|
|
3d291e3c23 | ||
|
|
b35bedc730 | ||
|
|
4d39cdf464 | ||
|
|
a874cc70a4 | ||
|
|
2319432182 | ||
|
|
7556468c6e | ||
|
|
91d38c0648 | ||
|
|
df3d58d388 | ||
|
|
80856e3c92 | ||
|
|
8c6f395818 | ||
|
|
2f4f7219e3 | ||
|
|
4c5183eddc | ||
|
|
dfc0ee9424 | ||
|
|
8dbb067b83 | ||
|
|
1df3fc416a | ||
|
|
6223b80cc4 | ||
|
|
68489f1b28 | ||
|
|
477853b04e | ||
|
|
863be50aaf | ||
|
|
d72d57f966 | ||
|
|
5b940e5f1a | ||
|
|
9ae1d2f0d9 | ||
|
|
318f1be107 | ||
|
|
4cab6317de | ||
|
|
81bfc9af36 | ||
|
|
189013f0f8 | ||
|
|
6f5bcd18a4 | ||
|
|
c7ef97c7a6 | ||
|
|
4d4a780ab7 | ||
|
|
9d2f3aa8f9 | ||
|
|
f2c9902a07 | ||
|
|
2525f8795c | ||
|
|
b7a03a844f | ||
|
|
c13c3846d1 | ||
|
|
30b5db1e98 | ||
|
|
f92eb9f45a | ||
|
|
a136d44e27 | ||
|
|
65b2f9e6e1 | ||
|
|
5275a274c3 | ||
|
|
4f09c4fbb3 | ||
|
|
7a3220aff5 | ||
|
|
14a32778f7 | ||
|
|
2a12cb04bf | ||
|
|
1e986c641f | ||
|
|
38c6c7f053 | ||
|
|
7c0743eb8f | ||
|
|
e981f066a3 | ||
|
|
db14d40fb3 | ||
|
|
e8d575fd0b | ||
|
|
a7285e35ad | ||
|
|
c4461c4917 | ||
|
|
2df615eca0 | ||
|
|
504e5ba61e | ||
|
|
0bae290e0c | ||
|
|
294ee49d59 | ||
|
|
26c36f70e6 | ||
|
|
c4b83b1f9c | ||
|
|
14413fd413 | ||
|
|
caab58dd2f | ||
|
|
0e899bea05 | ||
|
|
1794f8f209 | ||
|
|
85daf576e9 | ||
|
|
56fd5680cf | ||
|
|
0380c13a3b | ||
|
|
9ddc523f91 | ||
|
|
491ef27b8a | ||
|
|
edd115582f | ||
|
|
45eef12842 | ||
|
|
49364802c2 | ||
|
|
8873078006 | ||
|
|
2b9fd33bc8 | ||
|
|
e86d679ae5 | ||
|
|
def7367e33 | ||
|
|
54cff5861a | ||
|
|
dc2a73155b | ||
|
|
1856c55c04 | ||
|
|
522eb569f1 | ||
|
|
9df41456f6 | ||
|
|
04c54081c8 | ||
|
|
1c49e3c167 | ||
|
|
fb6ce839d2 | ||
|
|
c7275dccac | ||
|
|
d62b484d71 | ||
|
|
8ff1c6bd08 | ||
|
|
3dcf901043 | ||
|
|
d6dfc2cb12 | ||
|
|
8a3032ce4a | ||
|
|
391c60c812 | ||
|
|
b739b032d9 | ||
|
|
3dc863cabf | ||
|
|
611b14dfea | ||
|
|
de6e2f54d2 | ||
|
|
89d188fbf3 | ||
|
|
6bba574ca6 | ||
|
|
9cbffd6408 | ||
|
|
4d2ad5757c | ||
|
|
cd0ca9cae4 | ||
|
|
3369b702e4 | ||
|
|
cbec2c1356 | ||
|
|
5987eee0a8 | ||
|
|
6348304b7d | ||
|
|
59f8010519 | ||
|
|
9308c6efae | ||
|
|
2f78b7cf5e | ||
|
|
f86448f4bf | ||
|
|
48e2e613bb | ||
|
|
1060074740 | ||
|
|
95b7df7e38 | ||
|
|
fd1634eec4 | ||
|
|
efeead41b2 | ||
|
|
a3428c2435 | ||
|
|
31b8a3764e | ||
|
|
2ff81ba101 | ||
|
|
93deb286a3 | ||
|
|
7bd97bf6d3 | ||
|
|
2d1a1b4a1f | ||
|
|
503c890d93 | ||
|
|
1f73501786 | ||
|
|
eef13cb717 | ||
|
|
c70ac1339e | ||
|
|
24c13d408e | ||
|
|
338d7f1065 | ||
|
|
27672cfaa0 | ||
|
|
4dbb2bf2e2 | ||
|
|
37bc4beab4 | ||
|
|
31085ed678 | ||
|
|
dce7206c44 | ||
|
|
c17a2dad2d | ||
|
|
e8ae46b286 | ||
|
|
78316de411 | ||
|
|
c205e7d20e | ||
|
|
81f3b50200 | ||
|
|
e3795fe1ed | ||
|
|
72a2f2a7e8 | ||
|
|
0f092e08f4 | ||
|
|
8e7603bcc4 | ||
|
|
035cc17264 | ||
|
|
a079358028 | ||
|
|
cf26c9f39c | ||
|
|
fa29a39920 | ||
|
|
2146c555d2 | ||
|
|
240f1d431b | ||
|
|
9f947a3395 | ||
|
|
bf5c4628c3 | ||
|
|
911d5e0b34 | ||
|
|
bd31aa5abf | ||
|
|
0775fad5f0 | ||
|
|
726148d7ee | ||
|
|
0f1b1d7d10 | ||
|
|
fabc8936ab | ||
|
|
11aa2e1f9e | ||
|
|
ca654cca74 | ||
|
|
bd1f649bd0 | ||
|
|
06de54ebfd | ||
|
|
ea00747c66 | ||
|
|
3db031891e | ||
|
|
fb6ca3909a | ||
|
|
929afb1770 | ||
|
|
6235584b2e | ||
|
|
0b1ea33b41 | ||
|
|
3929f811b8 | ||
|
|
7c6e48b04e | ||
|
|
b1b53f6b1d | ||
|
|
551a2b59a5 | ||
|
|
9a765ac71e | ||
|
|
83e26732de | ||
|
|
52fdfc7744 | ||
|
|
4e544325a0 | ||
|
|
99a2f396fd | ||
|
|
0157c9d262 | ||
|
|
5ddacab162 | ||
|
|
a51e34852c | ||
|
|
fcc81ac025 | ||
|
|
36f670b2e9 | ||
|
|
cbcbc8822c | ||
|
|
69c001bf84 | ||
|
|
aa2d1e7a35 | ||
|
|
39b2f3ba0e | ||
|
|
43064ab71b | ||
|
|
4144f0b9b5 | ||
|
|
08f0be17ce | ||
|
|
2915e464bf | ||
|
|
152559ae46 | ||
|
|
1f531f1ace | ||
|
|
7ec947189c | ||
|
|
b4615bacdc | ||
|
|
e849fed5c1 | ||
|
|
0f5cae4590 | ||
|
|
1c3029f360 | ||
|
|
e2411e0bdd | ||
|
|
7af88b19cf | ||
|
|
c3f8dbd4bc | ||
|
|
c1e48fde86 | ||
|
|
f644c84fbb | ||
|
|
d0afce27c4 | ||
|
|
b84aba71e7 | ||
|
|
2e481df465 | ||
|
|
a322ec4fd5 | ||
|
|
bdbf9c0609 | ||
|
|
ef7d59e442 | ||
|
|
27b782e12a | ||
|
|
37a22fbfa9 | ||
|
|
d798d101f7 | ||
|
|
825f225f63 | ||
|
|
4d5e2958dc | ||
|
|
6105d46198 | ||
|
|
7aec157859 | ||
|
|
13abb03d87 | ||
|
|
e8947ad0bb | ||
|
|
7056865726 | ||
|
|
9d8c26b999 | ||
|
|
c2c832f8c9 | ||
|
|
6bc4f04293 | ||
|
|
9d150ab353 | ||
|
|
f045b59b2d | ||
|
|
0bb8278a39 | ||
|
|
e43f812c14 | ||
|
|
d584b47280 | ||
|
|
3e995cd971 | ||
|
|
b018e35ada | ||
|
|
4bc030c1ef | ||
|
|
86a0aa1f9f | ||
|
|
d523e4f3c6 | ||
|
|
84c23e7c4e | ||
|
|
186d097e00 | ||
|
|
c5cfe557da | ||
|
|
f786a66a3c | ||
|
|
ebd51928d7 | ||
|
|
2258b5c43c | ||
|
|
2e50e30071 | ||
|
|
8c804a1011 | ||
|
|
1a4c2d7cd0 | ||
|
|
c2fc4ab4ff | ||
|
|
83fcabadae | ||
|
|
d12ad213e0 | ||
|
|
33d522b387 | ||
|
|
5997458aaf | ||
|
|
68f9471caf | ||
|
|
ecbb61db27 | ||
|
|
b42815ee7a | ||
|
|
49d7398e14 | ||
|
|
91589c1497 | ||
|
|
a07727c047 | ||
|
|
25bc506f74 | ||
|
|
18ca83d763 | ||
|
|
4bbc561625 | ||
|
|
d77220a603 | ||
|
|
f52b681133 | ||
|
|
f6efa0d711 | ||
|
|
0fccc91dac | ||
|
|
8d8c6c695a | ||
|
|
57342259ce | ||
|
|
be46ed8865 | ||
|
|
04b2205769 | ||
|
|
76ba357982 | ||
|
|
2c318f6e60 | ||
|
|
3f04153f22 | ||
|
|
3df8af3852 | ||
|
|
8b9ab8a841 | ||
|
|
750dbcc7c3 | ||
|
|
5d6007aaff | ||
|
|
291767031c | ||
|
|
22ffe6ef1d | ||
|
|
02df1a70f3 | ||
|
|
8c5fa9c441 | ||
|
|
e6c558c2a0 | ||
|
|
b52e4d756c | ||
|
|
1089a52ca0 | ||
|
|
c7fb9ab8e3 | ||
|
|
83017d0c80 | ||
|
|
e24217a6ba | ||
|
|
a0f2f738df | ||
|
|
9d9250954b | ||
|
|
f042f44501 | ||
|
|
56c98648f9 | ||
|
|
956efe6a09 | ||
|
|
bb64ad23dd | ||
|
|
a97326df74 | ||
|
|
1503f8781a | ||
|
|
163ddbb6ed | ||
|
|
7bbfd33ca0 | ||
|
|
0ea47ce890 | ||
|
|
38f891235c | ||
|
|
4d83c074d9 | ||
|
|
0e9672df80 | ||
|
|
abc7460539 | ||
|
|
4bb2ccfba7 | ||
|
|
969d428320 | ||
|
|
ff64522c50 | ||
|
|
65dc1a8f48 | ||
|
|
859b7f3c7f | ||
|
|
da3f875555 | ||
|
|
44d63a44da | ||
|
|
7e5e1609b0 | ||
|
|
d94adcb19c | ||
|
|
83894df260 | ||
|
|
7b99a32a1e | ||
|
|
e8c3744f5e | ||
|
|
06d1f54030 | ||
|
|
599ccb6bde | ||
|
|
db9050c302 | ||
|
|
71b3b665b5 | ||
|
|
3b8a806661 | ||
|
|
774719fb50 | ||
|
|
a3ccd41288 | ||
|
|
8ddacb7bc9 | ||
|
|
e74a74c3fb | ||
|
|
262a9ddc48 | ||
|
|
70f84b65ec | ||
|
|
ec5cb42f67 | ||
|
|
0802481fd2 | ||
|
|
548ba0ae36 | ||
|
|
fc2360d40d | ||
|
|
ab67bda5a1 | ||
|
|
376d5ca7d0 | ||
|
|
55438136b0 | ||
|
|
82db3517d7 | ||
|
|
130490c022 | ||
|
|
ede8a11584 | ||
|
|
ba65b06582 | ||
|
|
f4f04036f3 | ||
|
|
43130dcbc8 | ||
|
|
ff6459e439 | ||
|
|
1893de4c75 | ||
|
|
dfcc85a466 | ||
|
|
dacfb360f6 | ||
|
|
8a0d83b340 | ||
|
|
be2ce854a1 | ||
|
|
e492dcd968 | ||
|
|
55bfee856d | ||
|
|
f951075551 | ||
|
|
964086a08a | ||
|
|
67501025b3 | ||
|
|
e1cc5c841a | ||
|
|
6b839bd5a8 | ||
|
|
5df339b56d | ||
|
|
56adca9f22 | ||
|
|
1e63dd8d2d | ||
|
|
fab9272124 | ||
|
|
2f66fd9aae | ||
|
|
5616583fa1 | ||
|
|
3f0e991112 | ||
|
|
477d404727 | ||
|
|
8e6288bca8 | ||
|
|
72bba0662f | ||
|
|
090f46006a | ||
|
|
abe0c7e7d1 | ||
|
|
6516f56ada | ||
|
|
ea391dc44e | ||
|
|
e21f713de0 | ||
|
|
3498e2e884 | ||
|
|
ea8edc5914 | ||
|
|
b62c40dba3 | ||
|
|
0832337839 | ||
|
|
b82f4491fb | ||
|
|
bdf0c256b3 | ||
|
|
3d91a9e926 | ||
|
|
779dbdea26 | ||
|
|
e8e342c206 | ||
|
|
78829d36cc | ||
|
|
f7c2e82dc0 | ||
|
|
88598fb9fb | ||
|
|
19d149c129 | ||
|
|
f09de3a11c | ||
|
|
e13acdc8a9 | ||
|
|
b8e85bed61 | ||
|
|
396493ad2b | ||
|
|
f32d92b9d0 | ||
|
|
6d79db8ba3 | ||
|
|
f9fb480cc3 | ||
|
|
1efa8798bf | ||
|
|
c244e9834f | ||
|
|
b1a7b58f97 | ||
|
|
e81f39b50e | ||
|
|
3c99fb116c | ||
|
|
e7e136036c | ||
|
|
ca84fc6c9d | ||
|
|
a0c4515a81 | ||
|
|
4bf418a3d6 | ||
|
|
f033607c8b | ||
|
|
32d612fbeb | ||
|
|
9ce3a881f3 | ||
|
|
860cd31799 | ||
|
|
d674b48f7d | ||
|
|
1635f9dbef | ||
|
|
07c899f0a9 | ||
|
|
382e4c5377 | ||
|
|
fe6518d052 | ||
|
|
dc513dfbeb | ||
|
|
3d9bc7a986 | ||
|
|
75e36173cd | ||
|
|
8097f227ca | ||
|
|
3d79b72d70 | ||
|
|
6eb9b772e7 | ||
|
|
90c8ff35d1 | ||
|
|
ad87fd96db | ||
|
|
fd1debe681 | ||
|
|
c7cc0cd922 | ||
|
|
81a232177e | ||
|
|
73aee97be5 | ||
|
|
39f3a85bb1 | ||
|
|
098a2e54ae | ||
|
|
d575478b53 | ||
|
|
aab54ca1a8 | ||
|
|
d4f2094ee0 | ||
|
|
c354618e20 | ||
|
|
5141a91041 | ||
|
|
668539e737 | ||
|
|
967139cea4 | ||
|
|
6d8b1aede4 | ||
|
|
744ba31ba6 | ||
|
|
db8257b67a | ||
|
|
85770dc037 | ||
|
|
69f976a79a | ||
|
|
fd7e77eff8 | ||
|
|
05c2a093c0 | ||
|
|
01a1e8eab1 | ||
|
|
b71bc1f875 | ||
|
|
6a0ee22d81 | ||
|
|
cbc8714414 | ||
|
|
065f8db2f7 | ||
|
|
0ac7f83726 | ||
|
|
d03473da10 | ||
|
|
dac1c01a2c | ||
|
|
f6d929ab7a | ||
|
|
a7a2dabc5a | ||
|
|
83015a3404 | ||
|
|
b88e9c5f5e | ||
|
|
8380a8a811 | ||
|
|
6c69181290 | ||
|
|
0694075447 | ||
|
|
d66b9dd8cb | ||
|
|
7267198a8c | ||
|
|
7b8f101824 | ||
|
|
0f36c5c872 | ||
|
|
6a67f028ce | ||
|
|
5d82786c20 | ||
|
|
e368f1c1d6 | ||
|
|
572ce7f9ec | ||
|
|
a4c942a21f | ||
|
|
4859ab3ba7 | ||
|
|
983b5f5087 | ||
|
|
75b87955dd | ||
|
|
110de0afbc | ||
|
|
2c074cd5c1 | ||
|
|
73e51a9b0b | ||
|
|
3a47039919 | ||
|
|
2961ea4e44 | ||
|
|
af2ffc9737 | ||
|
|
d7911244fc | ||
|
|
2a66775e45 | ||
|
|
6029a5a9a8 | ||
|
|
71d9ae15a1 | ||
|
|
f0c3d5f308 | ||
|
|
4706ea59fe | ||
|
|
5774a95f61 | ||
|
|
d660521c5c | ||
|
|
5db2c5092e | ||
|
|
59618457df | ||
|
|
c612dfbc1f | ||
|
|
fc58ac0408 | ||
|
|
8d053c97a7 | ||
|
|
a3e6f67ff7 | ||
|
|
01da2e3eee | ||
|
|
168cce1678 | ||
|
|
7240dfe793 | ||
|
|
b9340ba02d | ||
|
|
4f5ee24bc5 | ||
|
|
6a1b8d3ee3 | ||
|
|
f1207dc8b9 | ||
|
|
86c51559bb | ||
|
|
8b0f806079 | ||
|
|
99e94b3567 | ||
|
|
cfd5c1bc93 | ||
|
|
45d9e45346 | ||
|
|
fcb3845543 | ||
|
|
97eabc0c36 | ||
|
|
5328163973 | ||
|
|
7ff9dfee8c | ||
|
|
5b431400be | ||
|
|
1e1675ec12 | ||
|
|
f941541304 | ||
|
|
3f7083c5b3 | ||
|
|
e81faebf69 | ||
|
|
8a4d58c520 | ||
|
|
2ac29ee89c | ||
|
|
252cdcd6f5 | ||
|
|
16e2c95965 | ||
|
|
10560fb34c | ||
|
|
58aa60ca0e | ||
|
|
d24b186d3e | ||
|
|
b4e81615b1 | ||
|
|
424d2033ea | ||
|
|
fd556f9b00 | ||
|
|
e2f5fa87b1 | ||
|
|
e4a2bd3b9b | ||
|
|
e3ada17a78 | ||
|
|
3e5a7adfe4 | ||
|
|
3237f4cd6e | ||
|
|
beea826377 | ||
|
|
7cdbbefc64 | ||
|
|
18780622b3 | ||
|
|
f405ac4d84 | ||
|
|
9fe47e2fb2 | ||
|
|
e4aaa18f61 | ||
|
|
5c3d9717dd | ||
|
|
ac86bbd60c | ||
|
|
33d12c43b2 | ||
|
|
107c676185 | ||
|
|
0f221b7ee6 | ||
|
|
e1939ef472 | ||
|
|
5438d35f17 | ||
|
|
9c26d1f4c8 | ||
|
|
4c2b31f31f | ||
|
|
4f88a13256 | ||
|
|
21ae448ed7 | ||
|
|
50466124c8 | ||
|
|
ece88a3879 | ||
|
|
cedc4a92cc | ||
|
|
c8065b0c60 | ||
|
|
476632294f | ||
|
|
349d46e043 | ||
|
|
00e0201bf9 | ||
|
|
b9ebe22df1 | ||
|
|
389dd8d402 | ||
|
|
966bd8528d | ||
|
|
8f789d47a2 | ||
|
|
509d1a2e24 | ||
|
|
94a40e49a0 | ||
|
|
8429279eea | ||
|
|
cef14cda9e | ||
|
|
c14f067afb | ||
|
|
6c8dca6379 | ||
|
|
819d205166 | ||
|
|
153e68e055 | ||
|
|
77b9a6a94e | ||
|
|
d68bbab419 | ||
|
|
6d53d9178c | ||
|
|
9e17f65eda | ||
|
|
7373f68172 | ||
|
|
0999bd30d7 | ||
|
|
f01185a7fc | ||
|
|
7cd7303754 | ||
|
|
d19fec2155 | ||
|
|
2612abc9d0 | ||
|
|
06fe3f2f01 | ||
|
|
e2b6c713e7 | ||
|
|
0b3b241436 | ||
|
|
4c18f9e858 | ||
|
|
8fec54c085 | ||
|
|
d8e37a4d2b | ||
|
|
1da2c4fa37 | ||
|
|
d080b44ac3 | ||
|
|
df18868888 | ||
|
|
4438b08560 | ||
|
|
1029f94669 | ||
|
|
0a3acf446d | ||
|
|
c01ad5a19e | ||
|
|
5a7723553c | ||
|
|
975844eccf | ||
|
|
865ad31f2f | ||
|
|
b756f0c86c | ||
|
|
3e5f6176af | ||
|
|
ab5b165dc2 | ||
|
|
f9393c2f63 | ||
|
|
aa6638424c | ||
|
|
834387e254 | ||
|
|
9caa986c80 | ||
|
|
72b84dfc8f | ||
|
|
af10195025 | ||
|
|
22382423ad | ||
|
|
0f80c67cbd | ||
|
|
aa6473c1c7 | ||
|
|
cde61cb6ac | ||
|
|
b1368997c2 | ||
|
|
ec7dc448c1 | ||
|
|
254147265e | ||
|
|
479bba9a4e | ||
|
|
cfb39a6baa | ||
|
|
05c9ed1450 | ||
|
|
f53633a8b8 | ||
|
|
f56bc0f85a | ||
|
|
63882e9391 | ||
|
|
3c4dfb868f | ||
|
|
9600d687fa | ||
|
|
cae9105b8d | ||
|
|
41a0036bf6 | ||
|
|
2c9401ccfb | ||
|
|
08e4ad6a7c | ||
|
|
2b0dedc81c | ||
|
|
314e6e29d5 | ||
|
|
16b87de0df | ||
|
|
8c3af7f4ff | ||
|
|
391cd602a2 | ||
|
|
5f56cc8056 | ||
|
|
827ab27bef | ||
|
|
ccc67df8df | ||
|
|
82538c469f | ||
|
|
076ceee29d | ||
|
|
822b73b015 | ||
|
|
862bff51cb | ||
|
|
247db844a4 | ||
|
|
5495d32822 | ||
|
|
bccbeaabe4 | ||
|
|
a496991400 | ||
|
|
0ea83b4364 | ||
|
|
03676b7adc | ||
|
|
af6fde414f | ||
|
|
d069809001 | ||
|
|
fc240849cf | ||
|
|
61d2a328fe | ||
|
|
fed0ae8e9c | ||
|
|
eaf0de453b | ||
|
|
e833db954a | ||
|
|
0b2651f4ed | ||
|
|
10c677a6fd | ||
|
|
3398c4737a | ||
|
|
a008f5fbef | ||
|
|
6a42e73667 | ||
|
|
7611db19f3 | ||
|
|
d3399dfaf5 | ||
|
|
248f0d95ac | ||
|
|
5c39d841ee | ||
|
|
87be67cb9a | ||
|
|
1a08bea864 | ||
|
|
bc4406cec6 | ||
|
|
4206c849c3 | ||
|
|
3f052b7798 | ||
|
|
f1c5f24f6b | ||
|
|
e981c95225 | ||
|
|
4ce4f53835 | ||
|
|
f16e369540 | ||
|
|
47bf93d65e | ||
|
|
5c2e0af33e | ||
|
|
aaa0410781 | ||
|
|
366b148f3d | ||
|
|
6a265de31c | ||
|
|
c3707f543c | ||
|
|
8de368348b | ||
|
|
d052c31ac5 | ||
|
|
31320afed6 | ||
|
|
7afe507296 | ||
|
|
4188443101 | ||
|
|
a1fc0fd394 | ||
|
|
71fe35533d | ||
|
|
a2ed335e59 | ||
|
|
8422a05d74 | ||
|
|
139ae3bcb4 | ||
|
|
a0a57d5fbb | ||
|
|
80fa88ac37 | ||
|
|
0fda1c752d | ||
|
|
6c2fc75199 | ||
|
|
2cb6aeb022 | ||
|
|
e0174f75b3 | ||
|
|
51d04746a3 | ||
|
|
3b08d6c320 | ||
|
|
495c5802a0 | ||
|
|
621b074b3d | ||
|
|
6df32983b5 | ||
|
|
9c9fe9dde7 | ||
|
|
128c1a6178 | ||
|
|
f90e102854 | ||
|
|
2e1eb9a5a6 | ||
|
|
60a95f6556 | ||
|
|
218637e81d | ||
|
|
404f78af0f | ||
|
|
130f15665c | ||
|
|
6301528301 | ||
|
|
6feea968e0 | ||
|
|
b5199b2eb9 | ||
|
|
78ce2a9a8b | ||
|
|
6ed542b007 | ||
|
|
5322b0c4a3 | ||
|
|
a72d5d2c77 | ||
|
|
16c1cbe24f | ||
|
|
0d8f4c76e7 | ||
|
|
e511b14933 | ||
|
|
b5ba53208e | ||
|
|
b8bfb4d0c5 | ||
|
|
1b666638bc | ||
|
|
2bd364eca3 | ||
|
|
f27fc51801 | ||
|
|
0f85eff76b | ||
|
|
0def474cc2 | ||
|
|
026e4376d4 | ||
|
|
cf571cf02b | ||
|
|
590ec3a446 | ||
|
|
23bfdcefef | ||
|
|
647a978865 | ||
|
|
86f72100f0 | ||
|
|
8b255259ba | ||
|
|
8aad8faae9 | ||
|
|
420f391f3c | ||
|
|
817221347f | ||
|
|
13dce5e265 | ||
|
|
850d9ee70b | ||
|
|
ba36ccb21f | ||
|
|
f712754927 | ||
|
|
efe3865aa4 | ||
|
|
53dbe2f436 | ||
|
|
720498084b | ||
|
|
f5eda38dc9 | ||
|
|
8ada221777 | ||
|
|
4ee198813a | ||
|
|
440e8acd99 | ||
|
|
218671ef06 | ||
|
|
34de0bb9c5 | ||
|
|
8e6cf09056 | ||
|
|
5929072b76 | ||
|
|
37325e9802 | ||
|
|
778bc4bd70 | ||
|
|
f78f59ec42 | ||
|
|
d4c4160215 | ||
|
|
85aea97c21 | ||
|
|
b075cad4de | ||
|
|
f326febc8a | ||
|
|
1738e45090 | ||
|
|
6e758faa37 | ||
|
|
32e79c5df0 | ||
|
|
aa69cd3a0c | ||
|
|
da4a1f536d | ||
|
|
b3af757167 | ||
|
|
82794f051a | ||
|
|
a726a81224 | ||
|
|
9aae6163f0 | ||
|
|
941527e7ee | ||
|
|
a3f05220d3 | ||
|
|
7446241735 | ||
|
|
6033d37537 | ||
|
|
1524d7b5ce | ||
|
|
e00341a4cc | ||
|
|
f5185d2e95 | ||
|
|
c041d24989 | ||
|
|
dc9003f9db | ||
|
|
07e0c70629 | ||
|
|
37f77e0990 | ||
|
|
aef1a57ea8 | ||
|
|
69af479224 | ||
|
|
f38223c97f | ||
|
|
1ac6702eb0 | ||
|
|
2510f60dce | ||
|
|
b9d7fb2598 | ||
|
|
a39ba564fa | ||
|
|
34310bfabe | ||
|
|
78fd189510 | ||
|
|
94836ed9af | ||
|
|
1d662fb63e | ||
|
|
d1933d2aef | ||
|
|
163872be6e | ||
|
|
14fcb66a9c | ||
|
|
c488eb0cd0 | ||
|
|
91d20f7272 | ||
|
|
c3d7963fe0 | ||
|
|
c31a92bf01 | ||
|
|
b5703c1b82 | ||
|
|
df34735a9b | ||
|
|
31bee889d7 | ||
|
|
b3ba0a6ed6 | ||
|
|
ce3b7897d7 | ||
|
|
9115ad6950 | ||
|
|
c6b76438f4 | ||
|
|
68c4c7429c | ||
|
|
8466c8e019 | ||
|
|
d899b27448 | ||
|
|
229eb5cc86 | ||
|
|
66c153f1ad | ||
|
|
bbb2c6c903 | ||
|
|
5edf3f2b8a | ||
|
|
006c6cd159 | ||
|
|
9675982555 | ||
|
|
c6c7a1827c | ||
|
|
3ac8a9431b | ||
|
|
5c42a84c3e | ||
|
|
8fdaebbe6e | ||
|
|
9a98ccff2c | ||
|
|
ee4027c561 | ||
|
|
7f36a06f26 | ||
|
|
0826a34d8b | ||
|
|
1792cb4d93 | ||
|
|
304ccef101 | ||
|
|
bdc22c892d | ||
|
|
a5034e84ba | ||
|
|
6e2de96fed | ||
|
|
2b6d86e591 | ||
|
|
8c6f4cb117 | ||
|
|
16d4b32eb7 | ||
|
|
45a64dbbac | ||
|
|
537668b463 | ||
|
|
07fea23dd0 | ||
|
|
cef14291f0 | ||
|
|
bbde0588af | ||
|
|
aa7d52568b | ||
|
|
f39c77ac70 | ||
|
|
aa733354e8 | ||
|
|
7cec966979 | ||
|
|
74865d2cf2 | ||
|
|
c9a8753473 | ||
|
|
ce8a2cbe34 | ||
|
|
c0fdd0c6d3 | ||
|
|
88bfcfe6cd | ||
|
|
c4dcf1fd65 | ||
|
|
6cebddf893 | ||
|
|
1738ed3664 | ||
|
|
37ddcb91ac | ||
|
|
574ab4506b | ||
|
|
81353538e5 | ||
|
|
5abfcdfbe8 | ||
|
|
9962a61c21 | ||
|
|
5cf2b08777 | ||
|
|
9be1c01b70 | ||
|
|
62b2ecdfc2 | ||
|
|
2ff9000d25 | ||
|
|
5829148ce4 | ||
|
|
8e15a340f6 | ||
|
|
1270b7cdd8 | ||
|
|
7c02fe8148 | ||
|
|
4ac63e1c23 | ||
|
|
4aeb653ed2 | ||
|
|
2d5c2de613 | ||
|
|
96590941cf | ||
|
|
0655ff4a91 | ||
|
|
0ba370052e | ||
|
|
4d59e04aba | ||
|
|
6db6c33564 | ||
|
|
ed0d963aec | ||
|
|
3a36d038ee | ||
|
|
3d068a9c96 | ||
|
|
87df352adc | ||
|
|
8b546b7366 | ||
|
|
77ea0680fb | ||
|
|
4c592bf7e3 | ||
|
|
6718553bf4 | ||
|
|
79dc6f3f69 | ||
|
|
8df72d2822 | ||
|
|
b9578bd08a | ||
|
|
3ce5926689 | ||
|
|
035464c0ac | ||
|
|
f1fcffbfc0 | ||
|
|
b79fe07052 | ||
|
|
e6aa0e0e10 | ||
|
|
54700e6fbe | ||
|
|
3a0671c661 | ||
|
|
1037729fb3 | ||
|
|
5f211620c5 | ||
|
|
cb6a3aae9e | ||
|
|
5e512df3d4 | ||
|
|
9916cf3265 | ||
|
|
f7aed9dd98 | ||
|
|
5253cf3899 | ||
|
|
f7d92be5ea | ||
|
|
97d8168824 | ||
|
|
550bd4da23 | ||
|
|
a7ffc19ba1 | ||
|
|
b9201c918a | ||
|
|
4f0b653a82 | ||
|
|
616709acbb | ||
|
|
67053ab8ae | ||
|
|
33238d34c9 |
3
.gitignore
vendored
3
.gitignore
vendored
@@ -25,10 +25,13 @@ examples/
|
||||
time.log
|
||||
celerybeat-schedule.db
|
||||
search_results.json
|
||||
redbear-mem-metrics/
|
||||
pitch-deck/
|
||||
|
||||
api/migrations/versions
|
||||
tmp
|
||||
files
|
||||
powers/
|
||||
|
||||
# Exclude dep files
|
||||
huggingface.co/
|
||||
|
||||
@@ -226,8 +226,8 @@ REDIS_PORT=6379
|
||||
REDIS_DB=1
|
||||
|
||||
# Celery (Using Redis as broker)
|
||||
BROKER_URL=redis://127.0.0.1:6379/0
|
||||
RESULT_BACKEND=redis://127.0.0.1:6379/0
|
||||
REDIS_DB_CELERY_BROKER=1
|
||||
REDIS_DB_CELERY_BACKEND=2
|
||||
|
||||
# JWT Secret Key (Formation method: openssl rand -hex 32)
|
||||
SECRET_KEY=your-secret-key-here
|
||||
|
||||
@@ -201,8 +201,8 @@ REDIS_PORT=6379
|
||||
REDIS_DB=1
|
||||
|
||||
# Celery (使用Redis作为broker)
|
||||
BROKER_URL=redis://127.0.0.1:6379/0
|
||||
RESULT_BACKEND=redis://127.0.0.1:6379/0
|
||||
REDIS_DB_CELERY_BROKER=1
|
||||
REDIS_DB_CELERY_BACKEND=2
|
||||
|
||||
# JWT密钥 (生成方式: openssl rand -hex 32)
|
||||
SECRET_KEY=your-secret-key-here
|
||||
|
||||
@@ -45,7 +45,8 @@ RUN --mount=type=cache,id=mem_apt,target=/var/cache/apt,sharing=locked \
|
||||
apt install -y libpython3-dev libgtk-4-1 libnss3 xdg-utils libgbm-dev && \
|
||||
apt install -y libjemalloc-dev && \
|
||||
apt install -y python3-pip pipx nginx unzip curl wget git vim less && \
|
||||
apt install -y ghostscript
|
||||
apt install -y ghostscript && \
|
||||
apt install -y libmagic1
|
||||
|
||||
RUN if [ "$NEED_MIRROR" == "1" ]; then \
|
||||
pip3 config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple && \
|
||||
|
||||
@@ -60,7 +60,12 @@ version_path_separator = os # Use os.pathsep. Default configuration used for ne
|
||||
# are written from script.py.mako
|
||||
# output_encoding = utf-8
|
||||
|
||||
sqlalchemy.url = postgresql://user:password@localhost/dbname
|
||||
# Database connection URL - DO NOT hardcode credentials here!
|
||||
# Connection string is set dynamically from environment variables in migrations/env.py
|
||||
# Required env vars: DB_USER, DB_PASSWORD, DB_HOST, DB_PORT, DB_NAME
|
||||
# Example: postgresql://user:password@localhost:5432/dbname
|
||||
; sqlalchemy.url = postgresql://user:password@host:port/dbname
|
||||
sqlalchemy.url = driver://user:password@host:port/dbname
|
||||
|
||||
|
||||
[post_write_hooks]
|
||||
|
||||
@@ -1,16 +1,18 @@
|
||||
import os
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
from typing import Dict, Any, Optional
|
||||
|
||||
import redis.asyncio as redis
|
||||
from redis.asyncio import ConnectionPool
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
# 设置日志记录器
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# 创建连接池
|
||||
pool = ConnectionPool.from_url(
|
||||
f"redis://{settings.REDIS_HOST}:{settings.REDIS_PORT}",
|
||||
@@ -21,6 +23,51 @@ pool = ConnectionPool.from_url(
|
||||
)
|
||||
aio_redis = redis.StrictRedis(connection_pool=pool)
|
||||
|
||||
_REDIS_URL = f"redis://{settings.REDIS_HOST}:{settings.REDIS_PORT}"
|
||||
|
||||
# Thread-local storage for connection pools.
|
||||
# Each thread (and each forked process) gets its own pool to avoid
|
||||
# "Future attached to a different loop" errors in Celery --pool=threads
|
||||
# and stale connections after fork in --pool=prefork.
|
||||
_thread_local = threading.local()
|
||||
|
||||
|
||||
def get_thread_safe_redis() -> redis.StrictRedis:
|
||||
"""Return a Redis client whose connection pool is bound to the current
|
||||
thread, process **and** event loop.
|
||||
|
||||
The pool is recreated when:
|
||||
- The PID changes (fork, Celery --pool=prefork)
|
||||
- The thread has no pool yet (Celery --pool=threads)
|
||||
- The previously-cached event loop has been closed (Celery tasks call
|
||||
``_shutdown_loop_gracefully`` which closes the loop after each run)
|
||||
"""
|
||||
current_pid = os.getpid()
|
||||
cached_loop = getattr(_thread_local, "loop", None)
|
||||
loop_stale = cached_loop is not None and cached_loop.is_closed()
|
||||
|
||||
if not hasattr(_thread_local, "pool") \
|
||||
or getattr(_thread_local, "pid", None) != current_pid \
|
||||
or loop_stale:
|
||||
_thread_local.pid = current_pid
|
||||
# Python 3.10+: get_event_loop() raises RuntimeError in threads
|
||||
# where no loop has been set yet (e.g. Celery --pool=threads).
|
||||
try:
|
||||
_thread_local.loop = asyncio.get_event_loop()
|
||||
except RuntimeError:
|
||||
_thread_local.loop = None
|
||||
_thread_local.pool = ConnectionPool.from_url(
|
||||
_REDIS_URL,
|
||||
db=settings.REDIS_DB,
|
||||
password=settings.REDIS_PASSWORD,
|
||||
decode_responses=True,
|
||||
max_connections=5,
|
||||
health_check_interval=30,
|
||||
)
|
||||
|
||||
return redis.StrictRedis(connection_pool=_thread_local.pool)
|
||||
|
||||
|
||||
async def get_redis_connection():
|
||||
"""获取Redis连接"""
|
||||
try:
|
||||
@@ -29,7 +76,8 @@ async def get_redis_connection():
|
||||
logger.error(f"Redis连接失败: {str(e)}")
|
||||
return None
|
||||
|
||||
async def aio_redis_set(key: str, val: str|dict, expire: int = None):
|
||||
|
||||
async def aio_redis_set(key: str, val: str | dict, expire: int = None):
|
||||
"""设置Redis键值
|
||||
|
||||
Args:
|
||||
@@ -40,16 +88,15 @@ async def aio_redis_set(key: str, val: str|dict, expire: int = None):
|
||||
try:
|
||||
if isinstance(val, dict):
|
||||
val = json.dumps(val, ensure_ascii=False)
|
||||
|
||||
|
||||
if expire is not None:
|
||||
# 设置带过期时间的键值
|
||||
await aio_redis.set(key, val, ex=expire)
|
||||
else:
|
||||
# 设置永久键值
|
||||
await aio_redis.set(key, val)
|
||||
except Exception as e:
|
||||
logger.error(f"Redis set错误: {str(e)}")
|
||||
|
||||
|
||||
async def aio_redis_get(key: str):
|
||||
"""获取Redis键值"""
|
||||
try:
|
||||
@@ -58,6 +105,7 @@ async def aio_redis_get(key: str):
|
||||
logger.error(f"Redis get错误: {str(e)}")
|
||||
return None
|
||||
|
||||
|
||||
async def aio_redis_delete(key: str):
|
||||
"""删除Redis键"""
|
||||
try:
|
||||
@@ -66,6 +114,7 @@ async def aio_redis_delete(key: str):
|
||||
logger.error(f"Redis delete错误: {str(e)}")
|
||||
return None
|
||||
|
||||
|
||||
async def aio_redis_publish(channel: str, message: Dict[str, Any]) -> bool:
|
||||
"""发布消息到Redis频道"""
|
||||
try:
|
||||
@@ -78,9 +127,10 @@ async def aio_redis_publish(channel: str, message: Dict[str, Any]) -> bool:
|
||||
logger.error(f"Redis发布错误: {str(e)}")
|
||||
return False
|
||||
|
||||
|
||||
class RedisSubscriber:
|
||||
"""Redis订阅器"""
|
||||
|
||||
|
||||
def __init__(self, channel: str):
|
||||
self.channel = channel
|
||||
self.conn = None
|
||||
@@ -88,25 +138,25 @@ class RedisSubscriber:
|
||||
self.is_closed = False
|
||||
self._queue = asyncio.Queue()
|
||||
self._task = None
|
||||
|
||||
|
||||
async def start(self):
|
||||
"""开始订阅"""
|
||||
if self.is_closed or self._task:
|
||||
return
|
||||
|
||||
|
||||
self._task = asyncio.create_task(self._receive_messages())
|
||||
logger.info(f"开始订阅: {self.channel}")
|
||||
|
||||
|
||||
async def _receive_messages(self):
|
||||
"""接收消息"""
|
||||
try:
|
||||
self.conn = await get_redis_connection()
|
||||
if not self.conn:
|
||||
return
|
||||
|
||||
|
||||
self.pubsub = self.conn.pubsub()
|
||||
await self.pubsub.subscribe(self.channel)
|
||||
|
||||
|
||||
while not self.is_closed:
|
||||
try:
|
||||
message = await self.pubsub.get_message(ignore_subscribe_messages=True, timeout=0.01)
|
||||
@@ -127,7 +177,7 @@ class RedisSubscriber:
|
||||
finally:
|
||||
await self._queue.put(None)
|
||||
await self._cleanup()
|
||||
|
||||
|
||||
async def _cleanup(self):
|
||||
"""清理资源"""
|
||||
if self.pubsub:
|
||||
@@ -141,7 +191,7 @@ class RedisSubscriber:
|
||||
await self.conn.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
async def get_message(self) -> Optional[Dict[str, Any]]:
|
||||
"""获取消息"""
|
||||
if self.is_closed:
|
||||
@@ -153,7 +203,7 @@ class RedisSubscriber:
|
||||
except Exception as e:
|
||||
logger.error(f"获取消息错误: {str(e)}")
|
||||
return None
|
||||
|
||||
|
||||
async def close(self):
|
||||
"""关闭订阅器"""
|
||||
if self.is_closed:
|
||||
@@ -163,32 +213,33 @@ class RedisSubscriber:
|
||||
self._task.cancel()
|
||||
await self._cleanup()
|
||||
|
||||
|
||||
class RedisPubSubManager:
|
||||
"""Redis发布订阅管理器"""
|
||||
|
||||
|
||||
def __init__(self):
|
||||
self.subscribers = {}
|
||||
|
||||
|
||||
async def publish(self, channel: str, message: Dict[str, Any]) -> bool:
|
||||
return await aio_redis_publish(channel, message)
|
||||
|
||||
|
||||
def get_subscriber(self, channel: str) -> RedisSubscriber:
|
||||
if channel in self.subscribers:
|
||||
subscriber = self.subscribers[channel]
|
||||
if not subscriber.is_closed:
|
||||
return subscriber
|
||||
|
||||
|
||||
subscriber = RedisSubscriber(channel)
|
||||
self.subscribers[channel] = subscriber
|
||||
return subscriber
|
||||
|
||||
|
||||
def cancel_subscription(self, channel: str) -> bool:
|
||||
if channel in self.subscribers:
|
||||
asyncio.create_task(self.subscribers[channel].close())
|
||||
del self.subscribers[channel]
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def cancel_all_subscriptions(self) -> int:
|
||||
count = len(self.subscribers)
|
||||
for subscriber in self.subscribers.values():
|
||||
@@ -196,6 +247,6 @@ class RedisPubSubManager:
|
||||
self.subscribers.clear()
|
||||
return count
|
||||
|
||||
|
||||
# 全局实例
|
||||
pubsub_manager = RedisPubSubManager()
|
||||
|
||||
|
||||
5
api/app/cache/__init__.py
vendored
5
api/app/cache/__init__.py
vendored
@@ -3,9 +3,8 @@ Cache 缓存模块
|
||||
|
||||
提供各种缓存功能的统一入口
|
||||
"""
|
||||
from .memory import EmotionMemoryCache, ImplicitMemoryCache
|
||||
from .memory import InterestMemoryCache
|
||||
|
||||
__all__ = [
|
||||
"EmotionMemoryCache",
|
||||
"ImplicitMemoryCache",
|
||||
"InterestMemoryCache",
|
||||
]
|
||||
|
||||
8
api/app/cache/memory/__init__.py
vendored
8
api/app/cache/memory/__init__.py
vendored
@@ -3,10 +3,10 @@ Memory 缓存模块
|
||||
|
||||
提供记忆系统相关的缓存功能
|
||||
"""
|
||||
from .emotion_memory import EmotionMemoryCache
|
||||
from .implicit_memory import ImplicitMemoryCache
|
||||
from .interest_memory import InterestMemoryCache
|
||||
from .activity_stats_cache import ActivityStatsCache
|
||||
|
||||
__all__ = [
|
||||
"EmotionMemoryCache",
|
||||
"ImplicitMemoryCache",
|
||||
"InterestMemoryCache",
|
||||
"ActivityStatsCache",
|
||||
]
|
||||
|
||||
124
api/app/cache/memory/activity_stats_cache.py
vendored
Normal file
124
api/app/cache/memory/activity_stats_cache.py
vendored
Normal file
@@ -0,0 +1,124 @@
|
||||
"""
|
||||
Recent Activity Stats Cache
|
||||
|
||||
记忆提取活动统计缓存模块
|
||||
用于缓存每次记忆提取流程的统计数据,按 workspace_id 存储,24小时后释放
|
||||
查询命令:cache:memory:activity_stats:by_workspace:7de31a97-40a6-4fc0-b8d3-15c89f523843
|
||||
"""
|
||||
import json
|
||||
import logging
|
||||
from typing import Optional, Dict, Any
|
||||
from datetime import datetime
|
||||
|
||||
from app.aioRedis import get_thread_safe_redis
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 缓存过期时间:24小时
|
||||
ACTIVITY_STATS_CACHE_EXPIRE = 86400
|
||||
|
||||
|
||||
class ActivityStatsCache:
|
||||
"""记忆提取活动统计缓存类"""
|
||||
|
||||
PREFIX = "cache:memory:activity_stats"
|
||||
|
||||
@classmethod
|
||||
def _get_key(cls, workspace_id: str) -> str:
|
||||
"""生成 Redis key
|
||||
|
||||
Args:
|
||||
workspace_id: 工作空间ID
|
||||
|
||||
Returns:
|
||||
完整的 Redis key
|
||||
"""
|
||||
return f"{cls.PREFIX}:by_workspace:{workspace_id}"
|
||||
|
||||
@classmethod
|
||||
async def set_activity_stats(
|
||||
cls,
|
||||
workspace_id: str,
|
||||
stats: Dict[str, Any],
|
||||
expire: int = ACTIVITY_STATS_CACHE_EXPIRE,
|
||||
) -> bool:
|
||||
"""设置记忆提取活动统计缓存
|
||||
|
||||
Args:
|
||||
workspace_id: 工作空间ID
|
||||
stats: 统计数据,格式:
|
||||
{
|
||||
"chunk_count": int,
|
||||
"statements_count": int,
|
||||
"triplet_entities_count": int,
|
||||
"triplet_relations_count": int,
|
||||
"temporal_count": int,
|
||||
}
|
||||
expire: 过期时间(秒),默认24小时
|
||||
|
||||
Returns:
|
||||
是否设置成功
|
||||
"""
|
||||
try:
|
||||
key = cls._get_key(workspace_id)
|
||||
payload = {
|
||||
"stats": stats,
|
||||
"generated_at": datetime.now().isoformat(),
|
||||
"workspace_id": workspace_id,
|
||||
"cached": True,
|
||||
}
|
||||
value = json.dumps(payload, ensure_ascii=False)
|
||||
await get_thread_safe_redis().set(key, value, ex=expire)
|
||||
logger.info(f"设置活动统计缓存成功: {key}, 过期时间: {expire}秒")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"设置活动统计缓存失败: {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
async def get_activity_stats(
|
||||
cls,
|
||||
workspace_id: str,
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""获取记忆提取活动统计缓存
|
||||
|
||||
Args:
|
||||
workspace_id: 工作空间ID
|
||||
|
||||
Returns:
|
||||
统计数据字典,缓存不存在或已过期返回 None
|
||||
"""
|
||||
try:
|
||||
key = cls._get_key(workspace_id)
|
||||
value = await get_thread_safe_redis().get(key)
|
||||
if value:
|
||||
payload = json.loads(value)
|
||||
logger.info(f"命中活动统计缓存: {key}")
|
||||
return payload
|
||||
logger.info(f"活动统计缓存不存在或已过期: {key}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"获取活动统计缓存失败: {e}", exc_info=True)
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
async def delete_activity_stats(
|
||||
cls,
|
||||
workspace_id: str,
|
||||
) -> bool:
|
||||
"""删除记忆提取活动统计缓存
|
||||
|
||||
Args:
|
||||
workspace_id: 工作空间ID
|
||||
|
||||
Returns:
|
||||
是否删除成功
|
||||
"""
|
||||
try:
|
||||
key = cls._get_key(workspace_id)
|
||||
result = await get_thread_safe_redis().delete(key)
|
||||
logger.info(f"删除活动统计缓存: {key}, 结果: {result}")
|
||||
return result > 0
|
||||
except Exception as e:
|
||||
logger.error(f"删除活动统计缓存失败: {e}", exc_info=True)
|
||||
return False
|
||||
134
api/app/cache/memory/emotion_memory.py
vendored
134
api/app/cache/memory/emotion_memory.py
vendored
@@ -1,134 +0,0 @@
|
||||
"""
|
||||
Emotion Suggestions Cache
|
||||
|
||||
情绪个性化建议缓存模块
|
||||
用于缓存用户的情绪个性化建议数据
|
||||
"""
|
||||
import json
|
||||
import logging
|
||||
from typing import Optional, Dict, Any
|
||||
from datetime import datetime
|
||||
|
||||
from app.aioRedis import aio_redis
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EmotionMemoryCache:
|
||||
"""情绪建议缓存类"""
|
||||
|
||||
# Key 前缀
|
||||
PREFIX = "cache:memory:emotion_memory"
|
||||
|
||||
@classmethod
|
||||
def _get_key(cls, *parts: str) -> str:
|
||||
"""生成 Redis key
|
||||
|
||||
Args:
|
||||
*parts: key 的各个部分
|
||||
|
||||
Returns:
|
||||
完整的 Redis key
|
||||
"""
|
||||
return ":".join([cls.PREFIX] + list(parts))
|
||||
|
||||
@classmethod
|
||||
async def set_emotion_suggestions(
|
||||
cls,
|
||||
user_id: str,
|
||||
suggestions_data: Dict[str, Any],
|
||||
expire: int = 86400
|
||||
) -> bool:
|
||||
"""设置用户情绪建议缓存
|
||||
|
||||
Args:
|
||||
user_id: 用户ID(end_user_id)
|
||||
suggestions_data: 建议数据字典,包含:
|
||||
- health_summary: 健康状态摘要
|
||||
- suggestions: 建议列表
|
||||
- generated_at: 生成时间(可选)
|
||||
expire: 过期时间(秒),默认24小时(86400秒)
|
||||
|
||||
Returns:
|
||||
是否设置成功
|
||||
"""
|
||||
try:
|
||||
key = cls._get_key("suggestions", user_id)
|
||||
|
||||
# 添加生成时间戳
|
||||
if "generated_at" not in suggestions_data:
|
||||
suggestions_data["generated_at"] = datetime.now().isoformat()
|
||||
|
||||
# 添加缓存标记
|
||||
suggestions_data["cached"] = True
|
||||
|
||||
value = json.dumps(suggestions_data, ensure_ascii=False)
|
||||
await aio_redis.set(key, value, ex=expire)
|
||||
logger.info(f"设置情绪建议缓存成功: {key}, 过期时间: {expire}秒")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"设置情绪建议缓存失败: {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
async def get_emotion_suggestions(cls, user_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""获取用户情绪建议缓存
|
||||
|
||||
Args:
|
||||
user_id: 用户ID(end_user_id)
|
||||
|
||||
Returns:
|
||||
建议数据字典,如果不存在或已过期返回 None
|
||||
"""
|
||||
try:
|
||||
key = cls._get_key("suggestions", user_id)
|
||||
value = await aio_redis.get(key)
|
||||
|
||||
if value:
|
||||
data = json.loads(value)
|
||||
logger.info(f"成功获取情绪建议缓存: {key}")
|
||||
return data
|
||||
|
||||
logger.info(f"情绪建议缓存不存在或已过期: {key}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"获取情绪建议缓存失败: {e}", exc_info=True)
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
async def delete_emotion_suggestions(cls, user_id: str) -> bool:
|
||||
"""删除用户情绪建议缓存
|
||||
|
||||
Args:
|
||||
user_id: 用户ID(end_user_id)
|
||||
|
||||
Returns:
|
||||
是否删除成功
|
||||
"""
|
||||
try:
|
||||
key = cls._get_key("suggestions", user_id)
|
||||
result = await aio_redis.delete(key)
|
||||
logger.info(f"删除情绪建议缓存: {key}, 结果: {result}")
|
||||
return result > 0
|
||||
except Exception as e:
|
||||
logger.error(f"删除情绪建议缓存失败: {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
async def get_suggestions_ttl(cls, user_id: str) -> int:
|
||||
"""获取情绪建议缓存的剩余过期时间
|
||||
|
||||
Args:
|
||||
user_id: 用户ID(end_user_id)
|
||||
|
||||
Returns:
|
||||
剩余秒数,-1表示永不过期,-2表示key不存在
|
||||
"""
|
||||
try:
|
||||
key = cls._get_key("suggestions", user_id)
|
||||
ttl = await aio_redis.ttl(key)
|
||||
logger.debug(f"情绪建议缓存TTL: {key} = {ttl}秒")
|
||||
return ttl
|
||||
except Exception as e:
|
||||
logger.error(f"获取情绪建议缓存TTL失败: {e}")
|
||||
return -2
|
||||
136
api/app/cache/memory/implicit_memory.py
vendored
136
api/app/cache/memory/implicit_memory.py
vendored
@@ -1,136 +0,0 @@
|
||||
"""
|
||||
Implicit Memory Profile Cache
|
||||
|
||||
隐式记忆用户画像缓存模块
|
||||
用于缓存用户的完整画像数据(偏好标签、四维画像、兴趣领域、行为习惯)
|
||||
"""
|
||||
import json
|
||||
import logging
|
||||
from typing import Optional, Dict, Any
|
||||
from datetime import datetime
|
||||
|
||||
from app.aioRedis import aio_redis
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ImplicitMemoryCache:
|
||||
"""隐式记忆用户画像缓存类"""
|
||||
|
||||
# Key 前缀
|
||||
PREFIX = "cache:memory:implicit_memory"
|
||||
|
||||
@classmethod
|
||||
def _get_key(cls, *parts: str) -> str:
|
||||
"""生成 Redis key
|
||||
|
||||
Args:
|
||||
*parts: key 的各个部分
|
||||
|
||||
Returns:
|
||||
完整的 Redis key
|
||||
"""
|
||||
return ":".join([cls.PREFIX] + list(parts))
|
||||
|
||||
@classmethod
|
||||
async def set_user_profile(
|
||||
cls,
|
||||
user_id: str,
|
||||
profile_data: Dict[str, Any],
|
||||
expire: int = 86400
|
||||
) -> bool:
|
||||
"""设置用户完整画像缓存
|
||||
|
||||
Args:
|
||||
user_id: 用户ID(end_user_id)
|
||||
profile_data: 画像数据字典,包含:
|
||||
- preferences: 偏好标签列表
|
||||
- portrait: 四维画像对象
|
||||
- interest_areas: 兴趣领域分布对象
|
||||
- habits: 行为习惯列表
|
||||
- generated_at: 生成时间(可选)
|
||||
expire: 过期时间(秒),默认24小时(86400秒)
|
||||
|
||||
Returns:
|
||||
是否设置成功
|
||||
"""
|
||||
try:
|
||||
key = cls._get_key("profile", user_id)
|
||||
|
||||
# 添加生成时间戳
|
||||
if "generated_at" not in profile_data:
|
||||
profile_data["generated_at"] = datetime.now().isoformat()
|
||||
|
||||
# 添加缓存标记
|
||||
profile_data["cached"] = True
|
||||
|
||||
value = json.dumps(profile_data, ensure_ascii=False)
|
||||
await aio_redis.set(key, value, ex=expire)
|
||||
logger.info(f"设置用户画像缓存成功: {key}, 过期时间: {expire}秒")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"设置用户画像缓存失败: {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
async def get_user_profile(cls, user_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""获取用户完整画像缓存
|
||||
|
||||
Args:
|
||||
user_id: 用户ID(end_user_id)
|
||||
|
||||
Returns:
|
||||
画像数据字典,如果不存在或已过期返回 None
|
||||
"""
|
||||
try:
|
||||
key = cls._get_key("profile", user_id)
|
||||
value = await aio_redis.get(key)
|
||||
|
||||
if value:
|
||||
data = json.loads(value)
|
||||
logger.info(f"成功获取用户画像缓存: {key}")
|
||||
return data
|
||||
|
||||
logger.info(f"用户画像缓存不存在或已过期: {key}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"获取用户画像缓存失败: {e}", exc_info=True)
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
async def delete_user_profile(cls, user_id: str) -> bool:
|
||||
"""删除用户完整画像缓存
|
||||
|
||||
Args:
|
||||
user_id: 用户ID(end_user_id)
|
||||
|
||||
Returns:
|
||||
是否删除成功
|
||||
"""
|
||||
try:
|
||||
key = cls._get_key("profile", user_id)
|
||||
result = await aio_redis.delete(key)
|
||||
logger.info(f"删除用户画像缓存: {key}, 结果: {result}")
|
||||
return result > 0
|
||||
except Exception as e:
|
||||
logger.error(f"删除用户画像缓存失败: {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
async def get_profile_ttl(cls, user_id: str) -> int:
|
||||
"""获取用户画像缓存的剩余过期时间
|
||||
|
||||
Args:
|
||||
user_id: 用户ID(end_user_id)
|
||||
|
||||
Returns:
|
||||
剩余秒数,-1表示永不过期,-2表示key不存在
|
||||
"""
|
||||
try:
|
||||
key = cls._get_key("profile", user_id)
|
||||
ttl = await aio_redis.ttl(key)
|
||||
logger.debug(f"用户画像缓存TTL: {key} = {ttl}秒")
|
||||
return ttl
|
||||
except Exception as e:
|
||||
logger.error(f"获取用户画像缓存TTL失败: {e}")
|
||||
return -2
|
||||
122
api/app/cache/memory/interest_memory.py
vendored
Normal file
122
api/app/cache/memory/interest_memory.py
vendored
Normal file
@@ -0,0 +1,122 @@
|
||||
"""
|
||||
Interest Distribution Cache
|
||||
|
||||
兴趣分布缓存模块
|
||||
用于缓存用户的兴趣分布标签数据,避免重复调用模型生成
|
||||
"""
|
||||
import json
|
||||
import logging
|
||||
from typing import Optional, List, Dict, Any
|
||||
from datetime import datetime
|
||||
|
||||
from app.aioRedis import get_thread_safe_redis
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 缓存过期时间:24小时
|
||||
INTEREST_CACHE_EXPIRE = 86400
|
||||
|
||||
|
||||
class InterestMemoryCache:
|
||||
"""兴趣分布缓存类"""
|
||||
|
||||
PREFIX = "cache:memory:interest_distribution"
|
||||
|
||||
@classmethod
|
||||
def _get_key(cls, end_user_id: str, language: str) -> str:
|
||||
"""生成 Redis key
|
||||
|
||||
Args:
|
||||
end_user_id: 用户ID
|
||||
language: 语言类型
|
||||
|
||||
Returns:
|
||||
完整的 Redis key
|
||||
"""
|
||||
return f"{cls.PREFIX}:by_user:{end_user_id}:{language}"
|
||||
|
||||
@classmethod
|
||||
async def set_interest_distribution(
|
||||
cls,
|
||||
end_user_id: str,
|
||||
language: str,
|
||||
data: List[Dict[str, Any]],
|
||||
expire: int = INTEREST_CACHE_EXPIRE,
|
||||
) -> bool:
|
||||
"""设置用户兴趣分布缓存
|
||||
|
||||
Args:
|
||||
end_user_id: 用户ID
|
||||
language: 语言类型
|
||||
data: 兴趣分布列表,格式 [{"name": "...", "frequency": ...}, ...]
|
||||
expire: 过期时间(秒),默认24小时
|
||||
|
||||
Returns:
|
||||
是否设置成功
|
||||
"""
|
||||
try:
|
||||
key = cls._get_key(end_user_id, language)
|
||||
payload = {
|
||||
"data": data,
|
||||
"generated_at": datetime.now().isoformat(),
|
||||
"cached": True,
|
||||
}
|
||||
value = json.dumps(payload, ensure_ascii=False)
|
||||
await get_thread_safe_redis().set(key, value, ex=expire)
|
||||
logger.info(f"设置兴趣分布缓存成功: {key}, 过期时间: {expire}秒")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"设置兴趣分布缓存失败: {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
async def get_interest_distribution(
|
||||
cls,
|
||||
end_user_id: str,
|
||||
language: str,
|
||||
) -> Optional[List[Dict[str, Any]]]:
|
||||
"""获取用户兴趣分布缓存
|
||||
|
||||
Args:
|
||||
end_user_id: 用户ID
|
||||
language: 语言类型
|
||||
|
||||
Returns:
|
||||
兴趣分布列表,缓存不存在或已过期返回 None
|
||||
"""
|
||||
try:
|
||||
key = cls._get_key(end_user_id, language)
|
||||
value = await get_thread_safe_redis().get(key)
|
||||
if value:
|
||||
payload = json.loads(value)
|
||||
logger.info(f"命中兴趣分布缓存: {key}")
|
||||
return payload.get("data")
|
||||
logger.info(f"兴趣分布缓存不存在或已过期: {key}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"获取兴趣分布缓存失败: {e}", exc_info=True)
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
async def delete_interest_distribution(
|
||||
cls,
|
||||
end_user_id: str,
|
||||
language: str,
|
||||
) -> bool:
|
||||
"""删除用户兴趣分布缓存
|
||||
|
||||
Args:
|
||||
end_user_id: 用户ID
|
||||
language: 语言类型
|
||||
|
||||
Returns:
|
||||
是否删除成功
|
||||
"""
|
||||
try:
|
||||
key = cls._get_key(end_user_id, language)
|
||||
result = await get_thread_safe_redis().delete(key)
|
||||
logger.info(f"删除兴趣分布缓存: {key}, 结果: {result}")
|
||||
return result > 0
|
||||
except Exception as e:
|
||||
logger.error(f"删除兴趣分布缓存失败: {e}", exc_info=True)
|
||||
return False
|
||||
@@ -1,25 +1,58 @@
|
||||
import os
|
||||
import platform
|
||||
import re
|
||||
from datetime import timedelta
|
||||
from urllib.parse import quote
|
||||
|
||||
from celery import Celery
|
||||
from celery.schedules import crontab
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.logging_config import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def _mask_url(url: str) -> str:
|
||||
"""隐藏 URL 中的密码部分,适用于 redis:// 和 amqp:// 等协议"""
|
||||
return re.sub(r'(://[^:]*:)[^@]+(@)', r'\1***\2', url)
|
||||
|
||||
# macOS fork() safety - must be set before any Celery initialization
|
||||
if platform.system() == 'Darwin':
|
||||
os.environ.setdefault('OBJC_DISABLE_INITIALIZE_FORK_SAFETY', 'YES')
|
||||
|
||||
# 创建 Celery 应用实例
|
||||
# broker: 任务队列(使用 Redis DB 0)
|
||||
# backend: 结果存储(使用 Redis DB 10)
|
||||
# broker: 优先使用环境变量 CELERY_BROKER_URL(支持 amqp:// 等任意协议),
|
||||
# 未配置则回退到 Redis 方案
|
||||
# backend: 结果存储(使用 Redis)
|
||||
# NOTE: 不要在 .env 中设置 BROKER_URL / RESULT_BACKEND / CELERY_BROKER / CELERY_BACKEND,
|
||||
# 这些名称会被 Celery CLI 的 Click 框架劫持,详见 docs/celery-env-bug-report.md
|
||||
|
||||
_broker_url = os.getenv("CELERY_BROKER_URL") or \
|
||||
f"redis://:{quote(settings.REDIS_PASSWORD)}@{settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.REDIS_DB_CELERY_BROKER}"
|
||||
_backend_url = f"redis://:{quote(settings.REDIS_PASSWORD)}@{settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.REDIS_DB_CELERY_BACKEND}"
|
||||
os.environ["CELERY_BROKER_URL"] = _broker_url
|
||||
os.environ["CELERY_RESULT_BACKEND"] = _backend_url
|
||||
# Neutralize legacy Celery env vars that can be hijacked by Celery's CLI/Click
|
||||
# integration and accidentally override our canonical URLs.
|
||||
os.environ.pop("BROKER_URL", None)
|
||||
os.environ.pop("RESULT_BACKEND", None)
|
||||
os.environ.pop("CELERY_BROKER", None)
|
||||
os.environ.pop("CELERY_BACKEND", None)
|
||||
|
||||
celery_app = Celery(
|
||||
"redbear_tasks",
|
||||
broker=f"redis://:{quote(settings.REDIS_PASSWORD)}@{settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.CELERY_BROKER}",
|
||||
backend=f"redis://:{quote(settings.REDIS_PASSWORD)}@{settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.CELERY_BACKEND}",
|
||||
broker=_broker_url,
|
||||
backend=_backend_url,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Celery app initialized",
|
||||
extra={
|
||||
"broker": _mask_url(_broker_url),
|
||||
"backend": _mask_url(_backend_url),
|
||||
},
|
||||
)
|
||||
# Default queue for unrouted tasks
|
||||
celery_app.conf.task_default_queue = 'memory_tasks'
|
||||
|
||||
@@ -34,55 +67,64 @@ celery_app.conf.update(
|
||||
accept_content=['json'],
|
||||
result_serializer='json',
|
||||
|
||||
# 时区
|
||||
timezone='Asia/Shanghai',
|
||||
enable_utc=True,
|
||||
# # 时区
|
||||
# timezone='Asia/Shanghai',
|
||||
# enable_utc=False,
|
||||
|
||||
# 任务追踪
|
||||
task_track_started=True,
|
||||
task_ignore_result=False,
|
||||
|
||||
|
||||
# 超时设置
|
||||
task_time_limit=1800, # 30分钟硬超时
|
||||
task_soft_time_limit=1500, # 25分钟软超时
|
||||
|
||||
task_time_limit=3600, # 60分钟硬超时
|
||||
task_soft_time_limit=3000, # 50分钟软超时
|
||||
|
||||
# Worker 设置 (per-worker settings are in docker-compose command line)
|
||||
worker_prefetch_multiplier=1, # Don't hoard tasks, fairer distribution
|
||||
|
||||
worker_redirect_stdouts_level='INFO', # stdout/print → INFO instead of WARNING
|
||||
|
||||
# 结果过期时间
|
||||
result_expires=3600, # 结果保存1小时
|
||||
|
||||
|
||||
# 任务确认设置
|
||||
task_acks_late=True,
|
||||
task_reject_on_worker_lost=True,
|
||||
worker_disable_rate_limits=True,
|
||||
|
||||
|
||||
# FLower setting
|
||||
worker_send_task_events=True,
|
||||
task_send_sent_event=True,
|
||||
|
||||
|
||||
# task routing
|
||||
task_routes={
|
||||
# Memory tasks → memory_tasks queue (threads worker)
|
||||
'app.core.memory.agent.read_message_priority': {'queue': 'memory_tasks'},
|
||||
'app.core.memory.agent.read_message': {'queue': 'memory_tasks'},
|
||||
'app.core.memory.agent.write_message': {'queue': 'memory_tasks'},
|
||||
|
||||
'app.tasks.write_perceptual_memory': {'queue': 'memory_tasks'},
|
||||
|
||||
# Long-term storage tasks → memory_tasks queue (batched write strategies)
|
||||
'app.core.memory.agent.long_term_storage.window': {'queue': 'memory_tasks'},
|
||||
'app.core.memory.agent.long_term_storage.time': {'queue': 'memory_tasks'},
|
||||
'app.core.memory.agent.long_term_storage.aggregate': {'queue': 'memory_tasks'},
|
||||
|
||||
|
||||
# Clustering tasks → memory_tasks queue (使用相同的 worker,避免 macOS fork 问题)
|
||||
'app.tasks.run_incremental_clustering': {'queue': 'memory_tasks'},
|
||||
|
||||
# Document tasks → document_tasks queue (prefork worker)
|
||||
'app.core.rag.tasks.parse_document': {'queue': 'document_tasks'},
|
||||
'app.core.rag.tasks.build_graphrag_for_kb': {'queue': 'document_tasks'},
|
||||
'app.core.rag.tasks.sync_knowledge_for_kb': {'queue': 'document_tasks'},
|
||||
|
||||
|
||||
# Beat/periodic tasks → periodic_tasks queue (dedicated periodic worker)
|
||||
'app.tasks.workspace_reflection_task': {'queue': 'periodic_tasks'},
|
||||
'app.tasks.regenerate_memory_cache': {'queue': 'periodic_tasks'},
|
||||
'app.tasks.run_forgetting_cycle_task': {'queue': 'periodic_tasks'},
|
||||
'app.controllers.memory_storage_controller.search_all': {'queue': 'periodic_tasks'},
|
||||
'app.tasks.write_all_workspaces_memory_task': {'queue': 'periodic_tasks'},
|
||||
'app.tasks.update_implicit_emotions_storage': {'queue': 'periodic_tasks'},
|
||||
'app.tasks.init_implicit_emotions_for_users': {'queue': 'periodic_tasks'},
|
||||
'app.tasks.init_interest_distribution_for_users': {'queue': 'periodic_tasks'},
|
||||
'app.tasks.init_community_clustering_for_users': {'queue': 'periodic_tasks'},
|
||||
},
|
||||
)
|
||||
|
||||
@@ -90,13 +132,16 @@ celery_app.conf.update(
|
||||
celery_app.autodiscover_tasks(['app'])
|
||||
|
||||
# Celery Beat schedule for periodic tasks
|
||||
memory_increment_schedule = timedelta(hours=settings.MEMORY_INCREMENT_INTERVAL_HOURS)
|
||||
memory_increment_schedule = crontab(hour=settings.MEMORY_INCREMENT_HOUR, minute=settings.MEMORY_INCREMENT_MINUTE)
|
||||
memory_cache_regeneration_schedule = timedelta(hours=settings.MEMORY_CACHE_REGENERATION_HOURS)
|
||||
# 这个30秒的设计不合理
|
||||
workspace_reflection_schedule = timedelta(seconds=30) # 每30秒运行一次settings.REFLECTION_INTERVAL_TIME
|
||||
forgetting_cycle_schedule = timedelta(hours=24) # 每24小时运行一次遗忘周期
|
||||
workspace_reflection_schedule = timedelta(seconds=settings.WORKSPACE_REFLECTION_INTERVAL_SECONDS)
|
||||
forgetting_cycle_schedule = timedelta(hours=settings.FORGETTING_CYCLE_INTERVAL_HOURS)
|
||||
implicit_emotions_update_schedule = crontab(
|
||||
hour=settings.IMPLICIT_EMOTIONS_UPDATE_HOUR,
|
||||
minute=settings.IMPLICIT_EMOTIONS_UPDATE_MINUTE,
|
||||
)
|
||||
|
||||
#构建定时任务配置
|
||||
# 构建定时任务配置
|
||||
beat_schedule_config = {
|
||||
"run-workspace-reflection": {
|
||||
"task": "app.tasks.workspace_reflection_task",
|
||||
@@ -115,16 +160,16 @@ beat_schedule_config = {
|
||||
"config_id": None, # 使用默认配置,可以通过环境变量配置
|
||||
},
|
||||
},
|
||||
"write-all-workspaces-memory": {
|
||||
"task": "app.tasks.write_all_workspaces_memory_task",
|
||||
"schedule": memory_increment_schedule,
|
||||
"args": (),
|
||||
},
|
||||
"update-implicit-emotions-storage": {
|
||||
"task": "app.tasks.update_implicit_emotions_storage",
|
||||
"schedule": implicit_emotions_update_schedule,
|
||||
"args": (),
|
||||
},
|
||||
}
|
||||
|
||||
#如果配置了默认工作空间ID,则添加记忆总量统计任务
|
||||
if settings.DEFAULT_WORKSPACE_ID:
|
||||
beat_schedule_config["write-total-memory"] = {
|
||||
"task": "app.controllers.memory_storage_controller.search_all",
|
||||
"schedule": memory_increment_schedule,
|
||||
"kwargs": {
|
||||
"workspace_id": settings.DEFAULT_WORKSPACE_ID,
|
||||
},
|
||||
}
|
||||
|
||||
celery_app.conf.beat_schedule = beat_schedule_config
|
||||
|
||||
@@ -13,4 +13,4 @@ logger.info("Celery worker logging initialized")
|
||||
# 导入任务模块以注册任务
|
||||
import app.tasks
|
||||
|
||||
__all__ = ['celery_app']
|
||||
__all__ = ['celery_app']
|
||||
|
||||
1
api/app/config/__init__.py
Normal file
1
api/app/config/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Configuration module for application settings."""
|
||||
239
api/app/config/default_ontology_config.py
Normal file
239
api/app/config/default_ontology_config.py
Normal file
@@ -0,0 +1,239 @@
|
||||
"""默认本体场景配置
|
||||
|
||||
本模块定义系统预设的本体场景和实体类型配置。
|
||||
这些配置用于在工作空间创建时自动初始化默认场景。
|
||||
支持中英文双语配置,根据用户语言偏好创建对应语言的场景。
|
||||
"""
|
||||
|
||||
# 在线教育场景配置
|
||||
ONLINE_EDUCATION_SCENE = {
|
||||
"name_chinese": "在线教育",
|
||||
"name_english": "Online Education",
|
||||
"description_chinese": "适用于在线教育平台的本体建模,包含学生、教师、课程等核心实体类型",
|
||||
"description_english": "Ontology modeling for online education platforms, including core entity types such as students, teachers, and courses",
|
||||
"types": [
|
||||
{
|
||||
"name_chinese": "学生",
|
||||
"name_english": "Student",
|
||||
"description_chinese": "在教育系统中接受教育的个体,包含姓名、学号、年级、班级等属性",
|
||||
"description_english": "Individuals receiving education in the education system, including attributes such as name, student ID, grade, and class"
|
||||
},
|
||||
{
|
||||
"name_chinese": "教师",
|
||||
"name_english": "Teacher",
|
||||
"description_chinese": "在教育系统中提供教学服务的个体,包含姓名、工号、任教学科、职称等属性",
|
||||
"description_english": "Individuals providing teaching services in the education system, including attributes such as name, employee ID, teaching subject, and title"
|
||||
},
|
||||
{
|
||||
"name_chinese": "课程",
|
||||
"name_english": "Course",
|
||||
"description_chinese": "教育系统中的教学内容单元,包含课程名称、课程代码、学分、学时等属性",
|
||||
"description_english": "Teaching content units in the education system, including attributes such as course name, course code, credits, and class hours"
|
||||
},
|
||||
{
|
||||
"name_chinese": "作业",
|
||||
"name_english": "Assignment",
|
||||
"description_chinese": "课程中布置的学习任务,包含作业标题、截止日期、所属课程、提交状态等属性",
|
||||
"description_english": "Learning tasks assigned in courses, including attributes such as assignment title, deadline, course, and submission status"
|
||||
},
|
||||
{
|
||||
"name_chinese": "成绩",
|
||||
"name_english": "Grade",
|
||||
"description_chinese": "学生学习成果的评价结果,包含分数、评级、考试类型、所属课程等属性",
|
||||
"description_english": "Evaluation results of student learning outcomes, including attributes such as score, rating, exam type, and course"
|
||||
},
|
||||
{
|
||||
"name_chinese": "考试",
|
||||
"name_english": "Exam",
|
||||
"description_chinese": "评估学生学习成果的测试活动,包含考试名称、时间、地点、科目等属性",
|
||||
"description_english": "Test activities to assess student learning outcomes, including attributes such as exam name, time, location, and subject"
|
||||
},
|
||||
{
|
||||
"name_chinese": "教室",
|
||||
"name_english": "Classroom",
|
||||
"description_chinese": "进行教学活动的物理或虚拟空间,包含教室编号、容量、设备等属性",
|
||||
"description_english": "Physical or virtual spaces for teaching activities, including attributes such as classroom number, capacity, and equipment"
|
||||
},
|
||||
{
|
||||
"name_chinese": "学科",
|
||||
"name_english": "Subject",
|
||||
"description_chinese": "知识的分类领域,包含学科名称、代码、所属院系等属性",
|
||||
"description_english": "Classification domains of knowledge, including attributes such as subject name, code, and department"
|
||||
},
|
||||
{
|
||||
"name_chinese": "教材",
|
||||
"name_english": "Textbook",
|
||||
"description_chinese": "教学使用的书籍或资料,包含书名、作者、出版社、ISBN等属性",
|
||||
"description_english": "Books or materials used for teaching, including attributes such as title, author, publisher, and ISBN"
|
||||
},
|
||||
{
|
||||
"name_chinese": "班级",
|
||||
"name_english": "Class",
|
||||
"description_chinese": "学生的组织单位,包含班级名称、年级、人数、班主任等属性",
|
||||
"description_english": "Organizational units of students, including attributes such as class name, grade, number of students, and class teacher"
|
||||
},
|
||||
{
|
||||
"name_chinese": "学期",
|
||||
"name_english": "Semester",
|
||||
"description_chinese": "教学时间的划分单位,包含学期名称、开始时间、结束时间等属性",
|
||||
"description_english": "Time division units for teaching, including attributes such as semester name, start time, and end time"
|
||||
},
|
||||
{
|
||||
"name_chinese": "课时",
|
||||
"name_english": "Class Hour",
|
||||
"description_chinese": "课程的时间单位,包含上课时间、地点、教师、课程等属性",
|
||||
"description_english": "Time units of courses, including attributes such as class time, location, teacher, and course"
|
||||
},
|
||||
{
|
||||
"name_chinese": "教学计划",
|
||||
"name_english": "Teaching Plan",
|
||||
"description_chinese": "课程的教学安排,包含教学目标、内容安排、进度计划等属性",
|
||||
"description_english": "Teaching arrangements for courses, including attributes such as teaching objectives, content arrangement, and progress plan"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
# 情感陪伴场景配置
|
||||
EMOTIONAL_COMPANION_SCENE = {
|
||||
"name_chinese": "情感陪伴",
|
||||
"name_english": "Emotional Companion",
|
||||
"description_chinese": "适用于情感陪伴应用的本体建模,包含用户、情绪、活动等核心实体类型",
|
||||
"description_english": "Ontology modeling for emotional companion applications, including core entity types such as users, emotions, and activities",
|
||||
"types": [
|
||||
{
|
||||
"name_chinese": "用户",
|
||||
"name_english": "User",
|
||||
"description_chinese": "使用情感陪伴服务的个体,包含姓名、昵称、性格特征、偏好等属性",
|
||||
"description_english": "Individuals using emotional companion services, including attributes such as name, nickname, personality traits, and preferences"
|
||||
},
|
||||
{
|
||||
"name_chinese": "情绪",
|
||||
"name_english": "Emotion",
|
||||
"description_chinese": "用户的情感状态,包含情绪类型、强度、触发原因、持续时间等属性",
|
||||
"description_english": "Emotional states of users, including attributes such as emotion type, intensity, trigger cause, and duration"
|
||||
},
|
||||
{
|
||||
"name_chinese": "活动",
|
||||
"name_english": "Activity",
|
||||
"description_chinese": "用户参与的各类活动,包含活动名称、类型、参与者、时间地点等属性",
|
||||
"description_english": "Various activities users participate in, including attributes such as activity name, type, participants, time, and location"
|
||||
},
|
||||
{
|
||||
"name_chinese": "对话",
|
||||
"name_english": "Conversation",
|
||||
"description_chinese": "用户之间的交流记录,包含对话主题、参与者、时间、关键内容等属性",
|
||||
"description_english": "Communication records between users, including attributes such as conversation topic, participants, time, and key content"
|
||||
},
|
||||
{
|
||||
"name_chinese": "兴趣爱好",
|
||||
"name_english": "Hobby",
|
||||
"description_chinese": "用户的兴趣和爱好,包含爱好名称、类别、熟练程度、相关活动等属性",
|
||||
"description_english": "User interests and hobbies, including attributes such as hobby name, category, proficiency level, and related activities"
|
||||
},
|
||||
{
|
||||
"name_chinese": "日常事件",
|
||||
"name_english": "Daily Event",
|
||||
"description_chinese": "用户日常生活中的事件,包含事件描述、时间、地点、相关人物等属性",
|
||||
"description_english": "Events in users' daily lives, including attributes such as event description, time, location, and related people"
|
||||
},
|
||||
{
|
||||
"name_chinese": "关系",
|
||||
"name_english": "Relationship",
|
||||
"description_chinese": "用户之间的社会关系,包含关系类型、亲密度、建立时间等属性",
|
||||
"description_english": "Social relationships between users, including attributes such as relationship type, intimacy, and establishment time"
|
||||
},
|
||||
{
|
||||
"name_chinese": "回忆",
|
||||
"name_english": "Memory",
|
||||
"description_chinese": "用户的重要记忆片段,包含回忆内容、时间、地点、相关人物等属性",
|
||||
"description_english": "Important memory fragments of users, including attributes such as memory content, time, location, and related people"
|
||||
},
|
||||
{
|
||||
"name_chinese": "地点",
|
||||
"name_english": "Location",
|
||||
"description_chinese": "用户活动的地理位置,包含地点名称、地址、类型、相关事件等属性",
|
||||
"description_english": "Geographic locations of user activities, including attributes such as location name, address, type, and related events"
|
||||
},
|
||||
{
|
||||
"name_chinese": "时间节点",
|
||||
"name_english": "Time Point",
|
||||
"description_chinese": "重要的时间标记,包含日期、事件、意义等属性",
|
||||
"description_english": "Important time markers, including attributes such as date, event, and significance"
|
||||
},
|
||||
{
|
||||
"name_chinese": "目标",
|
||||
"name_english": "Goal",
|
||||
"description_chinese": "用户设定的目标,包含目标描述、截止时间、完成状态、相关活动等属性",
|
||||
"description_english": "Goals set by users, including attributes such as goal description, deadline, completion status, and related activities"
|
||||
},
|
||||
{
|
||||
"name_chinese": "成就",
|
||||
"name_english": "Achievement",
|
||||
"description_chinese": "用户获得的成就,包含成就名称、获得时间、描述、相关目标等属性",
|
||||
"description_english": "Achievements obtained by users, including attributes such as achievement name, acquisition time, description, and related goals"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
# 导出默认场景列表
|
||||
DEFAULT_SCENES = [ONLINE_EDUCATION_SCENE, EMOTIONAL_COMPANION_SCENE]
|
||||
|
||||
|
||||
def get_scene_name(scene_config: dict, language: str = "zh") -> str:
|
||||
"""获取场景名称(根据语言)
|
||||
|
||||
Args:
|
||||
scene_config: 场景配置字典
|
||||
language: 语言类型 ("zh" 或 "en")
|
||||
|
||||
Returns:
|
||||
对应语言的场景名称
|
||||
"""
|
||||
if language == "en":
|
||||
return scene_config.get("name_english", scene_config.get("name_chinese"))
|
||||
return scene_config.get("name_chinese")
|
||||
|
||||
|
||||
def get_scene_description(scene_config: dict, language: str = "zh") -> str:
|
||||
"""获取场景描述(根据语言)
|
||||
|
||||
Args:
|
||||
scene_config: 场景配置字典
|
||||
language: 语言类型 ("zh" 或 "en")
|
||||
|
||||
Returns:
|
||||
对应语言的场景描述
|
||||
"""
|
||||
if language == "en":
|
||||
return scene_config.get("description_english", scene_config.get("description_chinese"))
|
||||
return scene_config.get("description_chinese")
|
||||
|
||||
|
||||
def get_type_name(type_config: dict, language: str = "zh") -> str:
|
||||
"""获取类型名称(根据语言)
|
||||
|
||||
Args:
|
||||
type_config: 类型配置字典
|
||||
language: 语言类型 ("zh" 或 "en")
|
||||
|
||||
Returns:
|
||||
对应语言的类型名称
|
||||
"""
|
||||
if language == "en":
|
||||
return type_config.get("name_english", type_config.get("name_chinese"))
|
||||
return type_config.get("name_chinese")
|
||||
|
||||
|
||||
def get_type_description(type_config: dict, language: str = "zh") -> str:
|
||||
"""获取类型描述(根据语言)
|
||||
|
||||
Args:
|
||||
type_config: 类型配置字典
|
||||
language: 语言类型 ("zh" 或 "en")
|
||||
|
||||
Returns:
|
||||
对应语言的类型描述
|
||||
"""
|
||||
if language == "en":
|
||||
return type_config.get("description_english", type_config.get("description_chinese"))
|
||||
return type_config.get("description_chinese")
|
||||
249
api/app/config/default_ontology_initializer.py
Normal file
249
api/app/config/default_ontology_initializer.py
Normal file
@@ -0,0 +1,249 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""默认本体场景初始化器
|
||||
|
||||
本模块提供默认本体场景和类型的自动初始化功能。
|
||||
在工作空间创建时,自动添加预设的本体场景和实体类型。
|
||||
|
||||
Classes:
|
||||
DefaultOntologyInitializer: 默认本体场景初始化器
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import List, Optional, Tuple
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.config.default_ontology_config import (
|
||||
DEFAULT_SCENES,
|
||||
get_scene_name,
|
||||
get_scene_description,
|
||||
get_type_name,
|
||||
get_type_description,
|
||||
)
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.repositories.ontology_scene_repository import OntologySceneRepository
|
||||
from app.repositories.ontology_class_repository import OntologyClassRepository
|
||||
|
||||
|
||||
class DefaultOntologyInitializer:
|
||||
"""默认本体场景初始化器
|
||||
|
||||
负责在工作空间创建时自动初始化默认的本体场景和类型。
|
||||
遵循最小侵入原则,确保初始化失败不阻止工作空间创建。
|
||||
|
||||
Attributes:
|
||||
db: 数据库会话
|
||||
scene_repo: 场景Repository
|
||||
class_repo: 类型Repository
|
||||
logger: 业务日志记录器
|
||||
"""
|
||||
|
||||
def __init__(self, db: Session):
|
||||
"""初始化
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
"""
|
||||
self.db = db
|
||||
self.scene_repo = OntologySceneRepository(db)
|
||||
self.class_repo = OntologyClassRepository(db)
|
||||
self.logger = get_business_logger()
|
||||
|
||||
def initialize_default_scenes(
|
||||
self,
|
||||
workspace_id: UUID,
|
||||
language: str = "zh"
|
||||
) -> Tuple[bool, str]:
|
||||
"""为工作空间初始化默认场景
|
||||
|
||||
创建两个默认场景(在线教育、情感陪伴)及其对应的实体类型。
|
||||
如果创建失败,记录错误日志但不抛出异常。
|
||||
|
||||
Args:
|
||||
workspace_id: 工作空间ID
|
||||
language: 语言类型 ("zh" 或 "en"),默认为 "zh"
|
||||
|
||||
Returns:
|
||||
Tuple[bool, str]: (是否成功, 错误信息)
|
||||
"""
|
||||
try:
|
||||
self.logger.info(
|
||||
f"开始初始化默认本体场景 - workspace_id={workspace_id}, language={language}"
|
||||
)
|
||||
|
||||
scenes_created = 0
|
||||
total_types_created = 0
|
||||
|
||||
# 遍历默认场景配置
|
||||
for scene_config in DEFAULT_SCENES:
|
||||
scene_name = get_scene_name(scene_config, language)
|
||||
|
||||
# 创建场景及其类型
|
||||
scene_id = self._create_scene_with_types(workspace_id, scene_config, language)
|
||||
|
||||
if scene_id:
|
||||
scenes_created += 1
|
||||
# 统计类型数量
|
||||
types_count = len(scene_config.get("types", []))
|
||||
total_types_created += types_count
|
||||
|
||||
self.logger.info(
|
||||
f"场景创建成功 - scene_name={scene_name}, "
|
||||
f"scene_id={scene_id}, types_count={types_count}, language={language}"
|
||||
)
|
||||
else:
|
||||
self.logger.warning(
|
||||
f"场景创建失败 - scene_name={scene_name}, "
|
||||
f"workspace_id={workspace_id}, language={language}"
|
||||
)
|
||||
|
||||
# 记录总体结果
|
||||
self.logger.info(
|
||||
f"默认场景初始化完成 - workspace_id={workspace_id}, "
|
||||
f"language={language}, scenes_created={scenes_created}, "
|
||||
f"total_types_created={total_types_created}"
|
||||
)
|
||||
|
||||
# 如果至少创建了一个场景,视为成功
|
||||
if scenes_created > 0:
|
||||
return True, ""
|
||||
else:
|
||||
error_msg = "所有默认场景创建失败"
|
||||
self.logger.error(
|
||||
f"默认场景初始化失败 - workspace_id={workspace_id}, "
|
||||
f"language={language}, error={error_msg}"
|
||||
)
|
||||
return False, error_msg
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"默认场景初始化异常: {str(e)}"
|
||||
self.logger.error(
|
||||
f"默认场景初始化异常 - workspace_id={workspace_id}, "
|
||||
f"language={language}, error={str(e)}",
|
||||
exc_info=True
|
||||
)
|
||||
return False, error_msg
|
||||
|
||||
def _create_scene_with_types(
|
||||
self,
|
||||
workspace_id: UUID,
|
||||
scene_config: dict,
|
||||
language: str = "zh"
|
||||
) -> Optional[UUID]:
|
||||
"""创建场景及其类型
|
||||
|
||||
Args:
|
||||
workspace_id: 工作空间ID
|
||||
scene_config: 场景配置字典
|
||||
language: 语言类型 ("zh" 或 "en")
|
||||
|
||||
Returns:
|
||||
Optional[UUID]: 创建的场景ID,失败返回None
|
||||
"""
|
||||
try:
|
||||
scene_name = get_scene_name(scene_config, language)
|
||||
scene_description = get_scene_description(scene_config, language)
|
||||
|
||||
# 检查是否已存在同名场景(支持向后兼容)
|
||||
existing_scene = self.scene_repo.get_by_name(scene_name, workspace_id)
|
||||
if existing_scene:
|
||||
self.logger.info(
|
||||
f"场景已存在,跳过创建 - scene_name={scene_name}, "
|
||||
f"workspace_id={workspace_id}, scene_id={existing_scene.scene_id}, "
|
||||
f"language={language}"
|
||||
)
|
||||
return None
|
||||
|
||||
# 创建场景记录,设置 is_system_default=true
|
||||
scene_data = {
|
||||
"scene_name": scene_name,
|
||||
"scene_description": scene_description
|
||||
}
|
||||
|
||||
scene = self.scene_repo.create(scene_data, workspace_id)
|
||||
|
||||
# 设置系统默认标识
|
||||
scene.is_system_default = True
|
||||
self.db.flush()
|
||||
|
||||
self.logger.info(
|
||||
f"场景创建成功 - scene_name={scene_name}, "
|
||||
f"scene_id={scene.scene_id}, is_system_default=True, language={language}"
|
||||
)
|
||||
|
||||
# 批量创建类型
|
||||
types_config = scene_config.get("types", [])
|
||||
types_created = self._batch_create_types(scene.scene_id, types_config, language)
|
||||
|
||||
self.logger.info(
|
||||
f"场景类型创建完成 - scene_id={scene.scene_id}, "
|
||||
f"types_created={types_created}/{len(types_config)}, language={language}"
|
||||
)
|
||||
|
||||
return scene.scene_id
|
||||
|
||||
except Exception as e:
|
||||
scene_name = get_scene_name(scene_config, language)
|
||||
self.logger.error(
|
||||
f"场景创建失败 - scene_name={scene_name}, "
|
||||
f"workspace_id={workspace_id}, language={language}, error={str(e)}",
|
||||
exc_info=True
|
||||
)
|
||||
return None
|
||||
|
||||
def _batch_create_types(
|
||||
self,
|
||||
scene_id: UUID,
|
||||
types_config: List[dict],
|
||||
language: str = "zh"
|
||||
) -> int:
|
||||
"""批量创建实体类型
|
||||
|
||||
Args:
|
||||
scene_id: 场景ID
|
||||
types_config: 类型配置列表
|
||||
language: 语言类型 ("zh" 或 "en")
|
||||
|
||||
Returns:
|
||||
int: 成功创建的类型数量
|
||||
"""
|
||||
created_count = 0
|
||||
|
||||
for type_config in types_config:
|
||||
try:
|
||||
type_name = get_type_name(type_config, language)
|
||||
type_description = get_type_description(type_config, language)
|
||||
|
||||
# 创建类型数据
|
||||
class_data = {
|
||||
"class_name": type_name,
|
||||
"class_description": type_description
|
||||
}
|
||||
|
||||
# 创建类型
|
||||
ontology_class = self.class_repo.create(class_data, scene_id)
|
||||
|
||||
# 设置系统默认标识
|
||||
ontology_class.is_system_default = True
|
||||
self.db.flush()
|
||||
|
||||
created_count += 1
|
||||
|
||||
self.logger.debug(
|
||||
f"类型创建成功 - class_name={type_name}, "
|
||||
f"class_id={ontology_class.class_id}, "
|
||||
f"scene_id={scene_id}, is_system_default=True, language={language}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
type_name = get_type_name(type_config, language)
|
||||
self.logger.warning(
|
||||
f"单个类型创建失败,继续创建其他类型 - "
|
||||
f"class_name={type_name}, scene_id={scene_id}, "
|
||||
f"language={language}, error={str(e)}"
|
||||
)
|
||||
# 继续创建其他类型
|
||||
continue
|
||||
|
||||
return created_count
|
||||
@@ -8,14 +8,17 @@ from fastapi import APIRouter
|
||||
from . import (
|
||||
api_key_controller,
|
||||
app_controller,
|
||||
app_log_controller,
|
||||
auth_controller,
|
||||
chunk_controller,
|
||||
document_controller,
|
||||
emotion_config_controller,
|
||||
emotion_controller,
|
||||
end_user_controller,
|
||||
file_controller,
|
||||
file_storage_controller,
|
||||
home_page_controller,
|
||||
i18n_controller,
|
||||
implicit_memory_controller,
|
||||
knowledge_controller,
|
||||
knowledgeshare_controller,
|
||||
@@ -68,6 +71,7 @@ manager_router.include_router(chunk_controller.router)
|
||||
manager_router.include_router(test_controller.router)
|
||||
manager_router.include_router(knowledgeshare_controller.router)
|
||||
manager_router.include_router(app_controller.router)
|
||||
manager_router.include_router(app_log_controller.router)
|
||||
manager_router.include_router(upload_controller.router)
|
||||
manager_router.include_router(memory_agent_controller.router)
|
||||
manager_router.include_router(memory_dashboard_controller.router)
|
||||
@@ -94,5 +98,7 @@ manager_router.include_router(memory_working_controller.router)
|
||||
manager_router.include_router(file_storage_controller.router)
|
||||
manager_router.include_router(ontology_controller.router)
|
||||
manager_router.include_router(skill_controller.router)
|
||||
manager_router.include_router(i18n_controller.router)
|
||||
manager_router.include_router(end_user_controller.router)
|
||||
|
||||
__all__ = ["manager_router"]
|
||||
|
||||
@@ -1,9 +1,12 @@
|
||||
import uuid
|
||||
import io
|
||||
from typing import Optional, Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends, Path
|
||||
import yaml
|
||||
from fastapi import APIRouter, Depends, Path, Form, UploadFile, File
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy.orm import Session
|
||||
from urllib.parse import quote
|
||||
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.logging_config import get_business_logger
|
||||
@@ -17,12 +20,14 @@ from app.repositories.end_user_repository import EndUserRepository
|
||||
from app.schemas import app_schema
|
||||
from app.schemas.response_schema import PageData, PageMeta
|
||||
from app.schemas.workflow_schema import WorkflowConfig as WorkflowConfigSchema
|
||||
from app.schemas.workflow_schema import WorkflowConfigUpdate
|
||||
from app.schemas.workflow_schema import WorkflowConfigUpdate, WorkflowImportSave
|
||||
from app.services import app_service, workspace_service
|
||||
from app.services.agent_config_helper import enrich_agent_config
|
||||
from app.services.app_service import AppService
|
||||
from app.services.workflow_service import WorkflowService, get_workflow_service
|
||||
from app.services.app_statistics_service import AppStatisticsService
|
||||
from app.services.workflow_import_service import WorkflowImportService
|
||||
from app.services.workflow_service import WorkflowService, get_workflow_service
|
||||
from app.services.app_dsl_service import AppDslService
|
||||
|
||||
router = APIRouter(prefix="/apps", tags=["Apps"])
|
||||
logger = get_business_logger()
|
||||
@@ -48,6 +53,7 @@ def list_apps(
|
||||
status: str | None = None,
|
||||
search: str | None = None,
|
||||
include_shared: bool = True,
|
||||
shared_only: bool = False,
|
||||
page: int = 1,
|
||||
pagesize: int = 10,
|
||||
ids: Optional[str] = None,
|
||||
@@ -59,16 +65,42 @@ def list_apps(
|
||||
- 默认包含本工作空间的应用和分享给本工作空间的应用
|
||||
- 设置 include_shared=false 可以只查看本工作空间的应用
|
||||
- 当提供 ids 参数时,按逗号分割获取指定应用,不分页
|
||||
- search 参数支持:应用名称模糊搜索、API Key 精确搜索
|
||||
"""
|
||||
from sqlalchemy import select as sa_select
|
||||
from app.models.api_key_model import ApiKey
|
||||
|
||||
workspace_id = current_user.current_workspace_id
|
||||
service = app_service.AppService(db)
|
||||
|
||||
# 当 ids 存在且不为 None 时,根据 ids 获取应用
|
||||
# 通过 search 参数搜索:支持应用名称模糊搜索和 API Key 精确搜索
|
||||
if search:
|
||||
search = search.strip()
|
||||
# 尝试作为 API Key 精确匹配(API Key 通常较长)
|
||||
if len(search) >= 10:
|
||||
matched_id = db.execute(
|
||||
sa_select(ApiKey.resource_id).where(
|
||||
ApiKey.workspace_id == workspace_id,
|
||||
ApiKey.api_key == search,
|
||||
ApiKey.resource_id.isnot(None),
|
||||
)
|
||||
).scalar_one_or_none()
|
||||
if matched_id:
|
||||
# 找到 API Key,直接返回关联的应用
|
||||
ids = str(matched_id)
|
||||
|
||||
# 当 ids 存在时,根据 ids 获取应用(不分页)
|
||||
if ids is not None:
|
||||
app_ids = [id.strip() for id in ids.split(',') if id.strip()]
|
||||
items_orm = app_service.get_apps_by_ids(db, app_ids, workspace_id)
|
||||
items = [service._convert_to_schema(app, workspace_id) for app in items_orm]
|
||||
return success(data=items)
|
||||
app_ids = [app_id.strip() for app_id in ids.split(',') if app_id.strip()]
|
||||
if app_ids:
|
||||
items_orm = app_service.get_apps_by_ids(db, app_ids, workspace_id)
|
||||
items = [service._convert_to_schema(app, workspace_id) for app in items_orm]
|
||||
# 返回标准分页格式
|
||||
meta = PageMeta(page=1, pagesize=len(items), total=len(items), hasnext=False)
|
||||
return success(data=PageData(page=meta, items=items))
|
||||
# ids 为空时,返回空列表
|
||||
meta = PageMeta(page=1, pagesize=0, total=0, hasnext=False)
|
||||
return success(data=PageData(page=meta, items=[]))
|
||||
|
||||
# 正常分页查询
|
||||
items_orm, total = app_service.list_apps(
|
||||
@@ -79,6 +111,7 @@ def list_apps(
|
||||
status=status,
|
||||
search=search,
|
||||
include_shared=include_shared,
|
||||
shared_only=shared_only,
|
||||
page=page,
|
||||
pagesize=pagesize,
|
||||
)
|
||||
@@ -88,6 +121,37 @@ def list_apps(
|
||||
return success(data=PageData(page=meta, items=items))
|
||||
|
||||
|
||||
@router.get("/my-shared-out", summary="列出本工作空间主动分享出去的记录")
|
||||
@cur_workspace_access_guard()
|
||||
def list_my_shared_out(
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
):
|
||||
"""列出本工作空间主动分享给其他工作空间的所有记录(我的共享)"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
service = app_service.AppService(db)
|
||||
shares = service.list_my_shared_out(workspace_id=workspace_id)
|
||||
data = [app_schema.AppShare.model_validate(s) for s in shares]
|
||||
return success(data=data)
|
||||
|
||||
|
||||
@router.delete("/share/{target_workspace_id}", summary="取消对某工作空间的所有应用分享")
|
||||
@cur_workspace_access_guard()
|
||||
def unshare_all_apps_to_workspace(
|
||||
target_workspace_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
):
|
||||
"""Cancel all app shares from current workspace to a target workspace."""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
service = app_service.AppService(db)
|
||||
count = service.unshare_all_apps_to_workspace(
|
||||
target_workspace_id=target_workspace_id,
|
||||
workspace_id=workspace_id
|
||||
)
|
||||
return success(msg=f"已取消 {count} 个应用的分享", data={"count": count})
|
||||
|
||||
|
||||
@router.get("/{app_id}", summary="获取应用详情")
|
||||
@cur_workspace_access_guard()
|
||||
def get_app(
|
||||
@@ -156,6 +220,7 @@ def delete_app(
|
||||
def copy_app(
|
||||
app_id: uuid.UUID,
|
||||
new_name: Optional[str] = None,
|
||||
payload: app_schema.CopyAppRequest = None,
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
):
|
||||
@@ -167,6 +232,8 @@ def copy_app(
|
||||
- 不影响原应用
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
# body takes precedence over query param for backward compatibility
|
||||
new_name = (payload.new_name if payload else None) or new_name
|
||||
logger.info(
|
||||
"用户请求复制应用",
|
||||
extra={
|
||||
@@ -216,6 +283,27 @@ def get_agent_config(
|
||||
return success(data=app_schema.AgentConfig.model_validate(cfg))
|
||||
|
||||
|
||||
@router.get("/{app_id}/opening", summary="获取应用开场白配置")
|
||||
@cur_workspace_access_guard()
|
||||
def get_opening(
|
||||
app_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
):
|
||||
"""返回开场白文本和预设问题,供前端对话界面初始化时展示"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
cfg = app_service.get_agent_config(db, app_id=app_id, workspace_id=workspace_id)
|
||||
features = cfg.features or {}
|
||||
if hasattr(features, "model_dump"):
|
||||
features = features.model_dump()
|
||||
opening = features.get("opening_statement", {})
|
||||
return success(data=app_schema.OpeningResponse(
|
||||
enabled=opening.get("enabled", False),
|
||||
statement=opening.get("statement"),
|
||||
suggested_questions=opening.get("suggested_questions", []),
|
||||
))
|
||||
|
||||
|
||||
@router.post("/{app_id}/publish", summary="发布应用(生成不可变快照)")
|
||||
@cur_workspace_access_guard()
|
||||
def publish_app(
|
||||
@@ -297,7 +385,8 @@ def share_app(
|
||||
app_id=app_id,
|
||||
target_workspace_ids=payload.target_workspace_ids,
|
||||
user_id=current_user.id,
|
||||
workspace_id=workspace_id
|
||||
workspace_id=workspace_id,
|
||||
permission=payload.permission
|
||||
)
|
||||
|
||||
data = [app_schema.AppShare.model_validate(s) for s in shares]
|
||||
@@ -328,6 +417,32 @@ def unshare_app(
|
||||
return success(msg="应用分享已取消")
|
||||
|
||||
|
||||
@router.patch("/{app_id}/share/{target_workspace_id}", summary="更新共享权限")
|
||||
@cur_workspace_access_guard()
|
||||
def update_share_permission(
|
||||
app_id: uuid.UUID,
|
||||
target_workspace_id: uuid.UUID,
|
||||
payload: app_schema.UpdateSharePermissionRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
):
|
||||
"""更新共享权限(readonly <-> editable)
|
||||
|
||||
- 只能修改自己工作空间应用的共享权限
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
service = app_service.AppService(db)
|
||||
share = service.update_share_permission(
|
||||
app_id=app_id,
|
||||
target_workspace_id=target_workspace_id,
|
||||
permission=payload.permission,
|
||||
workspace_id=workspace_id
|
||||
)
|
||||
|
||||
return success(data=app_schema.AppShare.model_validate(share))
|
||||
|
||||
|
||||
@router.get("/{app_id}/shares", summary="列出应用的分享记录")
|
||||
@cur_workspace_access_guard()
|
||||
def list_app_shares(
|
||||
@@ -351,6 +466,46 @@ def list_app_shares(
|
||||
return success(data=data)
|
||||
|
||||
|
||||
@router.delete("/shared/{source_workspace_id}", summary="批量移除某来源工作空间的所有共享应用")
|
||||
@cur_workspace_access_guard()
|
||||
def remove_all_shared_apps_from_workspace(
|
||||
source_workspace_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
):
|
||||
"""Remove all shared apps from a specific source workspace (recipient operation)."""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
service = app_service.AppService(db)
|
||||
count = service.remove_all_shared_apps_from_workspace(
|
||||
source_workspace_id=source_workspace_id,
|
||||
workspace_id=workspace_id
|
||||
)
|
||||
return success(msg=f"已移除 {count} 个共享应用", data={"count": count})
|
||||
|
||||
|
||||
@router.delete("/{app_id}/shared", summary="移除共享给我的应用")
|
||||
@cur_workspace_access_guard()
|
||||
def remove_shared_app(
|
||||
app_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
):
|
||||
"""被共享者从自己的工作空间移除共享应用
|
||||
|
||||
- 不会删除源应用,只删除共享记录
|
||||
- 只能移除共享给自己工作空间的应用
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
service = app_service.AppService(db)
|
||||
service.remove_shared_app(
|
||||
app_id=app_id,
|
||||
workspace_id=workspace_id
|
||||
)
|
||||
|
||||
return success(msg="已移除共享应用")
|
||||
|
||||
|
||||
@router.post("/{app_id}/draft/run", summary="试运行 Agent(使用当前草稿配置)")
|
||||
@cur_workspace_access_guard()
|
||||
async def draft_run(
|
||||
@@ -391,13 +546,13 @@ async def draft_run(
|
||||
# 提前验证和准备(在流式响应开始前完成)
|
||||
from app.services.app_service import AppService
|
||||
from app.services.multi_agent_service import MultiAgentService
|
||||
from app.models import AgentConfig, ModelConfig
|
||||
from app.models import AgentConfig, ModelConfig, AppRelease
|
||||
from sqlalchemy import select
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.services.draft_run_service import DraftRunService
|
||||
from app.services.draft_run_service import AgentRunService
|
||||
|
||||
service = AppService(db)
|
||||
draft_service = DraftRunService(db)
|
||||
draft_service = AgentRunService(db)
|
||||
|
||||
# 1. 验证应用
|
||||
app = service._get_app_or_404(app_id)
|
||||
@@ -408,11 +563,12 @@ async def draft_run(
|
||||
service._validate_app_accessible(app, workspace_id)
|
||||
|
||||
if payload.user_id is None:
|
||||
# 先获取 app 的 workspace_id
|
||||
end_user_repo = EndUserRepository(db)
|
||||
new_end_user = end_user_repo.get_or_create_end_user(
|
||||
app_id=app_id,
|
||||
workspace_id=app.workspace_id,
|
||||
other_id=str(current_user.id),
|
||||
original_user_id=str(current_user.id) # Save original user_id to other_id
|
||||
)
|
||||
payload.user_id = str(new_end_user.id)
|
||||
|
||||
@@ -429,18 +585,29 @@ async def draft_run(
|
||||
service._check_agent_config(app_id)
|
||||
|
||||
# 2. 获取 Agent 配置
|
||||
stmt = select(AgentConfig).where(AgentConfig.app_id == app_id)
|
||||
agent_cfg = db.scalars(stmt).first()
|
||||
if not agent_cfg:
|
||||
raise BusinessException("Agent 配置不存在", BizCode.AGENT_CONFIG_MISSING)
|
||||
# 共享应用:从最新发布版本读配置快照,而非草稿
|
||||
is_shared = app.workspace_id != workspace_id
|
||||
if is_shared:
|
||||
if not app.current_release_id:
|
||||
raise BusinessException("该应用尚未发布,无法使用", BizCode.AGENT_CONFIG_MISSING)
|
||||
release = db.get(AppRelease, app.current_release_id)
|
||||
if not release:
|
||||
raise BusinessException("发布版本不存在", BizCode.AGENT_CONFIG_MISSING)
|
||||
agent_cfg = service._agent_config_from_release(release)
|
||||
model_config = db.get(ModelConfig, release.default_model_config_id) if release.default_model_config_id else None
|
||||
else:
|
||||
stmt = select(AgentConfig).where(AgentConfig.app_id == app_id)
|
||||
agent_cfg = db.scalars(stmt).first()
|
||||
if not agent_cfg:
|
||||
raise BusinessException("Agent 配置不存在", BizCode.AGENT_CONFIG_MISSING)
|
||||
|
||||
# 3. 获取模型配置
|
||||
model_config = None
|
||||
if agent_cfg.default_model_config_id:
|
||||
model_config = db.get(ModelConfig, agent_cfg.default_model_config_id)
|
||||
if not model_config:
|
||||
from app.core.exceptions import ResourceNotFoundException
|
||||
raise ResourceNotFoundException("模型配置", str(agent_cfg.default_model_config_id))
|
||||
# 3. 获取模型配置
|
||||
model_config = None
|
||||
if agent_cfg.default_model_config_id:
|
||||
model_config = db.get(ModelConfig, agent_cfg.default_model_config_id)
|
||||
if not model_config:
|
||||
from app.core.exceptions import ResourceNotFoundException
|
||||
raise ResourceNotFoundException("模型配置", str(agent_cfg.default_model_config_id))
|
||||
|
||||
# 流式返回
|
||||
if payload.stream:
|
||||
@@ -482,8 +649,8 @@ async def draft_run(
|
||||
}
|
||||
)
|
||||
|
||||
from app.services.draft_run_service import DraftRunService
|
||||
draft_service = DraftRunService(db)
|
||||
from app.services.draft_run_service import AgentRunService
|
||||
draft_service = AgentRunService(db)
|
||||
result = await draft_service.run(
|
||||
agent_config=agent_cfg,
|
||||
model_config=model_config,
|
||||
@@ -596,7 +763,17 @@ async def draft_run(
|
||||
msg="多 Agent 任务执行成功"
|
||||
)
|
||||
elif app.type == AppType.WORKFLOW: # 工作流
|
||||
config = workflow_service.check_config(app_id)
|
||||
# 共享应用:从最新发布版本读配置快照,而非草稿
|
||||
is_shared = app.workspace_id != workspace_id
|
||||
if is_shared:
|
||||
if not app.current_release_id:
|
||||
raise BusinessException("该应用尚未发布,无法使用", BizCode.AGENT_CONFIG_MISSING)
|
||||
release = db.get(AppRelease, app.current_release_id)
|
||||
if not release:
|
||||
raise BusinessException("发布版本不存在", BizCode.AGENT_CONFIG_MISSING)
|
||||
config = service._workflow_config_from_release(release)
|
||||
else:
|
||||
config = workflow_service.check_config(app_id)
|
||||
# 3. 流式返回
|
||||
if payload.stream:
|
||||
logger.debug(
|
||||
@@ -739,6 +916,16 @@ async def draft_run_compare(
|
||||
raise BusinessException("只有 Agent 类型应用支持试运行", BizCode.APP_TYPE_NOT_SUPPORTED)
|
||||
service._validate_app_accessible(app, workspace_id)
|
||||
|
||||
if payload.user_id is None:
|
||||
# 先获取 app 的 workspace_id
|
||||
end_user_repo = EndUserRepository(db)
|
||||
new_end_user = end_user_repo.get_or_create_end_user(
|
||||
app_id=app_id,
|
||||
workspace_id=app.workspace_id,
|
||||
other_id=str(current_user.id),
|
||||
)
|
||||
payload.user_id = str(new_end_user.id)
|
||||
|
||||
# 2. 获取 Agent 配置
|
||||
from sqlalchemy import select
|
||||
from app.models import AgentConfig
|
||||
@@ -784,22 +971,29 @@ async def draft_run_compare(
|
||||
"conversation_id": model_item.conversation_id # 传递每个模型的 conversation_id
|
||||
})
|
||||
|
||||
# 从 features 中读取功能开关(与 draft_run 保持一致)
|
||||
features_config: dict = agent_cfg.features or {}
|
||||
if hasattr(features_config, 'model_dump'):
|
||||
features_config = features_config.model_dump()
|
||||
web_search_feature = features_config.get("web_search", {})
|
||||
web_search = isinstance(web_search_feature, dict) and web_search_feature.get("enabled", False)
|
||||
|
||||
# 流式返回
|
||||
if payload.stream:
|
||||
async def event_generator():
|
||||
from app.services.draft_run_service import DraftRunService
|
||||
draft_service = DraftRunService(db)
|
||||
from app.services.draft_run_service import AgentRunService
|
||||
draft_service = AgentRunService(db)
|
||||
async for event in draft_service.run_compare_stream(
|
||||
agent_config=agent_cfg,
|
||||
models=model_configs,
|
||||
message=payload.message,
|
||||
workspace_id=workspace_id,
|
||||
conversation_id=payload.conversation_id,
|
||||
user_id=payload.user_id or str(current_user.id),
|
||||
user_id=payload.user_id,
|
||||
variables=payload.variables,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
web_search=True,
|
||||
web_search=web_search,
|
||||
memory=True,
|
||||
parallel=payload.parallel,
|
||||
timeout=payload.timeout or 60,
|
||||
@@ -818,22 +1012,23 @@ async def draft_run_compare(
|
||||
)
|
||||
|
||||
# 非流式返回
|
||||
from app.services.draft_run_service import DraftRunService
|
||||
draft_service = DraftRunService(db)
|
||||
from app.services.draft_run_service import AgentRunService
|
||||
draft_service = AgentRunService(db)
|
||||
result = await draft_service.run_compare(
|
||||
agent_config=agent_cfg,
|
||||
models=model_configs,
|
||||
message=payload.message,
|
||||
workspace_id=workspace_id,
|
||||
conversation_id=payload.conversation_id,
|
||||
user_id=payload.user_id or str(current_user.id),
|
||||
user_id=payload.user_id,
|
||||
variables=payload.variables,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
web_search=True,
|
||||
web_search=web_search,
|
||||
memory=True,
|
||||
parallel=payload.parallel,
|
||||
timeout=payload.timeout or 60
|
||||
timeout=payload.timeout or 60,
|
||||
files=payload.files
|
||||
)
|
||||
|
||||
logger.info(
|
||||
@@ -879,6 +1074,60 @@ async def update_workflow_config(
|
||||
return success(data=WorkflowConfigSchema.model_validate(cfg))
|
||||
|
||||
|
||||
@router.get("/{app_id}/workflow/export")
|
||||
@cur_workspace_access_guard()
|
||||
async def export_workflow_config(
|
||||
app_id: uuid.UUID,
|
||||
db: Annotated[Session, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_user)]
|
||||
):
|
||||
"""导出工作流配置为YAML文件"""
|
||||
workflow_service = WorkflowService(db)
|
||||
|
||||
return success(data={
|
||||
"content": workflow_service.export_workflow_dsl(app_id=app_id),
|
||||
})
|
||||
|
||||
|
||||
@router.post("/workflow/import")
|
||||
@cur_workspace_access_guard()
|
||||
async def import_workflow_config(
|
||||
file: UploadFile = File(...),
|
||||
platform: str = Form(...),
|
||||
app_id: str = Form(None),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
|
||||
):
|
||||
"""从YAML内容导入工作流配置"""
|
||||
if not file.filename.lower().endswith((".yaml", ".yml")):
|
||||
return fail(msg="Only yaml file is allowed", code=BizCode.BAD_REQUEST)
|
||||
|
||||
raw_text = (await file.read()).decode("utf-8")
|
||||
import_service = WorkflowImportService(db)
|
||||
config = yaml.safe_load(raw_text)
|
||||
result = await import_service.upload_config(platform, config)
|
||||
return success(data=result)
|
||||
|
||||
|
||||
@router.post("/workflow/import/save")
|
||||
@cur_workspace_access_guard()
|
||||
async def save_workflow_import(
|
||||
data: WorkflowImportSave,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
import_service = WorkflowImportService(db)
|
||||
app = await import_service.save_workflow(
|
||||
user_id=current_user.id,
|
||||
workspace_id=current_user.current_workspace_id,
|
||||
temp_id=data.temp_id,
|
||||
name=data.name,
|
||||
description=data.description,
|
||||
)
|
||||
return success(data=app_schema.App.model_validate(app))
|
||||
|
||||
|
||||
@router.get("/{app_id}/statistics", summary="应用统计数据")
|
||||
@cur_workspace_access_guard()
|
||||
def get_app_statistics(
|
||||
@@ -889,12 +1138,14 @@ def get_app_statistics(
|
||||
current_user=Depends(get_current_user),
|
||||
):
|
||||
"""获取应用统计数据
|
||||
|
||||
|
||||
Args:
|
||||
app_id: 应用ID
|
||||
start_date: 开始时间戳(毫秒)
|
||||
end_date: 结束时间戳(毫秒)
|
||||
|
||||
db: 数据库连接
|
||||
current_user: 当前用户
|
||||
|
||||
Returns:
|
||||
- daily_conversations: 每日会话数统计
|
||||
- total_conversations: 总会话数
|
||||
@@ -931,6 +1182,8 @@ def get_workspace_api_statistics(
|
||||
Args:
|
||||
start_date: 开始时间戳(毫秒)
|
||||
end_date: 结束时间戳(毫秒)
|
||||
db: 数据库连接
|
||||
current_user: 当前用户
|
||||
|
||||
Returns:
|
||||
每日统计数据列表,每项包含:
|
||||
@@ -949,3 +1202,57 @@ def get_workspace_api_statistics(
|
||||
)
|
||||
|
||||
return success(data=result)
|
||||
|
||||
|
||||
@router.get("/{app_id}/export", summary="导出应用配置为 YAML 文件")
|
||||
@cur_workspace_access_guard()
|
||||
async def export_app(
|
||||
app_id: uuid.UUID,
|
||||
db: Annotated[Session, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_user)],
|
||||
release_id: Optional[uuid.UUID] = None
|
||||
):
|
||||
"""导出 agent / multi_agent / workflow 应用配置为 YAML 文件流。
|
||||
release_id: 指定发布版本id,不传则导出当前草稿配置。
|
||||
"""
|
||||
yaml_str, filename = AppDslService(db).export_dsl(app_id, release_id)
|
||||
encoded = quote(filename, safe=".")
|
||||
yaml_bytes = yaml_str.encode("utf-8")
|
||||
file_stream = io.BytesIO(yaml_bytes)
|
||||
file_stream.seek(0)
|
||||
return StreamingResponse(
|
||||
file_stream,
|
||||
media_type="application/octet-stream; charset=utf-8",
|
||||
headers={"Content-Disposition": f"attachment; filename={encoded}",
|
||||
"Content-Length": str(len(yaml_bytes))}
|
||||
)
|
||||
|
||||
|
||||
@router.post("/import", summary="从 YAML 文件导入应用")
|
||||
@cur_workspace_access_guard()
|
||||
async def import_app(
|
||||
file: UploadFile = File(...),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""从 YAML 文件导入 agent / multi_agent / workflow 应用。
|
||||
跨空间/跨租户导入时,模型/工具/知识库会按名称匹配,匹配不到则置空并返回 warnings。
|
||||
"""
|
||||
if not file.filename.lower().endswith((".yaml", ".yml")):
|
||||
return fail(msg="仅支持 YAML 文件", code=BizCode.BAD_REQUEST)
|
||||
|
||||
raw = (await file.read()).decode("utf-8")
|
||||
dsl = yaml.safe_load(raw)
|
||||
if not dsl or "app" not in dsl:
|
||||
return fail(msg="YAML 格式无效,缺少 app 字段", code=BizCode.BAD_REQUEST)
|
||||
|
||||
new_app, warnings = AppDslService(db).import_dsl(
|
||||
dsl=dsl,
|
||||
workspace_id=current_user.current_workspace_id,
|
||||
tenant_id=current_user.tenant_id,
|
||||
user_id=current_user.id,
|
||||
)
|
||||
return success(
|
||||
data={"app": app_schema.App.model_validate(new_app), "warnings": warnings},
|
||||
msg="应用导入成功" + (",但部分资源需手动配置" if warnings else "")
|
||||
)
|
||||
|
||||
89
api/app/controllers/app_log_controller.py
Normal file
89
api/app/controllers/app_log_controller.py
Normal file
@@ -0,0 +1,89 @@
|
||||
"""应用日志(消息记录)接口"""
|
||||
import uuid
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.core.response_utils import success
|
||||
from app.db import get_db
|
||||
from app.dependencies import get_current_user, cur_workspace_access_guard
|
||||
from app.schemas.app_log_schema import AppLogConversation, AppLogConversationDetail
|
||||
from app.schemas.response_schema import PageData, PageMeta
|
||||
from app.services.app_service import AppService
|
||||
from app.services.app_log_service import AppLogService
|
||||
|
||||
router = APIRouter(prefix="/apps", tags=["App Logs"])
|
||||
logger = get_business_logger()
|
||||
|
||||
|
||||
@router.get("/{app_id}/logs", summary="应用日志 - 会话列表")
|
||||
@cur_workspace_access_guard()
|
||||
def list_app_logs(
|
||||
app_id: uuid.UUID,
|
||||
page: int = Query(1, ge=1),
|
||||
pagesize: int = Query(20, ge=1, le=100),
|
||||
is_draft: Optional[bool] = None,
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
):
|
||||
"""查看应用下所有会话记录(分页)
|
||||
|
||||
- 支持按 is_draft 筛选(草稿会话 / 发布会话)
|
||||
- 按最新更新时间倒序排列
|
||||
- 所有人(包括共享者和被共享者)都只能查看自己的会话记录
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
# 验证应用访问权限
|
||||
app_service = AppService(db)
|
||||
app_service.get_app(app_id, workspace_id)
|
||||
|
||||
# 使用 Service 层查询
|
||||
log_service = AppLogService(db)
|
||||
conversations, total = log_service.list_conversations(
|
||||
app_id=app_id,
|
||||
workspace_id=workspace_id,
|
||||
page=page,
|
||||
pagesize=pagesize,
|
||||
is_draft=is_draft
|
||||
)
|
||||
|
||||
items = [AppLogConversation.model_validate(c) for c in conversations]
|
||||
meta = PageMeta(page=page, pagesize=pagesize, total=total, hasnext=(page * pagesize) < total)
|
||||
|
||||
return success(data=PageData(page=meta, items=items))
|
||||
|
||||
|
||||
@router.get("/{app_id}/logs/{conversation_id}", summary="应用日志 - 会话消息详情")
|
||||
@cur_workspace_access_guard()
|
||||
def get_app_log_detail(
|
||||
app_id: uuid.UUID,
|
||||
conversation_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
):
|
||||
"""查看某会话的完整消息记录
|
||||
|
||||
- 返回会话基本信息 + 所有消息(按时间正序)
|
||||
- 消息 meta_data 包含模型名、token 用量等信息
|
||||
- 所有人(包括共享者和被共享者)都只能查看自己的会话详情
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
# 验证应用访问权限
|
||||
app_service = AppService(db)
|
||||
app_service.get_app(app_id, workspace_id)
|
||||
|
||||
# 使用 Service 层查询
|
||||
log_service = AppLogService(db)
|
||||
conversation = log_service.get_conversation_detail(
|
||||
app_id=app_id,
|
||||
conversation_id=conversation_id,
|
||||
workspace_id=workspace_id
|
||||
)
|
||||
|
||||
detail = AppLogConversationDetail.model_validate(conversation)
|
||||
|
||||
return success(data=detail)
|
||||
@@ -1,4 +1,5 @@
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Callable
|
||||
from fastapi import APIRouter, Depends
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@@ -16,6 +17,7 @@ from app.core.exceptions import BusinessException
|
||||
from app.core.error_codes import BizCode
|
||||
from app.dependencies import get_current_user, oauth2_scheme
|
||||
from app.models.user_model import User
|
||||
from app.i18n.dependencies import get_translator
|
||||
|
||||
# 获取专用日志器
|
||||
auth_logger = get_auth_logger()
|
||||
@@ -26,7 +28,8 @@ router = APIRouter(tags=["Authentication"])
|
||||
@router.post("/token", response_model=ApiResponse)
|
||||
async def login_for_access_token(
|
||||
form_data: TokenRequest,
|
||||
db: Session = Depends(get_db)
|
||||
db: Session = Depends(get_db),
|
||||
t: Callable = Depends(get_translator)
|
||||
):
|
||||
"""用户登录获取token"""
|
||||
auth_logger.info(f"用户登录请求: {form_data.email}")
|
||||
@@ -40,10 +43,10 @@ async def login_for_access_token(
|
||||
invite_info = workspace_service.validate_invite_token(db, form_data.invite)
|
||||
|
||||
if not invite_info.is_valid:
|
||||
raise BusinessException("邀请码无效或已过期", code=BizCode.BAD_REQUEST)
|
||||
raise BusinessException(t("auth.invite.invalid"), code=BizCode.BAD_REQUEST)
|
||||
|
||||
if invite_info.email != form_data.email:
|
||||
raise BusinessException("邀请邮箱与登录邮箱不匹配", code=BizCode.BAD_REQUEST)
|
||||
raise BusinessException(t("auth.invite.email_mismatch"), code=BizCode.BAD_REQUEST)
|
||||
auth_logger.info(f"邀请码验证成功: workspace={invite_info.workspace_name}")
|
||||
try:
|
||||
# 尝试认证用户
|
||||
@@ -69,7 +72,7 @@ async def login_for_access_token(
|
||||
elif e.code == BizCode.PASSWORD_ERROR:
|
||||
# 用户存在但密码错误
|
||||
auth_logger.warning(f"接受邀请失败,密码验证错误: {form_data.email}")
|
||||
raise BusinessException("接受邀请失败,密码验证错误", BizCode.LOGIN_FAILED)
|
||||
raise BusinessException(t("auth.invite.password_verification_failed"), BizCode.LOGIN_FAILED)
|
||||
else:
|
||||
# 其他认证失败情况,直接抛出
|
||||
raise
|
||||
@@ -82,7 +85,7 @@ async def login_for_access_token(
|
||||
except BusinessException as e:
|
||||
|
||||
# 其他认证失败情况,直接抛出
|
||||
raise BusinessException(e.message,BizCode.LOGIN_FAILED)
|
||||
raise BusinessException(e.message, BizCode.LOGIN_FAILED)
|
||||
|
||||
# 创建 tokens
|
||||
access_token, access_token_id = security.create_access_token(subject=user.id)
|
||||
@@ -110,14 +113,15 @@ async def login_for_access_token(
|
||||
expires_at=access_expires_at,
|
||||
refresh_expires_at=refresh_expires_at
|
||||
),
|
||||
msg="登录成功"
|
||||
msg=t("auth.login.success")
|
||||
)
|
||||
|
||||
|
||||
@router.post("/refresh", response_model=ApiResponse)
|
||||
async def refresh_token(
|
||||
refresh_request: RefreshTokenRequest,
|
||||
db: Session = Depends(get_db)
|
||||
db: Session = Depends(get_db),
|
||||
t: Callable = Depends(get_translator)
|
||||
):
|
||||
"""刷新token"""
|
||||
auth_logger.info("收到token刷新请求")
|
||||
@@ -125,18 +129,18 @@ async def refresh_token(
|
||||
# 验证 refresh token
|
||||
userId = security.verify_token(refresh_request.refresh_token, "refresh")
|
||||
if not userId:
|
||||
raise BusinessException("无效的refresh token", code=BizCode.TOKEN_INVALID)
|
||||
raise BusinessException(t("auth.token.invalid_refresh_token"), code=BizCode.TOKEN_INVALID)
|
||||
|
||||
# 检查用户是否存在
|
||||
user = auth_service.get_user_by_id(db, userId)
|
||||
if not user:
|
||||
raise BusinessException("用户不存在", code=BizCode.USER_NOT_FOUND)
|
||||
raise BusinessException(t("auth.user.not_found"), code=BizCode.USER_NOT_FOUND)
|
||||
|
||||
# 检查 refresh token 黑名单
|
||||
if settings.ENABLE_SINGLE_SESSION:
|
||||
refresh_token_id = security.get_token_id(refresh_request.refresh_token)
|
||||
if refresh_token_id and await SessionService.is_token_blacklisted(refresh_token_id):
|
||||
raise BusinessException("Refresh token已失效", code=BizCode.TOKEN_BLACKLISTED)
|
||||
raise BusinessException(t("auth.token.refresh_token_blacklisted"), code=BizCode.TOKEN_BLACKLISTED)
|
||||
|
||||
# 生成新 tokens
|
||||
new_access_token, new_access_token_id = security.create_access_token(subject=user.id)
|
||||
@@ -167,7 +171,7 @@ async def refresh_token(
|
||||
expires_at=access_expires_at,
|
||||
refresh_expires_at=refresh_expires_at
|
||||
),
|
||||
msg="token刷新成功"
|
||||
msg=t("auth.token.refresh_success")
|
||||
)
|
||||
|
||||
|
||||
@@ -175,14 +179,15 @@ async def refresh_token(
|
||||
async def logout(
|
||||
token: str = Depends(oauth2_scheme),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
db: Session = Depends(get_db),
|
||||
t: Callable = Depends(get_translator)
|
||||
):
|
||||
"""登出当前用户:加入token黑名单并清理会话"""
|
||||
auth_logger.info(f"用户 {current_user.username} 请求登出")
|
||||
|
||||
token_id = security.get_token_id(token)
|
||||
if not token_id:
|
||||
raise BusinessException("无效的access token", code=BizCode.TOKEN_INVALID)
|
||||
raise BusinessException(t("auth.token.invalid"), code=BizCode.TOKEN_INVALID)
|
||||
|
||||
# 加入黑名单
|
||||
await SessionService.blacklist_token(token_id)
|
||||
@@ -192,5 +197,5 @@ async def logout(
|
||||
await SessionService.clear_user_session(current_user.username)
|
||||
|
||||
auth_logger.info(f"用户 {current_user.username} 登出成功")
|
||||
return success(msg="登出成功")
|
||||
return success(msg=t("auth.logout.success"))
|
||||
|
||||
|
||||
@@ -441,14 +441,14 @@ async def retrieve_chunks(
|
||||
# 1 participle search, 2 semantic search, 3 hybrid search
|
||||
match retrieve_data.retrieve_type:
|
||||
case chunk_schema.RetrieveType.PARTICIPLE:
|
||||
rs = vector_service.search_by_full_text(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.similarity_threshold)
|
||||
rs = vector_service.search_by_full_text(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.similarity_threshold, file_names_filter=retrieve_data.file_names_filter)
|
||||
return success(data=rs, msg="retrieval successful")
|
||||
case chunk_schema.RetrieveType.SEMANTIC:
|
||||
rs = vector_service.search_by_vector(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.vector_similarity_weight)
|
||||
rs = vector_service.search_by_vector(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.vector_similarity_weight, file_names_filter=retrieve_data.file_names_filter)
|
||||
return success(data=rs, msg="retrieval successful")
|
||||
case _:
|
||||
rs1 = vector_service.search_by_vector(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.vector_similarity_weight)
|
||||
rs2 = vector_service.search_by_full_text(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.similarity_threshold)
|
||||
rs1 = vector_service.search_by_vector(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.vector_similarity_weight, file_names_filter=retrieve_data.file_names_filter)
|
||||
rs2 = vector_service.search_by_full_text(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.similarity_threshold, file_names_filter=retrieve_data.file_names_filter)
|
||||
# Efficient deduplication
|
||||
seen_ids = set()
|
||||
unique_rs = []
|
||||
|
||||
@@ -208,14 +208,64 @@ async def get_emotion_health(
|
||||
|
||||
|
||||
|
||||
# @router.post("/check-data", response_model=ApiResponse)
|
||||
# async def check_emotion_data_exists(
|
||||
# request: EmotionSuggestionsRequest,
|
||||
# db: Session = Depends(get_db),
|
||||
# current_user: User = Depends(get_current_user),
|
||||
# ):
|
||||
# """检查用户情绪建议数据是否存在
|
||||
|
||||
# Args:
|
||||
# request: 包含 end_user_id
|
||||
# db: 数据库会话
|
||||
# current_user: 当前用户
|
||||
|
||||
# Returns:
|
||||
# 数据存在状态
|
||||
# """
|
||||
# try:
|
||||
# api_logger.info(
|
||||
# f"检查用户情绪建议数据是否存在: {request.end_user_id}",
|
||||
# extra={"end_user_id": request.end_user_id}
|
||||
# )
|
||||
|
||||
# # 从数据库获取建议
|
||||
# data = await emotion_service.get_cached_suggestions(
|
||||
# end_user_id=request.end_user_id,
|
||||
# db=db
|
||||
# )
|
||||
|
||||
# if data is None:
|
||||
# api_logger.info(f"用户 {request.end_user_id} 的情绪建议数据不存在")
|
||||
# return fail(
|
||||
# BizCode.NOT_FOUND,
|
||||
# "情绪建议数据不存在,请点击右上角刷新进行初始化",
|
||||
# {"exists": False}
|
||||
# )
|
||||
|
||||
# api_logger.info(f"用户 {request.end_user_id} 的情绪建议数据存在")
|
||||
# return success(data={"exists": True}, msg="情绪建议数据已存在")
|
||||
|
||||
# except Exception as e:
|
||||
# api_logger.error(
|
||||
# f"检查情绪建议数据失败: {str(e)}",
|
||||
# extra={"end_user_id": request.end_user_id},
|
||||
# exc_info=True
|
||||
# )
|
||||
# raise HTTPException(
|
||||
# status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
# detail=f"检查情绪建议数据失败: {str(e)}"
|
||||
# )
|
||||
|
||||
|
||||
@router.post("/suggestions", response_model=ApiResponse)
|
||||
async def get_emotion_suggestions(
|
||||
request: EmotionSuggestionsRequest,
|
||||
language_type: str = Header(default=None, alias="X-Language-Type"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""获取个性化情绪建议(从缓存读取)
|
||||
"""获取个性化情绪建议(从数据库读取)
|
||||
|
||||
Args:
|
||||
request: 包含 end_user_id 和可选的 config_id
|
||||
@@ -223,77 +273,42 @@ async def get_emotion_suggestions(
|
||||
current_user: 当前用户
|
||||
|
||||
Returns:
|
||||
缓存的个性化情绪建议响应
|
||||
存储的个性化情绪建议响应
|
||||
"""
|
||||
try:
|
||||
# 使用集中化的语言校验
|
||||
language = get_language_from_header(language_type)
|
||||
|
||||
api_logger.info(
|
||||
f"用户 {current_user.username} 请求获取个性化情绪建议(缓存)",
|
||||
f"用户 {current_user.username} 请求获取个性化情绪建议",
|
||||
extra={
|
||||
"end_user_id": request.end_user_id,
|
||||
"config_id": request.config_id
|
||||
}
|
||||
)
|
||||
|
||||
# 从缓存获取建议
|
||||
# 从数据库获取建议
|
||||
data = await emotion_service.get_cached_suggestions(
|
||||
end_user_id=request.end_user_id,
|
||||
db=db
|
||||
)
|
||||
|
||||
if data is None:
|
||||
# 缓存不存在或已过期,自动触发生成
|
||||
api_logger.info(
|
||||
f"用户 {request.end_user_id} 的建议缓存不存在或已过期,自动生成新建议",
|
||||
f"用户 {request.end_user_id} 的建议数据不存在",
|
||||
extra={"end_user_id": request.end_user_id}
|
||||
)
|
||||
try:
|
||||
data = await emotion_service.generate_emotion_suggestions(
|
||||
end_user_id=request.end_user_id,
|
||||
db=db,
|
||||
language=language
|
||||
)
|
||||
# 保存到缓存
|
||||
await emotion_service.save_suggestions_cache(
|
||||
end_user_id=request.end_user_id,
|
||||
suggestions_data=data,
|
||||
db=db,
|
||||
expires_hours=24
|
||||
)
|
||||
except (ValueError, KeyError) as gen_e:
|
||||
# 预期内的业务异常:配置缺失、数据格式问题等
|
||||
api_logger.warning(
|
||||
f"自动生成建议失败(业务异常): {str(gen_e)}",
|
||||
extra={"end_user_id": request.end_user_id}
|
||||
)
|
||||
return fail(
|
||||
BizCode.NOT_FOUND,
|
||||
f"自动生成建议失败: {str(gen_e)}",
|
||||
""
|
||||
)
|
||||
except Exception as gen_e:
|
||||
# 非预期异常:记录完整 traceback 便于排查
|
||||
api_logger.error(
|
||||
f"自动生成建议时发生未预期异常: {str(gen_e)}",
|
||||
extra={"end_user_id": request.end_user_id},
|
||||
exc_info=True
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"生成建议时发生内部错误: {str(gen_e)}"
|
||||
)
|
||||
return success(
|
||||
data={"exists": False},
|
||||
msg="情绪建议数据不存在,请点击右上角刷新进行初始化"
|
||||
)
|
||||
|
||||
api_logger.info(
|
||||
"个性化建议获取成功(缓存)",
|
||||
"个性化建议获取成功",
|
||||
extra={
|
||||
"end_user_id": request.end_user_id,
|
||||
"suggestions_count": len(data.get("suggestions", []))
|
||||
}
|
||||
)
|
||||
|
||||
return success(data=data, msg="个性化建议获取成功(缓存)")
|
||||
return success(data=data, msg="个性化建议获取成功")
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(
|
||||
@@ -314,7 +329,7 @@ async def generate_emotion_suggestions(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""生成个性化情绪建议(调用LLM并缓存)
|
||||
"""生成个性化情绪建议(调用LLM并保存到数据库)
|
||||
|
||||
Args:
|
||||
request: 包含 end_user_id
|
||||
@@ -342,12 +357,11 @@ async def generate_emotion_suggestions(
|
||||
language=language
|
||||
)
|
||||
|
||||
# 保存到缓存
|
||||
# 保存到数据库
|
||||
await emotion_service.save_suggestions_cache(
|
||||
end_user_id=request.end_user_id,
|
||||
suggestions_data=data,
|
||||
db=db,
|
||||
expires_hours=24
|
||||
db=db
|
||||
)
|
||||
|
||||
api_logger.info(
|
||||
@@ -369,4 +383,4 @@ async def generate_emotion_suggestions(
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"生成个性化建议失败: {str(e)}"
|
||||
)
|
||||
)
|
||||
48
api/app/controllers/end_user_controller.py
Normal file
48
api/app/controllers/end_user_controller.py
Normal file
@@ -0,0 +1,48 @@
|
||||
"""End User 管理接口 - 无需认证"""
|
||||
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.core.response_utils import success
|
||||
from app.db import get_db
|
||||
from app.repositories.end_user_repository import EndUserRepository
|
||||
from app.schemas.memory_api_schema import (
|
||||
CreateEndUserRequest,
|
||||
CreateEndUserResponse,
|
||||
)
|
||||
from fastapi import APIRouter, Depends
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
router = APIRouter(prefix="/end_users", tags=["End Users"])
|
||||
logger = get_business_logger()
|
||||
|
||||
|
||||
@router.post("")
|
||||
async def create_end_user(
|
||||
data: CreateEndUserRequest,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
Create an end user.
|
||||
|
||||
Creates a new end user for the given workspace.
|
||||
If an end user with the same other_id already exists in the workspace,
|
||||
returns the existing one.
|
||||
"""
|
||||
logger.info(f"Create end user request - other_id: {data.other_id}, workspace_id: {data.workspace_id}")
|
||||
|
||||
end_user_repo = EndUserRepository(db)
|
||||
end_user = end_user_repo.get_or_create_end_user(
|
||||
app_id=None,
|
||||
workspace_id=data.workspace_id,
|
||||
other_id=data.other_id,
|
||||
)
|
||||
|
||||
logger.info(f"End user ready: {end_user.id}")
|
||||
|
||||
result = {
|
||||
"id": str(end_user.id),
|
||||
"other_id": end_user.other_id or "",
|
||||
"other_name": end_user.other_name or "",
|
||||
"workspace_id": str(end_user.workspace_id),
|
||||
}
|
||||
|
||||
return success(data=CreateEndUserResponse(**result).model_dump(), msg="End user created successfully")
|
||||
@@ -14,8 +14,11 @@ Routes:
|
||||
import os
|
||||
import uuid
|
||||
from typing import Any
|
||||
import httpx
|
||||
import mimetypes
|
||||
from urllib.parse import urlparse, unquote
|
||||
|
||||
from fastapi import APIRouter, Depends, File, HTTPException, UploadFile, status
|
||||
from fastapi import APIRouter, Depends, File, HTTPException, Request, UploadFile, status
|
||||
from fastapi.responses import FileResponse, RedirectResponse
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@@ -47,6 +50,19 @@ router = APIRouter(
|
||||
)
|
||||
|
||||
|
||||
def _match_scheme(request: Request, url: str) -> str:
|
||||
"""
|
||||
将 presigned URL 的协议替换为与当前请求一致的协议(http/https)。
|
||||
解决反向代理场景下 presigned URL 协议与请求协议不匹配的问题。
|
||||
"""
|
||||
incoming_scheme = request.headers.get("x-forwarded-proto") or request.url.scheme
|
||||
if url.startswith("http://") and incoming_scheme == "https":
|
||||
return "https://" + url[7:]
|
||||
if url.startswith("https://") and incoming_scheme == "http":
|
||||
return "http://" + url[8:]
|
||||
return url
|
||||
|
||||
|
||||
@router.post("/files", response_model=ApiResponse)
|
||||
async def upload_file(
|
||||
file: UploadFile = File(...),
|
||||
@@ -78,7 +94,7 @@ async def upload_file(
|
||||
|
||||
if file_size > settings.MAX_FILE_SIZE:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
status_code=status.HTTP_413_CONTENT_TOO_LARGE,
|
||||
detail=f"The file size exceeds the {settings.MAX_FILE_SIZE} byte limit"
|
||||
)
|
||||
|
||||
@@ -159,7 +175,6 @@ async def upload_file_with_share_token(
|
||||
|
||||
# Get share and release info from share_token
|
||||
service = ReleaseShareService(db)
|
||||
share_info = service.get_shared_release_info(share_token=share_data.share_token)
|
||||
|
||||
# Get share object to access app_id
|
||||
share = service.repo.get_by_share_token(share_data.share_token)
|
||||
@@ -278,8 +293,104 @@ async def upload_file_with_share_token(
|
||||
)
|
||||
|
||||
|
||||
@router.get("/files/info-by-url", response_model=ApiResponse)
|
||||
async def get_file_info_by_url(
|
||||
url: str,
|
||||
):
|
||||
"""
|
||||
Get file information by network URL (no authentication required).
|
||||
|
||||
Fetches file metadata from a remote URL via HTTP HEAD request.
|
||||
Falls back to GET request if HEAD is not supported.
|
||||
Returns file type, name, and size.
|
||||
|
||||
Args:
|
||||
url: The network URL of the file.
|
||||
|
||||
Returns:
|
||||
ApiResponse with file information.
|
||||
"""
|
||||
api_logger.info(f"File info by URL request: url={url}")
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
# Try HEAD request first
|
||||
response = await client.head(url, follow_redirects=True)
|
||||
|
||||
# If HEAD fails, try GET request (some servers don't support HEAD)
|
||||
if response.status_code != 200:
|
||||
api_logger.info(f"HEAD request failed with {response.status_code}, trying GET request")
|
||||
response = await client.get(url, follow_redirects=True)
|
||||
|
||||
if response.status_code != 200:
|
||||
api_logger.error(f"Failed to fetch file info: HTTP {response.status_code}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Unable to access file: HTTP {response.status_code}"
|
||||
)
|
||||
|
||||
# Get file size from Content-Length header or actual content
|
||||
file_size = response.headers.get("Content-Length")
|
||||
if file_size:
|
||||
file_size = int(file_size)
|
||||
elif hasattr(response, 'content'):
|
||||
file_size = len(response.content)
|
||||
else:
|
||||
file_size = None
|
||||
|
||||
# Get content type from Content-Type header
|
||||
content_type = response.headers.get("Content-Type", "application/octet-stream")
|
||||
# Remove charset and other parameters from content type
|
||||
content_type = content_type.split(';')[0].strip()
|
||||
|
||||
# Extract filename from Content-Disposition or URL
|
||||
file_name = None
|
||||
content_disposition = response.headers.get("Content-Disposition")
|
||||
if content_disposition and "filename=" in content_disposition:
|
||||
parts = content_disposition.split("filename=")
|
||||
if len(parts) > 1:
|
||||
file_name = parts[1].strip('"').strip("'")
|
||||
|
||||
if not file_name:
|
||||
parsed_url = urlparse(url)
|
||||
file_name = unquote(os.path.basename(parsed_url.path)) or "unknown"
|
||||
|
||||
# Extract file extension from filename
|
||||
_, file_ext = os.path.splitext(file_name)
|
||||
|
||||
# If no extension found, infer from content type
|
||||
if not file_ext:
|
||||
ext = mimetypes.guess_extension(content_type)
|
||||
if ext:
|
||||
file_ext = ext
|
||||
file_name = f"{file_name}{file_ext}"
|
||||
|
||||
api_logger.info(f"File info retrieved: name={file_name}, size={file_size}, type={content_type}")
|
||||
|
||||
return success(
|
||||
data={
|
||||
"url": url,
|
||||
"file_name": file_name,
|
||||
"file_ext": file_ext.lower() if file_ext else "",
|
||||
"file_size": file_size,
|
||||
"content_type": content_type,
|
||||
},
|
||||
msg="File information retrieved successfully"
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
api_logger.error(f"Unexpected error: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to retrieve file information: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/files/{file_id}", response_model=Any)
|
||||
async def download_file(
|
||||
request: Request,
|
||||
file_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
@@ -327,6 +438,7 @@ async def download_file(
|
||||
else:
|
||||
try:
|
||||
presigned_url = await storage_service.get_file_url(file_key, expires=3600)
|
||||
presigned_url = _match_scheme(request, presigned_url)
|
||||
api_logger.info(f"Redirecting to presigned URL: file_key={file_key}")
|
||||
return RedirectResponse(url=presigned_url, status_code=status.HTTP_302_FOUND)
|
||||
except FileNotFoundError:
|
||||
@@ -400,6 +512,7 @@ async def delete_file(
|
||||
|
||||
@router.get("/files/{file_id}/url", response_model=ApiResponse)
|
||||
async def get_file_url(
|
||||
request: Request,
|
||||
file_id: uuid.UUID,
|
||||
expires: int = None,
|
||||
permanent: bool = False,
|
||||
@@ -461,8 +574,13 @@ async def get_file_url(
|
||||
# For local storage, generate signed URL with expiration
|
||||
url = generate_signed_url(str(file_id), expires)
|
||||
else:
|
||||
# For remote storage (OSS/S3), get presigned URL
|
||||
url = await storage_service.get_file_url(file_key, expires=expires)
|
||||
# For remote storage (OSS/S3), get presigned URL with forced download
|
||||
url = await storage_service.get_file_url(
|
||||
file_key,
|
||||
expires=expires,
|
||||
file_name=file_metadata.file_name,
|
||||
)
|
||||
url = _match_scheme(request, url)
|
||||
|
||||
api_logger.info(f"Generated file URL: file_id={file_id}")
|
||||
return success(
|
||||
@@ -482,8 +600,54 @@ async def get_file_url(
|
||||
)
|
||||
|
||||
|
||||
@router.get("/files/{file_id}/public-url", response_model=ApiResponse)
|
||||
async def get_permanent_file_url(
|
||||
file_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
storage_service: FileStorageService = Depends(get_file_storage_service),
|
||||
):
|
||||
"""
|
||||
获取文件的永久公开 URL(无过期时间)。
|
||||
|
||||
- 本地存储:返回 API 永久访问地址(基于 FILE_LOCAL_SERVER_URL 配置)
|
||||
- 远程存储(OSS/S3):返回 bucket 公读地址(需 bucket 已配置公共读权限)
|
||||
"""
|
||||
file_metadata = db.query(FileMetadata).filter(FileMetadata.id == file_id).first()
|
||||
if not file_metadata:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="The file does not exist")
|
||||
|
||||
if file_metadata.status != "completed":
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"File upload not completed, status: {file_metadata.status}")
|
||||
|
||||
file_key = file_metadata.file_key
|
||||
storage = storage_service.storage
|
||||
|
||||
try:
|
||||
if isinstance(storage, LocalStorage):
|
||||
url = f"{settings.FILE_LOCAL_SERVER_URL}/storage/permanent/{file_id}"
|
||||
else:
|
||||
url = await storage.get_permanent_url(file_key)
|
||||
if not url:
|
||||
raise HTTPException(status_code=status.HTTP_501_NOT_IMPLEMENTED,
|
||||
detail="Permanent URL not supported for current storage backend")
|
||||
|
||||
api_logger.info(f"Generated permanent URL: file_id={file_id}")
|
||||
return success(
|
||||
data={"url": url, "expires_in": None, "permanent": True, "file_name": file_metadata.file_name},
|
||||
msg="Permanent file URL generated successfully"
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
api_logger.error(f"Failed to generate permanent URL: {e}")
|
||||
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to generate permanent URL: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/public/{file_id}", response_model=Any)
|
||||
async def public_download_file(
|
||||
request: Request,
|
||||
file_id: uuid.UUID,
|
||||
expires: int = 0,
|
||||
signature: str = "",
|
||||
@@ -555,6 +719,7 @@ async def public_download_file(
|
||||
# For remote storage, redirect to presigned URL
|
||||
try:
|
||||
presigned_url = await storage_service.get_file_url(file_key, expires=3600)
|
||||
presigned_url = _match_scheme(request, presigned_url)
|
||||
return RedirectResponse(url=presigned_url, status_code=status.HTTP_302_FOUND)
|
||||
except Exception as e:
|
||||
api_logger.error(f"Failed to get presigned URL: {e}")
|
||||
@@ -566,6 +731,7 @@ async def public_download_file(
|
||||
|
||||
@router.get("/permanent/{file_id}", response_model=Any)
|
||||
async def permanent_download_file(
|
||||
request: Request,
|
||||
file_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
storage_service: FileStorageService = Depends(get_file_storage_service),
|
||||
@@ -624,7 +790,8 @@ async def permanent_download_file(
|
||||
# For remote storage, redirect to presigned URL with long expiration
|
||||
try:
|
||||
# Use a very long expiration (7 days max for most cloud providers)
|
||||
presigned_url = await storage_service.get_file_url(file_key, expires=604800)
|
||||
presigned_url = await storage_service.get_file_url(file_key, expires=604800, file_name=file_metadata.file_name)
|
||||
presigned_url = _match_scheme(request, presigned_url)
|
||||
return RedirectResponse(url=presigned_url, status_code=status.HTTP_302_FOUND)
|
||||
except Exception as e:
|
||||
api_logger.error(f"Failed to get presigned URL: {e}")
|
||||
@@ -632,3 +799,44 @@ async def permanent_download_file(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to retrieve file: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/files/{file_id}/status", response_model=ApiResponse)
|
||||
async def get_file_status(
|
||||
file_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
Get file upload/processing status (no authentication required).
|
||||
|
||||
This endpoint is used to check if a file (e.g., TTS audio) is ready.
|
||||
Returns status: pending, completed, or failed.
|
||||
|
||||
Args:
|
||||
file_id: The UUID of the file.
|
||||
db: Database session.
|
||||
|
||||
Returns:
|
||||
ApiResponse with file status and metadata.
|
||||
"""
|
||||
api_logger.info(f"File status request: file_id={file_id}")
|
||||
|
||||
# Query file metadata from database
|
||||
file_metadata = db.query(FileMetadata).filter(FileMetadata.id == file_id).first()
|
||||
if not file_metadata:
|
||||
api_logger.warning(f"File not found in database: file_id={file_id}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="The file does not exist"
|
||||
)
|
||||
|
||||
return success(
|
||||
data={
|
||||
"file_id": str(file_id),
|
||||
"status": file_metadata.status,
|
||||
"file_name": file_metadata.file_name,
|
||||
"file_size": file_metadata.file_size,
|
||||
"content_type": file_metadata.content_type,
|
||||
},
|
||||
msg="File status retrieved successfully"
|
||||
)
|
||||
|
||||
833
api/app/controllers/i18n_controller.py
Normal file
833
api/app/controllers/i18n_controller.py
Normal file
@@ -0,0 +1,833 @@
|
||||
"""
|
||||
I18n Management API Controller
|
||||
|
||||
This module provides management APIs for:
|
||||
- Language management (list, get, add, update languages)
|
||||
- Translation management (get, update, reload translations)
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import Callable, Optional
|
||||
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.core.response_utils import success
|
||||
from app.db import get_db
|
||||
from app.dependencies import get_current_user, get_current_superuser
|
||||
from app.i18n.dependencies import get_translator
|
||||
from app.i18n.service import get_translation_service
|
||||
from app.models.user_model import User
|
||||
from app.schemas.i18n_schema import (
|
||||
LanguageInfo,
|
||||
LanguageListResponse,
|
||||
LanguageCreateRequest,
|
||||
LanguageUpdateRequest,
|
||||
TranslationResponse,
|
||||
TranslationUpdateRequest,
|
||||
MissingTranslationsResponse,
|
||||
ReloadResponse
|
||||
)
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
|
||||
api_logger = get_api_logger()
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/i18n",
|
||||
tags=["I18n Management"],
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Language Management APIs
|
||||
# ============================================================================
|
||||
|
||||
@router.get("/languages", response_model=ApiResponse)
|
||||
def get_languages(
|
||||
t: Callable = Depends(get_translator),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Get list of all supported languages.
|
||||
|
||||
Returns:
|
||||
List of language information including code, name, and status
|
||||
"""
|
||||
api_logger.info(f"Get languages request from user: {current_user.username}")
|
||||
|
||||
from app.core.config import settings
|
||||
translation_service = get_translation_service()
|
||||
|
||||
# Get available locales from translation service
|
||||
available_locales = translation_service.get_available_locales()
|
||||
|
||||
# Build language info list
|
||||
languages = []
|
||||
for locale in available_locales:
|
||||
is_default = locale == settings.I18N_DEFAULT_LANGUAGE
|
||||
is_enabled = locale in settings.I18N_SUPPORTED_LANGUAGES
|
||||
|
||||
# Get native names
|
||||
native_names = {
|
||||
"zh": "中文(简体)",
|
||||
"en": "English",
|
||||
"ja": "日本語",
|
||||
"ko": "한국어",
|
||||
"fr": "Français",
|
||||
"de": "Deutsch",
|
||||
"es": "Español"
|
||||
}
|
||||
|
||||
language_info = LanguageInfo(
|
||||
code=locale,
|
||||
name=f"{locale.upper()}",
|
||||
native_name=native_names.get(locale, locale),
|
||||
is_enabled=is_enabled,
|
||||
is_default=is_default
|
||||
)
|
||||
languages.append(language_info)
|
||||
|
||||
response = LanguageListResponse(languages=languages)
|
||||
|
||||
api_logger.info(f"Returning {len(languages)} languages")
|
||||
return success(data=response.dict(), msg=t("common.success.retrieved"))
|
||||
|
||||
|
||||
@router.get("/languages/{locale}", response_model=ApiResponse)
|
||||
def get_language(
|
||||
locale: str,
|
||||
t: Callable = Depends(get_translator),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Get information about a specific language.
|
||||
|
||||
Args:
|
||||
locale: Language code (e.g., 'zh', 'en')
|
||||
|
||||
Returns:
|
||||
Language information
|
||||
"""
|
||||
api_logger.info(f"Get language info request: locale={locale}, user={current_user.username}")
|
||||
|
||||
from app.core.config import settings
|
||||
translation_service = get_translation_service()
|
||||
|
||||
# Check if locale exists
|
||||
available_locales = translation_service.get_available_locales()
|
||||
if locale not in available_locales:
|
||||
api_logger.warning(f"Language not found: {locale}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=t("i18n.language.not_found", locale=locale)
|
||||
)
|
||||
|
||||
# Build language info
|
||||
is_default = locale == settings.I18N_DEFAULT_LANGUAGE
|
||||
is_enabled = locale in settings.I18N_SUPPORTED_LANGUAGES
|
||||
|
||||
native_names = {
|
||||
"zh": "中文(简体)",
|
||||
"en": "English",
|
||||
"ja": "日本語",
|
||||
"ko": "한국어",
|
||||
"fr": "Français",
|
||||
"de": "Deutsch",
|
||||
"es": "Español"
|
||||
}
|
||||
|
||||
language_info = LanguageInfo(
|
||||
code=locale,
|
||||
name=f"{locale.upper()}",
|
||||
native_name=native_names.get(locale, locale),
|
||||
is_enabled=is_enabled,
|
||||
is_default=is_default
|
||||
)
|
||||
|
||||
api_logger.info(f"Returning language info for: {locale}")
|
||||
return success(data=language_info.dict(), msg=t("common.success.retrieved"))
|
||||
|
||||
|
||||
@router.post("/languages", response_model=ApiResponse)
|
||||
def add_language(
|
||||
request: LanguageCreateRequest,
|
||||
t: Callable = Depends(get_translator),
|
||||
current_user: User = Depends(get_current_superuser)
|
||||
):
|
||||
"""
|
||||
Add a new language (admin only).
|
||||
|
||||
Note: This endpoint validates the request but actual language addition
|
||||
requires creating translation files in the locales directory.
|
||||
|
||||
Args:
|
||||
request: Language creation request
|
||||
|
||||
Returns:
|
||||
Success message
|
||||
"""
|
||||
api_logger.info(
|
||||
f"Add language request: code={request.code}, admin={current_user.username}"
|
||||
)
|
||||
|
||||
from app.core.config import settings
|
||||
translation_service = get_translation_service()
|
||||
|
||||
# Check if language already exists
|
||||
available_locales = translation_service.get_available_locales()
|
||||
if request.code in available_locales:
|
||||
api_logger.warning(f"Language already exists: {request.code}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=t("i18n.language.already_exists", locale=request.code)
|
||||
)
|
||||
|
||||
# Note: Actual language addition requires creating translation files
|
||||
# This endpoint serves as a validation and documentation point
|
||||
|
||||
api_logger.info(
|
||||
f"Language addition validated: {request.code}. "
|
||||
"Translation files need to be created manually."
|
||||
)
|
||||
|
||||
return success(
|
||||
msg=t(
|
||||
"i18n.language.add_instructions",
|
||||
locale=request.code,
|
||||
dir=settings.I18N_CORE_LOCALES_DIR
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@router.put("/languages/{locale}", response_model=ApiResponse)
|
||||
def update_language(
|
||||
locale: str,
|
||||
request: LanguageUpdateRequest,
|
||||
t: Callable = Depends(get_translator),
|
||||
current_user: User = Depends(get_current_superuser)
|
||||
):
|
||||
"""
|
||||
Update language configuration (admin only).
|
||||
|
||||
Note: This endpoint validates the request but actual configuration
|
||||
changes require updating environment variables or config files.
|
||||
|
||||
Args:
|
||||
locale: Language code
|
||||
request: Language update request
|
||||
|
||||
Returns:
|
||||
Success message
|
||||
"""
|
||||
api_logger.info(
|
||||
f"Update language request: locale={locale}, admin={current_user.username}"
|
||||
)
|
||||
|
||||
translation_service = get_translation_service()
|
||||
|
||||
# Check if language exists
|
||||
available_locales = translation_service.get_available_locales()
|
||||
if locale not in available_locales:
|
||||
api_logger.warning(f"Language not found: {locale}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=t("i18n.language.not_found", locale=locale)
|
||||
)
|
||||
|
||||
# Note: Actual configuration changes require updating settings
|
||||
# This endpoint serves as a validation and documentation point
|
||||
|
||||
api_logger.info(
|
||||
f"Language update validated: {locale}. "
|
||||
"Configuration changes require environment variable updates."
|
||||
)
|
||||
|
||||
return success(msg=t("i18n.language.update_instructions", locale=locale))
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Translation Management APIs
|
||||
# ============================================================================
|
||||
|
||||
@router.get("/translations", response_model=ApiResponse)
|
||||
def get_all_translations(
|
||||
locale: Optional[str] = None,
|
||||
t: Callable = Depends(get_translator),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Get all translations for all or specific locale.
|
||||
|
||||
Args:
|
||||
locale: Optional locale filter
|
||||
|
||||
Returns:
|
||||
All translations organized by locale and namespace
|
||||
"""
|
||||
api_logger.info(
|
||||
f"Get all translations request: locale={locale}, user={current_user.username}"
|
||||
)
|
||||
|
||||
translation_service = get_translation_service()
|
||||
|
||||
if locale:
|
||||
# Get translations for specific locale
|
||||
available_locales = translation_service.get_available_locales()
|
||||
if locale not in available_locales:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=t("i18n.language.not_found", locale=locale)
|
||||
)
|
||||
|
||||
translations = {
|
||||
locale: translation_service._cache.get(locale, {})
|
||||
}
|
||||
else:
|
||||
# Get all translations
|
||||
translations = translation_service._cache
|
||||
|
||||
response = TranslationResponse(translations=translations)
|
||||
|
||||
api_logger.info(f"Returning translations for: {locale or 'all locales'}")
|
||||
return success(data=response.dict(), msg=t("common.success.retrieved"))
|
||||
|
||||
|
||||
@router.get("/translations/{locale}", response_model=ApiResponse)
|
||||
def get_locale_translations(
|
||||
locale: str,
|
||||
t: Callable = Depends(get_translator),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Get all translations for a specific locale.
|
||||
|
||||
Args:
|
||||
locale: Language code
|
||||
|
||||
Returns:
|
||||
All translations for the locale organized by namespace
|
||||
"""
|
||||
api_logger.info(
|
||||
f"Get locale translations request: locale={locale}, user={current_user.username}"
|
||||
)
|
||||
|
||||
translation_service = get_translation_service()
|
||||
|
||||
# Check if locale exists
|
||||
available_locales = translation_service.get_available_locales()
|
||||
if locale not in available_locales:
|
||||
api_logger.warning(f"Language not found: {locale}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=t("i18n.language.not_found", locale=locale)
|
||||
)
|
||||
|
||||
translations = translation_service._cache.get(locale, {})
|
||||
|
||||
api_logger.info(f"Returning {len(translations)} namespaces for locale: {locale}")
|
||||
return success(data={"locale": locale, "translations": translations}, msg=t("common.success.retrieved"))
|
||||
|
||||
|
||||
@router.get("/translations/{locale}/{namespace}", response_model=ApiResponse)
|
||||
def get_namespace_translations(
|
||||
locale: str,
|
||||
namespace: str,
|
||||
t: Callable = Depends(get_translator),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Get translations for a specific namespace in a locale.
|
||||
|
||||
Args:
|
||||
locale: Language code
|
||||
namespace: Translation namespace (e.g., 'common', 'auth')
|
||||
|
||||
Returns:
|
||||
Translations for the specified namespace
|
||||
"""
|
||||
api_logger.info(
|
||||
f"Get namespace translations request: locale={locale}, "
|
||||
f"namespace={namespace}, user={current_user.username}"
|
||||
)
|
||||
|
||||
translation_service = get_translation_service()
|
||||
|
||||
# Check if locale exists
|
||||
available_locales = translation_service.get_available_locales()
|
||||
if locale not in available_locales:
|
||||
api_logger.warning(f"Language not found: {locale}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=t("i18n.language.not_found", locale=locale)
|
||||
)
|
||||
|
||||
# Get namespace translations
|
||||
locale_translations = translation_service._cache.get(locale, {})
|
||||
namespace_translations = locale_translations.get(namespace, {})
|
||||
|
||||
if not namespace_translations:
|
||||
api_logger.warning(f"Namespace not found: {namespace} in locale: {locale}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=t("i18n.namespace.not_found", namespace=namespace, locale=locale)
|
||||
)
|
||||
|
||||
api_logger.info(
|
||||
f"Returning translations for namespace: {namespace} in locale: {locale}"
|
||||
)
|
||||
return success(
|
||||
data={
|
||||
"locale": locale,
|
||||
"namespace": namespace,
|
||||
"translations": namespace_translations
|
||||
},
|
||||
msg=t("common.success.retrieved")
|
||||
)
|
||||
|
||||
|
||||
@router.put("/translations/{locale}/{key:path}", response_model=ApiResponse)
|
||||
def update_translation(
|
||||
locale: str,
|
||||
key: str,
|
||||
request: TranslationUpdateRequest,
|
||||
t: Callable = Depends(get_translator),
|
||||
current_user: User = Depends(get_current_superuser)
|
||||
):
|
||||
"""
|
||||
Update a single translation (admin only).
|
||||
|
||||
Note: This endpoint validates the request but actual translation updates
|
||||
require modifying translation files in the locales directory.
|
||||
|
||||
Args:
|
||||
locale: Language code
|
||||
key: Translation key (format: "namespace.key.subkey")
|
||||
request: Translation update request
|
||||
|
||||
Returns:
|
||||
Success message
|
||||
"""
|
||||
api_logger.info(
|
||||
f"Update translation request: locale={locale}, key={key}, "
|
||||
f"admin={current_user.username}"
|
||||
)
|
||||
|
||||
translation_service = get_translation_service()
|
||||
|
||||
# Check if locale exists
|
||||
available_locales = translation_service.get_available_locales()
|
||||
if locale not in available_locales:
|
||||
api_logger.warning(f"Language not found: {locale}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=t("i18n.language.not_found", locale=locale)
|
||||
)
|
||||
|
||||
# Validate key format
|
||||
if "." not in key:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=t("i18n.translation.invalid_key_format", key=key)
|
||||
)
|
||||
|
||||
# Note: Actual translation updates require modifying JSON files
|
||||
# This endpoint serves as a validation and documentation point
|
||||
|
||||
api_logger.info(
|
||||
f"Translation update validated: {locale}/{key}. "
|
||||
"Translation files need to be updated manually."
|
||||
)
|
||||
|
||||
return success(
|
||||
msg=t("i18n.translation.update_instructions", locale=locale, key=key)
|
||||
)
|
||||
|
||||
|
||||
@router.get("/translations/missing", response_model=ApiResponse)
|
||||
def get_missing_translations(
|
||||
locale: Optional[str] = None,
|
||||
t: Callable = Depends(get_translator),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Get list of missing translations.
|
||||
|
||||
Compares translations across locales to find missing keys.
|
||||
|
||||
Args:
|
||||
locale: Optional locale to check (defaults to checking all non-default locales)
|
||||
|
||||
Returns:
|
||||
List of missing translation keys
|
||||
"""
|
||||
api_logger.info(
|
||||
f"Get missing translations request: locale={locale}, user={current_user.username}"
|
||||
)
|
||||
|
||||
from app.core.config import settings
|
||||
translation_service = get_translation_service()
|
||||
|
||||
default_locale = settings.I18N_DEFAULT_LANGUAGE
|
||||
available_locales = translation_service.get_available_locales()
|
||||
|
||||
# Get default locale translations as reference
|
||||
default_translations = translation_service._cache.get(default_locale, {})
|
||||
|
||||
# Collect all keys from default locale
|
||||
def collect_keys(data, prefix=""):
|
||||
keys = []
|
||||
for key, value in data.items():
|
||||
full_key = f"{prefix}.{key}" if prefix else key
|
||||
if isinstance(value, dict):
|
||||
keys.extend(collect_keys(value, full_key))
|
||||
else:
|
||||
keys.append(full_key)
|
||||
return keys
|
||||
|
||||
default_keys = set()
|
||||
for namespace, translations in default_translations.items():
|
||||
namespace_keys = collect_keys(translations, namespace)
|
||||
default_keys.update(namespace_keys)
|
||||
|
||||
# Find missing keys in target locale(s)
|
||||
missing_by_locale = {}
|
||||
|
||||
target_locales = [locale] if locale else [
|
||||
loc for loc in available_locales if loc != default_locale
|
||||
]
|
||||
|
||||
for target_locale in target_locales:
|
||||
if target_locale not in available_locales:
|
||||
continue
|
||||
|
||||
target_translations = translation_service._cache.get(target_locale, {})
|
||||
target_keys = set()
|
||||
|
||||
for namespace, translations in target_translations.items():
|
||||
namespace_keys = collect_keys(translations, namespace)
|
||||
target_keys.update(namespace_keys)
|
||||
|
||||
missing_keys = default_keys - target_keys
|
||||
if missing_keys:
|
||||
missing_by_locale[target_locale] = sorted(list(missing_keys))
|
||||
|
||||
response = MissingTranslationsResponse(missing_translations=missing_by_locale)
|
||||
|
||||
total_missing = sum(len(keys) for keys in missing_by_locale.values())
|
||||
api_logger.info(f"Found {total_missing} missing translations across {len(missing_by_locale)} locales")
|
||||
|
||||
return success(data=response.dict(), msg=t("common.success.retrieved"))
|
||||
|
||||
|
||||
@router.post("/reload", response_model=ApiResponse)
|
||||
def reload_translations(
|
||||
locale: Optional[str] = None,
|
||||
t: Callable = Depends(get_translator),
|
||||
current_user: User = Depends(get_current_superuser)
|
||||
):
|
||||
"""
|
||||
Trigger hot reload of translation files (admin only).
|
||||
|
||||
Args:
|
||||
locale: Optional locale to reload (defaults to reloading all locales)
|
||||
|
||||
Returns:
|
||||
Reload status and statistics
|
||||
"""
|
||||
api_logger.info(
|
||||
f"Reload translations request: locale={locale or 'all'}, "
|
||||
f"admin={current_user.username}"
|
||||
)
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
if not settings.I18N_ENABLE_HOT_RELOAD:
|
||||
api_logger.warning("Hot reload is disabled in configuration")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=t("i18n.reload.disabled")
|
||||
)
|
||||
|
||||
translation_service = get_translation_service()
|
||||
|
||||
try:
|
||||
# Reload translations
|
||||
translation_service.reload(locale)
|
||||
|
||||
# Get statistics
|
||||
available_locales = translation_service.get_available_locales()
|
||||
reloaded_locales = [locale] if locale else available_locales
|
||||
|
||||
response = ReloadResponse(
|
||||
success=True,
|
||||
reloaded_locales=reloaded_locales,
|
||||
total_locales=len(available_locales)
|
||||
)
|
||||
|
||||
api_logger.info(
|
||||
f"Successfully reloaded translations for: {', '.join(reloaded_locales)}"
|
||||
)
|
||||
|
||||
return success(data=response.dict(), msg=t("i18n.reload.success"))
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"Failed to reload translations: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=t("i18n.reload.failed", error=str(e))
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Performance Monitoring APIs
|
||||
# ============================================================================
|
||||
|
||||
@router.get("/metrics", response_model=ApiResponse)
|
||||
def get_metrics(
|
||||
t: Callable = Depends(get_translator),
|
||||
current_user: User = Depends(get_current_superuser)
|
||||
):
|
||||
"""
|
||||
Get i18n performance metrics (admin only).
|
||||
|
||||
Returns:
|
||||
Performance metrics including:
|
||||
- Request counts
|
||||
- Missing translations
|
||||
- Timing statistics
|
||||
- Locale usage
|
||||
- Error counts
|
||||
"""
|
||||
api_logger.info(f"Get metrics request: admin={current_user.username}")
|
||||
|
||||
translation_service = get_translation_service()
|
||||
metrics = translation_service.get_metrics_summary()
|
||||
|
||||
api_logger.info("Returning i18n metrics")
|
||||
return success(data=metrics, msg=t("common.success.retrieved"))
|
||||
|
||||
|
||||
@router.get("/metrics/cache", response_model=ApiResponse)
|
||||
def get_cache_stats(
|
||||
t: Callable = Depends(get_translator),
|
||||
current_user: User = Depends(get_current_superuser)
|
||||
):
|
||||
"""
|
||||
Get cache statistics (admin only).
|
||||
|
||||
Returns:
|
||||
Cache statistics including:
|
||||
- Hit/miss rates
|
||||
- LRU cache performance
|
||||
- Loaded locales
|
||||
- Memory usage
|
||||
"""
|
||||
api_logger.info(f"Get cache stats request: admin={current_user.username}")
|
||||
|
||||
translation_service = get_translation_service()
|
||||
cache_stats = translation_service.get_cache_stats()
|
||||
memory_usage = translation_service.get_memory_usage()
|
||||
|
||||
data = {
|
||||
"cache": cache_stats,
|
||||
"memory": memory_usage
|
||||
}
|
||||
|
||||
api_logger.info("Returning cache statistics")
|
||||
return success(data=data, msg=t("common.success.retrieved"))
|
||||
|
||||
|
||||
@router.get("/metrics/prometheus")
|
||||
def get_prometheus_metrics(
|
||||
current_user: User = Depends(get_current_superuser)
|
||||
):
|
||||
"""
|
||||
Get metrics in Prometheus format (admin only).
|
||||
|
||||
Returns:
|
||||
Prometheus-formatted metrics as plain text
|
||||
"""
|
||||
api_logger.info(f"Get Prometheus metrics request: admin={current_user.username}")
|
||||
|
||||
from app.i18n.metrics import get_metrics
|
||||
metrics = get_metrics()
|
||||
prometheus_output = metrics.export_prometheus()
|
||||
|
||||
from fastapi.responses import PlainTextResponse
|
||||
return PlainTextResponse(content=prometheus_output)
|
||||
|
||||
|
||||
@router.post("/metrics/reset", response_model=ApiResponse)
|
||||
def reset_metrics(
|
||||
t: Callable = Depends(get_translator),
|
||||
current_user: User = Depends(get_current_superuser)
|
||||
):
|
||||
"""
|
||||
Reset all metrics (admin only).
|
||||
|
||||
Returns:
|
||||
Success message
|
||||
"""
|
||||
api_logger.info(f"Reset metrics request: admin={current_user.username}")
|
||||
|
||||
from app.i18n.metrics import get_metrics
|
||||
metrics = get_metrics()
|
||||
metrics.reset()
|
||||
|
||||
translation_service = get_translation_service()
|
||||
translation_service.cache.reset_stats()
|
||||
|
||||
api_logger.info("Metrics reset completed")
|
||||
return success(msg=t("i18n.metrics.reset_success"))
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Missing Translation Logging and Reporting APIs
|
||||
# ============================================================================
|
||||
|
||||
@router.get("/logs/missing", response_model=ApiResponse)
|
||||
def get_missing_translation_logs(
|
||||
locale: Optional[str] = None,
|
||||
limit: Optional[int] = 100,
|
||||
t: Callable = Depends(get_translator),
|
||||
current_user: User = Depends(get_current_superuser)
|
||||
):
|
||||
"""
|
||||
Get missing translation logs (admin only).
|
||||
|
||||
Returns logged missing translations with context information.
|
||||
|
||||
Args:
|
||||
locale: Optional locale filter
|
||||
limit: Maximum number of entries to return (default: 100)
|
||||
|
||||
Returns:
|
||||
Missing translation logs with context
|
||||
"""
|
||||
api_logger.info(
|
||||
f"Get missing translation logs request: locale={locale}, "
|
||||
f"limit={limit}, admin={current_user.username}"
|
||||
)
|
||||
|
||||
translation_service = get_translation_service()
|
||||
translation_logger = translation_service.translation_logger
|
||||
|
||||
# Get missing translations
|
||||
missing_translations = translation_logger.get_missing_translations(locale)
|
||||
|
||||
# Get missing with context
|
||||
missing_with_context = translation_logger.get_missing_with_context(locale, limit)
|
||||
|
||||
# Get statistics
|
||||
statistics = translation_logger.get_statistics()
|
||||
|
||||
data = {
|
||||
"missing_translations": missing_translations,
|
||||
"recent_context": missing_with_context,
|
||||
"statistics": statistics
|
||||
}
|
||||
|
||||
api_logger.info(
|
||||
f"Returning {statistics['total_missing']} missing translations"
|
||||
)
|
||||
return success(data=data, msg=t("common.success.retrieved"))
|
||||
|
||||
|
||||
@router.get("/logs/missing/report", response_model=ApiResponse)
|
||||
def generate_missing_translation_report(
|
||||
locale: Optional[str] = None,
|
||||
t: Callable = Depends(get_translator),
|
||||
current_user: User = Depends(get_current_superuser)
|
||||
):
|
||||
"""
|
||||
Generate a comprehensive missing translation report (admin only).
|
||||
|
||||
Args:
|
||||
locale: Optional locale filter
|
||||
|
||||
Returns:
|
||||
Comprehensive report with missing translations and statistics
|
||||
"""
|
||||
api_logger.info(
|
||||
f"Generate missing translation report request: locale={locale}, "
|
||||
f"admin={current_user.username}"
|
||||
)
|
||||
|
||||
translation_service = get_translation_service()
|
||||
translation_logger = translation_service.translation_logger
|
||||
|
||||
# Generate report
|
||||
report = translation_logger.generate_report(locale)
|
||||
|
||||
api_logger.info(
|
||||
f"Generated report with {report['total_missing']} missing translations"
|
||||
)
|
||||
return success(data=report, msg=t("common.success.retrieved"))
|
||||
|
||||
|
||||
@router.post("/logs/missing/export", response_model=ApiResponse)
|
||||
def export_missing_translations(
|
||||
locale: Optional[str] = None,
|
||||
t: Callable = Depends(get_translator),
|
||||
current_user: User = Depends(get_current_superuser)
|
||||
):
|
||||
"""
|
||||
Export missing translations to JSON file (admin only).
|
||||
|
||||
Args:
|
||||
locale: Optional locale filter
|
||||
|
||||
Returns:
|
||||
Export status and file path
|
||||
"""
|
||||
api_logger.info(
|
||||
f"Export missing translations request: locale={locale}, "
|
||||
f"admin={current_user.username}"
|
||||
)
|
||||
|
||||
from datetime import datetime
|
||||
translation_service = get_translation_service()
|
||||
translation_logger = translation_service.translation_logger
|
||||
|
||||
# Generate filename with timestamp
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
locale_suffix = f"_{locale}" if locale else "_all"
|
||||
output_file = f"logs/i18n/missing_translations{locale_suffix}_{timestamp}.json"
|
||||
|
||||
# Export to file
|
||||
translation_logger.export_to_json(output_file)
|
||||
|
||||
api_logger.info(f"Missing translations exported to: {output_file}")
|
||||
return success(
|
||||
data={"file_path": output_file},
|
||||
msg=t("i18n.logs.export_success", file=output_file)
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/logs/missing", response_model=ApiResponse)
|
||||
def clear_missing_translation_logs(
|
||||
locale: Optional[str] = None,
|
||||
t: Callable = Depends(get_translator),
|
||||
current_user: User = Depends(get_current_superuser)
|
||||
):
|
||||
"""
|
||||
Clear missing translation logs (admin only).
|
||||
|
||||
Args:
|
||||
locale: Optional locale to clear (clears all if not specified)
|
||||
|
||||
Returns:
|
||||
Success message
|
||||
"""
|
||||
api_logger.info(
|
||||
f"Clear missing translation logs request: locale={locale or 'all'}, "
|
||||
f"admin={current_user.username}"
|
||||
)
|
||||
|
||||
translation_service = get_translation_service()
|
||||
translation_logger = translation_service.translation_logger
|
||||
|
||||
# Clear logs
|
||||
translation_logger.clear(locale)
|
||||
|
||||
api_logger.info(f"Cleared missing translation logs for: {locale or 'all locales'}")
|
||||
return success(msg=t("i18n.logs.clear_success"))
|
||||
@@ -122,6 +122,48 @@ def validate_confidence_threshold(threshold: float) -> None:
|
||||
raise ValueError("confidence_threshold must be between 0.0 and 1.0")
|
||||
|
||||
|
||||
@router.get("/check-data/{end_user_id}", response_model=ApiResponse)
|
||||
@cur_workspace_access_guard()
|
||||
async def check_user_data_exists(
|
||||
end_user_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
) -> ApiResponse:
|
||||
"""
|
||||
检查用户画像数据是否存在
|
||||
|
||||
Args:
|
||||
end_user_id: 目标用户ID
|
||||
|
||||
Returns:
|
||||
数据存在状态
|
||||
"""
|
||||
api_logger.info(f"检查用户画像数据是否存在: {end_user_id}")
|
||||
|
||||
try:
|
||||
# Validate inputs
|
||||
validate_user_id(end_user_id)
|
||||
|
||||
# Create service with user-specific config
|
||||
service = ImplicitMemoryService(db=db, end_user_id=end_user_id)
|
||||
|
||||
# Get cached profile
|
||||
cached_profile = await service.get_cached_profile(end_user_id=end_user_id, db=db)
|
||||
|
||||
if cached_profile is None:
|
||||
api_logger.info(f"用户 {end_user_id} 的画像数据不存在")
|
||||
return success(
|
||||
data={"exists": False},
|
||||
msg="画像数据不存在,请点击右上角刷新进行初始化"
|
||||
)
|
||||
|
||||
api_logger.info(f"用户 {end_user_id} 的画像数据存在")
|
||||
return success(data={"exists": True}, msg="画像数据已存在")
|
||||
|
||||
except Exception as e:
|
||||
return handle_implicit_memory_error(e, "检查画像数据", end_user_id)
|
||||
|
||||
|
||||
@router.get("/preferences/{end_user_id}", response_model=ApiResponse)
|
||||
@cur_workspace_access_guard()
|
||||
async def get_preference_tags(
|
||||
@@ -159,12 +201,8 @@ async def get_preference_tags(
|
||||
cached_profile = await service.get_cached_profile(end_user_id=end_user_id, db=db)
|
||||
|
||||
if cached_profile is None:
|
||||
api_logger.info(f"用户 {end_user_id} 的画像缓存不存在或已过期")
|
||||
return fail(
|
||||
BizCode.NOT_FOUND,
|
||||
"画像缓存不存在或已过期,请右上角刷新生成新画像",
|
||||
""
|
||||
)
|
||||
api_logger.info(f"用户 {end_user_id} 的画像数据不存在")
|
||||
return fail(BizCode.NOT_FOUND, "", "")
|
||||
|
||||
# Extract preferences from cache
|
||||
preferences = cached_profile.get("preferences", [])
|
||||
@@ -230,12 +268,8 @@ async def get_dimension_portrait(
|
||||
cached_profile = await service.get_cached_profile(end_user_id=end_user_id, db=db)
|
||||
|
||||
if cached_profile is None:
|
||||
api_logger.info(f"用户 {end_user_id} 的画像缓存不存在或已过期")
|
||||
return fail(
|
||||
BizCode.NOT_FOUND,
|
||||
"画像缓存不存在或已过期,请右上角刷新生成新画像",
|
||||
""
|
||||
)
|
||||
api_logger.info(f"用户 {end_user_id} 的画像数据不存在")
|
||||
return fail(BizCode.NOT_FOUND, "", "")
|
||||
|
||||
# Extract portrait from cache
|
||||
portrait = cached_profile.get("portrait", {})
|
||||
@@ -278,12 +312,8 @@ async def get_interest_area_distribution(
|
||||
cached_profile = await service.get_cached_profile(end_user_id=end_user_id, db=db)
|
||||
|
||||
if cached_profile is None:
|
||||
api_logger.info(f"用户 {end_user_id} 的画像缓存不存在或已过期")
|
||||
return fail(
|
||||
BizCode.NOT_FOUND,
|
||||
"画像缓存不存在或已过期,请右上角刷新生成新画像",
|
||||
""
|
||||
)
|
||||
api_logger.info(f"用户 {end_user_id} 的画像数据不存在")
|
||||
return fail(BizCode.NOT_FOUND, "", "")
|
||||
|
||||
# Extract interest areas from cache
|
||||
interest_areas = cached_profile.get("interest_areas", {})
|
||||
@@ -330,12 +360,8 @@ async def get_behavior_habits(
|
||||
cached_profile = await service.get_cached_profile(end_user_id=end_user_id, db=db)
|
||||
|
||||
if cached_profile is None:
|
||||
api_logger.info(f"用户 {end_user_id} 的画像缓存不存在或已过期")
|
||||
return fail(
|
||||
BizCode.NOT_FOUND,
|
||||
"画像缓存不存在或已过期,请右上角刷新生成新画像",
|
||||
""
|
||||
)
|
||||
api_logger.info(f"用户 {end_user_id} 的画像数据不存在")
|
||||
return fail(BizCode.NOT_FOUND, "", "")
|
||||
|
||||
# Extract habits from cache
|
||||
habits = cached_profile.get("habits", [])
|
||||
|
||||
@@ -19,7 +19,7 @@ from app.models import mcp_market_config_model
|
||||
from app.models.user_model import User
|
||||
from app.schemas import mcp_market_config_schema
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.services import mcp_market_config_service
|
||||
from app.services import mcp_market_config_service, mcp_market_service
|
||||
|
||||
# Obtain a dedicated API logger
|
||||
api_logger = get_api_logger()
|
||||
@@ -55,6 +55,12 @@ async def get_mcp_servers(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="The paging parameter must be greater than 0"
|
||||
)
|
||||
if page * pagesize > 100:
|
||||
api_logger.warning(f"Paging parameters exceed ModelScope limit: page={page}, pagesize={pagesize}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"The maximum number of MCP services can view is 100. Please visit the ModelScope MCP Plaza."
|
||||
)
|
||||
|
||||
# 2. Query mcp market config information from the database
|
||||
api_logger.debug(f"Query mcp market config: {mcp_market_config_id}")
|
||||
@@ -64,14 +70,16 @@ async def get_mcp_servers(
|
||||
if not db_mcp_market_config:
|
||||
api_logger.warning(
|
||||
f"The mcp market config does not exist or access is denied: mcp_market_config_id={mcp_market_config_id}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="The mcp market config does not exist or access is denied"
|
||||
)
|
||||
return success(msg='The mcp market config does not exist or access is denied')
|
||||
|
||||
# 3. Execute paged query
|
||||
api = MCPApi()
|
||||
token = db_mcp_market_config.token
|
||||
if not token:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="MCP market config token is not configured"
|
||||
)
|
||||
api = MCPApi()
|
||||
api.login(token)
|
||||
|
||||
body = {
|
||||
@@ -83,14 +91,16 @@ async def get_mcp_servers(
|
||||
|
||||
try:
|
||||
cookies = api.get_cookies(token)
|
||||
headers=api.builder_headers(api.headers)
|
||||
headers['Authorization'] = f'Bearer {token}'
|
||||
r = api.session.put(
|
||||
url=api.mcp_base_url,
|
||||
headers=api.builder_headers(api.headers),
|
||||
headers=headers,
|
||||
json=body,
|
||||
cookies=cookies)
|
||||
raise_for_http_status(r)
|
||||
except requests.exceptions.RequestException as e:
|
||||
api_logger.error(f"mFailed to get MCP servers: {str(e)}")
|
||||
api_logger.error(f"Failed to get MCP servers: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to get MCP servers: {str(e)}"
|
||||
@@ -115,9 +125,82 @@ async def get_mcp_servers(
|
||||
"has_next": True if page * pagesize < total else False
|
||||
}
|
||||
}
|
||||
# 5. Update mck_market.mcp_count
|
||||
db_mcp_market = mcp_market_service.get_mcp_market_by_id(db, mcp_market_id=db_mcp_market_config.mcp_market_id, current_user=current_user)
|
||||
if not db_mcp_market:
|
||||
api_logger.warning(f"The mcp market does not exist or access is denied: mcp_market_id={db_mcp_market_config.mcp_market_id}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="The mcp market does not exist or access is denied"
|
||||
)
|
||||
db_mcp_market.mcp_count = total
|
||||
db.commit()
|
||||
db.refresh(db_mcp_market)
|
||||
return success(data=result, msg="Query of mcp servers list successful")
|
||||
|
||||
|
||||
@router.get("/operational_mcp_servers", response_model=ApiResponse)
|
||||
async def get_operational_mcp_servers(
|
||||
mcp_market_config_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Query the operational mcp servers list in pages
|
||||
- Support keyword search for name,author,owner
|
||||
- Return paging metadata + operational mcp server list
|
||||
"""
|
||||
api_logger.info(
|
||||
f"Query operational mcp server list: tenant_id={current_user.tenant_id}, username: {current_user.username}")
|
||||
|
||||
# 1. Query mcp market config information from the database
|
||||
api_logger.debug(f"Query mcp market config: {mcp_market_config_id}")
|
||||
db_mcp_market_config = mcp_market_config_service.get_mcp_market_config_by_id(db,
|
||||
mcp_market_config_id=mcp_market_config_id,
|
||||
current_user=current_user)
|
||||
if not db_mcp_market_config:
|
||||
api_logger.warning(
|
||||
f"The mcp market config does not exist or access is denied: mcp_market_config_id={mcp_market_config_id}")
|
||||
return success(msg='The mcp market config does not exist or access is denied')
|
||||
|
||||
# 2. Execute paged query
|
||||
token = db_mcp_market_config.token
|
||||
if not token:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="MCP market config token is not configured"
|
||||
)
|
||||
api = MCPApi()
|
||||
api.login(token)
|
||||
|
||||
url = f'{api.mcp_base_url}/operational'
|
||||
headers = api.builder_headers(api.headers)
|
||||
headers['Authorization'] = f'Bearer {token}'
|
||||
|
||||
try:
|
||||
cookies = api.get_cookies(access_token=token, cookies_required=True)
|
||||
r = api.session.get(url, headers=headers, cookies=cookies)
|
||||
raise_for_http_status(r)
|
||||
except requests.exceptions.RequestException as e:
|
||||
api_logger.error(f"Failed to get operational MCP servers: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to get operational MCP servers: {str(e)}"
|
||||
)
|
||||
|
||||
data = api._handle_response(r)
|
||||
total = data.get('total_count', 0)
|
||||
mcp_server_list = data.get('mcp_server_list', [])
|
||||
# items = [{
|
||||
# 'name': item.get('name', ''),
|
||||
# 'id': item.get('id', ''),
|
||||
# 'description': item.get('description', '')
|
||||
# } for item in mcp_server_list]
|
||||
|
||||
# 3. Return structured response
|
||||
return success(data=mcp_server_list, msg="Query of operational mcp servers list successful")
|
||||
|
||||
|
||||
@router.get("/mcp_server", response_model=ApiResponse)
|
||||
async def get_mcp_server(
|
||||
mcp_market_config_id: uuid.UUID,
|
||||
@@ -139,14 +222,16 @@ async def get_mcp_server(
|
||||
if not db_mcp_market_config:
|
||||
api_logger.warning(
|
||||
f"The mcp market config does not exist or access is denied: mcp_market_config_id={mcp_market_config_id}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="The mcp market config does not exist or access is denied"
|
||||
)
|
||||
return success(msg='The mcp market config does not exist or access is denied')
|
||||
|
||||
# 2. Get detailed information for a specific MCP Server
|
||||
api = MCPApi()
|
||||
token = db_mcp_market_config.token
|
||||
if not token:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="MCP market config token is not configured"
|
||||
)
|
||||
api = MCPApi()
|
||||
api.login(token)
|
||||
|
||||
result = api.get_mcp_server(server_id=server_id)
|
||||
@@ -167,7 +252,28 @@ async def create_mcp_market_config(
|
||||
|
||||
try:
|
||||
api_logger.debug(f"Start creating the mcp market config: {create_data.mcp_market_id}")
|
||||
# 1. Check if the mcp market name already exists
|
||||
# 1. Validate token can access ModelScope MCP market
|
||||
if not create_data.token:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Token is required to access ModelScope MCP market"
|
||||
)
|
||||
try:
|
||||
api = MCPApi()
|
||||
api.login(create_data.token)
|
||||
body = {'filter': {}, 'page_number': 1, 'page_size': 1, 'search': None}
|
||||
cookies = api.get_cookies(create_data.token)
|
||||
headers = api.builder_headers(api.headers)
|
||||
headers['Authorization'] = f'Bearer {create_data.token}'
|
||||
r = api.session.put(url=api.mcp_base_url, headers=headers, json=body, cookies=cookies)
|
||||
raise_for_http_status(r)
|
||||
except Exception as e:
|
||||
api_logger.warning(f"Token validation failed for ModelScope MCP market: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Unable to access ModelScope MCP market with the provided token: {str(e)}"
|
||||
)
|
||||
# 2. Check if the mcp market name already exists
|
||||
db_mcp_market_config_exist = mcp_market_config_service.get_mcp_market_config_by_mcp_market_id(db, mcp_market_id=create_data.mcp_market_id, current_user=current_user)
|
||||
if db_mcp_market_config_exist:
|
||||
api_logger.warning(f"The mcp market id already exists: {create_data.mcp_market_id}")
|
||||
@@ -175,6 +281,32 @@ async def create_mcp_market_config(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"The mcp market id already exists: {create_data.mcp_market_id}"
|
||||
)
|
||||
# 2. verify token
|
||||
create_data.status = 1
|
||||
try:
|
||||
api = MCPApi()
|
||||
token = create_data.token
|
||||
api.login(token)
|
||||
|
||||
body = {
|
||||
'filter': {},
|
||||
'page_number': 1,
|
||||
'page_size': 20,
|
||||
'search': ""
|
||||
}
|
||||
cookies = api.get_cookies(token)
|
||||
headers = api.builder_headers(api.headers)
|
||||
headers['Authorization'] = f'Bearer {token}'
|
||||
r = api.session.put(
|
||||
url=api.mcp_base_url,
|
||||
headers=headers,
|
||||
json=body,
|
||||
cookies=cookies)
|
||||
raise_for_http_status(r)
|
||||
except requests.exceptions.RequestException as e:
|
||||
api_logger.error(f"Failed to get MCP servers: {str(e)}")
|
||||
create_data.status = 0
|
||||
# 3. create mcp_market_config
|
||||
db_mcp_market_config = mcp_market_config_service.create_mcp_market_config(db=db, mcp_market_config=create_data, current_user=current_user)
|
||||
api_logger.info(
|
||||
f"The mcp market config has been successfully created: (ID: {db_mcp_market_config.id})")
|
||||
@@ -203,10 +335,7 @@ async def get_mcp_market_config(
|
||||
db_mcp_market_config = mcp_market_config_service.get_mcp_market_config_by_id(db, mcp_market_config_id=mcp_market_config_id, current_user=current_user)
|
||||
if not db_mcp_market_config:
|
||||
api_logger.warning(f"The mcp market config does not exist or access is denied: mcp_market_config_id={mcp_market_config_id}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="The mcp market config does not exist or access is denied"
|
||||
)
|
||||
return success(msg='The mcp market config does not exist or access is denied')
|
||||
|
||||
api_logger.info(f"mcp market config query successful: (ID: {db_mcp_market_config.id})")
|
||||
return success(data=jsonable_encoder(mcp_market_config_schema.McpMarketConfig.model_validate(db_mcp_market_config)),
|
||||
@@ -236,10 +365,7 @@ async def get_mcp_market_config_by_mcp_market_id(
|
||||
db_mcp_market_config = mcp_market_config_service.get_mcp_market_config_by_mcp_market_id(db, mcp_market_id=mcp_market_id, current_user=current_user)
|
||||
if not db_mcp_market_config:
|
||||
api_logger.warning(f"The mcp market config does not exist or access is denied: mcp_market_id={mcp_market_id}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="The mcp market config does not exist or access is denied"
|
||||
)
|
||||
return success(msg='The mcp market config does not exist or access is denied')
|
||||
|
||||
api_logger.info(f"mcp market config query successful: (ID: {db_mcp_market_config.id})")
|
||||
return success(data=jsonable_encoder(mcp_market_config_schema.McpMarketConfig.model_validate(db_mcp_market_config)),
|
||||
@@ -265,12 +391,27 @@ async def update_mcp_market_config(
|
||||
if not db_mcp_market_config:
|
||||
api_logger.warning(
|
||||
f"The mcp market config does not exist or you do not have permission to access it: mcp_market_config_id={mcp_market_config_id}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="The mcp market config does not exist or you do not have permission to access it"
|
||||
)
|
||||
return success(msg='The mcp market config does not exist or access is denied')
|
||||
|
||||
# 2. Update fields (only update non-null fields)
|
||||
# 2. Validate new token if provided
|
||||
if update_data.token is not None:
|
||||
try:
|
||||
api = MCPApi()
|
||||
api.login(update_data.token)
|
||||
body = {'filter': {}, 'page_number': 1, 'page_size': 1, 'search': None}
|
||||
cookies = api.get_cookies(update_data.token)
|
||||
headers = api.builder_headers(api.headers)
|
||||
headers['Authorization'] = f'Bearer {update_data.token}'
|
||||
r = api.session.put(url=api.mcp_base_url, headers=headers, json=body, cookies=cookies)
|
||||
raise_for_http_status(r)
|
||||
except Exception as e:
|
||||
api_logger.warning(f"Token validation failed for ModelScope MCP market: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Unable to access ModelScope MCP market with the provided token: {str(e)}"
|
||||
)
|
||||
|
||||
# 3. Update fields (only update non-null fields)
|
||||
api_logger.debug(f"Start updating the mcp market config fields: {mcp_market_config_id}")
|
||||
update_dict = update_data.dict(exclude_unset=True)
|
||||
updated_fields = []
|
||||
@@ -285,7 +426,7 @@ async def update_mcp_market_config(
|
||||
if updated_fields:
|
||||
api_logger.debug(f"updated fields: {', '.join(updated_fields)}")
|
||||
|
||||
# 3. Save to database
|
||||
# 4. Save to database
|
||||
try:
|
||||
db.commit()
|
||||
db.refresh(db_mcp_market_config)
|
||||
@@ -298,7 +439,7 @@ async def update_mcp_market_config(
|
||||
detail=f"The mcp market config update failed: {str(e)}"
|
||||
)
|
||||
|
||||
# 4. Return the updated mcp market config
|
||||
# 5. Return the updated mcp market config
|
||||
return success(data=jsonable_encoder(mcp_market_config_schema.McpMarketConfig.model_validate(db_mcp_market_config)),
|
||||
msg="The mcp market config information updated successfully")
|
||||
|
||||
@@ -322,10 +463,7 @@ async def delete_mcp_market_config(
|
||||
if not db_mcp_market_config:
|
||||
api_logger.warning(
|
||||
f"The mcp market config does not exist or you do not have permission to access it: mcp_market_config_id={mcp_market_config_id}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="The mcp market config does not exist or you do not have permission to access it"
|
||||
)
|
||||
return success(msg='The mcp market config does not exist or access is denied')
|
||||
|
||||
# 2. Deleting mcp market config
|
||||
mcp_market_config_service.delete_mcp_market_config_by_id(db, mcp_market_config_id=mcp_market_config_id, current_user=current_user)
|
||||
|
||||
@@ -1,27 +1,29 @@
|
||||
from typing import List, Optional
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from fastapi import APIRouter, Depends, File, Form, Query, UploadFile, Header
|
||||
from sqlalchemy.orm import Session
|
||||
from starlette.responses import StreamingResponse
|
||||
|
||||
from app.cache.memory.interest_memory import InterestMemoryCache
|
||||
from app.celery_app import celery_app
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.language_utils import get_language_from_header
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.core.memory.agent.utils.redis_tool import store
|
||||
from app.core.memory.agent.utils.session_tools import SessionService
|
||||
from app.core.rag.llm.cv_model import QWenCV
|
||||
from app.core.response_utils import fail, success
|
||||
from app.db import get_db
|
||||
from app.dependencies import cur_workspace_access_guard, get_current_user
|
||||
from app.models import ModelApiKey
|
||||
from app.models.user_model import User
|
||||
from app.core.memory.agent.utils.session_tools import SessionService
|
||||
from app.core.memory.agent.utils.redis_tool import store
|
||||
from app.repositories import knowledge_repository, WorkspaceRepository
|
||||
from app.repositories import knowledge_repository
|
||||
from app.schemas.memory_agent_schema import UserInput, Write_UserInput
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.services import task_service, workspace_service
|
||||
from app.services.memory_agent_service import MemoryAgentService
|
||||
from app.services.model_service import ModelConfigService
|
||||
from dotenv import load_dotenv
|
||||
from fastapi import APIRouter, Depends, File, Form, Query, UploadFile,Header
|
||||
from sqlalchemy.orm import Session
|
||||
from starlette.responses import StreamingResponse
|
||||
|
||||
load_dotenv()
|
||||
api_logger = get_api_logger()
|
||||
@@ -36,7 +38,7 @@ router = APIRouter(
|
||||
|
||||
@router.get("/health/status", response_model=ApiResponse)
|
||||
async def get_health_status(
|
||||
current_user: User = Depends(get_current_user)
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Get latest health status written by Celery periodic task
|
||||
@@ -54,8 +56,9 @@ async def get_health_status(
|
||||
|
||||
@router.get("/download_log")
|
||||
async def download_log(
|
||||
log_type: str = Query("file", regex="^(file|transmission)$", description="日志类型: file=完整文件, transmission=实时流式传输"),
|
||||
current_user: User = Depends(get_current_user)
|
||||
log_type: str = Query("file", regex="^(file|transmission)$",
|
||||
description="日志类型: file=完整文件, transmission=实时流式传输"),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Download or stream agent service log file
|
||||
@@ -74,16 +77,16 @@ async def download_log(
|
||||
- transmission mode: StreamingResponse with SSE
|
||||
"""
|
||||
api_logger.info(f"Log download requested with log_type={log_type}")
|
||||
|
||||
|
||||
# Validate log_type parameter (FastAPI Query regex already validates, but explicit check for clarity)
|
||||
if log_type not in ["file", "transmission"]:
|
||||
api_logger.warning(f"Invalid log_type parameter: {log_type}")
|
||||
return fail(
|
||||
BizCode.BAD_REQUEST,
|
||||
"无效的log_type参数",
|
||||
BizCode.BAD_REQUEST,
|
||||
"无效的log_type参数",
|
||||
"log_type必须是'file'或'transmission'"
|
||||
)
|
||||
|
||||
|
||||
# Route to appropriate mode
|
||||
if log_type == "file":
|
||||
# File mode: Return complete log file content
|
||||
@@ -115,147 +118,150 @@ async def download_log(
|
||||
return fail(BizCode.INTERNAL_ERROR, "启动日志流式传输失败", str(e))
|
||||
|
||||
|
||||
@router.post("/writer_service", response_model=ApiResponse)
|
||||
@cur_workspace_access_guard()
|
||||
async def write_server(
|
||||
user_input: Write_UserInput,
|
||||
language_type: str = Header(default=None, alias="X-Language-Type"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Write service endpoint - processes write operations synchronously
|
||||
|
||||
Args:
|
||||
user_input: Write request containing message and end_user_id
|
||||
language_type: 语言类型 ("zh" 中文, "en" 英文),通过 X-Language-Type Header 传递
|
||||
|
||||
Returns:
|
||||
Response with write operation status
|
||||
"""
|
||||
# 使用集中化的语言校验
|
||||
language = get_language_from_header(language_type)
|
||||
|
||||
config_id = user_input.config_id
|
||||
workspace_id = current_user.current_workspace_id
|
||||
api_logger.info(f"Write service: workspace_id={workspace_id}, config_id={config_id}, language_type={language}")
|
||||
|
||||
# 获取 storage_type,如果为 None 则使用默认值
|
||||
storage_type = workspace_service.get_workspace_storage_type(
|
||||
db=db,
|
||||
workspace_id=workspace_id,
|
||||
user=current_user
|
||||
)
|
||||
if storage_type is None: storage_type = 'neo4j'
|
||||
user_rag_memory_id = ''
|
||||
|
||||
# 如果 storage_type 是 rag,必须确保有有效的 user_rag_memory_id
|
||||
if storage_type == 'rag':
|
||||
if workspace_id:
|
||||
knowledge = knowledge_repository.get_knowledge_by_name(
|
||||
db=db,
|
||||
name="USER_RAG_MERORY",
|
||||
workspace_id=workspace_id
|
||||
)
|
||||
if knowledge:
|
||||
user_rag_memory_id = str(knowledge.id)
|
||||
else:
|
||||
api_logger.warning(f"未找到名为 'USER_RAG_MERORY' 的知识库,workspace_id: {workspace_id},将使用 neo4j 存储")
|
||||
storage_type = 'neo4j'
|
||||
else:
|
||||
api_logger.warning("workspace_id 为空,无法使用 rag 存储,将使用 neo4j 存储")
|
||||
storage_type = 'neo4j'
|
||||
|
||||
api_logger.info(f"Write service requested for group {user_input.end_user_id}, storage_type: {storage_type}, user_rag_memory_id: {user_rag_memory_id}")
|
||||
try:
|
||||
messages_list = memory_agent_service.get_messages_list(user_input)
|
||||
result = await memory_agent_service.write_memory(
|
||||
user_input.end_user_id,
|
||||
messages_list,
|
||||
config_id,
|
||||
db,
|
||||
storage_type,
|
||||
user_rag_memory_id,
|
||||
language
|
||||
)
|
||||
|
||||
return success(data=result, msg="写入成功")
|
||||
except BaseException as e:
|
||||
# Handle ExceptionGroup from TaskGroup (Python 3.11+) or BaseExceptionGroup
|
||||
if hasattr(e, 'exceptions'):
|
||||
error_messages = [f"{type(sub_e).__name__}: {str(sub_e)}" for sub_e in e.exceptions]
|
||||
detailed_error = "; ".join(error_messages)
|
||||
api_logger.error(f"Write operation error (TaskGroup): {detailed_error}", exc_info=True)
|
||||
return fail(BizCode.INTERNAL_ERROR, "写入失败", detailed_error)
|
||||
api_logger.error(f"Write operation error: {str(e)}", exc_info=True)
|
||||
return fail(BizCode.INTERNAL_ERROR, "写入失败", str(e))
|
||||
|
||||
|
||||
@router.post("/writer_service_async", response_model=ApiResponse)
|
||||
@cur_workspace_access_guard()
|
||||
async def write_server_async(
|
||||
user_input: Write_UserInput,
|
||||
language_type: str = Header(default=None, alias="X-Language-Type"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Async write service endpoint - enqueues write processing to Celery
|
||||
|
||||
Args:
|
||||
user_input: Write request containing message and end_user_id
|
||||
language_type: 语言类型 ("zh" 中文, "en" 英文),通过 X-Language-Type Header 传递
|
||||
|
||||
Returns:
|
||||
Task ID for tracking async operation
|
||||
Use GET /memory/write_result/{task_id} to check task status and get result
|
||||
"""
|
||||
# 使用集中化的语言校验
|
||||
language = get_language_from_header(language_type)
|
||||
|
||||
config_id = user_input.config_id
|
||||
workspace_id = current_user.current_workspace_id
|
||||
api_logger.info(f"Async write service: workspace_id={workspace_id}, config_id={config_id}, language_type={language}")
|
||||
|
||||
# 获取 storage_type,如果为 None 则使用默认值
|
||||
storage_type = workspace_service.get_workspace_storage_type(
|
||||
db=db,
|
||||
workspace_id=workspace_id,
|
||||
user=current_user
|
||||
)
|
||||
if storage_type is None: storage_type = 'neo4j'
|
||||
user_rag_memory_id = ''
|
||||
if workspace_id:
|
||||
|
||||
knowledge = knowledge_repository.get_knowledge_by_name(
|
||||
db=db,
|
||||
name="USER_RAG_MERORY",
|
||||
workspace_id=workspace_id
|
||||
)
|
||||
if knowledge: user_rag_memory_id = str(knowledge.id)
|
||||
api_logger.info(f"Async write: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}")
|
||||
try:
|
||||
# 获取标准化的消息列表
|
||||
messages_list = memory_agent_service.get_messages_list(user_input)
|
||||
|
||||
task = celery_app.send_task(
|
||||
"app.core.memory.agent.write_message",
|
||||
args=[user_input.end_user_id, messages_list, config_id, storage_type, user_rag_memory_id, language]
|
||||
)
|
||||
api_logger.info(f"Write task queued: {task.id}")
|
||||
|
||||
return success(data={"task_id": task.id}, msg="写入任务已提交")
|
||||
except Exception as e:
|
||||
api_logger.error(f"Async write operation failed: {str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "写入失败", str(e))
|
||||
# @router.post("/writer_service", response_model=ApiResponse)
|
||||
# @cur_workspace_access_guard()
|
||||
# async def write_server(
|
||||
# user_input: Write_UserInput,
|
||||
# language_type: str = Header(default=None, alias="X-Language-Type"),
|
||||
# db: Session = Depends(get_db),
|
||||
# current_user: User = Depends(get_current_user)
|
||||
# ):
|
||||
# """
|
||||
# Write service endpoint - processes write operations synchronously
|
||||
#
|
||||
# Args:
|
||||
# user_input: Write request containing message and end_user_id
|
||||
# language_type: 语言类型 ("zh" 中文, "en" 英文),通过 X-Language-Type Header 传递
|
||||
#
|
||||
# Returns:
|
||||
# Response with write operation status
|
||||
# """
|
||||
# # 使用集中化的语言校验
|
||||
# language = get_language_from_header(language_type)
|
||||
#
|
||||
# config_id = user_input.config_id
|
||||
# workspace_id = current_user.current_workspace_id
|
||||
# api_logger.info(f"Write service: workspace_id={workspace_id}, config_id={config_id}, language_type={language}")
|
||||
#
|
||||
# # 获取 storage_type,如果为 None 则使用默认值
|
||||
# storage_type = workspace_service.get_workspace_storage_type(
|
||||
# db=db,
|
||||
# workspace_id=workspace_id,
|
||||
# user=current_user
|
||||
# )
|
||||
# if storage_type is None: storage_type = 'neo4j'
|
||||
# user_rag_memory_id = ''
|
||||
#
|
||||
# # 如果 storage_type 是 rag,必须确保有有效的 user_rag_memory_id
|
||||
# if storage_type == 'rag':
|
||||
# if workspace_id:
|
||||
# knowledge = knowledge_repository.get_knowledge_by_name(
|
||||
# db=db,
|
||||
# name="USER_RAG_MERORY",
|
||||
# workspace_id=workspace_id
|
||||
# )
|
||||
# if knowledge:
|
||||
# user_rag_memory_id = str(knowledge.id)
|
||||
# else:
|
||||
# api_logger.warning(
|
||||
# f"未找到名为 'USER_RAG_MERORY' 的知识库,workspace_id: {workspace_id},将使用 neo4j 存储")
|
||||
# storage_type = 'neo4j'
|
||||
# else:
|
||||
# api_logger.warning("workspace_id 为空,无法使用 rag 存储,将使用 neo4j 存储")
|
||||
# storage_type = 'neo4j'
|
||||
#
|
||||
# api_logger.info(
|
||||
# f"Write service requested for group {user_input.end_user_id}, storage_type: {storage_type}, user_rag_memory_id: {user_rag_memory_id}")
|
||||
# try:
|
||||
# messages_list = memory_agent_service.get_messages_list(user_input)
|
||||
# result = await memory_agent_service.write_memory(
|
||||
# user_input.end_user_id,
|
||||
# messages_list,
|
||||
# config_id,
|
||||
# db,
|
||||
# storage_type,
|
||||
# user_rag_memory_id,
|
||||
# language
|
||||
# )
|
||||
#
|
||||
# return success(data=result, msg="写入成功")
|
||||
# except BaseException as e:
|
||||
# # Handle ExceptionGroup from TaskGroup (Python 3.11+) or BaseExceptionGroup
|
||||
# if hasattr(e, 'exceptions'):
|
||||
# error_messages = [f"{type(sub_e).__name__}: {str(sub_e)}" for sub_e in e.exceptions]
|
||||
# detailed_error = "; ".join(error_messages)
|
||||
# api_logger.error(f"Write operation error (TaskGroup): {detailed_error}", exc_info=True)
|
||||
# return fail(BizCode.INTERNAL_ERROR, "写入失败", detailed_error)
|
||||
# api_logger.error(f"Write operation error: {str(e)}", exc_info=True)
|
||||
# return fail(BizCode.INTERNAL_ERROR, "写入失败", str(e))
|
||||
#
|
||||
#
|
||||
# @router.post("/writer_service_async", response_model=ApiResponse)
|
||||
# @cur_workspace_access_guard()
|
||||
# async def write_server_async(
|
||||
# user_input: Write_UserInput,
|
||||
# language_type: str = Header(default=None, alias="X-Language-Type"),
|
||||
# db: Session = Depends(get_db),
|
||||
# current_user: User = Depends(get_current_user)
|
||||
# ):
|
||||
# """
|
||||
# Async write service endpoint - enqueues write processing to Celery
|
||||
#
|
||||
# Args:
|
||||
# user_input: Write request containing message and end_user_id
|
||||
# language_type: 语言类型 ("zh" 中文, "en" 英文),通过 X-Language-Type Header 传递
|
||||
#
|
||||
# Returns:
|
||||
# Task ID for tracking async operation
|
||||
# Use GET /memory/write_result/{task_id} to check task status and get result
|
||||
# """
|
||||
# # 使用集中化的语言校验
|
||||
# language = get_language_from_header(language_type)
|
||||
#
|
||||
# config_id = user_input.config_id
|
||||
# workspace_id = current_user.current_workspace_id
|
||||
# api_logger.info(
|
||||
# f"Async write service: workspace_id={workspace_id}, config_id={config_id}, language_type={language}")
|
||||
#
|
||||
# # 获取 storage_type,如果为 None 则使用默认值
|
||||
# storage_type = workspace_service.get_workspace_storage_type(
|
||||
# db=db,
|
||||
# workspace_id=workspace_id,
|
||||
# user=current_user
|
||||
# )
|
||||
# if storage_type is None: storage_type = 'neo4j'
|
||||
# user_rag_memory_id = ''
|
||||
# if workspace_id:
|
||||
#
|
||||
# knowledge = knowledge_repository.get_knowledge_by_name(
|
||||
# db=db,
|
||||
# name="USER_RAG_MERORY",
|
||||
# workspace_id=workspace_id
|
||||
# )
|
||||
# if knowledge: user_rag_memory_id = str(knowledge.id)
|
||||
# api_logger.info(f"Async write: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}")
|
||||
# try:
|
||||
# # 获取标准化的消息列表
|
||||
# messages_list = memory_agent_service.get_messages_list(user_input)
|
||||
#
|
||||
# task = celery_app.send_task(
|
||||
# "app.core.memory.agent.write_message",
|
||||
# args=[user_input.end_user_id, messages_list, config_id, storage_type, user_rag_memory_id, language]
|
||||
# )
|
||||
# api_logger.info(f"Write task queued: {task.id}")
|
||||
#
|
||||
# return success(data={"task_id": task.id}, msg="写入任务已提交")
|
||||
# except Exception as e:
|
||||
# api_logger.error(f"Async write operation failed: {str(e)}")
|
||||
# return fail(BizCode.INTERNAL_ERROR, "写入失败", str(e))
|
||||
|
||||
|
||||
@router.post("/read_service", response_model=ApiResponse)
|
||||
@cur_workspace_access_guard()
|
||||
async def read_server(
|
||||
user_input: UserInput,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
user_input: UserInput,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Read service endpoint - processes read operations synchronously
|
||||
@@ -290,8 +296,9 @@ async def read_server(
|
||||
)
|
||||
if knowledge:
|
||||
user_rag_memory_id = str(knowledge.id)
|
||||
|
||||
api_logger.info(f"Read service: group={user_input.end_user_id}, storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}, workspace_id={workspace_id}")
|
||||
|
||||
api_logger.info(
|
||||
f"Read service: group={user_input.end_user_id}, storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}, workspace_id={workspace_id}")
|
||||
try:
|
||||
result = await memory_agent_service.read_memory(
|
||||
user_input.end_user_id,
|
||||
@@ -305,7 +312,8 @@ async def read_server(
|
||||
)
|
||||
if str(user_input.search_switch) == "2":
|
||||
retrieve_info = result['answer']
|
||||
history = await SessionService(store).get_history(user_input.end_user_id, user_input.end_user_id, user_input.end_user_id)
|
||||
history = await SessionService(store).get_history(user_input.end_user_id, user_input.end_user_id,
|
||||
user_input.end_user_id)
|
||||
query = user_input.message
|
||||
|
||||
# 调用 memory_agent_service 的方法生成最终答案
|
||||
@@ -318,7 +326,7 @@ async def read_server(
|
||||
db=db
|
||||
)
|
||||
if "信息不足,无法回答" in result['answer']:
|
||||
result['answer']=retrieve_info
|
||||
result['answer'] = retrieve_info
|
||||
return success(data=result, msg="回复对话消息成功")
|
||||
except BaseException as e:
|
||||
# Handle ExceptionGroup from TaskGroup (Python 3.11+) or BaseExceptionGroup
|
||||
@@ -334,9 +342,10 @@ async def read_server(
|
||||
@router.post("/file", response_model=ApiResponse)
|
||||
async def file_update(
|
||||
files: List[UploadFile] = File(..., description="要上传的文件"),
|
||||
model_id:str = Form(..., description="模型ID"),
|
||||
model_id: str = Form(..., description="模型ID"),
|
||||
metadata: Optional[str] = Form(None, description="文件元数据 (JSON格式)"),
|
||||
current_user: User = Depends(get_current_user)
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
文件上传接口 - 支持图片识别
|
||||
@@ -349,9 +358,6 @@ async def file_update(
|
||||
Returns:
|
||||
文件处理结果
|
||||
"""
|
||||
|
||||
db_gen = get_db() # get_db 通常是一个生成器
|
||||
db = next(db_gen)
|
||||
api_logger.info(f"File upload requested, file count: {len(files)}")
|
||||
config = ModelConfigService.get_model_by_id(db=db, model_id=model_id)
|
||||
apiConfig: ModelApiKey = config.api_keys[0]
|
||||
@@ -360,7 +366,7 @@ async def file_update(
|
||||
for file in files:
|
||||
api_logger.debug(f"Processing file: {file.filename}, content_type: {file.content_type}")
|
||||
content = await file.read()
|
||||
|
||||
|
||||
if file.content_type and file.content_type.startswith("image/"):
|
||||
vision_model = QWenCV(
|
||||
key=apiConfig.api_key,
|
||||
@@ -374,12 +380,12 @@ async def file_update(
|
||||
else:
|
||||
api_logger.warning(f"Unsupported file type: {file.content_type}")
|
||||
file_content.append(f"[不支持的文件类型: {file.content_type}]")
|
||||
|
||||
|
||||
result_text = ';'.join(file_content)
|
||||
api_logger.info(f"File processing completed, result length: {len(result_text)}")
|
||||
|
||||
|
||||
return success(data=result_text, msg="转换文本成功")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"File processing failed: {str(e)}", exc_info=True)
|
||||
return fail(BizCode.INTERNAL_ERROR, "转换文本失败", str(e))
|
||||
@@ -429,8 +435,8 @@ async def read_server_async(
|
||||
|
||||
@router.get("/read_result/", response_model=ApiResponse)
|
||||
async def get_read_task_result(
|
||||
task_id: str,
|
||||
current_user: User = Depends(get_current_user)
|
||||
task_id: str,
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Get the status and result of an async read task
|
||||
@@ -451,7 +457,7 @@ async def get_read_task_result(
|
||||
try:
|
||||
result = task_service.get_task_memory_read_result(task_id)
|
||||
status = result.get("status")
|
||||
|
||||
|
||||
if status == "SUCCESS":
|
||||
# 任务成功完成
|
||||
task_result = result.get("result", {})
|
||||
@@ -469,7 +475,7 @@ async def get_read_task_result(
|
||||
else:
|
||||
# 旧格式:直接返回结果
|
||||
return success(data=task_result, msg="查询任务已完成")
|
||||
|
||||
|
||||
elif status == "FAILURE":
|
||||
# 任务失败
|
||||
error_info = result.get("result", "Unknown error")
|
||||
@@ -478,7 +484,7 @@ async def get_read_task_result(
|
||||
else:
|
||||
error_msg = str(error_info)
|
||||
return fail(BizCode.INTERNAL_ERROR, "查询任务失败", error_msg)
|
||||
|
||||
|
||||
elif status in ["PENDING", "STARTED"]:
|
||||
# 任务进行中
|
||||
return success(
|
||||
@@ -498,7 +504,7 @@ async def get_read_task_result(
|
||||
},
|
||||
msg=f"任务状态: {status}"
|
||||
)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"Read task status check failed: {str(e)}", exc_info=True)
|
||||
return fail(BizCode.INTERNAL_ERROR, "任务状态查询失败", str(e))
|
||||
@@ -506,8 +512,8 @@ async def get_read_task_result(
|
||||
|
||||
@router.get("/write_result/", response_model=ApiResponse)
|
||||
async def get_write_task_result(
|
||||
task_id: str,
|
||||
current_user: User = Depends(get_current_user)
|
||||
task_id: str,
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Get the status and result of an async write task
|
||||
@@ -528,7 +534,7 @@ async def get_write_task_result(
|
||||
try:
|
||||
result = task_service.get_task_memory_write_result(task_id)
|
||||
status = result.get("status")
|
||||
|
||||
|
||||
if status == "SUCCESS":
|
||||
# 任务成功完成
|
||||
task_result = result.get("result", {})
|
||||
@@ -546,7 +552,7 @@ async def get_write_task_result(
|
||||
else:
|
||||
# 旧格式:直接返回结果
|
||||
return success(data=task_result, msg="写入任务已完成")
|
||||
|
||||
|
||||
elif status == "FAILURE":
|
||||
# 任务失败
|
||||
error_info = result.get("result", "Unknown error")
|
||||
@@ -555,7 +561,7 @@ async def get_write_task_result(
|
||||
else:
|
||||
error_msg = str(error_info)
|
||||
return fail(BizCode.INTERNAL_ERROR, "写入任务失败", error_msg)
|
||||
|
||||
|
||||
elif status in ["PENDING", "STARTED"]:
|
||||
# 任务进行中
|
||||
return success(
|
||||
@@ -575,7 +581,7 @@ async def get_write_task_result(
|
||||
},
|
||||
msg=f"任务状态: {status}"
|
||||
)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"Write task status check failed: {str(e)}", exc_info=True)
|
||||
return fail(BizCode.INTERNAL_ERROR, "任务状态查询失败", str(e))
|
||||
@@ -583,9 +589,9 @@ async def get_write_task_result(
|
||||
|
||||
@router.post("/status_type", response_model=ApiResponse)
|
||||
async def status_type(
|
||||
user_input: Write_UserInput,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
user_input: Write_UserInput,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Determine the type of user message (read or write)
|
||||
@@ -628,26 +634,21 @@ async def status_type(
|
||||
|
||||
@router.get("/stats/types", response_model=ApiResponse)
|
||||
async def get_knowledge_type_stats_api(
|
||||
end_user_id: Optional[str] = Query(None, description="用户ID(可选)"),
|
||||
only_active: bool = Query(True, description="仅统计有效记录(status=1)"),
|
||||
current_user: User = Depends(get_current_user)
|
||||
end_user_id: Optional[str] = Query(None, description="用户ID(可选)"),
|
||||
only_active: bool = Query(True, description="仅统计有效记录(status=1)"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
统计当前空间下各知识库类型的数量,包含 General | Web | Third-party | Folder | memory。
|
||||
统计当前空间下各知识库类型的数量,包含 General | Web | Third-party | Folder。
|
||||
会对缺失类型补 0,返回字典形式。
|
||||
可选按状态过滤。
|
||||
- 知识库类型根据当前用户的 current_workspace_id 过滤
|
||||
- memory 是 Neo4j 中 Chunk 的数量,根据 end_user_id (end_user_id) 过滤
|
||||
- 如果用户没有当前工作空间或未提供 end_user_id,对应的统计返回 0
|
||||
- 如果用户没有当前工作空间,对应的统计返回 0
|
||||
"""
|
||||
api_logger.info(f"Knowledge type stats requested for workspace_id: {current_user.current_workspace_id}, end_user_id: {end_user_id}")
|
||||
api_logger.info(
|
||||
f"Knowledge type stats requested for workspace_id: {current_user.current_workspace_id}, end_user_id: {end_user_id}")
|
||||
try:
|
||||
from app.db import get_db
|
||||
|
||||
# 获取数据库会话
|
||||
db_gen = get_db()
|
||||
db = next(db_gen)
|
||||
|
||||
# 调用service层函数
|
||||
result = await memory_agent_service.get_knowledge_type_stats(
|
||||
end_user_id=end_user_id,
|
||||
@@ -655,48 +656,70 @@ async def get_knowledge_type_stats_api(
|
||||
current_workspace_id=current_user.current_workspace_id,
|
||||
db=db
|
||||
)
|
||||
|
||||
|
||||
return success(data=result, msg="获取知识库类型统计成功")
|
||||
except Exception as e:
|
||||
api_logger.error(f"Knowledge type stats failed: {str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "获取知识库类型统计失败", str(e))
|
||||
|
||||
|
||||
@router.get("/analytics/hot_memory_tags/by_user", response_model=ApiResponse)
|
||||
async def get_hot_memory_tags_by_user_api(
|
||||
end_user_id: Optional[str] = Query(None, description="用户ID(可选)"),
|
||||
limit: int = Query(20, description="返回标签数量限制"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session=Depends(get_db),
|
||||
@router.get("/analytics/interest_distribution/by_user", response_model=ApiResponse)
|
||||
async def get_interest_distribution_by_user_api(
|
||||
end_user_id: str = Query(..., description="用户ID(必填)"),
|
||||
limit: int = Query(5, le=5, description="返回兴趣标签数量限制,最多5个"),
|
||||
language_type: str = Header(default=None, alias="X-Language-Type"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
获取指定用户的热门记忆标签
|
||||
获取指定用户的兴趣分布标签
|
||||
|
||||
注意:标签语言由写入时的 X-Language-Type 决定,查询时不进行翻译
|
||||
与热门标签不同,此接口专注于识别用户的兴趣活动(运动、爱好、学习、创作等),
|
||||
过滤掉纯物品、工具、地点等不代表用户主动参与活动的名词。
|
||||
|
||||
返回格式:
|
||||
[
|
||||
{"name": "标签名", "frequency": 频次},
|
||||
{"name": "兴趣活动名", "frequency": 频次},
|
||||
...
|
||||
]
|
||||
"""
|
||||
api_logger.info(f"Hot memory tags by user requested: end_user_id={end_user_id}")
|
||||
language = get_language_from_header(language_type)
|
||||
api_logger.info(f"Interest distribution by user requested: end_user_id={end_user_id}, language={language}")
|
||||
try:
|
||||
result = await memory_agent_service.get_hot_memory_tags_by_user(
|
||||
# 优先读取缓存
|
||||
cached = await InterestMemoryCache.get_interest_distribution(
|
||||
end_user_id=end_user_id,
|
||||
limit=limit
|
||||
language=language,
|
||||
)
|
||||
return success(data=result, msg="获取热门记忆标签成功")
|
||||
if cached is not None:
|
||||
api_logger.info(f"Interest distribution cache hit: end_user_id={end_user_id}")
|
||||
return success(data=cached, msg="获取兴趣分布标签成功")
|
||||
|
||||
# 缓存未命中,调用模型生成
|
||||
result = await memory_agent_service.get_interest_distribution_by_user(
|
||||
end_user_id=end_user_id,
|
||||
limit=limit,
|
||||
language=language
|
||||
)
|
||||
|
||||
# 写入缓存,24小时过期
|
||||
await InterestMemoryCache.set_interest_distribution(
|
||||
end_user_id=end_user_id,
|
||||
language=language,
|
||||
data=result,
|
||||
)
|
||||
|
||||
return success(data=result, msg="获取兴趣分布标签成功")
|
||||
except Exception as e:
|
||||
api_logger.error(f"Hot memory tags by user failed: {str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "获取热门记忆标签失败", str(e))
|
||||
api_logger.error(f"Interest distribution by user failed: {str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "获取兴趣分布标签失败", str(e))
|
||||
|
||||
|
||||
@router.get("/analytics/user_profile", response_model=ApiResponse)
|
||||
async def get_user_profile_api(
|
||||
end_user_id: Optional[str] = Query(None, description="用户ID(可选)"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
end_user_id: Optional[str] = Query(None, description="用户ID(可选)"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
获取用户详情,包含:
|
||||
@@ -734,17 +757,17 @@ async def get_user_profile_api(
|
||||
# ):
|
||||
# """
|
||||
# Get parsed API documentation (Public endpoint - no authentication required)
|
||||
|
||||
|
||||
# Args:
|
||||
# file_path: Optional path to API docs file. If None, uses default path.
|
||||
|
||||
|
||||
# Returns:
|
||||
# Parsed API documentation including title, meta info, and sections
|
||||
# """
|
||||
# api_logger.info(f"API docs requested, file_path: {file_path or 'default'}")
|
||||
# try:
|
||||
# result = await memory_agent_service.get_api_docs(file_path)
|
||||
|
||||
|
||||
# if result.get("success"):
|
||||
# return success(msg=result["msg"], data=result["data"])
|
||||
# else:
|
||||
@@ -760,9 +783,9 @@ async def get_user_profile_api(
|
||||
|
||||
@router.get("/end_user/{end_user_id}/connected_config", response_model=ApiResponse)
|
||||
async def get_end_user_connected_config(
|
||||
end_user_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
end_user_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
获取终端用户关联的记忆配置
|
||||
@@ -781,9 +804,9 @@ async def get_end_user_connected_config(
|
||||
from app.services.memory_agent_service import (
|
||||
get_end_user_connected_config as get_config,
|
||||
)
|
||||
|
||||
|
||||
api_logger.info(f"Getting connected config for end_user: {end_user_id}")
|
||||
|
||||
|
||||
try:
|
||||
result = get_config(end_user_id, db)
|
||||
return success(data=result, msg="获取终端用户关联配置成功")
|
||||
@@ -792,4 +815,4 @@ async def get_end_user_connected_config(
|
||||
return fail(BizCode.NOT_FOUND, str(e))
|
||||
except Exception as e:
|
||||
api_logger.error(f"Failed to get end user connected config: {str(e)}", exc_info=True)
|
||||
return fail(BizCode.INTERNAL_ERROR, "获取终端用户关联配置失败", str(e))
|
||||
return fail(BizCode.INTERNAL_ERROR, "获取终端用户关联配置失败", str(e))
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Query
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import Optional
|
||||
from app.core.response_utils import success
|
||||
@@ -9,6 +10,7 @@ from app.schemas.response_schema import ApiResponse
|
||||
|
||||
from app.services import memory_dashboard_service, memory_storage_service, workspace_service
|
||||
from app.services.memory_agent_service import get_end_users_connected_configs_batch
|
||||
from app.services.app_statistics_service import AppStatisticsService
|
||||
from app.core.logging_config import get_api_logger
|
||||
|
||||
# 获取API专用日志器
|
||||
@@ -148,6 +150,21 @@ async def get_workspace_end_users(
|
||||
|
||||
return {uid: {"total": 0} for uid in end_user_ids}
|
||||
|
||||
# 触发按需初始化:为 implicit_emotions_storage 中没有记录的用户异步生成数据
|
||||
try:
|
||||
from app.celery_app import celery_app as _celery_app
|
||||
_celery_app.send_task(
|
||||
"app.tasks.init_implicit_emotions_for_users",
|
||||
kwargs={"end_user_ids": end_user_ids},
|
||||
)
|
||||
_celery_app.send_task(
|
||||
"app.tasks.init_interest_distribution_for_users",
|
||||
kwargs={"end_user_ids": end_user_ids},
|
||||
)
|
||||
api_logger.info(f"已触发按需初始化任务,候选用户数: {len(end_user_ids)}")
|
||||
except Exception as e:
|
||||
api_logger.warning(f"触发按需初始化任务失败(不影响主流程): {e}")
|
||||
|
||||
# 并发执行配置查询和记忆数量查询
|
||||
memory_configs_map, memory_nums_map = await asyncio.gather(
|
||||
get_memory_configs(),
|
||||
@@ -176,7 +193,15 @@ async def get_workspace_end_users(
|
||||
await aio_redis_set(cache_key, json.dumps(result), expire=30)
|
||||
except Exception as e:
|
||||
api_logger.warning(f"Redis 缓存写入失败: {str(e)}")
|
||||
|
||||
|
||||
# 触发社区聚类补全任务(异步,不阻塞接口响应)
|
||||
try:
|
||||
from app.tasks import init_community_clustering_for_users
|
||||
init_community_clustering_for_users.delay(end_user_ids=end_user_ids, workspace_id=str(workspace_id))
|
||||
api_logger.info(f"已触发社区聚类补全任务,候选用户数: {len(end_user_ids)}")
|
||||
except Exception as e:
|
||||
api_logger.warning(f"触发社区聚类补全任务失败(不影响主流程): {str(e)}")
|
||||
|
||||
api_logger.info(f"成功获取 {len(end_users)} 个宿主记录")
|
||||
return success(data=result, msg="宿主列表获取成功")
|
||||
|
||||
@@ -386,14 +411,15 @@ def get_current_user_rag_total_num(
|
||||
@router.get("/rag_content", response_model=ApiResponse)
|
||||
def get_rag_content(
|
||||
end_user_id: str = Query(..., description="宿主ID"),
|
||||
limit: int = Query(15, description="返回记录数"),
|
||||
page: int = Query(1, gt=0, description="页码,从1开始"),
|
||||
pagesize: int = Query(15, gt=0, le=100, description="每页返回记录数"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
获取当前宿主知识库中的chunk内容
|
||||
获取当前宿主知识库中的chunk内容(分页)
|
||||
"""
|
||||
data = memory_dashboard_service.get_rag_content(end_user_id, limit, db, current_user)
|
||||
data = memory_dashboard_service.get_rag_content(end_user_id, page, pagesize, db, current_user)
|
||||
return success(data=data, msg="宿主RAGchunk数据获取成功")
|
||||
|
||||
|
||||
@@ -406,26 +432,18 @@ async def get_chunk_summary_tag(
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
获取chunk总结、提取的标签和人物形象
|
||||
|
||||
读取RAG摘要、标签和人物形象(纯读库,不触发生成)。
|
||||
|
||||
返回格式:
|
||||
{
|
||||
"summary": "chunk内容的总结",
|
||||
"tags": [
|
||||
{"tag": "标签1", "frequency": 5},
|
||||
{"tag": "标签2", "frequency": 3},
|
||||
...
|
||||
],
|
||||
"personas": [
|
||||
"产品设计师",
|
||||
"旅行爱好者",
|
||||
"摄影发烧友",
|
||||
...
|
||||
]
|
||||
"summary": "用户摘要",
|
||||
"tags": [{"tag": "标签1", "frequency": 5}, ...],
|
||||
"personas": ["产品设计师", ...],
|
||||
"generated": true/false // false表示尚未生产,请调用 /generate_rag_profile
|
||||
}
|
||||
"""
|
||||
api_logger.info(f"用户 {current_user.username} 请求获取宿主 {end_user_id} 的chunk摘要、标签和人物形象")
|
||||
|
||||
api_logger.info(f"用户 {current_user.username} 读取宿主 {end_user_id} 的RAG摘要/标签/人物形象")
|
||||
|
||||
data = await memory_dashboard_service.get_chunk_summary_and_tags(
|
||||
end_user_id=end_user_id,
|
||||
limit=limit,
|
||||
@@ -433,9 +451,8 @@ async def get_chunk_summary_tag(
|
||||
db=db,
|
||||
current_user=current_user
|
||||
)
|
||||
|
||||
api_logger.info(f"成功获取chunk摘要、{len(data.get('tags', []))} 个标签和 {len(data.get('personas', []))} 个人物形象")
|
||||
return success(data=data, msg="chunk摘要、标签和人物形象获取成功")
|
||||
|
||||
return success(data=data, msg="获取成功")
|
||||
|
||||
|
||||
@router.get("/chunk_insight", response_model=ApiResponse)
|
||||
@@ -446,29 +463,64 @@ async def get_chunk_insight(
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
获取chunk的洞察内容
|
||||
|
||||
读取RAG洞察报告(纯读库,不触发生成)。
|
||||
|
||||
返回格式:
|
||||
{
|
||||
"insight": "对chunk内容的深度洞察分析"
|
||||
"insight": "总体概述",
|
||||
"behavior_pattern": "行为模式",
|
||||
"key_findings": "关键发现",
|
||||
"growth_trajectory": "成长轨迹",
|
||||
"generated": true/false // false表示尚未生产,请调用 /generate_rag_profile
|
||||
}
|
||||
"""
|
||||
api_logger.info(f"用户 {current_user.username} 请求获取宿主 {end_user_id} 的chunk洞察")
|
||||
|
||||
api_logger.info(f"用户 {current_user.username} 读取宿主 {end_user_id} 的RAG洞察")
|
||||
|
||||
data = await memory_dashboard_service.get_chunk_insight(
|
||||
end_user_id=end_user_id,
|
||||
limit=limit,
|
||||
db=db,
|
||||
current_user=current_user
|
||||
)
|
||||
|
||||
api_logger.info("成功获取chunk洞察")
|
||||
return success(data=data, msg="chunk洞察获取成功")
|
||||
|
||||
return success(data=data, msg="获取成功")
|
||||
|
||||
|
||||
class GenerateRagProfileRequest(BaseModel):
|
||||
end_user_id: str = Field(..., description="宿主ID")
|
||||
limit: int = Field(15, description="参与生成的chunk数量上限")
|
||||
max_tags: int = Field(10, description="最大标签数量")
|
||||
|
||||
|
||||
@router.post("/generate_rag_profile", response_model=ApiResponse)
|
||||
async def generate_rag_profile(
|
||||
body: GenerateRagProfileRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
生产接口:为RAG存储模式的宿主全量重新生成完整画像并持久化到end_user表。
|
||||
每次请求都会重新生成,覆盖已有数据。
|
||||
"""
|
||||
api_logger.info(f"用户 {current_user.username} 触发RAG画像生产: end_user_id={body.end_user_id}")
|
||||
|
||||
data = await memory_dashboard_service.generate_rag_profile(
|
||||
end_user_id=body.end_user_id,
|
||||
limit=body.limit,
|
||||
max_tags=body.max_tags,
|
||||
db=db,
|
||||
current_user=current_user,
|
||||
)
|
||||
|
||||
api_logger.info(f"RAG画像生产完成: {data}")
|
||||
return success(data=data, msg="RAG画像生产完成")
|
||||
|
||||
|
||||
@router.get("/dashboard_data", response_model=ApiResponse)
|
||||
async def dashboard_data(
|
||||
end_user_id: Optional[str] = Query(None, description="可选的用户ID"),
|
||||
start_date: Optional[int] = Query(None, description="开始时间戳(毫秒)"),
|
||||
end_date: Optional[int] = Query(None, description="结束时间戳(毫秒)"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
@@ -503,6 +555,15 @@ async def dashboard_data(
|
||||
workspace_id = current_user.current_workspace_id
|
||||
api_logger.info(f"用户 {current_user.username} 请求获取工作空间 {workspace_id} 的dashboard整合数据")
|
||||
|
||||
# 如果没有提供时间范围,默认使用最近30天
|
||||
if start_date is None or end_date is None:
|
||||
from datetime import datetime, timedelta
|
||||
end_dt = datetime.now()
|
||||
start_dt = end_dt - timedelta(days=30)
|
||||
end_date = int(end_dt.timestamp() * 1000)
|
||||
start_date = int(start_dt.timestamp() * 1000)
|
||||
api_logger.info(f"使用默认时间范围: {start_dt} 到 {end_dt}")
|
||||
|
||||
# 获取 storage_type,如果为 None 则使用默认值
|
||||
storage_type = workspace_service.get_workspace_storage_type(
|
||||
db=db,
|
||||
@@ -541,9 +602,12 @@ async def dashboard_data(
|
||||
)
|
||||
neo4j_data["total_memory"] = total_memory_data.get("total_memory_count", 0)
|
||||
# total_app: 统计当前空间下的所有app数量
|
||||
from app.repositories import app_repository
|
||||
apps_orm = app_repository.get_apps_by_workspace_id(db, workspace_id)
|
||||
neo4j_data["total_app"] = len(apps_orm)
|
||||
# 包含自有app + 被分享给本工作空间的app
|
||||
from app.services import app_service as _app_svc
|
||||
_, total_app = _app_svc.AppService(db).list_apps(
|
||||
workspace_id=workspace_id, include_shared=True, pagesize=1
|
||||
)
|
||||
neo4j_data["total_app"] = total_app
|
||||
api_logger.info(f"成功获取记忆总量: {neo4j_data['total_memory']}, 应用数量: {neo4j_data['total_app']}")
|
||||
except Exception as e:
|
||||
api_logger.warning(f"获取记忆总量失败: {str(e)}")
|
||||
@@ -563,17 +627,22 @@ async def dashboard_data(
|
||||
except Exception as e:
|
||||
api_logger.warning(f"获取知识库类型统计失败: {str(e)}")
|
||||
|
||||
# 3. 获取API调用增量(total_api_call,转换为整数)
|
||||
# 3. 获取API调用统计(total_api_call)
|
||||
try:
|
||||
api_increment = memory_dashboard_service.get_workspace_api_increment(
|
||||
db=db,
|
||||
# 使用 AppStatisticsService 获取真实的API调用统计
|
||||
app_stats_service = AppStatisticsService(db)
|
||||
api_stats = app_stats_service.get_workspace_api_statistics(
|
||||
workspace_id=workspace_id,
|
||||
current_user=current_user
|
||||
start_date=start_date,
|
||||
end_date=end_date
|
||||
)
|
||||
neo4j_data["total_api_call"] = api_increment
|
||||
api_logger.info(f"成功获取API调用增量: {neo4j_data['total_api_call']}")
|
||||
# 计算总调用次数
|
||||
total_api_calls = sum(item.get("total_calls", 0) for item in api_stats)
|
||||
neo4j_data["total_api_call"] = total_api_calls
|
||||
api_logger.info(f"成功获取API调用统计: {neo4j_data['total_api_call']}")
|
||||
except Exception as e:
|
||||
api_logger.warning(f"获取API调用增量失败: {str(e)}")
|
||||
api_logger.error(f"获取API调用统计失败: {str(e)}")
|
||||
neo4j_data["total_api_call"] = 0
|
||||
|
||||
result["neo4j_data"] = neo4j_data
|
||||
api_logger.info("成功获取neo4j_data")
|
||||
@@ -589,23 +658,39 @@ async def dashboard_data(
|
||||
|
||||
# 获取RAG相关数据
|
||||
try:
|
||||
# total_memory: 使用 total_chunk(总chunk数)
|
||||
total_chunk = memory_dashboard_service.get_rag_total_chunk(db, current_user)
|
||||
# total_memory: 只统计用户知识库(permission_id='Memory')的chunk数
|
||||
total_chunk = memory_dashboard_service.get_rag_user_kb_total_chunk(db, current_user)
|
||||
rag_data["total_memory"] = total_chunk
|
||||
|
||||
# total_app: 统计当前空间下的所有app数量
|
||||
from app.repositories import app_repository
|
||||
apps_orm = app_repository.get_apps_by_workspace_id(db, workspace_id)
|
||||
rag_data["total_app"] = len(apps_orm)
|
||||
# 包含自有app + 被分享给本工作空间的app
|
||||
from app.services import app_service as _app_svc
|
||||
_, total_app = _app_svc.AppService(db).list_apps(
|
||||
workspace_id=workspace_id, include_shared=True, pagesize=1
|
||||
)
|
||||
rag_data["total_app"] = total_app
|
||||
|
||||
# total_knowledge: 使用 total_kb(总知识库数)
|
||||
total_kb = memory_dashboard_service.get_rag_total_kb(db, current_user)
|
||||
rag_data["total_knowledge"] = total_kb
|
||||
|
||||
# total_api_call: 固定值
|
||||
rag_data["total_api_call"] = 1024
|
||||
# total_api_call: 使用 AppStatisticsService 获取真实的API调用统计
|
||||
try:
|
||||
app_stats_service = AppStatisticsService(db)
|
||||
api_stats = app_stats_service.get_workspace_api_statistics(
|
||||
workspace_id=workspace_id,
|
||||
start_date=start_date,
|
||||
end_date=end_date
|
||||
)
|
||||
# 计算总调用次数
|
||||
total_api_calls = sum(item.get("total_calls", 0) for item in api_stats)
|
||||
rag_data["total_api_call"] = total_api_calls
|
||||
api_logger.info(f"成功获取RAG模式API调用统计: {rag_data['total_api_call']}")
|
||||
except Exception as e:
|
||||
api_logger.warning(f"获取RAG模式API调用统计失败,使用默认值: {str(e)}")
|
||||
rag_data["total_api_call"] = 0
|
||||
|
||||
api_logger.info(f"成功获取RAG相关数据: memory={total_chunk}, app={len(apps_orm)}, knowledge={total_kb}")
|
||||
api_logger.info(f"成功获取RAG相关数据: memory={total_chunk}, app={total_app}, knowledge={total_kb}, api_calls={rag_data['total_api_call']}")
|
||||
except Exception as e:
|
||||
api_logger.warning(f"获取RAG相关数据失败: {str(e)}")
|
||||
|
||||
|
||||
@@ -31,6 +31,7 @@ from app.schemas.memory_storage_schema import (
|
||||
ForgettingCurveRequest,
|
||||
ForgettingCurveResponse,
|
||||
ForgettingCurvePoint,
|
||||
PendingNodesResponse,
|
||||
)
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.services.memory_forget_service import MemoryForgetService
|
||||
@@ -308,6 +309,100 @@ async def get_forgetting_stats(
|
||||
return fail(BizCode.INTERNAL_ERROR, "获取遗忘引擎统计失败", str(e))
|
||||
|
||||
|
||||
@router.get("/pending-nodes", response_model=ApiResponse)
|
||||
async def get_pending_nodes(
|
||||
end_user_id: str,
|
||||
page: int = 1,
|
||||
pagesize: int = 10,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
获取待遗忘节点列表(独立分页接口)
|
||||
|
||||
查询满足遗忘条件的节点(激活值低于阈值且最后访问时间超过最小天数)。
|
||||
此接口独立分页,与 /stats 接口分离。
|
||||
|
||||
Args:
|
||||
end_user_id: 组ID(即 end_user_id,必填)
|
||||
page: 页码(从1开始,默认1)
|
||||
pagesize: 每页数量(默认10)
|
||||
current_user: 当前用户
|
||||
db: 数据库会话
|
||||
|
||||
Returns:
|
||||
ApiResponse: 包含待遗忘节点列表和分页信息的响应
|
||||
|
||||
Examples:
|
||||
- 第1页,每页10条:GET /memory/forget-memory/pending-nodes?end_user_id=xxx&page=1&pagesize=10
|
||||
- 第2页,每页20条:GET /memory/forget-memory/pending-nodes?end_user_id=xxx&page=2&pagesize=20
|
||||
|
||||
Notes:
|
||||
- page 从1开始,pagesize 必须大于0
|
||||
- 返回格式:{"items": [...], "page": {"page": 1, "pagesize": 10, "total": 100, "hasnext": true}}
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
# 检查用户是否已选择工作空间
|
||||
if workspace_id is None:
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试获取待遗忘节点但未选择工作空间")
|
||||
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
||||
|
||||
# 验证 end_user_id 必填
|
||||
if not end_user_id:
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试获取待遗忘节点但未提供 end_user_id")
|
||||
return fail(BizCode.INVALID_PARAMETER, "end_user_id 不能为空", "end_user_id is required")
|
||||
|
||||
# 通过 end_user_id 获取关联的 config_id
|
||||
try:
|
||||
from app.services.memory_agent_service import get_end_user_connected_config
|
||||
|
||||
connected_config = get_end_user_connected_config(end_user_id, db)
|
||||
config_id = connected_config.get("memory_config_id")
|
||||
config_id = resolve_config_id(config_id, db)
|
||||
|
||||
if config_id is None:
|
||||
api_logger.warning(f"终端用户 {end_user_id} 未关联记忆配置")
|
||||
return fail(BizCode.INVALID_PARAMETER, f"终端用户 {end_user_id} 未关联记忆配置", "memory_config_id is None")
|
||||
|
||||
api_logger.debug(f"通过 end_user_id={end_user_id} 获取到 config_id={config_id}")
|
||||
except ValueError as e:
|
||||
api_logger.warning(f"获取终端用户配置失败: {str(e)}")
|
||||
return fail(BizCode.INVALID_PARAMETER, str(e), "ValueError")
|
||||
except Exception as e:
|
||||
api_logger.error(f"获取终端用户配置时发生错误: {str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "获取终端用户配置失败", str(e))
|
||||
|
||||
# 验证分页参数
|
||||
if page < 1:
|
||||
return fail(BizCode.INVALID_PARAMETER, "page 必须大于等于1", "page < 1")
|
||||
if pagesize < 1:
|
||||
return fail(BizCode.INVALID_PARAMETER, "pagesize 必须大于等于1", "pagesize < 1")
|
||||
|
||||
api_logger.info(
|
||||
f"用户 {current_user.username} 在工作空间 {workspace_id} 请求获取待遗忘节点: "
|
||||
f"end_user_id={end_user_id}, page={page}, pagesize={pagesize}"
|
||||
)
|
||||
|
||||
try:
|
||||
# 调用服务层获取待遗忘节点列表
|
||||
result = await forget_service.get_pending_nodes(
|
||||
db=db,
|
||||
end_user_id=end_user_id,
|
||||
config_id=config_id,
|
||||
page=page,
|
||||
pagesize=pagesize
|
||||
)
|
||||
|
||||
# 构建响应
|
||||
response_data = PendingNodesResponse(**result)
|
||||
|
||||
return success(data=response_data.model_dump(), msg="查询成功")
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"获取待遗忘节点列表失败: {str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "获取待遗忘节点列表失败", str(e))
|
||||
|
||||
|
||||
@router.post("/forgetting_curve", response_model=ApiResponse)
|
||||
async def get_forgetting_curve(
|
||||
request: ForgettingCurveRequest,
|
||||
|
||||
@@ -1,3 +1,19 @@
|
||||
"""
|
||||
Memory Reflection Controller
|
||||
|
||||
This module provides REST API endpoints for managing memory reflection configurations
|
||||
and operations. It handles reflection engine setup, configuration management, and
|
||||
execution of self-reflection processes across memory systems.
|
||||
|
||||
Key Features:
|
||||
- Reflection configuration management (save, retrieve, update)
|
||||
- Workspace-wide reflection execution across multiple applications
|
||||
- Individual configuration-based reflection runs
|
||||
- Multi-language support for reflection outputs
|
||||
- Integration with Neo4j memory storage and LLM models
|
||||
- Comprehensive error handling and logging
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
import uuid
|
||||
@@ -28,9 +44,13 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from app.utils.config_utils import resolve_config_id
|
||||
|
||||
# Load environment variables for configuration
|
||||
load_dotenv()
|
||||
|
||||
# Initialize API logger for request tracking and debugging
|
||||
api_logger = get_api_logger()
|
||||
|
||||
# Configure router with prefix and tags for API organization
|
||||
router = APIRouter(
|
||||
prefix="/memory",
|
||||
tags=["Memory"],
|
||||
@@ -43,7 +63,38 @@ async def save_reflection_config(
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
"""Save reflection configuration to data_comfig table"""
|
||||
"""
|
||||
Save reflection configuration to memory config table
|
||||
|
||||
Persists reflection engine configuration settings to the data_config table,
|
||||
including reflection parameters, model settings, and evaluation criteria.
|
||||
Validates configuration parameters and ensures data consistency.
|
||||
|
||||
Args:
|
||||
request: Memory reflection configuration data including:
|
||||
- config_id: Configuration identifier to update
|
||||
- reflection_enabled: Whether reflection is enabled
|
||||
- reflection_period_in_hours: Reflection execution interval
|
||||
- reflexion_range: Scope of reflection (partial/all)
|
||||
- baseline: Reflection strategy (time/fact/hybrid)
|
||||
- reflection_model_id: LLM model for reflection operations
|
||||
- memory_verify: Enable memory verification checks
|
||||
- quality_assessment: Enable quality assessment evaluation
|
||||
current_user: Authenticated user saving the configuration
|
||||
db: Database session for data operations
|
||||
|
||||
Returns:
|
||||
dict: Success response with saved reflection configuration data
|
||||
|
||||
Raises:
|
||||
HTTPException 400: If config_id is missing or parameters are invalid
|
||||
HTTPException 500: If configuration save operation fails
|
||||
|
||||
Database Operations:
|
||||
- Updates memory_config table with reflection settings
|
||||
- Commits transaction and refreshes entity
|
||||
- Maintains configuration consistency
|
||||
"""
|
||||
try:
|
||||
config_id = request.config_id
|
||||
config_id = resolve_config_id(config_id, db)
|
||||
@@ -54,6 +105,7 @@ async def save_reflection_config(
|
||||
)
|
||||
api_logger.info(f"用户 {current_user.username} 保存反思配置,config_id: {config_id}")
|
||||
|
||||
# Update reflection configuration in database
|
||||
memory_config = MemoryConfigRepository.update_reflection_config(
|
||||
db,
|
||||
config_id=config_id,
|
||||
@@ -66,6 +118,7 @@ async def save_reflection_config(
|
||||
quality_assessment=request.quality_assessment
|
||||
)
|
||||
|
||||
# Commit transaction and refresh entity
|
||||
db.commit()
|
||||
db.refresh(memory_config)
|
||||
|
||||
@@ -102,13 +155,55 @@ async def start_workspace_reflection(
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
"""启动工作空间中所有匹配应用的反思功能"""
|
||||
"""
|
||||
Start reflection functionality for all matching applications in workspace
|
||||
|
||||
Initiates reflection processes across all applications within the user's current
|
||||
workspace that have valid memory configurations. Processes each application's
|
||||
configurations and associated end users, executing reflection operations
|
||||
with proper error isolation and transaction management.
|
||||
|
||||
This endpoint serves as a workspace-wide reflection orchestrator, ensuring
|
||||
that reflection failures for individual users don't affect other operations.
|
||||
|
||||
Args:
|
||||
current_user: Authenticated user initiating workspace reflection
|
||||
db: Database session for configuration queries
|
||||
|
||||
Returns:
|
||||
dict: Success response with reflection results for all processed applications:
|
||||
- app_id: Application identifier
|
||||
- config_id: Memory configuration identifier
|
||||
- end_user_id: End user identifier
|
||||
- reflection_result: Individual reflection operation result
|
||||
|
||||
Processing Logic:
|
||||
1. Retrieve all applications in the current workspace
|
||||
2. Filter applications with valid memory configurations
|
||||
3. For each configuration, find matching releases
|
||||
4. Execute reflection for each end user with isolated transactions
|
||||
5. Aggregate results with error handling per user
|
||||
|
||||
Error Handling:
|
||||
- Individual user reflection failures are isolated
|
||||
- Failed operations are logged and included in results
|
||||
- Database transactions are isolated per user to prevent cascading failures
|
||||
- Comprehensive error reporting for debugging
|
||||
|
||||
Raises:
|
||||
HTTPException 500: If workspace reflection initialization fails
|
||||
|
||||
Performance Notes:
|
||||
- Uses independent database sessions for each user operation
|
||||
- Prevents transaction failures from affecting other users
|
||||
- Comprehensive logging for operation tracking
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
try:
|
||||
api_logger.info(f"用户 {current_user.username} 启动workspace反思,workspace_id: {workspace_id}")
|
||||
|
||||
# 使用独立的数据库会话来获取工作空间应用详情,避免事务失败
|
||||
# Use independent database session to get workspace app details, avoiding transaction failures
|
||||
from app.db import get_db_context
|
||||
with get_db_context() as query_db:
|
||||
service = WorkspaceAppService(query_db)
|
||||
@@ -116,8 +211,9 @@ async def start_workspace_reflection(
|
||||
|
||||
reflection_results = []
|
||||
|
||||
# Process each application in the workspace
|
||||
for data in result['apps_detailed_info']:
|
||||
# 跳过没有配置的应用
|
||||
# Skip applications without configurations
|
||||
if not data['memory_configs']:
|
||||
api_logger.debug(f"应用 {data['id']} 没有memory_configs,跳过")
|
||||
continue
|
||||
@@ -126,22 +222,22 @@ async def start_workspace_reflection(
|
||||
memory_configs = data['memory_configs']
|
||||
end_users = data['end_users']
|
||||
|
||||
# 为每个配置和用户组合执行反思
|
||||
# Execute reflection for each configuration and user combination
|
||||
for config in memory_configs:
|
||||
config_id_str = str(config['config_id'])
|
||||
|
||||
# 找到匹配此配置的所有release
|
||||
# Find all releases matching this configuration
|
||||
matching_releases = [r for r in releases if str(r['config']) == config_id_str]
|
||||
|
||||
if not matching_releases:
|
||||
api_logger.debug(f"配置 {config_id_str} 没有匹配的release")
|
||||
continue
|
||||
|
||||
# 为每个用户执行反思 - 使用独立的数据库会话
|
||||
# Execute reflection for each user - using independent database sessions
|
||||
for user in end_users:
|
||||
api_logger.info(f"为用户 {user['id']} 启动反思,config_id: {config_id_str}")
|
||||
|
||||
# 为每个用户创建独立的数据库会话,避免事务失败影响其他用户
|
||||
# Create independent database session for each user to avoid transaction failure impact
|
||||
with get_db_context() as user_db:
|
||||
try:
|
||||
reflection_service = MemoryReflectionService(user_db)
|
||||
@@ -184,14 +280,51 @@ async def start_reflection_configs(
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
"""通过config_id查询memory_config表中的反思配置信息"""
|
||||
"""
|
||||
Query reflection configuration information by config_id
|
||||
|
||||
Retrieves detailed reflection configuration settings from the memory_config
|
||||
table for a specific configuration ID. Provides comprehensive reflection
|
||||
parameters including model settings, evaluation criteria, and operational flags.
|
||||
|
||||
Args:
|
||||
config_id: Configuration identifier (UUID or integer) to query
|
||||
current_user: Authenticated user making the request
|
||||
db: Database session for data operations
|
||||
|
||||
Returns:
|
||||
dict: Success response with detailed reflection configuration:
|
||||
- config_id: Resolved configuration identifier
|
||||
- reflection_enabled: Whether reflection is enabled for this config
|
||||
- reflection_period_in_hours: Reflection execution interval
|
||||
- reflexion_range: Scope of reflection operations (partial/all)
|
||||
- baseline: Reflection strategy (time/fact/hybrid)
|
||||
- reflection_model_id: LLM model identifier for reflection
|
||||
- memory_verify: Memory verification flag
|
||||
- quality_assessment: Quality assessment flag
|
||||
|
||||
Database Operations:
|
||||
- Queries memory_config table by resolved config_id
|
||||
- Retrieves all reflection-related configuration fields
|
||||
- Resolves configuration ID for consistent formatting
|
||||
|
||||
Raises:
|
||||
HTTPException 404: If configuration with specified ID is not found
|
||||
HTTPException 500: If configuration query operation fails
|
||||
|
||||
ID Resolution:
|
||||
- Supports both UUID and integer config_id formats
|
||||
- Automatically resolves to appropriate internal format
|
||||
- Maintains consistency across different ID representations
|
||||
"""
|
||||
config_id = resolve_config_id(config_id, db)
|
||||
try:
|
||||
config_id=resolve_config_id(config_id,db)
|
||||
api_logger.info(f"用户 {current_user.username} 查询反思配置,config_id: {config_id}")
|
||||
result = MemoryConfigRepository.query_reflection_config_by_id(db, config_id)
|
||||
memory_config_id = resolve_config_id(result.config_id, db)
|
||||
# 构建返回数据
|
||||
|
||||
# Build response data with comprehensive configuration details
|
||||
reflection_config = {
|
||||
"config_id": memory_config_id,
|
||||
"reflection_enabled": result.enable_self_reflexion,
|
||||
@@ -204,10 +337,12 @@ async def start_reflection_configs(
|
||||
}
|
||||
api_logger.info(f"成功查询反思配置,config_id: {config_id}")
|
||||
return success(data=reflection_config, msg="反思配置查询成功")
|
||||
|
||||
|
||||
api_logger.info(f"Successfully queried reflection config, config_id: {config_id}")
|
||||
return success(data=reflection_config, msg="Reflection configuration query successful")
|
||||
|
||||
except HTTPException:
|
||||
# 重新抛出HTTP异常
|
||||
# Re-raise HTTP exceptions without modification
|
||||
raise
|
||||
except Exception as e:
|
||||
api_logger.error(f"查询反思配置失败: {str(e)}")
|
||||
@@ -223,13 +358,66 @@ async def reflection_run(
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
"""Activate the reflection function for all matching applications in the workspace"""
|
||||
# 使用集中化的语言校验
|
||||
"""
|
||||
Execute reflection engine with specified configuration
|
||||
|
||||
Runs the reflection engine using configuration parameters from the database.
|
||||
Validates model availability, sets up the reflection engine with proper
|
||||
configuration, and executes the reflection process with multi-language support.
|
||||
|
||||
This endpoint provides a test run capability for reflection configurations,
|
||||
allowing users to validate their reflection settings and see results before
|
||||
deploying to production environments.
|
||||
|
||||
Args:
|
||||
config_id: Configuration identifier (UUID or integer) for reflection settings
|
||||
language_type: Language preference header for output localization (optional)
|
||||
current_user: Authenticated user executing the reflection
|
||||
db: Database session for configuration queries
|
||||
|
||||
Returns:
|
||||
dict: Success response with reflection execution results including:
|
||||
- baseline: Reflection strategy used
|
||||
- source_data: Input data processed
|
||||
- memory_verifies: Memory verification results (if enabled)
|
||||
- quality_assessments: Quality assessment results (if enabled)
|
||||
- reflexion_data: Generated reflection insights and solutions
|
||||
|
||||
Configuration Validation:
|
||||
- Verifies configuration exists in database
|
||||
- Validates LLM model availability
|
||||
- Falls back to default model if specified model is unavailable
|
||||
- Ensures all required parameters are properly set
|
||||
|
||||
Reflection Engine Setup:
|
||||
- Creates ReflectionConfig with database parameters
|
||||
- Initializes Neo4j connector for memory access
|
||||
- Sets up ReflectionEngine with validated model
|
||||
- Configures language preferences for output
|
||||
|
||||
Error Handling:
|
||||
- Model validation with fallback to default
|
||||
- Configuration validation and error reporting
|
||||
- Comprehensive logging for debugging
|
||||
- Graceful handling of missing configurations
|
||||
|
||||
Raises:
|
||||
HTTPException 404: If configuration is not found
|
||||
HTTPException 500: If reflection execution fails
|
||||
|
||||
Performance Notes:
|
||||
- Direct database query for configuration retrieval
|
||||
- Model validation to prevent runtime failures
|
||||
- Efficient reflection engine initialization
|
||||
- Language-aware output processing
|
||||
"""
|
||||
# Use centralized language validation for consistent localization
|
||||
language = get_language_from_header(language_type)
|
||||
|
||||
api_logger.info(f"用户 {current_user.username} 查询反思配置,config_id: {config_id}")
|
||||
config_id = resolve_config_id(config_id, db)
|
||||
# 使用MemoryConfigRepository查询反思配置
|
||||
|
||||
# Query reflection configuration using MemoryConfigRepository
|
||||
result = MemoryConfigRepository.query_reflection_config_by_id(db, config_id)
|
||||
if not result:
|
||||
raise HTTPException(
|
||||
@@ -239,7 +427,7 @@ async def reflection_run(
|
||||
|
||||
api_logger.info(f"成功查询反思配置,config_id: {config_id}")
|
||||
|
||||
# 验证模型ID是否存在
|
||||
# Validate model ID existence
|
||||
model_id = result.reflection_model_id
|
||||
if model_id:
|
||||
try:
|
||||
@@ -250,6 +438,7 @@ async def reflection_run(
|
||||
# 可以设置为None,让反思引擎使用默认模型
|
||||
model_id = None
|
||||
|
||||
# Create reflection configuration with database parameters
|
||||
config = ReflectionConfig(
|
||||
enabled=result.enable_self_reflexion,
|
||||
iteration_period=result.iteration_period,
|
||||
@@ -262,11 +451,13 @@ async def reflection_run(
|
||||
model_id=model_id,
|
||||
language_type=language_type
|
||||
)
|
||||
|
||||
# Initialize Neo4j connector and reflection engine
|
||||
connector = Neo4jConnector()
|
||||
engine = ReflectionEngine(
|
||||
config=config,
|
||||
neo4j_connector=connector,
|
||||
llm_client=model_id # 传入验证后的 model_id
|
||||
llm_client=model_id # Pass validated model_id
|
||||
)
|
||||
|
||||
result=await (engine.reflection_run())
|
||||
|
||||
@@ -1,19 +1,40 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException, status,Header
|
||||
"""
|
||||
Memory Short Term Controller
|
||||
|
||||
This module provides REST API endpoints for managing short-term and long-term memory
|
||||
data retrieval and analysis. It handles memory system statistics, data aggregation,
|
||||
and provides comprehensive memory insights for end users.
|
||||
|
||||
Key Features:
|
||||
- Short-term memory data retrieval and statistics
|
||||
- Long-term memory data aggregation
|
||||
- Entity count integration
|
||||
- Multi-language response support
|
||||
- Memory system analytics and reporting
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from fastapi import APIRouter, Depends, Header, HTTPException, status
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.language_utils import get_language_from_header
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.core.response_utils import success
|
||||
from app.db import get_db
|
||||
from app.dependencies import get_current_user
|
||||
from app.models.user_model import User
|
||||
|
||||
from app.services.memory_short_service import LongService, ShortService
|
||||
from app.services.memory_storage_service import search_entity
|
||||
from app.services.memory_short_service import ShortService,LongService
|
||||
from dotenv import load_dotenv
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import Optional
|
||||
|
||||
# Load environment variables for configuration
|
||||
load_dotenv()
|
||||
|
||||
# Initialize API logger for request tracking and debugging
|
||||
api_logger = get_api_logger()
|
||||
|
||||
# Configure router with prefix and tags for API organization
|
||||
router = APIRouter(
|
||||
prefix="/memory/short",
|
||||
tags=["Memory"],
|
||||
@@ -25,24 +46,73 @@ async def short_term_configs(
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
# 使用集中化的语言校验
|
||||
"""
|
||||
Retrieve comprehensive short-term and long-term memory statistics
|
||||
|
||||
Provides a comprehensive overview of memory system data for a specific end user,
|
||||
including short-term memory entries, long-term memory aggregations, entity counts,
|
||||
and retrieval statistics. Supports multi-language responses based on request headers.
|
||||
|
||||
This endpoint serves as a central dashboard for memory system analytics, combining
|
||||
data from multiple memory subsystems to provide a holistic view of user memory state.
|
||||
|
||||
Args:
|
||||
end_user_id: Unique identifier for the end user whose memory data to retrieve
|
||||
language_type: Language preference header for response localization (optional)
|
||||
current_user: Authenticated user making the request (injected by dependency)
|
||||
db: Database session for data operations (injected by dependency)
|
||||
|
||||
Returns:
|
||||
dict: Success response containing comprehensive memory statistics:
|
||||
- short_term: List of short-term memory entries with detailed data
|
||||
- long_term: List of long-term memory aggregations and summaries
|
||||
- entity: Count of entities associated with the end user
|
||||
- retrieval_number: Total count of short-term memory retrievals
|
||||
- long_term_number: Total count of long-term memory entries
|
||||
|
||||
Response Structure:
|
||||
{
|
||||
"code": 200,
|
||||
"msg": "Short-term memory system data retrieved successfully",
|
||||
"data": {
|
||||
"short_term": [...], # Short-term memory entries
|
||||
"long_term": [...], # Long-term memory data
|
||||
"entity": 42, # Entity count
|
||||
"retrieval_number": 156, # Short-term retrieval count
|
||||
"long_term_number": 23 # Long-term memory count
|
||||
}
|
||||
}
|
||||
|
||||
Raises:
|
||||
HTTPException: If end_user_id is invalid or data retrieval fails
|
||||
|
||||
Performance Notes:
|
||||
- Combines multiple service calls for comprehensive data
|
||||
- Entity search is performed asynchronously for better performance
|
||||
- Response time depends on memory data volume for the specified user
|
||||
"""
|
||||
# Use centralized language validation for consistent localization
|
||||
language = get_language_from_header(language_type)
|
||||
|
||||
# 获取短期记忆数据
|
||||
short_term=ShortService(end_user_id)
|
||||
short_result=short_term.get_short_databasets()
|
||||
short_count=short_term.get_short_count()
|
||||
# Retrieve short-term memory data and statistics
|
||||
short_term = ShortService(end_user_id, db)
|
||||
short_result = short_term.get_short_databasets() # Get short-term memory entries
|
||||
short_count = short_term.get_short_count() # Get short-term retrieval count
|
||||
|
||||
long_term=LongService(end_user_id)
|
||||
long_result=long_term.get_long_databasets()
|
||||
# Retrieve long-term memory data and aggregations
|
||||
long_term = LongService(end_user_id, db)
|
||||
long_result = long_term.get_long_databasets() # Get long-term memory entries
|
||||
|
||||
# Get entity count for the specified end user
|
||||
entity_result = await search_entity(end_user_id)
|
||||
|
||||
# Compile comprehensive memory statistics response
|
||||
result = {
|
||||
'short_term': short_result,
|
||||
'long_term': long_result,
|
||||
'entity': entity_result.get('num', 0),
|
||||
"retrieval_number":short_count,
|
||||
"long_term_number":len(long_result)
|
||||
'short_term': short_result, # Short-term memory entries
|
||||
'long_term': long_result, # Long-term memory data
|
||||
'entity': entity_result.get('num', 0), # Entity count (default to 0 if not found)
|
||||
"retrieval_number": short_count, # Short-term retrieval statistics
|
||||
"long_term_number": len(long_result) # Long-term memory entry count
|
||||
}
|
||||
|
||||
return success(data=result, msg="短期记忆系统数据获取成功")
|
||||
@@ -2,7 +2,7 @@ from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from fastapi.responses import StreamingResponse
|
||||
from fastapi.responses import StreamingResponse, JSONResponse
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.error_codes import BizCode
|
||||
@@ -54,8 +54,8 @@ router = APIRouter(
|
||||
|
||||
@router.get("/info", response_model=ApiResponse)
|
||||
async def get_storage_info(
|
||||
storage_id: str,
|
||||
current_user: User = Depends(get_current_user)
|
||||
storage_id: str,
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Example wrapper endpoint - retrieves storage information
|
||||
@@ -75,23 +75,19 @@ async def get_storage_info(
|
||||
return fail(BizCode.INTERNAL_ERROR, "存储信息获取失败", str(e))
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@router.post("/create_config", response_model=ApiResponse) # 创建配置文件,其他参数默认
|
||||
@router.post("/create_config", response_model=ApiResponse) # 创建配置文件,其他参数默认
|
||||
def create_config(
|
||||
payload: ConfigParamsCreate,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
payload: ConfigParamsCreate,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
x_language_type: Optional[str] = Header(None, alias="X-Language-Type"),
|
||||
) -> dict:
|
||||
workspace_id = current_user.current_workspace_id
|
||||
# 检查用户是否已选择工作空间
|
||||
if workspace_id is None:
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试创建配置但未选择工作空间")
|
||||
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
||||
|
||||
|
||||
api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求创建配置: {payload.config_name}")
|
||||
try:
|
||||
# 将 workspace_id 注入到 payload 中(保持为 UUID 类型)
|
||||
@@ -99,17 +95,43 @@ def create_config(
|
||||
svc = DataConfigService(db)
|
||||
result = svc.create(payload)
|
||||
return success(data=result, msg="创建成功")
|
||||
except ValueError as e:
|
||||
err_str = str(e)
|
||||
if err_str.startswith("DUPLICATE_CONFIG_NAME:"):
|
||||
config_name = err_str.split(":", 1)[1]
|
||||
api_logger.warning(f"重复的配置名称 '{config_name}' 在工作空间 {workspace_id}")
|
||||
lang = get_language_from_header(x_language_type)
|
||||
if lang == "en":
|
||||
msg = fail(BizCode.BAD_REQUEST, "Config name already exists",
|
||||
f"A config named \"{config_name}\" already exists in the current workspace. Please use a different name.")
|
||||
else:
|
||||
msg = fail(BizCode.BAD_REQUEST, "配置名称已存在",
|
||||
f"当前工作空间下已存在名为「{config_name}」的记忆配置,请使用其他名称")
|
||||
return JSONResponse(status_code=400, content=msg)
|
||||
api_logger.error(f"Create config failed: {err_str}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "创建配置失败", err_str)
|
||||
except Exception as e:
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
if isinstance(e, IntegrityError) and "uq_workspace_config_name" in str(getattr(e, 'orig', '')):
|
||||
api_logger.warning(f"重复的配置名称 '{payload.config_name}' 在工作空间 {workspace_id}")
|
||||
lang = get_language_from_header(x_language_type)
|
||||
if lang == "en":
|
||||
msg = fail(BizCode.BAD_REQUEST, "Config name already exists",
|
||||
f"A config named \"{payload.config_name}\" already exists in the current workspace. Please use a different name.")
|
||||
else:
|
||||
msg = fail(BizCode.BAD_REQUEST, "配置名称已存在",
|
||||
f"当前工作空间下已存在名为「{payload.config_name}」的记忆配置,请使用其他名称")
|
||||
return JSONResponse(status_code=400, content=msg)
|
||||
api_logger.error(f"Create config failed: {str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "创建配置失败", str(e))
|
||||
|
||||
|
||||
@router.delete("/delete_config", response_model=ApiResponse) # 删除数据库中的内容(按配置名称)
|
||||
def delete_config(
|
||||
config_id: UUID|int,
|
||||
force: bool = Query(False, description="是否强制删除(即使有终端用户正在使用)"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
config_id: UUID | int,
|
||||
force: bool = Query(False, description="是否强制删除(即使有终端用户正在使用)"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
"""删除记忆配置(带终端用户保护)
|
||||
|
||||
@@ -122,24 +144,24 @@ def delete_config(
|
||||
force: 设置为 true 可强制删除(即使有终端用户正在使用)
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
config_id=resolve_config_id(config_id, db)
|
||||
config_id = resolve_config_id(config_id, db)
|
||||
# 检查用户是否已选择工作空间
|
||||
if workspace_id is None:
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试删除配置但未选择工作空间")
|
||||
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
||||
|
||||
|
||||
api_logger.info(
|
||||
f"用户 {current_user.username} 在工作空间 {workspace_id} 请求删除配置: "
|
||||
f"config_id={config_id}, force={force}"
|
||||
)
|
||||
|
||||
|
||||
try:
|
||||
# 使用带保护的删除服务
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
|
||||
|
||||
config_service = MemoryConfigService(db)
|
||||
result = config_service.delete_config(config_id=config_id, force=force)
|
||||
|
||||
|
||||
if result["status"] == "error":
|
||||
api_logger.warning(
|
||||
f"记忆配置删除被拒绝: config_id={config_id}, reason={result['message']}"
|
||||
@@ -149,7 +171,7 @@ def delete_config(
|
||||
msg=result["message"],
|
||||
data={"config_id": str(config_id), "is_default": result.get("is_default", False)}
|
||||
)
|
||||
|
||||
|
||||
if result["status"] == "warning":
|
||||
api_logger.warning(
|
||||
f"记忆配置正在使用,无法删除: config_id={config_id}, "
|
||||
@@ -163,7 +185,7 @@ def delete_config(
|
||||
"force_required": result["force_required"]
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
api_logger.info(
|
||||
f"记忆配置删除成功: config_id={config_id}, "
|
||||
f"affected_users={result['affected_users']}"
|
||||
@@ -172,7 +194,7 @@ def delete_config(
|
||||
msg=result["message"],
|
||||
data={"affected_users": result["affected_users"]}
|
||||
)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"Delete config failed: {str(e)}", exc_info=True)
|
||||
return fail(BizCode.INTERNAL_ERROR, "删除配置失败", str(e))
|
||||
@@ -180,9 +202,9 @@ def delete_config(
|
||||
|
||||
@router.post("/update_config", response_model=ApiResponse) # 更新配置文件中name和desc
|
||||
def update_config(
|
||||
payload: ConfigUpdate,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
payload: ConfigUpdate,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
workspace_id = current_user.current_workspace_id
|
||||
payload.config_id = resolve_config_id(payload.config_id, db)
|
||||
@@ -190,12 +212,13 @@ def update_config(
|
||||
if workspace_id is None:
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试更新配置但未选择工作空间")
|
||||
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
||||
|
||||
|
||||
# 校验至少有一个字段需要更新
|
||||
if payload.config_name is None and payload.config_desc is None and payload.scene_id is None:
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试更新配置但未提供任何更新字段")
|
||||
return fail(BizCode.INVALID_PARAMETER, "请至少提供一个需要更新的字段", "config_name, config_desc, scene_id 均为空")
|
||||
|
||||
return fail(BizCode.INVALID_PARAMETER, "请至少提供一个需要更新的字段",
|
||||
"config_name, config_desc, scene_id 均为空")
|
||||
|
||||
api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求更新配置: {payload.config_id}")
|
||||
try:
|
||||
svc = DataConfigService(db)
|
||||
@@ -208,9 +231,9 @@ def update_config(
|
||||
|
||||
@router.post("/update_config_extracted", response_model=ApiResponse) # 更新数据库中的部分内容 所有业务字段均可选
|
||||
def update_config_extracted(
|
||||
payload: ConfigUpdateExtracted,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
payload: ConfigUpdateExtracted,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
workspace_id = current_user.current_workspace_id
|
||||
payload.config_id = resolve_config_id(payload.config_id, db)
|
||||
@@ -218,7 +241,7 @@ def update_config_extracted(
|
||||
if workspace_id is None:
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试更新提取配置但未选择工作空间")
|
||||
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
||||
|
||||
|
||||
api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求更新提取配置: {payload.config_id}")
|
||||
try:
|
||||
svc = DataConfigService(db)
|
||||
@@ -233,11 +256,11 @@ def update_config_extracted(
|
||||
# 遗忘引擎配置接口已迁移到 memory_forget_controller.py
|
||||
# 使用新接口: /api/memory/forget/read_config 和 /api/memory/forget/update_config
|
||||
|
||||
@router.get("/read_config_extracted", response_model=ApiResponse) # 通过查询参数读取某条配置(固定路径) 没有意义的话就删除
|
||||
@router.get("/read_config_extracted", response_model=ApiResponse) # 通过查询参数读取某条配置(固定路径) 没有意义的话就删除
|
||||
def read_config_extracted(
|
||||
config_id: UUID | int,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
config_id: UUID | int,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
workspace_id = current_user.current_workspace_id
|
||||
config_id = resolve_config_id(config_id, db)
|
||||
@@ -245,7 +268,7 @@ def read_config_extracted(
|
||||
if workspace_id is None:
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试读取提取配置但未选择工作空间")
|
||||
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
||||
|
||||
|
||||
api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求读取提取配置: {config_id}")
|
||||
try:
|
||||
svc = DataConfigService(db)
|
||||
@@ -255,18 +278,19 @@ def read_config_extracted(
|
||||
api_logger.error(f"Read config extracted failed: {str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "查询配置失败", str(e))
|
||||
|
||||
@router.get("/read_all_config", response_model=ApiResponse) # 读取所有配置文件列表
|
||||
|
||||
@router.get("/read_all_config", response_model=ApiResponse) # 读取所有配置文件列表
|
||||
def read_all_config(
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
|
||||
# 检查用户是否已选择工作空间
|
||||
if workspace_id is None:
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试查询配置但未选择工作空间")
|
||||
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
||||
|
||||
|
||||
api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求读取所有配置")
|
||||
try:
|
||||
svc = DataConfigService(db)
|
||||
@@ -280,14 +304,14 @@ def read_all_config(
|
||||
|
||||
@router.post("/pilot_run", response_model=None)
|
||||
async def pilot_run(
|
||||
payload: ConfigPilotRun,
|
||||
language_type: str = Header(default=None, alias="X-Language-Type"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
payload: ConfigPilotRun,
|
||||
language_type: str = Header(default=None, alias="X-Language-Type"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> StreamingResponse:
|
||||
# 使用集中化的语言校验
|
||||
language = get_language_from_header(language_type)
|
||||
|
||||
|
||||
api_logger.info(
|
||||
f"Pilot run requested: config_id={payload.config_id}, "
|
||||
f"dialogue_text_length={len(payload.dialogue_text)}, "
|
||||
@@ -310,9 +334,9 @@ async def pilot_run(
|
||||
|
||||
@router.get("/search/kb_type_distribution", response_model=ApiResponse)
|
||||
async def get_kb_type_distribution(
|
||||
end_user_id: Optional[str] = None,
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> dict:
|
||||
end_user_id: Optional[str] = None,
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> dict:
|
||||
api_logger.info(f"KB type distribution requested for end_user_id: {end_user_id}")
|
||||
try:
|
||||
result = await kb_type_distribution(end_user_id)
|
||||
@@ -321,12 +345,12 @@ async def get_kb_type_distribution(
|
||||
api_logger.error(f"KB type distribution failed: {str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "知识库类型分布查询失败", str(e))
|
||||
|
||||
|
||||
|
||||
@router.get("/search/dialogue", response_model=ApiResponse)
|
||||
async def search_dialogues_num(
|
||||
end_user_id: Optional[str] = None,
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> dict:
|
||||
end_user_id: Optional[str] = None,
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> dict:
|
||||
api_logger.info(f"Search dialogue requested for end_user_id: {end_user_id}")
|
||||
try:
|
||||
result = await search_dialogue(end_user_id)
|
||||
@@ -338,9 +362,9 @@ async def search_dialogues_num(
|
||||
|
||||
@router.get("/search/chunk", response_model=ApiResponse)
|
||||
async def search_chunks_num(
|
||||
end_user_id: Optional[str] = None,
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> dict:
|
||||
end_user_id: Optional[str] = None,
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> dict:
|
||||
api_logger.info(f"Search chunk requested for end_user_id: {end_user_id}")
|
||||
try:
|
||||
result = await search_chunk(end_user_id)
|
||||
@@ -352,9 +376,9 @@ async def search_chunks_num(
|
||||
|
||||
@router.get("/search/statement", response_model=ApiResponse)
|
||||
async def search_statements_num(
|
||||
end_user_id: Optional[str] = None,
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> dict:
|
||||
end_user_id: Optional[str] = None,
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> dict:
|
||||
api_logger.info(f"Search statement requested for end_user_id: {end_user_id}")
|
||||
try:
|
||||
result = await search_statement(end_user_id)
|
||||
@@ -366,9 +390,9 @@ async def search_statements_num(
|
||||
|
||||
@router.get("/search/entity", response_model=ApiResponse)
|
||||
async def search_entities_num(
|
||||
end_user_id: Optional[str] = None,
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> dict:
|
||||
end_user_id: Optional[str] = None,
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> dict:
|
||||
api_logger.info(f"Search entity requested for end_user_id: {end_user_id}")
|
||||
try:
|
||||
result = await search_entity(end_user_id)
|
||||
@@ -380,9 +404,9 @@ async def search_entities_num(
|
||||
|
||||
@router.get("/search", response_model=ApiResponse)
|
||||
async def search_all_num(
|
||||
end_user_id: Optional[str] = None,
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> dict:
|
||||
end_user_id: Optional[str] = None,
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> dict:
|
||||
api_logger.info(f"Search all requested for end_user_id: {end_user_id}")
|
||||
try:
|
||||
result = await search_all(end_user_id)
|
||||
@@ -394,9 +418,9 @@ async def search_all_num(
|
||||
|
||||
@router.get("/search/detials", response_model=ApiResponse)
|
||||
async def search_entities_detials(
|
||||
end_user_id: Optional[str] = None,
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> dict:
|
||||
end_user_id: Optional[str] = None,
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> dict:
|
||||
api_logger.info(f"Search details requested for end_user_id: {end_user_id}")
|
||||
try:
|
||||
result = await search_detials(end_user_id)
|
||||
@@ -408,9 +432,9 @@ async def search_entities_detials(
|
||||
|
||||
@router.get("/search/edges", response_model=ApiResponse)
|
||||
async def search_entity_edges(
|
||||
end_user_id: Optional[str] = None,
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> dict:
|
||||
end_user_id: Optional[str] = None,
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> dict:
|
||||
api_logger.info(f"Search edges requested for end_user_id: {end_user_id}")
|
||||
try:
|
||||
result = await search_edges(end_user_id)
|
||||
@@ -420,14 +444,12 @@ async def search_entity_edges(
|
||||
return fail(BizCode.INTERNAL_ERROR, "边查询失败", str(e))
|
||||
|
||||
|
||||
|
||||
|
||||
@router.get("/analytics/hot_memory_tags", response_model=ApiResponse)
|
||||
async def get_hot_memory_tags_api(
|
||||
limit: int = 10,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> dict:
|
||||
limit: int = 10,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> dict:
|
||||
"""
|
||||
获取热门记忆标签(带Redis缓存)
|
||||
|
||||
@@ -438,18 +460,18 @@ async def get_hot_memory_tags_api(
|
||||
- 缓存未命中:~600-800ms(取决于LLM速度)
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
|
||||
# 构建缓存键
|
||||
cache_key = f"hot_memory_tags:{workspace_id}:{limit}"
|
||||
|
||||
|
||||
api_logger.info(f"Hot memory tags requested for workspace: {workspace_id}, limit: {limit}")
|
||||
|
||||
|
||||
try:
|
||||
# 尝试从Redis缓存获取
|
||||
import json
|
||||
|
||||
from app.aioRedis import aio_redis_get, aio_redis_set
|
||||
|
||||
|
||||
cached_result = await aio_redis_get(cache_key)
|
||||
if cached_result:
|
||||
api_logger.info(f"Cache hit for key: {cache_key}")
|
||||
@@ -458,11 +480,11 @@ async def get_hot_memory_tags_api(
|
||||
return success(data=data, msg="查询成功(缓存)")
|
||||
except json.JSONDecodeError:
|
||||
api_logger.warning(f"Failed to parse cached data, will refresh")
|
||||
|
||||
|
||||
# 缓存未命中,执行查询
|
||||
api_logger.info(f"Cache miss for key: {cache_key}, executing query")
|
||||
result = await analytics_hot_memory_tags(db, current_user, limit)
|
||||
|
||||
|
||||
# 写入缓存(过期时间:5分钟)
|
||||
# 注意:result是列表,需要转换为JSON字符串
|
||||
try:
|
||||
@@ -472,9 +494,9 @@ async def get_hot_memory_tags_api(
|
||||
except Exception as cache_error:
|
||||
# 缓存写入失败不影响主流程
|
||||
api_logger.warning(f"Failed to cache result: {str(cache_error)}")
|
||||
|
||||
|
||||
return success(data=result, msg="查询成功")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"Hot memory tags failed: {str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "热门标签查询失败", str(e))
|
||||
@@ -482,8 +504,8 @@ async def get_hot_memory_tags_api(
|
||||
|
||||
@router.delete("/analytics/hot_memory_tags/cache", response_model=ApiResponse)
|
||||
async def clear_hot_memory_tags_cache(
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> dict:
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> dict:
|
||||
"""
|
||||
清除热门标签缓存
|
||||
|
||||
@@ -493,12 +515,12 @@ async def clear_hot_memory_tags_cache(
|
||||
- 数据更新后立即生效
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
|
||||
api_logger.info(f"Clear hot memory tags cache requested for workspace: {workspace_id}")
|
||||
|
||||
|
||||
try:
|
||||
from app.aioRedis import aio_redis_delete
|
||||
|
||||
|
||||
# 清除所有limit的缓存(常见的limit值)
|
||||
cleared_count = 0
|
||||
for limit in [5, 10, 15, 20, 30, 50]:
|
||||
@@ -507,12 +529,12 @@ async def clear_hot_memory_tags_cache(
|
||||
if result:
|
||||
cleared_count += 1
|
||||
api_logger.info(f"Cleared cache for key: {cache_key}")
|
||||
|
||||
|
||||
return success(
|
||||
data={"cleared_count": cleared_count},
|
||||
data={"cleared_count": cleared_count},
|
||||
msg=f"成功清除 {cleared_count} 个缓存"
|
||||
)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"Clear cache failed: {str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "清除缓存失败", str(e))
|
||||
@@ -520,13 +542,13 @@ async def clear_hot_memory_tags_cache(
|
||||
|
||||
@router.get("/analytics/recent_activity_stats", response_model=ApiResponse)
|
||||
async def get_recent_activity_stats_api(
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> dict:
|
||||
api_logger.info("Recent activity stats requested")
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> dict:
|
||||
workspace_id = str(current_user.current_workspace_id) if current_user.current_workspace_id else None
|
||||
api_logger.info(f"Recent activity stats requested: workspace_id={workspace_id}")
|
||||
try:
|
||||
result = await analytics_recent_activity_stats()
|
||||
result = await analytics_recent_activity_stats(workspace_id=workspace_id)
|
||||
return success(data=result, msg="查询成功")
|
||||
except Exception as e:
|
||||
api_logger.error(f"Recent activity stats failed: {str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "最近活动统计失败", str(e))
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@ from app.core.response_utils import success
|
||||
from app.db import get_db
|
||||
from app.dependencies import get_current_user
|
||||
from app.models import User
|
||||
from app.schemas import conversation_schema
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.services.conversation_service import ConversationService
|
||||
|
||||
@@ -32,35 +33,47 @@ def get_memory_count(
|
||||
@router.get("/{end_user_id}/conversations", response_model=ApiResponse)
|
||||
def get_conversations(
|
||||
end_user_id: uuid.UUID,
|
||||
page: int = 1,
|
||||
pagesize: int = 20,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Retrieve all conversations for the current user in a specific group.
|
||||
Retrieve conversations for the current user in a specific group with pagination.
|
||||
|
||||
Args:
|
||||
end_user_id (UUID): The group identifier.
|
||||
page (int): Page number (1-based). Defaults to 1.
|
||||
pagesize (int): Number of items per page. Defaults to 20.
|
||||
current_user (User, optional): The authenticated user.
|
||||
db (Session, optional): SQLAlchemy session.
|
||||
|
||||
Returns:
|
||||
ApiResponse: Contains a list of conversation IDs.
|
||||
|
||||
Notes:
|
||||
- Initializes the ConversationService with the current DB session.
|
||||
- Returns only conversation IDs for lightweight response.
|
||||
- Logs can be added to trace requests in production.
|
||||
ApiResponse: Contains a paginated list of conversations.
|
||||
"""
|
||||
page = max(1, page)
|
||||
page_size = max(1, min(pagesize, 100)) # Limit page size between 1 and 100
|
||||
conversation_service = ConversationService(db)
|
||||
conversations = conversation_service.get_user_conversations(
|
||||
end_user_id
|
||||
conversations, total = conversation_service.get_user_conversations(
|
||||
end_user_id,
|
||||
page=page,
|
||||
page_size=page_size
|
||||
)
|
||||
return success(data=[
|
||||
{
|
||||
"id": conversation.id,
|
||||
"title": conversation.title
|
||||
} for conversation in conversations
|
||||
], msg="get conversations success")
|
||||
return success(data={
|
||||
"items": [
|
||||
{
|
||||
"id": conversation.id,
|
||||
"title": conversation.title
|
||||
} for conversation in conversations
|
||||
],
|
||||
"total": total,
|
||||
"page": {
|
||||
"page": page,
|
||||
"pagesize": page_size,
|
||||
"total": total,
|
||||
"hasnext": (page * page_size) < total
|
||||
},
|
||||
}, msg="get conversations success")
|
||||
|
||||
|
||||
@router.get("/{end_user_id}/messages", response_model=ApiResponse)
|
||||
@@ -90,11 +103,7 @@ def get_messages(
|
||||
conversation_id,
|
||||
)
|
||||
messages = [
|
||||
{
|
||||
"role": message.role,
|
||||
"content": message.content,
|
||||
"created_at": int(message.created_at.timestamp() * 1000),
|
||||
}
|
||||
conversation_schema.Message.model_validate(message)
|
||||
for message in messages_obj
|
||||
]
|
||||
return success(data=messages, msg="get conversation history success")
|
||||
|
||||
@@ -42,6 +42,7 @@ def get_model_strategies():
|
||||
@router.get("", response_model=ApiResponse)
|
||||
def get_model_list(
|
||||
type: Optional[list[str]] = Query(None, description="模型类型筛选(支持多个,如 ?type=LLM 或 ?type=LLM,EMBEDDING)"),
|
||||
capability: Optional[list[str]] = Query(None, description="能力筛选(支持多个,如 ?capability=chat 或 ?capability=chat, embedding)"),
|
||||
provider: Optional[model_schema.ModelProvider] = Query(None, description="提供商筛选(基于API Key)"),
|
||||
is_active: Optional[bool] = Query(None, description="激活状态筛选"),
|
||||
is_public: Optional[bool] = Query(None, description="公开状态筛选"),
|
||||
@@ -74,10 +75,21 @@ def get_model_list(
|
||||
unique_flat_type = list(dict.fromkeys(flat_type))
|
||||
type_list = [ModelType(t.lower()) for t in unique_flat_type]
|
||||
|
||||
capability_list = []
|
||||
if capability is not None:
|
||||
flat_capability = []
|
||||
for item in capability:
|
||||
split_items = [c.strip() for c in item.split(', ') if c.strip()]
|
||||
flat_capability.extend(split_items)
|
||||
|
||||
unique_flat_capability = list(dict.fromkeys(flat_capability))
|
||||
capability_list = unique_flat_capability
|
||||
|
||||
api_logger.error(f"获取模型type_list: {type_list}")
|
||||
query = model_schema.ModelConfigQuery(
|
||||
type=type_list,
|
||||
provider=provider,
|
||||
capability=capability_list,
|
||||
is_active=is_active,
|
||||
is_public=is_public,
|
||||
search=search,
|
||||
@@ -371,6 +383,11 @@ def update_model(
|
||||
|
||||
if model_data.type is not None or model_data.provider is not None:
|
||||
raise BusinessException("不允许更改模型类型和供应商", BizCode.INVALID_PARAMETER)
|
||||
|
||||
if model_data.is_active:
|
||||
active_keys = ModelApiKeyService.get_api_keys_by_model(db=db, model_config_id=model_id, is_active=model_data.is_active)
|
||||
if not active_keys:
|
||||
raise BusinessException("请先为该模型配置可用的 API Key", BizCode.INVALID_PARAMETER)
|
||||
|
||||
try:
|
||||
api_logger.debug(f"开始更新模型配置: model_id={model_id}")
|
||||
@@ -469,7 +486,9 @@ async def create_model_api_key_by_provider(
|
||||
config=api_key_data.config,
|
||||
is_active=api_key_data.is_active,
|
||||
priority=api_key_data.priority,
|
||||
model_config_ids=model_config_ids
|
||||
model_config_ids=model_config_ids,
|
||||
capability=api_key_data.capability,
|
||||
is_omni=api_key_data.is_omni
|
||||
)
|
||||
created_keys, failed_models = await ModelApiKeyService.create_api_key_by_provider(db=db, data=create_data)
|
||||
|
||||
|
||||
@@ -25,13 +25,13 @@ from typing import Dict, Optional, List
|
||||
from urllib.parse import quote
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, File, UploadFile, Form, Header
|
||||
from fastapi.responses import StreamingResponse
|
||||
from fastapi.responses import StreamingResponse, JSONResponse
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.language_utils import get_language_from_header
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.core.logging_config import get_api_logger, get_business_logger
|
||||
from app.core.response_utils import fail, success
|
||||
from app.db import get_db
|
||||
from app.dependencies import get_current_user
|
||||
@@ -61,6 +61,7 @@ from app.repositories.ontology_scene_repository import OntologySceneRepository
|
||||
|
||||
|
||||
api_logger = get_api_logger()
|
||||
business_logger = get_business_logger()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(
|
||||
@@ -123,15 +124,23 @@ def _get_ontology_service(
|
||||
)
|
||||
|
||||
# 通过 Repository 获取可用的 API Key(负载均衡逻辑由 Repository 处理)
|
||||
from app.repositories.model_repository import ModelApiKeyRepository
|
||||
api_keys = ModelApiKeyRepository.get_by_model_config(db, model_config.id)
|
||||
if not api_keys:
|
||||
# from app.repositories.model_repository import ModelApiKeyRepository
|
||||
from app.services.model_service import ModelApiKeyService
|
||||
api_key_config = ModelApiKeyService.get_available_api_key(db, model_config.id)
|
||||
if not api_key_config:
|
||||
logger.error(f"Model {llm_id} has no active API key")
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="指定的LLM模型没有可用的API密钥"
|
||||
)
|
||||
api_key_config = api_keys[0]
|
||||
# api_keys = ModelApiKeyRepository.get_by_model_config(db, model_config.id)
|
||||
# if not api_keys:
|
||||
# logger.error(f"Model {llm_id} has no active API key")
|
||||
# raise HTTPException(
|
||||
# status_code=400,
|
||||
# detail="指定的LLM模型没有可用的API密钥"
|
||||
# )
|
||||
# api_key_config = api_keys[0]
|
||||
|
||||
is_composite = getattr(model_config, 'is_composite', False)
|
||||
logger.info(
|
||||
@@ -153,6 +162,7 @@ def _get_ontology_service(
|
||||
provider=actual_provider,
|
||||
api_key=api_key_config.api_key,
|
||||
base_url=api_key_config.api_base,
|
||||
is_omni=api_key_config.is_omni,
|
||||
max_retries=3,
|
||||
timeout=60.0
|
||||
)
|
||||
@@ -279,7 +289,8 @@ async def extract_ontology(
|
||||
async def create_scene(
|
||||
request: SceneCreateRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
current_user: User = Depends(get_current_user),
|
||||
x_language_type: Optional[str] = Header(None, alias="X-Language-Type")
|
||||
):
|
||||
"""创建本体场景
|
||||
|
||||
@@ -350,8 +361,18 @@ async def create_scene(
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", str(e))
|
||||
|
||||
except RuntimeError as e:
|
||||
api_logger.error(f"Runtime error in scene creation: {str(e)}", exc_info=True)
|
||||
return fail(BizCode.INTERNAL_ERROR, "场景创建失败", str(e))
|
||||
err_str = str(e)
|
||||
if "UniqueViolation" in err_str or "uq_workspace_scene_name" in err_str:
|
||||
api_logger.warning(f"Duplicate scene name '{request.scene_name}' in workspace {current_user.current_workspace_id}")
|
||||
from app.core.language_utils import get_language_from_header
|
||||
lang = get_language_from_header(x_language_type)
|
||||
if lang == "en":
|
||||
msg = fail(BizCode.BAD_REQUEST, "Scene name already exists", f"A scene named \"{request.scene_name}\" already exists in the current workspace. Please use a different name.")
|
||||
else:
|
||||
msg = fail(BizCode.BAD_REQUEST, "场景名称已存在", f"当前工作空间下已存在名为「{request.scene_name}」的场景,请使用其他名称")
|
||||
return JSONResponse(status_code=400, content=msg)
|
||||
api_logger.error(f"Runtime error in scene creation: {err_str}", exc_info=True)
|
||||
return fail(BizCode.INTERNAL_ERROR, "场景创建失败", err_str)
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"Unexpected error in scene creation: {str(e)}", exc_info=True)
|
||||
@@ -399,6 +420,20 @@ async def update_scene(
|
||||
api_logger.warning(f"User {current_user.id} has no current workspace")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "当前用户没有工作空间")
|
||||
|
||||
# 检查是否为系统默认场景
|
||||
scene_repo = OntologySceneRepository(db)
|
||||
scene = scene_repo.get_by_id(scene_uuid)
|
||||
if scene and scene.is_system_default:
|
||||
business_logger.warning(
|
||||
f"尝试修改系统默认场景: user_id={current_user.id}, "
|
||||
f"scene_id={scene_id}, scene_name={scene.scene_name}"
|
||||
)
|
||||
return fail(
|
||||
BizCode.BAD_REQUEST,
|
||||
"系统默认场景不可修改",
|
||||
"该场景为系统预设场景,不允许修改"
|
||||
)
|
||||
|
||||
# 创建OntologyService实例
|
||||
from app.core.memory.llm_tools.openai_client import OpenAIClient
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
@@ -491,6 +526,19 @@ async def delete_scene(
|
||||
api_logger.warning(f"User {current_user.id} has no current workspace")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "当前用户没有工作空间")
|
||||
|
||||
# 检查是否为系统默认场景
|
||||
scene_repo = OntologySceneRepository(db)
|
||||
scene = scene_repo.get_by_id(scene_uuid)
|
||||
if scene and scene.is_system_default:
|
||||
business_logger.warning(
|
||||
f"尝试删除系统默认场景: user_id={current_user.id}, "
|
||||
f"scene_id={scene_id}, scene_name={scene.scene_name}"
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="SYSTEM_DEFAULT_SCENE_CANNOT_DELETE"
|
||||
)
|
||||
|
||||
# 创建OntologyService实例
|
||||
from app.core.memory.llm_tools.openai_client import OpenAIClient
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
@@ -514,6 +562,9 @@ async def delete_scene(
|
||||
|
||||
return success(data={"deleted": success_flag}, msg="场景删除成功")
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
|
||||
except ValueError as e:
|
||||
api_logger.warning(f"Validation error in scene deletion: {str(e)}")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", str(e))
|
||||
@@ -621,7 +672,8 @@ async def get_scenes(
|
||||
async def create_class(
|
||||
request: ClassCreateRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
current_user: User = Depends(get_current_user),
|
||||
x_language_type: Optional[str] = Header(None, alias="X-Language-Type")
|
||||
):
|
||||
"""创建本体类型
|
||||
|
||||
@@ -636,7 +688,7 @@ async def create_class(
|
||||
ApiResponse: 包含创建的类型信息
|
||||
"""
|
||||
from app.controllers.ontology_secondary_routes import create_class_handler
|
||||
return await create_class_handler(request, db, current_user)
|
||||
return await create_class_handler(request, db, current_user, x_language_type)
|
||||
|
||||
|
||||
@router.put("/class/{class_id}", response_model=ApiResponse)
|
||||
|
||||
@@ -7,11 +7,11 @@
|
||||
from uuid import UUID
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import Depends
|
||||
from fastapi import Depends, Header
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.core.logging_config import get_api_logger, get_business_logger
|
||||
from app.core.response_utils import fail, success
|
||||
from app.db import get_db
|
||||
from app.dependencies import get_current_user
|
||||
@@ -30,9 +30,11 @@ from app.schemas.response_schema import ApiResponse
|
||||
from app.services.ontology_service import OntologyService
|
||||
from app.core.memory.llm_tools.openai_client import OpenAIClient
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
from app.repositories.ontology_class_repository import OntologyClassRepository
|
||||
|
||||
|
||||
api_logger = get_api_logger()
|
||||
business_logger = get_business_logger()
|
||||
|
||||
|
||||
def _get_dummy_ontology_service(db: Session) -> OntologyService:
|
||||
@@ -56,7 +58,7 @@ async def scenes_handler(
|
||||
workspace_id: Optional[str] = None,
|
||||
scene_name: Optional[str] = None,
|
||||
page: Optional[int] = None,
|
||||
page_size: Optional[int] = None,
|
||||
pagesize: Optional[int] = None,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
@@ -69,14 +71,14 @@ async def scenes_handler(
|
||||
workspace_id: 工作空间ID(可选,默认当前用户工作空间)
|
||||
scene_name: 场景名称关键词(可选,支持模糊匹配)
|
||||
page: 页码(可选,从1开始,仅在全量查询时有效)
|
||||
page_size: 每页数量(可选,仅在全量查询时有效)
|
||||
pagesize: 每页数量(可选,仅在全量查询时有效)
|
||||
db: 数据库会话
|
||||
current_user: 当前用户
|
||||
"""
|
||||
operation = "search" if scene_name else "list"
|
||||
api_logger.info(
|
||||
f"Scene {operation} requested by user {current_user.id}, "
|
||||
f"workspace_id={workspace_id}, keyword={scene_name}, page={page}, page_size={page_size}"
|
||||
f"workspace_id={workspace_id}, keyword={scene_name}, page={page}, pagesize={pagesize}"
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -103,13 +105,13 @@ async def scenes_handler(
|
||||
api_logger.warning(f"Invalid page number: {page}")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "页码必须大于0")
|
||||
|
||||
if page_size is not None and page_size < 1:
|
||||
api_logger.warning(f"Invalid page_size: {page_size}")
|
||||
if pagesize is not None and pagesize < 1:
|
||||
api_logger.warning(f"Invalid pagesize: {pagesize}")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "每页数量必须大于0")
|
||||
|
||||
# 如果只提供了page或page_size中的一个,返回错误
|
||||
if (page is not None and page_size is None) or (page is None and page_size is not None):
|
||||
api_logger.warning(f"Incomplete pagination params: page={page}, page_size={page_size}")
|
||||
# 如果只提供了page或pagesize中的一个,返回错误
|
||||
if (page is not None and pagesize is None) or (page is None and pagesize is not None):
|
||||
api_logger.warning(f"Incomplete pagination params: page={page}, pagesize={pagesize}")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "分页参数page和pagesize必须同时提供")
|
||||
|
||||
# 模糊搜索场景(支持分页)
|
||||
@@ -117,17 +119,15 @@ async def scenes_handler(
|
||||
total = len(scenes)
|
||||
|
||||
# 如果提供了分页参数,进行分页处理
|
||||
if page is not None and page_size is not None:
|
||||
start_idx = (page - 1) * page_size
|
||||
end_idx = start_idx + page_size
|
||||
if page is not None and pagesize is not None:
|
||||
start_idx = (page - 1) * pagesize
|
||||
end_idx = start_idx + pagesize
|
||||
scenes = scenes[start_idx:end_idx]
|
||||
|
||||
# 构建响应
|
||||
items = []
|
||||
for scene in scenes:
|
||||
# 获取前3个class_name作为entity_type
|
||||
entity_type = [cls.class_name for cls in scene.classes[:3]] if scene.classes else None
|
||||
# 动态计算 type_num
|
||||
type_num = len(scene.classes) if scene.classes else 0
|
||||
|
||||
items.append(SceneResponse(
|
||||
@@ -139,17 +139,16 @@ async def scenes_handler(
|
||||
workspace_id=scene.workspace_id,
|
||||
created_at=scene.created_at,
|
||||
updated_at=scene.updated_at,
|
||||
classes_count=type_num
|
||||
classes_count=type_num,
|
||||
is_system_default=scene.is_system_default
|
||||
))
|
||||
|
||||
# 构建响应(包含分页信息)
|
||||
if page is not None and page_size is not None:
|
||||
# 计算是否有下一页
|
||||
hasnext = (page * page_size) < total
|
||||
|
||||
if page is not None and pagesize is not None:
|
||||
hasnext = (page * pagesize) < total
|
||||
pagination_info = PaginationInfo(
|
||||
page=page,
|
||||
pagesize=page_size,
|
||||
pagesize=pagesize,
|
||||
total=total,
|
||||
hasnext=hasnext
|
||||
)
|
||||
@@ -163,28 +162,25 @@ async def scenes_handler(
|
||||
)
|
||||
else:
|
||||
# 获取所有场景(支持分页)
|
||||
# 验证分页参数
|
||||
if page is not None and page < 1:
|
||||
api_logger.warning(f"Invalid page number: {page}")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "页码必须大于0")
|
||||
|
||||
if page_size is not None and page_size < 1:
|
||||
api_logger.warning(f"Invalid page_size: {page_size}")
|
||||
if pagesize is not None and pagesize < 1:
|
||||
api_logger.warning(f"Invalid pagesize: {pagesize}")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "每页数量必须大于0")
|
||||
|
||||
# 如果只提供了page或page_size中的一个,返回错误
|
||||
if (page is not None and page_size is None) or (page is None and page_size is not None):
|
||||
api_logger.warning(f"Incomplete pagination params: page={page}, page_size={page_size}")
|
||||
# 如果只提供了page或pagesize中的一个,返回错误
|
||||
if (page is not None and pagesize is None) or (page is None and pagesize is not None):
|
||||
api_logger.warning(f"Incomplete pagination params: page={page}, pagesize={pagesize}")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "分页参数page和pagesize必须同时提供")
|
||||
|
||||
scenes, total = service.list_scenes(ws_uuid, page, page_size)
|
||||
scenes, total = service.list_scenes(ws_uuid, page, pagesize)
|
||||
|
||||
# 构建响应
|
||||
items = []
|
||||
for scene in scenes:
|
||||
# 获取前3个class_name作为entity_type
|
||||
entity_type = [cls.class_name for cls in scene.classes[:3]] if scene.classes else None
|
||||
# 动态计算 type_num
|
||||
type_num = len(scene.classes) if scene.classes else 0
|
||||
|
||||
items.append(SceneResponse(
|
||||
@@ -196,17 +192,16 @@ async def scenes_handler(
|
||||
workspace_id=scene.workspace_id,
|
||||
created_at=scene.created_at,
|
||||
updated_at=scene.updated_at,
|
||||
classes_count=type_num
|
||||
classes_count=type_num,
|
||||
is_system_default=scene.is_system_default
|
||||
))
|
||||
|
||||
# 构建响应(包含分页信息)
|
||||
if page is not None and page_size is not None:
|
||||
# 计算是否有下一页
|
||||
hasnext = (page * page_size) < total
|
||||
|
||||
if page is not None and pagesize is not None:
|
||||
hasnext = (page * pagesize) < total
|
||||
pagination_info = PaginationInfo(
|
||||
page=page,
|
||||
pagesize=page_size,
|
||||
pagesize=pagesize,
|
||||
total=total,
|
||||
hasnext=hasnext
|
||||
)
|
||||
@@ -236,7 +231,8 @@ async def scenes_handler(
|
||||
async def create_class_handler(
|
||||
request: ClassCreateRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
current_user: User = Depends(get_current_user),
|
||||
x_language_type: Optional[str] = None
|
||||
):
|
||||
"""创建本体类型(统一使用列表形式,支持单个或批量)"""
|
||||
|
||||
@@ -269,8 +265,11 @@ async def create_class_handler(
|
||||
]
|
||||
|
||||
if count == 1:
|
||||
# 单个创建
|
||||
# 单个创建 - 先检查重名
|
||||
class_data = classes_data[0]
|
||||
existing = OntologyClassRepository(db).get_by_name(class_data["class_name"], request.scene_id)
|
||||
if existing:
|
||||
raise ValueError(f"DUPLICATE_CLASS_NAME:{class_data['class_name']}")
|
||||
ontology_class = service.create_class(
|
||||
scene_id=request.scene_id,
|
||||
class_name=class_data["class_name"],
|
||||
@@ -328,12 +327,36 @@ async def create_class_handler(
|
||||
return success(data=response.model_dump(mode='json'), msg="批量创建完成")
|
||||
|
||||
except ValueError as e:
|
||||
api_logger.warning(f"Validation error in class creation: {str(e)}")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", str(e))
|
||||
|
||||
err_str = str(e)
|
||||
if err_str.startswith("DUPLICATE_CLASS_NAME:"):
|
||||
class_name = err_str.split(":", 1)[1]
|
||||
api_logger.warning(f"Duplicate class name '{class_name}' in scene {request.scene_id}")
|
||||
from app.core.language_utils import get_language_from_header
|
||||
from fastapi.responses import JSONResponse
|
||||
lang = get_language_from_header(x_language_type)
|
||||
if lang == "en":
|
||||
msg = fail(BizCode.BAD_REQUEST, "Class name already exists", f"A class named \"{class_name}\" already exists in this scene. Please use a different name.")
|
||||
else:
|
||||
msg = fail(BizCode.BAD_REQUEST, "类型名称已存在", f"当前场景下已存在名为「{class_name}」的类型,请使用其他名称")
|
||||
return JSONResponse(status_code=400, content=msg)
|
||||
api_logger.warning(f"Validation error in class creation: {err_str}")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", err_str)
|
||||
|
||||
except RuntimeError as e:
|
||||
api_logger.error(f"Runtime error in class creation: {str(e)}", exc_info=True)
|
||||
return fail(BizCode.INTERNAL_ERROR, "类型创建失败", str(e))
|
||||
err_str = str(e)
|
||||
if "UniqueViolation" in err_str or "uq_scene_class_name" in err_str:
|
||||
api_logger.warning(f"Duplicate class name in scene {request.scene_id}")
|
||||
from app.core.language_utils import get_language_from_header
|
||||
from fastapi.responses import JSONResponse
|
||||
lang = get_language_from_header(x_language_type)
|
||||
class_name = request.classes[0].class_name if request.classes else ""
|
||||
if lang == "en":
|
||||
msg = fail(BizCode.BAD_REQUEST, "Class name already exists", f"A class named \"{class_name}\" already exists in this scene. Please use a different name.")
|
||||
else:
|
||||
msg = fail(BizCode.BAD_REQUEST, "类型名称已存在", f"当前场景下已存在名为「{class_name}」的类型,请使用其他名称")
|
||||
return JSONResponse(status_code=400, content=msg)
|
||||
api_logger.error(f"Runtime error in class creation: {err_str}", exc_info=True)
|
||||
return fail(BizCode.INTERNAL_ERROR, "类型创建失败", err_str)
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"Unexpected error in class creation: {str(e)}", exc_info=True)
|
||||
@@ -366,6 +389,20 @@ async def update_class_handler(
|
||||
api_logger.warning(f"User {current_user.id} has no current workspace")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "当前用户没有工作空间")
|
||||
|
||||
# 检查是否为系统默认类型
|
||||
class_repo = OntologyClassRepository(db)
|
||||
ontology_class = class_repo.get_by_id(class_uuid)
|
||||
if ontology_class and ontology_class.is_system_default:
|
||||
business_logger.warning(
|
||||
f"尝试修改系统默认类型: user_id={current_user.id}, "
|
||||
f"class_id={class_id}, class_name={ontology_class.class_name}"
|
||||
)
|
||||
return fail(
|
||||
BizCode.BAD_REQUEST,
|
||||
"系统默认类型不可修改",
|
||||
"该类型为系统预设类型,不允许修改"
|
||||
)
|
||||
|
||||
# 创建Service
|
||||
service = _get_dummy_ontology_service(db)
|
||||
|
||||
@@ -429,6 +466,20 @@ async def delete_class_handler(
|
||||
api_logger.warning(f"User {current_user.id} has no current workspace")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "当前用户没有工作空间")
|
||||
|
||||
# 检查是否为系统默认类型
|
||||
class_repo = OntologyClassRepository(db)
|
||||
ontology_class = class_repo.get_by_id(class_uuid)
|
||||
if ontology_class and ontology_class.is_system_default:
|
||||
business_logger.warning(
|
||||
f"尝试删除系统默认类型: user_id={current_user.id}, "
|
||||
f"class_id={class_id}, class_name={ontology_class.class_name}"
|
||||
)
|
||||
return fail(
|
||||
BizCode.BAD_REQUEST,
|
||||
"系统默认类型不可删除",
|
||||
"该类型为系统预设类型,不允许删除"
|
||||
)
|
||||
|
||||
# 创建Service
|
||||
service = _get_dummy_ontology_service(db)
|
||||
|
||||
@@ -585,6 +636,7 @@ async def classes_handler(
|
||||
scene_id=scene_uuid,
|
||||
scene_name=scene.scene_name,
|
||||
scene_description=scene.scene_description,
|
||||
is_system_default=scene.is_system_default,
|
||||
items=items
|
||||
)
|
||||
|
||||
|
||||
@@ -2,25 +2,33 @@ import hashlib
|
||||
import json
|
||||
import uuid
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends, Query, Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.core.response_utils import success
|
||||
from app.core.response_utils import success, fail
|
||||
from app.db import get_db, get_db_read
|
||||
from app.dependencies import get_share_user_id, ShareTokenData
|
||||
from app.models.app_model import AppType
|
||||
from app.repositories import knowledge_repository
|
||||
from app.repositories.end_user_repository import EndUserRepository
|
||||
from app.repositories.workflow_repository import WorkflowConfigRepository
|
||||
from app.schemas import release_share_schema, conversation_schema
|
||||
from app.schemas.response_schema import PageData, PageMeta
|
||||
from app.services import workspace_service
|
||||
from app.services.app_chat_service import AppChatService, get_app_chat_service
|
||||
from app.services.app_service import AppService
|
||||
from app.services.auth_service import create_access_token
|
||||
from app.services.conversation_service import ConversationService
|
||||
from app.services.release_share_service import ReleaseShareService
|
||||
from app.services.shared_chat_service import SharedChatService
|
||||
from app.services.app_chat_service import AppChatService, get_app_chat_service
|
||||
from app.utils.app_config_utils import dict_to_multi_agent_config, workflow_config_4_app_release, \
|
||||
from app.services.workflow_service import WorkflowService
|
||||
from app.models.file_metadata_model import FileMetadata
|
||||
from app.utils.app_config_utils import workflow_config_4_app_release, \
|
||||
agent_config_4_app_release, multi_agent_config_4_app_release
|
||||
|
||||
router = APIRouter(prefix="/public/share", tags=["Public Share"])
|
||||
@@ -206,15 +214,16 @@ def list_conversations(
|
||||
logger.debug(f"share_data:{share_data.user_id}")
|
||||
other_id = share_data.user_id
|
||||
service = SharedChatService(db)
|
||||
share, release = service._get_release_by_share_token(share_data.share_token, password)
|
||||
from app.repositories.end_user_repository import EndUserRepository
|
||||
share, release = service.get_release_by_share_token(share_data.share_token, password)
|
||||
end_user_repo = EndUserRepository(db)
|
||||
app_service = AppService(db)
|
||||
app = app_service._get_app_or_404(share.app_id)
|
||||
new_end_user = end_user_repo.get_or_create_end_user(
|
||||
app_id=share.app_id,
|
||||
workspace_id=app.workspace_id,
|
||||
other_id=other_id
|
||||
)
|
||||
logger.debug(new_end_user.id)
|
||||
service = SharedChatService(db)
|
||||
conversations, total = service.list_conversations(
|
||||
share_token=share_data.share_token,
|
||||
user_id=str(new_end_user.id),
|
||||
@@ -251,8 +260,41 @@ def get_conversation(
|
||||
conv_service = ConversationService(db)
|
||||
messages = conv_service.get_messages(conversation_id)
|
||||
|
||||
# 构建响应
|
||||
conv_dict = conversation_schema.Conversation.model_validate(conversation).model_dump()
|
||||
file_ids = []
|
||||
message_file_id_map = {}
|
||||
|
||||
# 第一次遍历:解析 audio_url,收集所有有效的 file_id
|
||||
for idx, m in enumerate(messages):
|
||||
if m.role == "assistant" and m.meta_data:
|
||||
audio_url = m.meta_data.get("audio_url")
|
||||
if not audio_url:
|
||||
continue
|
||||
try:
|
||||
file_id = uuid.UUID(audio_url.rstrip("/").split("/")[-1])
|
||||
except (ValueError, IndexError):
|
||||
# audio_url 无法解析为 UUID,标记为 unknown
|
||||
m.meta_data["audio_status"] = "unknown"
|
||||
continue
|
||||
|
||||
file_ids.append(file_id)
|
||||
message_file_id_map[idx] = file_id
|
||||
|
||||
# 批量查询所有相关的 FileMetadata
|
||||
file_status_map = {}
|
||||
if file_ids:
|
||||
file_metas = (
|
||||
db.query(FileMetadata)
|
||||
.filter(FileMetadata.id.in_(set(file_ids)))
|
||||
.all()
|
||||
)
|
||||
file_status_map = {fm.id: fm.status for fm in file_metas}
|
||||
|
||||
# 第二次遍历:将查询结果映射回消息
|
||||
for idx, file_id in message_file_id_map.items():
|
||||
m = messages[idx]
|
||||
m.meta_data["audio_status"] = file_status_map.get(file_id, "unknown")
|
||||
|
||||
conv_dict = conversation_schema.Conversation.model_validate(conversation).model_dump(mode="json")
|
||||
conv_dict["messages"] = [
|
||||
conversation_schema.Message.model_validate(m) for m in messages
|
||||
]
|
||||
@@ -293,40 +335,49 @@ async def chat(
|
||||
|
||||
# 提前验证和准备(在流式响应开始前完成)
|
||||
# 这样可以确保错误能正确返回,而不是在流式响应中间出错
|
||||
from app.models.app_model import AppType
|
||||
|
||||
try:
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.error_codes import BizCode
|
||||
from app.services.app_service import AppService
|
||||
# 验证分享链接和密码
|
||||
share, release = service._get_release_by_share_token(share_token, password)
|
||||
share, release = service.get_release_by_share_token(share_token, password)
|
||||
|
||||
# # Create end_user_id by concatenating app_id with user_id
|
||||
# end_user_id = f"{share.app_id}_{user_id}"
|
||||
|
||||
# Store end_user_id in database with original user_id
|
||||
from app.repositories.end_user_repository import EndUserRepository
|
||||
end_user_repo = EndUserRepository(db)
|
||||
app_service = AppService(db)
|
||||
app = app_service._get_app_or_404(share.app_id)
|
||||
workspace_id = app.workspace_id
|
||||
new_end_user = end_user_repo.get_or_create_end_user(
|
||||
app_id=share.app_id,
|
||||
workspace_id=workspace_id,
|
||||
other_id=other_id,
|
||||
original_user_id=user_id # Save original user_id to other_id
|
||||
original_user_id=user_id
|
||||
)
|
||||
|
||||
# Only extract and set memory_config_id when the end user doesn't have one yet
|
||||
if not new_end_user.memory_config_id:
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
memory_config_service = MemoryConfigService(db)
|
||||
memory_config_id, _ = memory_config_service.extract_memory_config_id(release.type, release.config or {})
|
||||
if memory_config_id:
|
||||
new_end_user.memory_config_id = memory_config_id
|
||||
db.commit()
|
||||
db.refresh(new_end_user)
|
||||
end_user_id = str(new_end_user.id)
|
||||
|
||||
appid = share.app_id
|
||||
# appid = share.app_id
|
||||
"""获取存储类型和工作空间的ID"""
|
||||
|
||||
# 直接通过 SQLAlchemy 查询 app(仅查询未删除的应用)
|
||||
from app.models.app_model import App
|
||||
app = db.query(App).filter(
|
||||
App.id == appid,
|
||||
App.is_active.is_(True)
|
||||
).first()
|
||||
if not app:
|
||||
raise BusinessException("应用不存在", BizCode.APP_NOT_FOUND)
|
||||
# app = db.query(App).filter(
|
||||
# App.id == appid,
|
||||
# App.is_active.is_(True)
|
||||
# ).first()
|
||||
# if not app:
|
||||
# raise BusinessException("应用不存在", BizCode.APP_NOT_FOUND)
|
||||
|
||||
workspace_id = app.workspace_id
|
||||
# workspace_id = app.workspace_id
|
||||
|
||||
# 直接从 workspace 获取 storage_type(公开分享场景无需权限检查)
|
||||
storage_type = workspace_service.get_workspace_storage_type_without_auth(
|
||||
@@ -359,12 +410,12 @@ async def chat(
|
||||
app_type = release.app.type if release.app else None
|
||||
|
||||
# 根据应用类型验证配置
|
||||
if app_type == "agent":
|
||||
if app_type == AppType.AGENT:
|
||||
# Agent 类型:验证模型配置
|
||||
model_config_id = release.default_model_config_id
|
||||
if not model_config_id:
|
||||
raise BusinessException("Agent 应用未配置模型", BizCode.AGENT_CONFIG_MISSING)
|
||||
elif app_type == "multi_agent":
|
||||
elif app_type == AppType.MULTI_AGENT:
|
||||
# Multi-Agent 类型:验证多 Agent 配置
|
||||
config = release.config or {}
|
||||
if not config.get("sub_agents"):
|
||||
@@ -610,11 +661,11 @@ async def chat(
|
||||
|
||||
# 多 Agent 非流式返回
|
||||
result = await app_chat_service.workflow_chat(
|
||||
|
||||
message=payload.message,
|
||||
conversation_id=conversation.id, # 使用已创建的会话 ID
|
||||
user_id=end_user_id, # 转换为字符串
|
||||
variables=payload.variables,
|
||||
files=payload.files,
|
||||
config=config,
|
||||
web_search=payload.web_search,
|
||||
memory=payload.memory,
|
||||
@@ -638,6 +689,39 @@ async def chat(
|
||||
# return success(data=conversation_schema.ChatResponse(**result).model_dump(mode="json"))
|
||||
|
||||
else:
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.error_codes import BizCode
|
||||
raise BusinessException(f"不支持的应用类型: {app_type}", BizCode.APP_TYPE_NOT_SUPPORTED)
|
||||
|
||||
|
||||
@router.get("/config", summary="获取应用启动配置")
|
||||
async def config_query(
|
||||
password: str = Query(None, description="访问密码"),
|
||||
share_data: ShareTokenData = Depends(get_share_user_id),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
share_service = SharedChatService(db)
|
||||
share_token = share_data.share_token
|
||||
share, release = share_service.get_release_by_share_token(share_token, password)
|
||||
if release.app.type == AppType.WORKFLOW:
|
||||
workflow_service = WorkflowService(db)
|
||||
content = {
|
||||
"app_type": release.app.type,
|
||||
"variables": workflow_service.get_start_node_variables(release.config),
|
||||
"memory": workflow_service.is_memory_enable(release.config),
|
||||
"features": release.config.get("features")
|
||||
}
|
||||
elif release.app.type == AppType.AGENT:
|
||||
content = {
|
||||
"app_type": release.app.type,
|
||||
"variables": release.config.get("variables"),
|
||||
"memory": release.config.get("memory", {}).get("enabled"),
|
||||
"features": release.config.get("features")
|
||||
}
|
||||
elif release.app.type == AppType.MULTI_AGENT:
|
||||
content = {
|
||||
"app_type": release.app.type,
|
||||
"variables": [],
|
||||
"features": release.config.get("features")
|
||||
}
|
||||
else:
|
||||
return fail(msg="Unsupported app type", code=BizCode.APP_TYPE_NOT_SUPPORTED)
|
||||
return success(data=content)
|
||||
|
||||
@@ -89,15 +89,14 @@ async def chat(
|
||||
body = await request.json()
|
||||
payload = AppChatRequest(**body)
|
||||
|
||||
other_id = payload.user_id
|
||||
app = app_service.get_app(api_key_auth.resource_id, api_key_auth.workspace_id)
|
||||
other_id = payload.user_id
|
||||
workspace_id = app.workspace_id
|
||||
workspace_id = api_key_auth.workspace_id
|
||||
end_user_repo = EndUserRepository(db)
|
||||
new_end_user = end_user_repo.get_or_create_end_user(
|
||||
app_id=app.id,
|
||||
workspace_id=workspace_id,
|
||||
other_id=other_id,
|
||||
original_user_id=other_id # Save original user_id to other_id
|
||||
)
|
||||
end_user_id = str(new_end_user.id)
|
||||
web_search = True
|
||||
@@ -135,7 +134,8 @@ async def chat(
|
||||
app_id=app.id,
|
||||
workspace_id=workspace_id,
|
||||
user_id=end_user_id,
|
||||
is_draft=False
|
||||
is_draft=False,
|
||||
conversation_id=payload.conversation_id
|
||||
)
|
||||
|
||||
if app_type == AppType.AGENT:
|
||||
@@ -249,6 +249,7 @@ async def chat(
|
||||
app_id=app.id,
|
||||
workspace_id=workspace_id,
|
||||
release_id=app.current_release.id,
|
||||
public=True
|
||||
):
|
||||
event_type = event.get("event", "message")
|
||||
event_data = event.get("data", {})
|
||||
@@ -279,6 +280,7 @@ async def chat(
|
||||
memory=memory,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
files=payload.files,
|
||||
app_id=app.id,
|
||||
workspace_id=workspace_id,
|
||||
release_id=app.current_release.id
|
||||
|
||||
@@ -6,6 +6,7 @@ from app.core.response_utils import success
|
||||
from app.db import get_db
|
||||
from app.schemas.api_key_schema import ApiKeyAuth
|
||||
from app.schemas.memory_api_schema import (
|
||||
ListConfigsResponse,
|
||||
MemoryReadRequest,
|
||||
MemoryReadResponse,
|
||||
MemoryWriteRequest,
|
||||
@@ -31,15 +32,16 @@ async def write_memory_api_service(
|
||||
request: Request,
|
||||
api_key_auth: ApiKeyAuth = None,
|
||||
db: Session = Depends(get_db),
|
||||
payload: MemoryWriteRequest = Body(..., embed=False),
|
||||
|
||||
message: str = Body(..., description="Message content"),
|
||||
):
|
||||
"""
|
||||
Write memory to storage.
|
||||
|
||||
Stores memory content for the specified end user using the Memory API Service.
|
||||
"""
|
||||
logger.info(f"Memory write request - end_user_id: {payload.end_user_id}, tenant_id: {api_key_auth.tenant_id}")
|
||||
body = await request.json()
|
||||
payload = MemoryWriteRequest(**body)
|
||||
logger.info(f"Memory write request - end_user_id: {payload.end_user_id}, workspace_id: {api_key_auth.workspace_id}")
|
||||
|
||||
memory_api_service = MemoryAPIService(db)
|
||||
|
||||
@@ -62,13 +64,15 @@ async def read_memory_api_service(
|
||||
request: Request,
|
||||
api_key_auth: ApiKeyAuth = None,
|
||||
db: Session = Depends(get_db),
|
||||
payload: MemoryReadRequest = Body(..., embed=False),
|
||||
message: str = Body(..., description="Query message"),
|
||||
):
|
||||
"""
|
||||
Read memory from storage.
|
||||
|
||||
Queries and retrieves memories for the specified end user with context-aware responses.
|
||||
"""
|
||||
body = await request.json()
|
||||
payload = MemoryReadRequest(**body)
|
||||
logger.info(f"Memory read request - end_user_id: {payload.end_user_id}")
|
||||
|
||||
memory_api_service = MemoryAPIService(db)
|
||||
@@ -85,3 +89,27 @@ async def read_memory_api_service(
|
||||
|
||||
logger.info(f"Memory read successful for end_user: {payload.end_user_id}")
|
||||
return success(data=MemoryReadResponse(**result).model_dump(), msg="Memory read successfully")
|
||||
|
||||
|
||||
@router.get("/configs")
|
||||
@require_api_key(scopes=["memory"])
|
||||
async def list_memory_configs(
|
||||
request: Request,
|
||||
api_key_auth: ApiKeyAuth = None,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
List all memory configs for the workspace.
|
||||
|
||||
Returns all available memory configurations associated with the authorized workspace.
|
||||
"""
|
||||
logger.info(f"List configs request - workspace_id: {api_key_auth.workspace_id}")
|
||||
|
||||
memory_api_service = MemoryAPIService(db)
|
||||
|
||||
result = memory_api_service.list_memory_configs(
|
||||
workspace_id=api_key_auth.workspace_id,
|
||||
)
|
||||
|
||||
logger.info(f"Listed {result['total']} configs for workspace: {api_key_auth.workspace_id}")
|
||||
return success(data=ListConfigsResponse(**result).model_dump(), msg="Configs listed successfully")
|
||||
|
||||
@@ -3,8 +3,11 @@ from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.error_codes import BizCode
|
||||
from app.schemas.tool_schema import (
|
||||
ToolCreateRequest, ToolUpdateRequest, ToolExecuteRequest, ParseSchemaRequest, CustomToolTestRequest
|
||||
ToolCreateRequest, ToolUpdateRequest, ToolExecuteRequest, ParseSchemaRequest,
|
||||
CustomToolTestRequest, ToolActiveUpdate
|
||||
)
|
||||
|
||||
from app.core.response_utils import success
|
||||
@@ -14,6 +17,7 @@ from app.models import User
|
||||
from app.models.tool_model import ToolType, ToolStatus, AuthType
|
||||
from app.services.tool_service import ToolService
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.core.exceptions import BusinessException
|
||||
|
||||
router = APIRouter(prefix="/tools", tags=["Tool System"])
|
||||
|
||||
@@ -72,6 +76,8 @@ async def get_tool_methods(
|
||||
if methods is None:
|
||||
raise HTTPException(status_code=404, detail="工具不存在")
|
||||
return success(data=methods, msg="获取工具方法成功")
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@@ -97,7 +103,13 @@ async def create_tool(
|
||||
):
|
||||
"""创建工具"""
|
||||
try:
|
||||
tool_id = service.create_tool(
|
||||
# 将 MCP 来源字段合并进 config
|
||||
if request.tool_type == ToolType.MCP:
|
||||
for key in ("source_channel", "market_id", "market_config_id", "mcp_service_id"):
|
||||
val = getattr(request, key, None)
|
||||
if val is not None:
|
||||
request.config[key] = val
|
||||
tool_id = await service.create_tool(
|
||||
name=request.name,
|
||||
tool_type=request.tool_type,
|
||||
tenant_id=current_user.tenant_id,
|
||||
@@ -107,8 +119,12 @@ async def create_tool(
|
||||
tags=request.tags
|
||||
)
|
||||
return success(data={"tool_id": tool_id}, msg="工具创建成功")
|
||||
except BusinessException as e:
|
||||
raise HTTPException(status_code=400, detail=e.message)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@@ -137,6 +153,8 @@ async def update_tool(
|
||||
return success(msg="工具更新成功")
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@@ -147,7 +165,7 @@ async def delete_tool(
|
||||
current_user: User = Depends(get_current_user),
|
||||
service: ToolService = Depends(get_tool_service)
|
||||
):
|
||||
"""删除工具"""
|
||||
"""删除工具(逻辑删除,is_active=False)"""
|
||||
try:
|
||||
success_flag = service.delete_tool(tool_id, current_user.tenant_id)
|
||||
if not success_flag:
|
||||
@@ -159,6 +177,32 @@ async def delete_tool(
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.patch("/{tool_id}/active", response_model=ApiResponse)
|
||||
async def set_tool_active(
|
||||
tool_id: str,
|
||||
request: ToolActiveUpdate,
|
||||
current_user: User = Depends(get_current_user),
|
||||
service: ToolService = Depends(get_tool_service)
|
||||
):
|
||||
"""设置工具可用状态(启用/禁用)
|
||||
|
||||
- is_active=true: 启用工具
|
||||
- is_active=false: 禁用工具(等同于删除,但可恢复)
|
||||
"""
|
||||
try:
|
||||
success_flag = service.set_tool_active(tool_id, current_user.tenant_id, request.is_active)
|
||||
if not success_flag:
|
||||
raise HTTPException(status_code=404, detail="工具不存在")
|
||||
action = "启用" if request.is_active else "禁用"
|
||||
return success(msg=f"工具已{action}")
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/execution/execute", response_model=ApiResponse)
|
||||
async def execute_tool(
|
||||
request: ToolExecuteRequest,
|
||||
@@ -187,6 +231,8 @@ async def execute_tool(
|
||||
},
|
||||
msg="工具执行完成"
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@@ -216,8 +262,10 @@ async def sync_mcp_tools(
|
||||
try:
|
||||
result = await service.sync_mcp_tools(tool_id, current_user.tenant_id)
|
||||
if not result.get("success", False):
|
||||
raise HTTPException(status_code=400, detail=result.get("message", "同步失败"))
|
||||
raise BusinessException(result.get("message", "工具列表同步失败"), BizCode.BAD_REQUEST)
|
||||
return success(data=result, msg="MCP工具列表同步完成")
|
||||
except BusinessException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@@ -240,8 +288,10 @@ async def test_tool_connection(
|
||||
# 普通连接测试
|
||||
result = await service.test_connection(tool_id, current_user.tenant_id)
|
||||
if result["success"] is False:
|
||||
raise HTTPException(status_code=400, detail=result["message"])
|
||||
raise BusinessException(result["message"], BizCode.SERVICE_UNAVAILABLE)
|
||||
return success(data=result, msg="连接测试完成")
|
||||
except BusinessException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from fastapi import APIRouter, Depends
|
||||
from sqlalchemy.orm import Session
|
||||
import uuid
|
||||
from typing import Callable
|
||||
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.exceptions import BusinessException
|
||||
@@ -19,6 +20,7 @@ from app.services import user_service
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.core.response_utils import success
|
||||
from app.core.security import verify_password
|
||||
from app.i18n.dependencies import get_translator
|
||||
|
||||
# 获取API专用日志器
|
||||
api_logger = get_api_logger()
|
||||
@@ -33,7 +35,8 @@ router = APIRouter(
|
||||
def create_superuser(
|
||||
user: user_schema.UserCreate,
|
||||
db: Session = Depends(get_db),
|
||||
current_superuser: User = Depends(get_current_superuser)
|
||||
current_superuser: User = Depends(get_current_superuser),
|
||||
t: Callable = Depends(get_translator)
|
||||
):
|
||||
"""创建超级管理员(仅超级管理员可访问)"""
|
||||
api_logger.info(f"超级管理员创建请求: {user.username}, email: {user.email}")
|
||||
@@ -42,7 +45,7 @@ def create_superuser(
|
||||
api_logger.info(f"超级管理员创建成功: {result.username} (ID: {result.id})")
|
||||
|
||||
result_schema = user_schema.User.model_validate(result)
|
||||
return success(data=result_schema, msg="超级管理员创建成功")
|
||||
return success(data=result_schema, msg=t("users.create.superuser_success"))
|
||||
|
||||
|
||||
@router.delete("/{user_id}", response_model=ApiResponse)
|
||||
@@ -50,6 +53,7 @@ def delete_user(
|
||||
user_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
t: Callable = Depends(get_translator)
|
||||
):
|
||||
"""停用用户(软删除)"""
|
||||
api_logger.info(f"用户停用请求: user_id={user_id}, 操作者: {current_user.username}")
|
||||
@@ -57,13 +61,14 @@ def delete_user(
|
||||
db=db, user_id_to_deactivate=user_id, current_user=current_user
|
||||
)
|
||||
api_logger.info(f"用户停用成功: {result.username} (ID: {result.id})")
|
||||
return success(msg="用户停用成功")
|
||||
return success(msg=t("users.delete.deactivate_success"))
|
||||
|
||||
@router.post("/{user_id}/activate", response_model=ApiResponse)
|
||||
def activate_user(
|
||||
user_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
t: Callable = Depends(get_translator)
|
||||
):
|
||||
"""激活用户"""
|
||||
api_logger.info(f"用户激活请求: user_id={user_id}, 操作者: {current_user.username}")
|
||||
@@ -74,13 +79,14 @@ def activate_user(
|
||||
api_logger.info(f"用户激活成功: {result.username} (ID: {result.id})")
|
||||
|
||||
result_schema = user_schema.User.model_validate(result)
|
||||
return success(data=result_schema, msg="用户激活成功")
|
||||
return success(data=result_schema, msg=t("users.activate.success"))
|
||||
|
||||
|
||||
@router.get("", response_model=ApiResponse)
|
||||
def get_current_user_info(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
t: Callable = Depends(get_translator)
|
||||
):
|
||||
"""获取当前用户信息"""
|
||||
api_logger.info(f"当前用户信息请求: {current_user.username}")
|
||||
@@ -105,7 +111,19 @@ def get_current_user_info(
|
||||
break
|
||||
|
||||
api_logger.info(f"当前用户信息获取成功: {result.username}, 角色: {result_schema.role}, 工作空间: {result_schema.current_workspace_name}")
|
||||
return success(data=result_schema, msg="用户信息获取成功")
|
||||
|
||||
# 设置权限:如果用户来自 SSO Source,则使用该 Source 的 permissions;否则返回 "all" 表示拥有所有权限
|
||||
if current_user.external_source:
|
||||
from premium.sso.models import SSOSource
|
||||
source = db.query(SSOSource).filter(SSOSource.source_code == current_user.external_source).first()
|
||||
if source and source.permissions:
|
||||
result_schema.permissions = source.permissions
|
||||
else:
|
||||
result_schema.permissions = []
|
||||
else:
|
||||
result_schema.permissions = ["all"]
|
||||
|
||||
return success(data=result_schema, msg=t("users.info.get_success"))
|
||||
|
||||
|
||||
@router.get("/superusers", response_model=ApiResponse)
|
||||
@@ -113,6 +131,7 @@ def get_tenant_superusers(
|
||||
include_inactive: bool = False,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_superuser),
|
||||
t: Callable = Depends(get_translator)
|
||||
):
|
||||
"""获取当前租户下的超管账号列表(仅超级管理员可访问)"""
|
||||
api_logger.info(f"获取租户超管列表请求: {current_user.username}")
|
||||
@@ -125,8 +144,7 @@ def get_tenant_superusers(
|
||||
api_logger.info(f"租户超管列表获取成功: count={len(superusers)}")
|
||||
|
||||
superusers_schema = [user_schema.User.model_validate(u) for u in superusers]
|
||||
return success(data=superusers_schema, msg="租户超管列表获取成功")
|
||||
|
||||
return success(data=superusers_schema, msg=t("users.list.superusers_success"))
|
||||
|
||||
|
||||
@router.get("/{user_id}", response_model=ApiResponse)
|
||||
@@ -134,6 +152,7 @@ def get_user_info_by_id(
|
||||
user_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
t: Callable = Depends(get_translator)
|
||||
):
|
||||
"""根据用户ID获取用户信息"""
|
||||
api_logger.info(f"获取用户信息请求: user_id={user_id}, 操作者: {current_user.username}")
|
||||
@@ -144,7 +163,7 @@ def get_user_info_by_id(
|
||||
api_logger.info(f"用户信息获取成功: {result.username}")
|
||||
|
||||
result_schema = user_schema.User.model_validate(result)
|
||||
return success(data=result_schema, msg="用户信息获取成功")
|
||||
return success(data=result_schema, msg=t("users.info.get_success"))
|
||||
|
||||
|
||||
@router.put("/change-password", response_model=ApiResponse)
|
||||
@@ -152,6 +171,7 @@ async def change_password(
|
||||
request: ChangePasswordRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
t: Callable = Depends(get_translator)
|
||||
):
|
||||
"""修改当前用户密码"""
|
||||
api_logger.info(f"用户密码修改请求: {current_user.username}")
|
||||
@@ -164,7 +184,7 @@ async def change_password(
|
||||
current_user=current_user
|
||||
)
|
||||
api_logger.info(f"用户密码修改成功: {current_user.username}")
|
||||
return success(msg="密码修改成功")
|
||||
return success(msg=t("auth.password.change_success"))
|
||||
|
||||
|
||||
@router.put("/admin/change-password", response_model=ApiResponse)
|
||||
@@ -172,6 +192,7 @@ async def admin_change_password(
|
||||
request: AdminChangePasswordRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_superuser),
|
||||
t: Callable = Depends(get_translator)
|
||||
):
|
||||
"""超级管理员修改指定用户的密码"""
|
||||
api_logger.info(f"管理员密码修改请求: 管理员 {current_user.username} 修改用户 {request.user_id}")
|
||||
@@ -186,16 +207,17 @@ async def admin_change_password(
|
||||
# 根据是否生成了随机密码来构造响应
|
||||
if request.new_password:
|
||||
api_logger.info(f"管理员密码修改成功: 用户 {request.user_id}")
|
||||
return success(msg="密码修改成功")
|
||||
return success(msg=t("auth.password.change_success"))
|
||||
else:
|
||||
api_logger.info(f"管理员密码重置成功: 用户 {request.user_id}, 随机密码已生成")
|
||||
return success(data=generated_password, msg="密码重置成功")
|
||||
return success(data=generated_password, msg=t("auth.password.reset_success"))
|
||||
|
||||
|
||||
@router.post("/verify_pwd", response_model=ApiResponse)
|
||||
def verify_pwd(
|
||||
request: VerifyPasswordRequest,
|
||||
current_user: User = Depends(get_current_user),
|
||||
t: Callable = Depends(get_translator)
|
||||
):
|
||||
"""验证当前用户密码"""
|
||||
api_logger.info(f"用户验证密码请求: {current_user.username}")
|
||||
@@ -203,8 +225,8 @@ def verify_pwd(
|
||||
is_valid = verify_password(request.password, current_user.hashed_password)
|
||||
api_logger.info(f"用户密码验证结果: {current_user.username}, valid={is_valid}")
|
||||
if not is_valid:
|
||||
raise BusinessException("密码验证失败", code=BizCode.VALIDATION_FAILED)
|
||||
return success(data={"valid": is_valid}, msg="验证完成")
|
||||
raise BusinessException(t("users.errors.password_verification_failed"), code=BizCode.VALIDATION_FAILED)
|
||||
return success(data={"valid": is_valid}, msg=t("common.success.retrieved"))
|
||||
|
||||
|
||||
@router.post("/send-email-code", response_model=ApiResponse)
|
||||
@@ -212,6 +234,7 @@ async def send_email_code(
|
||||
request: SendEmailCodeRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
t: Callable = Depends(get_translator)
|
||||
):
|
||||
"""发送邮箱验证码"""
|
||||
api_logger.info(f"用户请求发送邮箱验证码: {current_user.username}, email={request.email}")
|
||||
@@ -219,7 +242,7 @@ async def send_email_code(
|
||||
await user_service.send_email_code_method(db=db, email=request.email, user_id=current_user.id)
|
||||
|
||||
api_logger.info(f"邮箱验证码已发送: {current_user.username}")
|
||||
return success(msg="验证码已发送到您的邮箱,请查收")
|
||||
return success(msg=t("users.email.code_sent"))
|
||||
|
||||
|
||||
@router.put("/change-email", response_model=ApiResponse)
|
||||
@@ -227,6 +250,7 @@ async def change_email(
|
||||
request: VerifyEmailCodeRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
t: Callable = Depends(get_translator)
|
||||
):
|
||||
"""验证验证码并修改邮箱"""
|
||||
api_logger.info(f"用户修改邮箱: {current_user.username}, new_email={request.new_email}")
|
||||
@@ -239,4 +263,51 @@ async def change_email(
|
||||
)
|
||||
|
||||
api_logger.info(f"用户邮箱修改成功: {current_user.username}")
|
||||
return success(msg="邮箱修改成功")
|
||||
return success(msg=t("users.email.change_success"))
|
||||
|
||||
|
||||
|
||||
@router.get("/me/language", response_model=ApiResponse)
|
||||
def get_current_user_language(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
t: Callable = Depends(get_translator)
|
||||
):
|
||||
"""获取当前用户的语言偏好"""
|
||||
api_logger.info(f"获取用户语言偏好: {current_user.username}")
|
||||
|
||||
language = user_service.get_user_language_preference(
|
||||
db=db,
|
||||
user_id=current_user.id,
|
||||
current_user=current_user
|
||||
)
|
||||
|
||||
api_logger.info(f"用户语言偏好获取成功: {current_user.username}, language={language}")
|
||||
return success(
|
||||
data=user_schema.LanguagePreferenceResponse(language=language),
|
||||
msg=t("users.language.get_success")
|
||||
)
|
||||
|
||||
|
||||
@router.put("/me/language", response_model=ApiResponse)
|
||||
def update_current_user_language(
|
||||
request: user_schema.LanguagePreferenceRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
t: Callable = Depends(get_translator)
|
||||
):
|
||||
"""设置当前用户的语言偏好"""
|
||||
api_logger.info(f"更新用户语言偏好: {current_user.username}, language={request.language}")
|
||||
|
||||
updated_user = user_service.update_user_language_preference(
|
||||
db=db,
|
||||
user_id=current_user.id,
|
||||
language=request.language,
|
||||
current_user=current_user
|
||||
)
|
||||
|
||||
api_logger.info(f"用户语言偏好更新成功: {current_user.username}, language={request.language}")
|
||||
return success(
|
||||
data=user_schema.LanguagePreferenceResponse(language=updated_user.preferred_language),
|
||||
msg=t("users.language.update_success")
|
||||
)
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
from typing import Optional
|
||||
import datetime
|
||||
from sqlalchemy.orm import Session
|
||||
from fastapi import APIRouter, Depends,Header
|
||||
from fastapi import APIRouter, Depends, Header
|
||||
|
||||
from app.db import get_db
|
||||
from app.core.language_utils import get_language_from_header
|
||||
@@ -17,14 +17,17 @@ from app.services.user_memory_service import (
|
||||
UserMemoryService,
|
||||
analytics_memory_types,
|
||||
analytics_graph_data,
|
||||
analytics_community_graph_data,
|
||||
)
|
||||
from app.services.memory_entity_relationship_service import MemoryEntityService,MemoryEmotion,MemoryInteraction
|
||||
from app.services.memory_entity_relationship_service import MemoryEntityService, MemoryEmotion, MemoryInteraction
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.schemas.memory_storage_schema import GenerateCacheRequest
|
||||
from app.repositories.workspace_repository import WorkspaceRepository
|
||||
from app.schemas.end_user_schema import (
|
||||
EndUserProfileResponse,
|
||||
EndUserProfileUpdate,
|
||||
from app.repositories.end_user_repository import EndUserRepository
|
||||
from app.schemas.end_user_info_schema import (
|
||||
EndUserInfoResponse,
|
||||
EndUserInfoCreate,
|
||||
EndUserInfoUpdate,
|
||||
)
|
||||
from app.models.end_user_model import EndUser
|
||||
from app.dependencies import get_current_user
|
||||
@@ -44,9 +47,9 @@ router = APIRouter(
|
||||
|
||||
@router.get("/analytics/memory_insight/report", response_model=ApiResponse)
|
||||
async def get_memory_insight_report_api(
|
||||
end_user_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
end_user_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
"""
|
||||
获取缓存的记忆洞察报告
|
||||
@@ -72,10 +75,10 @@ async def get_memory_insight_report_api(
|
||||
|
||||
@router.get("/analytics/user_summary", response_model=ApiResponse)
|
||||
async def get_user_summary_api(
|
||||
end_user_id: str,
|
||||
language_type: str = Header(default=None, alias="X-Language-Type"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
end_user_id: str,
|
||||
language_type: str = Header(default=None, alias="X-Language-Type"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
"""
|
||||
获取缓存的用户摘要
|
||||
@@ -89,7 +92,7 @@ async def get_user_summary_api(
|
||||
"""
|
||||
# 使用集中化的语言校验
|
||||
language = get_language_from_header(language_type)
|
||||
|
||||
|
||||
workspace_id = current_user.current_workspace_id
|
||||
workspace_repo = WorkspaceRepository(db)
|
||||
workspace_models = workspace_repo.get_workspace_models_configs(workspace_id)
|
||||
@@ -101,7 +104,7 @@ async def get_user_summary_api(
|
||||
api_logger.info(f"用户摘要查询请求: end_user_id={end_user_id}, user={current_user.username}")
|
||||
try:
|
||||
# 调用服务层获取缓存数据
|
||||
result = await user_memory_service.get_cached_user_summary(db, end_user_id,model_id,language)
|
||||
result = await user_memory_service.get_cached_user_summary(db, end_user_id, model_id, language)
|
||||
|
||||
if result["is_cached"]:
|
||||
api_logger.info(f"成功返回缓存的用户摘要: end_user_id={end_user_id}")
|
||||
@@ -116,10 +119,10 @@ async def get_user_summary_api(
|
||||
|
||||
@router.post("/analytics/generate_cache", response_model=ApiResponse)
|
||||
async def generate_cache_api(
|
||||
request: GenerateCacheRequest,
|
||||
language_type: str = Header(default=None, alias="X-Language-Type"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
request: GenerateCacheRequest,
|
||||
language_type: str = Header(default=None, alias="X-Language-Type"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
"""
|
||||
手动触发缓存生成
|
||||
@@ -133,7 +136,7 @@ async def generate_cache_api(
|
||||
"""
|
||||
# 使用集中化的语言校验
|
||||
language = get_language_from_header(language_type)
|
||||
|
||||
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
# 检查用户是否已选择工作空间
|
||||
@@ -154,10 +157,12 @@ async def generate_cache_api(
|
||||
api_logger.info(f"开始为单个用户生成缓存: end_user_id={end_user_id}")
|
||||
|
||||
# 生成记忆洞察
|
||||
insight_result = await user_memory_service.generate_and_cache_insight(db, end_user_id, workspace_id, language=language)
|
||||
insight_result = await user_memory_service.generate_and_cache_insight(db, end_user_id, workspace_id,
|
||||
language=language)
|
||||
|
||||
# 生成用户摘要
|
||||
summary_result = await user_memory_service.generate_and_cache_summary(db, end_user_id, workspace_id, language=language)
|
||||
summary_result = await user_memory_service.generate_and_cache_summary(db, end_user_id, workspace_id,
|
||||
language=language)
|
||||
|
||||
# 构建响应
|
||||
result = {
|
||||
@@ -208,9 +213,9 @@ async def generate_cache_api(
|
||||
|
||||
@router.get("/analytics/node_statistics", response_model=ApiResponse)
|
||||
async def get_node_statistics_api(
|
||||
end_user_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
end_user_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
@@ -219,7 +224,8 @@ async def get_node_statistics_api(
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试查询节点统计但未选择工作空间")
|
||||
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
||||
|
||||
api_logger.info(f"记忆类型统计请求: end_user_id={end_user_id}, user={current_user.username}, workspace={workspace_id}")
|
||||
api_logger.info(
|
||||
f"记忆类型统计请求: end_user_id={end_user_id}, user={current_user.username}, workspace={workspace_id}")
|
||||
|
||||
try:
|
||||
# 调用新的记忆类型统计函数
|
||||
@@ -227,21 +233,23 @@ async def get_node_statistics_api(
|
||||
|
||||
# 计算总数用于日志
|
||||
total_count = sum(item["count"] for item in result)
|
||||
api_logger.info(f"成功获取记忆类型统计: end_user_id={end_user_id}, 总记忆数={total_count}, 类型数={len(result)}")
|
||||
api_logger.info(
|
||||
f"成功获取记忆类型统计: end_user_id={end_user_id}, 总记忆数={total_count}, 类型数={len(result)}")
|
||||
return success(data=result, msg="查询成功")
|
||||
except Exception as e:
|
||||
api_logger.error(f"记忆类型查询失败: end_user_id={end_user_id}, error={str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "记忆类型查询失败", str(e))
|
||||
|
||||
|
||||
@router.get("/analytics/graph_data", response_model=ApiResponse)
|
||||
async def get_graph_data_api(
|
||||
end_user_id: str,
|
||||
node_types: Optional[str] = None,
|
||||
limit: int = 100,
|
||||
depth: int = 1,
|
||||
center_node_id: Optional[str] = None,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
end_user_id: str,
|
||||
node_types: Optional[str] = None,
|
||||
limit: int = 100,
|
||||
depth: int = 1,
|
||||
center_node_id: Optional[str] = None,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
@@ -295,110 +303,165 @@ async def get_graph_data_api(
|
||||
return fail(BizCode.INTERNAL_ERROR, "图数据查询失败", str(e))
|
||||
|
||||
|
||||
@router.get("/read_end_user/profile", response_model=ApiResponse)
|
||||
async def get_end_user_profile(
|
||||
end_user_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
@router.get("/analytics/community_graph", response_model=ApiResponse)
|
||||
async def get_community_graph_data_api(
|
||||
end_user_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
workspace_id = current_user.current_workspace_id
|
||||
workspace_repo = WorkspaceRepository(db)
|
||||
workspace_models = workspace_repo.get_workspace_models_configs(workspace_id)
|
||||
|
||||
if workspace_models:
|
||||
model_id = workspace_models.get("llm", None)
|
||||
else:
|
||||
model_id = None
|
||||
# 检查用户是否已选择工作空间
|
||||
if workspace_id is None:
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试查询用户信息但未选择工作空间")
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试查询社区图谱但未选择工作空间")
|
||||
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
||||
|
||||
api_logger.info(
|
||||
f"用户信息查询请求: end_user_id={end_user_id}, user={current_user.username}, "
|
||||
f"社区图谱查询请求: end_user_id={end_user_id}, user={current_user.username}, "
|
||||
f"workspace={workspace_id}"
|
||||
)
|
||||
|
||||
try:
|
||||
# 查询终端用户
|
||||
end_user = db.query(EndUser).filter(EndUser.id == end_user_id).first()
|
||||
result = await analytics_community_graph_data(db=db, end_user_id=end_user_id)
|
||||
|
||||
if not end_user:
|
||||
api_logger.warning(f"终端用户不存在: end_user_id={end_user_id}")
|
||||
return fail(BizCode.INVALID_PARAMETER, "终端用户不存在", f"end_user_id={end_user_id}")
|
||||
# 构建响应数据
|
||||
profile_data = EndUserProfileResponse(
|
||||
id=end_user.id,
|
||||
other_name=end_user.other_name,
|
||||
position=end_user.position,
|
||||
department=end_user.department,
|
||||
contact=end_user.contact,
|
||||
phone=end_user.phone,
|
||||
hire_date=end_user.hire_date,
|
||||
updatetime_profile=end_user.updatetime_profile
|
||||
if "message" in result and result["statistics"]["total_nodes"] == 0:
|
||||
api_logger.warning(f"社区图谱查询返回空结果: {result.get('message')}")
|
||||
return success(data=result, msg=result.get("message", "查询成功"))
|
||||
|
||||
api_logger.info(
|
||||
f"成功获取社区图谱: end_user_id={end_user_id}, "
|
||||
f"nodes={result['statistics']['total_nodes']}, "
|
||||
f"edges={result['statistics']['total_edges']}"
|
||||
)
|
||||
|
||||
api_logger.info(f"成功获取用户信息: end_user_id={end_user_id}")
|
||||
return success(data=UserMemoryService.convert_profile_to_dict_with_timestamp(profile_data), msg="查询成功")
|
||||
return success(data=result, msg="查询成功")
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"用户信息查询失败: end_user_id={end_user_id}, error={str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "用户信息查询失败", str(e))
|
||||
api_logger.error(f"社区图谱查询失败: end_user_id={end_user_id}, error={str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "社区图谱查询失败", str(e))
|
||||
|
||||
#=======================终端用户信息接口=======================
|
||||
|
||||
@router.post("/updated_end_user/profile", response_model=ApiResponse)
|
||||
async def update_end_user_profile(
|
||||
profile_update: EndUserProfileUpdate,
|
||||
@router.get("/end_user_info", response_model=ApiResponse)
|
||||
async def get_end_user_info(
|
||||
end_user_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
"""
|
||||
更新终端用户的基本信息
|
||||
查询终端用户信息记录
|
||||
|
||||
该接口可以更新用户的姓名、职位、部门、联系方式、电话和入职日期等信息。
|
||||
所有字段都是可选的,只更新提供的字段。
|
||||
根据 end_user_id 查询单条终端用户信息记录。
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
end_user_id = profile_update.end_user_id
|
||||
|
||||
# 验证工作空间
|
||||
if workspace_id is None:
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试更新用户信息但未选择工作空间")
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试查询终端用户信息但未选择工作空间")
|
||||
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
||||
|
||||
api_logger.info(
|
||||
f"用户信息更新请求: end_user_id={end_user_id}, user={current_user.username}, "
|
||||
f"查询终端用户信息请求: end_user_id={end_user_id}, user={current_user.username}, "
|
||||
f"workspace={workspace_id}"
|
||||
)
|
||||
|
||||
# 调用 Service 层处理业务逻辑
|
||||
result = user_memory_service.update_end_user_profile(db, end_user_id, profile_update)
|
||||
# 校验 end_user 是否属于当前工作空间
|
||||
end_user_repo = EndUserRepository(db)
|
||||
end_user = end_user_repo.get_end_user_by_id(end_user_id)
|
||||
if end_user is None:
|
||||
return fail(BizCode.USER_NOT_FOUND, "终端用户不存在", "end_user not found")
|
||||
if str(end_user.workspace_id) != str(workspace_id):
|
||||
api_logger.warning(
|
||||
f"用户 {current_user.username} 尝试查询不属于工作空间 {workspace_id} 的终端用户 {end_user_id}"
|
||||
)
|
||||
return fail(BizCode.PERMISSION_DENIED, "该终端用户不属于当前工作空间", "end_user workspace mismatch")
|
||||
|
||||
result = user_memory_service.get_end_user_info(db, end_user_id)
|
||||
|
||||
if result["success"]:
|
||||
api_logger.info(f"成功更新用户信息: end_user_id={end_user_id}")
|
||||
api_logger.info(f"成功查询终端用户信息: end_user_id={end_user_id}")
|
||||
return success(data=result["data"], msg="查询成功")
|
||||
else:
|
||||
error_msg = result["error"]
|
||||
api_logger.error(f"查询终端用户信息失败: end_user_id={end_user_id}, error={error_msg}")
|
||||
|
||||
if error_msg == "终端用户信息记录不存在":
|
||||
return fail(BizCode.USER_NOT_FOUND, "终端用户信息记录不存在", error_msg)
|
||||
elif error_msg == "无效的终端用户ID格式":
|
||||
return fail(BizCode.INVALID_USER_ID, "无效的终端用户ID格式", error_msg)
|
||||
else:
|
||||
return fail(BizCode.INTERNAL_ERROR, "查询终端用户信息失败", error_msg)
|
||||
|
||||
|
||||
@router.post("/end_user_info/updated", response_model=ApiResponse)
|
||||
async def update_end_user_info(
|
||||
info_update: EndUserInfoUpdate,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
"""
|
||||
更新终端用户信息记录
|
||||
|
||||
根据 end_user_id 更新终端用户信息记录,支持批量更新多个别名。
|
||||
|
||||
示例请求体:
|
||||
{
|
||||
"end_user_id": "a1b2c3d4-e5f6-7890-abcd-ef1234567890",
|
||||
"other_name": "张三1",
|
||||
"aliases": ["小张", "张工"],
|
||||
"meta_data": {"position": "工程师", "department": "技术部"}
|
||||
}
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
end_user_id = info_update.end_user_id
|
||||
|
||||
if workspace_id is None:
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试更新终端用户信息但未选择工作空间")
|
||||
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
||||
|
||||
api_logger.info(
|
||||
f"更新终端用户信息请求: end_user_id={end_user_id}, user={current_user.username}, "
|
||||
f"workspace={workspace_id}"
|
||||
)
|
||||
|
||||
# 校验 end_user 是否属于当前工作空间
|
||||
end_user_repo = EndUserRepository(db)
|
||||
end_user = end_user_repo.get_end_user_by_id(end_user_id)
|
||||
if end_user is None:
|
||||
return fail(BizCode.USER_NOT_FOUND, "终端用户不存在", "end_user not found")
|
||||
if str(end_user.workspace_id) != str(workspace_id):
|
||||
api_logger.warning(
|
||||
f"用户 {current_user.username} 尝试更新不属于工作空间 {workspace_id} 的终端用户 {end_user_id}"
|
||||
)
|
||||
return fail(BizCode.PERMISSION_DENIED, "该终端用户不属于当前工作空间", "end_user workspace mismatch")
|
||||
|
||||
# 获取更新数据(排除 end_user_id)
|
||||
update_data = info_update.model_dump(exclude_unset=True, exclude={'end_user_id'})
|
||||
|
||||
result = user_memory_service.update_end_user_info(db, end_user_id, update_data)
|
||||
|
||||
if result["success"]:
|
||||
api_logger.info(f"成功更新终端用户信息: end_user_id={end_user_id}")
|
||||
return success(data=result["data"], msg="更新成功")
|
||||
else:
|
||||
error_msg = result["error"]
|
||||
api_logger.error(f"用户信息更新失败: end_user_id={end_user_id}, error={error_msg}")
|
||||
api_logger.error(f"终端用户信息更新失败: end_user_id={end_user_id}, error={error_msg}")
|
||||
|
||||
# 根据错误类型映射到合适的业务错误码
|
||||
if error_msg == "终端用户不存在":
|
||||
return fail(BizCode.USER_NOT_FOUND, "终端用户不存在", error_msg)
|
||||
elif error_msg == "无效的用户ID格式":
|
||||
return fail(BizCode.INVALID_USER_ID, "无效的用户ID格式", error_msg)
|
||||
if error_msg == "终端用户信息记录不存在":
|
||||
return fail(BizCode.USER_NOT_FOUND, "终端用户信息记录不存在", error_msg)
|
||||
elif error_msg == "无效的终端用户ID格式":
|
||||
return fail(BizCode.INVALID_USER_ID, "无效的终端用户ID格式", error_msg)
|
||||
else:
|
||||
# 只有未预期的错误才使用 INTERNAL_ERROR
|
||||
return fail(BizCode.INTERNAL_ERROR, "用户信息更新失败", error_msg)
|
||||
return fail(BizCode.INTERNAL_ERROR, "终端用户信息更新失败", error_msg)
|
||||
|
||||
@router.get("/memory_space/timeline_memories", response_model=ApiResponse)
|
||||
async def memory_space_timeline_of_shared_memories(id: str, label: str,language_type: str = Header(default=None, alias="X-Language-Type"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
async def memory_space_timeline_of_shared_memories(
|
||||
id: str, label: str,
|
||||
language_type: str = Header(default=None, alias="X-Language-Type"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
# 使用集中化的语言校验
|
||||
language = get_language_from_header(language_type)
|
||||
|
||||
workspace_id=current_user.current_workspace_id
|
||||
|
||||
workspace_id = current_user.current_workspace_id
|
||||
workspace_repo = WorkspaceRepository(db)
|
||||
workspace_models = workspace_repo.get_workspace_models_configs(workspace_id)
|
||||
|
||||
@@ -410,11 +473,13 @@ async def memory_space_timeline_of_shared_memories(id: str, label: str,language_
|
||||
timeline_memories_result = await MemoryEntity.get_timeline_memories_server(model_id, language)
|
||||
|
||||
return success(data=timeline_memories_result, msg="共同记忆时间线")
|
||||
|
||||
|
||||
@router.get("/memory_space/relationship_evolution", response_model=ApiResponse)
|
||||
async def memory_space_relationship_evolution(id: str, label: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
try:
|
||||
api_logger.info(f"关系演变查询请求: id={id}, table={label}, user={current_user.username}")
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import uuid
|
||||
from typing import List, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||||
from fastapi import APIRouter, Depends, Header, HTTPException, Query, status
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.logging_config import get_api_logger
|
||||
@@ -14,6 +14,12 @@ from app.dependencies import (
|
||||
get_current_user,
|
||||
workspace_access_guard,
|
||||
)
|
||||
from app.i18n.dependencies import get_current_language, get_translator
|
||||
from app.i18n.serializers import (
|
||||
WorkspaceSerializer,
|
||||
WorkspaceMemberSerializer,
|
||||
WorkspaceInviteSerializer
|
||||
)
|
||||
from app.models.tenant_model import Tenants
|
||||
from app.models.user_model import User
|
||||
from app.models.workspace_model import InviteStatus
|
||||
@@ -65,7 +71,9 @@ def get_workspaces(
|
||||
include_current: bool = Query(True, description="是否包含当前工作空间"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
current_tenant: Tenants = Depends(get_current_tenant)
|
||||
current_tenant: Tenants = Depends(get_current_tenant),
|
||||
language: str = Depends(get_current_language),
|
||||
t: callable = Depends(get_translator)
|
||||
):
|
||||
"""获取当前租户下用户参与的所有工作空间
|
||||
|
||||
@@ -88,25 +96,50 @@ def get_workspaces(
|
||||
)
|
||||
|
||||
api_logger.info(f"成功获取 {len(workspaces)} 个工作空间")
|
||||
workspaces_schema = [WorkspaceResponse.model_validate(w) for w in workspaces]
|
||||
return success(data=workspaces_schema, msg="工作空间列表获取成功")
|
||||
|
||||
# 使用序列化器添加国际化字段
|
||||
serializer = WorkspaceSerializer()
|
||||
workspaces_data = [WorkspaceResponse.model_validate(w).model_dump() for w in workspaces]
|
||||
workspaces_i18n = serializer.serialize_list(workspaces_data, language)
|
||||
|
||||
return success(data=workspaces_i18n, msg=t("workspace.list_retrieved"))
|
||||
|
||||
|
||||
@router.post("", response_model=ApiResponse)
|
||||
def create_workspace(
|
||||
workspace: WorkspaceCreate,
|
||||
language_type: str = Header(default="zh", alias="X-Language-Type"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_superuser),
|
||||
language: str = Depends(get_current_language),
|
||||
t: callable = Depends(get_translator)
|
||||
):
|
||||
"""创建新的工作空间"""
|
||||
api_logger.info(f"用户 {current_user.username} 请求创建工作空间: {workspace.name}")
|
||||
from app.core.language_utils import get_language_from_header
|
||||
|
||||
# 验证并获取语言参数
|
||||
language = get_language_from_header(language_type)
|
||||
|
||||
api_logger.info(
|
||||
f"用户 {current_user.username} 请求创建工作空间: {workspace.name}, "
|
||||
f"language={language}"
|
||||
)
|
||||
|
||||
result = workspace_service.create_workspace(
|
||||
db=db, workspace=workspace, user=current_user)
|
||||
db=db, workspace=workspace, user=current_user, language=language
|
||||
)
|
||||
|
||||
api_logger.info(f"工作空间创建成功 - 名称: {workspace.name}, ID: {result.id}, 创建者: {current_user.username}")
|
||||
result_schema = WorkspaceResponse.model_validate(result)
|
||||
return success(data=result_schema, msg="工作空间创建成功")
|
||||
api_logger.info(
|
||||
f"工作空间创建成功 - 名称: {workspace.name}, ID: {result.id}, "
|
||||
f"创建者: {current_user.username}, language={language}"
|
||||
)
|
||||
|
||||
# 使用序列化器添加国际化字段
|
||||
serializer = WorkspaceSerializer()
|
||||
result_data = WorkspaceResponse.model_validate(result).model_dump()
|
||||
result_i18n = serializer.serialize(result_data, language)
|
||||
|
||||
return success(data=result_i18n, msg=t("workspace.created"))
|
||||
|
||||
@router.put("", response_model=ApiResponse)
|
||||
@cur_workspace_access_guard()
|
||||
@@ -114,6 +147,8 @@ def update_workspace(
|
||||
workspace: WorkspaceUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
language: str = Depends(get_current_language),
|
||||
t: callable = Depends(get_translator)
|
||||
):
|
||||
"""更新工作空间"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
@@ -126,14 +161,21 @@ def update_workspace(
|
||||
user=current_user,
|
||||
)
|
||||
api_logger.info(f"工作空间更新成功 - ID: {workspace_id}, 用户: {current_user.username}")
|
||||
result_schema = WorkspaceResponse.model_validate(result)
|
||||
return success(data=result_schema, msg="工作空间更新成功")
|
||||
|
||||
# 使用序列化器添加国际化字段
|
||||
serializer = WorkspaceSerializer()
|
||||
result_data = WorkspaceResponse.model_validate(result).model_dump()
|
||||
result_i18n = serializer.serialize(result_data, language)
|
||||
|
||||
return success(data=result_i18n, msg=t("workspace.updated"))
|
||||
|
||||
@router.get("/members", response_model=ApiResponse)
|
||||
@cur_workspace_access_guard()
|
||||
def get_cur_workspace_members(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
language: str = Depends(get_current_language),
|
||||
t: callable = Depends(get_translator)
|
||||
):
|
||||
"""获取工作空间成员列表(关系序列化)"""
|
||||
api_logger.info(f"用户 {current_user.username} 请求获取工作空间 {current_user.current_workspace_id} 的成员列表")
|
||||
@@ -144,8 +186,14 @@ def get_cur_workspace_members(
|
||||
user=current_user,
|
||||
)
|
||||
api_logger.info(f"工作空间成员列表获取成功 - ID: {current_user.current_workspace_id}, 数量: {len(members)}")
|
||||
|
||||
# 转换为表格项并使用序列化器添加国际化字段
|
||||
table_items = _convert_members_to_table_items(members)
|
||||
return success(data=table_items, msg="工作空间成员列表获取成功")
|
||||
serializer = WorkspaceMemberSerializer()
|
||||
members_data = [item.model_dump() for item in table_items]
|
||||
members_i18n = serializer.serialize_list(members_data, language)
|
||||
|
||||
return success(data=members_i18n, msg=t("workspace.members.list_retrieved"))
|
||||
|
||||
|
||||
@router.put("/members", response_model=ApiResponse)
|
||||
@@ -155,6 +203,7 @@ def update_workspace_members(
|
||||
updates: List[WorkspaceMemberUpdate],
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
t: callable = Depends(get_translator)
|
||||
):
|
||||
workspace_id = current_user.current_workspace_id
|
||||
api_logger.info(f"用户 {current_user.username} 请求更新工作空间 {workspace_id} 的成员角色")
|
||||
@@ -165,7 +214,7 @@ def update_workspace_members(
|
||||
user=current_user,
|
||||
)
|
||||
api_logger.info(f"工作空间成员角色更新成功 - ID: {workspace_id}, 数量: {len(members)}")
|
||||
return success(msg="成员角色更新成功")
|
||||
return success(msg=t("workspace.members.role_updated"))
|
||||
|
||||
|
||||
@router.delete("/members/{member_id}", response_model=ApiResponse)
|
||||
@@ -174,6 +223,7 @@ def delete_workspace_member(
|
||||
member_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
t: callable = Depends(get_translator)
|
||||
):
|
||||
workspace_id = current_user.current_workspace_id
|
||||
api_logger.info(f"用户 {current_user.username} 请求删除工作空间 {workspace_id} 的成员 {member_id}")
|
||||
@@ -185,7 +235,7 @@ def delete_workspace_member(
|
||||
user=current_user,
|
||||
)
|
||||
api_logger.info(f"工作空间成员删除成功 - ID: {workspace_id}, 成员: {member_id}")
|
||||
return success(msg="成员删除成功")
|
||||
return success(msg=t("workspace.members.deleted"))
|
||||
|
||||
|
||||
# 创建空间协作邀请
|
||||
@@ -195,6 +245,8 @@ def create_workspace_invite(
|
||||
invite_data: WorkspaceInviteCreate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
language: str = Depends(get_current_language),
|
||||
t: callable = Depends(get_translator)
|
||||
):
|
||||
"""创建工作空间邀请"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
@@ -207,7 +259,12 @@ def create_workspace_invite(
|
||||
user=current_user
|
||||
)
|
||||
api_logger.info(f"工作空间邀请创建成功 - 工作空间: {workspace_id}, 邮箱: {invite_data.email}")
|
||||
return success(data=result, msg="邀请创建成功")
|
||||
|
||||
# 使用序列化器添加国际化字段
|
||||
serializer = WorkspaceInviteSerializer()
|
||||
result_i18n = serializer.serialize(result, language)
|
||||
|
||||
return success(data=result_i18n, msg=t("workspace.invites.created"))
|
||||
|
||||
|
||||
@router.get("/invites", response_model=ApiResponse)
|
||||
@@ -219,6 +276,8 @@ def get_workspace_invites(
|
||||
offset: int = Query(0, ge=0),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
language: str = Depends(get_current_language),
|
||||
t: callable = Depends(get_translator)
|
||||
):
|
||||
"""获取工作空间邀请列表"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
@@ -233,18 +292,30 @@ def get_workspace_invites(
|
||||
offset=offset
|
||||
)
|
||||
api_logger.info(f"成功获取 {len(invites)} 个邀请记录")
|
||||
return success(data=invites, msg="邀请列表获取成功")
|
||||
|
||||
# 使用序列化器添加国际化字段
|
||||
serializer = WorkspaceInviteSerializer()
|
||||
invites_i18n = serializer.serialize_list(invites, language)
|
||||
|
||||
return success(data=invites_i18n, msg=t("workspace.invites.list_retrieved"))
|
||||
|
||||
|
||||
@public_router.get("/invites/validate/{token}", response_model=ApiResponse)
|
||||
def get_workspace_invite_info(
|
||||
token: str,
|
||||
db: Session = Depends(get_db),
|
||||
language: str = Depends(get_current_language),
|
||||
t: callable = Depends(get_translator)
|
||||
):
|
||||
"""获取工作空间邀请用户信息(无需认证)"""
|
||||
result = workspace_service.validate_invite_token(db=db, token=token)
|
||||
api_logger.info(f"工作空间邀请验证成功 - 邀请: {token}")
|
||||
return success(data=result, msg="邀请验证成功")
|
||||
|
||||
# 使用序列化器添加国际化字段
|
||||
serializer = WorkspaceInviteSerializer()
|
||||
result_i18n = serializer.serialize(result, language)
|
||||
|
||||
return success(data=result_i18n, msg=t("workspace.invites.validated"))
|
||||
|
||||
|
||||
@router.delete("/invites/{invite_id}", response_model=ApiResponse)
|
||||
@@ -254,6 +325,8 @@ def revoke_workspace_invite(
|
||||
invite_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
language: str = Depends(get_current_language),
|
||||
t: callable = Depends(get_translator)
|
||||
):
|
||||
"""撤销工作空间邀请"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
@@ -266,7 +339,12 @@ def revoke_workspace_invite(
|
||||
user=current_user
|
||||
)
|
||||
api_logger.info(f"工作空间邀请撤销成功 - 邀请: {invite_id}")
|
||||
return success(data=result, msg="邀请撤销成功")
|
||||
|
||||
# 使用序列化器添加国际化字段
|
||||
serializer = WorkspaceInviteSerializer()
|
||||
result_i18n = serializer.serialize(result, language)
|
||||
|
||||
return success(data=result_i18n, msg=t("workspace.invites.revoked"))
|
||||
|
||||
# ==================== 公开邀请接口(无需认证) ====================
|
||||
|
||||
@@ -289,6 +367,7 @@ def switch_workspace(
|
||||
workspace_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
t: callable = Depends(get_translator)
|
||||
):
|
||||
"""切换工作空间"""
|
||||
api_logger.info(f"用户 {current_user.username} 请求切换工作空间为 {workspace_id}")
|
||||
@@ -299,7 +378,7 @@ def switch_workspace(
|
||||
user=current_user,
|
||||
)
|
||||
api_logger.info(f"成功切换工作空间为 {workspace_id}")
|
||||
return success(msg="工作空间切换成功")
|
||||
return success(msg=t("workspace.switched"))
|
||||
|
||||
|
||||
@router.get("/storage", response_model=ApiResponse)
|
||||
@@ -307,6 +386,7 @@ def switch_workspace(
|
||||
def get_workspace_storage_type(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
t: callable = Depends(get_translator)
|
||||
):
|
||||
"""获取当前工作空间的存储类型"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
@@ -318,7 +398,7 @@ def get_workspace_storage_type(
|
||||
user=current_user
|
||||
)
|
||||
api_logger.info(f"成功获取工作空间 {workspace_id} 的存储类型: {storage_type}")
|
||||
return success(data={"storage_type": storage_type}, msg="存储类型获取成功")
|
||||
return success(data={"storage_type": storage_type}, msg=t("workspace.storage.type_retrieved"))
|
||||
|
||||
|
||||
@router.get("/workspace_models", response_model=ApiResponse)
|
||||
@@ -326,6 +406,8 @@ def get_workspace_storage_type(
|
||||
def workspace_models_configs(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
language: str = Depends(get_current_language),
|
||||
t: callable = Depends(get_translator)
|
||||
):
|
||||
"""获取当前工作空间的模型配置(llm, embedding, rerank)"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
@@ -341,14 +423,14 @@ def workspace_models_configs(
|
||||
api_logger.warning(f"工作空间 {workspace_id} 不存在或无权访问")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="工作空间不存在或无权访问"
|
||||
detail=t("workspace.not_found")
|
||||
)
|
||||
|
||||
api_logger.info(
|
||||
f"成功获取工作空间 {workspace_id} 的模型配置: "
|
||||
f"llm={configs.get('llm')}, embedding={configs.get('embedding')}, rerank={configs.get('rerank')}"
|
||||
)
|
||||
return success(data=WorkspaceModelsConfig.model_validate(configs), msg="模型配置获取成功")
|
||||
return success(data=WorkspaceModelsConfig.model_validate(configs), msg=t("workspace.models.config_retrieved"))
|
||||
|
||||
|
||||
@router.put("/workspace_models", response_model=ApiResponse)
|
||||
@@ -357,6 +439,7 @@ def update_workspace_models_configs(
|
||||
models_update: WorkspaceModelsUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
t: callable = Depends(get_translator)
|
||||
):
|
||||
"""更新当前工作空间的模型配置(llm, embedding, rerank)"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
@@ -373,5 +456,5 @@ def update_workspace_models_configs(
|
||||
f"成功更新工作空间 {workspace_id} 的模型配置: "
|
||||
f"llm={updated_workspace.llm}, embedding={updated_workspace.embedding}, rerank={updated_workspace.rerank}"
|
||||
)
|
||||
return success(data=WorkspaceModelsConfig.model_validate(updated_workspace), msg="模型配置更新成功")
|
||||
return success(data=WorkspaceModelsConfig.model_validate(updated_workspace), msg=t("workspace.models.config_updated"))
|
||||
|
||||
|
||||
@@ -11,35 +11,37 @@ LangChain Agent 封装
|
||||
import time
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional, Sequence
|
||||
|
||||
from app.core.memory.agent.langgraph_graph.write_graph import write_long_term
|
||||
from app.core.memory.agent.langgraph_graph.write_graph import write_long_term
|
||||
from app.db import get_db
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.core.models import RedBearLLM, RedBearModelConfig
|
||||
from app.models.models_model import ModelType
|
||||
from app.models.models_model import ModelType, ModelProvider
|
||||
from app.services.memory_agent_service import (
|
||||
get_end_user_connected_config,
|
||||
)
|
||||
from langchain.agents import create_agent
|
||||
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
|
||||
from langchain_core.tools import BaseTool
|
||||
|
||||
logger = get_business_logger()
|
||||
|
||||
|
||||
class LangChainAgent:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str,
|
||||
api_key: str,
|
||||
provider: str = "openai",
|
||||
api_base: Optional[str] = None,
|
||||
temperature: float = 0.7,
|
||||
max_tokens: int = 2000,
|
||||
system_prompt: Optional[str] = None,
|
||||
tools: Optional[Sequence[BaseTool]] = None,
|
||||
streaming: bool = False,
|
||||
max_iterations: Optional[int] = None, # 最大迭代次数(None 表示自动计算)
|
||||
max_tool_consecutive_calls: int = 3 # 单个工具最大连续调用次数
|
||||
self,
|
||||
model_name: str,
|
||||
api_key: str,
|
||||
provider: str = "openai",
|
||||
api_base: Optional[str] = None,
|
||||
is_omni: bool = False,
|
||||
temperature: float = 0.7,
|
||||
max_tokens: int = 2000,
|
||||
system_prompt: Optional[str] = None,
|
||||
tools: Optional[Sequence[BaseTool]] = None,
|
||||
streaming: bool = False,
|
||||
max_iterations: Optional[int] = None, # 最大迭代次数(None 表示自动计算)
|
||||
max_tool_consecutive_calls: int = 3 # 单个工具最大连续调用次数
|
||||
):
|
||||
"""初始化 LangChain Agent
|
||||
|
||||
@@ -60,12 +62,13 @@ class LangChainAgent:
|
||||
self.provider = provider
|
||||
self.tools = tools or []
|
||||
self.streaming = streaming
|
||||
self.is_omni = is_omni
|
||||
self.max_tool_consecutive_calls = max_tool_consecutive_calls
|
||||
|
||||
|
||||
# 工具调用计数器:记录每个工具的连续调用次数
|
||||
self.tool_call_counter: Dict[str, int] = {}
|
||||
self.last_tool_called: Optional[str] = None
|
||||
|
||||
|
||||
# 根据工具数量动态调整最大迭代次数
|
||||
# 基础值 + 每个工具额外的调用机会
|
||||
if max_iterations is None:
|
||||
@@ -73,9 +76,9 @@ class LangChainAgent:
|
||||
self.max_iterations = 5 + len(self.tools) * 2
|
||||
else:
|
||||
self.max_iterations = max_iterations
|
||||
|
||||
|
||||
self.system_prompt = system_prompt or "你是一个专业的AI助手"
|
||||
|
||||
|
||||
logger.debug(
|
||||
f"Agent 迭代次数配置: max_iterations={self.max_iterations}, "
|
||||
f"tool_count={len(self.tools)}, "
|
||||
@@ -89,6 +92,7 @@ class LangChainAgent:
|
||||
provider=provider,
|
||||
api_key=api_key,
|
||||
base_url=api_base,
|
||||
is_omni=is_omni,
|
||||
extra_params={
|
||||
"temperature": temperature,
|
||||
"max_tokens": max_tokens,
|
||||
@@ -143,21 +147,22 @@ class LangChainAgent:
|
||||
"""
|
||||
from langchain_core.tools import StructuredTool
|
||||
from functools import wraps
|
||||
|
||||
|
||||
wrapped_tools = []
|
||||
|
||||
|
||||
for original_tool in tools:
|
||||
tool_name = original_tool.name
|
||||
original_func = original_tool.func if hasattr(original_tool, 'func') else None
|
||||
|
||||
|
||||
if not original_func:
|
||||
# 如果无法获取原始函数,直接使用原工具
|
||||
wrapped_tools.append(original_tool)
|
||||
continue
|
||||
|
||||
|
||||
# 创建包装函数
|
||||
def make_wrapped_func(tool_name, original_func):
|
||||
"""创建包装函数的工厂函数,避免闭包问题"""
|
||||
|
||||
@wraps(original_func)
|
||||
def wrapped_func(*args, **kwargs):
|
||||
"""包装后的工具函数,跟踪连续调用次数"""
|
||||
@@ -168,13 +173,13 @@ class LangChainAgent:
|
||||
# 切换到新工具,重置计数器
|
||||
self.tool_call_counter[tool_name] = 1
|
||||
self.last_tool_called = tool_name
|
||||
|
||||
|
||||
current_count = self.tool_call_counter[tool_name]
|
||||
|
||||
|
||||
logger.debug(
|
||||
f"工具调用: {tool_name}, 连续调用次数: {current_count}/{self.max_tool_consecutive_calls}"
|
||||
)
|
||||
|
||||
|
||||
# 检查是否超过最大连续调用次数
|
||||
if current_count > self.max_tool_consecutive_calls:
|
||||
logger.warning(
|
||||
@@ -185,12 +190,12 @@ class LangChainAgent:
|
||||
f"工具 '{tool_name}' 已连续调用 {self.max_tool_consecutive_calls} 次,"
|
||||
f"未找到有效结果。请尝试其他方法或直接回答用户的问题。"
|
||||
)
|
||||
|
||||
|
||||
# 调用原始工具函数
|
||||
return original_func(*args, **kwargs)
|
||||
|
||||
|
||||
return wrapped_func
|
||||
|
||||
|
||||
# 使用 StructuredTool 创建新工具
|
||||
wrapped_tool = StructuredTool(
|
||||
name=original_tool.name,
|
||||
@@ -198,17 +203,17 @@ class LangChainAgent:
|
||||
func=make_wrapped_func(tool_name, original_func),
|
||||
args_schema=original_tool.args_schema if hasattr(original_tool, 'args_schema') else None
|
||||
)
|
||||
|
||||
|
||||
wrapped_tools.append(wrapped_tool)
|
||||
|
||||
|
||||
return wrapped_tools
|
||||
|
||||
def _prepare_messages(
|
||||
self,
|
||||
message: str,
|
||||
history: Optional[List[Dict[str, str]]] = None,
|
||||
context: Optional[str] = None,
|
||||
files: Optional[List[Dict[str, Any]]] = None
|
||||
self,
|
||||
message: str,
|
||||
history: Optional[List[Dict[str, str]]] = None,
|
||||
context: Optional[str] = None,
|
||||
files: Optional[List[Dict[str, Any]]] = None
|
||||
) -> List[BaseMessage]:
|
||||
"""准备消息列表
|
||||
|
||||
@@ -248,7 +253,7 @@ class LangChainAgent:
|
||||
messages.append(HumanMessage(content=user_content))
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
def _build_multimodal_content(self, text: str, files: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
构建多模态消息内容
|
||||
@@ -261,23 +266,26 @@ class LangChainAgent:
|
||||
List[Dict]: 消息内容列表
|
||||
"""
|
||||
# 根据 provider 使用不同的文本格式
|
||||
if self.provider.lower() in ["bedrock", "anthropic"]:
|
||||
# Anthropic/Bedrock: {"type": "text", "text": "..."}
|
||||
content_parts = [{"type": "text", "text": text}]
|
||||
else:
|
||||
# 通义千问等: {"text": "..."}
|
||||
content_parts = [{"text": text}]
|
||||
|
||||
# if (self.provider.lower() in [ModelProvider.BEDROCK, ModelProvider.OPENAI, ModelProvider.XINFERENCE,
|
||||
# ModelProvider.GPUSTACK] or (
|
||||
# self.provider.lower() == ModelProvider.DASHSCOPE and self.is_omni)):
|
||||
# # Anthropic/Bedrock/Xinference/Gpustack/Openai: {"type": "text", "text": "..."}
|
||||
# content_parts = [{"type": "text", "text": text}]
|
||||
# else:
|
||||
# # 通义千问等: {"text": "..."}
|
||||
# content_parts = [{"type": "text", "text": text}]
|
||||
content_parts = [{"type": "text", "text": text}]
|
||||
|
||||
# 添加文件内容
|
||||
# MultimodalService 已经根据 provider 返回了正确格式,直接使用
|
||||
content_parts.extend(files)
|
||||
|
||||
|
||||
logger.debug(
|
||||
f"构建多模态消息: provider={self.provider}, "
|
||||
f"parts={len(content_parts)}, "
|
||||
f"files={len(files)}"
|
||||
)
|
||||
|
||||
|
||||
return content_parts
|
||||
|
||||
async def chat(
|
||||
@@ -302,7 +310,7 @@ class LangChainAgent:
|
||||
Returns:
|
||||
Dict: 包含 content 和元数据的字典
|
||||
"""
|
||||
message_chat= message
|
||||
message_chat = message
|
||||
start_time = time.time()
|
||||
actual_config_id = config_id
|
||||
# If config_id is None, try to get from end_user's connected config
|
||||
@@ -321,9 +329,8 @@ class LangChainAgent:
|
||||
db.close()
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get db session: {e}")
|
||||
actual_end_user_id = end_user_id if end_user_id is not None else "unknown"
|
||||
logger.info(f'写入类型{storage_type,str(end_user_id), message, str(user_rag_memory_id)}')
|
||||
print(f'写入类型{storage_type,str(end_user_id), message, str(user_rag_memory_id)}')
|
||||
logger.info(f'写入类型{storage_type, str(end_user_id), message, str(user_rag_memory_id)}')
|
||||
print(f'写入类型{storage_type, str(end_user_id), message, str(user_rag_memory_id)}')
|
||||
try:
|
||||
# 准备消息列表(支持多模态)
|
||||
messages = self._prepare_messages(message, history, context, files)
|
||||
@@ -367,14 +374,14 @@ class LangChainAgent:
|
||||
# 获取最后的 AI 消息
|
||||
output_messages = result.get("messages", [])
|
||||
content = ""
|
||||
|
||||
|
||||
logger.debug(f"输出消息数量: {len(output_messages)}")
|
||||
total_tokens = 0
|
||||
for msg in reversed(output_messages):
|
||||
if isinstance(msg, AIMessage):
|
||||
logger.debug(f"找到 AI 消息,content 类型: {type(msg.content)}")
|
||||
logger.debug(f"AI 消息内容: {msg.content}")
|
||||
|
||||
|
||||
# 处理多模态响应:content 可能是字符串或列表
|
||||
if isinstance(msg.content, str):
|
||||
content = msg.content
|
||||
@@ -407,12 +414,13 @@ class LangChainAgent:
|
||||
response_meta = msg.response_metadata if hasattr(msg, 'response_metadata') else None
|
||||
total_tokens = response_meta.get("token_usage", {}).get("total_tokens", 0) if response_meta else 0
|
||||
break
|
||||
|
||||
|
||||
logger.info(f"最终提取的内容长度: {len(content)}")
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
if memory_flag:
|
||||
await write_long_term(storage_type, end_user_id, message_chat, content, user_rag_memory_id, actual_config_id)
|
||||
await write_long_term(storage_type, end_user_id, message_chat, content, user_rag_memory_id,
|
||||
actual_config_id)
|
||||
response = {
|
||||
"content": content,
|
||||
"model": self.model_name,
|
||||
@@ -439,16 +447,16 @@ class LangChainAgent:
|
||||
raise
|
||||
|
||||
async def chat_stream(
|
||||
self,
|
||||
message: str,
|
||||
history: Optional[List[Dict[str, str]]] = None,
|
||||
context: Optional[str] = None,
|
||||
end_user_id:Optional[str] = None,
|
||||
config_id: Optional[str] = None,
|
||||
storage_type:Optional[str] = None,
|
||||
user_rag_memory_id:Optional[str] = None,
|
||||
memory_flag: Optional[bool] = True,
|
||||
files: Optional[List[Dict[str, Any]]] = None # 新增:多模态文件
|
||||
self,
|
||||
message: str,
|
||||
history: Optional[List[Dict[str, str]]] = None,
|
||||
context: Optional[str] = None,
|
||||
end_user_id: Optional[str] = None,
|
||||
config_id: Optional[str] = None,
|
||||
storage_type: Optional[str] = None,
|
||||
user_rag_memory_id: Optional[str] = None,
|
||||
memory_flag: Optional[bool] = True,
|
||||
files: Optional[List[Dict[str, Any]]] = None # 新增:多模态文件
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""执行流式对话
|
||||
|
||||
@@ -482,7 +490,6 @@ class LangChainAgent:
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get db session: {e}")
|
||||
|
||||
|
||||
# 注意:不在这里写入用户消息,等 AI 回复后一起写入
|
||||
try:
|
||||
# 准备消息列表(支持多模态)
|
||||
@@ -500,13 +507,13 @@ class LangChainAgent:
|
||||
full_content = ''
|
||||
try:
|
||||
async for event in self.agent.astream_events(
|
||||
{"messages": messages},
|
||||
version="v2",
|
||||
config={"recursion_limit": self.max_iterations}
|
||||
{"messages": messages},
|
||||
version="v2",
|
||||
config={"recursion_limit": self.max_iterations}
|
||||
):
|
||||
chunk_count += 1
|
||||
kind = event.get("event")
|
||||
|
||||
|
||||
# 处理所有可能的流式事件
|
||||
if kind == "on_chat_model_stream":
|
||||
# LLM 流式输出
|
||||
@@ -540,7 +547,7 @@ class LangChainAgent:
|
||||
full_content += item
|
||||
yield item
|
||||
yielded_content = True
|
||||
|
||||
|
||||
elif kind == "on_llm_stream":
|
||||
# 另一种 LLM 流式事件
|
||||
chunk = event.get("data", {}).get("chunk")
|
||||
@@ -577,25 +584,28 @@ class LangChainAgent:
|
||||
full_content += chunk
|
||||
yield chunk
|
||||
yielded_content = True
|
||||
|
||||
|
||||
# 记录工具调用(可选)
|
||||
elif kind == "on_tool_start":
|
||||
logger.debug(f"工具调用开始: {event.get('name')}")
|
||||
elif kind == "on_tool_end":
|
||||
logger.debug(f"工具调用结束: {event.get('name')}")
|
||||
|
||||
|
||||
logger.debug(f"Agent 流式完成,共 {chunk_count} 个事件")
|
||||
# 统计token消耗
|
||||
output_messages = event.get("data", {}).get("output", {}).get("messages", [])
|
||||
for msg in reversed(output_messages):
|
||||
if isinstance(msg, AIMessage):
|
||||
response_meta = msg.response_metadata if hasattr(msg, 'response_metadata') else None
|
||||
total_tokens = response_meta.get("token_usage", {}).get("total_tokens",
|
||||
0) if response_meta else 0
|
||||
total_tokens = response_meta.get("token_usage", {}).get(
|
||||
"total_tokens",
|
||||
0
|
||||
) if response_meta else 0
|
||||
yield total_tokens
|
||||
break
|
||||
if memory_flag:
|
||||
await write_long_term(storage_type, end_user_id, message_chat, full_content, user_rag_memory_id, actual_config_id)
|
||||
await write_long_term(storage_type, end_user_id, message_chat, full_content, user_rag_memory_id,
|
||||
actual_config_id)
|
||||
except Exception as e:
|
||||
logger.error(f"Agent astream_events 失败: {str(e)}", exc_info=True)
|
||||
raise
|
||||
@@ -609,5 +619,3 @@ class LangChainAgent:
|
||||
logger.info("=" * 80)
|
||||
logger.info("chat_stream 方法执行结束")
|
||||
logger.info("=" * 80)
|
||||
|
||||
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Annotated, Optional
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from pydantic import Field, TypeAdapter
|
||||
|
||||
load_dotenv()
|
||||
|
||||
@@ -16,18 +16,18 @@ class Settings:
|
||||
# cloud: SaaS 云服务版(全功能,按量计费)
|
||||
# enterprise: 企业私有化版(License 控制)
|
||||
DEPLOYMENT_MODE: str = os.getenv("DEPLOYMENT_MODE", "community")
|
||||
|
||||
|
||||
# License 配置(企业版)
|
||||
LICENSE_FILE: str = os.getenv("LICENSE_FILE", "/etc/app/license.json")
|
||||
LICENSE_SERVER_URL: str = os.getenv("LICENSE_SERVER_URL", "https://license.yourcompany.com")
|
||||
|
||||
|
||||
# 计费服务配置(SaaS 版)
|
||||
BILLING_SERVICE_URL: str = os.getenv("BILLING_SERVICE_URL", "")
|
||||
|
||||
|
||||
# 基础 URL(用于 SSO 回调等)
|
||||
BASE_URL: str = os.getenv("BASE_URL", "http://localhost:8000")
|
||||
FRONTEND_URL: str = os.getenv("FRONTEND_URL", "http://localhost:3000")
|
||||
|
||||
|
||||
ENABLE_SINGLE_WORKSPACE: bool = os.getenv("ENABLE_SINGLE_WORKSPACE", "true").lower() == "true"
|
||||
# API Keys Configuration
|
||||
OPENAI_API_KEY: str = os.getenv("OPENAI_API_KEY", "")
|
||||
@@ -57,7 +57,6 @@ class Settings:
|
||||
REDIS_PORT: int = int(os.getenv("REDIS_PORT", "6379"))
|
||||
REDIS_DB: int = int(os.getenv("REDIS_DB", "1"))
|
||||
REDIS_PASSWORD: str = os.getenv("REDIS_PASSWORD", "")
|
||||
|
||||
|
||||
# ElasticSearch configuration
|
||||
ELASTICSEARCH_HOST: str = os.getenv("ELASTICSEARCH_HOST", "https://127.0.0.1")
|
||||
@@ -91,13 +90,14 @@ class Settings:
|
||||
|
||||
# Single Sign-On configuration
|
||||
ENABLE_SINGLE_SESSION: bool = os.getenv("ENABLE_SINGLE_SESSION", "false").lower() == "true"
|
||||
|
||||
|
||||
# SSO 免登配置
|
||||
SSO_TOKEN_EXPIRE_SECONDS: int = int(os.getenv("SSO_TOKEN_EXPIRE_SECONDS", "300"))
|
||||
SSO_TRUSTED_SOURCES_CONFIG: str = os.getenv("SSO_TRUSTED_SOURCES_CONFIG", "{}")
|
||||
|
||||
# File Upload
|
||||
MAX_FILE_SIZE: int = int(os.getenv("MAX_FILE_SIZE", "52428800"))
|
||||
MAX_FILE_COUNT: int = int(os.getenv("MAX_FILE_COUNT", "20"))
|
||||
FILE_PATH: str = os.getenv("FILE_PATH", "/files")
|
||||
FILE_URL_EXPIRES: int = int(os.getenv("FILE_URL_EXPIRES", "3600"))
|
||||
|
||||
@@ -115,6 +115,7 @@ class Settings:
|
||||
S3_ACCESS_KEY_ID: str = os.getenv("S3_ACCESS_KEY_ID", "")
|
||||
S3_SECRET_ACCESS_KEY: str = os.getenv("S3_SECRET_ACCESS_KEY", "")
|
||||
S3_BUCKET_NAME: str = os.getenv("S3_BUCKET_NAME", "")
|
||||
S3_ENDPOINT_URL: str = os.getenv("S3_ENDPOINT_URL", "")
|
||||
|
||||
# VOLC ASR settings
|
||||
VOLC_APP_KEY: str = os.getenv("VOLC_APP_KEY", "")
|
||||
@@ -130,7 +131,7 @@ class Settings:
|
||||
|
||||
# Server Configuration
|
||||
SERVER_IP: str = os.getenv("SERVER_IP", "127.0.0.1")
|
||||
FILE_LOCAL_SERVER_URL : str = os.getenv("FILE_LOCAL_SERVER_URL", "http://localhost:8000/api")
|
||||
FILE_LOCAL_SERVER_URL: str = os.getenv("FILE_LOCAL_SERVER_URL", "http://localhost:8000/api")
|
||||
|
||||
# ========================================================================
|
||||
# Internal Configuration (not in .env, used by application code)
|
||||
@@ -162,6 +163,44 @@ class Settings:
|
||||
# This controls the language used for memory summary titles and other generated content
|
||||
DEFAULT_LANGUAGE: str = os.getenv("DEFAULT_LANGUAGE", "zh")
|
||||
|
||||
# ========================================================================
|
||||
# Internationalization (i18n) Configuration
|
||||
# ========================================================================
|
||||
# Default language for API responses
|
||||
I18N_DEFAULT_LANGUAGE: str = os.getenv("I18N_DEFAULT_LANGUAGE", "zh")
|
||||
|
||||
# Supported languages (comma-separated)
|
||||
I18N_SUPPORTED_LANGUAGES: list[str] = [
|
||||
lang.strip()
|
||||
for lang in os.getenv("I18N_SUPPORTED_LANGUAGES", "zh,en").split(",")
|
||||
if lang.strip()
|
||||
]
|
||||
|
||||
# Core locales directory (community edition)
|
||||
# Use absolute path to work from any working directory
|
||||
I18N_CORE_LOCALES_DIR: str = os.getenv(
|
||||
"I18N_CORE_LOCALES_DIR",
|
||||
os.path.join(os.path.dirname(os.path.dirname(__file__)), "locales")
|
||||
)
|
||||
|
||||
# Premium locales directory (enterprise edition, optional)
|
||||
I18N_PREMIUM_LOCALES_DIR: Optional[str] = os.getenv("I18N_PREMIUM_LOCALES_DIR", None)
|
||||
|
||||
# Enable translation cache
|
||||
I18N_ENABLE_TRANSLATION_CACHE: bool = os.getenv("I18N_ENABLE_TRANSLATION_CACHE", "true").lower() == "true"
|
||||
|
||||
# LRU cache size for hot translations
|
||||
I18N_LRU_CACHE_SIZE: int = int(os.getenv("I18N_LRU_CACHE_SIZE", "1000"))
|
||||
|
||||
# Enable hot reload of translation files
|
||||
I18N_ENABLE_HOT_RELOAD: bool = os.getenv("I18N_ENABLE_HOT_RELOAD", "false").lower() == "true"
|
||||
|
||||
# Fallback language when translation is missing
|
||||
I18N_FALLBACK_LANGUAGE: str = os.getenv("I18N_FALLBACK_LANGUAGE", "zh")
|
||||
|
||||
# Log missing translations
|
||||
I18N_LOG_MISSING_TRANSLATIONS: bool = os.getenv("I18N_LOG_MISSING_TRANSLATIONS", "true").lower() == "true"
|
||||
|
||||
# Logging settings
|
||||
LOG_LEVEL: str = os.getenv("LOG_LEVEL", "INFO")
|
||||
LOG_FORMAT: str = os.getenv("LOG_FORMAT", "%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
||||
@@ -190,8 +229,12 @@ class Settings:
|
||||
LOG_FILE_MAX_SIZE_MB: int = int(os.getenv("LOG_FILE_MAX_SIZE_MB", "10")) # 10MB
|
||||
|
||||
# Celery configuration (internal)
|
||||
CELERY_BROKER: int = int(os.getenv("CELERY_BROKER", "1"))
|
||||
CELERY_BACKEND: int = int(os.getenv("CELERY_BACKEND", "2"))
|
||||
# NOTE: 变量名不以 CELERY_ 开头,避免被 Celery CLI 的前缀匹配机制劫持
|
||||
# 详见 docs/celery-env-bug-report.md
|
||||
# 默认使用 Redis 作为 broker 和 backend,与业务缓存隔离
|
||||
# 如需使用 RabbitMQ,在 .env 中设置 CELERY_BROKER_URL=amqp://user:pass@host:5672/vhost
|
||||
REDIS_DB_CELERY_BROKER: int = int(os.getenv("REDIS_DB_CELERY_BROKER", "3"))
|
||||
REDIS_DB_CELERY_BACKEND: int = int(os.getenv("REDIS_DB_CELERY_BACKEND", "4"))
|
||||
|
||||
# SMTP Email Configuration
|
||||
SMTP_SERVER: str = os.getenv("SMTP_SERVER", "smtp.gmail.com")
|
||||
@@ -201,14 +244,30 @@ class Settings:
|
||||
|
||||
REFLECTION_INTERVAL_SECONDS: float = float(os.getenv("REFLECTION_INTERVAL_SECONDS", "300"))
|
||||
HEALTH_CHECK_SECONDS: float = float(os.getenv("HEALTH_CHECK_SECONDS", "600"))
|
||||
MEMORY_INCREMENT_INTERVAL_HOURS: float = float(os.getenv("MEMORY_INCREMENT_INTERVAL_HOURS", "24"))
|
||||
DEFAULT_WORKSPACE_ID: Optional[str] = os.getenv("DEFAULT_WORKSPACE_ID", None)
|
||||
REFLECTION_INTERVAL_TIME: Optional[str] = int(os.getenv("REFLECTION_INTERVAL_TIME", 30))
|
||||
|
||||
# Memory Cache Regeneration Configuration
|
||||
MEMORY_CACHE_REGENERATION_HOURS: int = int(os.getenv("MEMORY_CACHE_REGENERATION_HOURS", "24"))
|
||||
|
||||
# Celery Beat Schedule Configuration (定时任务执行频率)
|
||||
MEMORY_INCREMENT_HOUR: int = TypeAdapter(
|
||||
Annotated[int, Field(ge=0, le=23, description="cron hour [0, 23]")]
|
||||
).validate_python(int(os.getenv("MEMORY_INCREMENT_HOUR", "2")))
|
||||
MEMORY_INCREMENT_MINUTE: int = TypeAdapter(
|
||||
Annotated[int, Field(ge=0, le=59, description="cron minute [0, 59]")]
|
||||
).validate_python(int(os.getenv("MEMORY_INCREMENT_MINUTE", "0")))
|
||||
WORKSPACE_REFLECTION_INTERVAL_SECONDS: int = TypeAdapter(
|
||||
Annotated[int, Field(ge=1, description="reflection interval in seconds, must be >= 1")]
|
||||
).validate_python(int(os.getenv("WORKSPACE_REFLECTION_INTERVAL_SECONDS", "30")))
|
||||
FORGETTING_CYCLE_INTERVAL_HOURS: int = TypeAdapter(
|
||||
Annotated[int, Field(ge=1, description="forgetting cycle interval in hours, must be >= 1")]
|
||||
).validate_python(int(os.getenv("FORGETTING_CYCLE_INTERVAL_HOURS", "24")))
|
||||
|
||||
IMPLICIT_EMOTIONS_UPDATE_HOUR: int = int(os.getenv("IMPLICIT_EMOTIONS_UPDATE_HOUR", "2"))
|
||||
# implicit_emotions_update: 每天几分执行(分钟,0-59)
|
||||
IMPLICIT_EMOTIONS_UPDATE_MINUTE: int = int(os.getenv("IMPLICIT_EMOTIONS_UPDATE_MINUTE", "0"))
|
||||
# Memory Module Configuration (internal)
|
||||
|
||||
MEMORY_OUTPUT_DIR: str = os.getenv("MEMORY_OUTPUT_DIR", "logs/memory-output")
|
||||
MEMORY_CONFIG_DIR: str = os.getenv("MEMORY_CONFIG_DIR", "app/core/memory")
|
||||
|
||||
@@ -225,27 +284,28 @@ class Settings:
|
||||
LOAD_MODEL: bool = os.getenv("LOAD_MODEL", "false").lower() == "true"
|
||||
|
||||
# workflow config
|
||||
WORKFLOW_IMPORT_CACHE_TIMEOUT: int = int(os.getenv("WORKFLOW_IMPORT_CACHE_TIMEOUT", 1800))
|
||||
WORKFLOW_NODE_TIMEOUT: int = int(os.getenv("WORKFLOW_NODE_TIMEOUT", 600))
|
||||
|
||||
# ========================================================================
|
||||
# General Ontology Type Configuration
|
||||
# ========================================================================
|
||||
# 通用本体文件路径列表(逗号分隔)
|
||||
GENERAL_ONTOLOGY_FILES: str = os.getenv("GENERAL_ONTOLOGY_FILES", "General_purpose_entity.ttl")
|
||||
|
||||
GENERAL_ONTOLOGY_FILES: str = os.getenv("GENERAL_ONTOLOGY_FILES", "api/app/core/memory/ontology_services/General_purpose_entity.ttl")
|
||||
|
||||
# 是否启用通用本体类型功能
|
||||
ENABLE_GENERAL_ONTOLOGY_TYPES: bool = os.getenv("ENABLE_GENERAL_ONTOLOGY_TYPES", "true").lower() == "true"
|
||||
|
||||
|
||||
# Prompt 中最大类型数量
|
||||
MAX_ONTOLOGY_TYPES_IN_PROMPT: int = int(os.getenv("MAX_ONTOLOGY_TYPES_IN_PROMPT", "50"))
|
||||
|
||||
|
||||
# 核心通用类型列表(逗号分隔)
|
||||
CORE_GENERAL_TYPES: str = os.getenv(
|
||||
"CORE_GENERAL_TYPES",
|
||||
"Person,Organization,Company,GovernmentAgency,Place,Location,City,Country,Building,"
|
||||
"Event,SportsEvent,SocialEvent,Work,Book,Film,Software,Concept,TopicalConcept,AcademicSubject"
|
||||
)
|
||||
|
||||
|
||||
# 实验模式开关(允许通过 API 动态切换本体配置)
|
||||
ONTOLOGY_EXPERIMENT_MODE: bool = os.getenv("ONTOLOGY_EXPERIMENT_MODE", "true").lower() == "true"
|
||||
|
||||
|
||||
@@ -529,8 +529,9 @@ def log_time(step_name: str, duration: float, log_file: str = "logs/time.log") -
|
||||
# Fallback to console only if file write fails
|
||||
print(f"Warning: Could not write to timing log: {e}")
|
||||
|
||||
# Always print to console (backward compatible behavior)
|
||||
print(f"✓ {step_name}: {duration:.2f}s")
|
||||
# Always log at INFO level (avoids Celery treating stdout as WARNING)
|
||||
_timing_logger = logging.getLogger(__name__)
|
||||
_timing_logger.info(f"✓ {step_name}: {duration:.2f}s")
|
||||
|
||||
|
||||
def get_agent_logger(name: str = "agent_service",
|
||||
|
||||
@@ -1,16 +1,45 @@
|
||||
from app.core.memory.agent.utils.llm_tools import ReadState, WriteState
|
||||
from app.schemas.memory_agent_schema import AgentMemoryDataset
|
||||
|
||||
|
||||
def content_input_node(state: ReadState) -> ReadState:
|
||||
"""开始节点 - 提取内容并保持状态信息"""
|
||||
"""
|
||||
Start node - Extract content and maintain state information
|
||||
|
||||
Extracts the content from the first message in the state and returns it
|
||||
as the data field while preserving all other state information.
|
||||
|
||||
Args:
|
||||
state: ReadState containing messages and other state data
|
||||
|
||||
Returns:
|
||||
ReadState: Updated state with extracted content in data field
|
||||
"""
|
||||
|
||||
content = state['messages'][0].content if state.get('messages') else ''
|
||||
# 返回内容并保持所有状态信息
|
||||
# Return content and maintain all state information
|
||||
for pronoun in AgentMemoryDataset.PRONOUN:
|
||||
content = content.replace(pronoun, AgentMemoryDataset.NAME)
|
||||
|
||||
return {"data": content}
|
||||
|
||||
|
||||
def content_input_write(state: WriteState) -> WriteState:
|
||||
"""开始节点 - 提取内容并保持状态信息"""
|
||||
"""
|
||||
Start node - Extract content and maintain state information for write operations
|
||||
|
||||
Extracts the content from the first message in the state for write operations.
|
||||
|
||||
Args:
|
||||
state: WriteState containing messages and other state data
|
||||
|
||||
Returns:
|
||||
WriteState: Updated state with extracted content in data field
|
||||
"""
|
||||
|
||||
content = state['messages'][0].content if state.get('messages') else ''
|
||||
# 返回内容并保持所有状态信息
|
||||
return {"data": content}
|
||||
# Return content and maintain all state information
|
||||
for pronoun in AgentMemoryDataset.PRONOUN:
|
||||
content = content.replace(pronoun, AgentMemoryDataset.NAME)
|
||||
|
||||
return {"data": content}
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
import os
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from app.core.logging_config import get_agent_logger
|
||||
from app.db import get_db
|
||||
|
||||
from app.core.logging_config import get_agent_logger
|
||||
from app.core.memory.agent.models.problem_models import ProblemExtensionResponse
|
||||
from app.core.memory.agent.services.optimized_llm_service import LLMServiceMixin
|
||||
from app.core.memory.agent.utils.llm_tools import (
|
||||
PROJECT_ROOT_,
|
||||
ReadState,
|
||||
@@ -12,27 +12,46 @@ from app.core.memory.agent.utils.llm_tools import (
|
||||
from app.core.memory.agent.utils.redis_tool import store
|
||||
from app.core.memory.agent.utils.session_tools import SessionService
|
||||
from app.core.memory.agent.utils.template_tools import TemplateService
|
||||
from app.core.memory.agent.services.optimized_llm_service import LLMServiceMixin
|
||||
from app.db import get_db_context
|
||||
|
||||
template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt')
|
||||
db_session = next(get_db())
|
||||
logger = get_agent_logger(__name__)
|
||||
|
||||
|
||||
class ProblemNodeService(LLMServiceMixin):
|
||||
"""问题处理节点服务类"""
|
||||
"""
|
||||
Problem processing node service class
|
||||
|
||||
Handles problem decomposition and extension operations using LLM services.
|
||||
Inherits from LLMServiceMixin to provide structured LLM calling capabilities.
|
||||
|
||||
Attributes:
|
||||
template_service: Service for rendering Jinja2 templates
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.template_service = TemplateService(template_root)
|
||||
|
||||
|
||||
# 创建全局服务实例
|
||||
# Create global service instance
|
||||
problem_service = ProblemNodeService()
|
||||
|
||||
|
||||
async def Split_The_Problem(state: ReadState) -> ReadState:
|
||||
"""问题分解节点"""
|
||||
"""
|
||||
Problem decomposition node
|
||||
|
||||
Breaks down complex user queries into smaller, more manageable sub-problems.
|
||||
Uses LLM to analyze the input and generate structured problem decomposition
|
||||
with question types and reasoning.
|
||||
|
||||
Args:
|
||||
state: ReadState containing user input and configuration
|
||||
|
||||
Returns:
|
||||
ReadState: Updated state with problem decomposition results
|
||||
"""
|
||||
# 从状态中获取数据
|
||||
content = state.get('data', '')
|
||||
end_user_id = state.get('end_user_id', '')
|
||||
@@ -53,18 +72,19 @@ async def Split_The_Problem(state: ReadState) -> ReadState:
|
||||
|
||||
try:
|
||||
# 使用优化的LLM服务
|
||||
structured = await problem_service.call_llm_structured(
|
||||
state=state,
|
||||
db_session=db_session,
|
||||
system_prompt=system_prompt,
|
||||
response_model=ProblemExtensionResponse,
|
||||
fallback_value=[]
|
||||
)
|
||||
with get_db_context() as db_session:
|
||||
structured = await problem_service.call_llm_structured(
|
||||
state=state,
|
||||
db_session=db_session,
|
||||
system_prompt=system_prompt,
|
||||
response_model=ProblemExtensionResponse,
|
||||
fallback_value=[]
|
||||
)
|
||||
|
||||
# 添加更详细的日志记录
|
||||
logger.info(f"Split_The_Problem: 开始处理问题分解,内容长度: {len(content)}")
|
||||
|
||||
# 验证结构化响应
|
||||
# Validate structured response
|
||||
if not structured or not hasattr(structured, 'root'):
|
||||
logger.warning("Split_The_Problem: 结构化响应为空或格式不正确")
|
||||
split_result = json.dumps([], ensure_ascii=False)
|
||||
@@ -106,17 +126,17 @@ async def Split_The_Problem(state: ReadState) -> ReadState:
|
||||
exc_info=True
|
||||
)
|
||||
|
||||
# 提供更详细的错误信息
|
||||
# Provide more detailed error information
|
||||
error_details = {
|
||||
"error_type": type(e).__name__,
|
||||
"error_message": str(e),
|
||||
"content_length": len(content),
|
||||
"llm_model_id": memory_config.llm_model_id if memory_config else None
|
||||
"llm_model_id": str(memory_config.llm_model_id) if memory_config else None
|
||||
}
|
||||
|
||||
logger.error(f"Split_The_Problem error details: {error_details}")
|
||||
|
||||
# 创建默认的空结果
|
||||
# Create default empty result
|
||||
result = {
|
||||
"context": json.dumps([], ensure_ascii=False),
|
||||
"original": content,
|
||||
@@ -130,13 +150,25 @@ async def Split_The_Problem(state: ReadState) -> ReadState:
|
||||
}
|
||||
}
|
||||
|
||||
# 返回更新后的状态,包含spit_context字段
|
||||
# Return updated state including spit_context field
|
||||
return {"spit_data": result}
|
||||
|
||||
|
||||
async def Problem_Extension(state: ReadState) -> ReadState:
|
||||
"""问题扩展节点"""
|
||||
# 获取原始数据和分解结果
|
||||
"""
|
||||
Problem extension node
|
||||
|
||||
Extends the decomposed problems from Split_The_Problem node by generating
|
||||
additional related questions and organizing them by original question.
|
||||
Uses LLM to create comprehensive question extensions for better memory retrieval.
|
||||
|
||||
Args:
|
||||
state: ReadState containing decomposed problems and configuration
|
||||
|
||||
Returns:
|
||||
ReadState: Updated state with extended problem results
|
||||
"""
|
||||
# Get original data and decomposition results
|
||||
start = time.time()
|
||||
content = state.get('data', '')
|
||||
data = state.get('spit_data', '')['context']
|
||||
@@ -171,17 +203,18 @@ async def Problem_Extension(state: ReadState) -> ReadState:
|
||||
|
||||
try:
|
||||
# 使用优化的LLM服务
|
||||
response_content = await problem_service.call_llm_structured(
|
||||
state=state,
|
||||
db_session=db_session,
|
||||
system_prompt=system_prompt,
|
||||
response_model=ProblemExtensionResponse,
|
||||
fallback_value=[]
|
||||
)
|
||||
with get_db_context() as db_session:
|
||||
response_content = await problem_service.call_llm_structured(
|
||||
state=state,
|
||||
db_session=db_session,
|
||||
system_prompt=system_prompt,
|
||||
response_model=ProblemExtensionResponse,
|
||||
fallback_value=[]
|
||||
)
|
||||
|
||||
logger.info(f"Problem_Extension: 开始处理问题扩展,问题数量: {len(databasets)}")
|
||||
|
||||
# 验证结构化响应
|
||||
# Validate structured response
|
||||
if not response_content or not hasattr(response_content, 'root'):
|
||||
logger.warning("Problem_Extension: 结构化响应为空或格式不正确")
|
||||
aggregated_dict = {}
|
||||
@@ -215,12 +248,12 @@ async def Problem_Extension(state: ReadState) -> ReadState:
|
||||
exc_info=True
|
||||
)
|
||||
|
||||
# 提供更详细的错误信息
|
||||
# Provide more detailed error information
|
||||
error_details = {
|
||||
"error_type": type(e).__name__,
|
||||
"error_message": str(e),
|
||||
"questions_count": len(databasets),
|
||||
"llm_model_id": memory_config.llm_model_id if memory_config else None
|
||||
"llm_model_id": str(memory_config.llm_model_id) if memory_config else None
|
||||
}
|
||||
|
||||
logger.error(f"Problem_Extension error details: {error_details}")
|
||||
|
||||
@@ -6,34 +6,41 @@ import os
|
||||
# ===== 第三方库 =====
|
||||
from langchain.agents import create_agent
|
||||
from langchain_openai import ChatOpenAI
|
||||
|
||||
from app.core.logging_config import get_agent_logger
|
||||
from app.db import get_db, get_db_context
|
||||
|
||||
from app.schemas import model_schema
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
from app.services.model_service import ModelConfigService
|
||||
|
||||
from app.core.memory.agent.services.search_service import SearchService
|
||||
from app.core.memory.agent.utils.llm_tools import (
|
||||
COUNTState,
|
||||
ReadState,
|
||||
deduplicate_entries,
|
||||
merge_to_key_value_pairs,
|
||||
)
|
||||
from app.core.memory.agent.langgraph_graph.tools.tool import (
|
||||
create_hybrid_retrieval_tool_sync,
|
||||
create_time_retrieval_tool,
|
||||
extract_tool_message_content,
|
||||
)
|
||||
|
||||
from app.core.memory.agent.services.search_service import SearchService
|
||||
from app.core.memory.agent.utils.llm_tools import (
|
||||
ReadState,
|
||||
deduplicate_entries,
|
||||
merge_to_key_value_pairs,
|
||||
)
|
||||
from app.core.rag.nlp.search import knowledge_retrieval
|
||||
from app.db import get_db_context
|
||||
from app.schemas import model_schema
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
from app.services.model_service import ModelConfigService
|
||||
|
||||
logger = get_agent_logger(__name__)
|
||||
db = next(get_db())
|
||||
|
||||
|
||||
|
||||
async def rag_config(state):
|
||||
"""
|
||||
Configure RAG (Retrieval-Augmented Generation) settings
|
||||
|
||||
Creates configuration for knowledge base retrieval including similarity thresholds,
|
||||
weights, and reranker settings.
|
||||
|
||||
Args:
|
||||
state: Current state containing user_rag_memory_id
|
||||
|
||||
Returns:
|
||||
dict: RAG configuration dictionary
|
||||
"""
|
||||
user_rag_memory_id = state.get('user_rag_memory_id', '')
|
||||
kb_config = {
|
||||
"knowledge_bases": [
|
||||
@@ -50,10 +57,25 @@ async def rag_config(state):
|
||||
"reranker_top_k": 10
|
||||
}
|
||||
return kb_config
|
||||
async def rag_knowledge(state,question):
|
||||
|
||||
|
||||
async def rag_knowledge(state, question):
|
||||
"""
|
||||
Retrieve knowledge using RAG approach
|
||||
|
||||
Performs knowledge retrieval from configured knowledge bases using the
|
||||
provided question and returns formatted results.
|
||||
|
||||
Args:
|
||||
state: Current state containing configuration
|
||||
question: Question to search for
|
||||
|
||||
Returns:
|
||||
tuple: (retrieval_knowledge, clean_content, cleaned_query, raw_results)
|
||||
"""
|
||||
kb_config = await rag_config(state)
|
||||
end_user_id = state.get('end_user_id', '')
|
||||
user_rag_memory_id=state.get("user_rag_memory_id",'')
|
||||
user_rag_memory_id = state.get("user_rag_memory_id", '')
|
||||
retrieve_chunks_result = knowledge_retrieval(question, kb_config, [str(end_user_id)])
|
||||
try:
|
||||
retrieval_knowledge = [i.page_content for i in retrieve_chunks_result]
|
||||
@@ -61,22 +83,34 @@ async def rag_knowledge(state,question):
|
||||
cleaned_query = question
|
||||
raw_results = clean_content
|
||||
logger.info(f" Using RAG storage with memory_id={user_rag_memory_id}")
|
||||
except Exception :
|
||||
retrieval_knowledge=[]
|
||||
except Exception:
|
||||
retrieval_knowledge = []
|
||||
clean_content = ''
|
||||
raw_results = ''
|
||||
cleaned_query = question
|
||||
logger.info(f"No content retrieved from knowledge base: {user_rag_memory_id}")
|
||||
return retrieval_knowledge,clean_content,cleaned_query,raw_results
|
||||
return retrieval_knowledge, clean_content, cleaned_query, raw_results
|
||||
|
||||
|
||||
async def llm_infomation(state: ReadState) -> ReadState:
|
||||
"""
|
||||
Get LLM configuration information from state
|
||||
|
||||
Retrieves model configuration details including model ID and tenant ID
|
||||
from the memory configuration in the current state.
|
||||
|
||||
Args:
|
||||
state: ReadState containing memory configuration
|
||||
|
||||
Returns:
|
||||
ReadState: Model configuration as Pydantic model
|
||||
"""
|
||||
memory_config = state.get('memory_config', None)
|
||||
model_id = memory_config.llm_model_id
|
||||
tenant_id = memory_config.tenant_id
|
||||
|
||||
# 使用现有的 memory_config 而不是重新查询数据库
|
||||
# 或者使用线程安全的数据库访问
|
||||
# Use existing memory_config instead of re-querying database
|
||||
# or use thread-safe database access
|
||||
with get_db_context() as db:
|
||||
result_orm = ModelConfigService.get_model_by_id(db=db, model_id=model_id, tenant_id=tenant_id)
|
||||
result_pydantic = model_schema.ModelConfig.model_validate(result_orm)
|
||||
@@ -85,16 +119,20 @@ async def llm_infomation(state: ReadState) -> ReadState:
|
||||
|
||||
async def clean_databases(data) -> str:
|
||||
"""
|
||||
简化的数据库搜索结果清理函数
|
||||
Simplified database search result cleaning function
|
||||
|
||||
Processes and cleans search results from various sources including
|
||||
reranked results and time-based search results. Extracts text content
|
||||
from structured data and returns as formatted string.
|
||||
|
||||
Args:
|
||||
data: 搜索结果数据
|
||||
data: Search result data (can be string, dict, or other types)
|
||||
|
||||
Returns:
|
||||
清理后的内容字符串
|
||||
str: Cleaned content string
|
||||
"""
|
||||
try:
|
||||
# 解析JSON字符串
|
||||
# Parse JSON string
|
||||
if isinstance(data, str):
|
||||
try:
|
||||
data = json.loads(data)
|
||||
@@ -104,24 +142,24 @@ async def clean_databases(data) -> str:
|
||||
if not isinstance(data, dict):
|
||||
return str(data)
|
||||
|
||||
# 获取结果数据
|
||||
# Get result data
|
||||
# with open("搜索结果.json","w",encoding='utf-8') as f:
|
||||
# f.write(json.dumps(data, indent=4, ensure_ascii=False))
|
||||
results = data.get('results', data)
|
||||
if not isinstance(results, dict):
|
||||
return str(results)
|
||||
|
||||
# 收集所有内容
|
||||
# Collect all content
|
||||
content_list = []
|
||||
|
||||
# 处理重排序结果
|
||||
|
||||
# Process reranked results
|
||||
reranked = results.get('reranked_results', {})
|
||||
if reranked:
|
||||
for category in ['summaries', 'statements', 'chunks', 'entities']:
|
||||
for category in ['summaries', 'communities', 'statements', 'chunks', 'entities']:
|
||||
items = reranked.get(category, [])
|
||||
if isinstance(items, list):
|
||||
content_list.extend(items)
|
||||
# 处理时间搜索结果
|
||||
# Process time search results
|
||||
time_search = results.get('time_search', {})
|
||||
if time_search:
|
||||
if isinstance(time_search, dict):
|
||||
@@ -131,17 +169,23 @@ async def clean_databases(data) -> str:
|
||||
elif isinstance(time_search, list):
|
||||
content_list.extend(time_search)
|
||||
|
||||
# 提取文本内容
|
||||
# Extract text content,对 community 按 name 去重(多次 tool 调用会产生重复)
|
||||
text_parts = []
|
||||
seen_community_names = set()
|
||||
for item in content_list:
|
||||
if isinstance(item, dict):
|
||||
text = item.get('statement') or item.get('content', '')
|
||||
# community 节点用 name 去重
|
||||
if 'member_count' in item or 'core_entities' in item:
|
||||
community_name = item.get('name') or item.get('id', '')
|
||||
if community_name in seen_community_names:
|
||||
continue
|
||||
seen_community_names.add(community_name)
|
||||
text = item.get('statement') or item.get('content') or item.get('summary', '')
|
||||
if text:
|
||||
text_parts.append(text)
|
||||
elif isinstance(item, str):
|
||||
text_parts.append(item)
|
||||
|
||||
|
||||
return '\n'.join(text_parts).strip()
|
||||
|
||||
except Exception as e:
|
||||
@@ -150,24 +194,33 @@ async def clean_databases(data) -> str:
|
||||
|
||||
|
||||
async def retrieve_nodes(state: ReadState) -> ReadState:
|
||||
"""
|
||||
Retrieve information using simplified search approach
|
||||
|
||||
Processes extended problems from previous nodes and performs retrieval
|
||||
using either RAG or hybrid search based on storage type. Handles concurrent
|
||||
processing of multiple questions and deduplicates results.
|
||||
|
||||
Args:
|
||||
state: ReadState containing problem extensions and configuration
|
||||
|
||||
Returns:
|
||||
ReadState: Updated state with retrieval results and intermediate outputs
|
||||
"""
|
||||
|
||||
'''
|
||||
|
||||
模型信息
|
||||
'''
|
||||
|
||||
problem_extension=state.get('problem_extension', '')['context']
|
||||
storage_type=state.get('storage_type', '')
|
||||
user_rag_memory_id=state.get('user_rag_memory_id', '')
|
||||
end_user_id=state.get('end_user_id', '')
|
||||
problem_extension = state.get('problem_extension', '')['context']
|
||||
storage_type = state.get('storage_type', '')
|
||||
user_rag_memory_id = state.get('user_rag_memory_id', '')
|
||||
end_user_id = state.get('end_user_id', '')
|
||||
memory_config = state.get('memory_config', None)
|
||||
original=state.get('data', '')
|
||||
problem_list=[]
|
||||
for key,values in problem_extension.items():
|
||||
original = state.get('data', '')
|
||||
problem_list = []
|
||||
for key, values in problem_extension.items():
|
||||
for data in values:
|
||||
problem_list.append(data)
|
||||
logger.info(f"Retrieve: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}")
|
||||
# 创建异步任务处理单个问题
|
||||
|
||||
# Create async task to process individual questions
|
||||
async def process_question_nodes(idx, question):
|
||||
try:
|
||||
# Prepare search parameters based on storage type
|
||||
@@ -213,7 +266,7 @@ async def retrieve_nodes(state: ReadState) -> ReadState:
|
||||
}
|
||||
}
|
||||
|
||||
# 并发处理所有问题
|
||||
# Process all questions concurrently
|
||||
tasks = [process_question_nodes(idx, question) for idx, question in enumerate(problem_list)]
|
||||
databases_anser = await asyncio.gather(*tasks)
|
||||
databases_data = {
|
||||
@@ -244,7 +297,7 @@ async def retrieve_nodes(state: ReadState) -> ReadState:
|
||||
|
||||
send_verify = []
|
||||
for i, j in zip(keys, val, strict=False):
|
||||
if j!=['']:
|
||||
if j != ['']:
|
||||
send_verify.append({
|
||||
"Query_small": i,
|
||||
"Answer_Small": j
|
||||
@@ -257,15 +310,26 @@ async def retrieve_nodes(state: ReadState) -> ReadState:
|
||||
}
|
||||
|
||||
logger.info(f"Collected {len(intermediate_outputs)} intermediate outputs from search results")
|
||||
return {'retrieve':dup_databases}
|
||||
|
||||
|
||||
return {'retrieve': dup_databases}
|
||||
|
||||
|
||||
async def retrieve(state: ReadState) -> ReadState:
|
||||
# 从state中获取end_user_id
|
||||
"""
|
||||
Advanced retrieve function using LangChain agents and tools
|
||||
|
||||
Uses LangChain agents with specialized retrieval tools (time-based and hybrid)
|
||||
to perform sophisticated information retrieval. Supports both RAG and traditional
|
||||
memory storage approaches with concurrent processing and result deduplication.
|
||||
|
||||
Args:
|
||||
state: ReadState containing problem extensions and configuration
|
||||
|
||||
Returns:
|
||||
ReadState: Updated state with retrieval results and intermediate outputs
|
||||
"""
|
||||
# Get end_user_id from state
|
||||
import time
|
||||
start=time.time()
|
||||
start = time.time()
|
||||
problem_extension = state.get('problem_extension', '')['context']
|
||||
storage_type = state.get('storage_type', '')
|
||||
user_rag_memory_id = state.get('user_rag_memory_id', '')
|
||||
@@ -283,6 +347,7 @@ async def retrieve(state: ReadState) -> ReadState:
|
||||
with get_db_context() as db: # 使用同步数据库上下文管理器
|
||||
config_service = MemoryConfigService(db)
|
||||
return await llm_infomation(state)
|
||||
|
||||
llm_config = await get_llm_info()
|
||||
api_key_obj = llm_config.api_keys[0]
|
||||
api_key = api_key_obj.api_key
|
||||
@@ -296,28 +361,33 @@ async def retrieve(state: ReadState) -> ReadState:
|
||||
)
|
||||
|
||||
time_retrieval_tool = create_time_retrieval_tool(end_user_id)
|
||||
search_params = { "end_user_id": end_user_id, "return_raw_results": True }
|
||||
hybrid_retrieval=create_hybrid_retrieval_tool_sync(memory_config, **search_params)
|
||||
search_params = {
|
||||
"end_user_id": end_user_id,
|
||||
"return_raw_results": True,
|
||||
"include": ["summaries", "statements", "chunks", "entities", "communities"],
|
||||
}
|
||||
hybrid_retrieval = create_hybrid_retrieval_tool_sync(memory_config, **search_params)
|
||||
agent = create_agent(
|
||||
llm,
|
||||
tools=[time_retrieval_tool,hybrid_retrieval],
|
||||
tools=[time_retrieval_tool, hybrid_retrieval],
|
||||
system_prompt=f"我是检索专家,可以根据适合的工具进行检索。当前使用的end_user_id是: {end_user_id}"
|
||||
)
|
||||
|
||||
# 创建异步任务处理单个问题
|
||||
# Create async task to process individual questions
|
||||
import asyncio
|
||||
|
||||
# 在模块级别定义信号量,限制最大并发数
|
||||
SEMAPHORE = asyncio.Semaphore(5) # 限制最多5个并发数据库操作
|
||||
# Define semaphore at module level to limit maximum concurrency
|
||||
SEMAPHORE = asyncio.Semaphore(5) # Limit to maximum 5 concurrent database operations
|
||||
|
||||
async def process_question(idx, question):
|
||||
async with SEMAPHORE: # 限制并发
|
||||
async with SEMAPHORE: # Limit concurrency
|
||||
try:
|
||||
if storage_type == "rag" and user_rag_memory_id:
|
||||
retrieval_knowledge, clean_content, cleaned_query, raw_results = await rag_knowledge(state, question)
|
||||
retrieval_knowledge, clean_content, cleaned_query, raw_results = await rag_knowledge(state,
|
||||
question)
|
||||
else:
|
||||
cleaned_query = question
|
||||
# 使用 asyncio 在线程池中运行同步的 agent.invoke
|
||||
# Use asyncio to run synchronous agent.invoke in thread pool
|
||||
import asyncio
|
||||
response = await asyncio.get_event_loop().run_in_executor(
|
||||
None,
|
||||
@@ -331,8 +401,32 @@ async def retrieve(state: ReadState) -> ReadState:
|
||||
raw_results = tool_results['content']
|
||||
clean_content = await clean_databases(raw_results)
|
||||
|
||||
# 社区展开:从 tool 返回结果中提取命中的 community,
|
||||
# 沿 BELONGS_TO_COMMUNITY 关系拉取关联 Statement 追加到 clean_content
|
||||
_expanded_stmts_to_write = []
|
||||
try:
|
||||
results_dict = raw_results.get('results', {}) if isinstance(raw_results, dict) else {}
|
||||
reranked = results_dict.get('reranked_results', {})
|
||||
community_hits = reranked.get('communities', [])
|
||||
if not community_hits:
|
||||
community_hits = results_dict.get('communities', [])
|
||||
if community_hits:
|
||||
from app.core.memory.agent.services.search_service import expand_communities_to_statements
|
||||
_expanded_stmts_to_write, new_texts = await expand_communities_to_statements(
|
||||
community_results=community_hits,
|
||||
end_user_id=end_user_id,
|
||||
existing_content=clean_content,
|
||||
)
|
||||
if new_texts:
|
||||
clean_content = clean_content + '\n' + '\n'.join(new_texts)
|
||||
except Exception as parse_err:
|
||||
logger.warning(f"[Retrieve] 解析社区命中结果失败,跳过展开: {parse_err}")
|
||||
|
||||
try:
|
||||
raw_results = raw_results['results']
|
||||
# 写回展开结果,接口返回中可见(已在 helper 中清洗过字段)
|
||||
if _expanded_stmts_to_write and isinstance(raw_results, dict):
|
||||
raw_results.setdefault('reranked_results', {})['expanded_statements'] = _expanded_stmts_to_write
|
||||
except Exception:
|
||||
raw_results = []
|
||||
|
||||
@@ -366,7 +460,7 @@ async def retrieve(state: ReadState) -> ReadState:
|
||||
}
|
||||
}
|
||||
|
||||
# 并发处理所有问题
|
||||
# Process all questions concurrently
|
||||
import asyncio
|
||||
tasks = [process_question(idx, question) for idx, question in enumerate(problem_list)]
|
||||
databases_anser = await asyncio.gather(*tasks)
|
||||
@@ -413,5 +507,3 @@ async def retrieve(state: ReadState) -> ReadState:
|
||||
# json.dump(dup_databases, f, indent=4)
|
||||
logger.info(f"Collected {len(intermediate_outputs)} intermediate outputs from search results")
|
||||
return {'retrieve': dup_databases}
|
||||
|
||||
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
|
||||
|
||||
import os
|
||||
import time
|
||||
|
||||
@@ -17,34 +15,143 @@ from app.core.memory.agent.utils.llm_tools import (
|
||||
from app.core.memory.agent.utils.redis_tool import store
|
||||
from app.core.memory.agent.utils.session_tools import SessionService
|
||||
from app.core.memory.agent.utils.template_tools import TemplateService
|
||||
from app.db import get_db
|
||||
from app.core.rag.nlp.search import knowledge_retrieval
|
||||
from app.db import get_db_context
|
||||
|
||||
template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt')
|
||||
logger = get_agent_logger(__name__)
|
||||
db_session = next(get_db())
|
||||
|
||||
|
||||
class SummaryNodeService(LLMServiceMixin):
|
||||
"""总结节点服务类"""
|
||||
"""
|
||||
Summary node service class
|
||||
|
||||
Handles summary generation operations using LLM services. Inherits from
|
||||
LLMServiceMixin to provide structured LLM calling capabilities for
|
||||
generating summaries from retrieved information.
|
||||
|
||||
Attributes:
|
||||
template_service: Service for rendering Jinja2 templates
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.template_service = TemplateService(template_root)
|
||||
|
||||
# 创建全局服务实例
|
||||
|
||||
# Create global service instance
|
||||
summary_service = SummaryNodeService()
|
||||
|
||||
|
||||
async def rag_config(state):
|
||||
"""
|
||||
Configure RAG (Retrieval-Augmented Generation) settings for summary operations
|
||||
|
||||
Creates configuration for knowledge base retrieval including similarity thresholds,
|
||||
weights, and reranker settings specifically for summary generation.
|
||||
|
||||
Args:
|
||||
state: Current state containing user_rag_memory_id
|
||||
|
||||
Returns:
|
||||
dict: RAG configuration dictionary with knowledge base settings
|
||||
"""
|
||||
user_rag_memory_id = state.get('user_rag_memory_id', '')
|
||||
kb_config = {
|
||||
"knowledge_bases": [
|
||||
{
|
||||
"kb_id": user_rag_memory_id,
|
||||
"similarity_threshold": 0.7,
|
||||
"vector_similarity_weight": 0.5,
|
||||
"top_k": 10,
|
||||
"retrieve_type": "participle"
|
||||
}
|
||||
],
|
||||
"merge_strategy": "weight",
|
||||
"reranker_id": os.getenv('reranker_id'),
|
||||
"reranker_top_k": 10
|
||||
}
|
||||
return kb_config
|
||||
|
||||
|
||||
async def rag_knowledge(state, question):
|
||||
"""
|
||||
Retrieve knowledge using RAG approach for summary generation
|
||||
|
||||
Performs knowledge retrieval from configured knowledge bases using the
|
||||
provided question and returns formatted results for summary processing.
|
||||
|
||||
Args:
|
||||
state: Current state containing configuration
|
||||
question: Question to search for in knowledge base
|
||||
|
||||
Returns:
|
||||
tuple: (retrieval_knowledge, clean_content, cleaned_query, raw_results)
|
||||
- retrieval_knowledge: List of retrieved knowledge chunks
|
||||
- clean_content: Formatted content string
|
||||
- cleaned_query: Processed query string
|
||||
- raw_results: Raw retrieval results
|
||||
"""
|
||||
kb_config = await rag_config(state)
|
||||
end_user_id = state.get('end_user_id', '')
|
||||
user_rag_memory_id = state.get("user_rag_memory_id", '')
|
||||
retrieve_chunks_result = knowledge_retrieval(question, kb_config, [str(end_user_id)])
|
||||
try:
|
||||
retrieval_knowledge = [i.page_content for i in retrieve_chunks_result]
|
||||
clean_content = '\n\n'.join(retrieval_knowledge)
|
||||
cleaned_query = question
|
||||
raw_results = clean_content
|
||||
logger.info(f" Using RAG storage with memory_id={user_rag_memory_id}")
|
||||
except Exception:
|
||||
retrieval_knowledge = []
|
||||
clean_content = ''
|
||||
raw_results = ''
|
||||
cleaned_query = question
|
||||
logger.info(f"No content retrieved from knowledge base: {user_rag_memory_id}")
|
||||
return retrieval_knowledge, clean_content, cleaned_query, raw_results
|
||||
|
||||
|
||||
async def summary_history(state: ReadState) -> ReadState:
|
||||
"""
|
||||
Retrieve conversation history for summary context
|
||||
|
||||
Gets the conversation history for the current user to provide context
|
||||
for summary generation operations.
|
||||
|
||||
Args:
|
||||
state: ReadState containing end_user_id
|
||||
|
||||
Returns:
|
||||
ReadState: Conversation history data
|
||||
"""
|
||||
end_user_id = state.get("end_user_id", '')
|
||||
history = await SessionService(store).get_history(end_user_id, end_user_id, end_user_id)
|
||||
return history
|
||||
|
||||
async def summary_llm(state: ReadState, history, retrieve_info, template_name, operation_name, response_model,search_mode) -> str:
|
||||
|
||||
async def summary_llm(state: ReadState, history, retrieve_info, template_name, operation_name, response_model,
|
||||
search_mode) -> str:
|
||||
"""
|
||||
增强的summary_llm函数,包含更好的错误处理和数据验证
|
||||
Enhanced summary_llm function with better error handling and data validation
|
||||
|
||||
Generates summaries using LLM with structured output. Includes fallback mechanisms
|
||||
for handling LLM failures and provides robust error recovery.
|
||||
|
||||
Args:
|
||||
state: ReadState containing current context
|
||||
history: Conversation history for context
|
||||
retrieve_info: Retrieved information to summarize
|
||||
template_name: Jinja2 template name for prompt generation
|
||||
operation_name: Type of operation (summary, input_summary, retrieve_summary)
|
||||
response_model: Pydantic model for structured output
|
||||
search_mode: Search mode flag ("0" for simple, "1" for complex)
|
||||
|
||||
Returns:
|
||||
str: Generated summary text or fallback message
|
||||
"""
|
||||
data = state.get("data", '')
|
||||
|
||||
# 构建系统提示词
|
||||
|
||||
# Build system prompt
|
||||
if str(search_mode) == "0":
|
||||
system_prompt = await summary_service.template_service.render_template(
|
||||
template_name=template_name,
|
||||
@@ -61,40 +168,41 @@ async def summary_llm(state: ReadState, history, retrieve_info, template_name, o
|
||||
retrieve_info=retrieve_info
|
||||
)
|
||||
try:
|
||||
# 使用优化的LLM服务进行结构化输出
|
||||
structured = await summary_service.call_llm_structured(
|
||||
state=state,
|
||||
db_session=db_session,
|
||||
system_prompt=system_prompt,
|
||||
response_model=response_model,
|
||||
fallback_value=None
|
||||
)
|
||||
# 验证结构化响应
|
||||
# Use optimized LLM service for structured output
|
||||
with get_db_context() as db_session:
|
||||
structured = await summary_service.call_llm_structured(
|
||||
state=state,
|
||||
db_session=db_session,
|
||||
system_prompt=system_prompt,
|
||||
response_model=response_model,
|
||||
fallback_value=None
|
||||
)
|
||||
# Validate structured response
|
||||
if structured is None:
|
||||
logger.warning(f"LLM返回None,使用默认回答")
|
||||
logger.warning("LLM返回None,使用默认回答")
|
||||
return "信息不足,无法回答"
|
||||
|
||||
# 根据操作类型提取答案
|
||||
|
||||
# Extract answer based on operation type
|
||||
if operation_name == "summary":
|
||||
aimessages = getattr(structured, 'query_answer', None) or "信息不足,无法回答"
|
||||
else:
|
||||
# 处理RetrieveSummaryResponse
|
||||
# Handle RetrieveSummaryResponse
|
||||
if hasattr(structured, 'data') and structured.data:
|
||||
aimessages = getattr(structured.data, 'query_answer', None) or "信息不足,无法回答"
|
||||
else:
|
||||
logger.warning(f"结构化响应缺少data字段")
|
||||
logger.warning("结构化响应缺少data字段")
|
||||
aimessages = "信息不足,无法回答"
|
||||
|
||||
# 验证答案不为空
|
||||
|
||||
# Validate answer is not empty
|
||||
if not aimessages or aimessages.strip() == "":
|
||||
aimessages = "信息不足,无法回答"
|
||||
|
||||
|
||||
return aimessages
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"结构化输出失败: {e}", exc_info=True)
|
||||
|
||||
# 尝试非结构化输出作为fallback
|
||||
|
||||
# Try unstructured output as fallback
|
||||
try:
|
||||
logger.info("尝试非结构化输出作为fallback")
|
||||
response = await summary_service.call_llm_simple(
|
||||
@@ -103,24 +211,38 @@ async def summary_llm(state: ReadState, history, retrieve_info, template_name, o
|
||||
system_prompt=system_prompt,
|
||||
fallback_message="信息不足,无法回答"
|
||||
)
|
||||
|
||||
|
||||
if response and response.strip():
|
||||
# 简单清理响应
|
||||
# Simple response cleaning
|
||||
cleaned_response = response.strip()
|
||||
# 移除可能的JSON标记
|
||||
# Remove possible JSON markers
|
||||
if cleaned_response.startswith('```'):
|
||||
lines = cleaned_response.split('\n')
|
||||
cleaned_response = '\n'.join(lines[1:-1])
|
||||
|
||||
|
||||
return cleaned_response
|
||||
else:
|
||||
return "信息不足,无法回答"
|
||||
|
||||
|
||||
except Exception as fallback_error:
|
||||
logger.error(f"Fallback也失败: {fallback_error}")
|
||||
return "信息不足,无法回答"
|
||||
|
||||
async def summary_redis_save(state: ReadState,aimessages) -> ReadState:
|
||||
|
||||
async def summary_redis_save(state: ReadState, aimessages) -> ReadState:
|
||||
"""
|
||||
Save summary results to Redis session storage
|
||||
|
||||
Stores the generated summary and user query in Redis for session management
|
||||
and conversation history tracking.
|
||||
|
||||
Args:
|
||||
state: ReadState containing user and query information
|
||||
aimessages: Generated summary message to save
|
||||
|
||||
Returns:
|
||||
ReadState: Updated state after saving to Redis
|
||||
"""
|
||||
data = state.get("data", '')
|
||||
end_user_id = state.get("end_user_id", '')
|
||||
await SessionService(store).save_session(
|
||||
@@ -132,10 +254,26 @@ async def summary_redis_save(state: ReadState,aimessages) -> ReadState:
|
||||
)
|
||||
await SessionService(store).cleanup_duplicates()
|
||||
logger.info(f"sessionid: {aimessages} 写入成功")
|
||||
async def summary_prompt(state: ReadState,aimessages,raw_results) -> ReadState:
|
||||
storage_type=state.get("storage_type",'')
|
||||
user_rag_memory_id=state.get("user_rag_memory_id",'')
|
||||
data=state.get("data", '')
|
||||
|
||||
|
||||
async def summary_prompt(state: ReadState, aimessages, raw_results) -> ReadState:
|
||||
"""
|
||||
Format summary results for different output types
|
||||
|
||||
Creates structured output formats for both input summary and retrieval summary
|
||||
operations, including metadata and intermediate results for frontend display.
|
||||
|
||||
Args:
|
||||
state: ReadState containing storage and user information
|
||||
aimessages: Generated summary message
|
||||
raw_results: Raw search/retrieval results
|
||||
|
||||
Returns:
|
||||
tuple: (input_summary, retrieve_summary) formatted result dictionaries
|
||||
"""
|
||||
storage_type = state.get("storage_type", '')
|
||||
user_rag_memory_id = state.get("user_rag_memory_id", '')
|
||||
data = state.get("data", '')
|
||||
input_summary = {
|
||||
"status": "success",
|
||||
"summary_result": aimessages,
|
||||
@@ -152,14 +290,14 @@ async def summary_prompt(state: ReadState,aimessages,raw_results) -> ReadState:
|
||||
"user_rag_memory_id": user_rag_memory_id
|
||||
}
|
||||
}
|
||||
retrieve={
|
||||
retrieve = {
|
||||
"status": "success",
|
||||
"summary_result": aimessages,
|
||||
"storage_type": storage_type,
|
||||
"user_rag_memory_id": user_rag_memory_id,
|
||||
"_intermediate": {
|
||||
"type": "retrieval_summary",
|
||||
"title":"快速检索",
|
||||
"title": "快速检索",
|
||||
"summary": aimessages,
|
||||
"query": data,
|
||||
"storage_type": storage_type,
|
||||
@@ -167,31 +305,56 @@ async def summary_prompt(state: ReadState,aimessages,raw_results) -> ReadState:
|
||||
}
|
||||
}
|
||||
|
||||
return input_summary,retrieve
|
||||
return input_summary, retrieve
|
||||
|
||||
|
||||
async def Input_Summary(state: ReadState) -> ReadState:
|
||||
start=time.time()
|
||||
storage_type=state.get("storage_type",'')
|
||||
"""
|
||||
Generate quick input summary from retrieved information
|
||||
|
||||
Performs fast retrieval and generates a quick summary response for user queries.
|
||||
This function prioritizes speed by only searching summary nodes and provides
|
||||
immediate feedback to users.
|
||||
|
||||
Args:
|
||||
state: ReadState containing user query, storage configuration, and context
|
||||
|
||||
Returns:
|
||||
ReadState: Dictionary containing summary results with status and metadata
|
||||
"""
|
||||
start = time.time()
|
||||
storage_type = state.get("storage_type", '')
|
||||
memory_config = state.get('memory_config', None)
|
||||
user_rag_memory_id=state.get("user_rag_memory_id",'')
|
||||
data=state.get("data", '')
|
||||
end_user_id=state.get("end_user_id", '')
|
||||
user_rag_memory_id = state.get("user_rag_memory_id", '')
|
||||
data = state.get("data", '')
|
||||
end_user_id = state.get("end_user_id", '')
|
||||
logger.info(f"Input_Summary: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}")
|
||||
history = await summary_history( state)
|
||||
history = await summary_history(state)
|
||||
search_params = {
|
||||
"end_user_id": end_user_id,
|
||||
"question": data,
|
||||
"return_raw_results": True,
|
||||
"include": ["summaries"] # Only search summary nodes for faster performance
|
||||
"include": ["summaries", "communities"] # MemorySummary 和 Community 同为高维度概括节点
|
||||
}
|
||||
|
||||
try:
|
||||
retrieve_info, question, raw_results = await SearchService().execute_hybrid_search(**search_params, memory_config=memory_config)
|
||||
if storage_type != "rag":
|
||||
retrieve_info, question, raw_results = await SearchService().execute_hybrid_search(
|
||||
**search_params,
|
||||
memory_config=memory_config,
|
||||
expand_communities=False, # 路径 "2" 只需要 community 的 summary 文本,不展开到 Statement
|
||||
)
|
||||
# 调试:打印 community 检索结果数量
|
||||
if raw_results and isinstance(raw_results, dict):
|
||||
reranked = raw_results.get('reranked_results', {})
|
||||
community_hits = reranked.get('communities', [])
|
||||
logger.debug(f"[Input_Summary] community 命中数: {len(community_hits)}, "
|
||||
f"summary 命中数: {len(reranked.get('summaries', []))}")
|
||||
else:
|
||||
retrieval_knowledge, retrieve_info, question, raw_results = await rag_knowledge(state, data)
|
||||
except Exception as e:
|
||||
logger.error( f"Input_Summary: hybrid_search failed, using empty results: {e}", exc_info=True )
|
||||
logger.error(f"Input_Summary: hybrid_search failed, using empty results: {e}", exc_info=True)
|
||||
retrieve_info, question, raw_results = "", data, []
|
||||
|
||||
|
||||
try:
|
||||
# aimessages=await summary_llm(state,history,retrieve_info,'Retrieve_Summary_prompt.jinja2',
|
||||
# 'input_summary',RetrieveSummaryResponse)
|
||||
@@ -199,8 +362,8 @@ async def Input_Summary(state: ReadState) -> ReadState:
|
||||
summary_result = await summary_prompt(state, retrieve_info, retrieve_info)
|
||||
summary = summary_result[0]
|
||||
except Exception as e:
|
||||
logger.error( f"Input_Summary failed: {e}", exc_info=True )
|
||||
summary= {
|
||||
logger.error(f"Input_Summary failed: {e}", exc_info=True)
|
||||
summary = {
|
||||
"status": "fail",
|
||||
"summary_result": "信息不足,无法回答",
|
||||
"storage_type": storage_type,
|
||||
@@ -213,30 +376,44 @@ async def Input_Summary(state: ReadState) -> ReadState:
|
||||
except Exception:
|
||||
duration = 0.0
|
||||
log_time('检索', duration)
|
||||
return {"summary":summary}
|
||||
return {"summary": summary}
|
||||
|
||||
async def Retrieve_Summary(state: ReadState)-> ReadState:
|
||||
retrieve=state.get("retrieve", '')
|
||||
history = await summary_history( state)
|
||||
|
||||
async def Retrieve_Summary(state: ReadState) -> ReadState:
|
||||
"""
|
||||
Generate comprehensive summary from retrieved expansion issues
|
||||
|
||||
Processes retrieved expansion issues and generates a detailed summary using LLM.
|
||||
This function handles complex retrieval results and provides comprehensive answers
|
||||
based on expanded query results.
|
||||
|
||||
Args:
|
||||
state: ReadState containing retrieve data with expansion issues
|
||||
|
||||
Returns:
|
||||
ReadState: Dictionary containing comprehensive summary results
|
||||
"""
|
||||
retrieve = state.get("retrieve", '')
|
||||
history = await summary_history(state)
|
||||
import json
|
||||
with open("检索.json","w",encoding='utf-8') as f:
|
||||
with open("检索.json", "w", encoding='utf-8') as f:
|
||||
f.write(json.dumps(retrieve, indent=4, ensure_ascii=False))
|
||||
retrieve=retrieve.get("Expansion_issue", [])
|
||||
start=time.time()
|
||||
retrieve_info_str=[]
|
||||
retrieve = retrieve.get("Expansion_issue", [])
|
||||
start = time.time()
|
||||
retrieve_info_str = []
|
||||
for data in retrieve:
|
||||
if data=='':
|
||||
retrieve_info_str=''
|
||||
if data == '':
|
||||
retrieve_info_str = ''
|
||||
else:
|
||||
for key, value in data.items():
|
||||
if key=='Answer_Small':
|
||||
if key == 'Answer_Small':
|
||||
for i in value:
|
||||
retrieve_info_str.append(i)
|
||||
retrieve_info_str=list(set(retrieve_info_str))
|
||||
retrieve_info_str='\n'.join(retrieve_info_str)
|
||||
retrieve_info_str = list(set(retrieve_info_str))
|
||||
retrieve_info_str = '\n'.join(retrieve_info_str)
|
||||
|
||||
aimessages=await summary_llm(state,history,retrieve_info_str,
|
||||
'direct_summary_prompt.jinja2','retrieve_summary',RetrieveSummaryResponse,"1")
|
||||
aimessages = await summary_llm(state, history, retrieve_info_str,
|
||||
'direct_summary_prompt.jinja2', 'retrieve_summary', RetrieveSummaryResponse, "1")
|
||||
if '信息不足,无法回答' not in str(aimessages) or str(aimessages) != "":
|
||||
await summary_redis_save(state, aimessages)
|
||||
if aimessages == '':
|
||||
@@ -248,33 +425,46 @@ async def Retrieve_Summary(state: ReadState)-> ReadState:
|
||||
except Exception:
|
||||
duration = 0.0
|
||||
log_time('Retrieval summary', duration)
|
||||
|
||||
# 修复协程调用 - 先await,然后访问返回值
|
||||
|
||||
# Fixed coroutine call - await first, then access return value
|
||||
summary_result = await summary_prompt(state, aimessages, retrieve_info_str)
|
||||
summary = summary_result[1]
|
||||
return {"summary":summary}
|
||||
return {"summary": summary}
|
||||
|
||||
|
||||
async def Summary(state: ReadState)-> ReadState:
|
||||
start=time.time()
|
||||
async def Summary(state: ReadState) -> ReadState:
|
||||
"""
|
||||
Generate final comprehensive summary from verified data
|
||||
|
||||
Creates the final summary using verified expansion issues and conversation history.
|
||||
This function processes verified data to generate the most comprehensive and
|
||||
accurate response to user queries.
|
||||
|
||||
Args:
|
||||
state: ReadState containing verified data and query information
|
||||
|
||||
Returns:
|
||||
ReadState: Dictionary containing final summary results
|
||||
"""
|
||||
start = time.time()
|
||||
query = state.get("data", '')
|
||||
verify=state.get("verify", '')
|
||||
verify_expansion_issue=verify.get("verified_data", '')
|
||||
retrieve_info_str=''
|
||||
verify = state.get("verify", '')
|
||||
verify_expansion_issue = verify.get("verified_data", '')
|
||||
retrieve_info_str = ''
|
||||
for data in verify_expansion_issue:
|
||||
for key, value in data.items():
|
||||
if key=='answer_small':
|
||||
if key == 'answer_small':
|
||||
for i in value:
|
||||
retrieve_info_str+=i+'\n'
|
||||
history=await summary_history(state)
|
||||
retrieve_info_str += i + '\n'
|
||||
history = await summary_history(state)
|
||||
|
||||
data = {
|
||||
"query": query,
|
||||
"history": history,
|
||||
"retrieve_info": retrieve_info_str
|
||||
}
|
||||
aimessages=await summary_llm(state,history,data,
|
||||
'summary_prompt.jinja2','summary',SummaryResponse,0)
|
||||
aimessages = await summary_llm(state, history, data,
|
||||
'summary_prompt.jinja2', 'summary', SummaryResponse, 0)
|
||||
|
||||
if '信息不足,无法回答' not in str(aimessages) or str(aimessages) != "":
|
||||
await summary_redis_save(state, aimessages)
|
||||
@@ -286,14 +476,28 @@ async def Summary(state: ReadState)-> ReadState:
|
||||
duration = 0.0
|
||||
log_time('Retrieval summary', duration)
|
||||
|
||||
# 修复协程调用 - 先await,然后访问返回值
|
||||
# Fixed coroutine call - await first, then access return value
|
||||
summary_result = await summary_prompt(state, aimessages, retrieve_info_str)
|
||||
summary = summary_result[1]
|
||||
return {"summary":summary}
|
||||
return {"summary": summary}
|
||||
|
||||
async def Summary_fails(state: ReadState)-> ReadState:
|
||||
storage_type=state.get("storage_type", '')
|
||||
user_rag_memory_id=state.get("user_rag_memory_id", '')
|
||||
|
||||
async def Summary_fails(state: ReadState) -> ReadState:
|
||||
"""
|
||||
Generate fallback summary when normal summary process fails
|
||||
|
||||
Provides a fallback summary generation mechanism when the standard summary
|
||||
process encounters errors or fails to produce satisfactory results. Uses
|
||||
a specialized failure template to handle edge cases.
|
||||
|
||||
Args:
|
||||
state: ReadState containing verified data and failure context
|
||||
|
||||
Returns:
|
||||
ReadState: Dictionary containing fallback summary results
|
||||
"""
|
||||
storage_type = state.get("storage_type", '')
|
||||
user_rag_memory_id = state.get("user_rag_memory_id", '')
|
||||
history = await summary_history(state)
|
||||
query = state.get("data", '')
|
||||
verify = state.get("verify", '')
|
||||
@@ -309,12 +513,12 @@ async def Summary_fails(state: ReadState)-> ReadState:
|
||||
"history": history,
|
||||
"retrieve_info": retrieve_info_str
|
||||
}
|
||||
aimessages = await summary_llm(state, history, data,
|
||||
'fail_summary_prompt.jinja2', 'summary', SummaryResponse, 0)
|
||||
result= {
|
||||
aimessages = await summary_llm(state, history, data,
|
||||
'fail_summary_prompt.jinja2', 'summary', SummaryResponse, 0)
|
||||
result = {
|
||||
"status": "success",
|
||||
"summary_result": aimessages,
|
||||
"storage_type": storage_type,
|
||||
"user_rag_memory_id": user_rag_memory_id
|
||||
}
|
||||
return {"summary":result}
|
||||
return {"summary": result}
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
import asyncio
|
||||
import os
|
||||
from app.core.logging_config import get_agent_logger
|
||||
from app.db import get_db
|
||||
|
||||
from app.core.logging_config import get_agent_logger
|
||||
from app.core.memory.agent.models.verification_models import VerificationResult
|
||||
from app.core.memory.agent.services.optimized_llm_service import LLMServiceMixin
|
||||
from app.core.memory.agent.utils.llm_tools import (
|
||||
PROJECT_ROOT_,
|
||||
ReadState,
|
||||
@@ -10,29 +11,53 @@ from app.core.memory.agent.utils.llm_tools import (
|
||||
from app.core.memory.agent.utils.redis_tool import store
|
||||
from app.core.memory.agent.utils.session_tools import SessionService
|
||||
from app.core.memory.agent.utils.template_tools import TemplateService
|
||||
from app.core.memory.agent.services.optimized_llm_service import LLMServiceMixin
|
||||
from app.db import get_db_context
|
||||
|
||||
template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt')
|
||||
db_session = next(get_db())
|
||||
logger = get_agent_logger(__name__)
|
||||
|
||||
|
||||
class VerificationNodeService(LLMServiceMixin):
|
||||
"""验证节点服务类"""
|
||||
"""
|
||||
Verification node service class
|
||||
|
||||
Handles data verification operations using LLM services. Inherits from
|
||||
LLMServiceMixin to provide structured LLM calling capabilities for
|
||||
verifying and validating retrieved information.
|
||||
|
||||
Attributes:
|
||||
template_service: Service for rendering Jinja2 templates
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.template_service = TemplateService(template_root)
|
||||
|
||||
# 创建全局服务实例
|
||||
|
||||
# Create global service instance
|
||||
verification_service = VerificationNodeService()
|
||||
|
||||
|
||||
async def Verify_prompt(state: ReadState, messages_deal: VerificationResult):
|
||||
"""处理验证结果并生成输出格式"""
|
||||
"""
|
||||
Process verification results and generate output format
|
||||
|
||||
Transforms VerificationResult objects into structured output format suitable
|
||||
for frontend consumption. Handles conversion of VerificationItem objects to
|
||||
dictionary format and adds metadata for tracking.
|
||||
|
||||
Args:
|
||||
state: ReadState containing storage and user configuration
|
||||
messages_deal: VerificationResult containing verification outcomes
|
||||
|
||||
Returns:
|
||||
dict: Formatted verification result with status and metadata
|
||||
"""
|
||||
storage_type = state.get('storage_type', '')
|
||||
user_rag_memory_id = state.get('user_rag_memory_id', '')
|
||||
data = state.get('data', '')
|
||||
|
||||
# 将 VerificationItem 对象转换为字典列表
|
||||
|
||||
# Convert VerificationItem objects to dictionary list
|
||||
verified_data = []
|
||||
if messages_deal.expansion_issue:
|
||||
for item in messages_deal.expansion_issue:
|
||||
@@ -40,7 +65,7 @@ async def Verify_prompt(state: ReadState, messages_deal: VerificationResult):
|
||||
verified_data.append(item.model_dump())
|
||||
elif isinstance(item, dict):
|
||||
verified_data.append(item)
|
||||
|
||||
|
||||
Verify_result = {
|
||||
"status": messages_deal.split_result,
|
||||
"verified_data": verified_data,
|
||||
@@ -58,34 +83,37 @@ async def Verify_prompt(state: ReadState, messages_deal: VerificationResult):
|
||||
}
|
||||
}
|
||||
return Verify_result
|
||||
|
||||
|
||||
async def Verify(state: ReadState):
|
||||
logger.info("=== Verify 节点开始执行 ===")
|
||||
try:
|
||||
content = state.get('data', '')
|
||||
end_user_id = state.get('end_user_id', '')
|
||||
memory_config = state.get('memory_config', None)
|
||||
|
||||
|
||||
logger.info(f"Verify: content={content[:50] if content else 'empty'}..., end_user_id={end_user_id}")
|
||||
|
||||
history = await SessionService(store).get_history(end_user_id, end_user_id, end_user_id)
|
||||
logger.info(f"Verify: 获取历史记录完成,history length={len(history)}")
|
||||
|
||||
retrieve = state.get("retrieve", {})
|
||||
logger.info(f"Verify: retrieve data type={type(retrieve)}, keys={retrieve.keys() if isinstance(retrieve, dict) else 'N/A'}")
|
||||
|
||||
logger.info(
|
||||
f"Verify: retrieve data type={type(retrieve)}, keys={retrieve.keys() if isinstance(retrieve, dict) else 'N/A'}")
|
||||
|
||||
retrieve_expansion = retrieve.get("Expansion_issue", []) if isinstance(retrieve, dict) else []
|
||||
logger.info(f"Verify: Expansion_issue length={len(retrieve_expansion)}")
|
||||
|
||||
|
||||
messages = {
|
||||
"Query": content,
|
||||
"Expansion_issue": retrieve_expansion
|
||||
}
|
||||
|
||||
logger.info("Verify: 开始渲染模板")
|
||||
|
||||
# 生成 JSON schema 以指导 LLM 输出正确格式
|
||||
|
||||
# Generate JSON schema to guide LLM output format
|
||||
json_schema = VerificationResult.model_json_schema()
|
||||
|
||||
|
||||
system_prompt = await verification_service.template_service.render_template(
|
||||
template_name='split_verify_prompt.jinja2',
|
||||
operation_name='split_verify_prompt',
|
||||
@@ -94,29 +122,30 @@ async def Verify(state: ReadState):
|
||||
json_schema=json_schema
|
||||
)
|
||||
logger.info(f"Verify: 模板渲染完成,prompt length={len(system_prompt)}")
|
||||
|
||||
|
||||
# 使用优化的LLM服务,添加超时保护
|
||||
logger.info("Verify: 开始调用 LLM")
|
||||
try:
|
||||
# 添加 asyncio.wait_for 超时包裹,防止无限等待
|
||||
# 超时时间设置为 150 秒(比 LLM 配置的 120 秒稍长)
|
||||
import asyncio
|
||||
structured = await asyncio.wait_for(
|
||||
verification_service.call_llm_structured(
|
||||
state=state,
|
||||
db_session=db_session,
|
||||
system_prompt=system_prompt,
|
||||
response_model=VerificationResult,
|
||||
fallback_value={
|
||||
"query": content,
|
||||
"history": history if isinstance(history, list) else [],
|
||||
"expansion_issue": [],
|
||||
"split_result": "failed",
|
||||
"reason": "验证失败或超时"
|
||||
}
|
||||
),
|
||||
timeout=150.0 # 150秒超时
|
||||
)
|
||||
# Add asyncio.wait_for timeout wrapper to prevent infinite waiting
|
||||
# Timeout set to 150 seconds (slightly longer than LLM config's 120 seconds)
|
||||
|
||||
with get_db_context() as db_session:
|
||||
structured = await asyncio.wait_for(
|
||||
verification_service.call_llm_structured(
|
||||
state=state,
|
||||
db_session=db_session,
|
||||
system_prompt=system_prompt,
|
||||
response_model=VerificationResult,
|
||||
fallback_value={
|
||||
"query": content,
|
||||
"history": history if isinstance(history, list) else [],
|
||||
"expansion_issue": [],
|
||||
"split_result": "failed",
|
||||
"reason": "验证失败或超时"
|
||||
}
|
||||
),
|
||||
timeout=150.0 # 150 second timeout
|
||||
)
|
||||
logger.info(f"Verify: LLM 调用完成,result={structured}")
|
||||
except asyncio.TimeoutError:
|
||||
logger.error("Verify: LLM 调用超时(150秒),使用 fallback 值")
|
||||
@@ -127,11 +156,11 @@ async def Verify(state: ReadState):
|
||||
split_result="failed",
|
||||
reason="LLM调用超时"
|
||||
)
|
||||
|
||||
|
||||
result = await Verify_prompt(state, structured)
|
||||
logger.info("=== Verify 节点执行完成 ===")
|
||||
return {"verify": result}
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Verify 节点执行失败: {e}", exc_info=True)
|
||||
# 返回失败的验证结果
|
||||
@@ -152,4 +181,4 @@ async def Verify(state: ReadState):
|
||||
"user_rag_memory_id": state.get('user_rag_memory_id', '')
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from app.cache.memory.interest_memory import InterestMemoryCache
|
||||
from app.core.memory.agent.utils.llm_tools import WriteState
|
||||
from app.core.memory.agent.utils.write_tools import write
|
||||
from app.core.logging_config import get_agent_logger
|
||||
@@ -40,6 +41,15 @@ async def write_node(state: WriteState) -> WriteState:
|
||||
)
|
||||
logger.info(f"Write completed successfully! Config: {memory_config.config_name}")
|
||||
|
||||
# 写入 neo4j 成功后,删除该用户的兴趣分布缓存,确保下次请求重新生成
|
||||
for lang in ["zh", "en"]:
|
||||
deleted = await InterestMemoryCache.delete_interest_distribution(
|
||||
end_user_id=end_user_id,
|
||||
language=lang,
|
||||
)
|
||||
if deleted:
|
||||
logger.info(f"Invalidated interest distribution cache: end_user_id={end_user_id}, language={lang}")
|
||||
|
||||
write_result = {
|
||||
"status": "success",
|
||||
"data": structured_messages,
|
||||
|
||||
@@ -5,7 +5,6 @@ from langchain_core.messages import HumanMessage
|
||||
from langgraph.constants import START, END
|
||||
from langgraph.graph import StateGraph
|
||||
|
||||
|
||||
from app.db import get_db
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
|
||||
@@ -32,10 +31,21 @@ from app.core.memory.agent.langgraph_graph.routing.routers import (
|
||||
)
|
||||
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def make_read_graph():
|
||||
"""创建并返回 LangGraph 工作流"""
|
||||
"""
|
||||
Create and return a LangGraph workflow for memory reading operations
|
||||
|
||||
Builds a state graph workflow that handles memory retrieval, problem analysis,
|
||||
verification, and summarization. The workflow includes nodes for content input,
|
||||
problem splitting, retrieval, verification, and various summary operations.
|
||||
|
||||
Yields:
|
||||
StateGraph: Compiled LangGraph workflow for memory reading
|
||||
|
||||
Raises:
|
||||
Exception: If workflow creation fails
|
||||
"""
|
||||
try:
|
||||
# Build workflow graph
|
||||
workflow = StateGraph(ReadState)
|
||||
@@ -49,8 +59,8 @@ async def make_read_graph():
|
||||
workflow.add_node("Retrieve_Summary", Retrieve_Summary)
|
||||
workflow.add_node("Summary", Summary)
|
||||
workflow.add_node("Summary_fails", Summary_fails)
|
||||
|
||||
# 添加边
|
||||
|
||||
# Add edges to define workflow flow
|
||||
workflow.add_edge(START, "content_input")
|
||||
workflow.add_conditional_edges("content_input", Split_continue)
|
||||
workflow.add_edge("Input_Summary", END)
|
||||
@@ -62,116 +72,15 @@ async def make_read_graph():
|
||||
workflow.add_edge("Summary_fails", END)
|
||||
workflow.add_edge("Summary", END)
|
||||
|
||||
|
||||
'''-----'''
|
||||
# workflow.add_edge("Retrieve", END)
|
||||
|
||||
# 编译工作流
|
||||
|
||||
# Compile workflow
|
||||
graph = workflow.compile()
|
||||
yield graph
|
||||
|
||||
|
||||
except Exception as e:
|
||||
print(f"创建工作流失败: {e}")
|
||||
raise
|
||||
finally:
|
||||
print("工作流创建完成")
|
||||
|
||||
async def main():
|
||||
"""主函数 - 运行工作流"""
|
||||
message = "昨天有什么好看的电影"
|
||||
end_user_id = '88a459f5_text09' # 组ID
|
||||
storage_type = 'neo4j' # 存储类型
|
||||
search_switch = '1' # 搜索开关
|
||||
user_rag_memory_id = 'wwwwwwww' # 用户RAG记忆ID
|
||||
|
||||
# 获取数据库会话
|
||||
db_session = next(get_db())
|
||||
config_service = MemoryConfigService(db_session)
|
||||
memory_config = config_service.load_memory_config(
|
||||
config_id=17, # 改为整数
|
||||
service_name="MemoryAgentService"
|
||||
)
|
||||
import time
|
||||
start=time.time()
|
||||
try:
|
||||
async with make_read_graph() as graph:
|
||||
config = {"configurable": {"thread_id": end_user_id}}
|
||||
# 初始状态 - 包含所有必要字段
|
||||
initial_state = {"messages": [HumanMessage(content=message)] ,"search_switch":search_switch,"end_user_id":end_user_id
|
||||
,"storage_type":storage_type,"user_rag_memory_id":user_rag_memory_id,"memory_config":memory_config}
|
||||
# 获取节点更新信息
|
||||
_intermediate_outputs = []
|
||||
summary = ''
|
||||
|
||||
async for update_event in graph.astream(
|
||||
initial_state,
|
||||
stream_mode="updates",
|
||||
config=config
|
||||
):
|
||||
for node_name, node_data in update_event.items():
|
||||
print(f"处理节点: {node_name}")
|
||||
|
||||
# 处理不同Summary节点的返回结构
|
||||
if 'Summary' in node_name:
|
||||
if 'InputSummary' in node_data and 'summary_result' in node_data['InputSummary']:
|
||||
summary = node_data['InputSummary']['summary_result']
|
||||
elif 'RetrieveSummary' in node_data and 'summary_result' in node_data['RetrieveSummary']:
|
||||
summary = node_data['RetrieveSummary']['summary_result']
|
||||
elif 'summary' in node_data and 'summary_result' in node_data['summary']:
|
||||
summary = node_data['summary']['summary_result']
|
||||
elif 'SummaryFails' in node_data and 'summary_result' in node_data['SummaryFails']:
|
||||
summary = node_data['SummaryFails']['summary_result']
|
||||
|
||||
spit_data = node_data.get('spit_data', {}).get('_intermediate', None)
|
||||
if spit_data and spit_data != [] and spit_data != {}:
|
||||
_intermediate_outputs.append(spit_data)
|
||||
|
||||
# Problem_Extension 节点
|
||||
problem_extension = node_data.get('problem_extension', {}).get('_intermediate', None)
|
||||
if problem_extension and problem_extension != [] and problem_extension != {}:
|
||||
_intermediate_outputs.append(problem_extension)
|
||||
|
||||
# Retrieve 节点
|
||||
retrieve_node = node_data.get('retrieve', {}).get('_intermediate_outputs', None)
|
||||
if retrieve_node and retrieve_node != [] and retrieve_node != {}:
|
||||
_intermediate_outputs.extend(retrieve_node)
|
||||
|
||||
# Verify 节点
|
||||
verify_n = node_data.get('verify', {}).get('_intermediate', None)
|
||||
if verify_n and verify_n != [] and verify_n != {}:
|
||||
_intermediate_outputs.append(verify_n)
|
||||
|
||||
|
||||
# Summary 节点
|
||||
summary_n = node_data.get('summary', {}).get('_intermediate', None)
|
||||
if summary_n and summary_n != [] and summary_n != {}:
|
||||
_intermediate_outputs.append(summary_n)
|
||||
|
||||
# # 过滤掉空值
|
||||
# _intermediate_outputs = [item for item in _intermediate_outputs if item and item != [] and item != {}]
|
||||
#
|
||||
# # 优化搜索结果
|
||||
# print("=== 开始优化搜索结果 ===")
|
||||
# optimized_outputs = merge_multiple_search_results(_intermediate_outputs)
|
||||
# result=reorder_output_results(optimized_outputs)
|
||||
# # 保存优化后的结果到文件
|
||||
# with open('_intermediate_outputs_optimized.json', 'w', encoding='utf-8') as f:
|
||||
# import json
|
||||
# f.write(json.dumps(result, indent=4, ensure_ascii=False))
|
||||
#
|
||||
print(f"=== 最终摘要 ===")
|
||||
print(summary)
|
||||
|
||||
except Exception as e:
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
end=time.time()
|
||||
print(100*'y')
|
||||
print(f"总耗时: {end-start}s")
|
||||
print(100*'y')
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
asyncio.run(main())
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
|
||||
from typing import Literal
|
||||
|
||||
from app.core.logging_config import get_agent_logger
|
||||
from app.core.memory.agent.utils.llm_tools import ReadState, COUNTState
|
||||
|
||||
|
||||
logger = get_agent_logger(__name__)
|
||||
counter = COUNTState(limit=3)
|
||||
def Split_continue(state:ReadState) -> Literal["Split_The_Problem", "Input_Summary"]:
|
||||
|
||||
|
||||
def Split_continue(state: ReadState) -> Literal["Split_The_Problem", "Input_Summary"]:
|
||||
"""
|
||||
Determine routing based on search_switch value.
|
||||
|
||||
@@ -25,6 +25,7 @@ def Split_continue(state:ReadState) -> Literal["Split_The_Problem", "Input_Summa
|
||||
return 'Input_Summary'
|
||||
return 'Split_The_Problem' # 默认情况
|
||||
|
||||
|
||||
def Retrieve_continue(state) -> Literal["Verify", "Retrieve_Summary"]:
|
||||
"""
|
||||
Determine routing based on search_switch value.
|
||||
@@ -43,8 +44,10 @@ def Retrieve_continue(state) -> Literal["Verify", "Retrieve_Summary"]:
|
||||
elif search_switch == '1':
|
||||
return 'Retrieve_Summary'
|
||||
return 'Retrieve_Summary' # Default based on business logic
|
||||
|
||||
|
||||
def Verify_continue(state: ReadState) -> Literal["Summary", "Summary_fails", "content_input"]:
|
||||
status=state.get('verify', '')['status']
|
||||
status = state.get('verify', '')['status']
|
||||
# loop_count = counter.get_total()
|
||||
if "success" in status:
|
||||
# counter.reset()
|
||||
@@ -53,7 +56,7 @@ def Verify_continue(state: ReadState) -> Literal["Summary", "Summary_fails", "co
|
||||
# if loop_count < 2: # Maximum loop count is 3
|
||||
# return "content_input"
|
||||
# else:
|
||||
# counter.reset()
|
||||
# counter.reset()
|
||||
return "Summary_fails"
|
||||
else:
|
||||
# Add default return value to avoid returning None
|
||||
|
||||
@@ -2,77 +2,104 @@ import json
|
||||
import os
|
||||
|
||||
from app.core.logging_config import get_agent_logger
|
||||
from app.core.memory.agent.langgraph_graph.tools.write_tool import format_parsing, messages_parse
|
||||
from app.core.memory.agent.langgraph_graph.write_graph import make_write_graph, long_term_storage
|
||||
|
||||
from app.core.memory.agent.langgraph_graph.tools.write_tool import format_parsing, messages_parse
|
||||
from app.core.memory.agent.models.write_aggregate_model import WriteAggregateModel
|
||||
from app.core.memory.agent.utils.llm_tools import PROJECT_ROOT_
|
||||
from app.core.memory.agent.utils.redis_tool import write_store
|
||||
from app.core.memory.agent.utils.redis_tool import count_store
|
||||
from app.core.memory.agent.utils.redis_tool import write_store
|
||||
from app.core.memory.agent.utils.template_tools import TemplateService
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.db import get_db_context, get_db
|
||||
from app.db import get_db_context
|
||||
from app.repositories.memory_short_repository import LongTermMemoryRepository
|
||||
from app.schemas.memory_agent_schema import AgentMemory_Long_Term
|
||||
from app.services.memory_konwledges_server import write_rag
|
||||
from app.services.task_service import get_task_memory_write_result
|
||||
from app.tasks import write_message_task
|
||||
from app.utils.config_utils import resolve_config_id
|
||||
|
||||
logger = get_agent_logger(__name__)
|
||||
template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt')
|
||||
|
||||
|
||||
async def write_rag_agent(end_user_id, user_message, ai_message, user_rag_memory_id):
|
||||
# RAG 模式:组合消息为字符串格式(保持原有逻辑)
|
||||
"""
|
||||
Write messages to RAG storage system
|
||||
|
||||
Combines user and AI messages into a single string format and stores them
|
||||
in the RAG (Retrieval-Augmented Generation) knowledge base for future retrieval.
|
||||
|
||||
Args:
|
||||
end_user_id: User identifier for the conversation
|
||||
user_message: User's input message content
|
||||
ai_message: AI's response message content
|
||||
user_rag_memory_id: RAG memory identifier for storage location
|
||||
"""
|
||||
# RAG mode: combine messages into string format (maintain original logic)
|
||||
combined_message = f"user: {user_message}\nassistant: {ai_message}"
|
||||
await write_rag(end_user_id, combined_message, user_rag_memory_id)
|
||||
logger.info(f'RAG_Agent:{end_user_id};{user_rag_memory_id}')
|
||||
async def write(storage_type, end_user_id, user_message, ai_message, user_rag_memory_id, actual_end_user_id,
|
||||
actual_config_id, long_term_messages=[]):
|
||||
|
||||
|
||||
async def write(
|
||||
storage_type,
|
||||
end_user_id,
|
||||
user_message,
|
||||
ai_message,
|
||||
user_rag_memory_id,
|
||||
actual_end_user_id,
|
||||
actual_config_id,
|
||||
long_term_messages=None
|
||||
):
|
||||
"""
|
||||
写入记忆(支持结构化消息)
|
||||
Write memory with structured message support
|
||||
|
||||
Handles memory writing operations for different storage types (Neo4j/RAG).
|
||||
Supports both individual message pairs and batch long-term message processing.
|
||||
|
||||
Args:
|
||||
storage_type: 存储类型 (neo4j/rag)
|
||||
end_user_id: 终端用户ID
|
||||
user_message: 用户消息内容
|
||||
ai_message: AI 回复内容
|
||||
user_rag_memory_id: RAG 记忆ID
|
||||
actual_end_user_id: 实际用户ID
|
||||
actual_config_id: 配置ID
|
||||
storage_type: Storage type identifier ("neo4j" or "rag")
|
||||
end_user_id: Terminal user identifier
|
||||
user_message: User message content
|
||||
ai_message: AI response content
|
||||
user_rag_memory_id: RAG memory identifier
|
||||
actual_end_user_id: Actual user identifier for storage
|
||||
actual_config_id: Configuration identifier
|
||||
long_term_messages: Optional list of structured messages for batch processing
|
||||
|
||||
逻辑说明:
|
||||
- RAG 模式:组合 user_message 和 ai_message 为字符串格式,保持原有逻辑不变
|
||||
- Neo4j 模式:使用结构化消息列表
|
||||
1. 如果 user_message 和 ai_message 都不为空:创建配对消息 [user, assistant]
|
||||
2. 如果只有 user_message:创建单条用户消息 [user](用于历史记忆场景)
|
||||
3. 每条消息会被转换为独立的 Chunk,保留 speaker 字段
|
||||
Logic explanation:
|
||||
- RAG mode: Combines user_message and ai_message into string format, maintains original logic
|
||||
- Neo4j mode: Uses structured message lists
|
||||
1. If both user_message and ai_message are not empty: Creates paired messages [user, assistant]
|
||||
2. If only user_message exists: Creates single user message [user] (for historical memory scenarios)
|
||||
3. Each message is converted to independent Chunk, preserving speaker field
|
||||
"""
|
||||
|
||||
db = next(get_db())
|
||||
try:
|
||||
if long_term_messages is None:
|
||||
long_term_messages = []
|
||||
with get_db_context() as db:
|
||||
actual_config_id = resolve_config_id(actual_config_id, db)
|
||||
# Neo4j 模式:使用结构化消息列表
|
||||
# Neo4j mode: Use structured message lists
|
||||
structured_messages = []
|
||||
|
||||
# 始终添加用户消息(如果不为空)
|
||||
# Always add user message (if not empty)
|
||||
if isinstance(user_message, str) and user_message.strip() != "":
|
||||
structured_messages.append({"role": "user", "content": user_message})
|
||||
|
||||
# 只有当 AI 回复不为空时才添加 assistant 消息
|
||||
# Only add assistant message when AI reply is not empty
|
||||
if isinstance(ai_message, str) and ai_message.strip() != "":
|
||||
structured_messages.append({"role": "assistant", "content": ai_message})
|
||||
|
||||
# 如果提供了 long_term_messages,使用它替代 structured_messages
|
||||
# If long_term_messages provided, use it to replace structured_messages
|
||||
if long_term_messages and isinstance(long_term_messages, list):
|
||||
structured_messages = long_term_messages
|
||||
elif long_term_messages and isinstance(long_term_messages, str):
|
||||
# 如果是 JSON 字符串,先解析
|
||||
# If it's a JSON string, parse it first
|
||||
try:
|
||||
structured_messages = json.loads(long_term_messages)
|
||||
except json.JSONDecodeError:
|
||||
logger.error(f"Failed to parse long_term_messages as JSON: {long_term_messages}")
|
||||
|
||||
# 如果没有消息,直接返回
|
||||
# If no messages, return directly
|
||||
if not structured_messages:
|
||||
logger.warning(f"No messages to write for user {actual_end_user_id}")
|
||||
return
|
||||
@@ -80,29 +107,41 @@ async def write(storage_type, end_user_id, user_message, ai_message, user_rag_me
|
||||
logger.info(
|
||||
f"[WRITE] Submitting Celery task - user={actual_end_user_id}, messages={len(structured_messages)}, config={actual_config_id}")
|
||||
write_id = write_message_task.delay(
|
||||
actual_end_user_id, # end_user_id: 用户ID
|
||||
structured_messages, # message: JSON 字符串格式的消息列表
|
||||
str(actual_config_id), # config_id: 配置ID字符串
|
||||
actual_end_user_id, # end_user_id: User ID
|
||||
structured_messages, # message: JSON string format message list
|
||||
str(actual_config_id), # config_id: Configuration ID string
|
||||
storage_type, # storage_type: "neo4j"
|
||||
user_rag_memory_id or "" # user_rag_memory_id: RAG记忆ID(Neo4j模式下不使用)
|
||||
user_rag_memory_id or "" # user_rag_memory_id: RAG memory ID (not used in Neo4j mode)
|
||||
)
|
||||
logger.info(f"[WRITE] Celery task submitted - task_id={write_id}")
|
||||
write_status = get_task_memory_write_result(str(write_id))
|
||||
logger.info(f'[WRITE] Task result - user={actual_end_user_id}, status={write_status}')
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
async def term_memory_save(long_term_messages,actual_config_id,end_user_id,type,scope):
|
||||
|
||||
async def term_memory_save(long_term_messages, actual_config_id, end_user_id, type, scope):
|
||||
"""
|
||||
Save long-term memory data to database
|
||||
|
||||
Handles the storage of long-term memory data based on different strategies
|
||||
(chunk-based or aggregate-based) and manages the transition from short-term
|
||||
to long-term memory storage.
|
||||
|
||||
Args:
|
||||
long_term_messages: Long-term message data to be saved
|
||||
actual_config_id: Configuration identifier for memory settings
|
||||
end_user_id: User identifier for memory association
|
||||
type: Memory storage strategy type (STRATEGY_CHUNK or STRATEGY_AGGREGATE)
|
||||
scope: Scope/window size for memory processing
|
||||
"""
|
||||
with get_db_context() as db_session:
|
||||
repo = LongTermMemoryRepository(db_session)
|
||||
|
||||
|
||||
from app.core.memory.agent.utils.redis_tool import write_store
|
||||
result = write_store.get_session_by_userid(end_user_id)
|
||||
if type==AgentMemory_Long_Term.STRATEGY_CHUNK or AgentMemory_Long_Term.STRATEGY_AGGREGATE:
|
||||
if type == AgentMemory_Long_Term.STRATEGY_CHUNK or AgentMemory_Long_Term.STRATEGY_AGGREGATE:
|
||||
data = await format_parsing(result, "dict")
|
||||
chunk_data = data[:scope]
|
||||
if len(chunk_data)==scope:
|
||||
if len(chunk_data) == scope:
|
||||
repo.upsert(end_user_id, chunk_data)
|
||||
logger.info(f'---------写入短长期-----------')
|
||||
else:
|
||||
@@ -112,18 +151,23 @@ async def term_memory_save(long_term_messages,actual_config_id,end_user_id,type,
|
||||
logger.info(f'写入短长期:')
|
||||
|
||||
|
||||
"""Window-based dialogue processing"""
|
||||
|
||||
'''根据窗口'''
|
||||
async def window_dialogue(end_user_id,langchain_messages,memory_config,scope):
|
||||
'''
|
||||
根据窗口获取redis数据,写入neo4j:
|
||||
Args:
|
||||
end_user_id: 终端用户ID
|
||||
memory_config: 内存配置对象
|
||||
langchain_messages:原始数据LIST
|
||||
scope:窗口大小
|
||||
'''
|
||||
scope=scope
|
||||
|
||||
async def window_dialogue(end_user_id, langchain_messages, memory_config, scope):
|
||||
"""
|
||||
Process dialogue based on window size and write to Neo4j
|
||||
|
||||
Manages conversation data based on a sliding window approach. When the window
|
||||
reaches the specified scope size, it triggers long-term memory storage to Neo4j.
|
||||
|
||||
Args:
|
||||
end_user_id: Terminal user identifier
|
||||
memory_config: Memory configuration object containing settings
|
||||
langchain_messages: Original message data list
|
||||
scope: Window size determining when to trigger long-term storage
|
||||
"""
|
||||
scope = scope
|
||||
is_end_user_id = count_store.get_sessions_count(end_user_id)
|
||||
if is_end_user_id is not False:
|
||||
is_end_user_id = count_store.get_sessions_count(end_user_id)[0]
|
||||
@@ -134,51 +178,73 @@ async def window_dialogue(end_user_id,langchain_messages,memory_config,scope):
|
||||
count_store.update_sessions_count(end_user_id, is_end_user_id, langchain_messages)
|
||||
elif int(is_end_user_id) == int(scope):
|
||||
logger.info('写入长期记忆NEO4J')
|
||||
formatted_messages = (redis_messages)
|
||||
# 获取 config_id(如果 memory_config 是对象,提取 config_id;否则直接使用)
|
||||
formatted_messages = redis_messages
|
||||
# Get config_id (if memory_config is an object, extract config_id; otherwise use directly)
|
||||
if hasattr(memory_config, 'config_id'):
|
||||
config_id = memory_config.config_id
|
||||
else:
|
||||
config_id = memory_config
|
||||
|
||||
await write(AgentMemory_Long_Term.STORAGE_NEO4J, end_user_id, "", "", None, end_user_id,
|
||||
config_id, formatted_messages)
|
||||
|
||||
await write(
|
||||
AgentMemory_Long_Term.STORAGE_NEO4J,
|
||||
end_user_id,
|
||||
"",
|
||||
"",
|
||||
None,
|
||||
end_user_id,
|
||||
config_id,
|
||||
formatted_messages
|
||||
)
|
||||
count_store.update_sessions_count(end_user_id, 1, langchain_messages)
|
||||
else:
|
||||
count_store.save_sessions_count(end_user_id, 1, langchain_messages)
|
||||
|
||||
|
||||
"""根据时间"""
|
||||
async def memory_long_term_storage(end_user_id,memory_config,time):
|
||||
'''
|
||||
根据时间获取redis数据,写入neo4j:
|
||||
Args:
|
||||
end_user_id: 终端用户ID
|
||||
memory_config: 内存配置对象
|
||||
'''
|
||||
"""Time-based memory processing"""
|
||||
|
||||
|
||||
async def memory_long_term_storage(end_user_id, memory_config, time):
|
||||
"""
|
||||
Process memory storage based on time intervals and write to Neo4j
|
||||
|
||||
Retrieves Redis data based on time intervals and writes it to Neo4j for
|
||||
long-term storage. This function handles time-based memory consolidation.
|
||||
|
||||
Args:
|
||||
end_user_id: Terminal user identifier
|
||||
memory_config: Memory configuration object containing settings
|
||||
time: Time interval for data retrieval
|
||||
"""
|
||||
long_time_data = write_store.find_user_recent_sessions(end_user_id, time)
|
||||
format_messages = (long_time_data)
|
||||
messages=[]
|
||||
memory_config=memory_config.config_id
|
||||
format_messages = long_time_data
|
||||
messages = []
|
||||
memory_config = memory_config.config_id
|
||||
for i in format_messages:
|
||||
message=json.loads(i['Query'])
|
||||
messages+= message
|
||||
if format_messages!=[]:
|
||||
message = json.loads(i['Query'])
|
||||
messages += message
|
||||
if format_messages:
|
||||
await write(AgentMemory_Long_Term.STORAGE_NEO4J, end_user_id, "", "", None, end_user_id,
|
||||
memory_config, messages)
|
||||
'''聚合判断'''
|
||||
|
||||
|
||||
async def aggregate_judgment(end_user_id: str, ori_messages: list, memory_config) -> dict:
|
||||
"""
|
||||
聚合判断函数:判断输入句子和历史消息是否描述同一事件
|
||||
Aggregation judgment function: determine if input sentence and historical messages describe the same event
|
||||
|
||||
Uses LLM-based analysis to determine whether new messages should be aggregated with existing
|
||||
historical data or stored as separate events. This helps optimize memory storage and retrieval.
|
||||
|
||||
Args:
|
||||
end_user_id: 终端用户ID
|
||||
ori_messages: 原始消息列表,格式如 [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]
|
||||
memory_config: 内存配置对象
|
||||
"""
|
||||
end_user_id: Terminal user identifier
|
||||
ori_messages: Original message list, format like [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]
|
||||
memory_config: Memory configuration object containing LLM settings
|
||||
|
||||
Returns:
|
||||
dict: Aggregation judgment result containing is_same_event flag and processed output
|
||||
"""
|
||||
history = None
|
||||
try:
|
||||
# 1. 获取历史会话数据(使用新方法)
|
||||
# 1. Get historical session data (using new method)
|
||||
result = write_store.get_all_sessions_by_end_user_id(end_user_id)
|
||||
history = await format_parsing(result)
|
||||
if not result:
|
||||
@@ -210,7 +276,7 @@ async def aggregate_judgment(end_user_id: str, ori_messages: list, memory_config
|
||||
output_value = structured.output
|
||||
if isinstance(output_value, list):
|
||||
output_value = [
|
||||
{"role": msg.role, "content": msg.content}
|
||||
{"role": msg.role, "content": msg.content}
|
||||
for msg in output_value
|
||||
]
|
||||
|
||||
@@ -223,16 +289,16 @@ async def aggregate_judgment(end_user_id: str, ori_messages: list, memory_config
|
||||
await write("neo4j", end_user_id, "", "", None, end_user_id,
|
||||
memory_config.config_id, output_value)
|
||||
return result_dict
|
||||
|
||||
|
||||
except Exception as e:
|
||||
print(f"[aggregate_judgment] 发生错误: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
return {
|
||||
"is_same_event": False,
|
||||
"output": ori_messages,
|
||||
"messages": ori_messages,
|
||||
"history": history if 'history' in locals() else [],
|
||||
"error": str(e)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,41 +2,53 @@ import asyncio
|
||||
import json
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
|
||||
from langchain.tools import tool
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
from app.core.memory.src.search import (
|
||||
search_by_temporal,
|
||||
search_by_keyword_temporal,
|
||||
)
|
||||
|
||||
|
||||
def extract_tool_message_content(response):
|
||||
"""从agent响应中提取ToolMessage内容和工具名称"""
|
||||
"""
|
||||
Extract ToolMessage content and tool names from agent response
|
||||
|
||||
Parses agent response messages to extract tool execution results and metadata.
|
||||
Handles JSON parsing and provides structured access to tool output data.
|
||||
|
||||
Args:
|
||||
response: Agent response dictionary containing messages
|
||||
|
||||
Returns:
|
||||
dict: Dictionary containing tool_name and parsed content, or None if no tool message found
|
||||
- tool_name: Name of the executed tool
|
||||
- content: Parsed tool execution result (JSON or raw text)
|
||||
"""
|
||||
messages = response.get('messages', [])
|
||||
|
||||
for message in messages:
|
||||
if hasattr(message, 'tool_call_id') and hasattr(message, 'content'):
|
||||
# 这是一个ToolMessage
|
||||
# This is a ToolMessage
|
||||
tool_content = message.content
|
||||
tool_name = None
|
||||
|
||||
# 尝试获取工具名称
|
||||
# Try to get tool name
|
||||
if hasattr(message, 'name'):
|
||||
tool_name = message.name
|
||||
elif hasattr(message, 'tool_name'):
|
||||
tool_name = message.tool_name
|
||||
|
||||
try:
|
||||
# 解析JSON内容
|
||||
# Parse JSON content
|
||||
parsed_content = json.loads(tool_content)
|
||||
return {
|
||||
'tool_name': tool_name,
|
||||
'content': parsed_content
|
||||
}
|
||||
except json.JSONDecodeError:
|
||||
# 如果不是JSON格式,直接返回内容
|
||||
# If not JSON format, return content directly
|
||||
return {
|
||||
'tool_name': tool_name,
|
||||
'content': tool_content
|
||||
@@ -46,38 +58,61 @@ def extract_tool_message_content(response):
|
||||
|
||||
|
||||
class TimeRetrievalInput(BaseModel):
|
||||
"""时间检索工具的输入模式"""
|
||||
"""
|
||||
Input schema for time retrieval tool
|
||||
|
||||
Defines the expected input parameters for time-based retrieval operations.
|
||||
Used for validation and documentation of tool parameters.
|
||||
|
||||
Attributes:
|
||||
context: User input query content for search
|
||||
end_user_id: Group ID for filtering search results, defaults to test user
|
||||
"""
|
||||
context: str = Field(description="用户输入的查询内容")
|
||||
end_user_id: str = Field(default="88a459f5_text09", description="组ID,用于过滤搜索结果")
|
||||
|
||||
|
||||
def create_time_retrieval_tool(end_user_id: str):
|
||||
"""
|
||||
创建一个带有特定end_user_id的TimeRetrieval工具(同步版本),用于按时间范围搜索语句(Statements)
|
||||
Create a TimeRetrieval tool with specific end_user_id (synchronous version) for searching statements by time range
|
||||
|
||||
Creates a specialized time-based retrieval tool that searches for statements within
|
||||
specified time ranges. Includes field cleaning functionality to remove unnecessary
|
||||
metadata from search results.
|
||||
|
||||
Args:
|
||||
end_user_id: User identifier for scoping search results
|
||||
|
||||
Returns:
|
||||
function: Configured TimeRetrievalWithGroupId tool function
|
||||
"""
|
||||
|
||||
|
||||
def clean_temporal_result_fields(data):
|
||||
"""
|
||||
清理时间搜索结果中不需要的字段,并修改结构
|
||||
Clean unnecessary fields from temporal search results and modify structure
|
||||
|
||||
Removes metadata fields that are not needed for end-user consumption and
|
||||
restructures the response format for better usability.
|
||||
|
||||
Args:
|
||||
data: 要清理的数据
|
||||
data: Data to be cleaned (dict, list, or other types)
|
||||
|
||||
Returns:
|
||||
清理后的数据
|
||||
Cleaned data with unnecessary fields removed
|
||||
"""
|
||||
# 需要过滤的字段列表
|
||||
# List of fields to filter out
|
||||
fields_to_remove = {
|
||||
'id', 'apply_id', 'user_id', 'chunk_id', 'created_at',
|
||||
'id', 'apply_id', 'user_id', 'chunk_id', 'created_at',
|
||||
'valid_at', 'invalid_at', 'statement_ids'
|
||||
}
|
||||
|
||||
|
||||
if isinstance(data, dict):
|
||||
cleaned = {}
|
||||
for key, value in data.items():
|
||||
if key == 'statements' and isinstance(value, dict) and 'statements' in value:
|
||||
# 将 statements: {"statements": [...]} 改为 time_search: {"statements": [...]}
|
||||
# Change statements: {"statements": [...]} to time_search: {"statements": [...]}
|
||||
cleaned_value = clean_temporal_result_fields(value)
|
||||
# 进一步将内部的 statements 改为 time_search
|
||||
# Further change internal statements to time_search
|
||||
if 'statements' in cleaned_value:
|
||||
cleaned['results'] = {
|
||||
'time_search': cleaned_value['statements']
|
||||
@@ -91,26 +126,35 @@ def create_time_retrieval_tool(end_user_id: str):
|
||||
return [clean_temporal_result_fields(item) for item in data]
|
||||
else:
|
||||
return data
|
||||
|
||||
|
||||
@tool
|
||||
def TimeRetrievalWithGroupId(context: str, start_date: str = None, end_date: str = None, end_user_id_param: str = None, clean_output: bool = True) -> str:
|
||||
def TimeRetrievalWithGroupId(context: str, start_date: str = None, end_date: str = None,
|
||||
end_user_id_param: str = None, clean_output: bool = True) -> str:
|
||||
"""
|
||||
优化的时间检索工具,只结合时间范围搜索(同步版本),自动过滤不需要的元数据字段
|
||||
显式接收参数:
|
||||
- context: 查询上下文内容
|
||||
- start_date: 开始时间(可选,格式:YYYY-MM-DD)
|
||||
- end_date: 结束时间(可选,格式:YYYY-MM-DD)
|
||||
- end_user_id_param: 组ID(可选,用于覆盖默认组ID)
|
||||
- clean_output: 是否清理输出中的元数据字段
|
||||
-end_date 需要根据用户的描述获取结束的时间,输出格式用strftime("%Y-%m-%d")
|
||||
Optimized time retrieval tool, combines time range search only (synchronous version), automatically filters unnecessary metadata fields
|
||||
|
||||
Performs time-based search operations with automatic metadata filtering. Supports
|
||||
flexible date range specification and provides clean, user-friendly output.
|
||||
|
||||
Explicit parameters:
|
||||
- context: Query context content
|
||||
- start_date: Start time (optional, format: YYYY-MM-DD)
|
||||
- end_date: End time (optional, format: YYYY-MM-DD)
|
||||
- end_user_id_param: Group ID (optional, overrides default group ID)
|
||||
- clean_output: Whether to clean metadata fields from output
|
||||
- end_date needs to be obtained based on user description, output format uses strftime("%Y-%m-%d")
|
||||
|
||||
Returns:
|
||||
str: JSON formatted search results with temporal data
|
||||
"""
|
||||
|
||||
async def _async_search():
|
||||
# 使用传入的参数或默认值
|
||||
# Use passed parameters or default values
|
||||
actual_end_user_id = end_user_id_param or end_user_id
|
||||
actual_end_date = end_date or datetime.now().strftime("%Y-%m-%d")
|
||||
actual_start_date = start_date or (datetime.now() - timedelta(days=7)).strftime("%Y-%m-%d")
|
||||
|
||||
# 基本时间搜索
|
||||
# Basic time search
|
||||
results = await search_by_temporal(
|
||||
end_user_id=actual_end_user_id,
|
||||
start_date=actual_start_date,
|
||||
@@ -118,33 +162,43 @@ def create_time_retrieval_tool(end_user_id: str):
|
||||
limit=10
|
||||
)
|
||||
|
||||
# 清理结果中不需要的字段
|
||||
# Clean unnecessary fields from results
|
||||
if clean_output:
|
||||
cleaned_results = clean_temporal_result_fields(results)
|
||||
else:
|
||||
cleaned_results = results
|
||||
|
||||
return json.dumps(cleaned_results, ensure_ascii=False, indent=2)
|
||||
|
||||
|
||||
return asyncio.run(_async_search())
|
||||
|
||||
@tool
|
||||
def KeywordTimeRetrieval(context: str, days_back: int = 7, start_date: str = None, end_date: str = None, clean_output: bool = True) -> str:
|
||||
def KeywordTimeRetrieval(context: str, days_back: int = 7, start_date: str = None, end_date: str = None,
|
||||
clean_output: bool = True) -> str:
|
||||
"""
|
||||
优化的关键词时间检索工具,结合关键词和时间范围搜索(同步版本),自动过滤不需要的元数据字段
|
||||
显式接收参数:
|
||||
- context: 查询内容
|
||||
- days_back: 向前搜索的天数,默认7天
|
||||
- start_date: 开始时间(可选,格式:YYYY-MM-DD)
|
||||
- end_date: 结束时间(可选,格式:YYYY-MM-DD)
|
||||
- clean_output: 是否清理输出中的元数据字段
|
||||
- end_date 需要根据用户的描述获取结束的时间,输出格式用strftime("%Y-%m-%d")
|
||||
Optimized keyword time retrieval tool, combines keyword and time range search (synchronous version), automatically filters unnecessary metadata fields
|
||||
|
||||
Performs combined keyword and temporal search operations with automatic metadata
|
||||
filtering. Provides more targeted search results by combining content relevance
|
||||
with time-based filtering.
|
||||
|
||||
Explicit parameters:
|
||||
- context: Query content for keyword matching
|
||||
- days_back: Number of days to search backwards, default 7 days
|
||||
- start_date: Start time (optional, format: YYYY-MM-DD)
|
||||
- end_date: End time (optional, format: YYYY-MM-DD)
|
||||
- clean_output: Whether to clean metadata fields from output
|
||||
- end_date needs to be obtained based on user description, output format uses strftime("%Y-%m-%d")
|
||||
|
||||
Returns:
|
||||
str: JSON formatted search results combining keyword and temporal data
|
||||
"""
|
||||
|
||||
async def _async_search():
|
||||
actual_end_date = end_date or datetime.now().strftime("%Y-%m-%d")
|
||||
actual_start_date = start_date or (datetime.now() - timedelta(days=days_back)).strftime("%Y-%m-%d")
|
||||
|
||||
# 关键词时间搜索
|
||||
# Keyword time search
|
||||
results = await search_by_keyword_temporal(
|
||||
query_text=context,
|
||||
end_user_id=end_user_id,
|
||||
@@ -153,7 +207,7 @@ def create_time_retrieval_tool(end_user_id: str):
|
||||
limit=15
|
||||
)
|
||||
|
||||
# 清理结果中不需要的字段
|
||||
# Clean unnecessary fields from results
|
||||
if clean_output:
|
||||
cleaned_results = clean_temporal_result_fields(results)
|
||||
else:
|
||||
@@ -162,51 +216,61 @@ def create_time_retrieval_tool(end_user_id: str):
|
||||
return json.dumps(cleaned_results, ensure_ascii=False, indent=2)
|
||||
|
||||
return asyncio.run(_async_search())
|
||||
|
||||
|
||||
return TimeRetrievalWithGroupId
|
||||
|
||||
|
||||
def create_hybrid_retrieval_tool_async(memory_config, **search_params):
|
||||
"""
|
||||
创建混合检索工具,使用run_hybrid_search进行混合检索,优化输出格式并过滤不需要的字段
|
||||
Create hybrid retrieval tool using run_hybrid_search for hybrid retrieval, optimize output format and filter unnecessary fields
|
||||
|
||||
Creates an advanced hybrid search tool that combines multiple search strategies
|
||||
(keyword, vector, hybrid) with automatic result cleaning and formatting.
|
||||
|
||||
Args:
|
||||
memory_config: 内存配置对象
|
||||
**search_params: 搜索参数,包含end_user_id, limit, include等
|
||||
memory_config: Memory configuration object containing LLM and search settings
|
||||
**search_params: Search parameters including end_user_id, limit, include, etc.
|
||||
|
||||
Returns:
|
||||
function: Configured HybridSearch tool function with async capabilities
|
||||
"""
|
||||
|
||||
|
||||
def clean_result_fields(data):
|
||||
"""
|
||||
递归清理结果中不需要的字段
|
||||
Recursively clean unnecessary fields from results
|
||||
|
||||
Removes metadata fields that are not needed for end-user consumption,
|
||||
improving readability and reducing response size.
|
||||
|
||||
Args:
|
||||
data: 要清理的数据(可能是字典、列表或其他类型)
|
||||
data: Data to be cleaned (can be dict, list, or other types)
|
||||
|
||||
Returns:
|
||||
清理后的数据
|
||||
Cleaned data with unnecessary fields removed
|
||||
"""
|
||||
# 需要过滤的字段列表
|
||||
# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
|
||||
# List of fields to filter out
|
||||
# TODO: fact_summary functionality temporarily disabled, will be enabled after future development
|
||||
fields_to_remove = {
|
||||
'invalid_at', 'valid_at', 'chunk_id_from_rel', 'entity_ids',
|
||||
'expired_at', 'created_at', 'chunk_id', 'id', 'apply_id',
|
||||
'user_id', 'statement_ids', 'updated_at',"chunk_ids" ,"fact_summary"
|
||||
'invalid_at', 'valid_at', 'chunk_id_from_rel', 'entity_ids',
|
||||
'expired_at', 'created_at', 'chunk_id', 'apply_id',
|
||||
'user_id', 'statement_ids', 'updated_at', "chunk_ids", "fact_summary"
|
||||
}
|
||||
|
||||
# 注意:'id' 字段保留,community 展开时需要用 community id 查询成员 statements
|
||||
|
||||
if isinstance(data, dict):
|
||||
# 对字典进行清理
|
||||
# Clean dictionary
|
||||
cleaned = {}
|
||||
for key, value in data.items():
|
||||
if key not in fields_to_remove:
|
||||
cleaned[key] = clean_result_fields(value) # 递归清理嵌套数据
|
||||
cleaned[key] = clean_result_fields(value) # Recursively clean nested data
|
||||
return cleaned
|
||||
elif isinstance(data, list):
|
||||
# 对列表中的每个元素进行清理
|
||||
# Clean each element in list
|
||||
return [clean_result_fields(item) for item in data]
|
||||
else:
|
||||
# 其他类型直接返回
|
||||
# Return other types directly
|
||||
return data
|
||||
|
||||
|
||||
@tool
|
||||
async def HybridSearch(
|
||||
context: str,
|
||||
@@ -216,57 +280,63 @@ def create_hybrid_retrieval_tool_async(memory_config, **search_params):
|
||||
rerank_alpha: float = 0.6,
|
||||
use_forgetting_rerank: bool = False,
|
||||
use_llm_rerank: bool = False,
|
||||
clean_output: bool = True # 新增:是否清理输出字段
|
||||
clean_output: bool = True # New: whether to clean output fields
|
||||
) -> str:
|
||||
"""
|
||||
优化的混合检索工具,支持关键词、向量和混合搜索,自动过滤不需要的元数据字段
|
||||
Optimized hybrid retrieval tool, supports keyword, vector and hybrid search, automatically filters unnecessary metadata fields
|
||||
|
||||
Provides comprehensive search capabilities combining multiple search strategies
|
||||
with intelligent result ranking and automatic metadata filtering for clean output.
|
||||
|
||||
Args:
|
||||
context: 查询内容
|
||||
search_type: 搜索类型 ('keyword', 'embedding', 'hybrid')
|
||||
limit: 结果数量限制
|
||||
end_user_id: 组ID,用于过滤搜索结果
|
||||
rerank_alpha: 重排序权重参数
|
||||
use_forgetting_rerank: 是否使用遗忘重排序
|
||||
use_llm_rerank: 是否使用LLM重排序
|
||||
clean_output: 是否清理输出中的元数据字段
|
||||
context: Query content for search
|
||||
search_type: Search type ('keyword', 'embedding', 'hybrid')
|
||||
limit: Result quantity limit
|
||||
end_user_id: Group ID for filtering search results
|
||||
rerank_alpha: Reranking weight parameter for result scoring
|
||||
use_forgetting_rerank: Whether to use forgetting-based reranking
|
||||
use_llm_rerank: Whether to use LLM-based reranking
|
||||
clean_output: Whether to clean metadata fields from output
|
||||
|
||||
Returns:
|
||||
str: JSON formatted comprehensive search results
|
||||
"""
|
||||
try:
|
||||
# 导入run_hybrid_search函数
|
||||
# Import run_hybrid_search function
|
||||
from app.core.memory.src.search import run_hybrid_search
|
||||
|
||||
# 合并参数,优先使用传入的参数
|
||||
# Merge parameters, prioritize passed parameters
|
||||
final_params = {
|
||||
"query_text": context,
|
||||
"search_type": search_type,
|
||||
"end_user_id": end_user_id or search_params.get("end_user_id"),
|
||||
"limit": limit or search_params.get("limit", 10),
|
||||
"include": search_params.get("include", ["summaries", "statements", "chunks", "entities"]),
|
||||
"output_path": None, # 不保存到文件
|
||||
"include": search_params.get("include", ["summaries", "statements", "chunks", "entities", "communities"]),
|
||||
"output_path": None, # Don't save to file
|
||||
"memory_config": memory_config,
|
||||
"rerank_alpha": rerank_alpha,
|
||||
"use_forgetting_rerank": use_forgetting_rerank,
|
||||
"use_llm_rerank": use_llm_rerank
|
||||
}
|
||||
|
||||
# 执行混合检索
|
||||
# Execute hybrid retrieval
|
||||
raw_results = await run_hybrid_search(**final_params)
|
||||
|
||||
# 清理结果中不需要的字段
|
||||
# Clean unnecessary fields from results
|
||||
if clean_output:
|
||||
cleaned_results = clean_result_fields(raw_results)
|
||||
else:
|
||||
cleaned_results = raw_results
|
||||
|
||||
# 格式化返回结果
|
||||
# Format return results
|
||||
formatted_results = {
|
||||
"search_query": context,
|
||||
"search_type": search_type,
|
||||
"results": cleaned_results
|
||||
}
|
||||
|
||||
|
||||
return json.dumps(formatted_results, ensure_ascii=False, indent=2, default=str)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
error_result = {
|
||||
"error": f"混合检索失败: {str(e)}",
|
||||
@@ -275,38 +345,52 @@ def create_hybrid_retrieval_tool_async(memory_config, **search_params):
|
||||
"timestamp": datetime.now().isoformat()
|
||||
}
|
||||
return json.dumps(error_result, ensure_ascii=False, indent=2)
|
||||
|
||||
|
||||
return HybridSearch
|
||||
|
||||
|
||||
def create_hybrid_retrieval_tool_sync(memory_config, **search_params):
|
||||
"""
|
||||
创建同步版本的混合检索工具,优化输出格式并过滤不需要的字段
|
||||
Create synchronous version of hybrid retrieval tool, optimize output format and filter unnecessary fields
|
||||
|
||||
Creates a synchronous wrapper around the async hybrid search functionality,
|
||||
making it compatible with synchronous tool execution environments.
|
||||
|
||||
Args:
|
||||
memory_config: 内存配置对象
|
||||
**search_params: 搜索参数
|
||||
memory_config: Memory configuration object containing search settings
|
||||
**search_params: Search parameters for configuration
|
||||
|
||||
Returns:
|
||||
function: Configured HybridSearchSync tool function
|
||||
"""
|
||||
|
||||
@tool
|
||||
def HybridSearchSync(
|
||||
context: str,
|
||||
search_type: str = "hybrid",
|
||||
limit: int = 10,
|
||||
end_user_id: str = None,
|
||||
clean_output: bool = True
|
||||
context: str,
|
||||
search_type: str = "hybrid",
|
||||
limit: int = 10,
|
||||
end_user_id: str = None,
|
||||
clean_output: bool = True
|
||||
) -> str:
|
||||
"""
|
||||
优化的混合检索工具(同步版本),自动过滤不需要的元数据字段
|
||||
Optimized hybrid retrieval tool (synchronous version), automatically filters unnecessary metadata fields
|
||||
|
||||
Provides the same hybrid search capabilities as the async version but in a
|
||||
synchronous execution context. Automatically handles async-to-sync conversion.
|
||||
|
||||
Args:
|
||||
context: 查询内容
|
||||
search_type: 搜索类型 ('keyword', 'embedding', 'hybrid')
|
||||
limit: 结果数量限制
|
||||
end_user_id: 组ID,用于过滤搜索结果
|
||||
clean_output: 是否清理输出中的元数据字段
|
||||
context: Query content for search
|
||||
search_type: Search type ('keyword', 'embedding', 'hybrid')
|
||||
limit: Result quantity limit
|
||||
end_user_id: Group ID for filtering search results
|
||||
clean_output: Whether to clean metadata fields from output
|
||||
|
||||
Returns:
|
||||
str: JSON formatted search results
|
||||
"""
|
||||
|
||||
async def _async_search():
|
||||
# 创建异步工具并执行
|
||||
# Create async tool and execute
|
||||
async_tool = create_hybrid_retrieval_tool_async(memory_config, **search_params)
|
||||
return await async_tool.ainvoke({
|
||||
"context": context,
|
||||
@@ -315,7 +399,7 @@ def create_hybrid_retrieval_tool_sync(memory_config, **search_params):
|
||||
"end_user_id": end_user_id,
|
||||
"clean_output": clean_output
|
||||
})
|
||||
|
||||
|
||||
return asyncio.run(_async_search())
|
||||
|
||||
return HybridSearchSync
|
||||
|
||||
return HybridSearchSync
|
||||
|
||||
@@ -1,20 +1,28 @@
|
||||
import json
|
||||
|
||||
from langchain_core.messages import HumanMessage, AIMessage
|
||||
async def format_parsing(messages: list,type:str='string'):
|
||||
|
||||
|
||||
async def format_parsing(messages: list, type: str = 'string'):
|
||||
"""
|
||||
格式化解析消息列表
|
||||
Format and parse message lists into different output types
|
||||
|
||||
Processes message lists from storage and converts them into either string format
|
||||
or dictionary format based on the specified type parameter. Handles JSON parsing
|
||||
and role-based message organization.
|
||||
|
||||
Args:
|
||||
messages: 消息列表
|
||||
type: 返回类型 ('string' 或 'dict')
|
||||
messages: List of message objects from storage containing message data
|
||||
type: Return type specification ('string' for text format, 'dict' for key-value pairs)
|
||||
|
||||
Returns:
|
||||
格式化后的消息列表
|
||||
list: Formatted message list in the specified format
|
||||
- 'string': List of formatted text messages with role prefixes
|
||||
- 'dict': List of dictionaries mapping user messages to AI responses
|
||||
"""
|
||||
result = []
|
||||
user=[]
|
||||
ai=[]
|
||||
user = []
|
||||
ai = []
|
||||
|
||||
for message in messages:
|
||||
hstory_messages = message['messages']
|
||||
@@ -24,25 +32,38 @@ async def format_parsing(messages: list,type:str='string'):
|
||||
role = content['role']
|
||||
content = content['content']
|
||||
if type == "string":
|
||||
if role == 'human' or role=="user":
|
||||
if role == 'human' or role == "user":
|
||||
content = '用户:' + content
|
||||
else:
|
||||
content = 'AI:' + content
|
||||
result.append(content)
|
||||
if type == "dict" :
|
||||
if role == 'human' or role=="user":
|
||||
user.append( content)
|
||||
if type == "dict":
|
||||
if role == 'human' or role == "user":
|
||||
user.append(content)
|
||||
else:
|
||||
ai.append(content)
|
||||
if type == "dict":
|
||||
for key,values in zip(user,ai):
|
||||
result.append({key:values})
|
||||
for key, values in zip(user, ai):
|
||||
result.append({key: values})
|
||||
return result
|
||||
|
||||
|
||||
async def messages_parse(messages: list | dict):
|
||||
user=[]
|
||||
ai=[]
|
||||
database=[]
|
||||
"""
|
||||
Parse messages from storage format into user-AI conversation pairs
|
||||
|
||||
Extracts and organizes conversation data from stored message format,
|
||||
separating user and AI messages and pairing them for database storage.
|
||||
|
||||
Args:
|
||||
messages: List or dictionary containing stored message data with Query fields
|
||||
|
||||
Returns:
|
||||
list: List of dictionaries containing user-AI message pairs for database storage
|
||||
"""
|
||||
user = []
|
||||
ai = []
|
||||
database = []
|
||||
for message in messages:
|
||||
Query = message['Query']
|
||||
Query = json.loads(Query)
|
||||
@@ -54,10 +75,23 @@ async def messages_parse(messages: list | dict):
|
||||
ai.append(data['content'])
|
||||
for key, values in zip(user, ai):
|
||||
database.append({key, values})
|
||||
return database
|
||||
return database
|
||||
|
||||
|
||||
async def agent_chat_messages(user_content,ai_content):
|
||||
async def agent_chat_messages(user_content, ai_content):
|
||||
"""
|
||||
Create structured chat message format for agent conversations
|
||||
|
||||
Formats user and AI content into a standardized message structure suitable
|
||||
for agent processing and storage. Creates role-based message objects.
|
||||
|
||||
Args:
|
||||
user_content: User's message content string
|
||||
ai_content: AI's response content string
|
||||
|
||||
Returns:
|
||||
list: List of structured message dictionaries with role and content fields
|
||||
"""
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
|
||||
@@ -13,7 +13,6 @@ from app.core.memory.agent.langgraph_graph.nodes.write_nodes import write_node
|
||||
from app.schemas.memory_agent_schema import AgentMemory_Long_Term
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
|
||||
|
||||
warnings.filterwarnings("ignore", category=RuntimeWarning)
|
||||
logger = get_agent_logger(__name__)
|
||||
|
||||
@@ -42,10 +41,26 @@ async def make_write_graph():
|
||||
|
||||
yield graph
|
||||
|
||||
async def long_term_storage(long_term_type:str="chunk",langchain_messages:list=[],memory_config:str='',end_user_id:str='',scope:int=6):
|
||||
from app.core.memory.agent.langgraph_graph.routing.write_router import memory_long_term_storage, window_dialogue,aggregate_judgment
|
||||
|
||||
async def long_term_storage(long_term_type: str = "chunk", langchain_messages: list = [], memory_config: str = '',
|
||||
end_user_id: str = '', scope: int = 6):
|
||||
"""
|
||||
Handle long-term memory storage with different strategies
|
||||
|
||||
Supports multiple storage strategies including chunk-based, time-based,
|
||||
and aggregate judgment approaches for long-term memory persistence.
|
||||
|
||||
Args:
|
||||
long_term_type: Storage strategy type ('chunk', 'time', 'aggregate')
|
||||
langchain_messages: List of messages to store
|
||||
memory_config: Memory configuration identifier
|
||||
end_user_id: User group identifier
|
||||
scope: Scope parameter for chunk-based storage (default: 6)
|
||||
"""
|
||||
from app.core.memory.agent.langgraph_graph.routing.write_router import memory_long_term_storage, window_dialogue, \
|
||||
aggregate_judgment
|
||||
from app.core.memory.agent.utils.redis_tool import write_store
|
||||
write_store.save_session_write(end_user_id, (langchain_messages))
|
||||
write_store.save_session_write(end_user_id, langchain_messages)
|
||||
# 获取数据库会话
|
||||
with get_db_context() as db_session:
|
||||
config_service = MemoryConfigService(db_session)
|
||||
@@ -53,26 +68,39 @@ async def long_term_storage(long_term_type:str="chunk",langchain_messages:list=[
|
||||
config_id=memory_config, # 改为整数
|
||||
service_name="MemoryAgentService"
|
||||
)
|
||||
if long_term_type=='chunk':
|
||||
'''方案一:对话窗口6轮对话'''
|
||||
await window_dialogue(end_user_id,langchain_messages,memory_config,scope)
|
||||
if long_term_type=='time':
|
||||
"""时间"""
|
||||
await memory_long_term_storage(end_user_id, memory_config,5)
|
||||
if long_term_type=='aggregate':
|
||||
"""方案三:聚合判断"""
|
||||
if long_term_type == AgentMemory_Long_Term.STRATEGY_CHUNK:
|
||||
'''Strategy 1: Dialogue window with 6 rounds of conversation'''
|
||||
await window_dialogue(end_user_id, langchain_messages, memory_config, scope)
|
||||
if long_term_type == AgentMemory_Long_Term.STRATEGY_TIME:
|
||||
"""Time-based strategy"""
|
||||
await memory_long_term_storage(end_user_id, memory_config, AgentMemory_Long_Term.TIME_SCOPE)
|
||||
if long_term_type == AgentMemory_Long_Term.STRATEGY_AGGREGATE:
|
||||
"""Strategy 3: Aggregate judgment"""
|
||||
await aggregate_judgment(end_user_id, langchain_messages, memory_config)
|
||||
|
||||
|
||||
async def write_long_term(storage_type, end_user_id, message_chat, aimessages, user_rag_memory_id, actual_config_id):
|
||||
"""
|
||||
Write long-term memory with different storage types
|
||||
|
||||
async def write_long_term(storage_type,end_user_id,message_chat,aimessages,user_rag_memory_id,actual_config_id):
|
||||
Handles both RAG-based storage and traditional memory storage approaches.
|
||||
For traditional storage, uses chunk-based strategy with paired user-AI messages.
|
||||
|
||||
Args:
|
||||
storage_type: Type of storage (RAG or traditional)
|
||||
end_user_id: User group identifier
|
||||
message_chat: User message content
|
||||
aimessages: AI response messages
|
||||
user_rag_memory_id: RAG memory identifier
|
||||
actual_config_id: Actual configuration ID
|
||||
"""
|
||||
from app.core.memory.agent.langgraph_graph.routing.write_router import write_rag_agent
|
||||
from app.core.memory.agent.langgraph_graph.routing.write_router import term_memory_save
|
||||
from app.core.memory.agent.langgraph_graph.tools.write_tool import agent_chat_messages
|
||||
from app.core.memory.agent.langgraph_graph.tools.write_tool import agent_chat_messages
|
||||
if storage_type == AgentMemory_Long_Term.STORAGE_RAG:
|
||||
await write_rag_agent(end_user_id, message_chat, aimessages, user_rag_memory_id)
|
||||
else:
|
||||
# AI 回复写入(用户消息和 AI 回复配对,一次性写入完整对话)
|
||||
# AI reply writing (user messages and AI replies paired, written as complete dialogue at once)
|
||||
CHUNK = AgentMemory_Long_Term.STRATEGY_CHUNK
|
||||
SCOPE = AgentMemory_Long_Term.DEFAULT_SCOPE
|
||||
long_term_messages = await agent_chat_messages(message_chat, aimessages)
|
||||
@@ -101,4 +129,4 @@ async def write_long_term(storage_type,end_user_id,message_chat,aimessages,user_
|
||||
#
|
||||
# if __name__ == "__main__":
|
||||
# import asyncio
|
||||
# asyncio.run(main())
|
||||
# asyncio.run(main())
|
||||
|
||||
@@ -13,6 +13,72 @@ from app.core.memory.utils.data.text_utils import escape_lucene_query
|
||||
|
||||
logger = get_agent_logger(__name__)
|
||||
|
||||
# 需要从展开结果中过滤的字段(含 Neo4j DateTime,不可 JSON 序列化)
|
||||
_EXPAND_FIELDS_TO_REMOVE = {
|
||||
'invalid_at', 'valid_at', 'chunk_id_from_rel', 'entity_ids',
|
||||
'expired_at', 'created_at', 'chunk_id', 'apply_id',
|
||||
'user_id', 'statement_ids', 'updated_at', 'chunk_ids', 'fact_summary'
|
||||
}
|
||||
|
||||
|
||||
def _clean_expand_fields(obj):
|
||||
"""递归过滤展开结果中不可序列化的字段(DateTime 等)。"""
|
||||
if isinstance(obj, dict):
|
||||
return {k: _clean_expand_fields(v) for k, v in obj.items() if k not in _EXPAND_FIELDS_TO_REMOVE}
|
||||
if isinstance(obj, list):
|
||||
return [_clean_expand_fields(i) for i in obj]
|
||||
return obj
|
||||
|
||||
|
||||
async def expand_communities_to_statements(
|
||||
community_results: List[dict],
|
||||
end_user_id: str,
|
||||
existing_content: str = "",
|
||||
limit: int = 10,
|
||||
) -> Tuple[List[dict], List[str]]:
|
||||
"""
|
||||
社区展开 helper:给定命中的 community 列表,拉取关联 Statement。
|
||||
|
||||
- 对展开结果去重(过滤已在 existing_content 中出现的文本)
|
||||
- 过滤不可序列化字段
|
||||
- 返回 (cleaned_expanded_stmts, new_texts)
|
||||
- cleaned_expanded_stmts: 可直接写回 raw_results 的列表
|
||||
- new_texts: 去重后新增的 statement 文本列表,用于追加到 clean_content
|
||||
"""
|
||||
community_ids = [r.get("id") for r in community_results if r.get("id")]
|
||||
if not community_ids or not end_user_id:
|
||||
return [], []
|
||||
|
||||
from app.repositories.neo4j.graph_search import search_graph_community_expand
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
|
||||
connector = Neo4jConnector()
|
||||
try:
|
||||
result = await search_graph_community_expand(
|
||||
connector=connector,
|
||||
community_ids=community_ids,
|
||||
end_user_id=end_user_id,
|
||||
limit=limit,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"[expand_communities] 社区展开检索失败,跳过: {e}")
|
||||
return [], []
|
||||
finally:
|
||||
await connector.close()
|
||||
|
||||
expanded_stmts = result.get("expanded_statements", [])
|
||||
if not expanded_stmts:
|
||||
return [], []
|
||||
|
||||
existing_lines = set(existing_content.splitlines())
|
||||
new_texts = [
|
||||
s["statement"] for s in expanded_stmts
|
||||
if s.get("statement") and s["statement"] not in existing_lines
|
||||
]
|
||||
cleaned = _clean_expand_fields(expanded_stmts)
|
||||
logger.info(f"[expand_communities] 展开 {len(expanded_stmts)} 条 statements,新增 {len(new_texts)} 条,community_ids={community_ids}")
|
||||
return cleaned, new_texts
|
||||
|
||||
|
||||
class SearchService:
|
||||
"""Service for executing hybrid search and processing results."""
|
||||
@@ -21,7 +87,7 @@ class SearchService:
|
||||
"""Initialize the search service."""
|
||||
logger.info("SearchService initialized")
|
||||
|
||||
def extract_content_from_result(self, result: dict) -> str:
|
||||
def extract_content_from_result(self, result: dict, node_type: str = "") -> str:
|
||||
"""
|
||||
Extract only meaningful content from search results, dropping all metadata.
|
||||
|
||||
@@ -30,9 +96,11 @@ class SearchService:
|
||||
- Entities: extract 'name' and 'fact_summary' fields
|
||||
- Summaries: extract 'content' field
|
||||
- Chunks: extract 'content' field
|
||||
- Communities: extract 'content' field (c.summary), prefixed with community name
|
||||
|
||||
Args:
|
||||
result: Search result dictionary
|
||||
node_type: Hint for node type ("community", "summary", etc.)
|
||||
|
||||
Returns:
|
||||
Clean content string without metadata
|
||||
@@ -46,8 +114,21 @@ class SearchService:
|
||||
if 'statement' in result and result['statement']:
|
||||
content_parts.append(result['statement'])
|
||||
|
||||
# Summaries/Chunks: extract content field
|
||||
if 'content' in result and result['content']:
|
||||
# Community 节点:有 member_count 或 core_entities 字段,或 node_type 明确指定
|
||||
# 用 "[主题:{name}]" 前缀区分,让 LLM 知道这是主题级摘要
|
||||
is_community = (
|
||||
node_type == "community"
|
||||
or 'member_count' in result
|
||||
or 'core_entities' in result
|
||||
)
|
||||
if is_community:
|
||||
name = result.get('name', '')
|
||||
content = result.get('content', '')
|
||||
if content:
|
||||
prefix = f"[主题:{name}] " if name else ""
|
||||
content_parts.append(f"{prefix}{content}")
|
||||
elif 'content' in result and result['content']:
|
||||
# Summaries / Chunks
|
||||
content_parts.append(result['content'])
|
||||
|
||||
# Entities: extract name and fact_summary (commented out in original)
|
||||
@@ -99,7 +180,8 @@ class SearchService:
|
||||
rerank_alpha: float = 0.4,
|
||||
output_path: str = "search_results.json",
|
||||
return_raw_results: bool = False,
|
||||
memory_config = None
|
||||
memory_config = None,
|
||||
expand_communities: bool = True,
|
||||
) -> Tuple[str, str, Optional[dict]]:
|
||||
"""
|
||||
Execute hybrid search and return clean content.
|
||||
@@ -114,13 +196,15 @@ class SearchService:
|
||||
output_path: Path to save search results (default: "search_results.json")
|
||||
return_raw_results: If True, also return the raw search results as third element (default: False)
|
||||
memory_config: Memory configuration object (required)
|
||||
expand_communities: If True, expand community hits to member statements (default: True).
|
||||
Set to False for quick-summary paths that only need community-level text.
|
||||
|
||||
Returns:
|
||||
Tuple of (clean_content, cleaned_query, raw_results)
|
||||
raw_results is None if return_raw_results=False
|
||||
"""
|
||||
if include is None:
|
||||
include = ["statements", "chunks", "entities", "summaries"]
|
||||
include = ["statements", "chunks", "entities", "summaries", "communities"]
|
||||
|
||||
# Clean query
|
||||
cleaned_query = self.clean_query(question)
|
||||
@@ -146,8 +230,8 @@ class SearchService:
|
||||
if search_type == "hybrid":
|
||||
reranked_results = answer.get('reranked_results', {})
|
||||
|
||||
# Priority order: summaries first (most contextual), then statements, chunks, entities
|
||||
priority_order = ['summaries', 'statements', 'chunks', 'entities']
|
||||
# Priority order: summaries first (most contextual), then communities, statements, chunks, entities
|
||||
priority_order = ['summaries', 'communities', 'statements', 'chunks', 'entities']
|
||||
|
||||
for category in priority_order:
|
||||
if category in include and category in reranked_results:
|
||||
@@ -157,19 +241,33 @@ class SearchService:
|
||||
else:
|
||||
# For keyword or embedding search, results are directly in answer dict
|
||||
# Apply same priority order
|
||||
priority_order = ['summaries', 'statements', 'chunks', 'entities']
|
||||
priority_order = ['summaries', 'communities', 'statements', 'chunks', 'entities']
|
||||
|
||||
for category in priority_order:
|
||||
if category in include and category in answer:
|
||||
category_results = answer[category]
|
||||
if isinstance(category_results, list):
|
||||
answer_list.extend(category_results)
|
||||
|
||||
# 对命中的 community 节点展开其成员 statements(路径 "0"/"1" 需要,路径 "2" 不需要)
|
||||
if expand_communities and "communities" in include:
|
||||
community_results = (
|
||||
answer.get('reranked_results', {}).get('communities', [])
|
||||
if search_type == "hybrid"
|
||||
else answer.get('communities', [])
|
||||
)
|
||||
cleaned_stmts, new_texts = await expand_communities_to_statements(
|
||||
community_results=community_results,
|
||||
end_user_id=end_user_id,
|
||||
)
|
||||
answer_list.extend(cleaned_stmts)
|
||||
|
||||
# Extract clean content from all results
|
||||
content_list = [
|
||||
self.extract_content_from_result(ans)
|
||||
for ans in answer_list
|
||||
]
|
||||
# Extract clean content from all results,按类型传入 node_type 区分 community
|
||||
content_list = []
|
||||
for ans in answer_list:
|
||||
# community 节点有 member_count 或 core_entities 字段
|
||||
ntype = "community" if ('member_count' in ans or 'core_entities' in ans) else ""
|
||||
content_list.append(self.extract_content_from_result(ans, node_type=ntype))
|
||||
|
||||
|
||||
# Filter out empty strings and join with newlines
|
||||
|
||||
@@ -11,7 +11,7 @@ async def get_chunked_dialogs(
|
||||
chunker_strategy: str = "RecursiveChunker",
|
||||
end_user_id: str = "group_1",
|
||||
messages: list = None,
|
||||
ref_id: str = "wyl_20251027",
|
||||
ref_id: str = "",
|
||||
config_id: str = None
|
||||
) -> List[DialogData]:
|
||||
"""Generate chunks from structured messages using the specified chunker strategy.
|
||||
@@ -21,7 +21,7 @@ async def get_chunked_dialogs(
|
||||
end_user_id: Group identifier
|
||||
messages: Structured message list [{"role": "user", "content": "..."}, ...]
|
||||
ref_id: Reference identifier
|
||||
config_id: Configuration ID for processing
|
||||
config_id: Configuration ID for processing (used to load pruning config)
|
||||
|
||||
Returns:
|
||||
List of DialogData objects with generated chunks
|
||||
@@ -40,12 +40,13 @@ async def get_chunked_dialogs(
|
||||
|
||||
role = msg['role']
|
||||
content = msg['content']
|
||||
files = msg.get("file_content", [])
|
||||
|
||||
if role not in ['user', 'assistant']:
|
||||
raise ValueError(f"Message {idx} role must be 'user' or 'assistant', got: {role}")
|
||||
|
||||
if content.strip():
|
||||
conversation_messages.append(ConversationMessage(role=role, msg=content.strip()))
|
||||
conversation_messages.append(ConversationMessage(role=role, msg=content.strip(), files=files))
|
||||
|
||||
if not conversation_messages:
|
||||
raise ValueError("Message list cannot be empty after filtering")
|
||||
@@ -57,6 +58,63 @@ async def get_chunked_dialogs(
|
||||
end_user_id=end_user_id,
|
||||
config_id=config_id
|
||||
)
|
||||
|
||||
# 语义剪枝步骤(在分块之前)
|
||||
try:
|
||||
from app.core.memory.storage_services.extraction_engine.data_preprocessing.data_pruning import SemanticPruner
|
||||
from app.core.memory.models.config_models import PruningConfig
|
||||
from app.db import get_db_context
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
|
||||
# 加载剪枝配置
|
||||
pruning_config = None
|
||||
if config_id:
|
||||
try:
|
||||
with get_db_context() as db:
|
||||
# 使用 MemoryConfigService 加载完整的 MemoryConfig 对象
|
||||
config_service = MemoryConfigService(db)
|
||||
memory_config = config_service.load_memory_config(
|
||||
config_id=config_id,
|
||||
service_name="semantic_pruning"
|
||||
)
|
||||
|
||||
if memory_config:
|
||||
pruning_config = PruningConfig(
|
||||
pruning_switch=memory_config.pruning_enabled,
|
||||
pruning_scene=memory_config.pruning_scene or "education",
|
||||
pruning_threshold=memory_config.pruning_threshold,
|
||||
scene_id=str(memory_config.scene_id) if memory_config.scene_id else None,
|
||||
ontology_class_infos=memory_config.ontology_class_infos,
|
||||
)
|
||||
logger.info(f"[剪枝] 加载配置: switch={pruning_config.pruning_switch}, scene={pruning_config.pruning_scene}, threshold={pruning_config.pruning_threshold}")
|
||||
|
||||
# 获取LLM客户端用于剪枝
|
||||
if pruning_config.pruning_switch:
|
||||
factory = MemoryClientFactory(db)
|
||||
llm_client = factory.get_llm_client_from_config(memory_config)
|
||||
|
||||
# 执行剪枝 - 使用 prune_dataset 支持消息级剪枝
|
||||
pruner = SemanticPruner(config=pruning_config, llm_client=llm_client)
|
||||
original_msg_count = len(dialog_data.context.msgs)
|
||||
|
||||
# 使用 prune_dataset 而不是 prune_dialog
|
||||
# prune_dataset 会进行消息级剪枝,即使对话整体相关也会删除不重要消息
|
||||
pruned_dialogs = await pruner.prune_dataset([dialog_data])
|
||||
|
||||
if pruned_dialogs:
|
||||
dialog_data = pruned_dialogs[0]
|
||||
remaining_msg_count = len(dialog_data.context.msgs)
|
||||
deleted_count = original_msg_count - remaining_msg_count
|
||||
logger.info(f"[剪枝] 完成: 原始{original_msg_count}条 -> 保留{remaining_msg_count}条 (删除{deleted_count}条)")
|
||||
else:
|
||||
logger.warning("[剪枝] prune_dataset 返回空列表")
|
||||
else:
|
||||
logger.info("[剪枝] 配置中剪枝开关关闭,跳过剪枝")
|
||||
except Exception as e:
|
||||
logger.warning(f"[剪枝] 加载配置失败,跳过剪枝: {e}", exc_info=True)
|
||||
except Exception as e:
|
||||
logger.warning(f"[剪枝] 执行失败,跳过剪枝: {e}", exc_info=True)
|
||||
|
||||
chunker = DialogueChunker(chunker_strategy)
|
||||
extracted_chunks = await chunker.process_dialogue(dialog_data)
|
||||
|
||||
@@ -1,56 +0,0 @@
|
||||
|
||||
import asyncio
|
||||
from typing import Dict, Optional
|
||||
from app.core.memory.utils.llm.llm_utils import get_llm_client_fast
|
||||
from app.db import get_db
|
||||
from app.core.logging_config import get_agent_logger
|
||||
|
||||
logger = get_agent_logger(__name__)
|
||||
|
||||
class LLMClientPool:
|
||||
"""LLM客户端连接池"""
|
||||
|
||||
def __init__(self, max_size: int = 5):
|
||||
self.max_size = max_size
|
||||
self.pools: Dict[str, asyncio.Queue] = {}
|
||||
self.active_clients: Dict[str, int] = {}
|
||||
|
||||
async def get_client(self, llm_model_id: str):
|
||||
"""获取LLM客户端"""
|
||||
if llm_model_id not in self.pools:
|
||||
self.pools[llm_model_id] = asyncio.Queue(maxsize=self.max_size)
|
||||
self.active_clients[llm_model_id] = 0
|
||||
|
||||
pool = self.pools[llm_model_id]
|
||||
|
||||
try:
|
||||
# 尝试从池中获取客户端
|
||||
client = pool.get_nowait()
|
||||
logger.debug(f"从池中获取LLM客户端: {llm_model_id}")
|
||||
return client
|
||||
except asyncio.QueueEmpty:
|
||||
# 池为空,创建新客户端
|
||||
if self.active_clients[llm_model_id] < self.max_size:
|
||||
db_session = next(get_db())
|
||||
client = get_llm_client_fast(llm_model_id, db_session)
|
||||
self.active_clients[llm_model_id] += 1
|
||||
logger.debug(f"创建新LLM客户端: {llm_model_id}")
|
||||
return client
|
||||
else:
|
||||
# 等待可用客户端
|
||||
logger.debug(f"等待LLM客户端可用: {llm_model_id}")
|
||||
return await pool.get()
|
||||
|
||||
async def return_client(self, llm_model_id: str, client):
|
||||
"""归还LLM客户端到池中"""
|
||||
if llm_model_id in self.pools:
|
||||
try:
|
||||
self.pools[llm_model_id].put_nowait(client)
|
||||
logger.debug(f"归还LLM客户端到池: {llm_model_id}")
|
||||
except asyncio.QueueFull:
|
||||
# 池已满,丢弃客户端
|
||||
self.active_clients[llm_model_id] -= 1
|
||||
logger.debug(f"池已满,丢弃LLM客户端: {llm_model_id}")
|
||||
|
||||
# 全局客户端池
|
||||
llm_client_pool = LLMClientPool()
|
||||
@@ -8,10 +8,11 @@ from langgraph.graph import add_messages
|
||||
|
||||
PROJECT_ROOT_ = str(Path(__file__).resolve().parents[3])
|
||||
|
||||
|
||||
class WriteState(TypedDict):
|
||||
'''
|
||||
"""
|
||||
Langgrapg Writing TypedDict
|
||||
'''
|
||||
"""
|
||||
messages: Annotated[list[AnyMessage], add_messages]
|
||||
end_user_id: str
|
||||
errors: list[dict] # Track errors: [{"tool": "tool_name", "error": "message"}]
|
||||
@@ -20,6 +21,7 @@ class WriteState(TypedDict):
|
||||
data: str
|
||||
language: str # 语言类型 ("zh" 中文, "en" 英文)
|
||||
|
||||
|
||||
class ReadState(TypedDict):
|
||||
"""
|
||||
LangGraph 工作流状态定义
|
||||
@@ -43,18 +45,20 @@ class ReadState(TypedDict):
|
||||
config_id: str
|
||||
data: str # 新增字段用于传递内容
|
||||
spit_data: dict # 新增字段用于传递问题分解结果
|
||||
problem_extension:dict
|
||||
problem_extension: dict
|
||||
storage_type: str
|
||||
user_rag_memory_id: str
|
||||
llm_id: str
|
||||
embedding_id: str
|
||||
memory_config: object # 新增字段用于传递内存配置对象
|
||||
retrieve:dict
|
||||
retrieve: dict
|
||||
RetrieveSummary: dict
|
||||
InputSummary: dict
|
||||
verify: dict
|
||||
SummaryFails: dict
|
||||
summary: dict
|
||||
|
||||
|
||||
class COUNTState:
|
||||
"""
|
||||
工作流对话检索内容计数器
|
||||
@@ -99,6 +103,7 @@ class COUNTState:
|
||||
self.total = 0
|
||||
print("[COUNTState] 已重置为 0")
|
||||
|
||||
|
||||
def deduplicate_entries(entries):
|
||||
seen = set()
|
||||
deduped = []
|
||||
@@ -109,6 +114,7 @@ def deduplicate_entries(entries):
|
||||
deduped.append(entry)
|
||||
return deduped
|
||||
|
||||
|
||||
def merge_to_key_value_pairs(data, query_key, result_key):
|
||||
grouped = defaultdict(list)
|
||||
for item in data:
|
||||
@@ -142,4 +148,4 @@ def convert_extended_question_to_question(data):
|
||||
return [convert_extended_question_to_question(item) for item in data]
|
||||
else:
|
||||
# 其他类型直接返回
|
||||
return data
|
||||
return data
|
||||
|
||||
@@ -39,6 +39,30 @@
|
||||
比如:输入历史信息内容:[{'Query': '4月27日,我和你推荐过一本书,书名是什么?', 'ANswer': '张曼玉推荐了《小王子》'}]
|
||||
拆分问题:4月27日,我和你推荐过一本书,书名是什么?,可以拆分为:4月27日,张曼玉推荐过一本书,书名是什么?
|
||||
|
||||
## 指代消歧规则(Coreference Resolution):
|
||||
在拆分问题时,必须解析并替换所有指代词和抽象称呼,使问题具体化:
|
||||
|
||||
1. **"用户"的消歧**:
|
||||
- "用户是谁?" → 分析历史记录,找出对话发起者的姓名
|
||||
- 如果历史中有"我叫X"、"我的名字是X"、或多次提到某个人物,则"用户"指的就是这个人
|
||||
- 示例:历史中有"老李的原名叫李建国",则"用户是谁?"应拆分为"李建国是谁?"或"老李(李建国)是谁?"
|
||||
|
||||
2. **"我"的消歧**:
|
||||
- "我喜欢什么?" → 从历史中找出对话发起者的姓名,替换为"X喜欢什么?"
|
||||
- 示例:历史中有"张曼玉推荐了《小王子》",则"我推荐的书是什么?"应拆分为"张曼玉推荐的书是什么?"
|
||||
|
||||
3. **"他/她/它"的消歧**:
|
||||
- 从上下文或历史中找出最近提到的同类实体
|
||||
- 示例:历史中有"老李的同事叫他建国哥",则"他的同事怎么称呼他?"应拆分为"老李的同事怎么称呼他?"
|
||||
|
||||
4. **"那个人/这个人"的消歧**:
|
||||
- 从历史中找出最近提到的人物
|
||||
- 示例:历史中有"李建国",则"那个人的原名是什么?"应拆分为"李建国的原名是什么?"
|
||||
|
||||
5. **优先级**:
|
||||
- 如果历史记录中反复出现某个人物(如"老李"、"李建国"、"建国哥"),则"用户"很可能指的就是这个人
|
||||
- 如果无法从历史中确定指代对象,保留原问题,但在reason中说明"无法确定指代对象"
|
||||
|
||||
|
||||
|
||||
输出要求:
|
||||
@@ -71,6 +95,34 @@
|
||||
"reason": "输出原问题的关键要素"
|
||||
}
|
||||
]
|
||||
|
||||
## 指代消歧示例(重要):
|
||||
示例1 - "用户"的消歧:
|
||||
输入历史:[{'Query': '老李的原名叫什么?', 'Answer': '李建国'}, {'Query': '老李的同事叫他什么?', 'Answer': '建国哥'}]
|
||||
输入问题:"用户是谁?"
|
||||
输出:
|
||||
[
|
||||
{
|
||||
"original_question": "用户是谁?",
|
||||
"extended_question": "李建国是谁?",
|
||||
"type": "单跳",
|
||||
"reason": "历史中反复提到'老李/李建国/建国哥','用户'指的就是对话发起者李建国"
|
||||
}
|
||||
]
|
||||
|
||||
示例2 - "我"的消歧:
|
||||
输入历史:[{'Query': '张曼玉推荐了什么书?', 'Answer': '《小王子》'}]
|
||||
输入问题:"我推荐的书是什么?"
|
||||
输出:
|
||||
[
|
||||
{
|
||||
"original_question": "我推荐的书是什么?",
|
||||
"extended_question": "张曼玉推荐的书是什么?",
|
||||
"type": "单跳",
|
||||
"reason": "历史中提到张曼玉推荐了书,'我'指的就是张曼玉"
|
||||
}
|
||||
]
|
||||
|
||||
**Output format**
|
||||
**CRITICAL JSON FORMATTING REQUIREMENTS:**
|
||||
1. Use only standard ASCII double quotes (") for JSON structure - never use Chinese quotation marks ("") or other Unicode quotes
|
||||
|
||||
@@ -27,6 +27,30 @@
|
||||
比如:输入历史信息内容:[{'Query': '4月27日,我和你推荐过一本书,书名是什么?', 'ANswer': '张曼玉推荐了《小王子》'}]
|
||||
拆分问题:4月27日,我和你推荐过一本书,书名是什么?,可以拆分为:4月27日,张曼玉推荐过一本书,书名是什么?
|
||||
|
||||
## 指代消歧规则(Coreference Resolution):
|
||||
在拆分问题时,必须解析并替换所有指代词和抽象称呼,使问题具体化:
|
||||
|
||||
1. **"用户"的消歧**:
|
||||
- "用户是谁?" → 分析历史记录,找出对话发起者的姓名
|
||||
- 如果历史中有"我叫X"、"我的名字是X"、或多次提到某个人物(如"老李"、"李建国"),则"用户"指的就是这个人
|
||||
- 示例:历史中反复出现"老李/李建国/建国哥",则"用户是谁?"应拆分为"李建国是谁?"或"老李(李建国)是谁?"
|
||||
|
||||
2. **"我"的消歧**:
|
||||
- "我喜欢什么?" → 从历史中找出对话发起者的姓名,替换为"X喜欢什么?"
|
||||
- 示例:历史中有"张曼玉推荐了《小王子》",则"我推荐的书是什么?"应拆分为"张曼玉推荐的书是什么?"
|
||||
|
||||
3. **"他/她/它"的消歧**:
|
||||
- 从上下文或历史中找出最近提到的同类实体
|
||||
- 示例:历史中有"老李的同事叫他建国哥",则"他的同事怎么称呼他?"应拆分为"老李的同事怎么称呼他?"
|
||||
|
||||
4. **"那个人/这个人"的消歧**:
|
||||
- 从历史中找出最近提到的人物
|
||||
- 示例:历史中有"李建国",则"那个人的原名是什么?"应拆分为"李建国的原名是什么?"
|
||||
|
||||
5. **优先级**:
|
||||
- 如果历史记录中反复出现某个人物(如"老李"、"李建国"、"建国哥"),则"用户"很可能指的就是这个人
|
||||
- 如果无法从历史中确定指代对象,保留原问题,但在reason中说明"无法确定指代对象"
|
||||
|
||||
## 指令:
|
||||
你是一个智能数据拆分助手,请根据数据特性判断输入属于哪种类型:
|
||||
单跳(Single-hop)
|
||||
@@ -151,6 +175,34 @@
|
||||
]
|
||||
- 必须通过json.loads()的格式支持的形式输出
|
||||
- 必须通过json.loads()的格式支持的形式输出,响应必须是与此确切模式匹配的有效JSON对象。不要在JSON之前或之后包含任何文本。
|
||||
|
||||
## 指代消歧示例(重要):
|
||||
示例1 - "用户"的消歧:
|
||||
输入历史:[{'Query': '老李的原名叫什么?', 'Answer': '李建国'}, {'Query': '老李的同事叫他什么?', 'Answer': '建国哥'}]
|
||||
输入问题:"用户是谁?"
|
||||
输出:
|
||||
[
|
||||
{
|
||||
"id": "Q1",
|
||||
"question": "李建国是谁?",
|
||||
"type": "单跳",
|
||||
"reason": "历史中反复提到'老李/李建国/建国哥','用户'指的就是对话发起者李建国"
|
||||
}
|
||||
]
|
||||
|
||||
示例2 - "我"的消歧:
|
||||
输入历史:[{'Query': '张曼玉推荐了什么书?', 'Answer': '《小王子》'}]
|
||||
输入问题:"我推荐的书是什么?"
|
||||
输出:
|
||||
[
|
||||
{
|
||||
"id": "Q1",
|
||||
"question": "张曼玉推荐的书是什么?",
|
||||
"type": "单跳",
|
||||
"reason": "历史中提到张曼玉推荐了书,'我'指的就是张曼玉"
|
||||
}
|
||||
]
|
||||
|
||||
- 关键的JSON格式要求
|
||||
1.JSON结构仅使用标准ASCII双引号(“)-切勿使用中文引号(“”)或其他Unicode引号
|
||||
2.如果提取的语句文本包含引号,请使用反斜杠(\“)正确转义它们
|
||||
|
||||
@@ -6,14 +6,17 @@ pipeline. Only MemoryConfig is needed - clients are constructed internally.
|
||||
"""
|
||||
import asyncio
|
||||
import time
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import List, Optional
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from app.core.logging_config import get_agent_logger
|
||||
from app.core.memory.agent.utils.get_dialogs import get_chunked_dialogs
|
||||
from app.core.memory.storage_services.extraction_engine.extraction_orchestrator import ExtractionOrchestrator
|
||||
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.memory_summary import memory_summary_generation
|
||||
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.memory_summary import \
|
||||
memory_summary_generation
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.core.memory.utils.log.logging_utils import log_time
|
||||
from app.db import get_db_context
|
||||
@@ -23,18 +26,17 @@ from app.repositories.neo4j.graph_saver import save_dialog_and_statements_to_neo
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.schemas.memory_config_schema import MemoryConfig
|
||||
|
||||
|
||||
load_dotenv()
|
||||
|
||||
logger = get_agent_logger(__name__)
|
||||
|
||||
|
||||
async def write(
|
||||
end_user_id: str,
|
||||
memory_config: MemoryConfig,
|
||||
messages: list,
|
||||
ref_id: str = "wyl20251027",
|
||||
language: str = "zh",
|
||||
end_user_id: str,
|
||||
memory_config: MemoryConfig,
|
||||
messages: list,
|
||||
ref_id: str = "",
|
||||
language: str = "zh",
|
||||
) -> None:
|
||||
"""
|
||||
Execute the complete knowledge extraction pipeline.
|
||||
@@ -43,9 +45,11 @@ async def write(
|
||||
end_user_id: Group identifier
|
||||
memory_config: MemoryConfig object containing all configuration
|
||||
messages: Structured message list [{"role": "user", "content": "..."}, ...]
|
||||
ref_id: Reference ID, defaults to "wyl20251027"
|
||||
ref_id: Reference ID, defaults to ""
|
||||
language: 语言类型 ("zh" 中文, "en" 英文),默认中文
|
||||
"""
|
||||
if not ref_id:
|
||||
ref_id = uuid.uuid4().hex
|
||||
# Extract config values
|
||||
embedding_model_id = str(memory_config.embedding_model_id)
|
||||
chunker_strategy = memory_config.chunker_strategy
|
||||
@@ -99,14 +103,14 @@ async def write(
|
||||
if memory_config.scene_id:
|
||||
try:
|
||||
from app.core.memory.ontology_services.ontology_type_loader import load_ontology_types_for_scene
|
||||
|
||||
|
||||
with get_db_context() as db:
|
||||
ontology_types = load_ontology_types_for_scene(
|
||||
scene_id=memory_config.scene_id,
|
||||
workspace_id=memory_config.workspace_id,
|
||||
db=db
|
||||
)
|
||||
|
||||
|
||||
if ontology_types:
|
||||
logger.info(
|
||||
f"Loaded {len(ontology_types.types)} ontology types for scene_id: {memory_config.scene_id}"
|
||||
@@ -135,9 +139,11 @@ async def write(
|
||||
all_chunk_nodes,
|
||||
all_statement_nodes,
|
||||
all_entity_nodes,
|
||||
all_perceptual_nodes,
|
||||
all_statement_chunk_edges,
|
||||
all_statement_entity_edges,
|
||||
all_entity_entity_edges,
|
||||
all_perceptual_edges,
|
||||
all_dedup_details,
|
||||
) = await orchestrator.run(chunked_dialogs, is_pilot_run=False)
|
||||
|
||||
@@ -145,11 +151,6 @@ async def write(
|
||||
|
||||
# Step 3: Save all data to Neo4j database
|
||||
step_start = time.time()
|
||||
from app.repositories.neo4j.create_indexes import create_fulltext_indexes
|
||||
try:
|
||||
await create_fulltext_indexes()
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating indexes: {e}", exc_info=True)
|
||||
|
||||
# 添加死锁重试机制
|
||||
max_retries = 3
|
||||
@@ -162,13 +163,43 @@ async def write(
|
||||
chunk_nodes=all_chunk_nodes,
|
||||
statement_nodes=all_statement_nodes,
|
||||
entity_nodes=all_entity_nodes,
|
||||
perceptual_nodes=all_perceptual_nodes,
|
||||
statement_chunk_edges=all_statement_chunk_edges,
|
||||
statement_entity_edges=all_statement_entity_edges,
|
||||
entity_edges=all_entity_entity_edges,
|
||||
connector=neo4j_connector
|
||||
perceptual_edges=all_perceptual_edges,
|
||||
connector=neo4j_connector,
|
||||
)
|
||||
if success:
|
||||
logger.info("Successfully saved all data to Neo4j")
|
||||
|
||||
# 使用 Celery 异步任务触发聚类(不阻塞主流程)
|
||||
if all_entity_nodes:
|
||||
try:
|
||||
from app.tasks import run_incremental_clustering
|
||||
|
||||
end_user_id = all_entity_nodes[0].end_user_id
|
||||
new_entity_ids = [e.id for e in all_entity_nodes]
|
||||
|
||||
# 异步提交 Celery 任务
|
||||
task = run_incremental_clustering.apply_async(
|
||||
kwargs={
|
||||
"end_user_id": end_user_id,
|
||||
"new_entity_ids": new_entity_ids,
|
||||
"llm_model_id": str(memory_config.llm_model_id) if memory_config.llm_model_id else None,
|
||||
"embedding_model_id": str(memory_config.embedding_model_id) if memory_config.embedding_model_id else None,
|
||||
},
|
||||
# 设置任务优先级(低优先级,不影响主业务)
|
||||
priority=3,
|
||||
)
|
||||
logger.info(
|
||||
f"[Clustering] 增量聚类任务已提交到 Celery - "
|
||||
f"task_id={task.id}, end_user_id={end_user_id}, entity_count={len(new_entity_ids)}"
|
||||
)
|
||||
except Exception as e:
|
||||
# 聚类任务提交失败不影响主流程
|
||||
logger.error(f"[Clustering] 提交聚类任务失败(不影响主流程): {e}", exc_info=True)
|
||||
|
||||
break
|
||||
else:
|
||||
logger.warning("Failed to save some data to Neo4j")
|
||||
@@ -202,9 +233,8 @@ async def write(
|
||||
summaries = await memory_summary_generation(
|
||||
chunked_dialogs, llm_client=llm_client, embedder_client=embedder_client, language=language
|
||||
)
|
||||
|
||||
ms_connector = Neo4jConnector()
|
||||
try:
|
||||
ms_connector = Neo4jConnector()
|
||||
await add_memory_summary_nodes(summaries, ms_connector)
|
||||
await add_memory_summary_statement_edges(summaries, ms_connector)
|
||||
finally:
|
||||
@@ -225,5 +255,40 @@ async def write(
|
||||
with open(log_file, "a", encoding="utf-8") as f:
|
||||
f.write(f"=== Pipeline Run Completed: {timestamp} ===\n\n")
|
||||
|
||||
# 将提取统计写入 Redis,按 workspace_id 存储
|
||||
try:
|
||||
from app.cache.memory.activity_stats_cache import ActivityStatsCache
|
||||
|
||||
stats_to_cache = {
|
||||
"chunk_count": len(all_chunk_nodes) if all_chunk_nodes else 0,
|
||||
"statements_count": len(all_statement_nodes) if all_statement_nodes else 0,
|
||||
"triplet_entities_count": len(all_entity_nodes) if all_entity_nodes else 0,
|
||||
"triplet_relations_count": len(all_entity_entity_edges) if all_entity_entity_edges else 0,
|
||||
"temporal_count": 0,
|
||||
}
|
||||
await ActivityStatsCache.set_activity_stats(
|
||||
workspace_id=str(memory_config.workspace_id),
|
||||
stats=stats_to_cache,
|
||||
)
|
||||
logger.info(f"[WRITE] 活动统计已写入 Redis: workspace_id={memory_config.workspace_id}")
|
||||
except Exception as cache_err:
|
||||
logger.warning(f"[WRITE] 写入活动统计缓存失败(不影响主流程): {cache_err}", exc_info=True)
|
||||
|
||||
# Close LLM/Embedder underlying httpx clients to prevent
|
||||
# 'RuntimeError: Event loop is closed' during garbage collection
|
||||
for client_obj in (llm_client, embedder_client):
|
||||
try:
|
||||
underlying = getattr(client_obj, 'client', None) or getattr(client_obj, 'model', None)
|
||||
if underlying is None:
|
||||
continue
|
||||
# Unwrap RedBearLLM / RedBearEmbeddings to get the LangChain model
|
||||
inner = getattr(underlying, '_model', underlying)
|
||||
# LangChain OpenAI models expose async_client (httpx.AsyncClient)
|
||||
http_client = getattr(inner, 'async_client', None)
|
||||
if http_client is not None and hasattr(http_client, 'aclose'):
|
||||
await http_client.aclose()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
logger.info("=== Pipeline Complete ===")
|
||||
logger.info(f"Total execution time: {total_time:.2f} seconds")
|
||||
logger.info(f"Total execution time: {total_time:.2f} seconds")
|
||||
|
||||
@@ -1,9 +1,12 @@
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from typing import List, Tuple
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.db import get_db_context
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
@@ -16,6 +19,10 @@ class FilteredTags(BaseModel):
|
||||
"""用于接收LLM筛选后的核心标签列表的模型。"""
|
||||
meaningful_tags: List[str] = Field(..., description="从原始列表中筛选出的具有核心代表意义的名词列表。")
|
||||
|
||||
class InterestTags(BaseModel):
|
||||
"""用于接收LLM筛选后的兴趣活动标签列表的模型。"""
|
||||
interest_tags: List[str] = Field(..., description="从原始列表中筛选出的代表用户兴趣活动的标签列表。")
|
||||
|
||||
async def filter_tags_with_llm(tags: List[str], end_user_id: str) -> List[str]:
|
||||
"""
|
||||
使用LLM筛选标签列表,仅保留具有代表性的核心名词。
|
||||
@@ -85,10 +92,74 @@ async def filter_tags_with_llm(tags: List[str], end_user_id: str) -> List[str]:
|
||||
return structured_response.meaningful_tags
|
||||
|
||||
except Exception as e:
|
||||
print(f"LLM筛选过程中发生错误: {e}")
|
||||
logger.error(f"LLM筛选过程中发生错误: {e}", exc_info=True)
|
||||
# 在LLM失败时返回原始标签,确保流程继续
|
||||
return tags
|
||||
|
||||
async def filter_interests_with_llm(tags: List[str], end_user_id: str, language: str = "zh") -> List[str]:
|
||||
"""
|
||||
使用LLM从标签列表中筛选出代表用户兴趣活动的标签。
|
||||
|
||||
与 filter_tags_with_llm 不同,此函数专注于识别"活动/行为"类兴趣,
|
||||
过滤掉纯物品、工具、地点等不代表用户主动参与活动的名词。
|
||||
|
||||
Args:
|
||||
tags: 原始标签列表
|
||||
end_user_id: 用户ID,用于获取LLM配置
|
||||
|
||||
Returns:
|
||||
筛选后的兴趣活动标签列表
|
||||
"""
|
||||
try:
|
||||
with get_db_context() as db:
|
||||
from app.services.memory_agent_service import (
|
||||
get_end_user_connected_config,
|
||||
)
|
||||
connected_config = get_end_user_connected_config(end_user_id, db)
|
||||
config_id = connected_config.get("memory_config_id")
|
||||
workspace_id = connected_config.get("workspace_id")
|
||||
|
||||
if not config_id and not workspace_id:
|
||||
raise ValueError(
|
||||
f"No memory_config_id found for end_user_id: {end_user_id}."
|
||||
)
|
||||
|
||||
config_service = MemoryConfigService(db)
|
||||
memory_config = config_service.load_memory_config(
|
||||
config_id=config_id,
|
||||
workspace_id=workspace_id
|
||||
)
|
||||
|
||||
if not memory_config.llm_model_id:
|
||||
raise ValueError(
|
||||
f"No llm_model_id found in memory config {config_id}."
|
||||
)
|
||||
|
||||
factory = MemoryClientFactory(db)
|
||||
llm_client = factory.get_llm_client(memory_config.llm_model_id)
|
||||
|
||||
tag_list_str = ", ".join(tags)
|
||||
from app.core.memory.utils.prompt.prompt_utils import render_interest_filter_prompt
|
||||
rendered_prompt = render_interest_filter_prompt(tag_list_str, language=language)
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": rendered_prompt
|
||||
}
|
||||
]
|
||||
|
||||
structured_response = await llm_client.response_structured(
|
||||
messages=messages,
|
||||
response_model=InterestTags
|
||||
)
|
||||
|
||||
return structured_response.interest_tags
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"兴趣标签LLM筛选过程中发生错误: {e}", exc_info=True)
|
||||
return tags
|
||||
|
||||
|
||||
async def get_raw_tags_from_db(
|
||||
connector: Neo4jConnector,
|
||||
end_user_id: str,
|
||||
@@ -139,14 +210,14 @@ async def get_raw_tags_from_db(
|
||||
|
||||
return [(record["name"], record["frequency"]) for record in results]
|
||||
|
||||
async def get_hot_memory_tags(end_user_id: str, limit: int = 40, by_user: bool = False) -> List[Tuple[str, int]]:
|
||||
async def get_hot_memory_tags(end_user_id: str, limit: int = 10, by_user: bool = False) -> List[Tuple[str, int]]:
|
||||
"""
|
||||
获取原始标签,然后使用LLM进行筛选,返回最终的热门标签列表。
|
||||
查询更多的标签(limit=40)给LLM提供更丰富的上下文进行筛选。
|
||||
查询更多的标签(40条)给LLM提供更丰富的上下文进行筛选,但最终返回数量由limit参数控制。
|
||||
|
||||
Args:
|
||||
end_user_id: 必需参数。如果by_user=False,则为end_user_id;如果by_user=True,则为user_id
|
||||
limit: 返回的标签数量限制
|
||||
limit: 最终返回的标签数量限制(默认10)
|
||||
by_user: 是否按user_id查询(默认False,按end_user_id查询)
|
||||
|
||||
Raises:
|
||||
@@ -161,8 +232,9 @@ async def get_hot_memory_tags(end_user_id: str, limit: int = 40, by_user: bool =
|
||||
# 使用项目的Neo4jConnector
|
||||
connector = Neo4jConnector()
|
||||
try:
|
||||
# 1. 从数据库获取原始排名靠前的标签
|
||||
raw_tags_with_freq = await get_raw_tags_from_db(connector, end_user_id, limit, by_user=by_user)
|
||||
# 1. 从数据库获取原始排名靠前的标签(查询40条给LLM提供更丰富的上下文)
|
||||
query_limit = 40
|
||||
raw_tags_with_freq = await get_raw_tags_from_db(connector, end_user_id, query_limit, by_user=by_user)
|
||||
if not raw_tags_with_freq:
|
||||
return []
|
||||
|
||||
@@ -177,7 +249,61 @@ async def get_hot_memory_tags(end_user_id: str, limit: int = 40, by_user: bool =
|
||||
if tag in meaningful_tag_names:
|
||||
final_tags.append((tag, freq))
|
||||
|
||||
return final_tags
|
||||
# 4. 限制返回的标签数量
|
||||
return final_tags[:limit]
|
||||
finally:
|
||||
# 确保关闭连接
|
||||
await connector.close()
|
||||
|
||||
async def get_interest_distribution(end_user_id: str, limit: int = 10, by_user: bool = False, language: str = "zh") -> List[Tuple[str, int]]:
|
||||
"""
|
||||
获取用户的兴趣分布标签。
|
||||
|
||||
与 get_hot_memory_tags 不同,此函数使用专门针对"活动/行为"的LLM prompt,
|
||||
过滤掉纯物品、工具、地点等,只保留能代表用户兴趣爱好的活动类标签。
|
||||
|
||||
Args:
|
||||
end_user_id: 必需参数。如果by_user=False,则为end_user_id;如果by_user=True,则为user_id
|
||||
limit: 最终返回的标签数量限制(默认10)
|
||||
by_user: 是否按user_id查询(默认False,按end_user_id查询)
|
||||
|
||||
Raises:
|
||||
ValueError: 如果end_user_id未提供或为空
|
||||
"""
|
||||
if not end_user_id or not end_user_id.strip():
|
||||
raise ValueError(
|
||||
"end_user_id is required. Please provide a valid end_user_id or user_id."
|
||||
)
|
||||
|
||||
connector = Neo4jConnector()
|
||||
try:
|
||||
# 查询更多原始标签,给LLM提供充足上下文
|
||||
query_limit = 40
|
||||
raw_tags_with_freq = await get_raw_tags_from_db(connector, end_user_id, query_limit, by_user=by_user)
|
||||
if not raw_tags_with_freq:
|
||||
return []
|
||||
|
||||
raw_tag_names = [tag for tag, freq in raw_tags_with_freq]
|
||||
raw_freq_map = {tag: freq for tag, freq in raw_tags_with_freq}
|
||||
|
||||
# 使用兴趣活动专用prompt进行筛选(支持语义推断出新标签)
|
||||
interest_tag_names = await filter_interests_with_llm(raw_tag_names, end_user_id, language=language)
|
||||
|
||||
# 构建最终标签列表:
|
||||
# - 原始标签中存在的,保留原始频率
|
||||
# - LLM推断出的新标签(不在原始列表中),赋予默认频率1
|
||||
final_tags = []
|
||||
seen = set()
|
||||
for tag in interest_tag_names:
|
||||
if tag in seen:
|
||||
continue
|
||||
seen.add(tag)
|
||||
freq = raw_freq_map.get(tag, 1)
|
||||
final_tags.append((tag, freq))
|
||||
|
||||
# 按频率降序排列
|
||||
final_tags.sort(key=lambda x: x[1], reverse=True)
|
||||
|
||||
return final_tags[:limit]
|
||||
finally:
|
||||
await connector.close()
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
from typing import Any, List
|
||||
import re
|
||||
import os
|
||||
import asyncio
|
||||
import json
|
||||
import numpy as np
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, List
|
||||
|
||||
import numpy as np
|
||||
|
||||
# Fix tokenizer parallelism warning
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
@@ -246,6 +246,7 @@ class ChunkerClient:
|
||||
"total_sub_chunks": len(sub_chunks),
|
||||
"chunker_strategy": self.chunker_config.chunker_strategy,
|
||||
},
|
||||
files=msg.files
|
||||
)
|
||||
dialogue.chunks.append(chunk)
|
||||
else:
|
||||
@@ -258,6 +259,7 @@ class ChunkerClient:
|
||||
"message_role": msg.role,
|
||||
"chunker_strategy": self.chunker_config.chunker_strategy,
|
||||
},
|
||||
files=msg.files
|
||||
)
|
||||
dialogue.chunks.append(chunk)
|
||||
|
||||
|
||||
@@ -65,7 +65,7 @@ class OpenAIClient(LLMClient):
|
||||
type=type_
|
||||
)
|
||||
|
||||
logger.info(f"OpenAI 客户端初始化完成: type={type_}")
|
||||
logger.debug(f"OpenAI 客户端初始化完成: type={type_}")
|
||||
|
||||
async def chat(self, messages: List[Dict[str, str]], **kwargs) -> Any:
|
||||
"""
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
OpenAI Embedder 客户端实现
|
||||
|
||||
基于 LangChain 和 RedBearEmbeddings 的 OpenAI 嵌入模型客户端实现。
|
||||
自动支持火山引擎的多模态 Embedding。
|
||||
"""
|
||||
|
||||
from typing import List
|
||||
@@ -13,6 +14,7 @@ from app.core.memory.llm_tools.embedder_client import (
|
||||
)
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
from app.core.models.embedding import RedBearEmbeddings
|
||||
from app.models.models_model import ModelProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -25,6 +27,7 @@ class OpenAIEmbedderClient(EmbedderClient):
|
||||
- 批量文本嵌入
|
||||
- 自动重试机制
|
||||
- 错误处理
|
||||
- 火山引擎多模态 Embedding(自动识别)
|
||||
"""
|
||||
|
||||
def __init__(self, model_config: RedBearModelConfig):
|
||||
@@ -36,7 +39,7 @@ class OpenAIEmbedderClient(EmbedderClient):
|
||||
"""
|
||||
super().__init__(model_config)
|
||||
|
||||
# 初始化 RedBearEmbeddings 模型
|
||||
# 初始化 RedBearEmbeddings(自动支持火山引擎多模态)
|
||||
self.model = RedBearEmbeddings(
|
||||
RedBearModelConfig(
|
||||
model_name=self.model_name,
|
||||
@@ -47,8 +50,9 @@ class OpenAIEmbedderClient(EmbedderClient):
|
||||
timeout=self.timeout,
|
||||
)
|
||||
)
|
||||
self.is_multimodal = self.model.is_multimodal_supported()
|
||||
|
||||
logger.info("OpenAI Embedder 客户端初始化完成")
|
||||
logger.info(f"OpenAI Embedder 客户端初始化完成 (provider={self.provider}, multimodal={self.is_multimodal})")
|
||||
|
||||
async def response(
|
||||
self,
|
||||
@@ -77,7 +81,14 @@ class OpenAIEmbedderClient(EmbedderClient):
|
||||
return []
|
||||
|
||||
# 生成嵌入向量
|
||||
embeddings = await self.model.aembed_documents(texts)
|
||||
if self.is_multimodal:
|
||||
# 火山引擎多模态 Embedding
|
||||
embeddings = await self.model.aembed_multimodal(
|
||||
[{"type": "text", "text": text} for text in texts]
|
||||
)
|
||||
else:
|
||||
# 普通 Embedding
|
||||
embeddings = await self.model.aembed_documents(texts)
|
||||
|
||||
logger.debug(f"成功生成 {len(embeddings)} 个嵌入向量")
|
||||
return embeddings
|
||||
|
||||
@@ -6,11 +6,12 @@ of the memory system including LLM, chunking, pruning, and search.
|
||||
Classes:
|
||||
LLMConfig: Configuration for LLM client
|
||||
ChunkerConfig: Configuration for dialogue chunking
|
||||
OntologyClassInfo: Single ontology class with name and description
|
||||
PruningConfig: Configuration for semantic pruning
|
||||
TemporalSearchParams: Parameters for temporal search queries
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
from typing import Optional, List
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
@@ -50,22 +51,42 @@ class ChunkerConfig(BaseModel):
|
||||
min_characters_per_chunk: Optional[int] = Field(24, ge=0, description="The minimum number of characters in each chunk.")
|
||||
|
||||
|
||||
class OntologyClassInfo(BaseModel):
|
||||
"""本体类型的名称与语义描述,用于剪枝提示词注入。
|
||||
|
||||
Attributes:
|
||||
class_name: 本体类型名称(如"患者"、"课程")
|
||||
class_description: 本体类型语义描述,告知 LLM 该类型在当前场景下的含义
|
||||
"""
|
||||
class_name: str = Field(..., description="本体类型名称")
|
||||
class_description: str = Field(default="", description="本体类型语义描述")
|
||||
|
||||
|
||||
class PruningConfig(BaseModel):
|
||||
"""Configuration for semantic pruning of dialogue content.
|
||||
|
||||
Attributes:
|
||||
pruning_switch: Enable or disable semantic pruning
|
||||
pruning_scene: Scene type for pruning ('education', 'online_service', 'outbound')
|
||||
pruning_scene: Scene name for pruning from ontology_scene table
|
||||
pruning_threshold: Pruning ratio (0-0.9, max 0.9 to avoid complete removal)
|
||||
scene_id: Optional ontology scene UUID
|
||||
ontology_class_infos: Full ontology class info (name + description) from
|
||||
ontology_class table, injected into the pruning prompt to drive
|
||||
scene-aware preservation decisions
|
||||
"""
|
||||
pruning_switch: bool = Field(False, description="Enable semantic pruning when True.")
|
||||
pruning_scene: str = Field(
|
||||
"education",
|
||||
description="Scene for pruning: one of 'education', 'online_service', 'outbound'.",
|
||||
description="Scene name from ontology_scene table.",
|
||||
)
|
||||
pruning_threshold: float = Field(
|
||||
0.5, ge=0.0, le=0.9,
|
||||
description="Pruning ratio within 0-0.9 (max 0.9 to avoid termination).")
|
||||
scene_id: Optional[str] = Field(None, description="Ontology scene UUID (optional).")
|
||||
ontology_class_infos: List[OntologyClassInfo] = Field(
|
||||
default_factory=list,
|
||||
description="Full ontology class info (name + description) injected into pruning prompt."
|
||||
)
|
||||
|
||||
|
||||
class TemporalSearchParams(BaseModel):
|
||||
|
||||
@@ -44,21 +44,21 @@ def parse_historical_datetime(v):
|
||||
"""
|
||||
if v is None:
|
||||
return v
|
||||
|
||||
|
||||
# 处理 Neo4j DateTime 对象
|
||||
if hasattr(v, 'to_native'):
|
||||
return v.to_native()
|
||||
|
||||
|
||||
# 处理 Python datetime 对象
|
||||
if isinstance(v, datetime):
|
||||
return v
|
||||
|
||||
|
||||
if isinstance(v, str):
|
||||
# 匹配 ISO 8601 格式:YYYY-MM-DD 或 YYYY-MM-DDTHH:MM:SS[.ffffff][Z|±HH:MM]
|
||||
# 支持1-4位年份
|
||||
pattern = r'^(\d{1,4})-(\d{2})-(\d{2})(?:T(\d{2}):(\d{2}):(\d{2})(?:\.(\d+))?(?:Z|([+-]\d{2}:\d{2}))?)?'
|
||||
match = re.match(pattern, v)
|
||||
|
||||
|
||||
if match:
|
||||
try:
|
||||
year = int(match.group(1))
|
||||
@@ -68,31 +68,31 @@ def parse_historical_datetime(v):
|
||||
minute = int(match.group(5)) if match.group(5) else 0
|
||||
second = int(match.group(6)) if match.group(6) else 0
|
||||
microsecond = 0
|
||||
|
||||
|
||||
# 处理微秒
|
||||
if match.group(7):
|
||||
# 补齐或截断到6位
|
||||
us_str = match.group(7).ljust(6, '0')[:6]
|
||||
microsecond = int(us_str)
|
||||
|
||||
|
||||
# 处理时区
|
||||
tzinfo = None
|
||||
if 'Z' in v or match.group(8):
|
||||
tzinfo = timezone.utc
|
||||
|
||||
|
||||
# 创建 datetime 对象
|
||||
return datetime(year, month, day, hour, minute, second, microsecond, tzinfo=tzinfo)
|
||||
|
||||
|
||||
except (ValueError, OverflowError):
|
||||
# 日期值无效(如月份13、日期32等)
|
||||
return None
|
||||
|
||||
|
||||
# 如果不匹配模式,尝试使用 fromisoformat(用于标准格式)
|
||||
try:
|
||||
return datetime.fromisoformat(v.replace('Z', '+00:00'))
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
return v
|
||||
|
||||
|
||||
@@ -114,7 +114,7 @@ class Edge(BaseModel):
|
||||
end_user_id: str = Field(..., description="The end user ID of the edge.")
|
||||
run_id: str = Field(default_factory=lambda: uuid4().hex, description="Unique identifier for this pipeline run.")
|
||||
created_at: datetime = Field(..., description="The valid time of the edge from system perspective.")
|
||||
expired_at: Optional[datetime] = Field(None, description="The expired time of the edge from system perspective.")
|
||||
expired_at: Optional[datetime] = Field(default=None, description="The expired time of the edge from system perspective.")
|
||||
|
||||
|
||||
class ChunkEdge(Edge):
|
||||
@@ -167,7 +167,7 @@ class EntityEntityEdge(Edge):
|
||||
source_statement_id: str = Field(..., description="Statement where this relationship was extracted")
|
||||
valid_at: Optional[datetime] = Field(None, description="Temporal validity start")
|
||||
invalid_at: Optional[datetime] = Field(None, description="Temporal validity end")
|
||||
|
||||
|
||||
@field_validator('valid_at', 'invalid_at', mode='before')
|
||||
@classmethod
|
||||
def validate_datetime(cls, v):
|
||||
@@ -175,6 +175,12 @@ class EntityEntityEdge(Edge):
|
||||
return parse_historical_datetime(v)
|
||||
|
||||
|
||||
class PerceptualEdge(Edge):
|
||||
"""Edge connecting perceptual nodes to their source chunks
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class Node(BaseModel):
|
||||
"""Base class for all graph nodes in the knowledge graph.
|
||||
|
||||
@@ -206,7 +212,8 @@ class DialogueNode(Node):
|
||||
ref_id: str = Field(..., description="Reference identifier of the dialog")
|
||||
content: str = Field(..., description="Dialogue content")
|
||||
dialog_embedding: Optional[List[float]] = Field(None, description="Dialog embedding vector")
|
||||
config_id: Optional[int | str] = Field(None, description="Configuration ID used to process this dialogue (integer or string)")
|
||||
config_id: Optional[int | str] = Field(None,
|
||||
description="Configuration ID used to process this dialogue (integer or string)")
|
||||
|
||||
|
||||
class StatementNode(Node):
|
||||
@@ -241,17 +248,17 @@ class StatementNode(Node):
|
||||
chunk_id: str = Field(..., description="ID of the parent chunk")
|
||||
stmt_type: str = Field(..., description="Type of the statement")
|
||||
statement: str = Field(..., description="The statement text content")
|
||||
|
||||
|
||||
# Speaker identification
|
||||
speaker: Optional[str] = Field(
|
||||
None,
|
||||
description="Speaker identifier: 'user' for user messages, 'assistant' for AI responses"
|
||||
)
|
||||
|
||||
|
||||
# Emotion fields (ordered as requested, emotion_intensity first for display)
|
||||
emotion_intensity: Optional[float] = Field(
|
||||
None,
|
||||
ge=0.0,
|
||||
None,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description="Emotion intensity: 0.0-1.0 (displayed on node)"
|
||||
)
|
||||
@@ -264,25 +271,26 @@ class StatementNode(Node):
|
||||
description="Emotion subject: self/other/object"
|
||||
)
|
||||
emotion_type: Optional[str] = Field(
|
||||
None,
|
||||
None,
|
||||
description="Emotion type: joy/sadness/anger/fear/surprise/neutral"
|
||||
)
|
||||
emotion_keywords: Optional[List[str]] = Field(
|
||||
default_factory=list,
|
||||
description="Emotion keywords list, max 3 items"
|
||||
)
|
||||
|
||||
|
||||
# Temporal fields
|
||||
temporal_info: TemporalInfo = Field(..., description="Temporal information")
|
||||
valid_at: Optional[datetime] = Field(None, description="Temporal validity start")
|
||||
invalid_at: Optional[datetime] = Field(None, description="Temporal validity end")
|
||||
|
||||
|
||||
# Embedding and other fields
|
||||
statement_embedding: Optional[List[float]] = Field(None, description="Statement embedding vector")
|
||||
chunk_embedding: Optional[List[float]] = Field(None, description="Chunk embedding vector")
|
||||
connect_strength: str = Field(..., description="Strong VS Weak classification of this statement")
|
||||
config_id: Optional[int | str] = Field(None, description="Configuration ID used to process this statement (integer or string)")
|
||||
|
||||
config_id: Optional[int | str] = Field(None,
|
||||
description="Configuration ID used to process this statement (integer or string)")
|
||||
|
||||
# ACT-R Memory Activation Properties
|
||||
importance_score: float = Field(
|
||||
default=0.5,
|
||||
@@ -309,13 +317,13 @@ class StatementNode(Node):
|
||||
ge=0,
|
||||
description="Total number of times this node has been accessed"
|
||||
)
|
||||
|
||||
|
||||
@field_validator('valid_at', 'invalid_at', mode='before')
|
||||
@classmethod
|
||||
def validate_datetime(cls, v):
|
||||
"""使用通用的历史日期解析函数"""
|
||||
return parse_historical_datetime(v)
|
||||
|
||||
|
||||
@field_validator('emotion_type', mode='before')
|
||||
@classmethod
|
||||
def validate_emotion_type(cls, v):
|
||||
@@ -326,7 +334,7 @@ class StatementNode(Node):
|
||||
if v not in valid_types:
|
||||
raise ValueError(f"emotion_type must be one of {valid_types}, got {v}")
|
||||
return v
|
||||
|
||||
|
||||
@field_validator('emotion_subject', mode='before')
|
||||
@classmethod
|
||||
def validate_emotion_subject(cls, v):
|
||||
@@ -337,7 +345,7 @@ class StatementNode(Node):
|
||||
if v not in valid_subjects:
|
||||
raise ValueError(f"emotion_subject must be one of {valid_subjects}, got {v}")
|
||||
return v
|
||||
|
||||
|
||||
@field_validator('emotion_keywords', mode='before')
|
||||
@classmethod
|
||||
def validate_emotion_keywords(cls, v):
|
||||
@@ -405,19 +413,20 @@ class ExtractedEntityNode(Node):
|
||||
entity_type: str = Field(..., description="Type of the entity")
|
||||
description: str = Field(..., description="Entity description")
|
||||
example: str = Field(
|
||||
default="",
|
||||
default="",
|
||||
description="A concise example (around 20 characters) to help understand the entity"
|
||||
)
|
||||
aliases: List[str] = Field(
|
||||
default_factory=list,
|
||||
default_factory=list,
|
||||
description="Entity aliases - alternative names for this entity"
|
||||
)
|
||||
name_embedding: Optional[List[float]] = Field(default_factory=list, description="Name embedding vector")
|
||||
# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
|
||||
# fact_summary: str = Field(default="", description="Summary of the fact about this entity")
|
||||
connect_strength: str = Field(..., description="Strong VS Weak about this entity")
|
||||
config_id: Optional[int | str] = Field(None, description="Configuration ID used to process this entity (integer or string)")
|
||||
|
||||
config_id: Optional[int | str] = Field(None,
|
||||
description="Configuration ID used to process this entity (integer or string)")
|
||||
|
||||
# ACT-R Memory Activation Properties
|
||||
importance_score: float = Field(
|
||||
default=0.5,
|
||||
@@ -444,16 +453,16 @@ class ExtractedEntityNode(Node):
|
||||
ge=0,
|
||||
description="Total number of times this node has been accessed"
|
||||
)
|
||||
|
||||
|
||||
# Explicit Memory Classification
|
||||
is_explicit_memory: bool = Field(
|
||||
default=False,
|
||||
description="Whether this entity represents explicit/semantic memory (knowledge, concepts, definitions, theories, principles)"
|
||||
)
|
||||
|
||||
|
||||
@field_validator('aliases', mode='before')
|
||||
@classmethod
|
||||
def validate_aliases_field(cls, v): # 字段验证器 自动清理和验证 aliases 字段
|
||||
def validate_aliases_field(cls, v): # 字段验证器 自动清理和验证 aliases 字段
|
||||
"""Validate and clean aliases field using utility function.
|
||||
|
||||
This validator ensures that the aliases field is always a valid list of strings.
|
||||
@@ -507,8 +516,9 @@ class MemorySummaryNode(Node):
|
||||
memory_type: Optional[str] = Field(None, description="Type/category of the episodic memory")
|
||||
summary_embedding: Optional[List[float]] = Field(None, description="Embedding vector for the summary")
|
||||
metadata: dict = Field(default_factory=dict, description="Additional metadata for the summary")
|
||||
config_id: Optional[int | str] = Field(None, description="Configuration ID used to process this summary (integer or string)")
|
||||
|
||||
config_id: Optional[int | str] = Field(None,
|
||||
description="Configuration ID used to process this summary (integer or string)")
|
||||
|
||||
# ACT-R Forgetting Engine Properties
|
||||
original_statement_id: Optional[str] = Field(
|
||||
None,
|
||||
@@ -522,7 +532,7 @@ class MemorySummaryNode(Node):
|
||||
None,
|
||||
description="Timestamp when the nodes were merged"
|
||||
)
|
||||
|
||||
|
||||
# ACT-R Memory Activation Properties
|
||||
importance_score: float = Field(
|
||||
default=0.5,
|
||||
@@ -549,3 +559,18 @@ class MemorySummaryNode(Node):
|
||||
ge=0,
|
||||
description="Total number of times this node has been accessed (reset to 1 on creation)"
|
||||
)
|
||||
|
||||
|
||||
class PerceptualNode(Node):
|
||||
"""Node representing a multimodal message in the knowledge graph.
|
||||
"""
|
||||
perceptual_type: int
|
||||
file_path: str
|
||||
file_name: str
|
||||
file_ext: str
|
||||
summary: str
|
||||
keywords: list[str]
|
||||
topic: str
|
||||
domain: str
|
||||
file_type: str
|
||||
summary_embedding: list[float] | None
|
||||
|
||||
@@ -30,6 +30,7 @@ class ConversationMessage(BaseModel):
|
||||
"""
|
||||
role: str = Field(..., description="The role of the speaker (e.g., 'user', 'assistant').")
|
||||
msg: str = Field(..., description="The text content of the message.")
|
||||
files: list[tuple] = Field(default_factory=list, description="The file content of the message", exclude=True)
|
||||
|
||||
|
||||
class TemporalValidityRange(BaseModel):
|
||||
@@ -130,7 +131,8 @@ class Chunk(BaseModel):
|
||||
content: str = Field(..., description="The content of the chunk as a string.")
|
||||
speaker: Optional[str] = Field(None, description="The speaker/role for this chunk (user/assistant).")
|
||||
statements: List[Statement] = Field(default_factory=list, description="A list of statements in the chunk.")
|
||||
chunk_embedding: Optional[List[float]] = Field(None, description="The embedding vector of the chunk.")
|
||||
files: list[tuple] = Field(default_factory=list, description="List of files in the chunk.")
|
||||
chunk_embedding: Optional[List[float]] = Field(default=None, description="The embedding vector of the chunk.")
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict, description="Additional metadata for the chunk.")
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -238,7 +238,7 @@ def rerank_with_activation(
|
||||
|
||||
reranked: Dict[str, List[Dict[str, Any]]] = {}
|
||||
|
||||
for category in ["statements", "chunks", "entities", "summaries"]:
|
||||
for category in ["statements", "chunks", "entities", "summaries", "communities"]:
|
||||
keyword_items = keyword_results.get(category, [])
|
||||
embedding_items = embedding_results.get(category, [])
|
||||
|
||||
@@ -281,21 +281,23 @@ def rerank_with_activation(
|
||||
for item in items_list:
|
||||
item_id = item.get("id") or item.get("uuid") or item.get("chunk_id")
|
||||
if item_id and item_id in combined_items:
|
||||
combined_items[item_id]["normalized_activation_value"] = item.get("normalized_activation_value", 0)
|
||||
combined_items[item_id]["normalized_activation_value"] = item.get("normalized_activation_value")
|
||||
|
||||
# 步骤 4: 计算基础分数和最终分数
|
||||
for item_id, item in combined_items.items():
|
||||
bm25_norm = float(item.get("bm25_score", 0) or 0)
|
||||
emb_norm = float(item.get("embedding_score", 0) or 0)
|
||||
act_norm = float(item.get("normalized_activation_value", 0) or 0)
|
||||
# normalized_activation_value 为 None 表示该节点无激活值,保留 None 语义
|
||||
raw_act_norm = item.get("normalized_activation_value")
|
||||
act_norm = float(raw_act_norm) if raw_act_norm is not None else None
|
||||
|
||||
# 第一阶段:只考虑内容相关性(BM25 + Embedding)
|
||||
# alpha 控制 BM25 权重,(1-alpha) 控制 Embedding 权重
|
||||
content_score = alpha * bm25_norm + (1 - alpha) * emb_norm
|
||||
base_score = content_score # 第一阶段用内容分数
|
||||
|
||||
# 存储激活度分数供第二阶段使用
|
||||
item["activation_score"] = act_norm
|
||||
# 存储激活度分数供第二阶段使用(None 表示无激活值,不参与激活值排序)
|
||||
item["activation_score"] = act_norm # 可能为 None
|
||||
item["content_score"] = content_score
|
||||
item["base_score"] = base_score
|
||||
|
||||
@@ -724,6 +726,8 @@ async def run_hybrid_search(
|
||||
try:
|
||||
keyword_task = None
|
||||
embedding_task = None
|
||||
keyword_results: Dict[str, List] = {}
|
||||
embedding_results: Dict[str, List] = {}
|
||||
|
||||
if search_type in ["keyword", "hybrid"]:
|
||||
# Keyword-based search
|
||||
@@ -746,35 +750,42 @@ async def run_hybrid_search(
|
||||
|
||||
# 从数据库读取嵌入器配置(按 ID)并构建 RedBearModelConfig
|
||||
config_load_start = time.time()
|
||||
with get_db_context() as db:
|
||||
config_service = MemoryConfigService(db)
|
||||
embedder_config_dict = config_service.get_embedder_config(str(memory_config.embedding_model_id))
|
||||
rb_config = RedBearModelConfig(
|
||||
model_name=embedder_config_dict["model_name"],
|
||||
provider=embedder_config_dict["provider"],
|
||||
api_key=embedder_config_dict["api_key"],
|
||||
base_url=embedder_config_dict["base_url"],
|
||||
type="llm"
|
||||
)
|
||||
config_load_time = time.time() - config_load_start
|
||||
logger.info(f"[PERF] Config loading took {config_load_time:.4f}s")
|
||||
|
||||
# Init embedder
|
||||
embedder_init_start = time.time()
|
||||
embedder = OpenAIEmbedderClient(model_config=rb_config)
|
||||
embedder_init_time = time.time() - embedder_init_start
|
||||
logger.info(f"[PERF] Embedder init took {embedder_init_time:.4f}s")
|
||||
|
||||
embedding_task = asyncio.create_task(
|
||||
search_graph_by_embedding(
|
||||
connector=connector,
|
||||
embedder_client=embedder,
|
||||
query_text=query_text,
|
||||
end_user_id=end_user_id,
|
||||
limit=limit,
|
||||
include=include,
|
||||
try:
|
||||
with get_db_context() as db:
|
||||
config_service = MemoryConfigService(db)
|
||||
embedder_config_dict = config_service.get_embedder_config(str(memory_config.embedding_model_id))
|
||||
rb_config = RedBearModelConfig(
|
||||
model_name=embedder_config_dict["model_name"],
|
||||
provider=embedder_config_dict["provider"],
|
||||
api_key=embedder_config_dict["api_key"],
|
||||
base_url=embedder_config_dict["base_url"],
|
||||
type="llm"
|
||||
)
|
||||
)
|
||||
config_load_time = time.time() - config_load_start
|
||||
logger.info(f"[PERF] Config loading took {config_load_time:.4f}s")
|
||||
|
||||
# Init embedder
|
||||
embedder_init_start = time.time()
|
||||
embedder = OpenAIEmbedderClient(model_config=rb_config)
|
||||
embedder_init_time = time.time() - embedder_init_start
|
||||
logger.info(f"[PERF] Embedder init took {embedder_init_time:.4f}s")
|
||||
|
||||
embedding_task = asyncio.create_task(
|
||||
search_graph_by_embedding(
|
||||
connector=connector,
|
||||
embedder_client=embedder,
|
||||
query_text=query_text,
|
||||
end_user_id=end_user_id,
|
||||
limit=limit,
|
||||
include=include,
|
||||
)
|
||||
)
|
||||
except Exception as emb_init_err:
|
||||
logger.warning(
|
||||
f"[PERF] Embedding search skipped due to init error "
|
||||
f"(embedding_model_id={memory_config.embedding_model_id}): {emb_init_err}"
|
||||
)
|
||||
embedding_task = None
|
||||
|
||||
if keyword_task:
|
||||
keyword_results = await keyword_task
|
||||
|
||||
@@ -0,0 +1,3 @@
|
||||
from app.core.memory.storage_services.clustering_engine.label_propagation import LabelPropagationEngine
|
||||
|
||||
__all__ = ["LabelPropagationEngine"]
|
||||
@@ -0,0 +1,683 @@
|
||||
"""标签传播聚类引擎
|
||||
|
||||
基于 ZEP 论文的动态标签传播算法,对 Neo4j 中的 ExtractedEntity 节点进行社区聚类。
|
||||
|
||||
支持两种模式:
|
||||
- 全量初始化(full_clustering):首次运行,对所有实体做完整 LPA 迭代
|
||||
- 增量更新(incremental_update):新实体到达时,只处理新实体及其邻居
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
from math import sqrt
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from app.repositories.neo4j.community_repository import CommunityRepository
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 全量迭代最大轮数,防止不收敛
|
||||
MAX_ITERATIONS = 10
|
||||
|
||||
# 社区核心实体取 top-N 数量
|
||||
CORE_ENTITY_LIMIT = 10
|
||||
|
||||
|
||||
def _cosine_similarity(v1: Optional[List[float]], v2: Optional[List[float]]) -> float:
|
||||
"""计算两个向量的余弦相似度,任一为空则返回 0。"""
|
||||
if not v1 or not v2 or len(v1) != len(v2):
|
||||
return 0.0
|
||||
dot = sum(a * b for a, b in zip(v1, v2))
|
||||
norm1 = sqrt(sum(a * a for a in v1))
|
||||
norm2 = sqrt(sum(b * b for b in v2))
|
||||
if norm1 == 0 or norm2 == 0:
|
||||
return 0.0
|
||||
return dot / (norm1 * norm2)
|
||||
|
||||
|
||||
def _weighted_vote(
|
||||
neighbors: List[Dict],
|
||||
self_embedding: Optional[List[float]],
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
加权多数投票,选出得票最高的社区。
|
||||
|
||||
权重 = 语义相似度(name_embedding 余弦)* activation_value 加成
|
||||
没有 community_id 的邻居不参与投票。
|
||||
"""
|
||||
votes: Dict[str, float] = {}
|
||||
for nb in neighbors:
|
||||
cid = nb.get("community_id")
|
||||
if not cid:
|
||||
continue
|
||||
sem = _cosine_similarity(self_embedding, nb.get("name_embedding"))
|
||||
act = nb.get("activation_value") or 0.5
|
||||
# 语义相似度权重 0.6,激活值权重 0.4
|
||||
weight = 0.6 * sem + 0.4 * act
|
||||
votes[cid] = votes.get(cid, 0.0) + weight
|
||||
|
||||
if not votes:
|
||||
return None
|
||||
return max(votes, key=votes.__getitem__)
|
||||
|
||||
|
||||
class LabelPropagationEngine:
|
||||
"""标签传播聚类引擎"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
connector: Neo4jConnector,
|
||||
llm_model_id: Optional[str] = None,
|
||||
embedding_model_id: Optional[str] = None,
|
||||
):
|
||||
self.connector = connector
|
||||
self.repo = CommunityRepository(connector)
|
||||
self.llm_model_id = llm_model_id
|
||||
self.embedding_model_id = embedding_model_id
|
||||
# 缓存客户端实例,避免重复初始化
|
||||
self._llm_client = None
|
||||
self._embedder_client = None
|
||||
|
||||
# ──────────────────────────────────────────────────────────────────────────
|
||||
# 公开接口
|
||||
# ──────────────────────────────────────────────────────────────────────────
|
||||
|
||||
async def run(
|
||||
self,
|
||||
end_user_id: str,
|
||||
new_entity_ids: Optional[List[str]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
统一入口:自动判断全量还是增量。
|
||||
|
||||
- 若该用户尚无 Community 节点 → 全量初始化
|
||||
- 否则 → 增量更新(仅处理 new_entity_ids)
|
||||
"""
|
||||
has_communities = await self.repo.has_communities(end_user_id)
|
||||
if not has_communities:
|
||||
logger.info(f"[Clustering] 用户 {end_user_id} 首次聚类,执行全量初始化")
|
||||
await self.full_clustering(end_user_id)
|
||||
else:
|
||||
if new_entity_ids:
|
||||
logger.info(
|
||||
f"[Clustering] 增量更新,新实体数: {len(new_entity_ids)}"
|
||||
)
|
||||
await self.incremental_update(new_entity_ids, end_user_id)
|
||||
|
||||
async def full_clustering(self, end_user_id: str) -> None:
|
||||
"""
|
||||
全量标签传播初始化(分批处理,控制内存峰值)。
|
||||
|
||||
策略:
|
||||
- 每次只加载 BATCH_SIZE 个实体及其邻居进内存
|
||||
- labels 字典跨批次共享(只存 id→community_id,内存极小)
|
||||
- 每批独立跑 MAX_ITERATIONS 轮 LPA,批次间通过 labels 传递社区信息
|
||||
- 所有批次完成后统一 flush 和 merge
|
||||
"""
|
||||
BATCH_SIZE = 888 # 每批实体数,可按需调整
|
||||
|
||||
# 轻量查询:只获取总数和 ID 列表,不加载 embedding 等大字段
|
||||
total_count = await self.repo.get_entity_count(end_user_id)
|
||||
if not total_count:
|
||||
logger.info(f"[Clustering] 用户 {end_user_id} 无实体,跳过全量聚类")
|
||||
return
|
||||
|
||||
all_entity_ids = await self.repo.get_all_entity_ids(end_user_id)
|
||||
logger.info(f"[Clustering] 用户 {end_user_id} 共 {total_count} 个实体,"
|
||||
f"分批大小 {BATCH_SIZE},共 {(total_count + BATCH_SIZE - 1) // BATCH_SIZE} 批")
|
||||
|
||||
# labels 跨批次共享:只存 id→community_id,内存极小
|
||||
labels: Dict[str, str] = {eid: eid for eid in all_entity_ids}
|
||||
del all_entity_ids # 释放 ID 列表,后续按批次加载完整数据
|
||||
|
||||
for batch_start in range(0, total_count, BATCH_SIZE):
|
||||
batch_entities = await self.repo.get_entities_page(
|
||||
end_user_id, skip=batch_start, limit=BATCH_SIZE
|
||||
)
|
||||
if not batch_entities:
|
||||
break
|
||||
|
||||
batch_ids = [e["id"] for e in batch_entities]
|
||||
batch_embeddings: Dict[str, Optional[List[float]]] = {
|
||||
e["id"]: e.get("name_embedding") for e in batch_entities
|
||||
}
|
||||
|
||||
logger.info(
|
||||
f"[Clustering] 批次 {batch_start // BATCH_SIZE + 1}:"
|
||||
f"加载 {len(batch_entities)} 个实体的邻居图..."
|
||||
)
|
||||
neighbors_cache = await self.repo.get_entity_neighbors_for_ids(
|
||||
batch_ids, end_user_id
|
||||
)
|
||||
logger.info(f"[Clustering] 邻居预加载完成,覆盖实体数: {len(neighbors_cache)}")
|
||||
|
||||
for iteration in range(MAX_ITERATIONS):
|
||||
changed = 0
|
||||
for entity in batch_entities:
|
||||
eid = entity["id"]
|
||||
neighbors = neighbors_cache.get(eid, [])
|
||||
|
||||
# 注入跨批次的最新标签(邻居可能在其他批次,labels 里有其最新值)
|
||||
enriched = []
|
||||
for nb in neighbors:
|
||||
nb_copy = dict(nb)
|
||||
nb_copy["community_id"] = labels.get(nb["id"], nb.get("community_id"))
|
||||
enriched.append(nb_copy)
|
||||
|
||||
new_label = _weighted_vote(enriched, batch_embeddings.get(eid))
|
||||
if new_label and new_label != labels[eid]:
|
||||
labels[eid] = new_label
|
||||
changed += 1
|
||||
|
||||
logger.info(
|
||||
f"[Clustering] 批次 {batch_start // BATCH_SIZE + 1} "
|
||||
f"迭代 {iteration + 1}/{MAX_ITERATIONS},标签变化数: {changed}"
|
||||
)
|
||||
if changed == 0:
|
||||
logger.info("[Clustering] 标签已收敛,提前结束本批迭代")
|
||||
break
|
||||
|
||||
# 释放本批次的大对象
|
||||
del neighbors_cache, batch_embeddings, batch_entities
|
||||
|
||||
# 所有批次完成,统一写入 Neo4j
|
||||
await self._flush_labels(labels, end_user_id)
|
||||
pre_merge_count = len(set(labels.values()))
|
||||
logger.info(
|
||||
f"[Clustering] 全量迭代完成,共 {pre_merge_count} 个社区,"
|
||||
f"{len(labels)} 个实体,开始后处理合并"
|
||||
)
|
||||
|
||||
all_community_ids = list(set(labels.values()))
|
||||
await self._evaluate_merge(all_community_ids, end_user_id)
|
||||
|
||||
logger.info(
|
||||
f"[Clustering] 全量聚类完成,合并前 {pre_merge_count} 个社区,"
|
||||
f"{len(labels)} 个实体"
|
||||
)
|
||||
|
||||
# 查询存活社区并生成元数据
|
||||
surviving_communities = await self.repo.get_all_entities(end_user_id)
|
||||
surviving_community_ids = list({
|
||||
e.get("community_id") for e in surviving_communities
|
||||
if e.get("community_id")
|
||||
})
|
||||
logger.info(f"[Clustering] 合并后实际存活社区数: {len(surviving_community_ids)}")
|
||||
await self._generate_community_metadata(surviving_community_ids, end_user_id)
|
||||
|
||||
async def incremental_update(
|
||||
self, new_entity_ids: List[str], end_user_id: str
|
||||
) -> None:
|
||||
"""
|
||||
增量更新:只处理新实体及其邻居,不重跑全图。
|
||||
|
||||
1. 对每个新实体查询邻居
|
||||
2. 加权多数投票决定社区归属
|
||||
3. 若邻居无社区 → 创建新社区
|
||||
4. 若邻居分属多个社区 → 评估是否合并
|
||||
"""
|
||||
# 收集所有需要生成元数据的社区ID
|
||||
communities_to_update = set()
|
||||
|
||||
for entity_id in new_entity_ids:
|
||||
cid = await self._process_single_entity(entity_id, end_user_id)
|
||||
if cid:
|
||||
communities_to_update.add(cid)
|
||||
|
||||
# 批量生成所有社区的元数据
|
||||
if communities_to_update:
|
||||
await self._generate_community_metadata(list(communities_to_update), end_user_id, force=True)
|
||||
|
||||
# ──────────────────────────────────────────────────────────────────────────
|
||||
# 内部方法
|
||||
# ──────────────────────────────────────────────────────────────────────────
|
||||
|
||||
async def _process_single_entity(
|
||||
self, entity_id: str, end_user_id: str
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
处理单个新实体的社区分配。
|
||||
|
||||
该函数会为新实体分配社区,可能的情况包括:
|
||||
1. 孤立实体(无邻居):创建新的单成员社区
|
||||
2. 邻居都没有社区:创建新社区并将实体和邻居都加入
|
||||
3. 邻居有社区:通过加权投票选择最合适的社区加入
|
||||
|
||||
Returns:
|
||||
Optional[str]: 分配到的社区ID。当前实现总是返回一个有效的社区ID,
|
||||
但返回类型保留为Optional以支持未来可能的扩展场景
|
||||
(例如:实体无法分配到任何社区的情况)。
|
||||
调用方应检查返回值的真假性(truthiness)。
|
||||
"""
|
||||
neighbors = await self.repo.get_entity_neighbors(entity_id, end_user_id)
|
||||
|
||||
# 查询自身 embedding(从邻居查询结果中无法获取,需单独查)
|
||||
self_embedding = await self._get_entity_embedding(entity_id, end_user_id)
|
||||
|
||||
if not neighbors:
|
||||
# 孤立实体:创建单成员社区
|
||||
new_cid = self._new_community_id()
|
||||
await self.repo.upsert_community(new_cid, end_user_id, member_count=1)
|
||||
await self.repo.assign_entity_to_community(entity_id, new_cid, end_user_id)
|
||||
logger.debug(f"[Clustering] 孤立实体 {entity_id} → 新社区 {new_cid}")
|
||||
return new_cid
|
||||
|
||||
# 统计邻居社区分布
|
||||
community_ids_in_neighbors = set(
|
||||
nb["community_id"] for nb in neighbors if nb.get("community_id")
|
||||
)
|
||||
|
||||
target_cid = _weighted_vote(neighbors, self_embedding)
|
||||
|
||||
if target_cid is None:
|
||||
# 邻居都没有社区,连同新实体一起创建新社区
|
||||
new_cid = self._new_community_id()
|
||||
await self.repo.upsert_community(new_cid, end_user_id)
|
||||
await self.repo.assign_entity_to_community(entity_id, new_cid, end_user_id)
|
||||
for nb in neighbors:
|
||||
await self.repo.assign_entity_to_community(
|
||||
nb["id"], new_cid, end_user_id
|
||||
)
|
||||
await self.repo.refresh_member_count(new_cid, end_user_id)
|
||||
logger.debug(
|
||||
f"[Clustering] 新实体 {entity_id} 与 {len(neighbors)} 个无社区邻居 → 新社区 {new_cid}"
|
||||
)
|
||||
return new_cid
|
||||
else:
|
||||
# 加入得票最多的社区
|
||||
await self.repo.assign_entity_to_community(entity_id, target_cid, end_user_id)
|
||||
await self.repo.refresh_member_count(target_cid, end_user_id)
|
||||
logger.debug(f"[Clustering] 新实体 {entity_id} → 社区 {target_cid}")
|
||||
|
||||
# 若邻居分属多个社区,评估合并
|
||||
if len(community_ids_in_neighbors) > 1:
|
||||
await self._evaluate_merge(
|
||||
list(community_ids_in_neighbors), end_user_id
|
||||
)
|
||||
# 返回目标社区ID,稍后批量生成元数据
|
||||
return target_cid
|
||||
|
||||
async def _evaluate_merge(
|
||||
self, community_ids: List[str], end_user_id: str
|
||||
) -> None:
|
||||
"""
|
||||
评估多个社区是否应合并。
|
||||
|
||||
策略:计算各社区成员 embedding 的平均向量,若两两余弦相似度 > 0.75 则合并。
|
||||
合并时保留成员数最多的社区,其余成员迁移过来。
|
||||
|
||||
全量场景(社区数 > 20)使用批量查询,避免 N 次数据库往返。
|
||||
"""
|
||||
MERGE_THRESHOLD = 0.85
|
||||
BATCH_THRESHOLD = 20 # 超过此数量走批量查询
|
||||
|
||||
community_embeddings: Dict[str, Optional[List[float]]] = {}
|
||||
community_sizes: Dict[str, int] = {}
|
||||
|
||||
if len(community_ids) > BATCH_THRESHOLD:
|
||||
# 批量查询:一次拉取所有社区成员
|
||||
all_members = await self.repo.get_all_community_members_batch(
|
||||
community_ids, end_user_id
|
||||
)
|
||||
for cid in community_ids:
|
||||
members = all_members.get(cid, [])
|
||||
community_sizes[cid] = len(members)
|
||||
valid_embeddings = [
|
||||
m["name_embedding"] for m in members if m.get("name_embedding")
|
||||
]
|
||||
if valid_embeddings:
|
||||
dim = len(valid_embeddings[0])
|
||||
community_embeddings[cid] = [
|
||||
sum(e[i] for e in valid_embeddings) / len(valid_embeddings)
|
||||
for i in range(dim)
|
||||
]
|
||||
else:
|
||||
community_embeddings[cid] = None
|
||||
else:
|
||||
# 增量场景:逐个查询
|
||||
for cid in community_ids:
|
||||
members = await self.repo.get_community_members(cid, end_user_id)
|
||||
community_sizes[cid] = len(members)
|
||||
valid_embeddings = [
|
||||
m["name_embedding"] for m in members if m.get("name_embedding")
|
||||
]
|
||||
if valid_embeddings:
|
||||
dim = len(valid_embeddings[0])
|
||||
community_embeddings[cid] = [
|
||||
sum(e[i] for e in valid_embeddings) / len(valid_embeddings)
|
||||
for i in range(dim)
|
||||
]
|
||||
else:
|
||||
community_embeddings[cid] = None
|
||||
|
||||
# 找出应合并的社区对
|
||||
to_merge: List[tuple] = []
|
||||
cids = list(community_ids)
|
||||
for i in range(len(cids)):
|
||||
for j in range(i + 1, len(cids)):
|
||||
sim = _cosine_similarity(
|
||||
community_embeddings[cids[i]],
|
||||
community_embeddings[cids[j]],
|
||||
)
|
||||
if sim > MERGE_THRESHOLD:
|
||||
to_merge.append((cids[i], cids[j]))
|
||||
|
||||
logger.info(f"[Clustering] 发现 {len(to_merge)} 对可合并社区")
|
||||
|
||||
# 执行合并:逐对处理,每次合并后重新计算合并社区的平均向量
|
||||
# 避免 union-find 链式传递导致语义不相关的社区被间接合并
|
||||
# (A≈B、B≈C 不代表 A≈C,不能因传递性把 A/B/C 全部合并)
|
||||
merged_into: Dict[str, str] = {} # dissolve → keep 的最终映射
|
||||
|
||||
def get_root(x: str) -> str:
|
||||
"""路径压缩,找到 x 当前所属的根社区。"""
|
||||
while x in merged_into:
|
||||
merged_into[x] = merged_into.get(merged_into[x], merged_into[x])
|
||||
x = merged_into[x]
|
||||
return x
|
||||
|
||||
for c1, c2 in to_merge:
|
||||
root1, root2 = get_root(c1), get_root(c2)
|
||||
if root1 == root2:
|
||||
continue
|
||||
|
||||
# 用合并后的最新平均向量重新验证相似度
|
||||
# 防止链式传递:A≈B 合并后 B 的向量已更新,C 必须和新 B 相似才能合并
|
||||
current_sim = _cosine_similarity(
|
||||
community_embeddings.get(root1),
|
||||
community_embeddings.get(root2),
|
||||
)
|
||||
if current_sim <= MERGE_THRESHOLD:
|
||||
# 合并后向量已漂移,不再满足阈值,跳过
|
||||
logger.debug(
|
||||
f"[Clustering] 跳过合并 {root1} ↔ {root2},"
|
||||
f"当前相似度 {current_sim:.3f} ≤ {MERGE_THRESHOLD}"
|
||||
)
|
||||
continue
|
||||
|
||||
keep = root1 if community_sizes.get(root1, 0) >= community_sizes.get(root2, 0) else root2
|
||||
dissolve = root2 if keep == root1 else root1
|
||||
merged_into[dissolve] = keep
|
||||
|
||||
members = await self.repo.get_community_members(dissolve, end_user_id)
|
||||
for m in members:
|
||||
await self.repo.assign_entity_to_community(m["id"], keep, end_user_id)
|
||||
|
||||
# 合并后重新计算 keep 的平均向量(加权平均)
|
||||
keep_emb = community_embeddings.get(keep)
|
||||
dissolve_emb = community_embeddings.get(dissolve)
|
||||
keep_size = community_sizes.get(keep, 0)
|
||||
dissolve_size = community_sizes.get(dissolve, 0)
|
||||
total_size = keep_size + dissolve_size
|
||||
if keep_emb and dissolve_emb and total_size > 0:
|
||||
dim = len(keep_emb)
|
||||
community_embeddings[keep] = [
|
||||
(keep_emb[i] * keep_size + dissolve_emb[i] * dissolve_size) / total_size
|
||||
for i in range(dim)
|
||||
]
|
||||
community_embeddings[dissolve] = None
|
||||
|
||||
community_sizes[keep] = total_size
|
||||
community_sizes[dissolve] = 0
|
||||
await self.repo.refresh_member_count(keep, end_user_id)
|
||||
logger.info(
|
||||
f"[Clustering] 社区合并: {dissolve} → {keep},"
|
||||
f"相似度={current_sim:.3f},迁移 {len(members)} 个成员"
|
||||
)
|
||||
|
||||
async def _flush_labels(
|
||||
self, labels: Dict[str, str], end_user_id: str
|
||||
) -> None:
|
||||
"""将内存中的标签批量写入 Neo4j。"""
|
||||
# 先创建所有唯一社区节点
|
||||
unique_communities = set(labels.values())
|
||||
for cid in unique_communities:
|
||||
await self.repo.upsert_community(cid, end_user_id)
|
||||
|
||||
# 再批量分配实体
|
||||
for entity_id, community_id in labels.items():
|
||||
await self.repo.assign_entity_to_community(
|
||||
entity_id, community_id, end_user_id
|
||||
)
|
||||
|
||||
# 刷新成员数
|
||||
for cid in unique_communities:
|
||||
await self.repo.refresh_member_count(cid, end_user_id)
|
||||
|
||||
async def _get_entity_embedding(
|
||||
self, entity_id: str, end_user_id: str
|
||||
) -> Optional[List[float]]:
|
||||
"""查询单个实体的 name_embedding。"""
|
||||
try:
|
||||
result = await self.connector.execute_query(
|
||||
"MATCH (e:ExtractedEntity {id: $eid, end_user_id: $uid}) "
|
||||
"RETURN e.name_embedding AS name_embedding",
|
||||
eid=entity_id,
|
||||
uid=end_user_id,
|
||||
)
|
||||
return result[0]["name_embedding"] if result else None
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _build_entity_lines(members: List[Dict]) -> List[str]:
|
||||
"""将实体列表格式化为 prompt 行,包含 name、aliases、description、example。"""
|
||||
lines = []
|
||||
for m in members:
|
||||
m_name = m.get("name", "")
|
||||
aliases = m.get("aliases") or []
|
||||
description = m.get("description") or ""
|
||||
example = m.get("example") or ""
|
||||
aliases_str = f"(别名:{'、'.join(aliases)})" if aliases else ""
|
||||
desc_str = f":{description}" if description else ""
|
||||
example_str = f"(示例:{example})" if example else ""
|
||||
lines.append(f"- {m_name}{aliases_str}{desc_str}{example_str}")
|
||||
return lines
|
||||
|
||||
async def _generate_community_metadata(
|
||||
self, community_ids: List[str], end_user_id: str, force: bool = False
|
||||
) -> None:
|
||||
"""
|
||||
为一个或多个社区生成并写入元数据(优化版:批量 LLM 调用)。
|
||||
|
||||
流程:
|
||||
1. 批量准备所有社区的 prompt
|
||||
2. 并发调用 LLM 生成所有社区的 name / summary
|
||||
3. 批量 embed 所有 summary
|
||||
4. 批量写入数据库
|
||||
|
||||
Args:
|
||||
force: 为 True 时跳过完整性检查,强制重新生成(用于增量更新成员变化后)
|
||||
"""
|
||||
async def _prepare_one(cid: str) -> Optional[Dict]:
|
||||
"""准备单个社区的数据和 prompt"""
|
||||
try:
|
||||
if not force:
|
||||
check_embedding = bool(self.embedding_model_id)
|
||||
if await self.repo.is_community_complete(cid, end_user_id, check_embedding=check_embedding):
|
||||
return None
|
||||
|
||||
members = await self.repo.get_community_members(cid, end_user_id)
|
||||
if not members:
|
||||
logger.warning(f"[Clustering] 社区 {cid} 无成员,跳过元数据生成")
|
||||
return None
|
||||
|
||||
sorted_members = sorted(
|
||||
members,
|
||||
key=lambda m: m.get("activation_value") or 0,
|
||||
reverse=True,
|
||||
)
|
||||
core_entities = [m["name"] for m in sorted_members[:CORE_ENTITY_LIMIT] if m.get("name")]
|
||||
all_names = [m["name"] for m in members if m.get("name")]
|
||||
|
||||
# 默认值
|
||||
name = "、".join(core_entities[:3]) if core_entities else cid[:8]
|
||||
summary = f"包含实体:{', '.join(all_names)}"
|
||||
|
||||
# 准备 LLM prompt(如果配置了 LLM)
|
||||
prompt = None
|
||||
if self.llm_model_id:
|
||||
entity_list_str = "\n".join(self._build_entity_lines(members))
|
||||
relationships = await self.repo.get_community_relationships(cid, end_user_id)
|
||||
rel_lines = [
|
||||
f"- {r['subject']} → {r['predicate']} → {r['object']}"
|
||||
for r in relationships
|
||||
if r.get("subject") and r.get("predicate") and r.get("object")
|
||||
]
|
||||
rel_section = (
|
||||
f"\n实体间关系:\n" + "\n".join(rel_lines)
|
||||
if rel_lines else ""
|
||||
)
|
||||
prompt = (
|
||||
f"以下是一组语义相关的实体:\n{entity_list_str}{rel_section}\n\n"
|
||||
f"请为这组实体所代表的主题:\n"
|
||||
f"1. 起一个简洁的中文名称(不超过10个字)\n"
|
||||
f"2. 写一句话摘要(不超过80个字)\n\n"
|
||||
f"严格按以下格式输出,不要有其他内容:\n"
|
||||
f"名称:<名称>\n摘要:<摘要>"
|
||||
)
|
||||
|
||||
return {
|
||||
"community_id": cid,
|
||||
"end_user_id": end_user_id,
|
||||
"name": name,
|
||||
"summary": summary,
|
||||
"core_entities": core_entities,
|
||||
"prompt": prompt,
|
||||
"summary_embedding": None,
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"[Clustering] 社区 {cid} 元数据准备失败: {e}", exc_info=True)
|
||||
return None
|
||||
|
||||
# --- 阶段1:并发准备所有社区数据 ---
|
||||
results = await asyncio.gather(
|
||||
*[_prepare_one(cid) for cid in community_ids],
|
||||
return_exceptions=True,
|
||||
)
|
||||
metadata_list = []
|
||||
for cid, res in zip(community_ids, results):
|
||||
if isinstance(res, Exception):
|
||||
logger.error(f"[Clustering] 社区 {cid} 元数据准备失败: {res}", exc_info=res)
|
||||
elif res is not None:
|
||||
metadata_list.append(res)
|
||||
|
||||
if not metadata_list:
|
||||
logger.warning(f"[Clustering] 无有效元数据可写入,community_ids={community_ids}")
|
||||
return
|
||||
|
||||
# --- 阶段2:批量调用 LLM 生成 name 和 summary ---
|
||||
if self.llm_model_id:
|
||||
llm_client = self._get_llm_client()
|
||||
if not llm_client:
|
||||
logger.warning(
|
||||
f"[Clustering] LLM 已配置(model_id={self.llm_model_id})但客户端初始化失败,"
|
||||
f"将跳过社区元数据的 LLM 富化。请检查 model_id 是否正确或数据库连接是否正常。"
|
||||
)
|
||||
if llm_client:
|
||||
prompts_to_process = [(i, m) for i, m in enumerate(metadata_list) if m.get("prompt")]
|
||||
|
||||
if prompts_to_process:
|
||||
logger.info(f"[Clustering] 批量调用 LLM 生成 {len(prompts_to_process)} 个社区元数据")
|
||||
|
||||
async def _call_llm(idx: int, meta: Dict) -> tuple:
|
||||
"""单个 LLM 调用"""
|
||||
try:
|
||||
response = await llm_client.chat([{"role": "user", "content": meta["prompt"]}])
|
||||
text = response.content if hasattr(response, "content") else str(response)
|
||||
return (idx, text, None)
|
||||
except Exception as e:
|
||||
logger.warning(f"[Clustering] 社区 {meta['community_id']} LLM 生成失败: {e}")
|
||||
return (idx, None, e)
|
||||
|
||||
# 并发调用所有 LLM 请求
|
||||
llm_results = await asyncio.gather(
|
||||
*[_call_llm(idx, meta) for idx, meta in prompts_to_process],
|
||||
return_exceptions=True
|
||||
)
|
||||
|
||||
# 解析 LLM 响应
|
||||
for result in llm_results:
|
||||
if isinstance(result, Exception):
|
||||
continue
|
||||
idx, text, error = result
|
||||
if error or not text:
|
||||
continue
|
||||
|
||||
meta = metadata_list[idx]
|
||||
for line in text.strip().splitlines():
|
||||
if line.startswith("名称:"):
|
||||
meta["name"] = line[3:].strip()
|
||||
elif line.startswith("摘要:"):
|
||||
meta["summary"] = line[3:].strip()
|
||||
|
||||
logger.info(f"[Clustering] LLM 批量生成完成")
|
||||
|
||||
# --- 阶段3:批量生成 summary_embedding ---
|
||||
if self.embedding_model_id:
|
||||
embedder = self._get_embedder_client()
|
||||
if not embedder:
|
||||
logger.warning(
|
||||
f"[Clustering] Embedding 已配置(model_id={self.embedding_model_id})但客户端初始化失败,"
|
||||
f"将跳过社区摘要的向量化。请检查 model_id 是否正确或数据库连接是否正常。"
|
||||
)
|
||||
if embedder:
|
||||
try:
|
||||
summaries = [m["summary"] for m in metadata_list]
|
||||
logger.info(f"[Clustering] 批量生成 {len(summaries)} 个 summary embedding")
|
||||
embeddings = await embedder.response(summaries)
|
||||
for i, meta in enumerate(metadata_list):
|
||||
meta["summary_embedding"] = embeddings[i] if i < len(embeddings) else None
|
||||
logger.info(f"[Clustering] Embedding 批量生成完成")
|
||||
except Exception as e:
|
||||
logger.error(f"[Clustering] 批量生成 summary_embedding 失败: {e}", exc_info=True)
|
||||
|
||||
# --- 阶段4:批量写入数据库 ---
|
||||
# 移除 prompt 字段(不需要存储)
|
||||
for m in metadata_list:
|
||||
m.pop("prompt", None)
|
||||
|
||||
if len(metadata_list) == 1:
|
||||
m = metadata_list[0]
|
||||
result = await self.repo.update_community_metadata(
|
||||
community_id=m["community_id"],
|
||||
end_user_id=m["end_user_id"],
|
||||
name=m["name"],
|
||||
summary=m["summary"],
|
||||
core_entities=m["core_entities"],
|
||||
summary_embedding=m["summary_embedding"],
|
||||
)
|
||||
if not result:
|
||||
logger.error(f"[Clustering] 社区 {m['community_id']} 元数据写入失败")
|
||||
else:
|
||||
ok = await self.repo.batch_update_community_metadata(metadata_list)
|
||||
if not ok:
|
||||
logger.error(f"[Clustering] 批量写入 {len(metadata_list)} 个社区元数据失败")
|
||||
else:
|
||||
logger.info(f"[Clustering] 批量写入 {len(metadata_list)} 个社区元数据成功")
|
||||
|
||||
def _get_llm_client(self):
|
||||
"""获取或创建 LLM 客户端(单例模式)"""
|
||||
if self._llm_client is None and self.llm_model_id:
|
||||
from app.db import get_db_context
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
with get_db_context() as db:
|
||||
self._llm_client = MemoryClientFactory(db).get_llm_client(self.llm_model_id)
|
||||
logger.info(f"[Clustering] LLM 客户端初始化完成(单例): model_id={self.llm_model_id}")
|
||||
return self._llm_client
|
||||
|
||||
def _get_embedder_client(self):
|
||||
"""获取或创建 Embedder 客户端(单例模式)"""
|
||||
if self._embedder_client is None and self.embedding_model_id:
|
||||
from app.db import get_db_context
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
with get_db_context() as db:
|
||||
self._embedder_client = MemoryClientFactory(db).get_embedder_client(self.embedding_model_id)
|
||||
logger.info(f"[Clustering] Embedder 客户端初始化完成(单例): model_id={self.embedding_model_id}")
|
||||
return self._embedder_client
|
||||
|
||||
@staticmethod
|
||||
def _new_community_id() -> str:
|
||||
return str(uuid.uuid4())
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,66 @@
|
||||
"""
|
||||
场景特定配置 - 统一填充词库
|
||||
|
||||
重要性判断已完全交由 extracat_Pruning.jinja2 提示词 + LLM preserve_tokens 机制承担。
|
||||
本模块仅保留统一填充词库(filler_phrases),用于识别无意义寒暄/表情/口头禅。
|
||||
所有场景共用同一份词库,场景差异由 LLM 语义判断处理。
|
||||
"""
|
||||
|
||||
from typing import List, Set
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScenePatterns:
|
||||
"""场景特定的识别模式(仅保留填充词库)"""
|
||||
filler_phrases: Set[str] = field(default_factory=set)
|
||||
|
||||
|
||||
class SceneConfigRegistry:
|
||||
"""场景配置注册表 - 所有场景共用统一填充词库"""
|
||||
|
||||
BASE_FILLERS: Set[str] = {
|
||||
# 基础寒暄
|
||||
"你好", "您好", "在吗", "在的", "在呢", "嗯", "嗯嗯", "哦", "哦哦",
|
||||
"好的", "好", "行", "可以", "不可以", "谢谢", "多谢", "感谢",
|
||||
"拜拜", "再见", "88", "拜", "回见",
|
||||
# 口头禅
|
||||
"哈哈", "呵呵", "哈哈哈", "嘿嘿", "嘻嘻", "hiahia",
|
||||
"额", "呃", "啊", "诶", "唉", "哎", "嗯哼",
|
||||
# 确认词
|
||||
"是的", "对", "对的", "没错", "好嘞", "收到", "明白", "了解", "知道了",
|
||||
# 服务类套话
|
||||
"请问", "请稍等", "稍等", "马上", "立即",
|
||||
"正在查询", "正在处理", "正在为您", "帮您查一下",
|
||||
"还有其他问题吗", "还需要什么帮助", "很高兴为您服务",
|
||||
"感谢您的耐心等待", "抱歉让您久等了",
|
||||
"已记录", "已反馈", "已转接", "已升级",
|
||||
"祝您生活愉快", "欢迎下次咨询",
|
||||
# 外呼套话
|
||||
"喂", "hello", "打扰了", "不好意思",
|
||||
"方便接电话吗", "现在方便吗", "占用您一点时间",
|
||||
"我是", "我们是", "我们公司", "我们这边",
|
||||
"了解一下", "介绍一下", "简单说一下",
|
||||
"考虑考虑", "想一想", "再说", "再看看",
|
||||
"不需要", "不感兴趣", "没兴趣", "不用了",
|
||||
"没问题", "那就这样", "再联系", "回头聊", "有需要再说",
|
||||
# 教育场景套话
|
||||
"老师好", "同学们好", "上课", "下课", "起立", "坐下",
|
||||
"举手", "请坐", "很好", "不错", "继续",
|
||||
"下一个", "下一题", "下一位", "还有吗", "还有问题吗",
|
||||
# 标点和符号
|
||||
"。。。", "...", "???", "???", "!!!", "!!!",
|
||||
# 表情符号
|
||||
"[微笑]", "[呲牙]", "[发呆]", "[得意]", "[流泪]", "[害羞]", "[闭嘴]",
|
||||
"[睡]", "[大哭]", "[尴尬]", "[发怒]", "[调皮]", "[龇牙]", "[惊讶]",
|
||||
"[难过]", "[酷]", "[冷汗]", "[抓狂]", "[吐]", "[偷笑]", "[可爱]",
|
||||
"[白眼]", "[傲慢]", "[饥饿]", "[困]", "[惊恐]", "[流汗]", "[憨笑]",
|
||||
# 网络用语
|
||||
"hhh", "hhhh", "2333", "666", "gg", "ok", "OK", "okok",
|
||||
"emmm", "emm", "em", "mmp", "wtf", "omg",
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_config(cls, scene: str = "") -> ScenePatterns:
|
||||
"""所有场景统一返回同一份填充词库"""
|
||||
return ScenePatterns(filler_phrases=cls.BASE_FILLERS)
|
||||
@@ -203,6 +203,7 @@ def accurate_match(
|
||||
) -> Tuple[List[ExtractedEntityNode], Dict[str, str], Dict[str, Dict]]:
|
||||
"""
|
||||
精确匹配:按 (end_user_id, name, entity_type) 合并实体并建立重定向与合并记录。
|
||||
同时检测某实体的 name 是否命中另一实体的 aliases,若命中则直接合并。
|
||||
返回: (deduped_entities, id_redirect, exact_merge_map)
|
||||
"""
|
||||
exact_merge_map: Dict[str, Dict] = {}
|
||||
@@ -240,6 +241,48 @@ def accurate_match(
|
||||
pass
|
||||
|
||||
deduped_entities = list(canonical_map.values())
|
||||
|
||||
# 2) 第二轮:检测某实体的 name 是否命中另一实体的 aliases(alias-to-name 精确合并)
|
||||
# 场景:LLM 把 aliases 中的词(如"齐齐")又单独抽取为独立实体,需在此阶段合并掉
|
||||
# 优化:先构建 (end_user_id, alias_lower) -> canonical 的反向索引,查找 O(1)
|
||||
alias_index: Dict[tuple, ExtractedEntityNode] = {}
|
||||
for canonical in deduped_entities:
|
||||
uid = getattr(canonical, "end_user_id", None)
|
||||
for alias in (getattr(canonical, "aliases", []) or []):
|
||||
alias_lower = alias.strip().lower()
|
||||
if alias_lower:
|
||||
alias_index[(uid, alias_lower)] = canonical
|
||||
|
||||
i = 0
|
||||
while i < len(deduped_entities):
|
||||
ent = deduped_entities[i]
|
||||
ent_name = (getattr(ent, "name", "") or "").strip().lower()
|
||||
ent_uid = getattr(ent, "end_user_id", None)
|
||||
canonical = alias_index.get((ent_uid, ent_name))
|
||||
# 确保不是自身
|
||||
if canonical is not None and canonical.id != ent.id:
|
||||
_merge_attribute(canonical, ent)
|
||||
id_redirect[ent.id] = canonical.id
|
||||
for k, v in list(id_redirect.items()):
|
||||
if v == ent.id:
|
||||
id_redirect[k] = canonical.id
|
||||
try:
|
||||
k = f"{canonical.end_user_id}|{(canonical.name or '').strip()}|{(canonical.entity_type or '').strip()}"
|
||||
if k not in exact_merge_map:
|
||||
exact_merge_map[k] = {
|
||||
"canonical_id": canonical.id,
|
||||
"end_user_id": canonical.end_user_id,
|
||||
"name": canonical.name,
|
||||
"entity_type": canonical.entity_type,
|
||||
"merged_ids": set(),
|
||||
}
|
||||
exact_merge_map[k]["merged_ids"].add(ent.id)
|
||||
except Exception:
|
||||
pass
|
||||
deduped_entities.pop(i)
|
||||
else:
|
||||
i += 1
|
||||
|
||||
return deduped_entities, id_redirect, exact_merge_map
|
||||
|
||||
def fuzzy_match(
|
||||
|
||||
@@ -25,17 +25,17 @@ from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
|
||||
|
||||
async def dedup_layers_and_merge_and_return(
|
||||
dialogue_nodes: List[DialogueNode],
|
||||
chunk_nodes: List[ChunkNode],
|
||||
statement_nodes: List[StatementNode],
|
||||
entity_nodes: List[ExtractedEntityNode],
|
||||
statement_chunk_edges: List[StatementChunkEdge],
|
||||
statement_entity_edges: List[StatementEntityEdge],
|
||||
entity_entity_edges: List[EntityEntityEdge],
|
||||
dialog_data_list: List[DialogData],
|
||||
pipeline_config: ExtractionPipelineConfig,
|
||||
connector: Optional[Neo4jConnector] = None,
|
||||
llm_client = None,
|
||||
dialogue_nodes: List[DialogueNode],
|
||||
chunk_nodes: List[ChunkNode],
|
||||
statement_nodes: List[StatementNode],
|
||||
entity_nodes: List[ExtractedEntityNode],
|
||||
statement_chunk_edges: List[StatementChunkEdge],
|
||||
statement_entity_edges: List[StatementEntityEdge],
|
||||
entity_entity_edges: List[EntityEntityEdge],
|
||||
dialog_data_list: List[DialogData],
|
||||
pipeline_config: ExtractionPipelineConfig,
|
||||
connector: Optional[Neo4jConnector] = None,
|
||||
llm_client=None,
|
||||
) -> Tuple[
|
||||
List[DialogueNode],
|
||||
List[ChunkNode],
|
||||
@@ -44,7 +44,7 @@ async def dedup_layers_and_merge_and_return(
|
||||
List[StatementChunkEdge],
|
||||
List[StatementEntityEdge],
|
||||
List[EntityEntityEdge],
|
||||
dict, # 新增:返回去重详情
|
||||
dict
|
||||
]:
|
||||
"""
|
||||
执行两层实体去重与融合:
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,5 +1,7 @@
|
||||
import os
|
||||
from typing import Optional
|
||||
from typing import Optional, List, Any
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
|
||||
from app.core.logging_config import get_memory_logger
|
||||
from app.core.memory.models.message_models import DialogData, Chunk
|
||||
@@ -10,6 +12,20 @@ from app.core.memory.utils.config.config_utils import get_chunker_config
|
||||
logger = get_memory_logger(__name__)
|
||||
|
||||
|
||||
class ChunkerStrategy(Enum):
|
||||
"""Supported chunking strategies."""
|
||||
RECURSIVE = "RecursiveChunker"
|
||||
SEMANTIC = "SemanticChunker"
|
||||
LATE = "LateChunker"
|
||||
NEURAL = "NeuralChunker"
|
||||
LLM = "LLMChunker"
|
||||
|
||||
@classmethod
|
||||
def get_valid_strategies(cls) -> List[str]:
|
||||
"""Get list of valid strategy names."""
|
||||
return [strategy.value for strategy in cls]
|
||||
|
||||
|
||||
class DialogueChunker:
|
||||
"""A class that processes dialogues and fills them with chunks based on a specified strategy.
|
||||
|
||||
@@ -17,23 +33,51 @@ class DialogueChunker:
|
||||
of different chunking strategies to dialogue data.
|
||||
"""
|
||||
|
||||
def __init__(self, chunker_strategy: str = "RecursiveChunker", llm_client=None):
|
||||
def __init__(self, chunker_strategy: str = "RecursiveChunker", llm_client: Optional[Any] = None):
|
||||
"""Initialize the DialogueChunker with a specific chunking strategy.
|
||||
|
||||
Args:
|
||||
chunker_strategy: The chunking strategy to use (default: RecursiveChunker)
|
||||
Options: SemanticChunker, RecursiveChunker, LateChunker, NeuralChunker
|
||||
Options: SemanticChunker, RecursiveChunker, LateChunker, NeuralChunker, LLMChunker
|
||||
llm_client: LLM client instance (required for LLMChunker strategy)
|
||||
|
||||
Raises:
|
||||
ValueError: If chunker_strategy is invalid or required parameters are missing
|
||||
"""
|
||||
self.chunker_strategy = chunker_strategy
|
||||
chunker_config_dict = get_chunker_config(chunker_strategy)
|
||||
self.chunker_config = ChunkerConfig.model_validate(chunker_config_dict)
|
||||
# Validate strategy
|
||||
valid_strategies = ChunkerStrategy.get_valid_strategies()
|
||||
if chunker_strategy not in valid_strategies:
|
||||
raise ValueError(
|
||||
f"Invalid chunker_strategy: '{chunker_strategy}'. "
|
||||
f"Must be one of {valid_strategies}"
|
||||
)
|
||||
|
||||
if self.chunker_config.chunker_strategy == "LLMChunker":
|
||||
self.chunker_client = ChunkerClient(self.chunker_config, llm_client)
|
||||
else:
|
||||
self.chunker_client = ChunkerClient(self.chunker_config)
|
||||
self.chunker_strategy = chunker_strategy
|
||||
logger.info(f"Initializing DialogueChunker with strategy: {chunker_strategy}")
|
||||
|
||||
try:
|
||||
# Load and validate configuration
|
||||
chunker_config_dict = get_chunker_config(chunker_strategy)
|
||||
if not chunker_config_dict:
|
||||
raise ValueError(f"Failed to load configuration for strategy: {chunker_strategy}")
|
||||
|
||||
self.chunker_config = ChunkerConfig.model_validate(chunker_config_dict)
|
||||
|
||||
# Initialize chunker client
|
||||
if self.chunker_config.chunker_strategy == "LLMChunker":
|
||||
if not llm_client:
|
||||
raise ValueError("llm_client is required for LLMChunker strategy")
|
||||
self.chunker_client = ChunkerClient(self.chunker_config, llm_client)
|
||||
else:
|
||||
self.chunker_client = ChunkerClient(self.chunker_config)
|
||||
|
||||
logger.info(f"DialogueChunker initialized successfully with strategy: {chunker_strategy}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize DialogueChunker: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def process_dialogue(self, dialogue: DialogData) -> list[Chunk]:
|
||||
async def process_dialogue(self, dialogue: DialogData) -> List[Chunk]:
|
||||
"""Process a dialogue by generating chunks and adding them to the DialogData object.
|
||||
|
||||
Args:
|
||||
@@ -43,54 +87,125 @@ class DialogueChunker:
|
||||
A list of Chunk objects
|
||||
|
||||
Raises:
|
||||
ValueError: If chunking fails or returns empty chunks
|
||||
ValueError: If dialogue is invalid or chunking fails
|
||||
Exception: If chunking process encounters an error
|
||||
"""
|
||||
result_dialogue = await self.chunker_client.generate_chunks(dialogue)
|
||||
chunks = result_dialogue.chunks
|
||||
|
||||
if not chunks or len(chunks) == 0:
|
||||
# Validate input
|
||||
if not dialogue:
|
||||
raise ValueError("dialogue cannot be None")
|
||||
|
||||
if not dialogue.context or not dialogue.context.msgs:
|
||||
raise ValueError(
|
||||
f"Chunking failed: No chunks generated for dialogue {dialogue.ref_id}. "
|
||||
f"Messages: {len(dialogue.context.msgs) if dialogue.context else 0}, "
|
||||
f"Strategy: {self.chunker_config.chunker_strategy}"
|
||||
f"Dialogue {dialogue.ref_id} has no messages to chunk. "
|
||||
f"Context: {dialogue.context is not None}, "
|
||||
f"Messages: {len(dialogue.context.msgs) if dialogue.context else 0}"
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Processing dialogue {dialogue.ref_id} with {len(dialogue.context.msgs)} messages "
|
||||
f"using strategy: {self.chunker_strategy}"
|
||||
)
|
||||
|
||||
try:
|
||||
# Generate chunks
|
||||
result_dialogue = await self.chunker_client.generate_chunks(dialogue)
|
||||
chunks = result_dialogue.chunks
|
||||
|
||||
return chunks
|
||||
# Validate results
|
||||
if not chunks or len(chunks) == 0:
|
||||
raise ValueError(
|
||||
f"Chunking failed: No chunks generated for dialogue {dialogue.ref_id}. "
|
||||
f"Messages: {len(dialogue.context.msgs)}, "
|
||||
f"Content length: {len(dialogue.content) if dialogue.content else 0}, "
|
||||
f"Strategy: {self.chunker_config.chunker_strategy}"
|
||||
)
|
||||
|
||||
def save_chunking_results(self, dialogue: DialogData, output_path: Optional[str] = None) -> str:
|
||||
logger.info(
|
||||
f"Successfully generated {len(chunks)} chunks for dialogue {dialogue.ref_id}. "
|
||||
f"Total characters processed: {len(dialogue.content) if dialogue.content else 0}"
|
||||
)
|
||||
|
||||
return chunks
|
||||
|
||||
except ValueError:
|
||||
# Re-raise validation errors
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error processing dialogue {dialogue.ref_id} with strategy {self.chunker_strategy}: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
raise
|
||||
|
||||
def save_chunking_results(
|
||||
self,
|
||||
chunks: List[Chunk],
|
||||
dialogue: DialogData,
|
||||
output_path: Optional[str] = None,
|
||||
preview_length: int = 100
|
||||
) -> str:
|
||||
"""Save the chunking results to a file and return the output path.
|
||||
|
||||
Args:
|
||||
dialogue: The processed DialogData object with chunks
|
||||
output_path: Optional path to save the output
|
||||
chunks: List of Chunk objects to save
|
||||
dialogue: The DialogData object that was processed
|
||||
output_path: Optional path to save the output (defaults to current directory)
|
||||
preview_length: Maximum length of content preview (default: 100)
|
||||
|
||||
Returns:
|
||||
The path where the output was saved
|
||||
|
||||
Raises:
|
||||
ValueError: If chunks or dialogue is invalid
|
||||
IOError: If file writing fails
|
||||
"""
|
||||
if not output_path:
|
||||
output_path = os.path.join(
|
||||
os.path.dirname(__file__), "..", "..",
|
||||
f"chunker_output_{self.chunker_strategy.lower()}.txt"
|
||||
)
|
||||
|
||||
output_lines = [
|
||||
f"=== Chunking Results ({self.chunker_strategy}) ===",
|
||||
f"Dialogue ID: {dialogue.ref_id}",
|
||||
f"Original conversation has {len(dialogue.context.msgs)} messages",
|
||||
f"Total characters: {len(dialogue.content)}",
|
||||
f"Generated {len(dialogue.chunks)} chunks:"
|
||||
]
|
||||
# Validate input
|
||||
if not chunks:
|
||||
raise ValueError("chunks list cannot be empty")
|
||||
if not dialogue:
|
||||
raise ValueError("dialogue cannot be None")
|
||||
|
||||
for i, chunk in enumerate(dialogue.chunks):
|
||||
output_lines.append(f" Chunk {i+1}: {len(chunk.content)} characters")
|
||||
output_lines.append(f" Content preview: {chunk.content}...")
|
||||
if chunk.metadata:
|
||||
output_lines.append(f" Metadata: {chunk.metadata}")
|
||||
# Generate default output path if not provided
|
||||
if not output_path:
|
||||
output_dir = Path(__file__).parent.parent.parent
|
||||
output_path = str(output_dir / f"chunker_output_{self.chunker_strategy.lower()}.txt")
|
||||
|
||||
logger.info(f"Saving chunking results to: {output_path}")
|
||||
|
||||
try:
|
||||
# Prepare output content
|
||||
output_lines = [
|
||||
f"=== Chunking Results ({self.chunker_strategy}) ===",
|
||||
f"Dialogue ID: {dialogue.ref_id}",
|
||||
f"Original conversation has {len(dialogue.context.msgs) if dialogue.context else 0} messages",
|
||||
f"Total characters: {len(dialogue.content) if dialogue.content else 0}",
|
||||
f"Generated {len(chunks)} chunks:",
|
||||
""
|
||||
]
|
||||
|
||||
for i, chunk in enumerate(chunks, 1):
|
||||
content_preview = chunk.content[:preview_length] if chunk.content else ""
|
||||
if len(chunk.content) > preview_length:
|
||||
content_preview += "..."
|
||||
|
||||
output_lines.append(f" Chunk {i}: {len(chunk.content)} characters")
|
||||
output_lines.append(f" Content preview: {content_preview}")
|
||||
if chunk.metadata:
|
||||
output_lines.append(f" Metadata: {chunk.metadata}")
|
||||
output_lines.append("")
|
||||
|
||||
with open(output_path, "w", encoding="utf-8") as f:
|
||||
f.write("\n".join(output_lines))
|
||||
# Write to file
|
||||
with open(output_path, "w", encoding="utf-8") as f:
|
||||
f.write("\n".join(output_lines))
|
||||
|
||||
logger.info(f"Chunking results saved to: {output_path}")
|
||||
return output_path
|
||||
logger.info(f"Successfully saved chunking results to: {output_path}")
|
||||
return output_path
|
||||
|
||||
except IOError as e:
|
||||
logger.error(f"Failed to write chunking results to {output_path}: {e}", exc_info=True)
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error saving chunking results: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
|
||||
@@ -5,8 +5,11 @@
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
|
||||
from app.core.memory.models.message_models import DialogData
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
@@ -48,9 +51,9 @@ class EmbeddingGenerator:
|
||||
return await self.embedder_client.response(texts)
|
||||
|
||||
# 分批并行处理
|
||||
print(f"文本数量 {len(texts)} 超过批次大小 {batch_size},分批并行处理")
|
||||
logger.info(f"文本数量 {len(texts)} 超过批次大小 {batch_size},分批并行处理")
|
||||
batches = [texts[i:i+batch_size] for i in range(0, len(texts), batch_size)]
|
||||
print(f"分成 {len(batches)} 批,每批最多 {batch_size} 个文本")
|
||||
logger.info(f"分成 {len(batches)} 批,每批最多 {batch_size} 个文本")
|
||||
|
||||
# 并行发送所有批次
|
||||
batch_results = await asyncio.gather(*[
|
||||
@@ -62,7 +65,7 @@ class EmbeddingGenerator:
|
||||
for batch_result in batch_results:
|
||||
embeddings.extend(batch_result)
|
||||
|
||||
print(f"分批并行处理完成,共生成 {len(embeddings)} 个嵌入向量")
|
||||
logger.info(f"分批并行处理完成,共生成 {len(embeddings)} 个嵌入向量")
|
||||
return embeddings
|
||||
|
||||
async def generate_statement_embeddings(
|
||||
@@ -77,7 +80,7 @@ class EmbeddingGenerator:
|
||||
Returns:
|
||||
每个对话的陈述句嵌入向量映射列表
|
||||
"""
|
||||
print("\n=== 生成陈述句嵌入向量 ===")
|
||||
logger.debug("=== 生成陈述句嵌入向量 ===")
|
||||
|
||||
# 收集所有陈述句
|
||||
all_statements = []
|
||||
@@ -102,7 +105,7 @@ class EmbeddingGenerator:
|
||||
stmt_id = chunked_dialogs[d_idx].chunks[c_idx].statements[s_idx].id
|
||||
stmt_embedding_maps[d_idx][stmt_id] = embedding
|
||||
|
||||
print(f"为 {len(all_statements)} 个陈述句生成了嵌入向量")
|
||||
logger.info(f"为 {len(all_statements)} 个陈述句生成了嵌入向量")
|
||||
return stmt_embedding_maps
|
||||
|
||||
async def generate_chunk_embeddings(
|
||||
@@ -117,7 +120,7 @@ class EmbeddingGenerator:
|
||||
Returns:
|
||||
每个对话的分块嵌入向量映射列表
|
||||
"""
|
||||
print("\n=== 生成分块嵌入向量 ===")
|
||||
logger.debug("=== 生成分块嵌入向量 ===")
|
||||
|
||||
# 收集所有分块
|
||||
all_chunks = []
|
||||
@@ -138,7 +141,7 @@ class EmbeddingGenerator:
|
||||
chunk_id = chunked_dialogs[d_idx].chunks[c_idx].id
|
||||
chunk_embedding_maps[d_idx][chunk_id] = embedding
|
||||
|
||||
print(f"为 {len(all_chunks)} 个分块生成了嵌入向量")
|
||||
logger.info(f"为 {len(all_chunks)} 个分块生成了嵌入向量")
|
||||
return chunk_embedding_maps
|
||||
|
||||
async def generate_dialog_embeddings(
|
||||
@@ -172,7 +175,7 @@ class EmbeddingGenerator:
|
||||
Returns:
|
||||
(陈述句嵌入映射列表, 分块嵌入映射列表, 对话嵌入列表)
|
||||
"""
|
||||
print("\n=== 生成所有嵌入向量 ===")
|
||||
logger.debug("=== 生成所有嵌入向量 ===")
|
||||
|
||||
# 并发生成陈述句和分块嵌入向量
|
||||
stmt_embedding_maps, chunk_embedding_maps = await asyncio.gather(
|
||||
@@ -183,9 +186,7 @@ class EmbeddingGenerator:
|
||||
# 对话嵌入向量(当前跳过)
|
||||
dialog_embeddings = await self.generate_dialog_embeddings(chunked_dialogs)
|
||||
|
||||
print(
|
||||
f"生成完成:{len(chunked_dialogs)} 个对话的嵌入向量"
|
||||
)
|
||||
logger.info(f"生成完成:{len(chunked_dialogs)} 个对话的嵌入向量")
|
||||
|
||||
return stmt_embedding_maps, chunk_embedding_maps, dialog_embeddings
|
||||
|
||||
@@ -201,7 +202,7 @@ class EmbeddingGenerator:
|
||||
Returns:
|
||||
更新后的三元组映射列表(实体包含嵌入向量)
|
||||
"""
|
||||
print("\n=== 生成实体嵌入向量 ===")
|
||||
logger.debug("=== 生成实体嵌入向量 ===")
|
||||
|
||||
entity_texts: List[str] = []
|
||||
entity_refs: List[Any] = []
|
||||
@@ -219,7 +220,7 @@ class EmbeddingGenerator:
|
||||
entity_refs.append(ent)
|
||||
|
||||
if not entity_texts:
|
||||
print("没有找到需要生成嵌入向量的实体")
|
||||
logger.debug("没有找到需要生成嵌入向量的实体")
|
||||
return triplet_maps
|
||||
|
||||
# 批量生成嵌入向量
|
||||
@@ -227,13 +228,13 @@ class EmbeddingGenerator:
|
||||
|
||||
# 打印前几个嵌入向量的维度
|
||||
for i in range(min(5, len(embeddings))):
|
||||
print(f"实体 '{entity_texts[i]}' 嵌入向量维度: {len(embeddings[i])}")
|
||||
logger.debug(f"实体 '{entity_texts[i]}' 嵌入向量维度: {len(embeddings[i])}")
|
||||
|
||||
# 将嵌入向量赋值给实体
|
||||
for ent, emb in zip(entity_refs, embeddings):
|
||||
setattr(ent, "name_embedding", emb)
|
||||
|
||||
print(f"为 {len(entity_refs)} 个实体生成了嵌入向量")
|
||||
logger.info(f"为 {len(entity_refs)} 个实体生成了嵌入向量")
|
||||
return triplet_maps
|
||||
|
||||
|
||||
@@ -296,7 +297,7 @@ async def embedding_generation_all(
|
||||
Returns:
|
||||
(陈述句嵌入映射列表, 分块嵌入映射列表, 对话嵌入列表, 更新后的三元组映射列表)
|
||||
"""
|
||||
print("\n=== 综合嵌入向量生成(陈述句/分块/对话 + 实体)===")
|
||||
logger.debug("=== 综合嵌入向量生成(陈述句/分块/对话 + 实体)===")
|
||||
|
||||
generator = EmbeddingGenerator(embedding_id)
|
||||
|
||||
|
||||
@@ -188,7 +188,6 @@ async def _process_chunk_summary(
|
||||
response_model=MemorySummaryResponse,
|
||||
)
|
||||
summary_text = structured.summary.strip()
|
||||
|
||||
# Generate title and type for the summary
|
||||
title = None
|
||||
episodic_type = None
|
||||
|
||||
@@ -5,7 +5,7 @@ from typing import List, Dict, Optional
|
||||
from app.core.logging_config import get_memory_logger
|
||||
from app.core.memory.llm_tools.openai_client import OpenAIClient
|
||||
from app.core.memory.utils.prompt.prompt_utils import render_triplet_extraction_prompt
|
||||
from app.core.memory.utils.data.ontology import PREDICATE_DEFINITIONS, Predicate # 引入枚举 Predicate 白名单过滤
|
||||
from app.core.memory.utils.data.ontology import PREDICATE_DEFINITIONS, Predicate # 引入枚举 Predicate 白名单过滤
|
||||
from app.core.memory.models.triplet_models import TripletExtractionResponse
|
||||
from app.core.memory.models.message_models import DialogData, Statement
|
||||
from app.core.memory.models.ontology_extraction_models import OntologyTypeList
|
||||
@@ -14,15 +14,15 @@ from app.core.memory.utils.log.logging_utils import prompt_logger
|
||||
logger = get_memory_logger(__name__)
|
||||
|
||||
|
||||
|
||||
class TripletExtractor:
|
||||
"""Extracts knowledge triplets and entities from statements using LLM"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
llm_client: OpenAIClient,
|
||||
ontology_types: Optional[OntologyTypeList] = None,
|
||||
language: str = "zh"):
|
||||
self,
|
||||
llm_client: OpenAIClient,
|
||||
ontology_types: Optional[OntologyTypeList] = None,
|
||||
language: str = "zh"
|
||||
):
|
||||
"""Initialize the TripletExtractor with an LLM client
|
||||
|
||||
Args:
|
||||
@@ -65,7 +65,8 @@ class TripletExtractor:
|
||||
|
||||
# Create messages for LLM
|
||||
messages = [
|
||||
{"role": "system", "content": "You are an expert at extracting knowledge triplets and entities from text. Follow the provided instructions carefully and return valid JSON."},
|
||||
{"role": "system",
|
||||
"content": "You are an expert at extracting knowledge triplets and entities from text. Follow the provided instructions carefully and return valid JSON."},
|
||||
{"role": "user", "content": prompt_content}
|
||||
]
|
||||
|
||||
@@ -116,7 +117,8 @@ class TripletExtractor:
|
||||
logger.error(f"Error processing statement: {e}", exc_info=True)
|
||||
return TripletExtractionResponse(triplets=[], entities=[])
|
||||
|
||||
async def extract_triplets_from_statements(self, dialog_data: DialogData, limit_chunks: int = None) -> Dict[str, TripletExtractionResponse]:
|
||||
async def extract_triplets_from_statements(self, dialog_data: DialogData, limit_chunks: int = None) -> Dict[
|
||||
str, TripletExtractionResponse]:
|
||||
"""Extract triplets and entities from statements
|
||||
|
||||
Args:
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
"""
|
||||
自我反思引擎实现
|
||||
Self-Reflection Engine Implementation
|
||||
|
||||
该模块实现了记忆系统的自我反思功能,包括:
|
||||
1. 基于时间的反思 - 根据时间周期触发反思
|
||||
2. 基于事实的反思 - 检测记忆冲突并解决
|
||||
3. 综合反思 - 整合多种反思策略
|
||||
4. 反思结果应用 - 更新记忆库
|
||||
This module implements the self-reflection functionality of the memory system, including:
|
||||
1. Time-based reflection - Triggers reflection based on time cycles
|
||||
2. Fact-based reflection - Detects and resolves memory conflicts
|
||||
3. Comprehensive reflection - Integrates multiple reflection strategies
|
||||
4. Reflection result application - Updates memory database
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
@@ -38,7 +38,7 @@ from app.schemas.memory_storage_schema import (
|
||||
)
|
||||
from pydantic import BaseModel
|
||||
|
||||
# 配置日志
|
||||
# Configure logging
|
||||
_root_logger = logging.getLogger()
|
||||
if not _root_logger.handlers:
|
||||
logging.basicConfig(
|
||||
@@ -49,35 +49,62 @@ else:
|
||||
_root_logger.setLevel(logging.INFO)
|
||||
|
||||
class TranslationResponse(BaseModel):
|
||||
"""翻译响应模型"""
|
||||
"""Translation response model for language conversion"""
|
||||
data: str
|
||||
|
||||
class ReflectionRange(str, Enum):
|
||||
"""反思范围枚举"""
|
||||
PARTIAL = "partial" # 从检索结果中反思
|
||||
ALL = "all" # 从整个数据库中反思
|
||||
"""
|
||||
Reflection range enumeration
|
||||
|
||||
Defines the scope of data to be included in reflection operations.
|
||||
"""
|
||||
PARTIAL = "partial" # Reflect from retrieval results
|
||||
ALL = "all" # Reflect from entire database
|
||||
|
||||
|
||||
class ReflectionBaseline(str, Enum):
|
||||
"""反思基线枚举"""
|
||||
TIME = "TIME" # 基于时间的反思
|
||||
FACT = "FACT" # 基于事实的反思
|
||||
HYBRID = "HYBRID" # 混合反思
|
||||
"""
|
||||
Reflection baseline enumeration
|
||||
|
||||
Defines the strategy or approach used for reflection operations.
|
||||
"""
|
||||
TIME = "TIME" # Time-based reflection
|
||||
FACT = "FACT" # Fact-based reflection
|
||||
HYBRID = "HYBRID" # Hybrid reflection combining multiple strategies
|
||||
|
||||
|
||||
class ReflectionConfig(BaseModel):
|
||||
"""反思引擎配置"""
|
||||
"""
|
||||
Reflection engine configuration
|
||||
|
||||
Defines all configuration parameters for the reflection engine including
|
||||
operation modes, model settings, and evaluation criteria.
|
||||
|
||||
Attributes:
|
||||
enabled: Whether reflection engine is enabled
|
||||
iteration_period: Reflection cycle period (e.g., "3" hours)
|
||||
reflexion_range: Scope of reflection (PARTIAL or ALL)
|
||||
baseline: Reflection strategy (TIME, FACT, or HYBRID)
|
||||
model_id: LLM model identifier for reflection operations
|
||||
end_user_id: User identifier for scoped operations
|
||||
output_example: Example output format for guidance
|
||||
memory_verify: Enable memory verification checks
|
||||
quality_assessment: Enable quality assessment evaluation
|
||||
violation_handling_strategy: Strategy for handling violations
|
||||
language_type: Language type for output ("zh" or "en")
|
||||
"""
|
||||
enabled: bool = False
|
||||
iteration_period: str = "3" # 反思周期
|
||||
iteration_period: str = "3" # Reflection cycle period
|
||||
reflexion_range: ReflectionRange = ReflectionRange.PARTIAL
|
||||
baseline: ReflectionBaseline = ReflectionBaseline.TIME
|
||||
model_id: Optional[str] = None # 模型ID
|
||||
model_id: Optional[str] = None # Model ID
|
||||
end_user_id: Optional[str] = None
|
||||
output_example: Optional[str] = None # 输出示例
|
||||
output_example: Optional[str] = None # Output example
|
||||
|
||||
# 评估相关字段
|
||||
memory_verify: bool = True # 记忆验证
|
||||
quality_assessment: bool = True # 质量评估
|
||||
violation_handling_strategy: str = "warn" # 违规处理策略
|
||||
# Evaluation related fields
|
||||
memory_verify: bool = True # Memory verification
|
||||
quality_assessment: bool = True # Quality assessment
|
||||
violation_handling_strategy: str = "warn" # Violation handling strategy
|
||||
language_type: str = "zh"
|
||||
|
||||
class Config:
|
||||
@@ -85,7 +112,21 @@ class ReflectionConfig(BaseModel):
|
||||
|
||||
|
||||
class ReflectionResult(BaseModel):
|
||||
"""反思结果"""
|
||||
"""
|
||||
Reflection operation result
|
||||
|
||||
Contains comprehensive information about the outcome of a reflection operation
|
||||
including success status, metrics, and execution details.
|
||||
|
||||
Attributes:
|
||||
success: Whether the reflection operation succeeded
|
||||
message: Descriptive message about the operation result
|
||||
conflicts_found: Number of conflicts detected during reflection
|
||||
conflicts_resolved: Number of conflicts successfully resolved
|
||||
memories_updated: Number of memory entries updated in database
|
||||
execution_time: Total time taken for the reflection operation
|
||||
details: Additional details about the operation (optional)
|
||||
"""
|
||||
success: bool
|
||||
message: str
|
||||
conflicts_found: int = 0
|
||||
@@ -97,9 +138,22 @@ class ReflectionResult(BaseModel):
|
||||
|
||||
class ReflectionEngine:
|
||||
"""
|
||||
自我反思引擎
|
||||
|
||||
负责执行记忆系统的自我反思,包括冲突检测、冲突解决和记忆更新。
|
||||
Self-Reflection Engine
|
||||
|
||||
Responsible for executing memory system self-reflection operations including
|
||||
conflict detection, conflict resolution, and memory updates. Supports multiple
|
||||
reflection strategies and provides comprehensive result tracking.
|
||||
|
||||
The engine can operate in different modes:
|
||||
- Time-based: Reflects on memories within specific time periods
|
||||
- Fact-based: Detects and resolves factual conflicts in memories
|
||||
- Hybrid: Combines multiple reflection strategies
|
||||
|
||||
Attributes:
|
||||
config: Reflection engine configuration
|
||||
neo4j_connector: Neo4j database connector
|
||||
llm_client: Language model client for analysis
|
||||
Various function handlers for data processing and prompt rendering
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -115,18 +169,21 @@ class ReflectionEngine:
|
||||
update_query: Optional[str] = None
|
||||
):
|
||||
"""
|
||||
初始化反思引擎
|
||||
Initialize reflection engine
|
||||
|
||||
Sets up the reflection engine with configuration and optional dependencies.
|
||||
Uses lazy initialization to avoid circular imports and optimize startup time.
|
||||
|
||||
Args:
|
||||
config: 反思引擎配置
|
||||
neo4j_connector: Neo4j 连接器(可选)
|
||||
llm_client: LLM 客户端(可选)
|
||||
get_data_func: 获取数据的函数(可选)
|
||||
render_evaluate_prompt_func: 渲染评估提示词的函数(可选)
|
||||
render_reflexion_prompt_func: 渲染反思提示词的函数(可选)
|
||||
conflict_schema: 冲突结果 Schema(可选)
|
||||
reflexion_schema: 反思结果 Schema(可选)
|
||||
update_query: 更新查询语句(可选)
|
||||
config: Reflection engine configuration object
|
||||
neo4j_connector: Neo4j connector instance (optional, will be created if not provided)
|
||||
llm_client: LLM client instance (optional, will be created if not provided)
|
||||
get_data_func: Function for retrieving data (optional, uses default if not provided)
|
||||
render_evaluate_prompt_func: Function for rendering evaluation prompts (optional)
|
||||
render_reflexion_prompt_func: Function for rendering reflection prompts (optional)
|
||||
conflict_schema: Schema for conflict result validation (optional)
|
||||
reflexion_schema: Schema for reflection result validation (optional)
|
||||
update_query: Query string for database updates (optional)
|
||||
"""
|
||||
self.config = config
|
||||
self.neo4j_connector = neo4j_connector
|
||||
@@ -137,14 +194,20 @@ class ReflectionEngine:
|
||||
self.conflict_schema = conflict_schema
|
||||
self.reflexion_schema = reflexion_schema
|
||||
self.update_query = update_query
|
||||
self._semaphore = asyncio.Semaphore(5) # 默认并发数为5
|
||||
self._semaphore = asyncio.Semaphore(5) # Default concurrency limit of 5
|
||||
|
||||
|
||||
# 延迟导入以避免循环依赖
|
||||
# Lazy import to avoid circular dependencies
|
||||
self._lazy_init_done = False
|
||||
|
||||
def _lazy_init(self):
|
||||
"""延迟初始化,避免循环导入"""
|
||||
"""
|
||||
Lazy initialization to avoid circular imports
|
||||
|
||||
Initializes dependencies only when needed, preventing circular import issues
|
||||
and optimizing startup performance. Sets up default implementations for
|
||||
any components not provided during construction.
|
||||
"""
|
||||
if self._lazy_init_done:
|
||||
return
|
||||
|
||||
@@ -158,7 +221,7 @@ class ReflectionEngine:
|
||||
factory = MemoryClientFactory(db)
|
||||
self.llm_client = factory.get_llm_client(self.config.model_id)
|
||||
elif isinstance(self.llm_client, str):
|
||||
# 如果 llm_client 是字符串(model_id),则用它初始化客户端
|
||||
# If llm_client is a string (model_id), use it to initialize the client
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.db import get_db_context
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
@@ -172,10 +235,10 @@ class ReflectionEngine:
|
||||
model_config = config_service.get_model_config(model_id)
|
||||
|
||||
extra_params={
|
||||
"temperature": 0.2, # 降低温度提高响应速度和一致性
|
||||
"max_tokens": 600, # 限制最大token数
|
||||
"top_p": 0.8, # 优化采样参数
|
||||
"stream": False, # 确保非流式输出以获得最快响应
|
||||
"temperature": 0.2, # Lower temperature for faster response and consistency
|
||||
"max_tokens": 600, # Limit maximum token count
|
||||
"top_p": 0.8, # Optimize sampling parameters
|
||||
"stream": False, # Ensure non-streaming output for fastest response
|
||||
}
|
||||
|
||||
self.llm_client = OpenAIClient(RedBearModelConfig(
|
||||
@@ -191,7 +254,7 @@ class ReflectionEngine:
|
||||
if self.get_data_func is None:
|
||||
self.get_data_func = get_data
|
||||
|
||||
# 导入get_data_statement函数
|
||||
# Import get_data_statement function
|
||||
if not hasattr(self, 'get_data_statement'):
|
||||
self.get_data_statement = get_data_statement
|
||||
|
||||
@@ -223,13 +286,20 @@ class ReflectionEngine:
|
||||
|
||||
async def execute_reflection(self, host_id) -> ReflectionResult:
|
||||
"""
|
||||
执行完整的反思流程
|
||||
Execute complete reflection workflow
|
||||
|
||||
Performs the full reflection process including data retrieval, conflict detection,
|
||||
conflict resolution, and memory updates. This is the main entry point for
|
||||
reflection operations.
|
||||
|
||||
Args:
|
||||
host_id: 主机ID
|
||||
host_id: Host identifier for scoping reflection operations
|
||||
|
||||
Returns:
|
||||
ReflectionResult: 反思结果
|
||||
ReflectionResult: Comprehensive result of the reflection operation including
|
||||
success status, conflict metrics, and execution time
|
||||
"""
|
||||
# 延迟初始化
|
||||
# Lazy initialization
|
||||
self._lazy_init()
|
||||
|
||||
if not self.config.enabled:
|
||||
@@ -243,7 +313,7 @@ class ReflectionEngine:
|
||||
|
||||
print(self.config.baseline, self.config.memory_verify, self.config.quality_assessment)
|
||||
try:
|
||||
# 1. 获取反思数据
|
||||
# 1. Get reflection data
|
||||
reflexion_data, statement_databasets = await self._get_reflexion_data(host_id)
|
||||
if not reflexion_data:
|
||||
return ReflectionResult(
|
||||
@@ -252,7 +322,7 @@ class ReflectionEngine:
|
||||
execution_time=asyncio.get_event_loop().time() - start_time
|
||||
)
|
||||
|
||||
# 2. 检测冲突(基于事实的反思)
|
||||
# 2. Detect conflicts (fact-based reflection)
|
||||
conflict_data = await self._detect_conflicts(reflexion_data, statement_databasets)
|
||||
conflict_list=[]
|
||||
for i in conflict_data:
|
||||
@@ -261,7 +331,7 @@ class ReflectionEngine:
|
||||
|
||||
|
||||
conflicts_found=0
|
||||
# 3. 解决冲突
|
||||
# 3. Resolve conflicts
|
||||
solved_data = await self._resolve_conflicts(conflict_list, statement_databasets)
|
||||
|
||||
if not solved_data:
|
||||
@@ -276,7 +346,7 @@ class ReflectionEngine:
|
||||
logging.info(f"解决了 {conflicts_resolved} 个冲突")
|
||||
|
||||
|
||||
# 4. 应用反思结果(更新记忆库)
|
||||
# 4. Apply reflection results (update memory database)
|
||||
memories_updated=await self._apply_reflection_results(solved_data)
|
||||
|
||||
execution_time = asyncio.get_event_loop().time() - start_time
|
||||
@@ -302,7 +372,19 @@ class ReflectionEngine:
|
||||
)
|
||||
|
||||
async def Translate(self, text):
|
||||
# 翻译中文为英文
|
||||
"""
|
||||
Translate Chinese text to English
|
||||
|
||||
Uses the configured LLM to translate Chinese text to English with structured output.
|
||||
Provides consistent translation format for reflection results.
|
||||
|
||||
Args:
|
||||
text: Chinese text to be translated
|
||||
|
||||
Returns:
|
||||
str: Translated English text
|
||||
"""
|
||||
# Translate Chinese to English
|
||||
translation_messages = [
|
||||
{
|
||||
"role": "user",
|
||||
@@ -316,6 +398,19 @@ class ReflectionEngine:
|
||||
)
|
||||
return response.data
|
||||
async def extract_translation(self,data):
|
||||
"""
|
||||
Extract and translate reflection data to English
|
||||
|
||||
Processes reflection data structure and translates all Chinese content to English.
|
||||
Handles nested data structures including memory verifications, quality assessments,
|
||||
and reflection data while preserving the original structure.
|
||||
|
||||
Args:
|
||||
data: Dictionary containing reflection data with Chinese content
|
||||
|
||||
Returns:
|
||||
dict: Translated data structure with English content
|
||||
"""
|
||||
end_datas={}
|
||||
end_datas['source_data']=await self.Translate(data['source_data'])
|
||||
quality_assessments = []
|
||||
@@ -350,6 +445,18 @@ class ReflectionEngine:
|
||||
return end_datas
|
||||
|
||||
async def reflection_run(self):
|
||||
"""
|
||||
Execute reflection workflow with comprehensive data processing
|
||||
|
||||
Performs a complete reflection operation including conflict detection, resolution,
|
||||
and result formatting. Supports both Chinese and English output based on
|
||||
configuration settings.
|
||||
|
||||
Returns:
|
||||
dict: Comprehensive reflection results including source data, memory verifications,
|
||||
quality assessments, and reflection data. Results are translated to English
|
||||
if language_type is set to 'en'.
|
||||
"""
|
||||
self._lazy_init()
|
||||
start_time = time.time()
|
||||
memory_verifies_flag = self.config.memory_verify
|
||||
@@ -367,7 +474,7 @@ class ReflectionEngine:
|
||||
result_data['source_data'] = "我是 2023 年春天去北京工作的,后来基本一直都在北京上班,也没怎么换过城市。不过后来公司调整,2024 年上半年我被调到上海待了差不多半年,那段时间每天都是在上海办公室打卡。当时入职资料用的还是我之前的身份信息,身份证号是 11010119950308123X,银行卡是 6222023847595898,这些一直没变。对了,其实我 从 2023 年开始就一直在北京生活,从来没有长期离开过北京,上海那段更多算是远程配合"
|
||||
# 2. 检测冲突(基于事实的反思)
|
||||
conflict_data = await self._detect_conflicts(databasets, source_data)
|
||||
# 遍历数据提取字段
|
||||
# Traverse data to extract fields
|
||||
quality_assessments = []
|
||||
memory_verifies = []
|
||||
for item in conflict_data:
|
||||
@@ -375,9 +482,9 @@ class ReflectionEngine:
|
||||
memory_verifies.append(item['memory_verify'])
|
||||
result_data['memory_verifies'] = memory_verifies
|
||||
result_data['quality_assessments'] = quality_assessments
|
||||
conflicts_found = 0 # 初始化为整数0而不是空字符串
|
||||
conflicts_found = 0 # Initialize as integer 0 instead of empty string
|
||||
REMOVE_KEYS = {"created_at", "expired_at","relationship","predicate","statement_id","id","statement_id","relationship_statement_id"}
|
||||
# Clearn conflict_data,And memory_verify和quality_assessment
|
||||
# Clean conflict_data, and memory_verify and quality_assessment
|
||||
cleaned_conflict_data = []
|
||||
for item in conflict_data:
|
||||
cleaned_item = {
|
||||
@@ -389,7 +496,7 @@ class ReflectionEngine:
|
||||
for item in conflict_data:
|
||||
cleaned_data = []
|
||||
for row in item.get("data", []):
|
||||
# 删除 created_at / expired_at
|
||||
# Remove created_at / expired_at
|
||||
cleaned_row = {
|
||||
k: v
|
||||
for k, v in row.items()
|
||||
@@ -402,7 +509,7 @@ class ReflectionEngine:
|
||||
}
|
||||
cleaned_conflict_data_.append(cleaned_item)
|
||||
print(cleaned_conflict_data_)
|
||||
# 3. 解决冲突
|
||||
# 3. Resolve conflicts
|
||||
solved_data = await self._resolve_conflicts(cleaned_conflict_data_, source_data)
|
||||
if not solved_data:
|
||||
return ReflectionResult(
|
||||
@@ -413,7 +520,7 @@ class ReflectionEngine:
|
||||
)
|
||||
reflexion_data = []
|
||||
|
||||
# 遍历数据提取reflexion字段
|
||||
# Traverse data to extract reflexion fields
|
||||
for item in solved_data:
|
||||
if 'results' in item:
|
||||
for result in item['results']:
|
||||
@@ -431,15 +538,24 @@ class ReflectionEngine:
|
||||
|
||||
|
||||
async def extract_fields_from_json(self):
|
||||
"""从example.json中提取source_data和databasets字段"""
|
||||
"""
|
||||
Extract source_data and databasets fields from example.json
|
||||
|
||||
Reads reflection example data from the example.json file and extracts
|
||||
the source data and database statements for testing and demonstration purposes.
|
||||
|
||||
Returns:
|
||||
tuple: (source_data, databasets) extracted from the example file
|
||||
Returns empty lists if file reading fails
|
||||
"""
|
||||
|
||||
prompt_dir = os.path.join(os.path.dirname(__file__), "example")
|
||||
try:
|
||||
# 读取JSON文件
|
||||
# Read JSON file
|
||||
with open(prompt_dir + '/example.json', 'r', encoding='utf-8') as f:
|
||||
data = json.loads(f.read())
|
||||
|
||||
# 提取memory_verify下的字段
|
||||
# Extract fields under memory_verify
|
||||
memory_verify = data.get("memory_verify", {})
|
||||
source_data = memory_verify.get("source_data", [])
|
||||
databasets = memory_verify.get("databasets", [])
|
||||
@@ -451,15 +567,17 @@ class ReflectionEngine:
|
||||
|
||||
async def _get_reflexion_data(self, host_id: uuid.UUID) -> List[Any]:
|
||||
"""
|
||||
获取反思数据
|
||||
|
||||
根据配置的反思范围获取需要反思的记忆数据。
|
||||
Get reflection data from database
|
||||
|
||||
Retrieves memory data for reflection based on the configured reflection range.
|
||||
Supports both partial (from retrieval results) and full (entire database) modes.
|
||||
|
||||
Args:
|
||||
host_id: 主机ID
|
||||
host_id: Host UUID identifier for scoping data retrieval
|
||||
|
||||
Returns:
|
||||
List[Any]: 反思数据列表
|
||||
tuple: (reflexion_data, statement_data) containing memory data for reflection
|
||||
Returns empty lists if query fails
|
||||
"""
|
||||
|
||||
print("=== 获取反思数据 ===")
|
||||
@@ -484,26 +602,29 @@ class ReflectionEngine:
|
||||
|
||||
async def _detect_conflicts(self, data: List[Any], statement_databasets: List[Any]) -> List[Any]:
|
||||
"""
|
||||
检测冲突(基于事实的反思)
|
||||
|
||||
使用 LLM 分析记忆数据,检测其中的冲突。
|
||||
Detect conflicts (fact-based reflection)
|
||||
|
||||
Uses LLM to analyze memory data and detect conflicts within the memories.
|
||||
Performs comprehensive conflict detection including memory verification and
|
||||
quality assessment based on configuration settings.
|
||||
|
||||
Args:
|
||||
data: 待检测的记忆数据
|
||||
data: Memory data to be analyzed for conflicts
|
||||
statement_databasets: Statement database records for context
|
||||
|
||||
Returns:
|
||||
List[Any]: 冲突记忆列表
|
||||
List[Any]: List of detected conflicts with detailed analysis
|
||||
"""
|
||||
if not data:
|
||||
return []
|
||||
|
||||
# 数据预处理:如果数据量太少,直接返回无冲突
|
||||
# Data preprocessing: if data is too small, return no conflicts directly
|
||||
if len(data) < 2:
|
||||
logging.info("数据量不足,无需检测冲突")
|
||||
return []
|
||||
|
||||
# 使用转换后的数据
|
||||
# print("转换后的数据:", data[:2] if len(data) > 2 else data) # 只打印前2条避免日志过长
|
||||
# Use converted data
|
||||
# print("Converted data:", data[:2] if len(data) > 2 else data) # Only print first 2 to avoid long logs
|
||||
memory_verify = self.config.memory_verify
|
||||
|
||||
logging.info("====== 冲突检测开始 ======")
|
||||
@@ -512,7 +633,7 @@ class ReflectionEngine:
|
||||
language_type=self.config.language_type
|
||||
|
||||
try:
|
||||
# 渲染冲突检测提示词
|
||||
# Render conflict detection prompt
|
||||
rendered_prompt = await self.render_evaluate_prompt_func(
|
||||
data,
|
||||
self.conflict_schema,
|
||||
@@ -526,7 +647,7 @@ class ReflectionEngine:
|
||||
messages = [{"role": "user", "content": rendered_prompt}]
|
||||
logging.info(f"提示词长度: {len(rendered_prompt)}")
|
||||
|
||||
# 调用 LLM 进行冲突检测
|
||||
# Call LLM for conflict detection
|
||||
response = await self.llm_client.response_structured(
|
||||
messages,
|
||||
self.conflict_schema
|
||||
@@ -539,7 +660,7 @@ class ReflectionEngine:
|
||||
logging.error("LLM 冲突检测输出解析失败")
|
||||
return []
|
||||
|
||||
# 标准化返回格式
|
||||
# Standardize return format
|
||||
if isinstance(response, BaseModel):
|
||||
return [response.model_dump()]
|
||||
elif hasattr(response, 'dict'):
|
||||
@@ -553,15 +674,17 @@ class ReflectionEngine:
|
||||
|
||||
async def _resolve_conflicts(self, conflicts: List[Any], statement_databasets: List[Any]) -> List[Any]:
|
||||
"""
|
||||
解决冲突
|
||||
|
||||
使用 LLM 对检测到的冲突进行反思和解决。
|
||||
Resolve detected conflicts
|
||||
|
||||
Uses LLM to perform reflection and resolution on detected conflicts.
|
||||
Processes conflicts in parallel for efficiency while respecting concurrency limits.
|
||||
|
||||
Args:
|
||||
conflicts: 冲突列表
|
||||
conflicts: List of conflicts to be resolved
|
||||
statement_databasets: Statement database records for context
|
||||
|
||||
Returns:
|
||||
List[Any]: 解决方案列表
|
||||
List[Any]: List of resolution solutions with reflection results
|
||||
"""
|
||||
if not conflicts:
|
||||
return []
|
||||
@@ -570,12 +693,12 @@ class ReflectionEngine:
|
||||
baseline = self.config.baseline
|
||||
memory_verify = self.config.memory_verify
|
||||
|
||||
# 并行处理每个冲突
|
||||
# Process each conflict in parallel
|
||||
async def _resolve_one(conflict: Any) -> Optional[Dict[str, Any]]:
|
||||
"""解决单个冲突"""
|
||||
"""Resolve a single conflict"""
|
||||
async with self._semaphore:
|
||||
try:
|
||||
# 渲染反思提示词
|
||||
# Render reflection prompt
|
||||
rendered_prompt = await self.render_reflexion_prompt_func(
|
||||
[conflict],
|
||||
self.reflexion_schema,
|
||||
@@ -587,7 +710,7 @@ class ReflectionEngine:
|
||||
|
||||
messages = [{"role": "user", "content": rendered_prompt}]
|
||||
|
||||
# 调用 LLM 进行反思
|
||||
# Call LLM for reflection
|
||||
response = await self.llm_client.response_structured(
|
||||
messages,
|
||||
self.reflexion_schema
|
||||
@@ -596,7 +719,7 @@ class ReflectionEngine:
|
||||
if not response:
|
||||
return None
|
||||
|
||||
# 标准化返回格式
|
||||
# Standardize return format
|
||||
if isinstance(response, BaseModel):
|
||||
return response.model_dump()
|
||||
elif hasattr(response, 'dict'):
|
||||
@@ -610,11 +733,11 @@ class ReflectionEngine:
|
||||
logging.warning(f"解决单个冲突失败: {e}")
|
||||
return None
|
||||
|
||||
# 并发执行所有冲突解决任务
|
||||
# Execute all conflict resolution tasks concurrently
|
||||
tasks = [_resolve_one(conflict) for conflict in conflicts]
|
||||
results = await asyncio.gather(*tasks, return_exceptions=False)
|
||||
|
||||
# 过滤掉失败的结果
|
||||
# Filter out failed results
|
||||
solved = [r for r in results if r is not None]
|
||||
|
||||
logging.info(f"成功解决 {len(solved)}/{len(conflicts)} 个冲突")
|
||||
@@ -626,15 +749,16 @@ class ReflectionEngine:
|
||||
solved_data: List[Dict[str, Any]]
|
||||
) -> int:
|
||||
"""
|
||||
应用反思结果(更新记忆库)
|
||||
|
||||
将解决冲突后的记忆更新到 Neo4j 数据库中。
|
||||
Apply reflection results (update memory database)
|
||||
|
||||
Updates the Neo4j database with resolved conflicts and reflection results.
|
||||
Processes the solved data and applies changes to the memory storage system.
|
||||
|
||||
Args:
|
||||
solved_data: 解决方案列表
|
||||
solved_data: List of resolved conflict solutions with reflection data
|
||||
|
||||
Returns:
|
||||
int: 成功更新的记忆数量
|
||||
int: Number of successfully updated memory entries
|
||||
"""
|
||||
changes = extract_and_process_changes(solved_data)
|
||||
success_count = await neo4j_data(changes)
|
||||
@@ -642,80 +766,86 @@ class ReflectionEngine:
|
||||
|
||||
|
||||
|
||||
# 基于时间的反思方法
|
||||
# Time-based reflection methods
|
||||
async def time_based_reflection(
|
||||
self,
|
||||
host_id: uuid.UUID,
|
||||
time_period: Optional[str] = None
|
||||
) -> ReflectionResult:
|
||||
"""
|
||||
基于时间的反思
|
||||
|
||||
根据时间周期触发反思,检查在指定时间段内的记忆。
|
||||
Time-based reflection
|
||||
|
||||
Triggers reflection based on time cycles, checking memories within
|
||||
specified time periods. Uses the configured iteration period if
|
||||
no specific time period is provided.
|
||||
|
||||
Args:
|
||||
host_id: 主机ID
|
||||
time_period: 时间周期(如"三小时"),如果不提供则使用配置中的值
|
||||
host_id: Host UUID identifier for scoping reflection
|
||||
time_period: Time period (e.g., "three hours"), uses config value if not provided
|
||||
|
||||
Returns:
|
||||
ReflectionResult: 反思结果
|
||||
ReflectionResult: Comprehensive reflection operation result
|
||||
"""
|
||||
period = time_period or self.config.iteration_period
|
||||
logging.info(f"执行基于时间的反思,周期: {period}")
|
||||
|
||||
# 使用标准反思流程
|
||||
# Use standard reflection workflow
|
||||
return await self.execute_reflection(host_id)
|
||||
|
||||
# 基于事实的反思方法
|
||||
# Fact-based reflection methods
|
||||
async def fact_based_reflection(
|
||||
self,
|
||||
host_id: uuid.UUID
|
||||
) -> ReflectionResult:
|
||||
"""
|
||||
基于事实的反思
|
||||
|
||||
检测记忆中的事实冲突并解决。
|
||||
Fact-based reflection
|
||||
|
||||
Detects and resolves factual conflicts within memories. Analyzes
|
||||
memory data for inconsistencies and contradictions that need resolution.
|
||||
|
||||
Args:
|
||||
host_id: 主机ID
|
||||
host_id: Host UUID identifier for scoping reflection
|
||||
|
||||
Returns:
|
||||
ReflectionResult: 反思结果
|
||||
ReflectionResult: Comprehensive reflection operation result
|
||||
"""
|
||||
logging.info("执行基于事实的反思")
|
||||
|
||||
# 使用标准反思流程
|
||||
# Use standard reflection workflow
|
||||
return await self.execute_reflection(host_id)
|
||||
|
||||
# 综合反思方法
|
||||
# Comprehensive reflection methods
|
||||
async def comprehensive_reflection(
|
||||
self,
|
||||
host_id: uuid.UUID
|
||||
) -> ReflectionResult:
|
||||
"""
|
||||
综合反思
|
||||
|
||||
整合基于时间和基于事实的反思策略。
|
||||
Comprehensive reflection
|
||||
|
||||
Integrates time-based and fact-based reflection strategies based on
|
||||
the configured baseline. Supports hybrid approaches that combine
|
||||
multiple reflection methodologies.
|
||||
|
||||
Args:
|
||||
host_id: 主机ID
|
||||
host_id: Host UUID identifier for scoping reflection
|
||||
|
||||
Returns:
|
||||
ReflectionResult: 反思结果
|
||||
ReflectionResult: Comprehensive reflection operation result combining
|
||||
multiple strategies if using hybrid baseline
|
||||
"""
|
||||
logging.info("执行综合反思")
|
||||
|
||||
# 根据配置的基线选择反思策略
|
||||
# Choose reflection strategy based on configured baseline
|
||||
if self.config.baseline == ReflectionBaseline.TIME:
|
||||
return await self.time_based_reflection(host_id)
|
||||
elif self.config.baseline == ReflectionBaseline.FACT:
|
||||
return await self.fact_based_reflection(host_id)
|
||||
elif self.config.baseline == ReflectionBaseline.HYBRID:
|
||||
# 混合策略:先执行基于时间的反思,再执行基于事实的反思
|
||||
# Hybrid strategy: execute time-based reflection first, then fact-based reflection
|
||||
time_result = await self.time_based_reflection(host_id)
|
||||
fact_result = await self.fact_based_reflection(host_id)
|
||||
|
||||
# 合并结果
|
||||
# Merge results
|
||||
return ReflectionResult(
|
||||
success=time_result.success and fact_result.success,
|
||||
message=f"时间反思: {time_result.message}; 事实反思: {fact_result.message}",
|
||||
|
||||
@@ -2,9 +2,17 @@ import json
|
||||
|
||||
|
||||
def escape_lucene_query(query: str) -> str:
|
||||
"""Escape Lucene special characters in a free-text query.
|
||||
|
||||
This prevents ParseException when using Neo4j full-text procedures.
|
||||
"""
|
||||
Escape special characters in Lucene queries
|
||||
|
||||
Prevents ParseException when using Neo4j full-text search procedures.
|
||||
Escapes all Lucene reserved special characters and operators.
|
||||
|
||||
Args:
|
||||
query: Original query string
|
||||
|
||||
Returns:
|
||||
str: Escaped query string safe for Lucene search
|
||||
"""
|
||||
if query is None:
|
||||
return ""
|
||||
@@ -22,11 +30,21 @@ def escape_lucene_query(query: str) -> str:
|
||||
return s
|
||||
|
||||
def extract_plain_query(query_input: str) -> str:
|
||||
"""Extract clean, plain-text query from various input forms.
|
||||
|
||||
"""
|
||||
Extract clean plain-text query from various input forms
|
||||
|
||||
Handles the following cases:
|
||||
- Strips surrounding quotes and whitespace
|
||||
- If input looks like JSON, prefers the 'original' field
|
||||
- Fallbacks to the raw string when parsing fails
|
||||
- Falls back to raw string when parsing fails
|
||||
- Handles dictionary-type input
|
||||
- Best-effort unescape common escape characters
|
||||
|
||||
Args:
|
||||
query_input: Query input in various forms (string, dict, etc.)
|
||||
|
||||
Returns:
|
||||
str: Extracted plain-text query string
|
||||
"""
|
||||
if query_input is None:
|
||||
return ""
|
||||
|
||||
@@ -4,7 +4,13 @@ from datetime import datetime
|
||||
|
||||
def validate_date_format(date_str: str) -> bool:
|
||||
"""
|
||||
Validate if the date string is in the format YYYY-MM-DD.
|
||||
Validate if date string conforms to YYYY-MM-DD format
|
||||
|
||||
Args:
|
||||
date_str: Date string to validate
|
||||
|
||||
Returns:
|
||||
bool: True if format is correct, False otherwise
|
||||
"""
|
||||
pattern = r"^\d{4}-\d{1,2}-\d{1,2}$"
|
||||
return bool(re.match(pattern, date_str))
|
||||
@@ -41,7 +47,20 @@ def normalize_date(date_str: str) -> str:
|
||||
|
||||
|
||||
def preprocess_date_string(date_str: str) -> str:
|
||||
"""预处理日期字符串,处理特殊格式"""
|
||||
"""
|
||||
预处理日期字符串,处理特殊格式
|
||||
|
||||
处理以下特殊格式:
|
||||
- 年份后直接跟月份没有分隔符的格式(如 "20259/28")
|
||||
- 无分隔符的纯数字格式(如 "20251028", "251028")
|
||||
- 混合分隔符,统一为 "-"
|
||||
|
||||
Args:
|
||||
date_str: 原始日期字符串
|
||||
|
||||
Returns:
|
||||
str: 预处理后的日期字符串,格式为 "YYYY-MM-DD" 或 "YYYY-MM"
|
||||
"""
|
||||
|
||||
# 处理类似 "20259/28" 的格式(年份后直接跟月份没有分隔)
|
||||
match = re.match(r'^(\d{4,5})[/\.\-_]?(\d{1,2})[/\.\-_]?(\d{1,2})$', date_str)
|
||||
@@ -78,7 +97,23 @@ def preprocess_date_string(date_str: str) -> str:
|
||||
|
||||
|
||||
def fallback_parse(date_str: str) -> str:
|
||||
"""备选解析方案"""
|
||||
"""
|
||||
备选日期解析方案
|
||||
|
||||
当智能解析失败时,尝试使用预定义的日期格式进行解析。
|
||||
支持多种常见的日期格式,包括:
|
||||
- YYYY-MM-DD, YYYY/MM/DD, YYYY.MM.DD
|
||||
- YYYYMMDD, YYMMDD
|
||||
- MM-DD-YYYY, MM/DD/YYYY, MM.DD.YYYY
|
||||
- DD-MM-YYYY, DD/MM/YYYY, DD.MM.YYYY
|
||||
- YYYY-MM, YYYY/MM, YYYY.MM
|
||||
|
||||
Args:
|
||||
date_str: 待解析的日期字符串
|
||||
|
||||
Returns:
|
||||
str: 标准化后的日期字符串(YYYY-MM-DD格式),解析失败时返回原字符串
|
||||
"""
|
||||
|
||||
# 尝试常见的日期格式[citation:4][citation:5]
|
||||
formats_to_try = [
|
||||
|
||||
@@ -327,7 +327,7 @@ class MultiOntologyParser:
|
||||
|
||||
Example:
|
||||
>>> parser = MultiOntologyParser([
|
||||
... "General_purpose_entity.ttl",
|
||||
... "app/core/memory/ontology_services/General_purpose_entity.ttl",
|
||||
... "domain_specific.owl"
|
||||
... ])
|
||||
>>> registry = parser.parse_all()
|
||||
|
||||
@@ -400,7 +400,8 @@ async def render_user_summary_prompt(
|
||||
user_id: str,
|
||||
entities: str,
|
||||
statements: str,
|
||||
language: str = "zh"
|
||||
language: str = "zh",
|
||||
user_display_name: str = None
|
||||
) -> str:
|
||||
"""
|
||||
Renders the user summary prompt using the user_summary.jinja2 template.
|
||||
@@ -410,16 +411,22 @@ async def render_user_summary_prompt(
|
||||
entities: Core entities with frequency information
|
||||
statements: Representative statement samples
|
||||
language: The language to use for summary generation ("zh" for Chinese, "en" for English)
|
||||
user_display_name: Display name for the user (e.g., other_name or "该用户"/"the user")
|
||||
|
||||
Returns:
|
||||
Rendered prompt content as string
|
||||
"""
|
||||
# 如果没有提供 user_display_name,使用默认值
|
||||
if user_display_name is None:
|
||||
user_display_name = "该用户" if language == "zh" else "the user"
|
||||
|
||||
template = prompt_env.get_template("user_summary.jinja2")
|
||||
rendered_prompt = template.render(
|
||||
user_id=user_id,
|
||||
entities=entities,
|
||||
statements=statements,
|
||||
language=language
|
||||
language=language,
|
||||
user_display_name=user_display_name
|
||||
)
|
||||
|
||||
# 记录渲染结果到提示日志
|
||||
@@ -429,7 +436,8 @@ async def render_user_summary_prompt(
|
||||
'user_id': user_id,
|
||||
'entities_len': len(entities),
|
||||
'statements_len': len(statements),
|
||||
'language': language
|
||||
'language': language,
|
||||
'user_display_name': user_display_name
|
||||
})
|
||||
|
||||
return rendered_prompt
|
||||
@@ -540,3 +548,20 @@ async def render_ontology_extraction_prompt(
|
||||
})
|
||||
|
||||
return rendered_prompt
|
||||
|
||||
|
||||
def render_interest_filter_prompt(tag_list: str, language: str = "zh") -> str:
|
||||
"""
|
||||
Renders the interest filter prompt using the interest_filter.jinja2 template.
|
||||
|
||||
Args:
|
||||
tag_list: Comma-separated string of raw tags to filter
|
||||
language: Output language ("zh" for Chinese, "en" for English)
|
||||
|
||||
Returns:
|
||||
Rendered prompt content as string
|
||||
"""
|
||||
template = prompt_env.get_template("interest_filter.jinja2")
|
||||
rendered_prompt = template.render(tag_list=tag_list, language=language)
|
||||
log_prompt_rendering('interest filter', rendered_prompt)
|
||||
return rendered_prompt
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
{#
|
||||
对话级抽取与相关性判定模板(用于剪枝加速)
|
||||
输入:pruning_scene, dialog_text
|
||||
输入:pruning_scene, ontology_class_infos, dialog_text, language
|
||||
- ontology_class_infos: List[{class_name: str, class_description: str}]
|
||||
输出:严格 JSON(不要包含任何多余文本),字段:
|
||||
- is_related: bool,是否与所选场景相关
|
||||
- times: [string],从对话中抽取的时间相关文本(日期、时间、时间段、有效期等)
|
||||
@@ -9,39 +10,103 @@
|
||||
- contacts: [string],联系方式(电话/手机号/邮箱/微信/QQ等)
|
||||
- addresses: [string],地址/地点相关文本
|
||||
- keywords: [string],其它有助于保留的重要关键词(与场景强相关的术语)
|
||||
- preserve_keywords: [string],必须保留的情绪/兴趣/爱好/个人偏好相关词或短语片段
|
||||
|
||||
要求:
|
||||
- 必须只输出上述 JSON,且键名一致;不得输出解释、前后缀;不得包含注释。
|
||||
- times/ids/amounts/contacts/addresses/keywords 仅抽取原文片段或规范化后的简单字符串。
|
||||
- times/ids/amounts/contacts/addresses/keywords/preserve_keywords 仅抽取原文片段或规范化后的简单字符串。
|
||||
- 仅输出上述键;避免多余解释或字段。
|
||||
#}
|
||||
|
||||
{% set scene_instructions = {
|
||||
'education': {
|
||||
'zh': '教育场景:教学、课程、考试、作业、老师/学生互动、学习资源、学校管理等。',
|
||||
'en': 'Education Scenario: Teaching, courses, exams, homework, teacher/student interaction, learning resources, school management, etc.'
|
||||
},
|
||||
'online_service': {
|
||||
'zh': '在线客服场景:客户咨询、问题排查、服务工单、售后支持、订单/退款、工单升级等。',
|
||||
'en': 'Online Service Scenario: Customer inquiries, troubleshooting, service tickets, after-sales support, orders/refunds, ticket escalation, etc.'
|
||||
},
|
||||
'outbound': {
|
||||
'zh': '外呼场景:电话外呼、邀约、调研问卷、线索跟进、对话脚本、回访记录等。',
|
||||
'en': 'Outbound Scenario: Outbound calls, invitations, survey questionnaires, lead follow-up, call scripts, follow-up records, etc.'
|
||||
}
|
||||
} %}
|
||||
|
||||
{% set scene_key = pruning_scene %}
|
||||
{% if scene_key not in scene_instructions %}
|
||||
{% set scene_key = 'education' %}
|
||||
{# ── 确定场景说明 ── #}
|
||||
{% if ontology_class_infos and ontology_class_infos | length > 0 %}
|
||||
{% if language == 'en' %}
|
||||
{% set instruction = 'Scene "' ~ pruning_scene ~ '": The dialogue is relevant if it involves any of the following entity types.' %}
|
||||
{% else %}
|
||||
{% set instruction = '场景「' ~ pruning_scene ~ '」:对话涉及以下任意实体类型时视为相关。' %}
|
||||
{% endif %}
|
||||
{% else %}
|
||||
{% if language == 'en' %}
|
||||
{% set instruction = 'Scene "' ~ pruning_scene ~ '": Determine whether the dialogue content is relevant to this scene based on overall context.' %}
|
||||
{% else %}
|
||||
{% set instruction = '场景「' ~ pruning_scene ~ '」:根据对话整体内容判断是否与该场景相关。' %}
|
||||
{% endif %}
|
||||
{% endif %}
|
||||
|
||||
{% set instruction = scene_instructions[scene_key][language] if language in ['zh', 'en'] else scene_instructions[scene_key]['zh'] %}
|
||||
|
||||
{% if language == "zh" %}
|
||||
请在下方对话全文基础上,按该场景进行一次性抽取并判定相关性:
|
||||
你是一个对话内容分析助手。请对下方对话全文进行一次性分析,完成两项任务:
|
||||
1. 判断对话是否与指定场景相关;
|
||||
2. 从对话中抽取所有需要保留的重要信息片段。
|
||||
|
||||
场景说明:{{ instruction }}
|
||||
|
||||
{% if ontology_class_infos and ontology_class_infos | length > 0 %}
|
||||
【本场景实体类型定义】
|
||||
以下实体类型定义了本场景中哪些内容是重要的。
|
||||
凡是与以下任意类型相关的内容,都必须保留,并将关键词/短语提取到 keywords 字段:
|
||||
|
||||
{% for info in ontology_class_infos %}
|
||||
- {{ info.class_name }}:{{ info.class_description }}
|
||||
{% endfor %}
|
||||
|
||||
重要提示:只要对话中出现与上述任意实体类型相关的内容,即判定为相关(is_related=true)。
|
||||
{% endif %}
|
||||
|
||||
---
|
||||
【必须保留的内容(不可删除)】
|
||||
以下类型的内容无论是否与场景直接相关,都必须保留,请将其关键词/短语抽取到对应字段:
|
||||
- 时间信息:日期、时间点、时间段、有效期 → times 字段
|
||||
- 编号信息:学号、工号、订单号、申请号、账号、ID → ids 字段
|
||||
- 金额信息:价格、费用、金额(含货币符号或单位,如"100元"、"¥200")→ amounts 字段(注意:考试分数、成绩分数不属于金额,不要放入此字段)
|
||||
- 联系方式:电话、手机号、邮箱、微信、QQ → contacts 字段
|
||||
- 地址信息:地点、地址、位置 → addresses 字段
|
||||
- 场景关键词:与**当前场景**强相关的专业术语、事件名称 → keywords 字段(注意:只放与当前场景直接相关的词,跨场景的内容不要放入此字段)
|
||||
- **情绪与情感**:喜悦、悲伤、愤怒、焦虑、开心、难过、委屈、兴奋、害怕、担心、压力、感动等情绪表达 → preserve_keywords 字段
|
||||
- **兴趣与爱好**:喜欢、热爱、爱好、擅长、享受、沉迷、着迷、讨厌某事物等个人偏好表达 → preserve_keywords 字段
|
||||
- **个人情感态度**:对人际关系、情感状态的明确表达(如"我跟室友闹矛盾了"、"我都快抑郁了")→ preserve_keywords 字段
|
||||
- 注意:学业目标(如"我想考研")、成绩(如"87分")、学科偏好(如"喜欢数学")属于学业信息,不属于情绪/情感,不要放入 preserve_keywords 字段
|
||||
|
||||
【场景无关内容标记】
|
||||
请从对话中识别出与当前场景({{ pruning_scene }})**既不相关、也无语义关联**的消息片段,将其原文(或关键片段)提取到 scene_unrelated_snippets 字段。
|
||||
判断标准:
|
||||
- 与场景实体类型完全无关
|
||||
- 与场景话题没有因果/时间/情境上的关联(例如:不是"因为上课所以累"这种关联)
|
||||
- 纯粹是另一个话题的内容(如在教育场景中讨论购物、娱乐等)
|
||||
注意:有情绪/感受表达的消息即使话题不同,也可能有语义关联,请谨慎标记。
|
||||
|
||||
**重要:scene_unrelated_snippets 必须认真填写,不能为空数组。**
|
||||
如果对话中存在与场景无关的内容,必须将其原文片段提取出来。
|
||||
|
||||
示例(场景=在线教育):
|
||||
- "我最近心情很差,跟室友闹矛盾了" → 与教育场景无关,加入 scene_unrelated_snippets
|
||||
- "她总是很晚回来吵到我睡觉" → 与教育场景无关,加入 scene_unrelated_snippets
|
||||
- "对,我都快抑郁了" → 与教育场景无关,加入 scene_unrelated_snippets
|
||||
- "期末考试12月25日" → 与教育场景相关,不加入 scene_unrelated_snippets
|
||||
- "我上次高数作业87分" → 与教育场景相关,不加入 scene_unrelated_snippets
|
||||
- "我的目标是考研" → 与教育场景相关,不加入 scene_unrelated_snippets
|
||||
|
||||
示例(场景=情感陪伴):
|
||||
- "我最近心情很差,跟室友闹矛盾了" → 与情感陪伴场景相关(情绪+关系),不加入 scene_unrelated_snippets
|
||||
- "对,我都快抑郁了" → 与情感陪伴场景相关(情绪),不加入 scene_unrelated_snippets
|
||||
- "期末考试12月25日,3号教学楼201室" → 与情感陪伴场景无关(教育信息),加入 scene_unrelated_snippets
|
||||
- "我上次高数作业87分,这次能考好吗" → 与情感陪伴场景无关(学业信息),加入 scene_unrelated_snippets
|
||||
- "我的目标是考研,想读应用数学" → 与情感陪伴场景无关(学业目标),加入 scene_unrelated_snippets
|
||||
|
||||
【可以删除的内容】
|
||||
以下类型的内容属于低价值信息,可以在剪枝时删除:
|
||||
- 纯寒暄问候:如"你好"、"在吗"、"拜拜"、"嗯"、"好的"、"哦"等无实质内容的短语
|
||||
- 纯表情/符号:如"[微笑]"、"😊"、"哈哈"等
|
||||
- 重复确认:如"对对对"、"是的是的"、"嗯嗯嗯"等无新增信息的重复
|
||||
- 无意义填充:如"啊"、"呢"、"嘛"等语气词单独成句
|
||||
|
||||
**注意:即使消息很短,只要包含情绪、兴趣、爱好、个人观点等有价值信息,就必须保留,不得删除。**
|
||||
例如:
|
||||
- "我好开心呀" → 包含情绪(开心),必须保留,preserve_keywords 中加入"开心"
|
||||
- "好喜欢打羽毛球呀" → 包含兴趣爱好(喜欢打羽毛球),必须保留,preserve_keywords 中加入"喜欢打羽毛球"
|
||||
- "我好难过" → 包含情绪(难过),必须保留,preserve_keywords 中加入"难过"
|
||||
- "太好啦!看到你开心,我也跟着心情亮起来" → 包含情绪,必须保留,preserve_keywords 中加入"开心"
|
||||
|
||||
---
|
||||
对话全文:
|
||||
"""
|
||||
{{ dialog_text }}
|
||||
@@ -55,12 +120,65 @@
|
||||
"amounts": [<string>...],
|
||||
"contacts": [<string>...],
|
||||
"addresses": [<string>...],
|
||||
"keywords": [<string>...]
|
||||
"keywords": [<string>...],
|
||||
"preserve_keywords": [<string>...],
|
||||
"scene_unrelated_snippets": [<string>...]
|
||||
}
|
||||
{% else %}
|
||||
Based on the full dialogue below, perform one-time extraction and relevance determination according to this scenario:
|
||||
You are a dialogue content analysis assistant. Please analyze the full dialogue below in one pass and complete two tasks:
|
||||
1. Determine whether the dialogue is relevant to the specified scene;
|
||||
2. Extract all important information fragments that must be preserved.
|
||||
|
||||
Scenario Description: {{ instruction }}
|
||||
|
||||
{% if ontology_class_infos and ontology_class_infos | length > 0 %}
|
||||
[Scene Entity Type Definitions]
|
||||
The following entity types define what content is important in this scene.
|
||||
Content related to ANY of these types must be preserved and extracted into the keywords field:
|
||||
|
||||
{% for info in ontology_class_infos %}
|
||||
- {{ info.class_name }}: {{ info.class_description }}
|
||||
{% endfor %}
|
||||
|
||||
Important: If the dialogue contains content related to any of the entity types above, mark it as relevant (is_related=true).
|
||||
{% endif %}
|
||||
|
||||
---
|
||||
[MUST PRESERVE (cannot be deleted)]
|
||||
The following types of content must always be preserved regardless of scene relevance. Extract their keywords/phrases into the corresponding fields:
|
||||
- Time information: dates, time points, durations, expiry dates → times field
|
||||
- ID information: student IDs, employee IDs, order numbers, application numbers, account IDs → ids field
|
||||
- Amount information: prices, fees, amounts (with currency symbols or units, e.g., "$100", "¥200") → amounts field (Note: exam scores and grades are NOT amounts, do not put them here)
|
||||
- Contact information: phone numbers, emails, WeChat, QQ → contacts field
|
||||
- Address information: locations, addresses, places → addresses field
|
||||
- Scene keywords: professional terms and event names strongly related to **the current scene** → keywords field (Note: only put terms directly related to the current scene; cross-scene content should not be placed here)
|
||||
- **Emotions and feelings**: joy, sadness, anger, anxiety, happiness, sadness, excitement, fear, worry, stress, being moved, etc. → preserve_keywords field
|
||||
- **Interests and hobbies**: likes, loves, hobbies, good at, enjoys, obsessed with, hates something, personal preferences → preserve_keywords field
|
||||
- **Personal emotional attitudes**: clear expressions about interpersonal relationships or emotional states (e.g., "I had a fight with my roommate", "I'm almost depressed") → preserve_keywords field
|
||||
- Note: Academic goals (e.g., "I want to pursue a master's degree"), grades (e.g., "87 points"), and subject preferences (e.g., "I like math") are academic information, NOT emotions/feelings — do not put them in preserve_keywords
|
||||
|
||||
[Scene-Unrelated Content Marking]
|
||||
Please identify message snippets in the dialogue that are **neither relevant to nor semantically associated with** the current scene ({{ pruning_scene }}), and extract their original text (or key fragments) into the scene_unrelated_snippets field.
|
||||
Criteria:
|
||||
- Completely unrelated to the scene's entity types
|
||||
- No causal/temporal/contextual association with the scene topic (e.g., "feeling tired because of class" IS associated)
|
||||
- Purely belongs to a different topic (e.g., discussing shopping or entertainment in an education scene)
|
||||
Note: Messages with emotional/feeling expressions may still have semantic association even if the topic differs — mark carefully.
|
||||
|
||||
[CAN BE DELETED]
|
||||
The following types of content are low-value and can be removed during pruning:
|
||||
- Pure greetings: e.g., "hello", "are you there", "bye", "ok", "yeah" — short phrases with no substantive content
|
||||
- Pure emojis/symbols: e.g., "[smile]", "😊", "haha"
|
||||
- Repetitive confirmations: e.g., "yes yes yes", "right right", "uh huh" — repetitions with no new information
|
||||
- Meaningless fillers: standalone interjections like "ah", "well", "hmm"
|
||||
|
||||
**Note: Even if a message is short, if it contains emotions, interests, hobbies, or personal opinions, it MUST be preserved.**
|
||||
Examples:
|
||||
- "I'm so happy!" → contains emotion (happy), must preserve; add "happy" to preserve_keywords
|
||||
- "I love playing badminton!" → contains interest (love playing badminton), must preserve; add "love playing badminton" to preserve_keywords
|
||||
- "I feel so sad" → contains emotion (sad), must preserve; add "sad" to preserve_keywords
|
||||
|
||||
---
|
||||
Full Dialogue:
|
||||
"""
|
||||
{{ dialog_text }}
|
||||
@@ -74,6 +192,8 @@ Output strict JSON only (fixed keys, order doesn't matter):
|
||||
"amounts": [<string>...],
|
||||
"contacts": [<string>...],
|
||||
"addresses": [<string>...],
|
||||
"keywords": [<string>...]
|
||||
"keywords": [<string>...],
|
||||
"preserve_keywords": [<string>...],
|
||||
"scene_unrelated_snippets": [<string>...]
|
||||
}
|
||||
{% endif %}
|
||||
|
||||
@@ -5,6 +5,15 @@
|
||||
===Task===
|
||||
Extract entities and knowledge triplets from the given statement.
|
||||
|
||||
**⚠️ CRITICAL REQUIREMENTS:**
|
||||
1. **ALIASES ORDER IS CRITICAL**: The FIRST alias in the array will be used as the user's primary display name (other_name). You MUST put the most important/frequently used name FIRST.
|
||||
2. **ALWAYS include aliases field**: Even if empty, you MUST include "aliases": [] in EVERY entity.
|
||||
|
||||
<!-- TODO: v0.2.10 - denied_aliases 功能暂时禁用,将通过 Cypher 查询实现
|
||||
2. **DENIED_ALIASES**: When user explicitly denies a name (e.g., "我不叫X", "I'm not called X"), you MUST put X in denied_aliases field, NOT in aliases.
|
||||
3. **ALWAYS include both fields**: Even if empty, you MUST include "aliases": [] and "denied_aliases": [] in EVERY entity.
|
||||
-->
|
||||
|
||||
{% if language == "zh" %}
|
||||
**重要:请使用中文生成实体名称(name)、描述(description)和示例(example)。**
|
||||
{% else %}
|
||||
@@ -18,34 +27,29 @@ Extract entities and knowledge triplets from the given statement.
|
||||
{% if ontology_types %}
|
||||
===Ontology Type Guidance===
|
||||
|
||||
**CRITICAL RULE: You MUST ONLY use the predefined ontology type names listed below for the entity "type" field. Do NOT use any other type names, even if they seem reasonable.**
|
||||
**CRITICAL: Use ONLY predefined type names below. If no exact match, use CLOSEST type. NEVER invent new types.**
|
||||
|
||||
**If no predefined type fits an entity, use the CLOSEST matching predefined type. NEVER invent new type names.**
|
||||
**Type Priority:**
|
||||
1. [场景类型] Scene Types (domain-specific, prefer first)
|
||||
2. [通用类型] General Types (standard ontologies)
|
||||
3. [通用父类] Parent Types (hierarchy context)
|
||||
|
||||
**Type Priority (from highest to lowest):**
|
||||
1. **[场景类型] Scene Types** - Domain-specific types, ALWAYS prefer these first
|
||||
2. **[通用类型] General Types** - Common types from standard ontologies (DBpedia)
|
||||
3. **[通用父类] Parent Types** - Provide type hierarchy context
|
||||
**Rules:**
|
||||
- Type MUST exactly match predefined names
|
||||
- Do NOT modify, translate, or abbreviate type names
|
||||
- Prefer scene types over general types
|
||||
|
||||
**Type Matching Rules:**
|
||||
- Entity type MUST exactly match one of the predefined type names below
|
||||
- Do NOT use types like "Equipment", "Component", "Concept", "Action", "Condition", "Data", "Duration" unless they appear in the predefined list
|
||||
- Do NOT modify, translate, abbreviate, or create variations of type names
|
||||
- Prefer scene types (marked [场景类型]) over general types when both could apply
|
||||
- If uncertain, check the type description to find the best match
|
||||
|
||||
**Predefined Ontology Types:**
|
||||
**Predefined Types:**
|
||||
{{ ontology_types }}
|
||||
|
||||
{% if type_hierarchy_hints %}
|
||||
**Type Hierarchy Reference:**
|
||||
The following shows type inheritance relationships (Child → Parent → Grandparent):
|
||||
**Hierarchy:**
|
||||
{% for hint in type_hierarchy_hints %}
|
||||
- {{ hint }}
|
||||
{% endfor %}
|
||||
{% endif %}
|
||||
|
||||
**ALLOWED Type Names (use EXACTLY one of these, no exceptions):**
|
||||
**ALLOWED Names:**
|
||||
{{ ontology_type_names | join(', ') }}
|
||||
|
||||
{% endif %}
|
||||
@@ -62,66 +66,94 @@ The following shows type inheritance relationships (Child → Parent → Grandpa
|
||||
- **Entity descriptions must be in English**
|
||||
- **Examples must be in English**
|
||||
{% endif %}
|
||||
- **Semantic Memory Classification (is_explicit_memory):**
|
||||
* Set to `true` if the entity represents **explicit/semantic memory**:
|
||||
- **Concepts:** "Machine Learning", "Photosynthesis", "Democracy"
|
||||
- **Knowledge:** "Python Programming Language", "Theory of Relativity"
|
||||
- **Definitions:** "API (Application Programming Interface)", "REST API"
|
||||
- **Principles:** "SOLID Principles", "First Law of Thermodynamics"
|
||||
- **Theories:** "Evolution Theory", "Quantum Mechanics"
|
||||
- **Methods/Techniques:** "Agile Development", "Machine Learning Algorithm"
|
||||
- **Technical Terms:** "Neural Network", "Database"
|
||||
* Set to `false` for:
|
||||
- **People:** "John Smith", "Dr. Wang"
|
||||
- **Organizations:** "Microsoft", "Harvard University"
|
||||
- **Locations:** "Beijing", "Central Park"
|
||||
- **Events:** "2024 Conference", "Project Meeting"
|
||||
- **Specific objects:** "iPhone 15", "Building A"
|
||||
- **Example Generation (IMPORTANT for semantic memory entities):**
|
||||
* For entities where `is_explicit_memory=true`, generate a **concise example (around 20 characters)** to help understand the concept
|
||||
* The example should be:
|
||||
- **Specific and concrete**: Use real-world scenarios or applications
|
||||
- **Brief**: Around 20 characters (can be slightly longer if needed for clarity)
|
||||
- **Semantic Memory (is_explicit_memory):**
|
||||
* `true` for: Concepts, Knowledge, Definitions, Theories, Methods (e.g., "Machine Learning", "REST API")
|
||||
* `false` for: People, Organizations, Locations, Events, Specific objects
|
||||
* For `is_explicit_memory=true`, provide concise example (~20 chars{% if language == "zh" %},使用中文{% endif %})
|
||||
|
||||
**🚨🚨🚨 ALIASES & DENIED_ALIASES - MANDATORY FIELDS 🚨🚨🚨**
|
||||
|
||||
**CRITICAL RULES (违反将导致提取失败):**
|
||||
|
||||
1. **EVERY entity MUST have aliases field:**
|
||||
- `"aliases": [...]` - REQUIRED, even if empty `[]`
|
||||
|
||||
2. **ALIASES - 别名提取规则:**
|
||||
{% if language == "zh" %}
|
||||
- **使用中文**
|
||||
- 包含:昵称、全名、简称、别称、网名等
|
||||
- 顺序:**第一个别名将作为用户的主显示名称(other_name),必须把最重要/最常用的名字放在第一位**
|
||||
- 提取顺序:严格按照对话中首次出现的顺序
|
||||
- 示例:
|
||||
* "我叫张三,大家叫我小张" → aliases=["张三", "小张"](张三是第一个,将成为 other_name)
|
||||
* "大家叫我小李,我全名叫李明" → aliases=["小李", "李明"](小李先出现,将成为 other_name)
|
||||
- 空值:如果没有别名,使用 `[]`
|
||||
- 重要:只提取本次对话中明确提到的别名,不要推测或添加未提及的名字
|
||||
{% else %}
|
||||
- **In English**
|
||||
- Include: nicknames, full names, abbreviations, alternative names
|
||||
- Order: **The FIRST alias will be used as the user's primary display name (other_name). Put the most important/frequently used name FIRST**
|
||||
- Extraction order: Strictly follow the order of first appearance in conversation
|
||||
- Examples:
|
||||
* "I'm John, people call me Johnny" → aliases=["John", "Johnny"] (John is first, will become other_name)
|
||||
* "People call me Mike, my full name is Michael" → aliases=["Mike", "Michael"] (Mike appears first, will become other_name)
|
||||
- Empty: If no aliases, use `[]`
|
||||
- Important: Only extract aliases explicitly mentioned in current conversation, do not infer or add unmentioned names
|
||||
{% endif %}
|
||||
* For non-semantic entities (`is_explicit_memory=false`), the example field can be empty
|
||||
- **Aliases Extraction:**
|
||||
|
||||
|
||||
|
||||
3. **USER ENTITY SPECIAL HANDLING:**
|
||||
{% if language == "zh" %}
|
||||
* 别名使用中文
|
||||
- 用户实体的 name 字段:使用 "用户" 或 "我"
|
||||
- 用户的真实姓名:放入 aliases
|
||||
- **🚨 禁止将 "用户"、"我" 放入 aliases 中,aliases 只能包含用户的真实姓名、昵称等**
|
||||
- 示例:
|
||||
* "我叫李明" → name="用户", aliases=["李明"]
|
||||
* ❌ 错误:aliases=["用户", "李明"]("用户"不是真实姓名,禁止放入 aliases)
|
||||
* ❌ 错误:aliases=["我", "李明"]("我"不是真实姓名,禁止放入 aliases)
|
||||
{% else %}
|
||||
* Aliases should be in English
|
||||
- User entity name field: use "User" or "I"
|
||||
- User's real name: put in aliases
|
||||
- **🚨 NEVER put "User" or "I" in aliases. Aliases must only contain real names, nicknames, etc.**
|
||||
- Examples:
|
||||
* "I'm John" → name="User", aliases=["John"]
|
||||
* ❌ Wrong: aliases=["User", "John"] ("User" is not a real name, FORBIDDEN in aliases)
|
||||
* ❌ Wrong: aliases=["I", "John"] ("I" is not a real name, FORBIDDEN in aliases)
|
||||
{% endif %}
|
||||
* Include common alternative names, abbreviations and full names
|
||||
* If no aliases exist, use empty array: []
|
||||
- Exclude lengthy quotes, calendar dates, temporal ranges, and temporal expressions
|
||||
- For numeric values: extract as separate entities (instance_of: 'Numeric', name: units, numeric_value: value)
|
||||
Example: £30 → name: 'GBP', numeric_value: 30, instance_of: 'Numeric'
|
||||
|
||||
|
||||
|
||||
4. **ALIASES ORDER:**
|
||||
{% if language == "zh" %}
|
||||
- 顺序优先级:按出现顺序,先出现的在前
|
||||
{% else %}
|
||||
- Order priority: by appearance order, first mentioned comes first
|
||||
{% endif %}
|
||||
|
||||
**EXAMPLES OF CORRECT EXTRACTION:**
|
||||
{% if language == "zh" %}
|
||||
- "我叫张三" → aliases=["张三"] (张三将成为 other_name)
|
||||
- "大家叫我小明,我全名叫李明" → aliases=["小明", "李明"] (小明先出现,将成为 other_name)
|
||||
- "我是李华,网名叫华仔" → aliases=["李华", "华仔"] (李华先出现,将成为 other_name)
|
||||
{% else %}
|
||||
- "I'm John" → aliases=["John"] (John will become other_name)
|
||||
- "People call me Mike, my full name is Michael" → aliases=["Mike", "Michael"] (Mike appears first, will become other_name)
|
||||
- "I'm John Smith, username JSmith" → aliases=["John Smith", "JSmith"] (John Smith appears first, will become other_name)
|
||||
{% endif %}
|
||||
|
||||
- Exclude lengthy quotes, dates, temporal expressions
|
||||
- Numeric values: extract as entities (instance_of: 'Numeric', name: units, numeric_value: value)
|
||||
|
||||
**Triplet Extraction:**
|
||||
- Extract (subject, predicate, object) triplets where:
|
||||
- Subject: main entity performing the action or being described
|
||||
- Predicate: relationship between entities (e.g., 'is', 'works at', 'believes')
|
||||
- Object: entity, value, or concept affected by the predicate
|
||||
- Extract (subject, predicate, object) where subject/object are entities, predicate is relationship
|
||||
{% if language == "zh" %}
|
||||
- subject_name 和 object_name 必须使用中文
|
||||
- subject_name 和 object_name 使用中文
|
||||
{% else %}
|
||||
- subject_name and object_name must be in English (translate if original is in another language)
|
||||
- subject_name and object_name in English
|
||||
{% endif %}
|
||||
- Exclude all temporal expressions from every field
|
||||
- Use ONLY the predicates listed in "Predicate Instructions" (uppercase English tokens)
|
||||
- Do NOT translate predicate tokens
|
||||
- Do NOT include `statement_id` field (assigned automatically)
|
||||
|
||||
**When NOT to extract triplets:**
|
||||
- Non-propositional utterances (emotions, fillers, onomatopoeia)
|
||||
- No clear predicate from the given definitions applies
|
||||
- Standalone noun phrases or checklist items → extract as entities only
|
||||
- Do NOT invent generic predicates (e.g., "IS_DOING", "FEELS", "MENTIONS")
|
||||
|
||||
**If no valid triplet exists:** Return triplets: [], extract entities if present, otherwise both arrays empty.
|
||||
- Use ONLY predicates from "Predicate Instructions" (uppercase tokens)
|
||||
- Exclude temporal expressions, do NOT include `statement_id`
|
||||
- **When NOT to extract:** emotions, fillers, no clear predicate, standalone nouns
|
||||
- **If no valid triplet:** Return triplets: []
|
||||
{%- if predicate_instructions -%}
|
||||
|
||||
**Predicate Instructions:**
|
||||
@@ -207,26 +239,44 @@ Output:
|
||||
{"entity_idx": 0, "name": "三脚架", "type": "Equipment", "description": "摄影器材配件", "example": "", "aliases": ["相机三脚架"], "is_explicit_memory": false}
|
||||
]
|
||||
}
|
||||
|
||||
**Example 4 (别名 - Chinese):** "我的名字是乐力齐,我的小名是齐齐,同事们都叫我小乐"
|
||||
Output:
|
||||
{
|
||||
"triplets": [],
|
||||
"entities": [
|
||||
{"entity_idx": 0, "name": "用户", "type": "Person", "description": "用户本人", "example": "", "aliases": ["乐力齐", "齐齐", "小乐"], "is_explicit_memory": false}
|
||||
]
|
||||
}
|
||||
|
||||
**Example 5 (别名顺序 - Chinese):** "我叫陈思远。对了,我的网名叫「远山」"
|
||||
Output:
|
||||
{
|
||||
"triplets": [],
|
||||
"entities": [
|
||||
{"entity_idx": 0, "name": "用户", "type": "Person", "description": "用户本人", "example": "", "aliases": ["陈思远", "远山"], "is_explicit_memory": false}
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
{% endif %}
|
||||
===End of Examples===
|
||||
|
||||
{% if ontology_types %}
|
||||
**⚠️ REMINDER: The examples above use generic type names for illustration only. You MUST use ONLY the predefined ontology type names from the "ALLOWED Type Names" list above. For example, use "PredictiveMaintenance" instead of "Concept", use "ProductionLine" instead of "Equipment", etc. Map each entity to the closest matching predefined type.**
|
||||
**⚠️ REMINDER: Examples use generic types for illustration. You MUST use predefined types from "ALLOWED Names" above.**
|
||||
{% endif %}
|
||||
|
||||
===Output Format===
|
||||
|
||||
**JSON Requirements:**
|
||||
- Use only ASCII double quotes (") for JSON structure
|
||||
- Never use Chinese quotation marks ("") or Unicode quotes
|
||||
- Escape quotation marks in text with backslashes (\")
|
||||
- Ensure proper string closure and comma separation
|
||||
- No line breaks within JSON string values
|
||||
- Use ASCII double quotes ("), escape with \"
|
||||
- No Chinese quotes (""), no line breaks in strings
|
||||
{% if language == "zh" %}
|
||||
- **语言要求:实体名称(name)、描述(description)、示例(example)、subject_name、object_name 必须使用中文**
|
||||
- **语言:name、description、example、subject_name、object_name 使用中文**
|
||||
{% else %}
|
||||
- **Language Requirement: Entity names, descriptions, examples, subject_name, object_name must be in English**
|
||||
- **If the original text is in Chinese, translate all names to English**
|
||||
- **Language: names, descriptions, examples in English (translate if needed)**
|
||||
{% endif %}
|
||||
- **⚠️ ALIASES ORDER: preserve temporal order of appearance**
|
||||
- **🚨 MANDATORY FIELD: EVERY entity MUST include "aliases" field, even if empty array []**
|
||||
|
||||
{{ json_schema }}
|
||||
|
||||
@@ -0,0 +1,67 @@
|
||||
{% if language == "zh" %}
|
||||
You are a user interest analysis expert. Your task is to infer and extract the user's core hobby/interest activities from a tag list. The tags may be specific project names, tool names, or compound nouns — your job is to identify the underlying interest they represent.
|
||||
|
||||
**Step 1 - Infer the underlying interest from each tag**:
|
||||
Look at each tag and ask: "What hobby or interest does this tag suggest the user has?"
|
||||
|
||||
Examples of inference:
|
||||
- '攀岩', '室内攀岩馆', '攀岩者数据仪表盘', '路线解锁地图', '指力', '路线等级', '当日攀岩流畅度' → '攀岩'
|
||||
- '风光摄影元数据增强器', 'EXIF数据', '.CR2文件', '.NEF文件', '日出拍摄点', '曝光补偿', '光圈', '太阳高度角', '云量预测图层' → '摄影'
|
||||
- '晨间冥想坚持天数', '身心协同峰值' → '冥想'
|
||||
- '川味可视化', '川菜' → '烹饪'
|
||||
- '开源项目命名建议', 'climbviz', '可视化', '力量增长雷达图' → '编程' 或 '数据可视化'
|
||||
- '吉他', '指弹', '琴谱' → '吉他'
|
||||
- '跑步', '5公里', '跑鞋' → '跑步'
|
||||
- '瑜伽垫', '瑜伽课' → '瑜伽'
|
||||
|
||||
**Step 2 - Consolidate and deduplicate**:
|
||||
- Merge tags that point to the same interest into one representative label
|
||||
- Use concise, standard hobby names (e.g., '攀岩', '摄影', '编程', '烹饪', '冥想', '吉他', '跑步')
|
||||
- If multiple tags all point to '攀岩', output '攀岩' only once
|
||||
|
||||
**Step 3 - Filter out non-interest tags**:
|
||||
Remove tags that do NOT suggest any hobby or interest:
|
||||
- Generic system/assistant terms (e.g., '助手', '用户', 'AI')
|
||||
- Pure abstract metrics with no clear hobby link (e.g., '完成时间', '日期', '自我评分')
|
||||
- Location names with no clear hobby link (e.g., '青城山后山' alone — but if combined with photography context, infer '摄影')
|
||||
|
||||
**Output format**: Return a list of concise interest activity names in Chinese.
|
||||
|
||||
**Example**:
|
||||
Input: ['攀岩', '攀岩者数据仪表盘', '路线解锁地图', '指力', '风光摄影元数据增强器', 'EXIF数据', '晨间冥想坚持天数', '川味可视化', '可视化', '助手', '完成时间']
|
||||
Output: ['攀岩', '摄影', '冥想', '烹饪', '编程']
|
||||
|
||||
Now process the following tag list and return the inferred interest activities in Chinese: {{ tag_list }}
|
||||
{% else %}
|
||||
You are a user interest analysis expert. Your task is to infer and extract the user's core hobby/interest activities from a tag list. The tags may be specific project names, tool names, or compound nouns — your job is to identify the underlying interest they represent.
|
||||
|
||||
**Step 1 - Infer the underlying interest from each tag**:
|
||||
Look at each tag and ask: "What hobby or interest does this tag suggest the user has?"
|
||||
|
||||
Examples of inference:
|
||||
- 'rock climbing', 'indoor climbing gym', 'climber dashboard', 'route map', 'finger strength' → 'rock climbing'
|
||||
- 'landscape photography metadata enhancer', 'EXIF data', 'sunrise shooting spot', 'exposure compensation' → 'photography'
|
||||
- 'morning meditation streak', 'mind-body peak' → 'meditation'
|
||||
- 'Sichuan cuisine visualization', 'Sichuan food' → 'cooking'
|
||||
- 'open source project', 'data visualization tool', 'Python' → 'programming'
|
||||
- 'guitar', 'fingerpicking', 'sheet music' → 'guitar'
|
||||
- 'running', '5km', 'running shoes' → 'running'
|
||||
|
||||
**Step 2 - Consolidate and deduplicate**:
|
||||
- Merge tags that point to the same interest into one representative label
|
||||
- Use concise, standard hobby names (e.g., 'rock climbing', 'photography', 'programming', 'cooking', 'meditation')
|
||||
- If multiple tags all point to 'rock climbing', output 'rock climbing' only once
|
||||
|
||||
**Step 3 - Filter out non-interest tags**:
|
||||
Remove tags that do NOT suggest any hobby or interest:
|
||||
- Generic system/assistant terms (e.g., 'assistant', 'user', 'AI')
|
||||
- Pure abstract metrics with no clear hobby link (e.g., 'completion time', 'date', 'self-rating')
|
||||
|
||||
**Output format**: Return a list of concise interest activity names in English.
|
||||
|
||||
**Example**:
|
||||
Input: ['rock climbing', 'climber dashboard', 'route map', 'finger strength', 'landscape photography metadata enhancer', 'EXIF data', 'morning meditation streak', 'Sichuan cuisine visualization', 'visualization', 'assistant', 'completion time']
|
||||
Output: ['rock climbing', 'photography', 'meditation', 'cooking', 'programming']
|
||||
|
||||
Now process the following tag list and return the inferred interest activities in English: {{ tag_list }}
|
||||
{% endif %}
|
||||
@@ -14,8 +14,8 @@ Your task is to generate a comprehensive user profile based on the provided enti
|
||||
{% endif %}
|
||||
|
||||
===Inputs===
|
||||
{% if user_id %}
|
||||
- User ID: {{ user_id }}
|
||||
{% if user_display_name %}
|
||||
- User Display Name: {{ user_display_name }}
|
||||
{% endif %}
|
||||
{% if entities %}
|
||||
- Core Entities & Frequency: {{ entities }}
|
||||
@@ -33,6 +33,20 @@ Your task is to generate a comprehensive user profile based on the provided enti
|
||||
3. Avoid excessive adjectives and empty phrases
|
||||
4. Strictly follow the output format specified below
|
||||
|
||||
{% if language == "zh" %}
|
||||
**【严格人称规定】**
|
||||
- 在描述用户时,必须使用"{{ user_display_name }}"作为人称
|
||||
- 绝对禁止使用用户ID(如 {{ user_id }})来称呼用户
|
||||
- 绝对禁止在摘要中出现任何形式的UUID或ID字符串
|
||||
- 如果需要指代用户,只能使用"{{ user_display_name }}"或相应的代词(他/她/TA)
|
||||
{% else %}
|
||||
**【STRICT PRONOUN RULES】**
|
||||
- When describing the user, you MUST use "{{ user_display_name }}" as the reference
|
||||
- It is ABSOLUTELY FORBIDDEN to use the user ID (such as {{ user_id }}) to refer to the user
|
||||
- It is ABSOLUTELY FORBIDDEN to include any form of UUID or ID string in the summary
|
||||
- If you need to refer to the user, you can ONLY use "{{ user_display_name }}" or appropriate pronouns (he/she/they)
|
||||
{% endif %}
|
||||
|
||||
**Section-Specific Requirements:**
|
||||
|
||||
{% if language == "zh" %}
|
||||
@@ -103,13 +117,13 @@ Your task is to generate a comprehensive user profile based on the provided enti
|
||||
|
||||
{% if language == "zh" %}
|
||||
Example Input:
|
||||
- User ID: user_12345
|
||||
- User Display Name: 张三
|
||||
- Core Entities & Frequency: 产品经理 (15), AI (12), 深圳 (10), 数据分析 (8), 团队协作 (7)
|
||||
- Representative Statement Samples: 我在深圳从事产品经理工作已经5年了 | 我相信好的产品源于对用户需求的深刻理解 | 我喜欢在团队中起到协调作用 | 数据驱动决策是我的工作原则
|
||||
|
||||
Example Output:
|
||||
【基本介绍】
|
||||
我是张三,一名充满热情的高级产品经理。在过去的5年里,我专注于AI和数据驱动的产品设计,致力于创造能够真正改善用户生活的产品。我相信好的产品源于对用户需求的深刻理解和对技术可能性的不断探索。
|
||||
张三是一名充满热情的高级产品经理,在深圳工作。在过去的5年里,张三专注于AI和数据驱动的产品设计,致力于创造能够真正改善用户生活的产品。张三相信好的产品源于对用户需求的深刻理解和对技术可能性的不断探索。
|
||||
|
||||
【性格特点】
|
||||
性格开朗,善于沟通,注重细节。喜欢在团队中起到协调作用,帮助大家达成共识。面对挑战时保持乐观,相信每个问题都有解决方案。
|
||||
@@ -121,13 +135,13 @@ Example Output:
|
||||
"让每一个产品决策都充满温度。"
|
||||
{% else %}
|
||||
Example Input:
|
||||
- User ID: user_12345
|
||||
- User Display Name: John
|
||||
- Core Entities & Frequency: Product Manager (15), AI (12), San Francisco (10), Data Analysis (8), Team Collaboration (7)
|
||||
- Representative Statement Samples: I have been working as a product manager in San Francisco for 5 years | I believe good products come from deep understanding of user needs | I enjoy playing a coordinating role in teams | Data-driven decision making is my work principle
|
||||
|
||||
Example Output:
|
||||
【Basic Introduction】
|
||||
This is a passionate senior product manager based in San Francisco. Over the past 5 years, they have focused on AI and data-driven product design, dedicated to creating products that truly improve users' lives. They believe good products stem from deep understanding of user needs and continuous exploration of technological possibilities.
|
||||
John is a passionate senior product manager based in San Francisco. Over the past 5 years, John has focused on AI and data-driven product design, dedicated to creating products that truly improve users' lives. John believes good products stem from deep understanding of user needs and continuous exploration of technological possibilities.
|
||||
|
||||
【Personality Traits】
|
||||
Outgoing personality with excellent communication skills and attention to detail. Enjoys playing a coordinating role in teams, helping everyone reach consensus. Maintains optimism when facing challenges, believing every problem has a solution.
|
||||
|
||||
@@ -2,15 +2,15 @@ import os
|
||||
from jinja2 import Environment, FileSystemLoader
|
||||
from typing import List, Dict, Any
|
||||
|
||||
|
||||
# Setup Jinja2 environment
|
||||
prompt_dir = os.path.join(os.path.dirname(__file__), "prompts")
|
||||
prompt_env = Environment(loader=FileSystemLoader(prompt_dir))
|
||||
|
||||
|
||||
async def render_evaluate_prompt(evaluate_data: List[Any], schema: Any,
|
||||
baseline: str = "TIME",
|
||||
memory_verify: bool = False,quality_assessment:bool = False,
|
||||
statement_databasets: List[str] = [],language_type:str = "zh") -> str:
|
||||
memory_verify: bool = False, quality_assessment: bool = False,
|
||||
statement_databasets=None, language_type: str = "zh") -> str:
|
||||
"""
|
||||
Renders the evaluate prompt using the evaluate_optimized.jinja2 template.
|
||||
|
||||
@@ -23,6 +23,8 @@ async def render_evaluate_prompt(evaluate_data: List[Any], schema: Any,
|
||||
Returns:
|
||||
Rendered prompt content as string
|
||||
"""
|
||||
if statement_databasets is None:
|
||||
statement_databasets = []
|
||||
template = prompt_env.get_template("evaluate.jinja2")
|
||||
|
||||
# Convert Pydantic model to JSON schema if needed
|
||||
@@ -46,7 +48,7 @@ async def render_evaluate_prompt(evaluate_data: List[Any], schema: Any,
|
||||
|
||||
|
||||
async def render_reflexion_prompt(data: Dict[str, Any], schema: Any, baseline: str, memory_verify: bool = False,
|
||||
statement_databasets: List[str] = [],language_type:str = "zh") -> str:
|
||||
statement_databasets=None, language_type: str = "zh") -> str:
|
||||
"""
|
||||
Renders the reflexion prompt using the reflexion_optimized.jinja2 template.
|
||||
|
||||
@@ -58,6 +60,8 @@ async def render_reflexion_prompt(data: Dict[str, Any], schema: Any, baseline: s
|
||||
Returns:
|
||||
Rendered prompt content as a string.
|
||||
"""
|
||||
if statement_databasets is None:
|
||||
statement_databasets = []
|
||||
template = prompt_env.get_template("reflexion.jinja2")
|
||||
|
||||
# Convert Pydantic model to JSON schema if needed
|
||||
@@ -69,7 +73,7 @@ async def render_reflexion_prompt(data: Dict[str, Any], schema: Any, baseline: s
|
||||
json_schema = schema
|
||||
|
||||
rendered_prompt = template.render(data=data, json_schema=json_schema,
|
||||
baseline=baseline,memory_verify=memory_verify,
|
||||
statement_databasets=statement_databasets,language_type=language_type)
|
||||
baseline=baseline, memory_verify=memory_verify,
|
||||
statement_databasets=statement_databasets, language_type=language_type)
|
||||
|
||||
return rendered_prompt
|
||||
|
||||
@@ -2,6 +2,7 @@ from .base import RedBearModelConfig, get_provider_llm_class, RedBearModelFacto
|
||||
from .llm import RedBearLLM
|
||||
from .embedding import RedBearEmbeddings
|
||||
from .rerank import RedBearRerank
|
||||
from .generation import RedBearImageGenerator, RedBearVideoGenerator
|
||||
|
||||
__all__ = [
|
||||
"RedBearModelConfig",
|
||||
@@ -9,5 +10,7 @@ __all__ = [
|
||||
"RedBearEmbeddings",
|
||||
"RedBearRerank",
|
||||
"RedBearModelFactory",
|
||||
"get_provider_llm_class"
|
||||
"get_provider_llm_class",
|
||||
"RedBearImageGenerator",
|
||||
"RedBearVideoGenerator"
|
||||
]
|
||||
@@ -1,53 +1,73 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Callable, Dict, List, Optional, TypeVar
|
||||
from typing import Any, Dict, Optional, TypeVar
|
||||
|
||||
from langchain_aws import ChatBedrock
|
||||
from langchain_community.chat_models import ChatTongyi
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.language_models import BaseLLM
|
||||
from langchain_ollama import OllamaLLM
|
||||
from langchain_openai import ChatOpenAI, OpenAI
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
import httpx
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.models.models_model import ModelProvider, ModelType
|
||||
from langchain_community.document_compressors import JinaRerank
|
||||
from langchain_core.callbacks import CallbackManagerForLLMRun
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.language_models import BaseLanguageModel, BaseLLM
|
||||
from langchain_core.outputs import Generation, LLMResult
|
||||
from langchain_core.retrievers import BaseRetriever
|
||||
from langchain_core.runnables import RunnableSerializable
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class RedBearModelConfig(BaseModel):
|
||||
"""模型配置基类"""
|
||||
model_name: str
|
||||
provider: str
|
||||
api_key: str
|
||||
base_url: Optional[str] = None
|
||||
is_omni: bool = False # 是否为 Omni 模型
|
||||
# 请求超时时间(秒)- 默认120秒以支持复杂的LLM调用,可通过环境变量 LLM_TIMEOUT 配置
|
||||
timeout: float = Field(default_factory=lambda: float(os.getenv("LLM_TIMEOUT", "120.0")))
|
||||
# 最大重试次数 - 默认2次以避免过长等待,可通过环境变量 LLM_MAX_RETRIES 配置
|
||||
max_retries: int = Field(default_factory=lambda: int(os.getenv("LLM_MAX_RETRIES", "2")))
|
||||
concurrency: int = 5 # 并发限流
|
||||
concurrency: int = 5 # 并发限流
|
||||
extra_params: Dict[str, Any] = {}
|
||||
|
||||
|
||||
class RedBearModelFactory:
|
||||
"""模型工厂类"""
|
||||
|
||||
|
||||
@classmethod
|
||||
def get_model_params(cls, config: RedBearModelConfig) -> Dict[str, Any]:
|
||||
"""根据提供商获取模型参数"""
|
||||
provider = config.provider.lower()
|
||||
|
||||
|
||||
# 打印供应商信息用于调试
|
||||
from app.core.logging_config import get_business_logger
|
||||
logger = get_business_logger()
|
||||
logger.debug(f"获取模型参数 - Provider: {provider}, Model: {config.model_name}")
|
||||
logger.debug(f"获取模型参数 - Provider: {provider}, Model: {config.model_name}, is_omni: {config.is_omni}")
|
||||
|
||||
if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK, ModelProvider.OLLAMA]:
|
||||
# dashscope 的 omni 模型使用 OpenAI 兼容模式
|
||||
if provider == ModelProvider.DASHSCOPE and config.is_omni:
|
||||
import httpx
|
||||
if not config.base_url:
|
||||
config.base_url = "https://dashscope.aliyuncs.com/compatible-mode/v1"
|
||||
timeout_config = httpx.Timeout(
|
||||
timeout=config.timeout,
|
||||
connect=60.0,
|
||||
read=config.timeout,
|
||||
write=60.0,
|
||||
pool=10.0,
|
||||
)
|
||||
return {
|
||||
"model": config.model_name,
|
||||
"base_url": config.base_url,
|
||||
"api_key": config.api_key,
|
||||
"timeout": timeout_config,
|
||||
"max_retries": config.max_retries,
|
||||
**config.extra_params
|
||||
}
|
||||
|
||||
if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK, ModelProvider.OLLAMA, ModelProvider.VOLCANO]:
|
||||
# 使用 httpx.Timeout 对象来设置详细的超时配置
|
||||
# 这样可以分别控制连接超时和读取超时
|
||||
import httpx
|
||||
@@ -65,7 +85,7 @@ class RedBearModelFactory:
|
||||
"timeout": timeout_config,
|
||||
"max_retries": config.max_retries,
|
||||
**config.extra_params
|
||||
}
|
||||
}
|
||||
elif provider == ModelProvider.DASHSCOPE:
|
||||
# DashScope (通义千问) 使用自己的参数格式
|
||||
# 注意: DashScopeEmbeddings 不支持 timeout 和 base_url 参数
|
||||
@@ -82,7 +102,7 @@ class RedBearModelFactory:
|
||||
# region 从 base_url 或 extra_params 获取
|
||||
from botocore.config import Config as BotoConfig
|
||||
from app.core.models.bedrock_model_mapper import normalize_bedrock_model_id
|
||||
|
||||
|
||||
max_pool_connections = int(os.getenv("BEDROCK_MAX_POOL_CONNECTIONS", "50"))
|
||||
max_retries = int(os.getenv("BEDROCK_MAX_RETRIES", "2"))
|
||||
# Configure with increased connection pool
|
||||
@@ -90,16 +110,16 @@ class RedBearModelFactory:
|
||||
max_pool_connections=max_pool_connections,
|
||||
retries={'max_attempts': max_retries, 'mode': 'adaptive'}
|
||||
)
|
||||
|
||||
|
||||
# 标准化模型 ID(自动转换简化名称为完整 Bedrock Model ID)
|
||||
model_id = normalize_bedrock_model_id(config.model_name)
|
||||
|
||||
|
||||
params = {
|
||||
"model_id": model_id,
|
||||
"config": boto_config,
|
||||
**config.extra_params
|
||||
}
|
||||
|
||||
|
||||
# 解析 API key (格式: access_key_id:secret_access_key)
|
||||
if config.api_key and ":" in config.api_key:
|
||||
access_key_id, secret_access_key = config.api_key.split(":", 1)
|
||||
@@ -107,63 +127,65 @@ class RedBearModelFactory:
|
||||
params["aws_secret_access_key"] = secret_access_key
|
||||
elif config.api_key:
|
||||
params["aws_access_key_id"] = config.api_key
|
||||
|
||||
|
||||
# 设置 region
|
||||
if config.base_url:
|
||||
params["region_name"] = config.base_url
|
||||
elif "region_name" not in params:
|
||||
params["region_name"] = "us-east-1" # 默认区域
|
||||
|
||||
|
||||
return params
|
||||
else:
|
||||
raise BusinessException(f"不支持的提供商: {provider}", code=BizCode.PROVIDER_NOT_SUPPORTED)
|
||||
|
||||
|
||||
@classmethod
|
||||
def get_rerank_model_params(cls, config: RedBearModelConfig) -> Dict[str, Any]:
|
||||
"""根据提供商获取模型参数"""
|
||||
provider = config.provider.lower()
|
||||
if provider in [ModelProvider.XINFERENCE, ModelProvider.GPUSTACK]:
|
||||
return {
|
||||
return {
|
||||
"model": config.model_name,
|
||||
# "base_url": config.base_url,
|
||||
"jina_api_key": config.api_key,
|
||||
**config.extra_params
|
||||
}
|
||||
}
|
||||
else:
|
||||
raise BusinessException(f"不支持的提供商: {provider}", code=BizCode.PROVIDER_NOT_SUPPORTED)
|
||||
|
||||
def get_provider_llm_class(config:RedBearModelConfig, type: ModelType=ModelType.LLM) -> type[BaseLLM]:
|
||||
|
||||
def get_provider_llm_class(config: RedBearModelConfig, type: ModelType = ModelType.LLM) -> type[BaseLLM]:
|
||||
"""根据模型提供商获取对应的模型类"""
|
||||
provider = config.provider.lower()
|
||||
if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK] :
|
||||
|
||||
# dashscope 的 omni 模型使用 OpenAI 兼容模式
|
||||
if provider == ModelProvider.DASHSCOPE and config.is_omni:
|
||||
return ChatOpenAI
|
||||
if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK, ModelProvider.VOLCANO]:
|
||||
if type == ModelType.LLM:
|
||||
from langchain_openai import OpenAI
|
||||
return OpenAI
|
||||
return OpenAI
|
||||
elif type == ModelType.CHAT:
|
||||
from langchain_openai import ChatOpenAI
|
||||
return ChatOpenAI
|
||||
else:
|
||||
raise BusinessException(f"不支持的模型提供商及类型: {provider}-{type}", code=BizCode.PROVIDER_NOT_SUPPORTED)
|
||||
elif provider == ModelProvider.DASHSCOPE:
|
||||
from langchain_community.chat_models import ChatTongyi
|
||||
return ChatTongyi
|
||||
elif provider == ModelProvider.OLLAMA:
|
||||
from langchain_ollama import OllamaLLM
|
||||
elif provider == ModelProvider.OLLAMA:
|
||||
return OllamaLLM
|
||||
elif provider == ModelProvider.BEDROCK:
|
||||
from langchain_aws import ChatBedrock, ChatBedrockConverse
|
||||
|
||||
return ChatBedrock
|
||||
else:
|
||||
raise BusinessException(f"不支持的模型提供商: {provider}", code=BizCode.PROVIDER_NOT_SUPPORTED)
|
||||
|
||||
|
||||
def get_provider_embedding_class(provider: str) -> type[Embeddings]:
|
||||
"""根据模型提供商获取对应的模型类"""
|
||||
provider = provider.lower()
|
||||
if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK] :
|
||||
if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK]:
|
||||
from langchain_openai import OpenAIEmbeddings
|
||||
return OpenAIEmbeddings
|
||||
return OpenAIEmbeddings
|
||||
elif provider == ModelProvider.DASHSCOPE:
|
||||
from langchain_community.embeddings import DashScopeEmbeddings
|
||||
return DashScopeEmbeddings
|
||||
return DashScopeEmbeddings
|
||||
elif provider == ModelProvider.OLLAMA:
|
||||
from langchain_ollama import OllamaEmbeddings
|
||||
return OllamaEmbeddings
|
||||
@@ -173,14 +195,15 @@ def get_provider_embedding_class(provider: str) -> type[Embeddings]:
|
||||
else:
|
||||
raise BusinessException(f"不支持的模型提供商: {provider}", code=BizCode.PROVIDER_NOT_SUPPORTED)
|
||||
|
||||
|
||||
def get_provider_rerank_class(provider: str):
|
||||
"""根据模型提供商获取对应的模型类"""
|
||||
provider = provider.lower()
|
||||
if provider in [ModelProvider.XINFERENCE, ModelProvider.GPUSTACK] :
|
||||
provider = provider.lower()
|
||||
if provider in [ModelProvider.XINFERENCE, ModelProvider.GPUSTACK]:
|
||||
from langchain_community.document_compressors import JinaRerank
|
||||
return JinaRerank
|
||||
# elif provider == ModelProvider.OLLAMA:
|
||||
return JinaRerank
|
||||
# elif provider == ModelProvider.OLLAMA:
|
||||
# from langchain_ollama import OllamaEmbeddings
|
||||
# return OllamaEmbeddings
|
||||
else:
|
||||
raise BusinessException(f"不支持的模型提供商: {provider}", code=BizCode.PROVIDER_NOT_SUPPORTED)
|
||||
raise BusinessException(f"不支持的模型提供商: {provider}", code=BizCode.PROVIDER_NOT_SUPPORTED)
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user