Compare commits
928 Commits
v0.2.7
...
hotfix/v0.
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9edf5c9bd6 | ||
|
|
2f0c4300df | ||
|
|
e0d7a5a91f | ||
|
|
5ac2d5602e | ||
|
|
f4c3974956 | ||
|
|
71e5b6586a | ||
|
|
bfb723a468 | ||
|
|
61f2e44bd5 | ||
|
|
ed765b7c26 | ||
|
|
3018d186f7 | ||
|
|
2e1470cb52 | ||
|
|
737858731b | ||
|
|
d072eb1af7 | ||
|
|
daaee63bd5 | ||
|
|
e3c643b659 | ||
|
|
017efdc320 | ||
|
|
29aef4527c | ||
|
|
d9cb2b511b | ||
|
|
49e0801d15 | ||
|
|
dde7ea9039 | ||
|
|
5262aedab9 | ||
|
|
441b21774d | ||
|
|
d6dd038167 | ||
|
|
47c242e513 | ||
|
|
811193dd75 | ||
|
|
797780824c | ||
|
|
75e95bab01 | ||
|
|
3c2a78a449 | ||
|
|
4f0e5d0866 | ||
|
|
7a84ee33c6 | ||
|
|
e3265e4ba3 | ||
|
|
3e7a004599 | ||
|
|
fa1e5ee43c | ||
|
|
c72a6fd724 | ||
|
|
0965008210 | ||
|
|
bcadd2a6f1 | ||
|
|
b5ec5c2cea | ||
|
|
aa683efaa0 | ||
|
|
2d9986f902 | ||
|
|
06075ffef5 | ||
|
|
a7336b0829 | ||
|
|
0d16e168e7 | ||
|
|
a882e5e5c4 | ||
|
|
c614bb5be7 | ||
|
|
1ff0f3ebfd | ||
|
|
bafcb5c545 | ||
|
|
f8d27fada6 | ||
|
|
90365cd026 | ||
|
|
d96c7b88f0 | ||
|
|
99559621c5 | ||
|
|
926f65a1ff | ||
|
|
b20971dc95 | ||
|
|
1ff0274027 | ||
|
|
8495aa5dde | ||
|
|
d8ef7a8e02 | ||
|
|
7a4a02b2bb | ||
|
|
8f623a66c8 | ||
|
|
77ed9faea1 | ||
|
|
1ff3748935 | ||
|
|
f023c43f80 | ||
|
|
70d4e79de1 | ||
|
|
62f345b3de | ||
|
|
52e726eabc | ||
|
|
9470dd2f1e | ||
|
|
ef8c7093b5 | ||
|
|
05ea372776 | ||
|
|
2b067ce08a | ||
|
|
b63cff2993 | ||
|
|
5bb9ce9018 | ||
|
|
aa581a9083 | ||
|
|
ac51ccaf1f | ||
|
|
bd955569b3 | ||
|
|
7a2a941ac4 | ||
|
|
62355186ef | ||
|
|
11ea486f82 | ||
|
|
efdee32f85 | ||
|
|
988d101e93 | ||
|
|
418f9f4dba | ||
|
|
520ee7c132 | ||
|
|
2b52b32b96 | ||
|
|
b8acc0a32f | ||
|
|
e1cf3bb3d2 | ||
|
|
6f66c9727f | ||
|
|
3beca641e1 | ||
|
|
b8507a1df6 | ||
|
|
0f28d54c43 | ||
|
|
4c2a1e6d1d | ||
|
|
7cfb6ace22 | ||
|
|
91cc20d589 | ||
|
|
f01ca51896 | ||
|
|
f4a63f7d55 | ||
|
|
0019f3acfd | ||
|
|
bc14c94407 | ||
|
|
a21dad70ed | ||
|
|
807a4e715d | ||
|
|
58d18b476c | ||
|
|
5e5927a0b9 | ||
|
|
7869121382 | ||
|
|
7c0fb624d9 | ||
|
|
af83980f99 | ||
|
|
cf0d11208c | ||
|
|
87d1630230 | ||
|
|
50392384e7 | ||
|
|
9a926a8398 | ||
|
|
e5e6699168 | ||
|
|
068e2bfb7e | ||
|
|
4ce6fede67 | ||
|
|
8497c955f9 | ||
|
|
72fe3962cf | ||
|
|
c253968aa8 | ||
|
|
d517bceda2 | ||
|
|
412183c359 | ||
|
|
90e8e90528 | ||
|
|
fd05c000f6 | ||
|
|
627d6a0381 | ||
|
|
807dee8460 | ||
|
|
cd018814fe | ||
|
|
e0b7e95af6 | ||
|
|
3a62d50048 | ||
|
|
0e60da6d8a | ||
|
|
39e94eb3ea | ||
|
|
3e0f59adc6 | ||
|
|
660cd2fadb | ||
|
|
6f1bb43eab | ||
|
|
61b5627505 | ||
|
|
af6392fb09 | ||
|
|
84b1a95313 | ||
|
|
8b21dab255 | ||
|
|
fc5ce63e44 | ||
|
|
15a863b41a | ||
|
|
5226c5b79d | ||
|
|
27e9f9968d | ||
|
|
d38612a10d | ||
|
|
32c71dcd89 | ||
|
|
428e7ebaa5 | ||
|
|
57833689d9 | ||
|
|
384a67482c | ||
|
|
7842435321 | ||
|
|
33c4c5d31b | ||
|
|
ca4f7aa65d | ||
|
|
b875626f18 | ||
|
|
130684cac0 | ||
|
|
5adff38bda | ||
|
|
62e0b2730b | ||
|
|
55b2e05ba8 | ||
|
|
562ca6c1f1 | ||
|
|
e298b38de9 | ||
|
|
a7b8ba0c66 | ||
|
|
460c86cd94 | ||
|
|
33a1c178ff | ||
|
|
c81612e6d3 | ||
|
|
9f9ac69f97 | ||
|
|
0516822d42 | ||
|
|
b598171a3d | ||
|
|
a4ea7f0385 | ||
|
|
32ae60fc65 | ||
|
|
6b272c5b44 | ||
|
|
2782d0661f | ||
|
|
ea2f5e61c9 | ||
|
|
5975d70bf9 | ||
|
|
e0546e01ef | ||
|
|
70aab94fc3 | ||
|
|
b7c1ce261b | ||
|
|
edac6a164e | ||
|
|
1503b242ea | ||
|
|
18fd48505d | ||
|
|
807ddce5cd | ||
|
|
62fb6c79a0 | ||
|
|
cc373b2864 | ||
|
|
f2d7479229 | ||
|
|
ae1909b7e9 | ||
|
|
b0aaa12340 | ||
|
|
5eb65e7ad8 | ||
|
|
cb5610e8b1 | ||
|
|
6bb01119d0 | ||
|
|
c16e832081 | ||
|
|
e3d50c5c55 | ||
|
|
e64aadce95 | ||
|
|
bad6087c25 | ||
|
|
b04c05f4a4 | ||
|
|
5e372627f7 | ||
|
|
29611738ce | ||
|
|
de846c05ab | ||
|
|
6475387af8 | ||
|
|
b330bdba29 | ||
|
|
bed279c604 | ||
|
|
9eaf779e67 | ||
|
|
1fbccd98a7 | ||
|
|
931b800bb6 | ||
|
|
d76bb36b9f | ||
|
|
3c93409f7f | ||
|
|
9451a08e7f | ||
|
|
bc49bd2a43 | ||
|
|
4bfc9ca991 | ||
|
|
1ba60401af | ||
|
|
236e8973ac | ||
|
|
ca6cc8ae63 | ||
|
|
dd2cc89c62 | ||
|
|
a87bba93c2 | ||
|
|
a153fdb7cb | ||
|
|
4eed393db5 | ||
|
|
dc8e432719 | ||
|
|
16a0099e27 | ||
|
|
c4cf639bbc | ||
|
|
f91431a70d | ||
|
|
0102ad3a30 | ||
|
|
ebe298b71d | ||
|
|
a486ca7857 | ||
|
|
dd40e5df5f | ||
|
|
8e1ec1bae6 | ||
|
|
1f9c4919be | ||
|
|
065182ad5c | ||
|
|
90ec8db0d8 | ||
|
|
78baf1c60a | ||
|
|
072c94cccb | ||
|
|
a66030d1b3 | ||
|
|
a90ceaf5a2 | ||
|
|
725f2f5146 | ||
|
|
a2c3357e80 | ||
|
|
acbc954e6f | ||
|
|
77032583ab | ||
|
|
ef533d27ac | ||
|
|
ca1a2c7b9e | ||
|
|
145fa398dd | ||
|
|
274430b2c9 | ||
|
|
e9972834fe | ||
|
|
1ecc04fee7 | ||
|
|
78cd1f69a3 | ||
|
|
aabd9a1b57 | ||
|
|
b9439b337a | ||
|
|
eb9f4f39f1 | ||
|
|
baa4b56426 | ||
|
|
49bcc6131b | ||
|
|
0d3f6f1e14 | ||
|
|
84536925c6 | ||
|
|
b22a5a9f12 | ||
|
|
b8825a83dd | ||
|
|
08b4d5c1cf | ||
|
|
a5dfc472d3 | ||
|
|
48d29bcc63 | ||
|
|
856c6f6d78 | ||
|
|
bfc47ad738 | ||
|
|
841b6abb33 | ||
|
|
8a1114a1a7 | ||
|
|
be8c481d6d | ||
|
|
5d439346a1 | ||
|
|
ed753caaf7 | ||
|
|
9a931389ea | ||
|
|
b9d469b6e3 | ||
|
|
e817cfd292 | ||
|
|
af86cb3556 | ||
|
|
e48b146e60 | ||
|
|
07b66a9801 | ||
|
|
c3ee3c4af9 | ||
|
|
cd8229f370 | ||
|
|
9a4a614fc8 | ||
|
|
0b5a030e46 | ||
|
|
675d0fc5ef | ||
|
|
6291c28f0a | ||
|
|
30b512e554 | ||
|
|
33c73c6c6f | ||
|
|
072d118935 | ||
|
|
2e7ebf174b | ||
|
|
3ece83d419 | ||
|
|
9c1c232b2e | ||
|
|
bfc98efc9d | ||
|
|
cfbf83f71e | ||
|
|
a43e8fa594 | ||
|
|
6c8d0d9d64 | ||
|
|
bd2a3bd7ef | ||
|
|
1f72b8aa70 | ||
|
|
9bb32888a2 | ||
|
|
caee5d214e | ||
|
|
38f3455bab | ||
|
|
d60cb423a4 | ||
|
|
b20a65ce29 | ||
|
|
99862db7a0 | ||
|
|
00a8099857 | ||
|
|
117e29fbe3 | ||
|
|
32740e8159 | ||
|
|
bc5ea2d421 | ||
|
|
d34bf4bc89 | ||
|
|
c4ff1a325b | ||
|
|
d1f0258065 | ||
|
|
5db59bc9cf | ||
|
|
a711635694 | ||
|
|
15b3ce3dd5 | ||
|
|
9cc19047b4 | ||
|
|
2e8e63878e | ||
|
|
38955d7d45 | ||
|
|
b6167d4e94 | ||
|
|
7890970a39 | ||
|
|
203732de1d | ||
|
|
fa4be10e51 | ||
|
|
1b52850526 | ||
|
|
1732fc7af5 | ||
|
|
a52e2137b7 | ||
|
|
377f79773d | ||
|
|
63235de42b | ||
|
|
106a32bc3a | ||
|
|
dcb7b496d3 | ||
|
|
0b47194f12 | ||
|
|
abbd92b74c | ||
|
|
960ee9f2df | ||
|
|
1c133d3d6c | ||
|
|
d270d25a99 | ||
|
|
8abd59b26e | ||
|
|
bd48b4fdbe | ||
|
|
9535545947 | ||
|
|
aad6955709 | ||
|
|
9f2cd6afae | ||
|
|
2c7aaebdd5 | ||
|
|
be38c9e385 | ||
|
|
9facb513b2 | ||
|
|
9bce14be4e | ||
|
|
59f5c7a8bb | ||
|
|
12f3a3ed77 | ||
|
|
4fb3d6992c | ||
|
|
370a668ead | ||
|
|
6eca5f6cdf | ||
|
|
f61f86f8fe | ||
|
|
57eb5aa967 | ||
|
|
1305a08c86 | ||
|
|
cf519738f4 | ||
|
|
cdebe014cf | ||
|
|
853ce6f4e1 | ||
|
|
9cbe9d5edc | ||
|
|
767f9ab17c | ||
|
|
7b5b2ab31a | ||
|
|
924d10ac5b | ||
|
|
0470a71d03 | ||
|
|
378b110d91 | ||
|
|
5f7db778b5 | ||
|
|
0d15457299 | ||
|
|
ad4ddea977 | ||
|
|
75bb96d4e7 | ||
|
|
68fdf5d76f | ||
|
|
258c19f9e0 | ||
|
|
386ed2b914 | ||
|
|
264183cec2 | ||
|
|
9561578a2a | ||
|
|
99ff07ccac | ||
|
|
e77a1a92fd | ||
|
|
d3cd66fc6e | ||
|
|
b95a627424 | ||
|
|
c9ca5df05c | ||
|
|
70c3c7dd74 | ||
|
|
b482822629 | ||
|
|
8f609ba29c | ||
|
|
a1ef5146d7 | ||
|
|
8b997b422a | ||
|
|
6d6338eb06 | ||
|
|
b5c5863b39 | ||
|
|
ab45b7abac | ||
|
|
2dfc3b25d8 | ||
|
|
3ea42ac27f | ||
|
|
fff5e0e8b8 | ||
|
|
fe29141437 | ||
|
|
17d3c81c02 | ||
|
|
ef626951bc | ||
|
|
4533644e13 | ||
|
|
ca255304d9 | ||
|
|
b40f4829cb | ||
|
|
52ae914e17 | ||
|
|
baf02e4faa | ||
|
|
87c2419186 | ||
|
|
2ad25c48d2 | ||
|
|
75e8caf441 | ||
|
|
4d6038c3cc | ||
|
|
d4450658a8 | ||
|
|
02660c7c97 | ||
|
|
3ceb2efeaf | ||
|
|
e134b96333 | ||
|
|
3ea57d1cb0 | ||
|
|
4a71484151 | ||
|
|
db8b3416a6 | ||
|
|
4df41966fe | ||
|
|
2d6cde157e | ||
|
|
abc27c8372 | ||
|
|
dbe387f666 | ||
|
|
5e70d436a8 | ||
|
|
b7198f1abd | ||
|
|
5c87a2beeb | ||
|
|
3419bb137a | ||
|
|
a00684c67d | ||
|
|
6e7c641fd4 | ||
|
|
876c39b1b0 | ||
|
|
0c677701c0 | ||
|
|
4974f9aa98 | ||
|
|
c90b58bbcd | ||
|
|
d6a243f1be | ||
|
|
418114ef72 | ||
|
|
ceed61167f | ||
|
|
83774d7443 | ||
|
|
052c7c19b3 | ||
|
|
d42db0ca33 | ||
|
|
e15af5a2ba | ||
|
|
8b44b2cd61 | ||
|
|
9d91453200 | ||
|
|
ea8db7cd90 | ||
|
|
d60f16df1b | ||
|
|
3cca35a74f | ||
|
|
8dd24533bf | ||
|
|
ed90405439 | ||
|
|
533000030f | ||
|
|
a58ac385b1 | ||
|
|
91b7f2a980 | ||
|
|
891cfc2704 | ||
|
|
f7e89af9d2 | ||
|
|
afbd8c9b4f | ||
|
|
09b3b01d37 | ||
|
|
e3dcbed5f9 | ||
|
|
c7b51e7ad8 | ||
|
|
e9ad13504a | ||
|
|
c0cd2373c0 | ||
|
|
6e757ae9e2 | ||
|
|
64a73c41d6 | ||
|
|
dae7431075 | ||
|
|
643bbbcf5c | ||
|
|
6702e86536 | ||
|
|
13e35ed122 | ||
|
|
ab2bdfa088 | ||
|
|
8285250096 | ||
|
|
e59a215078 | ||
|
|
c89eccf8fe | ||
|
|
5703fc0cb4 | ||
|
|
7acb7045f0 | ||
|
|
3aed5c447a | ||
|
|
13352178ad | ||
|
|
f9f302dd2a | ||
|
|
8f216db353 | ||
|
|
9f6026492d | ||
|
|
b699b746a5 | ||
|
|
6095170169 | ||
|
|
173697e86a | ||
|
|
5c11da6a2e | ||
|
|
96214c433f | ||
|
|
167c915631 | ||
|
|
f485398768 | ||
|
|
289b1989e5 | ||
|
|
8224848ce1 | ||
|
|
c43d258455 | ||
|
|
c3e5c8b8bb | ||
|
|
930cadcaa8 | ||
|
|
57b6b34567 | ||
|
|
f878846364 | ||
|
|
7dce63dc0b | ||
|
|
03bc8ee7f5 | ||
|
|
4aefb01b0b | ||
|
|
4e9b5736b1 | ||
|
|
46fa99a8b8 | ||
|
|
17ea92357d | ||
|
|
bd70a8b812 | ||
|
|
ad5dc3c138 | ||
|
|
e37b1b01ca | ||
|
|
e659ca9fa2 | ||
|
|
758be0087f | ||
|
|
200c13b59f | ||
|
|
32f6886000 | ||
|
|
7fbf3e8873 | ||
|
|
3026702000 | ||
|
|
8677db114b | ||
|
|
2597a1f532 | ||
|
|
4298cd7d06 | ||
|
|
8197f9db35 | ||
|
|
3da6331515 | ||
|
|
539999131c | ||
|
|
d0ca5c8b27 | ||
|
|
ee6b8ffa62 | ||
|
|
14838dc064 | ||
|
|
e017870f44 | ||
|
|
9730c5ce0f | ||
|
|
bca43fcc75 | ||
|
|
f30260939a | ||
|
|
8ba0a74473 | ||
|
|
4f69224cfd | ||
|
|
6f7fee18c9 | ||
|
|
7fd00009a2 | ||
|
|
4534b65d6a | ||
|
|
cc58c7333c | ||
|
|
c936277507 | ||
|
|
701df40270 | ||
|
|
b724dbe53a | ||
|
|
ac7c891ded | ||
|
|
a5bce221bd | ||
|
|
3ed6f49bb0 | ||
|
|
a416a6b2bd | ||
|
|
35be03803f | ||
|
|
6427018ffb | ||
|
|
06b823ff96 | ||
|
|
0fdb489227 | ||
|
|
f6394a791e | ||
|
|
4bfd4944d0 | ||
|
|
7faf291ec3 | ||
|
|
3d291e3c23 | ||
|
|
b35bedc730 | ||
|
|
4d39cdf464 | ||
|
|
a874cc70a4 | ||
|
|
2319432182 | ||
|
|
7556468c6e | ||
|
|
91d38c0648 | ||
|
|
df3d58d388 | ||
|
|
80856e3c92 | ||
|
|
8c6f395818 | ||
|
|
2f4f7219e3 | ||
|
|
4c5183eddc | ||
|
|
dfc0ee9424 | ||
|
|
8dbb067b83 | ||
|
|
1df3fc416a | ||
|
|
6223b80cc4 | ||
|
|
68489f1b28 | ||
|
|
477853b04e | ||
|
|
863be50aaf | ||
|
|
d72d57f966 | ||
|
|
5b940e5f1a | ||
|
|
9ae1d2f0d9 | ||
|
|
318f1be107 | ||
|
|
4cab6317de | ||
|
|
81bfc9af36 | ||
|
|
189013f0f8 | ||
|
|
6f5bcd18a4 | ||
|
|
c7ef97c7a6 | ||
|
|
4d4a780ab7 | ||
|
|
9d2f3aa8f9 | ||
|
|
f2c9902a07 | ||
|
|
2525f8795c | ||
|
|
b7a03a844f | ||
|
|
c13c3846d1 | ||
|
|
30b5db1e98 | ||
|
|
f92eb9f45a | ||
|
|
a136d44e27 | ||
|
|
65b2f9e6e1 | ||
|
|
5275a274c3 | ||
|
|
4f09c4fbb3 | ||
|
|
7a3220aff5 | ||
|
|
14a32778f7 | ||
|
|
2a12cb04bf | ||
|
|
1e986c641f | ||
|
|
38c6c7f053 | ||
|
|
7c0743eb8f | ||
|
|
e981f066a3 | ||
|
|
db14d40fb3 | ||
|
|
e8d575fd0b | ||
|
|
a7285e35ad | ||
|
|
c4461c4917 | ||
|
|
2df615eca0 | ||
|
|
504e5ba61e | ||
|
|
0bae290e0c | ||
|
|
294ee49d59 | ||
|
|
26c36f70e6 | ||
|
|
c4b83b1f9c | ||
|
|
14413fd413 | ||
|
|
caab58dd2f | ||
|
|
0e899bea05 | ||
|
|
1794f8f209 | ||
|
|
85daf576e9 | ||
|
|
56fd5680cf | ||
|
|
0380c13a3b | ||
|
|
9ddc523f91 | ||
|
|
491ef27b8a | ||
|
|
edd115582f | ||
|
|
45eef12842 | ||
|
|
49364802c2 | ||
|
|
8873078006 | ||
|
|
2b9fd33bc8 | ||
|
|
e86d679ae5 | ||
|
|
def7367e33 | ||
|
|
54cff5861a | ||
|
|
dc2a73155b | ||
|
|
1856c55c04 | ||
|
|
522eb569f1 | ||
|
|
9df41456f6 | ||
|
|
04c54081c8 | ||
|
|
1c49e3c167 | ||
|
|
fb6ce839d2 | ||
|
|
c7275dccac | ||
|
|
d62b484d71 | ||
|
|
8ff1c6bd08 | ||
|
|
3dcf901043 | ||
|
|
d6dfc2cb12 | ||
|
|
8a3032ce4a | ||
|
|
391c60c812 | ||
|
|
b739b032d9 | ||
|
|
3dc863cabf | ||
|
|
611b14dfea | ||
|
|
de6e2f54d2 | ||
|
|
89d188fbf3 | ||
|
|
6bba574ca6 | ||
|
|
9cbffd6408 | ||
|
|
4d2ad5757c | ||
|
|
cd0ca9cae4 | ||
|
|
3369b702e4 | ||
|
|
cbec2c1356 | ||
|
|
5987eee0a8 | ||
|
|
6348304b7d | ||
|
|
59f8010519 | ||
|
|
9308c6efae | ||
|
|
2f78b7cf5e | ||
|
|
f86448f4bf | ||
|
|
48e2e613bb | ||
|
|
1060074740 | ||
|
|
95b7df7e38 | ||
|
|
fd1634eec4 | ||
|
|
efeead41b2 | ||
|
|
a3428c2435 | ||
|
|
31b8a3764e | ||
|
|
2ff81ba101 | ||
|
|
93deb286a3 | ||
|
|
7bd97bf6d3 | ||
|
|
2d1a1b4a1f | ||
|
|
503c890d93 | ||
|
|
1f73501786 | ||
|
|
eef13cb717 | ||
|
|
c70ac1339e | ||
|
|
24c13d408e | ||
|
|
338d7f1065 | ||
|
|
27672cfaa0 | ||
|
|
4dbb2bf2e2 | ||
|
|
37bc4beab4 | ||
|
|
6056952936 | ||
|
|
31085ed678 | ||
|
|
dce7206c44 | ||
|
|
c17a2dad2d | ||
|
|
e8ae46b286 | ||
|
|
78316de411 | ||
|
|
c205e7d20e | ||
|
|
81f3b50200 | ||
|
|
e3795fe1ed | ||
|
|
72a2f2a7e8 | ||
|
|
0f092e08f4 | ||
|
|
8e7603bcc4 | ||
|
|
035cc17264 | ||
|
|
a079358028 | ||
|
|
cf26c9f39c | ||
|
|
fa29a39920 | ||
|
|
2146c555d2 | ||
|
|
240f1d431b | ||
|
|
9f947a3395 | ||
|
|
bf5c4628c3 | ||
|
|
911d5e0b34 | ||
|
|
bd31aa5abf | ||
|
|
0775fad5f0 | ||
|
|
726148d7ee | ||
|
|
0f1b1d7d10 | ||
|
|
fabc8936ab | ||
|
|
11aa2e1f9e | ||
|
|
ca654cca74 | ||
|
|
bd1f649bd0 | ||
|
|
06de54ebfd | ||
|
|
ea00747c66 | ||
|
|
3db031891e | ||
|
|
fb6ca3909a | ||
|
|
929afb1770 | ||
|
|
6235584b2e | ||
|
|
0b1ea33b41 | ||
|
|
3929f811b8 | ||
|
|
7c6e48b04e | ||
|
|
b1b53f6b1d | ||
|
|
551a2b59a5 | ||
|
|
9a765ac71e | ||
|
|
83e26732de | ||
|
|
52fdfc7744 | ||
|
|
4e544325a0 | ||
|
|
99a2f396fd | ||
|
|
0157c9d262 | ||
|
|
5ddacab162 | ||
|
|
a51e34852c | ||
|
|
fcc81ac025 | ||
|
|
36f670b2e9 | ||
|
|
cbcbc8822c | ||
|
|
69c001bf84 | ||
|
|
aa2d1e7a35 | ||
|
|
39b2f3ba0e | ||
|
|
43064ab71b | ||
|
|
4144f0b9b5 | ||
|
|
08f0be17ce | ||
|
|
2915e464bf | ||
|
|
152559ae46 | ||
|
|
1f531f1ace | ||
|
|
7ec947189c | ||
|
|
b4615bacdc | ||
|
|
e849fed5c1 | ||
|
|
0f5cae4590 | ||
|
|
1c3029f360 | ||
|
|
e2411e0bdd | ||
|
|
7af88b19cf | ||
|
|
c3f8dbd4bc | ||
|
|
c1e48fde86 | ||
|
|
f644c84fbb | ||
|
|
d0afce27c4 | ||
|
|
b84aba71e7 | ||
|
|
2e481df465 | ||
|
|
a322ec4fd5 | ||
|
|
bdbf9c0609 | ||
|
|
ef7d59e442 | ||
|
|
27b782e12a | ||
|
|
37a22fbfa9 | ||
|
|
d798d101f7 | ||
|
|
825f225f63 | ||
|
|
4d5e2958dc | ||
|
|
6105d46198 | ||
|
|
7aec157859 | ||
|
|
13abb03d87 | ||
|
|
e8947ad0bb | ||
|
|
7056865726 | ||
|
|
9d8c26b999 | ||
|
|
c2c832f8c9 | ||
|
|
6bc4f04293 | ||
|
|
9d150ab353 | ||
|
|
f045b59b2d | ||
|
|
0bb8278a39 | ||
|
|
e43f812c14 | ||
|
|
d584b47280 | ||
|
|
3e995cd971 | ||
|
|
b018e35ada | ||
|
|
4bc030c1ef | ||
|
|
86a0aa1f9f | ||
|
|
d523e4f3c6 | ||
|
|
84c23e7c4e | ||
|
|
186d097e00 | ||
|
|
c5cfe557da | ||
|
|
f786a66a3c | ||
|
|
ebd51928d7 | ||
|
|
2258b5c43c | ||
|
|
2e50e30071 | ||
|
|
8c804a1011 | ||
|
|
1a4c2d7cd0 | ||
|
|
c2fc4ab4ff | ||
|
|
83fcabadae | ||
|
|
d12ad213e0 | ||
|
|
33d522b387 | ||
|
|
5997458aaf | ||
|
|
68f9471caf | ||
|
|
ecbb61db27 | ||
|
|
b42815ee7a | ||
|
|
49d7398e14 | ||
|
|
91589c1497 | ||
|
|
a07727c047 | ||
|
|
25bc506f74 | ||
|
|
18ca83d763 | ||
|
|
4bbc561625 | ||
|
|
d77220a603 | ||
|
|
f52b681133 | ||
|
|
f6efa0d711 | ||
|
|
0fccc91dac | ||
|
|
8d8c6c695a | ||
|
|
57342259ce | ||
|
|
be46ed8865 | ||
|
|
04b2205769 | ||
|
|
76ba357982 | ||
|
|
2c318f6e60 | ||
|
|
3f04153f22 | ||
|
|
3df8af3852 | ||
|
|
8b9ab8a841 | ||
|
|
750dbcc7c3 | ||
|
|
5d6007aaff | ||
|
|
291767031c | ||
|
|
22ffe6ef1d | ||
|
|
02df1a70f3 | ||
|
|
8c5fa9c441 | ||
|
|
e6c558c2a0 | ||
|
|
b52e4d756c | ||
|
|
1089a52ca0 | ||
|
|
c7fb9ab8e3 | ||
|
|
83017d0c80 | ||
|
|
e24217a6ba | ||
|
|
a0f2f738df | ||
|
|
9d9250954b | ||
|
|
f042f44501 | ||
|
|
56c98648f9 | ||
|
|
956efe6a09 | ||
|
|
bb64ad23dd | ||
|
|
a97326df74 | ||
|
|
1503f8781a | ||
|
|
163ddbb6ed | ||
|
|
7bbfd33ca0 | ||
|
|
0ea47ce890 | ||
|
|
38f891235c | ||
|
|
4d83c074d9 | ||
|
|
0e9672df80 | ||
|
|
abc7460539 | ||
|
|
4bb2ccfba7 | ||
|
|
969d428320 | ||
|
|
ff64522c50 | ||
|
|
65dc1a8f48 | ||
|
|
859b7f3c7f | ||
|
|
da3f875555 | ||
|
|
44d63a44da | ||
|
|
7e5e1609b0 | ||
|
|
d94adcb19c | ||
|
|
83894df260 | ||
|
|
7b99a32a1e | ||
|
|
e8c3744f5e | ||
|
|
06d1f54030 | ||
|
|
599ccb6bde | ||
|
|
db9050c302 | ||
|
|
71b3b665b5 | ||
|
|
3b8a806661 | ||
|
|
774719fb50 | ||
|
|
a3ccd41288 | ||
|
|
8ddacb7bc9 | ||
|
|
e74a74c3fb | ||
|
|
262a9ddc48 | ||
|
|
70f84b65ec | ||
|
|
ec5cb42f67 | ||
|
|
0802481fd2 | ||
|
|
548ba0ae36 | ||
|
|
fc2360d40d | ||
|
|
ab67bda5a1 | ||
|
|
376d5ca7d0 | ||
|
|
55438136b0 | ||
|
|
82db3517d7 | ||
|
|
130490c022 | ||
|
|
ede8a11584 | ||
|
|
ba65b06582 | ||
|
|
f4f04036f3 | ||
|
|
43130dcbc8 | ||
|
|
ff6459e439 | ||
|
|
1893de4c75 | ||
|
|
dfcc85a466 | ||
|
|
dacfb360f6 | ||
|
|
8a0d83b340 | ||
|
|
be2ce854a1 | ||
|
|
e492dcd968 | ||
|
|
55bfee856d | ||
|
|
f951075551 | ||
|
|
964086a08a | ||
|
|
67501025b3 | ||
|
|
e1cc5c841a | ||
|
|
6b839bd5a8 | ||
|
|
5df339b56d | ||
|
|
56adca9f22 | ||
|
|
1e63dd8d2d | ||
|
|
fab9272124 | ||
|
|
2f66fd9aae | ||
|
|
5616583fa1 | ||
|
|
3f0e991112 | ||
|
|
477d404727 | ||
|
|
8e6288bca8 | ||
|
|
72bba0662f | ||
|
|
090f46006a | ||
|
|
abe0c7e7d1 | ||
|
|
6516f56ada | ||
|
|
ea391dc44e | ||
|
|
e21f713de0 | ||
|
|
3498e2e884 | ||
|
|
ea8edc5914 | ||
|
|
b62c40dba3 | ||
|
|
0832337839 | ||
|
|
b82f4491fb | ||
|
|
bdf0c256b3 | ||
|
|
3d91a9e926 | ||
|
|
779dbdea26 | ||
|
|
e8e342c206 | ||
|
|
78829d36cc | ||
|
|
f7c2e82dc0 | ||
|
|
88598fb9fb | ||
|
|
19d149c129 | ||
|
|
f09de3a11c | ||
|
|
e13acdc8a9 | ||
|
|
b8e85bed61 | ||
|
|
396493ad2b | ||
|
|
f32d92b9d0 | ||
|
|
6d79db8ba3 | ||
|
|
f9fb480cc3 | ||
|
|
1efa8798bf | ||
|
|
c244e9834f | ||
|
|
b1a7b58f97 | ||
|
|
e81f39b50e | ||
|
|
a0c4515a81 | ||
|
|
4bf418a3d6 | ||
|
|
f033607c8b | ||
|
|
860cd31799 | ||
|
|
d674b48f7d | ||
|
|
07c899f0a9 | ||
|
|
382e4c5377 | ||
|
|
fe6518d052 | ||
|
|
dc513dfbeb | ||
|
|
3d9bc7a986 | ||
|
|
3d79b72d70 | ||
|
|
6eb9b772e7 | ||
|
|
90c8ff35d1 | ||
|
|
ad87fd96db | ||
|
|
c7cc0cd922 | ||
|
|
81a232177e | ||
|
|
73aee97be5 | ||
|
|
aab54ca1a8 | ||
|
|
c354618e20 | ||
|
|
5141a91041 | ||
|
|
668539e737 | ||
|
|
967139cea4 | ||
|
|
6d8b1aede4 | ||
|
|
744ba31ba6 | ||
|
|
db8257b67a | ||
|
|
85770dc037 | ||
|
|
69f976a79a | ||
|
|
fd7e77eff8 | ||
|
|
05c2a093c0 | ||
|
|
01a1e8eab1 | ||
|
|
b71bc1f875 | ||
|
|
6a0ee22d81 | ||
|
|
cbc8714414 | ||
|
|
f6d929ab7a | ||
|
|
a7a2dabc5a | ||
|
|
0694075447 | ||
|
|
d66b9dd8cb | ||
|
|
7267198a8c | ||
|
|
7b8f101824 | ||
|
|
a4c942a21f | ||
|
|
2a66775e45 | ||
|
|
f0c3d5f308 | ||
|
|
d660521c5c | ||
|
|
c612dfbc1f | ||
|
|
fc58ac0408 | ||
|
|
4f5ee24bc5 | ||
|
|
5b431400be | ||
|
|
509d1a2e24 | ||
|
|
153e68e055 | ||
|
|
77b9a6a94e | ||
|
|
d68bbab419 | ||
|
|
6d53d9178c | ||
|
|
06fe3f2f01 | ||
|
|
e2b6c713e7 | ||
|
|
0b3b241436 | ||
|
|
4c18f9e858 | ||
|
|
8fec54c085 | ||
|
|
d8e37a4d2b | ||
|
|
1da2c4fa37 |
157
.github/workflows/release-notify-wechat.yml
vendored
Normal file
157
.github/workflows/release-notify-wechat.yml
vendored
Normal file
@@ -0,0 +1,157 @@
|
||||
name: Release Notify Workflow
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
types: [closed]
|
||||
|
||||
jobs:
|
||||
notify:
|
||||
if: >
|
||||
github.event.pull_request.merged == true &&
|
||||
startsWith(github.event.pull_request.base.ref, 'release')
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
# 防止 GitHub HEAD 未同步
|
||||
- run: sleep 3
|
||||
|
||||
# 1️⃣ 获取分支 HEAD
|
||||
- name: Get HEAD
|
||||
id: head
|
||||
run: |
|
||||
HEAD_SHA=$(curl -s \
|
||||
-H "Authorization: Bearer ${{ secrets.GITHUB_TOKEN }}" \
|
||||
https://api.github.com/repos/${{ github.repository }}/git/ref/heads/${{ github.event.pull_request.base.ref }} \
|
||||
| jq -r '.object.sha')
|
||||
echo "head_sha=$HEAD_SHA" >> $GITHUB_OUTPUT
|
||||
|
||||
# 2️⃣ 判断是否最终PR
|
||||
- name: Check Latest
|
||||
id: check
|
||||
run: |
|
||||
if [ "${{ github.event.pull_request.merge_commit_sha }}" = "${{ steps.head.outputs.head_sha }}" ]; then
|
||||
echo "ok=true" >> $GITHUB_OUTPUT
|
||||
else
|
||||
echo "ok=false" >> $GITHUB_OUTPUT
|
||||
fi
|
||||
|
||||
# 3️⃣ 尝试从 PR body 提取 Sourcery 摘要
|
||||
- name: Extract Sourcery Summary
|
||||
if: steps.check.outputs.ok == 'true'
|
||||
id: sourcery
|
||||
env:
|
||||
PR_BODY: ${{ github.event.pull_request.body }}
|
||||
run: |
|
||||
python3 << 'PYEOF'
|
||||
import os, re
|
||||
|
||||
body = os.environ.get("PR_BODY", "") or ""
|
||||
match = re.search(
|
||||
r"## Summary by Sourcery\s*\n(.*?)(?=\n## |\Z)",
|
||||
body,
|
||||
re.DOTALL
|
||||
)
|
||||
|
||||
if match:
|
||||
summary = match.group(1).strip()
|
||||
found = "true"
|
||||
else:
|
||||
summary = ""
|
||||
found = "false"
|
||||
|
||||
with open("sourcery_summary.txt", "w", encoding="utf-8") as f:
|
||||
f.write(summary)
|
||||
|
||||
with open(os.environ["GITHUB_OUTPUT"], "a") as gh:
|
||||
gh.write(f"found={found}\n")
|
||||
gh.write("summary<<EOF\n")
|
||||
gh.write(summary + "\n")
|
||||
gh.write("EOF\n")
|
||||
PYEOF
|
||||
|
||||
# 4️⃣ Fallback: 获取 commits + 通义千问总结
|
||||
- name: Get Commits
|
||||
if: steps.check.outputs.ok == 'true' && steps.sourcery.outputs.found == 'false'
|
||||
run: |
|
||||
curl -s \
|
||||
-H "Authorization: Bearer ${{ secrets.GITHUB_TOKEN }}" \
|
||||
${{ github.event.pull_request.commits_url }} \
|
||||
| jq -r '.[].commit.message' | head -n 20 > commits.txt
|
||||
|
||||
- name: AI Summary (Qwen Fallback)
|
||||
if: steps.check.outputs.ok == 'true' && steps.sourcery.outputs.found == 'false'
|
||||
id: qwen
|
||||
env:
|
||||
DASHSCOPE_API_KEY: ${{ secrets.DASHSCOPE_API_KEY }}
|
||||
run: |
|
||||
python3 << 'PYEOF'
|
||||
import json, os, urllib.request
|
||||
|
||||
with open("commits.txt", "r") as f:
|
||||
commits = f.read().strip()
|
||||
|
||||
prompt = "请用中文总结以下代码提交,输出3-5条要点,面向测试人员。直接输出编号列表,不要输出标题或前言:\n" + commits
|
||||
payload = {"model": "qwen-plus", "input": {"prompt": prompt}}
|
||||
data = json.dumps(payload, ensure_ascii=False).encode("utf-8")
|
||||
|
||||
req = urllib.request.Request(
|
||||
"https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation",
|
||||
data=data,
|
||||
headers={
|
||||
"Authorization": "Bearer " + os.environ["DASHSCOPE_API_KEY"],
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
)
|
||||
resp = urllib.request.urlopen(req)
|
||||
result = json.loads(resp.read().decode())
|
||||
summary = result.get("output", {}).get("text", "AI 摘要生成失败")
|
||||
|
||||
with open(os.environ["GITHUB_OUTPUT"], "a") as gh:
|
||||
gh.write("summary<<EOF\n")
|
||||
gh.write(summary + "\n")
|
||||
gh.write("EOF\n")
|
||||
PYEOF
|
||||
|
||||
# 5️⃣ 企业微信通知(Markdown)
|
||||
- name: Notify WeChat
|
||||
if: steps.check.outputs.ok == 'true'
|
||||
env:
|
||||
WECHAT_WEBHOOK: ${{ secrets.WECHAT_WEBHOOK }}
|
||||
BRANCH: ${{ github.event.pull_request.base.ref }}
|
||||
AUTHOR: ${{ github.event.pull_request.user.login }}
|
||||
PR_TITLE: ${{ github.event.pull_request.title }}
|
||||
PR_URL: ${{ github.event.pull_request.html_url }}
|
||||
SOURCERY_FOUND: ${{ steps.sourcery.outputs.found }}
|
||||
SOURCERY_SUMMARY: ${{ steps.sourcery.outputs.summary }}
|
||||
QWEN_SUMMARY: ${{ steps.qwen.outputs.summary }}
|
||||
run: |
|
||||
python3 << 'PYEOF'
|
||||
import json, os, urllib.request
|
||||
|
||||
if os.environ.get("SOURCERY_FOUND") == "true":
|
||||
label = "Summary by Sourcery"
|
||||
summary = os.environ.get("SOURCERY_SUMMARY", "")
|
||||
else:
|
||||
label = "AI变更摘要"
|
||||
summary = os.environ.get("QWEN_SUMMARY", "AI 摘要生成失败")
|
||||
|
||||
content = (
|
||||
"## 🚀 Release 发布通知\n"
|
||||
"> 📦 **分支**: " + os.environ["BRANCH"] + "\n"
|
||||
"> 👤 **提交人**: " + os.environ["AUTHOR"] + "\n"
|
||||
"> 📝 **标题**: " + os.environ["PR_TITLE"] + "\n\n"
|
||||
"### 🧠 " + label + "\n" +
|
||||
summary + "\n\n"
|
||||
"---\n"
|
||||
"🔗 [查看PR详情](" + os.environ["PR_URL"] + ")"
|
||||
)
|
||||
payload = {"msgtype": "markdown", "markdown": {"content": content}}
|
||||
data = json.dumps(payload, ensure_ascii=False).encode("utf-8")
|
||||
req = urllib.request.Request(
|
||||
os.environ["WECHAT_WEBHOOK"],
|
||||
data=data,
|
||||
headers={"Content-Type": "application/json"}
|
||||
)
|
||||
resp = urllib.request.urlopen(req)
|
||||
print(resp.read().decode())
|
||||
PYEOF
|
||||
36
.github/workflows/sync-to-gitee.yml
vendored
Normal file
36
.github/workflows/sync-to-gitee.yml
vendored
Normal file
@@ -0,0 +1,36 @@
|
||||
name: Sync to Gitee
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main # Production
|
||||
- develop # Integration
|
||||
- 'release/*' # Release preparation
|
||||
- 'hotfix/*' # Urgent fixes
|
||||
tags:
|
||||
- '*' # All version tags (v1.0.0, etc.)
|
||||
|
||||
jobs:
|
||||
sync:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Checkout Source Code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Sync to Gitee
|
||||
run: |
|
||||
GITEE_URL="https://${{ secrets.GITEE_USERNAME }}:${{ secrets.GITEE_TOKEN }}@gitee.com/hangzhou-hongxiong-intelligent_1/MemoryBear.git"
|
||||
git remote add gitee "$GITEE_URL"
|
||||
|
||||
# 遍历并推送所有分支
|
||||
for branch in $(git branch -r | grep -v HEAD | sed 's/origin\///'); do
|
||||
echo "Syncing branch: $branch"
|
||||
git push -f gitee "origin/$branch:refs/heads/$branch"
|
||||
done
|
||||
|
||||
# 推送所有标签
|
||||
echo "Syncing tags..."
|
||||
git push gitee --tags --force
|
||||
4
.gitignore
vendored
4
.gitignore
vendored
@@ -18,6 +18,7 @@ examples/
|
||||
.kiro
|
||||
.vscode
|
||||
.idea
|
||||
.claude
|
||||
|
||||
# Temporary outputs
|
||||
.DS_Store
|
||||
@@ -25,6 +26,9 @@ examples/
|
||||
time.log
|
||||
celerybeat-schedule.db
|
||||
search_results.json
|
||||
redbear-mem-metrics/
|
||||
redbear-mem-benchmark/
|
||||
pitch-deck/
|
||||
|
||||
api/migrations/versions
|
||||
tmp
|
||||
|
||||
@@ -2,6 +2,10 @@
|
||||
|
||||
# MemoryBear empowers AI with human-like memory capabilities
|
||||
|
||||
[](LICENSE)
|
||||
[](https://www.python.org/)
|
||||
[](https://github.com/SuanmoSuanyangTechnology/MemoryBear/actions/workflows/sync-to-gitee.yml)
|
||||
|
||||
[中文](./README_CN.md) | English
|
||||
|
||||
### [Installation Guide](#memorybear-installation-guide)
|
||||
|
||||
@@ -2,6 +2,10 @@
|
||||
|
||||
# MemoryBear 让AI拥有如同人类一样的记忆
|
||||
|
||||
[](LICENSE)
|
||||
[](https://www.python.org/)
|
||||
[](https://github.com/SuanmoSuanyangTechnology/MemoryBear/actions/workflows/sync-to-gitee.yml)
|
||||
|
||||
中文 | [English](./README.md)
|
||||
|
||||
### [安装教程](#memorybear安装教程)
|
||||
|
||||
@@ -60,7 +60,12 @@ version_path_separator = os # Use os.pathsep. Default configuration used for ne
|
||||
# are written from script.py.mako
|
||||
# output_encoding = utf-8
|
||||
|
||||
sqlalchemy.url = postgresql://user:password@localhost/dbname
|
||||
# Database connection URL - DO NOT hardcode credentials here!
|
||||
# Connection string is set dynamically from environment variables in migrations/env.py
|
||||
# Required env vars: DB_USER, DB_PASSWORD, DB_HOST, DB_PORT, DB_NAME
|
||||
# Example: postgresql://user:password@localhost:5432/dbname
|
||||
; sqlalchemy.url = postgresql://user:password@host:port/dbname
|
||||
sqlalchemy.url = driver://user:password@host:port/dbname
|
||||
|
||||
|
||||
[post_write_hooks]
|
||||
|
||||
@@ -1,10 +1,13 @@
|
||||
import os
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
from typing import Dict, Any, Optional
|
||||
|
||||
import redis.asyncio as redis
|
||||
from redis.asyncio import ConnectionPool
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
# 设置日志记录器
|
||||
@@ -20,6 +23,50 @@ pool = ConnectionPool.from_url(
|
||||
)
|
||||
aio_redis = redis.StrictRedis(connection_pool=pool)
|
||||
|
||||
_REDIS_URL = f"redis://{settings.REDIS_HOST}:{settings.REDIS_PORT}"
|
||||
|
||||
# Thread-local storage for connection pools.
|
||||
# Each thread (and each forked process) gets its own pool to avoid
|
||||
# "Future attached to a different loop" errors in Celery --pool=threads
|
||||
# and stale connections after fork in --pool=prefork.
|
||||
_thread_local = threading.local()
|
||||
|
||||
|
||||
def get_thread_safe_redis() -> redis.StrictRedis:
|
||||
"""Return a Redis client whose connection pool is bound to the current
|
||||
thread, process **and** event loop.
|
||||
|
||||
The pool is recreated when:
|
||||
- The PID changes (fork, Celery --pool=prefork)
|
||||
- The thread has no pool yet (Celery --pool=threads)
|
||||
- The previously-cached event loop has been closed (Celery tasks call
|
||||
``_shutdown_loop_gracefully`` which closes the loop after each run)
|
||||
"""
|
||||
current_pid = os.getpid()
|
||||
cached_loop = getattr(_thread_local, "loop", None)
|
||||
loop_stale = cached_loop is not None and cached_loop.is_closed()
|
||||
|
||||
if not hasattr(_thread_local, "pool") \
|
||||
or getattr(_thread_local, "pid", None) != current_pid \
|
||||
or loop_stale:
|
||||
_thread_local.pid = current_pid
|
||||
# Python 3.10+: get_event_loop() raises RuntimeError in threads
|
||||
# where no loop has been set yet (e.g. Celery --pool=threads).
|
||||
try:
|
||||
_thread_local.loop = asyncio.get_event_loop()
|
||||
except RuntimeError:
|
||||
_thread_local.loop = None
|
||||
_thread_local.pool = ConnectionPool.from_url(
|
||||
_REDIS_URL,
|
||||
db=settings.REDIS_DB,
|
||||
password=settings.REDIS_PASSWORD,
|
||||
decode_responses=True,
|
||||
max_connections=5,
|
||||
health_check_interval=30,
|
||||
)
|
||||
|
||||
return redis.StrictRedis(connection_pool=_thread_local.pool)
|
||||
|
||||
|
||||
async def get_redis_connection():
|
||||
"""获取Redis连接"""
|
||||
@@ -43,10 +90,8 @@ async def aio_redis_set(key: str, val: str | dict, expire: int = None):
|
||||
val = json.dumps(val, ensure_ascii=False)
|
||||
|
||||
if expire is not None:
|
||||
# 设置带过期时间的键值
|
||||
await aio_redis.set(key, val, ex=expire)
|
||||
else:
|
||||
# 设置永久键值
|
||||
await aio_redis.set(key, val)
|
||||
except Exception as e:
|
||||
logger.error(f"Redis set错误: {str(e)}")
|
||||
|
||||
8
api/app/cache/memory/activity_stats_cache.py
vendored
8
api/app/cache/memory/activity_stats_cache.py
vendored
@@ -10,7 +10,7 @@ import logging
|
||||
from typing import Optional, Dict, Any
|
||||
from datetime import datetime
|
||||
|
||||
from app.aioRedis import aio_redis
|
||||
from app.aioRedis import get_thread_safe_redis
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -68,7 +68,7 @@ class ActivityStatsCache:
|
||||
"cached": True,
|
||||
}
|
||||
value = json.dumps(payload, ensure_ascii=False)
|
||||
await aio_redis.set(key, value, ex=expire)
|
||||
await get_thread_safe_redis().set(key, value, ex=expire)
|
||||
logger.info(f"设置活动统计缓存成功: {key}, 过期时间: {expire}秒")
|
||||
return True
|
||||
except Exception as e:
|
||||
@@ -90,7 +90,7 @@ class ActivityStatsCache:
|
||||
"""
|
||||
try:
|
||||
key = cls._get_key(workspace_id)
|
||||
value = await aio_redis.get(key)
|
||||
value = await get_thread_safe_redis().get(key)
|
||||
if value:
|
||||
payload = json.loads(value)
|
||||
logger.info(f"命中活动统计缓存: {key}")
|
||||
@@ -116,7 +116,7 @@ class ActivityStatsCache:
|
||||
"""
|
||||
try:
|
||||
key = cls._get_key(workspace_id)
|
||||
result = await aio_redis.delete(key)
|
||||
result = await get_thread_safe_redis().delete(key)
|
||||
logger.info(f"删除活动统计缓存: {key}, 结果: {result}")
|
||||
return result > 0
|
||||
except Exception as e:
|
||||
|
||||
8
api/app/cache/memory/interest_memory.py
vendored
8
api/app/cache/memory/interest_memory.py
vendored
@@ -9,7 +9,7 @@ import logging
|
||||
from typing import Optional, List, Dict, Any
|
||||
from datetime import datetime
|
||||
|
||||
from app.aioRedis import aio_redis
|
||||
from app.aioRedis import get_thread_safe_redis
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -62,7 +62,7 @@ class InterestMemoryCache:
|
||||
"cached": True,
|
||||
}
|
||||
value = json.dumps(payload, ensure_ascii=False)
|
||||
await aio_redis.set(key, value, ex=expire)
|
||||
await get_thread_safe_redis().set(key, value, ex=expire)
|
||||
logger.info(f"设置兴趣分布缓存成功: {key}, 过期时间: {expire}秒")
|
||||
return True
|
||||
except Exception as e:
|
||||
@@ -86,7 +86,7 @@ class InterestMemoryCache:
|
||||
"""
|
||||
try:
|
||||
key = cls._get_key(end_user_id, language)
|
||||
value = await aio_redis.get(key)
|
||||
value = await get_thread_safe_redis().get(key)
|
||||
if value:
|
||||
payload = json.loads(value)
|
||||
logger.info(f"命中兴趣分布缓存: {key}")
|
||||
@@ -114,7 +114,7 @@ class InterestMemoryCache:
|
||||
"""
|
||||
try:
|
||||
key = cls._get_key(end_user_id, language)
|
||||
result = await aio_redis.delete(key)
|
||||
result = await get_thread_safe_redis().delete(key)
|
||||
logger.info(f"删除兴趣分布缓存: {key}, 结果: {result}")
|
||||
return result > 0
|
||||
except Exception as e:
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import os
|
||||
import platform
|
||||
import re
|
||||
from datetime import timedelta
|
||||
from urllib.parse import quote
|
||||
|
||||
@@ -11,21 +12,24 @@ from app.core.logging_config import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def _mask_url(url: str) -> str:
|
||||
"""隐藏 URL 中的密码部分,适用于 redis:// 和 amqp:// 等协议"""
|
||||
return re.sub(r'(://[^:]*:)[^@]+(@)', r'\1***\2', url)
|
||||
|
||||
# macOS fork() safety - must be set before any Celery initialization
|
||||
if platform.system() == 'Darwin':
|
||||
os.environ.setdefault('OBJC_DISABLE_INITIALIZE_FORK_SAFETY', 'YES')
|
||||
|
||||
# 创建 Celery 应用实例
|
||||
# broker: 任务队列(使用 Redis DB,由 CELERY_BROKER_DB 指定)
|
||||
# backend: 结果存储(使用 Redis DB,由 CELERY_BACKEND_DB 指定)
|
||||
# broker: 优先使用环境变量 CELERY_BROKER_URL(支持 amqp:// 等任意协议),
|
||||
# 未配置则回退到 Redis 方案
|
||||
# backend: 结果存储(使用 Redis)
|
||||
# NOTE: 不要在 .env 中设置 BROKER_URL / RESULT_BACKEND / CELERY_BROKER / CELERY_BACKEND,
|
||||
# 这些名称会被 Celery CLI 的 Click 框架劫持,详见 docs/celery-env-bug-report.md
|
||||
|
||||
# Build canonical broker/backend URLs and force them into os.environ so that
|
||||
# Celery's Settings.broker_url property (which checks CELERY_BROKER_URL first)
|
||||
# cannot be overridden by stray env vars.
|
||||
# See: https://github.com/celery/celery/issues/4284
|
||||
_broker_url = f"redis://:{quote(settings.REDIS_PASSWORD)}@{settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.REDIS_DB_CELERY_BROKER}"
|
||||
_broker_url = os.getenv("CELERY_BROKER_URL") or \
|
||||
f"redis://:{quote(settings.REDIS_PASSWORD)}@{settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.REDIS_DB_CELERY_BROKER}"
|
||||
_backend_url = f"redis://:{quote(settings.REDIS_PASSWORD)}@{settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.REDIS_DB_CELERY_BACKEND}"
|
||||
os.environ["CELERY_BROKER_URL"] = _broker_url
|
||||
os.environ["CELERY_RESULT_BACKEND"] = _backend_url
|
||||
@@ -45,8 +49,8 @@ celery_app = Celery(
|
||||
logger.info(
|
||||
"Celery app initialized",
|
||||
extra={
|
||||
"broker": _broker_url.replace(quote(settings.REDIS_PASSWORD), "***"),
|
||||
"backend": _backend_url.replace(quote(settings.REDIS_PASSWORD), "***"),
|
||||
"broker": _mask_url(_broker_url),
|
||||
"backend": _mask_url(_backend_url),
|
||||
},
|
||||
)
|
||||
# Default queue for unrouted tasks
|
||||
@@ -70,43 +74,51 @@ celery_app.conf.update(
|
||||
# 任务追踪
|
||||
task_track_started=True,
|
||||
task_ignore_result=False,
|
||||
|
||||
|
||||
# 超时设置
|
||||
task_time_limit=3600, # 60分钟硬超时
|
||||
task_soft_time_limit=3000, # 50分钟软超时
|
||||
|
||||
|
||||
# Worker 设置 (per-worker settings are in docker-compose command line)
|
||||
worker_prefetch_multiplier=1, # Don't hoard tasks, fairer distribution
|
||||
|
||||
worker_redirect_stdouts_level='INFO', # stdout/print → INFO instead of WARNING
|
||||
|
||||
# 结果过期时间
|
||||
result_expires=3600, # 结果保存1小时
|
||||
|
||||
|
||||
# 任务确认设置
|
||||
task_acks_late=True,
|
||||
task_reject_on_worker_lost=True,
|
||||
worker_disable_rate_limits=True,
|
||||
|
||||
|
||||
# FLower setting
|
||||
worker_send_task_events=True,
|
||||
task_send_sent_event=True,
|
||||
|
||||
|
||||
# task routing
|
||||
task_routes={
|
||||
# Memory tasks → memory_tasks queue (threads worker)
|
||||
'app.core.memory.agent.read_message_priority': {'queue': 'memory_tasks'},
|
||||
'app.core.memory.agent.read_message': {'queue': 'memory_tasks'},
|
||||
'app.core.memory.agent.write_message': {'queue': 'memory_tasks'},
|
||||
|
||||
'app.tasks.write_perceptual_memory': {'queue': 'memory_tasks'},
|
||||
|
||||
# Long-term storage tasks → memory_tasks queue (batched write strategies)
|
||||
'app.core.memory.agent.long_term_storage.window': {'queue': 'memory_tasks'},
|
||||
'app.core.memory.agent.long_term_storage.time': {'queue': 'memory_tasks'},
|
||||
'app.core.memory.agent.long_term_storage.aggregate': {'queue': 'memory_tasks'},
|
||||
|
||||
|
||||
# Clustering tasks → memory_tasks queue (使用相同的 worker,避免 macOS fork 问题)
|
||||
'app.tasks.run_incremental_clustering': {'queue': 'memory_tasks'},
|
||||
|
||||
# Metadata extraction → memory_tasks queue
|
||||
'app.tasks.extract_user_metadata': {'queue': 'memory_tasks'},
|
||||
|
||||
# Document tasks → document_tasks queue (prefork worker)
|
||||
'app.core.rag.tasks.parse_document': {'queue': 'document_tasks'},
|
||||
'app.core.rag.tasks.build_graphrag_for_kb': {'queue': 'document_tasks'},
|
||||
'app.core.rag.tasks.sync_knowledge_for_kb': {'queue': 'document_tasks'},
|
||||
|
||||
|
||||
# Beat/periodic tasks → periodic_tasks queue (dedicated periodic worker)
|
||||
'app.tasks.workspace_reflection_task': {'queue': 'periodic_tasks'},
|
||||
'app.tasks.regenerate_memory_cache': {'queue': 'periodic_tasks'},
|
||||
@@ -115,6 +127,7 @@ celery_app.conf.update(
|
||||
'app.tasks.update_implicit_emotions_storage': {'queue': 'periodic_tasks'},
|
||||
'app.tasks.init_implicit_emotions_for_users': {'queue': 'periodic_tasks'},
|
||||
'app.tasks.init_interest_distribution_for_users': {'queue': 'periodic_tasks'},
|
||||
'app.tasks.init_community_clustering_for_users': {'queue': 'periodic_tasks'},
|
||||
},
|
||||
)
|
||||
|
||||
@@ -131,7 +144,7 @@ implicit_emotions_update_schedule = crontab(
|
||||
minute=settings.IMPLICIT_EMOTIONS_UPDATE_MINUTE,
|
||||
)
|
||||
|
||||
#构建定时任务配置
|
||||
# 构建定时任务配置
|
||||
beat_schedule_config = {
|
||||
"run-workspace-reflection": {
|
||||
"task": "app.tasks.workspace_reflection_task",
|
||||
|
||||
@@ -13,4 +13,4 @@ logger.info("Celery worker logging initialized")
|
||||
# 导入任务模块以注册任务
|
||||
import app.tasks
|
||||
|
||||
__all__ = ['celery_app']
|
||||
__all__ = ['celery_app']
|
||||
|
||||
@@ -8,6 +8,7 @@ from fastapi import APIRouter
|
||||
from . import (
|
||||
api_key_controller,
|
||||
app_controller,
|
||||
app_log_controller,
|
||||
auth_controller,
|
||||
chunk_controller,
|
||||
document_controller,
|
||||
@@ -16,6 +17,7 @@ from . import (
|
||||
file_controller,
|
||||
file_storage_controller,
|
||||
home_page_controller,
|
||||
i18n_controller,
|
||||
implicit_memory_controller,
|
||||
knowledge_controller,
|
||||
knowledgeshare_controller,
|
||||
@@ -68,6 +70,7 @@ manager_router.include_router(chunk_controller.router)
|
||||
manager_router.include_router(test_controller.router)
|
||||
manager_router.include_router(knowledgeshare_controller.router)
|
||||
manager_router.include_router(app_controller.router)
|
||||
manager_router.include_router(app_log_controller.router)
|
||||
manager_router.include_router(upload_controller.router)
|
||||
manager_router.include_router(memory_agent_controller.router)
|
||||
manager_router.include_router(memory_dashboard_controller.router)
|
||||
@@ -94,5 +97,6 @@ manager_router.include_router(memory_working_controller.router)
|
||||
manager_router.include_router(file_storage_controller.router)
|
||||
manager_router.include_router(ontology_controller.router)
|
||||
manager_router.include_router(skill_controller.router)
|
||||
manager_router.include_router(i18n_controller.router)
|
||||
|
||||
__all__ = ["manager_router"]
|
||||
|
||||
@@ -53,6 +53,7 @@ def list_apps(
|
||||
status: str | None = None,
|
||||
search: str | None = None,
|
||||
include_shared: bool = True,
|
||||
shared_only: bool = False,
|
||||
page: int = 1,
|
||||
pagesize: int = 10,
|
||||
ids: Optional[str] = None,
|
||||
@@ -64,16 +65,42 @@ def list_apps(
|
||||
- 默认包含本工作空间的应用和分享给本工作空间的应用
|
||||
- 设置 include_shared=false 可以只查看本工作空间的应用
|
||||
- 当提供 ids 参数时,按逗号分割获取指定应用,不分页
|
||||
- search 参数支持:应用名称模糊搜索、API Key 精确搜索
|
||||
"""
|
||||
from sqlalchemy import select as sa_select
|
||||
from app.models.api_key_model import ApiKey
|
||||
|
||||
workspace_id = current_user.current_workspace_id
|
||||
service = app_service.AppService(db)
|
||||
|
||||
# 当 ids 存在且不为 None 时,根据 ids 获取应用
|
||||
# 通过 search 参数搜索:支持应用名称模糊搜索和 API Key 精确搜索
|
||||
if search:
|
||||
search = search.strip()
|
||||
# 尝试作为 API Key 精确匹配(API Key 通常较长)
|
||||
if len(search) >= 10:
|
||||
matched_id = db.execute(
|
||||
sa_select(ApiKey.resource_id).where(
|
||||
ApiKey.workspace_id == workspace_id,
|
||||
ApiKey.api_key == search,
|
||||
ApiKey.resource_id.isnot(None),
|
||||
)
|
||||
).scalar_one_or_none()
|
||||
if matched_id:
|
||||
# 找到 API Key,直接返回关联的应用
|
||||
ids = str(matched_id)
|
||||
|
||||
# 当 ids 存在时,根据 ids 获取应用(不分页)
|
||||
if ids is not None:
|
||||
app_ids = [app_id.strip() for app_id in ids.split(',') if app_id.strip()]
|
||||
items_orm = app_service.get_apps_by_ids(db, app_ids, workspace_id)
|
||||
items = [service._convert_to_schema(app, workspace_id) for app in items_orm]
|
||||
return success(data=items)
|
||||
if app_ids:
|
||||
items_orm = app_service.get_apps_by_ids(db, app_ids, workspace_id)
|
||||
items = [service._convert_to_schema(app, workspace_id) for app in items_orm]
|
||||
# 返回标准分页格式
|
||||
meta = PageMeta(page=1, pagesize=len(items), total=len(items), hasnext=False)
|
||||
return success(data=PageData(page=meta, items=items))
|
||||
# ids 为空时,返回空列表
|
||||
meta = PageMeta(page=1, pagesize=0, total=0, hasnext=False)
|
||||
return success(data=PageData(page=meta, items=[]))
|
||||
|
||||
# 正常分页查询
|
||||
items_orm, total = app_service.list_apps(
|
||||
@@ -84,6 +111,7 @@ def list_apps(
|
||||
status=status,
|
||||
search=search,
|
||||
include_shared=include_shared,
|
||||
shared_only=shared_only,
|
||||
page=page,
|
||||
pagesize=pagesize,
|
||||
)
|
||||
@@ -93,6 +121,37 @@ def list_apps(
|
||||
return success(data=PageData(page=meta, items=items))
|
||||
|
||||
|
||||
@router.get("/my-shared-out", summary="列出本工作空间主动分享出去的记录")
|
||||
@cur_workspace_access_guard()
|
||||
def list_my_shared_out(
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
):
|
||||
"""列出本工作空间主动分享给其他工作空间的所有记录(我的共享)"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
service = app_service.AppService(db)
|
||||
shares = service.list_my_shared_out(workspace_id=workspace_id)
|
||||
data = [app_schema.AppShare.model_validate(s) for s in shares]
|
||||
return success(data=data)
|
||||
|
||||
|
||||
@router.delete("/share/{target_workspace_id}", summary="取消对某工作空间的所有应用分享")
|
||||
@cur_workspace_access_guard()
|
||||
def unshare_all_apps_to_workspace(
|
||||
target_workspace_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
):
|
||||
"""Cancel all app shares from current workspace to a target workspace."""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
service = app_service.AppService(db)
|
||||
count = service.unshare_all_apps_to_workspace(
|
||||
target_workspace_id=target_workspace_id,
|
||||
workspace_id=workspace_id
|
||||
)
|
||||
return success(msg=f"已取消 {count} 个应用的分享", data={"count": count})
|
||||
|
||||
|
||||
@router.get("/{app_id}", summary="获取应用详情")
|
||||
@cur_workspace_access_guard()
|
||||
def get_app(
|
||||
@@ -161,6 +220,7 @@ def delete_app(
|
||||
def copy_app(
|
||||
app_id: uuid.UUID,
|
||||
new_name: Optional[str] = None,
|
||||
payload: app_schema.CopyAppRequest = None,
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
):
|
||||
@@ -172,6 +232,8 @@ def copy_app(
|
||||
- 不影响原应用
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
# body takes precedence over query param for backward compatibility
|
||||
new_name = (payload.new_name if payload else None) or new_name
|
||||
logger.info(
|
||||
"用户请求复制应用",
|
||||
extra={
|
||||
@@ -221,6 +283,36 @@ def get_agent_config(
|
||||
return success(data=app_schema.AgentConfig.model_validate(cfg))
|
||||
|
||||
|
||||
@router.get("/{app_id}/opening", summary="获取应用开场白配置")
|
||||
@cur_workspace_access_guard()
|
||||
def get_opening(
|
||||
app_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
):
|
||||
"""返回开场白文本和预设问题,供前端对话界面初始化时展示"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
# 根据应用类型获取 features
|
||||
from app.models.app_model import App as AppModel
|
||||
app = db.get(AppModel, app_id)
|
||||
if app and app.type == "workflow":
|
||||
cfg = app_service.get_workflow_config(db=db, app_id=app_id, workspace_id=workspace_id)
|
||||
features = cfg.features or {}
|
||||
else:
|
||||
cfg = app_service.get_agent_config(db, app_id=app_id, workspace_id=workspace_id)
|
||||
features = cfg.features or {}
|
||||
if hasattr(features, "model_dump"):
|
||||
features = features.model_dump()
|
||||
|
||||
opening = features.get("opening_statement", {})
|
||||
return success(data=app_schema.OpeningResponse(
|
||||
enabled=opening.get("enabled", False),
|
||||
statement=opening.get("statement"),
|
||||
suggested_questions=opening.get("suggested_questions", []),
|
||||
))
|
||||
|
||||
|
||||
@router.post("/{app_id}/publish", summary="发布应用(生成不可变快照)")
|
||||
@cur_workspace_access_guard()
|
||||
def publish_app(
|
||||
@@ -302,7 +394,8 @@ def share_app(
|
||||
app_id=app_id,
|
||||
target_workspace_ids=payload.target_workspace_ids,
|
||||
user_id=current_user.id,
|
||||
workspace_id=workspace_id
|
||||
workspace_id=workspace_id,
|
||||
permission=payload.permission
|
||||
)
|
||||
|
||||
data = [app_schema.AppShare.model_validate(s) for s in shares]
|
||||
@@ -333,6 +426,32 @@ def unshare_app(
|
||||
return success(msg="应用分享已取消")
|
||||
|
||||
|
||||
@router.patch("/{app_id}/share/{target_workspace_id}", summary="更新共享权限")
|
||||
@cur_workspace_access_guard()
|
||||
def update_share_permission(
|
||||
app_id: uuid.UUID,
|
||||
target_workspace_id: uuid.UUID,
|
||||
payload: app_schema.UpdateSharePermissionRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
):
|
||||
"""更新共享权限(readonly <-> editable)
|
||||
|
||||
- 只能修改自己工作空间应用的共享权限
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
service = app_service.AppService(db)
|
||||
share = service.update_share_permission(
|
||||
app_id=app_id,
|
||||
target_workspace_id=target_workspace_id,
|
||||
permission=payload.permission,
|
||||
workspace_id=workspace_id
|
||||
)
|
||||
|
||||
return success(data=app_schema.AppShare.model_validate(share))
|
||||
|
||||
|
||||
@router.get("/{app_id}/shares", summary="列出应用的分享记录")
|
||||
@cur_workspace_access_guard()
|
||||
def list_app_shares(
|
||||
@@ -356,6 +475,46 @@ def list_app_shares(
|
||||
return success(data=data)
|
||||
|
||||
|
||||
@router.delete("/shared/{source_workspace_id}", summary="批量移除某来源工作空间的所有共享应用")
|
||||
@cur_workspace_access_guard()
|
||||
def remove_all_shared_apps_from_workspace(
|
||||
source_workspace_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
):
|
||||
"""Remove all shared apps from a specific source workspace (recipient operation)."""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
service = app_service.AppService(db)
|
||||
count = service.remove_all_shared_apps_from_workspace(
|
||||
source_workspace_id=source_workspace_id,
|
||||
workspace_id=workspace_id
|
||||
)
|
||||
return success(msg=f"已移除 {count} 个共享应用", data={"count": count})
|
||||
|
||||
|
||||
@router.delete("/{app_id}/shared", summary="移除共享给我的应用")
|
||||
@cur_workspace_access_guard()
|
||||
def remove_shared_app(
|
||||
app_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
):
|
||||
"""被共享者从自己的工作空间移除共享应用
|
||||
|
||||
- 不会删除源应用,只删除共享记录
|
||||
- 只能移除共享给自己工作空间的应用
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
service = app_service.AppService(db)
|
||||
service.remove_shared_app(
|
||||
app_id=app_id,
|
||||
workspace_id=workspace_id
|
||||
)
|
||||
|
||||
return success(msg="已移除共享应用")
|
||||
|
||||
|
||||
@router.post("/{app_id}/draft/run", summary="试运行 Agent(使用当前草稿配置)")
|
||||
@cur_workspace_access_guard()
|
||||
async def draft_run(
|
||||
@@ -396,7 +555,7 @@ async def draft_run(
|
||||
# 提前验证和准备(在流式响应开始前完成)
|
||||
from app.services.app_service import AppService
|
||||
from app.services.multi_agent_service import MultiAgentService
|
||||
from app.models import AgentConfig, ModelConfig
|
||||
from app.models import AgentConfig, ModelConfig, AppRelease
|
||||
from sqlalchemy import select
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.services.draft_run_service import AgentRunService
|
||||
@@ -413,11 +572,12 @@ async def draft_run(
|
||||
service._validate_app_accessible(app, workspace_id)
|
||||
|
||||
if payload.user_id is None:
|
||||
# 先获取 app 的 workspace_id
|
||||
end_user_repo = EndUserRepository(db)
|
||||
new_end_user = end_user_repo.get_or_create_end_user(
|
||||
app_id=app_id,
|
||||
workspace_id=app.workspace_id,
|
||||
other_id=str(current_user.id),
|
||||
original_user_id=str(current_user.id) # Save original user_id to other_id
|
||||
)
|
||||
payload.user_id = str(new_end_user.id)
|
||||
|
||||
@@ -434,18 +594,29 @@ async def draft_run(
|
||||
service._check_agent_config(app_id)
|
||||
|
||||
# 2. 获取 Agent 配置
|
||||
stmt = select(AgentConfig).where(AgentConfig.app_id == app_id)
|
||||
agent_cfg = db.scalars(stmt).first()
|
||||
if not agent_cfg:
|
||||
raise BusinessException("Agent 配置不存在", BizCode.AGENT_CONFIG_MISSING)
|
||||
# 共享应用:从最新发布版本读配置快照,而非草稿
|
||||
is_shared = app.workspace_id != workspace_id
|
||||
if is_shared:
|
||||
if not app.current_release_id:
|
||||
raise BusinessException("该应用尚未发布,无法使用", BizCode.AGENT_CONFIG_MISSING)
|
||||
release = db.get(AppRelease, app.current_release_id)
|
||||
if not release:
|
||||
raise BusinessException("发布版本不存在", BizCode.AGENT_CONFIG_MISSING)
|
||||
agent_cfg = service._agent_config_from_release(release)
|
||||
model_config = db.get(ModelConfig, release.default_model_config_id) if release.default_model_config_id else None
|
||||
else:
|
||||
stmt = select(AgentConfig).where(AgentConfig.app_id == app_id)
|
||||
agent_cfg = db.scalars(stmt).first()
|
||||
if not agent_cfg:
|
||||
raise BusinessException("Agent 配置不存在", BizCode.AGENT_CONFIG_MISSING)
|
||||
|
||||
# 3. 获取模型配置
|
||||
model_config = None
|
||||
if agent_cfg.default_model_config_id:
|
||||
model_config = db.get(ModelConfig, agent_cfg.default_model_config_id)
|
||||
if not model_config:
|
||||
from app.core.exceptions import ResourceNotFoundException
|
||||
raise ResourceNotFoundException("模型配置", str(agent_cfg.default_model_config_id))
|
||||
# 3. 获取模型配置
|
||||
model_config = None
|
||||
if agent_cfg.default_model_config_id:
|
||||
model_config = db.get(ModelConfig, agent_cfg.default_model_config_id)
|
||||
if not model_config:
|
||||
from app.core.exceptions import ResourceNotFoundException
|
||||
raise ResourceNotFoundException("模型配置", str(agent_cfg.default_model_config_id))
|
||||
|
||||
# 流式返回
|
||||
if payload.stream:
|
||||
@@ -601,7 +772,17 @@ async def draft_run(
|
||||
msg="多 Agent 任务执行成功"
|
||||
)
|
||||
elif app.type == AppType.WORKFLOW: # 工作流
|
||||
config = workflow_service.check_config(app_id)
|
||||
# 共享应用:从最新发布版本读配置快照,而非草稿
|
||||
is_shared = app.workspace_id != workspace_id
|
||||
if is_shared:
|
||||
if not app.current_release_id:
|
||||
raise BusinessException("该应用尚未发布,无法使用", BizCode.AGENT_CONFIG_MISSING)
|
||||
release = db.get(AppRelease, app.current_release_id)
|
||||
if not release:
|
||||
raise BusinessException("发布版本不存在", BizCode.AGENT_CONFIG_MISSING)
|
||||
config = service._workflow_config_from_release(release)
|
||||
else:
|
||||
config = workflow_service.check_config(app_id)
|
||||
# 3. 流式返回
|
||||
if payload.stream:
|
||||
logger.debug(
|
||||
@@ -744,6 +925,16 @@ async def draft_run_compare(
|
||||
raise BusinessException("只有 Agent 类型应用支持试运行", BizCode.APP_TYPE_NOT_SUPPORTED)
|
||||
service._validate_app_accessible(app, workspace_id)
|
||||
|
||||
if payload.user_id is None:
|
||||
# 先获取 app 的 workspace_id
|
||||
end_user_repo = EndUserRepository(db)
|
||||
new_end_user = end_user_repo.get_or_create_end_user(
|
||||
app_id=app_id,
|
||||
workspace_id=app.workspace_id,
|
||||
other_id=str(current_user.id),
|
||||
)
|
||||
payload.user_id = str(new_end_user.id)
|
||||
|
||||
# 2. 获取 Agent 配置
|
||||
from sqlalchemy import select
|
||||
from app.models import AgentConfig
|
||||
@@ -789,6 +980,13 @@ async def draft_run_compare(
|
||||
"conversation_id": model_item.conversation_id # 传递每个模型的 conversation_id
|
||||
})
|
||||
|
||||
# 从 features 中读取功能开关(与 draft_run 保持一致)
|
||||
features_config: dict = agent_cfg.features or {}
|
||||
if hasattr(features_config, 'model_dump'):
|
||||
features_config = features_config.model_dump()
|
||||
web_search_feature = features_config.get("web_search", {})
|
||||
web_search = isinstance(web_search_feature, dict) and web_search_feature.get("enabled", False)
|
||||
|
||||
# 流式返回
|
||||
if payload.stream:
|
||||
async def event_generator():
|
||||
@@ -800,11 +998,11 @@ async def draft_run_compare(
|
||||
message=payload.message,
|
||||
workspace_id=workspace_id,
|
||||
conversation_id=payload.conversation_id,
|
||||
user_id=payload.user_id or str(current_user.id),
|
||||
user_id=payload.user_id,
|
||||
variables=payload.variables,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
web_search=True,
|
||||
web_search=web_search,
|
||||
memory=True,
|
||||
parallel=payload.parallel,
|
||||
timeout=payload.timeout or 60,
|
||||
@@ -831,11 +1029,11 @@ async def draft_run_compare(
|
||||
message=payload.message,
|
||||
workspace_id=workspace_id,
|
||||
conversation_id=payload.conversation_id,
|
||||
user_id=payload.user_id or str(current_user.id),
|
||||
user_id=payload.user_id,
|
||||
variables=payload.variables,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
web_search=True,
|
||||
web_search=web_search,
|
||||
memory=True,
|
||||
parallel=payload.parallel,
|
||||
timeout=payload.timeout or 60,
|
||||
@@ -881,6 +1079,14 @@ async def update_workflow_config(
|
||||
current_user: Annotated[User, Depends(get_current_user)]
|
||||
):
|
||||
workspace_id = current_user.current_workspace_id
|
||||
if payload.variables:
|
||||
from app.services.workflow_service import WorkflowService
|
||||
resolved = await WorkflowService(db)._resolve_variables_file_defaults(
|
||||
[v.model_dump() for v in payload.variables]
|
||||
)
|
||||
# Patch default values back into VariableDefinition objects
|
||||
for var_def, resolved_def in zip(payload.variables, resolved):
|
||||
var_def.default = resolved_def.get("default", var_def.default)
|
||||
cfg = app_service.update_workflow_config(db, app_id=app_id, data=payload, workspace_id=workspace_id)
|
||||
return success(data=WorkflowConfigSchema.model_validate(cfg))
|
||||
|
||||
|
||||
89
api/app/controllers/app_log_controller.py
Normal file
89
api/app/controllers/app_log_controller.py
Normal file
@@ -0,0 +1,89 @@
|
||||
"""应用日志(消息记录)接口"""
|
||||
import uuid
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.core.response_utils import success
|
||||
from app.db import get_db
|
||||
from app.dependencies import get_current_user, cur_workspace_access_guard
|
||||
from app.schemas.app_log_schema import AppLogConversation, AppLogConversationDetail
|
||||
from app.schemas.response_schema import PageData, PageMeta
|
||||
from app.services.app_service import AppService
|
||||
from app.services.app_log_service import AppLogService
|
||||
|
||||
router = APIRouter(prefix="/apps", tags=["App Logs"])
|
||||
logger = get_business_logger()
|
||||
|
||||
|
||||
@router.get("/{app_id}/logs", summary="应用日志 - 会话列表")
|
||||
@cur_workspace_access_guard()
|
||||
def list_app_logs(
|
||||
app_id: uuid.UUID,
|
||||
page: int = Query(1, ge=1),
|
||||
pagesize: int = Query(20, ge=1, le=100),
|
||||
is_draft: Optional[bool] = None,
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
):
|
||||
"""查看应用下所有会话记录(分页)
|
||||
|
||||
- 支持按 is_draft 筛选(草稿会话 / 发布会话)
|
||||
- 按最新更新时间倒序排列
|
||||
- 所有人(包括共享者和被共享者)都只能查看自己的会话记录
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
# 验证应用访问权限
|
||||
app_service = AppService(db)
|
||||
app_service.get_app(app_id, workspace_id)
|
||||
|
||||
# 使用 Service 层查询
|
||||
log_service = AppLogService(db)
|
||||
conversations, total = log_service.list_conversations(
|
||||
app_id=app_id,
|
||||
workspace_id=workspace_id,
|
||||
page=page,
|
||||
pagesize=pagesize,
|
||||
is_draft=is_draft
|
||||
)
|
||||
|
||||
items = [AppLogConversation.model_validate(c) for c in conversations]
|
||||
meta = PageMeta(page=page, pagesize=pagesize, total=total, hasnext=(page * pagesize) < total)
|
||||
|
||||
return success(data=PageData(page=meta, items=items))
|
||||
|
||||
|
||||
@router.get("/{app_id}/logs/{conversation_id}", summary="应用日志 - 会话消息详情")
|
||||
@cur_workspace_access_guard()
|
||||
def get_app_log_detail(
|
||||
app_id: uuid.UUID,
|
||||
conversation_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
):
|
||||
"""查看某会话的完整消息记录
|
||||
|
||||
- 返回会话基本信息 + 所有消息(按时间正序)
|
||||
- 消息 meta_data 包含模型名、token 用量等信息
|
||||
- 所有人(包括共享者和被共享者)都只能查看自己的会话详情
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
# 验证应用访问权限
|
||||
app_service = AppService(db)
|
||||
app_service.get_app(app_id, workspace_id)
|
||||
|
||||
# 使用 Service 层查询
|
||||
log_service = AppLogService(db)
|
||||
conversation = log_service.get_conversation_detail(
|
||||
app_id=app_id,
|
||||
conversation_id=conversation_id,
|
||||
workspace_id=workspace_id
|
||||
)
|
||||
|
||||
detail = AppLogConversationDetail.model_validate(conversation)
|
||||
|
||||
return success(data=detail)
|
||||
@@ -1,4 +1,5 @@
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Callable
|
||||
from fastapi import APIRouter, Depends
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@@ -16,6 +17,7 @@ from app.core.exceptions import BusinessException
|
||||
from app.core.error_codes import BizCode
|
||||
from app.dependencies import get_current_user, oauth2_scheme
|
||||
from app.models.user_model import User
|
||||
from app.i18n.dependencies import get_translator
|
||||
|
||||
# 获取专用日志器
|
||||
auth_logger = get_auth_logger()
|
||||
@@ -26,7 +28,8 @@ router = APIRouter(tags=["Authentication"])
|
||||
@router.post("/token", response_model=ApiResponse)
|
||||
async def login_for_access_token(
|
||||
form_data: TokenRequest,
|
||||
db: Session = Depends(get_db)
|
||||
db: Session = Depends(get_db),
|
||||
t: Callable = Depends(get_translator)
|
||||
):
|
||||
"""用户登录获取token"""
|
||||
auth_logger.info(f"用户登录请求: {form_data.email}")
|
||||
@@ -40,36 +43,38 @@ async def login_for_access_token(
|
||||
invite_info = workspace_service.validate_invite_token(db, form_data.invite)
|
||||
|
||||
if not invite_info.is_valid:
|
||||
raise BusinessException("邀请码无效或已过期", code=BizCode.BAD_REQUEST)
|
||||
raise BusinessException(t("auth.invite.invalid"), code=BizCode.BAD_REQUEST)
|
||||
|
||||
if invite_info.email != form_data.email:
|
||||
raise BusinessException("邀请邮箱与登录邮箱不匹配", code=BizCode.BAD_REQUEST)
|
||||
raise BusinessException(t("auth.invite.email_mismatch"), code=BizCode.BAD_REQUEST)
|
||||
auth_logger.info(f"邀请码验证成功: workspace={invite_info.workspace_name}")
|
||||
try:
|
||||
# 尝试认证用户
|
||||
user = auth_service.authenticate_user_or_raise(db, form_data.email, form_data.password)
|
||||
auth_logger.info(f"用户认证成功: {user.email} (ID: {user.id})")
|
||||
if form_data.invite:
|
||||
auth_service.bind_workspace_with_invite(db=db,
|
||||
user=user,
|
||||
invite_token=form_data.invite,
|
||||
workspace_id=invite_info.workspace_id)
|
||||
auth_service.bind_workspace_with_invite(
|
||||
db=db,
|
||||
user=user,
|
||||
invite_token=form_data.invite,
|
||||
workspace_id=invite_info.workspace_id
|
||||
)
|
||||
except BusinessException as e:
|
||||
# 用户不存在且有邀请码,尝试注册
|
||||
if e.code == BizCode.USER_NOT_FOUND:
|
||||
auth_logger.info(f"用户不存在,使用邀请码注册: {form_data.email}")
|
||||
user = auth_service.register_user_with_invite(
|
||||
db=db,
|
||||
email=form_data.email,
|
||||
username=form_data.username,
|
||||
password=form_data.password,
|
||||
invite_token=form_data.invite,
|
||||
workspace_id=invite_info.workspace_id
|
||||
)
|
||||
db=db,
|
||||
email=form_data.email,
|
||||
username=form_data.username,
|
||||
password=form_data.password,
|
||||
invite_token=form_data.invite,
|
||||
workspace_id=invite_info.workspace_id
|
||||
)
|
||||
elif e.code == BizCode.PASSWORD_ERROR:
|
||||
# 用户存在但密码错误
|
||||
auth_logger.warning(f"接受邀请失败,密码验证错误: {form_data.email}")
|
||||
raise BusinessException("接受邀请失败,密码验证错误", BizCode.LOGIN_FAILED)
|
||||
raise BusinessException(t("auth.invite.password_verification_failed"), BizCode.LOGIN_FAILED)
|
||||
else:
|
||||
# 其他认证失败情况,直接抛出
|
||||
raise
|
||||
@@ -82,7 +87,7 @@ async def login_for_access_token(
|
||||
except BusinessException as e:
|
||||
|
||||
# 其他认证失败情况,直接抛出
|
||||
raise BusinessException(e.message,BizCode.LOGIN_FAILED)
|
||||
raise BusinessException(e.message, BizCode.LOGIN_FAILED)
|
||||
|
||||
# 创建 tokens
|
||||
access_token, access_token_id = security.create_access_token(subject=user.id)
|
||||
@@ -110,14 +115,15 @@ async def login_for_access_token(
|
||||
expires_at=access_expires_at,
|
||||
refresh_expires_at=refresh_expires_at
|
||||
),
|
||||
msg="登录成功"
|
||||
msg=t("auth.login.success")
|
||||
)
|
||||
|
||||
|
||||
@router.post("/refresh", response_model=ApiResponse)
|
||||
async def refresh_token(
|
||||
refresh_request: RefreshTokenRequest,
|
||||
db: Session = Depends(get_db)
|
||||
db: Session = Depends(get_db),
|
||||
t: Callable = Depends(get_translator)
|
||||
):
|
||||
"""刷新token"""
|
||||
auth_logger.info("收到token刷新请求")
|
||||
@@ -125,18 +131,18 @@ async def refresh_token(
|
||||
# 验证 refresh token
|
||||
userId = security.verify_token(refresh_request.refresh_token, "refresh")
|
||||
if not userId:
|
||||
raise BusinessException("无效的refresh token", code=BizCode.TOKEN_INVALID)
|
||||
raise BusinessException(t("auth.token.invalid_refresh_token"), code=BizCode.TOKEN_INVALID)
|
||||
|
||||
# 检查用户是否存在
|
||||
user = auth_service.get_user_by_id(db, userId)
|
||||
if not user:
|
||||
raise BusinessException("用户不存在", code=BizCode.USER_NOT_FOUND)
|
||||
raise BusinessException(t("auth.user.not_found"), code=BizCode.USER_NO_ACCESS)
|
||||
|
||||
# 检查 refresh token 黑名单
|
||||
if settings.ENABLE_SINGLE_SESSION:
|
||||
refresh_token_id = security.get_token_id(refresh_request.refresh_token)
|
||||
if refresh_token_id and await SessionService.is_token_blacklisted(refresh_token_id):
|
||||
raise BusinessException("Refresh token已失效", code=BizCode.TOKEN_BLACKLISTED)
|
||||
raise BusinessException(t("auth.token.refresh_token_blacklisted"), code=BizCode.TOKEN_BLACKLISTED)
|
||||
|
||||
# 生成新 tokens
|
||||
new_access_token, new_access_token_id = security.create_access_token(subject=user.id)
|
||||
@@ -167,7 +173,7 @@ async def refresh_token(
|
||||
expires_at=access_expires_at,
|
||||
refresh_expires_at=refresh_expires_at
|
||||
),
|
||||
msg="token刷新成功"
|
||||
msg=t("auth.token.refresh_success")
|
||||
)
|
||||
|
||||
|
||||
@@ -175,14 +181,15 @@ async def refresh_token(
|
||||
async def logout(
|
||||
token: str = Depends(oauth2_scheme),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
db: Session = Depends(get_db),
|
||||
t: Callable = Depends(get_translator)
|
||||
):
|
||||
"""登出当前用户:加入token黑名单并清理会话"""
|
||||
auth_logger.info(f"用户 {current_user.username} 请求登出")
|
||||
|
||||
token_id = security.get_token_id(token)
|
||||
if not token_id:
|
||||
raise BusinessException("无效的access token", code=BizCode.TOKEN_INVALID)
|
||||
raise BusinessException(t("auth.token.invalid"), code=BizCode.TOKEN_INVALID)
|
||||
|
||||
# 加入黑名单
|
||||
await SessionService.blacklist_token(token_id)
|
||||
@@ -192,5 +199,5 @@ async def logout(
|
||||
await SessionService.clear_user_session(current_user.username)
|
||||
|
||||
auth_logger.info(f"用户 {current_user.username} 登出成功")
|
||||
return success(msg="登出成功")
|
||||
return success(msg=t("auth.logout.success"))
|
||||
|
||||
|
||||
@@ -23,6 +23,7 @@ from app.models.user_model import User
|
||||
from app.schemas import chunk_schema
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.services import knowledge_service, document_service, file_service, knowledgeshare_service
|
||||
from app.services.model_service import ModelApiKeyService
|
||||
|
||||
# Obtain a dedicated API logger
|
||||
api_logger = get_api_logger()
|
||||
@@ -460,18 +461,20 @@ async def retrieve_chunks(
|
||||
if retrieve_data.retrieve_type == chunk_schema.RetrieveType.Graph:
|
||||
kb_ids = [str(kb_id) for kb_id in private_kb_ids]
|
||||
workspace_ids = [str(workspace_id) for workspace_id in private_workspace_ids]
|
||||
llm_key = ModelApiKeyService.get_available_api_key(db, db_knowledge.llm_id)
|
||||
emb_key = ModelApiKeyService.get_available_api_key(db, db_knowledge.embedding_id)
|
||||
# Prepare to configure chat_mdl、embedding_model、vision_model information
|
||||
chat_model = Base(
|
||||
key=db_knowledge.llm.api_keys[0].api_key,
|
||||
model_name=db_knowledge.llm.api_keys[0].model_name,
|
||||
base_url=db_knowledge.llm.api_keys[0].api_base
|
||||
key=llm_key.api_key,
|
||||
model_name=llm_key.model_name,
|
||||
base_url=llm_key.api_base
|
||||
)
|
||||
embedding_model = OpenAIEmbed(
|
||||
key=db_knowledge.embedding.api_keys[0].api_key,
|
||||
model_name=db_knowledge.embedding.api_keys[0].model_name,
|
||||
base_url=db_knowledge.embedding.api_keys[0].api_base
|
||||
key=emb_key.api_key,
|
||||
model_name=emb_key.model_name,
|
||||
base_url=emb_key.api_base
|
||||
)
|
||||
doc = kg_retriever.retrieval(question=retrieve_data.query, workspace_ids=workspace_ids, kb_ids= kb_ids, emb_mdl=embedding_model, llm=chat_model)
|
||||
doc = kg_retriever.retrieval(question=retrieve_data.query, workspace_ids=workspace_ids, kb_ids=kb_ids, emb_mdl=embedding_model, llm=chat_model)
|
||||
if doc:
|
||||
rs.insert(0, doc)
|
||||
return success(data=jsonable_encoder(rs), msg="retrieval successful")
|
||||
@@ -314,8 +314,10 @@ async def parse_documents(
|
||||
)
|
||||
|
||||
# 4. Check if the file exists
|
||||
api_logger.debug(f"Constructed file path: {file_path}")
|
||||
api_logger.debug(f"File metadata - kb_id: {db_file.kb_id}, parent_id: {db_file.parent_id}, file_id: {db_file.id}, extension: {db_file.file_ext}")
|
||||
if not os.path.exists(file_path):
|
||||
api_logger.warning(f"File not found (possibly deleted): file_path={file_path}")
|
||||
api_logger.error(f"File not found (possibly deleted): file_path={file_path}, file_id={db_file.id}, document_id={document_id}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="File not found (possibly deleted)"
|
||||
|
||||
@@ -14,8 +14,11 @@ Routes:
|
||||
import os
|
||||
import uuid
|
||||
from typing import Any
|
||||
import httpx
|
||||
import mimetypes
|
||||
from urllib.parse import urlparse, unquote
|
||||
|
||||
from fastapi import APIRouter, Depends, File, HTTPException, UploadFile, status
|
||||
from fastapi import APIRouter, Depends, File, HTTPException, Request, UploadFile, status
|
||||
from fastapi.responses import FileResponse, RedirectResponse
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@@ -47,6 +50,19 @@ router = APIRouter(
|
||||
)
|
||||
|
||||
|
||||
def _match_scheme(request: Request, url: str) -> str:
|
||||
"""
|
||||
将 presigned URL 的协议替换为与当前请求一致的协议(http/https)。
|
||||
解决反向代理场景下 presigned URL 协议与请求协议不匹配的问题。
|
||||
"""
|
||||
incoming_scheme = request.headers.get("x-forwarded-proto") or request.url.scheme
|
||||
if url.startswith("http://") and incoming_scheme == "https":
|
||||
return "https://" + url[7:]
|
||||
if url.startswith("https://") and incoming_scheme == "http":
|
||||
return "http://" + url[8:]
|
||||
return url
|
||||
|
||||
|
||||
@router.post("/files", response_model=ApiResponse)
|
||||
async def upload_file(
|
||||
file: UploadFile = File(...),
|
||||
@@ -78,7 +94,7 @@ async def upload_file(
|
||||
|
||||
if file_size > settings.MAX_FILE_SIZE:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
status_code=status.HTTP_413_CONTENT_TOO_LARGE,
|
||||
detail=f"The file size exceeds the {settings.MAX_FILE_SIZE} byte limit"
|
||||
)
|
||||
|
||||
@@ -159,7 +175,6 @@ async def upload_file_with_share_token(
|
||||
|
||||
# Get share and release info from share_token
|
||||
service = ReleaseShareService(db)
|
||||
share_info = service.get_shared_release_info(share_token=share_data.share_token)
|
||||
|
||||
# Get share object to access app_id
|
||||
share = service.repo.get_by_share_token(share_data.share_token)
|
||||
@@ -278,8 +293,104 @@ async def upload_file_with_share_token(
|
||||
)
|
||||
|
||||
|
||||
@router.get("/files/info-by-url", response_model=ApiResponse)
|
||||
async def get_file_info_by_url(
|
||||
url: str,
|
||||
):
|
||||
"""
|
||||
Get file information by network URL (no authentication required).
|
||||
|
||||
Fetches file metadata from a remote URL via HTTP HEAD request.
|
||||
Falls back to GET request if HEAD is not supported.
|
||||
Returns file type, name, and size.
|
||||
|
||||
Args:
|
||||
url: The network URL of the file.
|
||||
|
||||
Returns:
|
||||
ApiResponse with file information.
|
||||
"""
|
||||
api_logger.info(f"File info by URL request: url={url}")
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
# Try HEAD request first
|
||||
response = await client.head(url, follow_redirects=True)
|
||||
|
||||
# If HEAD fails, try GET request (some servers don't support HEAD)
|
||||
if response.status_code != 200:
|
||||
api_logger.info(f"HEAD request failed with {response.status_code}, trying GET request")
|
||||
response = await client.get(url, follow_redirects=True)
|
||||
|
||||
if response.status_code != 200:
|
||||
api_logger.error(f"Failed to fetch file info: HTTP {response.status_code}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Unable to access file: HTTP {response.status_code}"
|
||||
)
|
||||
|
||||
# Get file size from Content-Length header or actual content
|
||||
file_size = response.headers.get("Content-Length")
|
||||
if file_size:
|
||||
file_size = int(file_size)
|
||||
elif hasattr(response, 'content'):
|
||||
file_size = len(response.content)
|
||||
else:
|
||||
file_size = None
|
||||
|
||||
# Get content type from Content-Type header
|
||||
content_type = response.headers.get("Content-Type", "application/octet-stream")
|
||||
# Remove charset and other parameters from content type
|
||||
content_type = content_type.split(';')[0].strip()
|
||||
|
||||
# Extract filename from Content-Disposition or URL
|
||||
file_name = None
|
||||
content_disposition = response.headers.get("Content-Disposition")
|
||||
if content_disposition and "filename=" in content_disposition:
|
||||
parts = content_disposition.split("filename=")
|
||||
if len(parts) > 1:
|
||||
file_name = parts[1].strip('"').strip("'")
|
||||
|
||||
if not file_name:
|
||||
parsed_url = urlparse(url)
|
||||
file_name = unquote(os.path.basename(parsed_url.path)) or "unknown"
|
||||
|
||||
# Extract file extension from filename
|
||||
_, file_ext = os.path.splitext(file_name)
|
||||
|
||||
# If no extension found, infer from content type
|
||||
if not file_ext:
|
||||
ext = mimetypes.guess_extension(content_type)
|
||||
if ext:
|
||||
file_ext = ext
|
||||
file_name = f"{file_name}{file_ext}"
|
||||
|
||||
api_logger.info(f"File info retrieved: name={file_name}, size={file_size}, type={content_type}")
|
||||
|
||||
return success(
|
||||
data={
|
||||
"url": url,
|
||||
"file_name": file_name,
|
||||
"file_ext": file_ext.lower() if file_ext else "",
|
||||
"file_size": file_size,
|
||||
"content_type": content_type,
|
||||
},
|
||||
msg="File information retrieved successfully"
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
api_logger.error(f"Unexpected error: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to retrieve file information: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/files/{file_id}", response_model=Any)
|
||||
async def download_file(
|
||||
request: Request,
|
||||
file_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
@@ -327,6 +438,7 @@ async def download_file(
|
||||
else:
|
||||
try:
|
||||
presigned_url = await storage_service.get_file_url(file_key, expires=3600)
|
||||
presigned_url = _match_scheme(request, presigned_url)
|
||||
api_logger.info(f"Redirecting to presigned URL: file_key={file_key}")
|
||||
return RedirectResponse(url=presigned_url, status_code=status.HTTP_302_FOUND)
|
||||
except FileNotFoundError:
|
||||
@@ -400,6 +512,7 @@ async def delete_file(
|
||||
|
||||
@router.get("/files/{file_id}/url", response_model=ApiResponse)
|
||||
async def get_file_url(
|
||||
request: Request,
|
||||
file_id: uuid.UUID,
|
||||
expires: int = None,
|
||||
permanent: bool = False,
|
||||
@@ -461,8 +574,13 @@ async def get_file_url(
|
||||
# For local storage, generate signed URL with expiration
|
||||
url = generate_signed_url(str(file_id), expires)
|
||||
else:
|
||||
# For remote storage (OSS/S3), get presigned URL
|
||||
url = await storage_service.get_file_url(file_key, expires=expires)
|
||||
# For remote storage (OSS/S3), get presigned URL with forced download
|
||||
url = await storage_service.get_file_url(
|
||||
file_key,
|
||||
expires=expires,
|
||||
file_name=file_metadata.file_name,
|
||||
)
|
||||
url = _match_scheme(request, url)
|
||||
|
||||
api_logger.info(f"Generated file URL: file_id={file_id}")
|
||||
return success(
|
||||
@@ -482,8 +600,54 @@ async def get_file_url(
|
||||
)
|
||||
|
||||
|
||||
@router.get("/files/{file_id}/public-url", response_model=ApiResponse)
|
||||
async def get_permanent_file_url(
|
||||
file_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
storage_service: FileStorageService = Depends(get_file_storage_service),
|
||||
):
|
||||
"""
|
||||
获取文件的永久公开 URL(无过期时间)。
|
||||
|
||||
- 本地存储:返回 API 永久访问地址(基于 FILE_LOCAL_SERVER_URL 配置)
|
||||
- 远程存储(OSS/S3):返回 bucket 公读地址(需 bucket 已配置公共读权限)
|
||||
"""
|
||||
file_metadata = db.query(FileMetadata).filter(FileMetadata.id == file_id).first()
|
||||
if not file_metadata:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="The file does not exist")
|
||||
|
||||
if file_metadata.status != "completed":
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"File upload not completed, status: {file_metadata.status}")
|
||||
|
||||
file_key = file_metadata.file_key
|
||||
storage = storage_service.storage
|
||||
|
||||
try:
|
||||
if isinstance(storage, LocalStorage):
|
||||
url = f"{settings.FILE_LOCAL_SERVER_URL}/storage/permanent/{file_id}"
|
||||
else:
|
||||
url = await storage.get_permanent_url(file_key)
|
||||
if not url:
|
||||
raise HTTPException(status_code=status.HTTP_501_NOT_IMPLEMENTED,
|
||||
detail="Permanent URL not supported for current storage backend")
|
||||
|
||||
api_logger.info(f"Generated permanent URL: file_id={file_id}")
|
||||
return success(
|
||||
data={"url": url, "expires_in": None, "permanent": True, "file_name": file_metadata.file_name},
|
||||
msg="Permanent file URL generated successfully"
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
api_logger.error(f"Failed to generate permanent URL: {e}")
|
||||
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to generate permanent URL: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/public/{file_id}", response_model=Any)
|
||||
async def public_download_file(
|
||||
request: Request,
|
||||
file_id: uuid.UUID,
|
||||
expires: int = 0,
|
||||
signature: str = "",
|
||||
@@ -555,6 +719,7 @@ async def public_download_file(
|
||||
# For remote storage, redirect to presigned URL
|
||||
try:
|
||||
presigned_url = await storage_service.get_file_url(file_key, expires=3600)
|
||||
presigned_url = _match_scheme(request, presigned_url)
|
||||
return RedirectResponse(url=presigned_url, status_code=status.HTTP_302_FOUND)
|
||||
except Exception as e:
|
||||
api_logger.error(f"Failed to get presigned URL: {e}")
|
||||
@@ -566,6 +731,7 @@ async def public_download_file(
|
||||
|
||||
@router.get("/permanent/{file_id}", response_model=Any)
|
||||
async def permanent_download_file(
|
||||
request: Request,
|
||||
file_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
storage_service: FileStorageService = Depends(get_file_storage_service),
|
||||
@@ -624,7 +790,8 @@ async def permanent_download_file(
|
||||
# For remote storage, redirect to presigned URL with long expiration
|
||||
try:
|
||||
# Use a very long expiration (7 days max for most cloud providers)
|
||||
presigned_url = await storage_service.get_file_url(file_key, expires=604800)
|
||||
presigned_url = await storage_service.get_file_url(file_key, expires=604800, file_name=file_metadata.file_name)
|
||||
presigned_url = _match_scheme(request, presigned_url)
|
||||
return RedirectResponse(url=presigned_url, status_code=status.HTTP_302_FOUND)
|
||||
except Exception as e:
|
||||
api_logger.error(f"Failed to get presigned URL: {e}")
|
||||
@@ -632,3 +799,44 @@ async def permanent_download_file(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to retrieve file: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/files/{file_id}/status", response_model=ApiResponse)
|
||||
async def get_file_status(
|
||||
file_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
Get file upload/processing status (no authentication required).
|
||||
|
||||
This endpoint is used to check if a file (e.g., TTS audio) is ready.
|
||||
Returns status: pending, completed, or failed.
|
||||
|
||||
Args:
|
||||
file_id: The UUID of the file.
|
||||
db: Database session.
|
||||
|
||||
Returns:
|
||||
ApiResponse with file status and metadata.
|
||||
"""
|
||||
api_logger.info(f"File status request: file_id={file_id}")
|
||||
|
||||
# Query file metadata from database
|
||||
file_metadata = db.query(FileMetadata).filter(FileMetadata.id == file_id).first()
|
||||
if not file_metadata:
|
||||
api_logger.warning(f"File not found in database: file_id={file_id}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="The file does not exist"
|
||||
)
|
||||
|
||||
return success(
|
||||
data={
|
||||
"file_id": str(file_id),
|
||||
"status": file_metadata.status,
|
||||
"file_name": file_metadata.file_name,
|
||||
"file_size": file_metadata.file_size,
|
||||
"content_type": file_metadata.content_type,
|
||||
},
|
||||
msg="File status retrieved successfully"
|
||||
)
|
||||
|
||||
@@ -3,9 +3,10 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.response_utils import success
|
||||
from app.db import get_db
|
||||
from app.db import get_db, SessionLocal
|
||||
from app.dependencies import get_current_user
|
||||
from app.models.user_model import User
|
||||
from app.repositories.home_page_repository import HomePageRepository
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.services.home_page_service import HomePageService
|
||||
|
||||
@@ -31,9 +32,32 @@ def get_workspace_list(
|
||||
|
||||
@router.get("/version", response_model=ApiResponse)
|
||||
def get_system_version():
|
||||
"""获取系统版本号+说明"""
|
||||
current_version = settings.SYSTEM_VERSION
|
||||
version_info = HomePageService.load_version_introduction(current_version)
|
||||
"""获取系统版本号 + 说明"""
|
||||
current_version = None
|
||||
version_info = None
|
||||
|
||||
# 1️⃣ 优先从数据库获取最新已发布的版本
|
||||
try:
|
||||
db = SessionLocal()
|
||||
try:
|
||||
current_version, version_info = HomePageRepository.get_latest_version_introduction(db)
|
||||
finally:
|
||||
db.close()
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
# 2️⃣ 降级:使用环境变量中的版本号
|
||||
if not current_version:
|
||||
current_version = settings.SYSTEM_VERSION
|
||||
version_info = HomePageService.load_version_introduction(current_version)
|
||||
|
||||
# 3️⃣ 如果数据库和 JSON 都没有,返回基本信息
|
||||
if not version_info:
|
||||
version_info = {
|
||||
"introduction": {"codeName": "", "releaseDate": "", "upgradePosition": "", "coreUpgrades": []},
|
||||
"introduction_en": {"codeName": "", "releaseDate": "", "upgradePosition": "", "coreUpgrades": []}
|
||||
}
|
||||
|
||||
return success(
|
||||
data={
|
||||
"version": current_version,
|
||||
|
||||
833
api/app/controllers/i18n_controller.py
Normal file
833
api/app/controllers/i18n_controller.py
Normal file
@@ -0,0 +1,833 @@
|
||||
"""
|
||||
I18n Management API Controller
|
||||
|
||||
This module provides management APIs for:
|
||||
- Language management (list, get, add, update languages)
|
||||
- Translation management (get, update, reload translations)
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import Callable, Optional
|
||||
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.core.response_utils import success
|
||||
from app.db import get_db
|
||||
from app.dependencies import get_current_user, get_current_superuser
|
||||
from app.i18n.dependencies import get_translator
|
||||
from app.i18n.service import get_translation_service
|
||||
from app.models.user_model import User
|
||||
from app.schemas.i18n_schema import (
|
||||
LanguageInfo,
|
||||
LanguageListResponse,
|
||||
LanguageCreateRequest,
|
||||
LanguageUpdateRequest,
|
||||
TranslationResponse,
|
||||
TranslationUpdateRequest,
|
||||
MissingTranslationsResponse,
|
||||
ReloadResponse
|
||||
)
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
|
||||
api_logger = get_api_logger()
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/i18n",
|
||||
tags=["I18n Management"],
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Language Management APIs
|
||||
# ============================================================================
|
||||
|
||||
@router.get("/languages", response_model=ApiResponse)
|
||||
def get_languages(
|
||||
t: Callable = Depends(get_translator),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Get list of all supported languages.
|
||||
|
||||
Returns:
|
||||
List of language information including code, name, and status
|
||||
"""
|
||||
api_logger.info(f"Get languages request from user: {current_user.username}")
|
||||
|
||||
from app.core.config import settings
|
||||
translation_service = get_translation_service()
|
||||
|
||||
# Get available locales from translation service
|
||||
available_locales = translation_service.get_available_locales()
|
||||
|
||||
# Build language info list
|
||||
languages = []
|
||||
for locale in available_locales:
|
||||
is_default = locale == settings.I18N_DEFAULT_LANGUAGE
|
||||
is_enabled = locale in settings.I18N_SUPPORTED_LANGUAGES
|
||||
|
||||
# Get native names
|
||||
native_names = {
|
||||
"zh": "中文(简体)",
|
||||
"en": "English",
|
||||
"ja": "日本語",
|
||||
"ko": "한국어",
|
||||
"fr": "Français",
|
||||
"de": "Deutsch",
|
||||
"es": "Español"
|
||||
}
|
||||
|
||||
language_info = LanguageInfo(
|
||||
code=locale,
|
||||
name=f"{locale.upper()}",
|
||||
native_name=native_names.get(locale, locale),
|
||||
is_enabled=is_enabled,
|
||||
is_default=is_default
|
||||
)
|
||||
languages.append(language_info)
|
||||
|
||||
response = LanguageListResponse(languages=languages)
|
||||
|
||||
api_logger.info(f"Returning {len(languages)} languages")
|
||||
return success(data=response.dict(), msg=t("common.success.retrieved"))
|
||||
|
||||
|
||||
@router.get("/languages/{locale}", response_model=ApiResponse)
|
||||
def get_language(
|
||||
locale: str,
|
||||
t: Callable = Depends(get_translator),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Get information about a specific language.
|
||||
|
||||
Args:
|
||||
locale: Language code (e.g., 'zh', 'en')
|
||||
|
||||
Returns:
|
||||
Language information
|
||||
"""
|
||||
api_logger.info(f"Get language info request: locale={locale}, user={current_user.username}")
|
||||
|
||||
from app.core.config import settings
|
||||
translation_service = get_translation_service()
|
||||
|
||||
# Check if locale exists
|
||||
available_locales = translation_service.get_available_locales()
|
||||
if locale not in available_locales:
|
||||
api_logger.warning(f"Language not found: {locale}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=t("i18n.language.not_found", locale=locale)
|
||||
)
|
||||
|
||||
# Build language info
|
||||
is_default = locale == settings.I18N_DEFAULT_LANGUAGE
|
||||
is_enabled = locale in settings.I18N_SUPPORTED_LANGUAGES
|
||||
|
||||
native_names = {
|
||||
"zh": "中文(简体)",
|
||||
"en": "English",
|
||||
"ja": "日本語",
|
||||
"ko": "한국어",
|
||||
"fr": "Français",
|
||||
"de": "Deutsch",
|
||||
"es": "Español"
|
||||
}
|
||||
|
||||
language_info = LanguageInfo(
|
||||
code=locale,
|
||||
name=f"{locale.upper()}",
|
||||
native_name=native_names.get(locale, locale),
|
||||
is_enabled=is_enabled,
|
||||
is_default=is_default
|
||||
)
|
||||
|
||||
api_logger.info(f"Returning language info for: {locale}")
|
||||
return success(data=language_info.dict(), msg=t("common.success.retrieved"))
|
||||
|
||||
|
||||
@router.post("/languages", response_model=ApiResponse)
|
||||
def add_language(
|
||||
request: LanguageCreateRequest,
|
||||
t: Callable = Depends(get_translator),
|
||||
current_user: User = Depends(get_current_superuser)
|
||||
):
|
||||
"""
|
||||
Add a new language (admin only).
|
||||
|
||||
Note: This endpoint validates the request but actual language addition
|
||||
requires creating translation files in the locales directory.
|
||||
|
||||
Args:
|
||||
request: Language creation request
|
||||
|
||||
Returns:
|
||||
Success message
|
||||
"""
|
||||
api_logger.info(
|
||||
f"Add language request: code={request.code}, admin={current_user.username}"
|
||||
)
|
||||
|
||||
from app.core.config import settings
|
||||
translation_service = get_translation_service()
|
||||
|
||||
# Check if language already exists
|
||||
available_locales = translation_service.get_available_locales()
|
||||
if request.code in available_locales:
|
||||
api_logger.warning(f"Language already exists: {request.code}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=t("i18n.language.already_exists", locale=request.code)
|
||||
)
|
||||
|
||||
# Note: Actual language addition requires creating translation files
|
||||
# This endpoint serves as a validation and documentation point
|
||||
|
||||
api_logger.info(
|
||||
f"Language addition validated: {request.code}. "
|
||||
"Translation files need to be created manually."
|
||||
)
|
||||
|
||||
return success(
|
||||
msg=t(
|
||||
"i18n.language.add_instructions",
|
||||
locale=request.code,
|
||||
dir=settings.I18N_CORE_LOCALES_DIR
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@router.put("/languages/{locale}", response_model=ApiResponse)
|
||||
def update_language(
|
||||
locale: str,
|
||||
request: LanguageUpdateRequest,
|
||||
t: Callable = Depends(get_translator),
|
||||
current_user: User = Depends(get_current_superuser)
|
||||
):
|
||||
"""
|
||||
Update language configuration (admin only).
|
||||
|
||||
Note: This endpoint validates the request but actual configuration
|
||||
changes require updating environment variables or config files.
|
||||
|
||||
Args:
|
||||
locale: Language code
|
||||
request: Language update request
|
||||
|
||||
Returns:
|
||||
Success message
|
||||
"""
|
||||
api_logger.info(
|
||||
f"Update language request: locale={locale}, admin={current_user.username}"
|
||||
)
|
||||
|
||||
translation_service = get_translation_service()
|
||||
|
||||
# Check if language exists
|
||||
available_locales = translation_service.get_available_locales()
|
||||
if locale not in available_locales:
|
||||
api_logger.warning(f"Language not found: {locale}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=t("i18n.language.not_found", locale=locale)
|
||||
)
|
||||
|
||||
# Note: Actual configuration changes require updating settings
|
||||
# This endpoint serves as a validation and documentation point
|
||||
|
||||
api_logger.info(
|
||||
f"Language update validated: {locale}. "
|
||||
"Configuration changes require environment variable updates."
|
||||
)
|
||||
|
||||
return success(msg=t("i18n.language.update_instructions", locale=locale))
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Translation Management APIs
|
||||
# ============================================================================
|
||||
|
||||
@router.get("/translations", response_model=ApiResponse)
|
||||
def get_all_translations(
|
||||
locale: Optional[str] = None,
|
||||
t: Callable = Depends(get_translator),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Get all translations for all or specific locale.
|
||||
|
||||
Args:
|
||||
locale: Optional locale filter
|
||||
|
||||
Returns:
|
||||
All translations organized by locale and namespace
|
||||
"""
|
||||
api_logger.info(
|
||||
f"Get all translations request: locale={locale}, user={current_user.username}"
|
||||
)
|
||||
|
||||
translation_service = get_translation_service()
|
||||
|
||||
if locale:
|
||||
# Get translations for specific locale
|
||||
available_locales = translation_service.get_available_locales()
|
||||
if locale not in available_locales:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=t("i18n.language.not_found", locale=locale)
|
||||
)
|
||||
|
||||
translations = {
|
||||
locale: translation_service._cache.get(locale, {})
|
||||
}
|
||||
else:
|
||||
# Get all translations
|
||||
translations = translation_service._cache
|
||||
|
||||
response = TranslationResponse(translations=translations)
|
||||
|
||||
api_logger.info(f"Returning translations for: {locale or 'all locales'}")
|
||||
return success(data=response.dict(), msg=t("common.success.retrieved"))
|
||||
|
||||
|
||||
@router.get("/translations/{locale}", response_model=ApiResponse)
|
||||
def get_locale_translations(
|
||||
locale: str,
|
||||
t: Callable = Depends(get_translator),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Get all translations for a specific locale.
|
||||
|
||||
Args:
|
||||
locale: Language code
|
||||
|
||||
Returns:
|
||||
All translations for the locale organized by namespace
|
||||
"""
|
||||
api_logger.info(
|
||||
f"Get locale translations request: locale={locale}, user={current_user.username}"
|
||||
)
|
||||
|
||||
translation_service = get_translation_service()
|
||||
|
||||
# Check if locale exists
|
||||
available_locales = translation_service.get_available_locales()
|
||||
if locale not in available_locales:
|
||||
api_logger.warning(f"Language not found: {locale}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=t("i18n.language.not_found", locale=locale)
|
||||
)
|
||||
|
||||
translations = translation_service._cache.get(locale, {})
|
||||
|
||||
api_logger.info(f"Returning {len(translations)} namespaces for locale: {locale}")
|
||||
return success(data={"locale": locale, "translations": translations}, msg=t("common.success.retrieved"))
|
||||
|
||||
|
||||
@router.get("/translations/{locale}/{namespace}", response_model=ApiResponse)
|
||||
def get_namespace_translations(
|
||||
locale: str,
|
||||
namespace: str,
|
||||
t: Callable = Depends(get_translator),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Get translations for a specific namespace in a locale.
|
||||
|
||||
Args:
|
||||
locale: Language code
|
||||
namespace: Translation namespace (e.g., 'common', 'auth')
|
||||
|
||||
Returns:
|
||||
Translations for the specified namespace
|
||||
"""
|
||||
api_logger.info(
|
||||
f"Get namespace translations request: locale={locale}, "
|
||||
f"namespace={namespace}, user={current_user.username}"
|
||||
)
|
||||
|
||||
translation_service = get_translation_service()
|
||||
|
||||
# Check if locale exists
|
||||
available_locales = translation_service.get_available_locales()
|
||||
if locale not in available_locales:
|
||||
api_logger.warning(f"Language not found: {locale}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=t("i18n.language.not_found", locale=locale)
|
||||
)
|
||||
|
||||
# Get namespace translations
|
||||
locale_translations = translation_service._cache.get(locale, {})
|
||||
namespace_translations = locale_translations.get(namespace, {})
|
||||
|
||||
if not namespace_translations:
|
||||
api_logger.warning(f"Namespace not found: {namespace} in locale: {locale}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=t("i18n.namespace.not_found", namespace=namespace, locale=locale)
|
||||
)
|
||||
|
||||
api_logger.info(
|
||||
f"Returning translations for namespace: {namespace} in locale: {locale}"
|
||||
)
|
||||
return success(
|
||||
data={
|
||||
"locale": locale,
|
||||
"namespace": namespace,
|
||||
"translations": namespace_translations
|
||||
},
|
||||
msg=t("common.success.retrieved")
|
||||
)
|
||||
|
||||
|
||||
@router.put("/translations/{locale}/{key:path}", response_model=ApiResponse)
|
||||
def update_translation(
|
||||
locale: str,
|
||||
key: str,
|
||||
request: TranslationUpdateRequest,
|
||||
t: Callable = Depends(get_translator),
|
||||
current_user: User = Depends(get_current_superuser)
|
||||
):
|
||||
"""
|
||||
Update a single translation (admin only).
|
||||
|
||||
Note: This endpoint validates the request but actual translation updates
|
||||
require modifying translation files in the locales directory.
|
||||
|
||||
Args:
|
||||
locale: Language code
|
||||
key: Translation key (format: "namespace.key.subkey")
|
||||
request: Translation update request
|
||||
|
||||
Returns:
|
||||
Success message
|
||||
"""
|
||||
api_logger.info(
|
||||
f"Update translation request: locale={locale}, key={key}, "
|
||||
f"admin={current_user.username}"
|
||||
)
|
||||
|
||||
translation_service = get_translation_service()
|
||||
|
||||
# Check if locale exists
|
||||
available_locales = translation_service.get_available_locales()
|
||||
if locale not in available_locales:
|
||||
api_logger.warning(f"Language not found: {locale}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=t("i18n.language.not_found", locale=locale)
|
||||
)
|
||||
|
||||
# Validate key format
|
||||
if "." not in key:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=t("i18n.translation.invalid_key_format", key=key)
|
||||
)
|
||||
|
||||
# Note: Actual translation updates require modifying JSON files
|
||||
# This endpoint serves as a validation and documentation point
|
||||
|
||||
api_logger.info(
|
||||
f"Translation update validated: {locale}/{key}. "
|
||||
"Translation files need to be updated manually."
|
||||
)
|
||||
|
||||
return success(
|
||||
msg=t("i18n.translation.update_instructions", locale=locale, key=key)
|
||||
)
|
||||
|
||||
|
||||
@router.get("/translations/missing", response_model=ApiResponse)
|
||||
def get_missing_translations(
|
||||
locale: Optional[str] = None,
|
||||
t: Callable = Depends(get_translator),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Get list of missing translations.
|
||||
|
||||
Compares translations across locales to find missing keys.
|
||||
|
||||
Args:
|
||||
locale: Optional locale to check (defaults to checking all non-default locales)
|
||||
|
||||
Returns:
|
||||
List of missing translation keys
|
||||
"""
|
||||
api_logger.info(
|
||||
f"Get missing translations request: locale={locale}, user={current_user.username}"
|
||||
)
|
||||
|
||||
from app.core.config import settings
|
||||
translation_service = get_translation_service()
|
||||
|
||||
default_locale = settings.I18N_DEFAULT_LANGUAGE
|
||||
available_locales = translation_service.get_available_locales()
|
||||
|
||||
# Get default locale translations as reference
|
||||
default_translations = translation_service._cache.get(default_locale, {})
|
||||
|
||||
# Collect all keys from default locale
|
||||
def collect_keys(data, prefix=""):
|
||||
keys = []
|
||||
for key, value in data.items():
|
||||
full_key = f"{prefix}.{key}" if prefix else key
|
||||
if isinstance(value, dict):
|
||||
keys.extend(collect_keys(value, full_key))
|
||||
else:
|
||||
keys.append(full_key)
|
||||
return keys
|
||||
|
||||
default_keys = set()
|
||||
for namespace, translations in default_translations.items():
|
||||
namespace_keys = collect_keys(translations, namespace)
|
||||
default_keys.update(namespace_keys)
|
||||
|
||||
# Find missing keys in target locale(s)
|
||||
missing_by_locale = {}
|
||||
|
||||
target_locales = [locale] if locale else [
|
||||
loc for loc in available_locales if loc != default_locale
|
||||
]
|
||||
|
||||
for target_locale in target_locales:
|
||||
if target_locale not in available_locales:
|
||||
continue
|
||||
|
||||
target_translations = translation_service._cache.get(target_locale, {})
|
||||
target_keys = set()
|
||||
|
||||
for namespace, translations in target_translations.items():
|
||||
namespace_keys = collect_keys(translations, namespace)
|
||||
target_keys.update(namespace_keys)
|
||||
|
||||
missing_keys = default_keys - target_keys
|
||||
if missing_keys:
|
||||
missing_by_locale[target_locale] = sorted(list(missing_keys))
|
||||
|
||||
response = MissingTranslationsResponse(missing_translations=missing_by_locale)
|
||||
|
||||
total_missing = sum(len(keys) for keys in missing_by_locale.values())
|
||||
api_logger.info(f"Found {total_missing} missing translations across {len(missing_by_locale)} locales")
|
||||
|
||||
return success(data=response.dict(), msg=t("common.success.retrieved"))
|
||||
|
||||
|
||||
@router.post("/reload", response_model=ApiResponse)
|
||||
def reload_translations(
|
||||
locale: Optional[str] = None,
|
||||
t: Callable = Depends(get_translator),
|
||||
current_user: User = Depends(get_current_superuser)
|
||||
):
|
||||
"""
|
||||
Trigger hot reload of translation files (admin only).
|
||||
|
||||
Args:
|
||||
locale: Optional locale to reload (defaults to reloading all locales)
|
||||
|
||||
Returns:
|
||||
Reload status and statistics
|
||||
"""
|
||||
api_logger.info(
|
||||
f"Reload translations request: locale={locale or 'all'}, "
|
||||
f"admin={current_user.username}"
|
||||
)
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
if not settings.I18N_ENABLE_HOT_RELOAD:
|
||||
api_logger.warning("Hot reload is disabled in configuration")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=t("i18n.reload.disabled")
|
||||
)
|
||||
|
||||
translation_service = get_translation_service()
|
||||
|
||||
try:
|
||||
# Reload translations
|
||||
translation_service.reload(locale)
|
||||
|
||||
# Get statistics
|
||||
available_locales = translation_service.get_available_locales()
|
||||
reloaded_locales = [locale] if locale else available_locales
|
||||
|
||||
response = ReloadResponse(
|
||||
success=True,
|
||||
reloaded_locales=reloaded_locales,
|
||||
total_locales=len(available_locales)
|
||||
)
|
||||
|
||||
api_logger.info(
|
||||
f"Successfully reloaded translations for: {', '.join(reloaded_locales)}"
|
||||
)
|
||||
|
||||
return success(data=response.dict(), msg=t("i18n.reload.success"))
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"Failed to reload translations: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=t("i18n.reload.failed", error=str(e))
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Performance Monitoring APIs
|
||||
# ============================================================================
|
||||
|
||||
@router.get("/metrics", response_model=ApiResponse)
|
||||
def get_metrics(
|
||||
t: Callable = Depends(get_translator),
|
||||
current_user: User = Depends(get_current_superuser)
|
||||
):
|
||||
"""
|
||||
Get i18n performance metrics (admin only).
|
||||
|
||||
Returns:
|
||||
Performance metrics including:
|
||||
- Request counts
|
||||
- Missing translations
|
||||
- Timing statistics
|
||||
- Locale usage
|
||||
- Error counts
|
||||
"""
|
||||
api_logger.info(f"Get metrics request: admin={current_user.username}")
|
||||
|
||||
translation_service = get_translation_service()
|
||||
metrics = translation_service.get_metrics_summary()
|
||||
|
||||
api_logger.info("Returning i18n metrics")
|
||||
return success(data=metrics, msg=t("common.success.retrieved"))
|
||||
|
||||
|
||||
@router.get("/metrics/cache", response_model=ApiResponse)
|
||||
def get_cache_stats(
|
||||
t: Callable = Depends(get_translator),
|
||||
current_user: User = Depends(get_current_superuser)
|
||||
):
|
||||
"""
|
||||
Get cache statistics (admin only).
|
||||
|
||||
Returns:
|
||||
Cache statistics including:
|
||||
- Hit/miss rates
|
||||
- LRU cache performance
|
||||
- Loaded locales
|
||||
- Memory usage
|
||||
"""
|
||||
api_logger.info(f"Get cache stats request: admin={current_user.username}")
|
||||
|
||||
translation_service = get_translation_service()
|
||||
cache_stats = translation_service.get_cache_stats()
|
||||
memory_usage = translation_service.get_memory_usage()
|
||||
|
||||
data = {
|
||||
"cache": cache_stats,
|
||||
"memory": memory_usage
|
||||
}
|
||||
|
||||
api_logger.info("Returning cache statistics")
|
||||
return success(data=data, msg=t("common.success.retrieved"))
|
||||
|
||||
|
||||
@router.get("/metrics/prometheus")
|
||||
def get_prometheus_metrics(
|
||||
current_user: User = Depends(get_current_superuser)
|
||||
):
|
||||
"""
|
||||
Get metrics in Prometheus format (admin only).
|
||||
|
||||
Returns:
|
||||
Prometheus-formatted metrics as plain text
|
||||
"""
|
||||
api_logger.info(f"Get Prometheus metrics request: admin={current_user.username}")
|
||||
|
||||
from app.i18n.metrics import get_metrics
|
||||
metrics = get_metrics()
|
||||
prometheus_output = metrics.export_prometheus()
|
||||
|
||||
from fastapi.responses import PlainTextResponse
|
||||
return PlainTextResponse(content=prometheus_output)
|
||||
|
||||
|
||||
@router.post("/metrics/reset", response_model=ApiResponse)
|
||||
def reset_metrics(
|
||||
t: Callable = Depends(get_translator),
|
||||
current_user: User = Depends(get_current_superuser)
|
||||
):
|
||||
"""
|
||||
Reset all metrics (admin only).
|
||||
|
||||
Returns:
|
||||
Success message
|
||||
"""
|
||||
api_logger.info(f"Reset metrics request: admin={current_user.username}")
|
||||
|
||||
from app.i18n.metrics import get_metrics
|
||||
metrics = get_metrics()
|
||||
metrics.reset()
|
||||
|
||||
translation_service = get_translation_service()
|
||||
translation_service.cache.reset_stats()
|
||||
|
||||
api_logger.info("Metrics reset completed")
|
||||
return success(msg=t("i18n.metrics.reset_success"))
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Missing Translation Logging and Reporting APIs
|
||||
# ============================================================================
|
||||
|
||||
@router.get("/logs/missing", response_model=ApiResponse)
|
||||
def get_missing_translation_logs(
|
||||
locale: Optional[str] = None,
|
||||
limit: Optional[int] = 100,
|
||||
t: Callable = Depends(get_translator),
|
||||
current_user: User = Depends(get_current_superuser)
|
||||
):
|
||||
"""
|
||||
Get missing translation logs (admin only).
|
||||
|
||||
Returns logged missing translations with context information.
|
||||
|
||||
Args:
|
||||
locale: Optional locale filter
|
||||
limit: Maximum number of entries to return (default: 100)
|
||||
|
||||
Returns:
|
||||
Missing translation logs with context
|
||||
"""
|
||||
api_logger.info(
|
||||
f"Get missing translation logs request: locale={locale}, "
|
||||
f"limit={limit}, admin={current_user.username}"
|
||||
)
|
||||
|
||||
translation_service = get_translation_service()
|
||||
translation_logger = translation_service.translation_logger
|
||||
|
||||
# Get missing translations
|
||||
missing_translations = translation_logger.get_missing_translations(locale)
|
||||
|
||||
# Get missing with context
|
||||
missing_with_context = translation_logger.get_missing_with_context(locale, limit)
|
||||
|
||||
# Get statistics
|
||||
statistics = translation_logger.get_statistics()
|
||||
|
||||
data = {
|
||||
"missing_translations": missing_translations,
|
||||
"recent_context": missing_with_context,
|
||||
"statistics": statistics
|
||||
}
|
||||
|
||||
api_logger.info(
|
||||
f"Returning {statistics['total_missing']} missing translations"
|
||||
)
|
||||
return success(data=data, msg=t("common.success.retrieved"))
|
||||
|
||||
|
||||
@router.get("/logs/missing/report", response_model=ApiResponse)
|
||||
def generate_missing_translation_report(
|
||||
locale: Optional[str] = None,
|
||||
t: Callable = Depends(get_translator),
|
||||
current_user: User = Depends(get_current_superuser)
|
||||
):
|
||||
"""
|
||||
Generate a comprehensive missing translation report (admin only).
|
||||
|
||||
Args:
|
||||
locale: Optional locale filter
|
||||
|
||||
Returns:
|
||||
Comprehensive report with missing translations and statistics
|
||||
"""
|
||||
api_logger.info(
|
||||
f"Generate missing translation report request: locale={locale}, "
|
||||
f"admin={current_user.username}"
|
||||
)
|
||||
|
||||
translation_service = get_translation_service()
|
||||
translation_logger = translation_service.translation_logger
|
||||
|
||||
# Generate report
|
||||
report = translation_logger.generate_report(locale)
|
||||
|
||||
api_logger.info(
|
||||
f"Generated report with {report['total_missing']} missing translations"
|
||||
)
|
||||
return success(data=report, msg=t("common.success.retrieved"))
|
||||
|
||||
|
||||
@router.post("/logs/missing/export", response_model=ApiResponse)
|
||||
def export_missing_translations(
|
||||
locale: Optional[str] = None,
|
||||
t: Callable = Depends(get_translator),
|
||||
current_user: User = Depends(get_current_superuser)
|
||||
):
|
||||
"""
|
||||
Export missing translations to JSON file (admin only).
|
||||
|
||||
Args:
|
||||
locale: Optional locale filter
|
||||
|
||||
Returns:
|
||||
Export status and file path
|
||||
"""
|
||||
api_logger.info(
|
||||
f"Export missing translations request: locale={locale}, "
|
||||
f"admin={current_user.username}"
|
||||
)
|
||||
|
||||
from datetime import datetime
|
||||
translation_service = get_translation_service()
|
||||
translation_logger = translation_service.translation_logger
|
||||
|
||||
# Generate filename with timestamp
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
locale_suffix = f"_{locale}" if locale else "_all"
|
||||
output_file = f"logs/i18n/missing_translations{locale_suffix}_{timestamp}.json"
|
||||
|
||||
# Export to file
|
||||
translation_logger.export_to_json(output_file)
|
||||
|
||||
api_logger.info(f"Missing translations exported to: {output_file}")
|
||||
return success(
|
||||
data={"file_path": output_file},
|
||||
msg=t("i18n.logs.export_success", file=output_file)
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/logs/missing", response_model=ApiResponse)
|
||||
def clear_missing_translation_logs(
|
||||
locale: Optional[str] = None,
|
||||
t: Callable = Depends(get_translator),
|
||||
current_user: User = Depends(get_current_superuser)
|
||||
):
|
||||
"""
|
||||
Clear missing translation logs (admin only).
|
||||
|
||||
Args:
|
||||
locale: Optional locale to clear (clears all if not specified)
|
||||
|
||||
Returns:
|
||||
Success message
|
||||
"""
|
||||
api_logger.info(
|
||||
f"Clear missing translation logs request: locale={locale or 'all'}, "
|
||||
f"admin={current_user.username}"
|
||||
)
|
||||
|
||||
translation_service = get_translation_service()
|
||||
translation_logger = translation_service.translation_logger
|
||||
|
||||
# Clear logs
|
||||
translation_logger.clear(locale)
|
||||
|
||||
api_logger.info(f"Cleared missing translation logs for: {locale or 'all locales'}")
|
||||
return success(msg=t("i18n.logs.clear_success"))
|
||||
@@ -352,6 +352,7 @@ async def delete_knowledge(
|
||||
# 2. Soft-delete knowledge base
|
||||
api_logger.debug(f"Perform a soft delete: {db_knowledge.name} (ID: {knowledge_id})")
|
||||
db_knowledge.status = 2
|
||||
db_knowledge.updated_at = datetime.datetime.now()
|
||||
db.commit()
|
||||
api_logger.info(f"The knowledge base has been successfully deleted: {db_knowledge.name} (ID: {knowledge_id})")
|
||||
return success(msg="The knowledge base has been successfully deleted")
|
||||
|
||||
@@ -19,7 +19,7 @@ from app.models import mcp_market_config_model
|
||||
from app.models.user_model import User
|
||||
from app.schemas import mcp_market_config_schema
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.services import mcp_market_config_service
|
||||
from app.services import mcp_market_config_service, mcp_market_service
|
||||
|
||||
# Obtain a dedicated API logger
|
||||
api_logger = get_api_logger()
|
||||
@@ -91,9 +91,11 @@ async def get_mcp_servers(
|
||||
|
||||
try:
|
||||
cookies = api.get_cookies(token)
|
||||
headers=api.builder_headers(api.headers)
|
||||
headers['Authorization'] = f'Bearer {token}'
|
||||
r = api.session.put(
|
||||
url=api.mcp_base_url,
|
||||
headers=api.builder_headers(api.headers),
|
||||
headers=headers,
|
||||
json=body,
|
||||
cookies=cookies)
|
||||
raise_for_http_status(r)
|
||||
@@ -123,6 +125,17 @@ async def get_mcp_servers(
|
||||
"has_next": True if page * pagesize < total else False
|
||||
}
|
||||
}
|
||||
# 5. Update mck_market.mcp_count
|
||||
db_mcp_market = mcp_market_service.get_mcp_market_by_id(db, mcp_market_id=db_mcp_market_config.mcp_market_id, current_user=current_user)
|
||||
if not db_mcp_market:
|
||||
api_logger.warning(f"The mcp market does not exist or access is denied: mcp_market_id={db_mcp_market_config.mcp_market_id}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="The mcp market does not exist or access is denied"
|
||||
)
|
||||
db_mcp_market.mcp_count = total
|
||||
db.commit()
|
||||
db.refresh(db_mcp_market)
|
||||
return success(data=result, msg="Query of mcp servers list successful")
|
||||
|
||||
|
||||
@@ -162,6 +175,7 @@ async def get_operational_mcp_servers(
|
||||
|
||||
url = f'{api.mcp_base_url}/operational'
|
||||
headers = api.builder_headers(api.headers)
|
||||
headers['Authorization'] = f'Bearer {token}'
|
||||
|
||||
try:
|
||||
cookies = api.get_cookies(access_token=token, cookies_required=True)
|
||||
@@ -249,7 +263,9 @@ async def create_mcp_market_config(
|
||||
api.login(create_data.token)
|
||||
body = {'filter': {}, 'page_number': 1, 'page_size': 1, 'search': None}
|
||||
cookies = api.get_cookies(create_data.token)
|
||||
r = api.session.put(url=api.mcp_base_url, headers=api.builder_headers(api.headers), json=body, cookies=cookies)
|
||||
headers = api.builder_headers(api.headers)
|
||||
headers['Authorization'] = f'Bearer {create_data.token}'
|
||||
r = api.session.put(url=api.mcp_base_url, headers=headers, json=body, cookies=cookies)
|
||||
raise_for_http_status(r)
|
||||
except Exception as e:
|
||||
api_logger.warning(f"Token validation failed for ModelScope MCP market: {str(e)}")
|
||||
@@ -265,6 +281,32 @@ async def create_mcp_market_config(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"The mcp market id already exists: {create_data.mcp_market_id}"
|
||||
)
|
||||
# 2. verify token
|
||||
create_data.status = 1
|
||||
try:
|
||||
api = MCPApi()
|
||||
token = create_data.token
|
||||
api.login(token)
|
||||
|
||||
body = {
|
||||
'filter': {},
|
||||
'page_number': 1,
|
||||
'page_size': 20,
|
||||
'search': ""
|
||||
}
|
||||
cookies = api.get_cookies(token)
|
||||
headers = api.builder_headers(api.headers)
|
||||
headers['Authorization'] = f'Bearer {token}'
|
||||
r = api.session.put(
|
||||
url=api.mcp_base_url,
|
||||
headers=headers,
|
||||
json=body,
|
||||
cookies=cookies)
|
||||
raise_for_http_status(r)
|
||||
except requests.exceptions.RequestException as e:
|
||||
api_logger.error(f"Failed to get MCP servers: {str(e)}")
|
||||
create_data.status = 0
|
||||
# 3. create mcp_market_config
|
||||
db_mcp_market_config = mcp_market_config_service.create_mcp_market_config(db=db, mcp_market_config=create_data, current_user=current_user)
|
||||
api_logger.info(
|
||||
f"The mcp market config has been successfully created: (ID: {db_mcp_market_config.id})")
|
||||
@@ -358,7 +400,9 @@ async def update_mcp_market_config(
|
||||
api.login(update_data.token)
|
||||
body = {'filter': {}, 'page_number': 1, 'page_size': 1, 'search': None}
|
||||
cookies = api.get_cookies(update_data.token)
|
||||
r = api.session.put(url=api.mcp_base_url, headers=api.builder_headers(api.headers), json=body, cookies=cookies)
|
||||
headers = api.builder_headers(api.headers)
|
||||
headers['Authorization'] = f'Bearer {update_data.token}'
|
||||
r = api.session.put(url=api.mcp_base_url, headers=headers, json=body, cookies=cookies)
|
||||
raise_for_http_status(r)
|
||||
except Exception as e:
|
||||
api_logger.warning(f"Token validation failed for ModelScope MCP market: {str(e)}")
|
||||
@@ -395,7 +439,7 @@ async def update_mcp_market_config(
|
||||
detail=f"The mcp market config update failed: {str(e)}"
|
||||
)
|
||||
|
||||
# 4. Return the updated mcp market config
|
||||
# 5. Return the updated mcp market config
|
||||
return success(data=jsonable_encoder(mcp_market_config_schema.McpMarketConfig.model_validate(db_mcp_market_config)),
|
||||
msg="The mcp market config information updated successfully")
|
||||
|
||||
|
||||
@@ -118,142 +118,142 @@ async def download_log(
|
||||
return fail(BizCode.INTERNAL_ERROR, "启动日志流式传输失败", str(e))
|
||||
|
||||
|
||||
@router.post("/writer_service", response_model=ApiResponse)
|
||||
@cur_workspace_access_guard()
|
||||
async def write_server(
|
||||
user_input: Write_UserInput,
|
||||
language_type: str = Header(default=None, alias="X-Language-Type"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Write service endpoint - processes write operations synchronously
|
||||
|
||||
Args:
|
||||
user_input: Write request containing message and end_user_id
|
||||
language_type: 语言类型 ("zh" 中文, "en" 英文),通过 X-Language-Type Header 传递
|
||||
|
||||
Returns:
|
||||
Response with write operation status
|
||||
"""
|
||||
# 使用集中化的语言校验
|
||||
language = get_language_from_header(language_type)
|
||||
|
||||
config_id = user_input.config_id
|
||||
workspace_id = current_user.current_workspace_id
|
||||
api_logger.info(f"Write service: workspace_id={workspace_id}, config_id={config_id}, language_type={language}")
|
||||
|
||||
# 获取 storage_type,如果为 None 则使用默认值
|
||||
storage_type = workspace_service.get_workspace_storage_type(
|
||||
db=db,
|
||||
workspace_id=workspace_id,
|
||||
user=current_user
|
||||
)
|
||||
if storage_type is None: storage_type = 'neo4j'
|
||||
user_rag_memory_id = ''
|
||||
|
||||
# 如果 storage_type 是 rag,必须确保有有效的 user_rag_memory_id
|
||||
if storage_type == 'rag':
|
||||
if workspace_id:
|
||||
knowledge = knowledge_repository.get_knowledge_by_name(
|
||||
db=db,
|
||||
name="USER_RAG_MERORY",
|
||||
workspace_id=workspace_id
|
||||
)
|
||||
if knowledge:
|
||||
user_rag_memory_id = str(knowledge.id)
|
||||
else:
|
||||
api_logger.warning(
|
||||
f"未找到名为 'USER_RAG_MERORY' 的知识库,workspace_id: {workspace_id},将使用 neo4j 存储")
|
||||
storage_type = 'neo4j'
|
||||
else:
|
||||
api_logger.warning("workspace_id 为空,无法使用 rag 存储,将使用 neo4j 存储")
|
||||
storage_type = 'neo4j'
|
||||
|
||||
api_logger.info(
|
||||
f"Write service requested for group {user_input.end_user_id}, storage_type: {storage_type}, user_rag_memory_id: {user_rag_memory_id}")
|
||||
try:
|
||||
messages_list = memory_agent_service.get_messages_list(user_input)
|
||||
result = await memory_agent_service.write_memory(
|
||||
user_input.end_user_id,
|
||||
messages_list,
|
||||
config_id,
|
||||
db,
|
||||
storage_type,
|
||||
user_rag_memory_id,
|
||||
language
|
||||
)
|
||||
|
||||
return success(data=result, msg="写入成功")
|
||||
except BaseException as e:
|
||||
# Handle ExceptionGroup from TaskGroup (Python 3.11+) or BaseExceptionGroup
|
||||
if hasattr(e, 'exceptions'):
|
||||
error_messages = [f"{type(sub_e).__name__}: {str(sub_e)}" for sub_e in e.exceptions]
|
||||
detailed_error = "; ".join(error_messages)
|
||||
api_logger.error(f"Write operation error (TaskGroup): {detailed_error}", exc_info=True)
|
||||
return fail(BizCode.INTERNAL_ERROR, "写入失败", detailed_error)
|
||||
api_logger.error(f"Write operation error: {str(e)}", exc_info=True)
|
||||
return fail(BizCode.INTERNAL_ERROR, "写入失败", str(e))
|
||||
|
||||
|
||||
@router.post("/writer_service_async", response_model=ApiResponse)
|
||||
@cur_workspace_access_guard()
|
||||
async def write_server_async(
|
||||
user_input: Write_UserInput,
|
||||
language_type: str = Header(default=None, alias="X-Language-Type"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Async write service endpoint - enqueues write processing to Celery
|
||||
|
||||
Args:
|
||||
user_input: Write request containing message and end_user_id
|
||||
language_type: 语言类型 ("zh" 中文, "en" 英文),通过 X-Language-Type Header 传递
|
||||
|
||||
Returns:
|
||||
Task ID for tracking async operation
|
||||
Use GET /memory/write_result/{task_id} to check task status and get result
|
||||
"""
|
||||
# 使用集中化的语言校验
|
||||
language = get_language_from_header(language_type)
|
||||
|
||||
config_id = user_input.config_id
|
||||
workspace_id = current_user.current_workspace_id
|
||||
api_logger.info(
|
||||
f"Async write service: workspace_id={workspace_id}, config_id={config_id}, language_type={language}")
|
||||
|
||||
# 获取 storage_type,如果为 None 则使用默认值
|
||||
storage_type = workspace_service.get_workspace_storage_type(
|
||||
db=db,
|
||||
workspace_id=workspace_id,
|
||||
user=current_user
|
||||
)
|
||||
if storage_type is None: storage_type = 'neo4j'
|
||||
user_rag_memory_id = ''
|
||||
if workspace_id:
|
||||
|
||||
knowledge = knowledge_repository.get_knowledge_by_name(
|
||||
db=db,
|
||||
name="USER_RAG_MERORY",
|
||||
workspace_id=workspace_id
|
||||
)
|
||||
if knowledge: user_rag_memory_id = str(knowledge.id)
|
||||
api_logger.info(f"Async write: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}")
|
||||
try:
|
||||
# 获取标准化的消息列表
|
||||
messages_list = memory_agent_service.get_messages_list(user_input)
|
||||
|
||||
task = celery_app.send_task(
|
||||
"app.core.memory.agent.write_message",
|
||||
args=[user_input.end_user_id, messages_list, config_id, storage_type, user_rag_memory_id, language]
|
||||
)
|
||||
api_logger.info(f"Write task queued: {task.id}")
|
||||
|
||||
return success(data={"task_id": task.id}, msg="写入任务已提交")
|
||||
except Exception as e:
|
||||
api_logger.error(f"Async write operation failed: {str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "写入失败", str(e))
|
||||
# @router.post("/writer_service", response_model=ApiResponse)
|
||||
# @cur_workspace_access_guard()
|
||||
# async def write_server(
|
||||
# user_input: Write_UserInput,
|
||||
# language_type: str = Header(default=None, alias="X-Language-Type"),
|
||||
# db: Session = Depends(get_db),
|
||||
# current_user: User = Depends(get_current_user)
|
||||
# ):
|
||||
# """
|
||||
# Write service endpoint - processes write operations synchronously
|
||||
#
|
||||
# Args:
|
||||
# user_input: Write request containing message and end_user_id
|
||||
# language_type: 语言类型 ("zh" 中文, "en" 英文),通过 X-Language-Type Header 传递
|
||||
#
|
||||
# Returns:
|
||||
# Response with write operation status
|
||||
# """
|
||||
# # 使用集中化的语言校验
|
||||
# language = get_language_from_header(language_type)
|
||||
#
|
||||
# config_id = user_input.config_id
|
||||
# workspace_id = current_user.current_workspace_id
|
||||
# api_logger.info(f"Write service: workspace_id={workspace_id}, config_id={config_id}, language_type={language}")
|
||||
#
|
||||
# # 获取 storage_type,如果为 None 则使用默认值
|
||||
# storage_type = workspace_service.get_workspace_storage_type(
|
||||
# db=db,
|
||||
# workspace_id=workspace_id,
|
||||
# user=current_user
|
||||
# )
|
||||
# if storage_type is None: storage_type = 'neo4j'
|
||||
# user_rag_memory_id = ''
|
||||
#
|
||||
# # 如果 storage_type 是 rag,必须确保有有效的 user_rag_memory_id
|
||||
# if storage_type == 'rag':
|
||||
# if workspace_id:
|
||||
# knowledge = knowledge_repository.get_knowledge_by_name(
|
||||
# db=db,
|
||||
# name="USER_RAG_MERORY",
|
||||
# workspace_id=workspace_id
|
||||
# )
|
||||
# if knowledge:
|
||||
# user_rag_memory_id = str(knowledge.id)
|
||||
# else:
|
||||
# api_logger.warning(
|
||||
# f"未找到名为 'USER_RAG_MERORY' 的知识库,workspace_id: {workspace_id},将使用 neo4j 存储")
|
||||
# storage_type = 'neo4j'
|
||||
# else:
|
||||
# api_logger.warning("workspace_id 为空,无法使用 rag 存储,将使用 neo4j 存储")
|
||||
# storage_type = 'neo4j'
|
||||
#
|
||||
# api_logger.info(
|
||||
# f"Write service requested for group {user_input.end_user_id}, storage_type: {storage_type}, user_rag_memory_id: {user_rag_memory_id}")
|
||||
# try:
|
||||
# messages_list = memory_agent_service.get_messages_list(user_input)
|
||||
# result = await memory_agent_service.write_memory(
|
||||
# user_input.end_user_id,
|
||||
# messages_list,
|
||||
# config_id,
|
||||
# db,
|
||||
# storage_type,
|
||||
# user_rag_memory_id,
|
||||
# language
|
||||
# )
|
||||
#
|
||||
# return success(data=result, msg="写入成功")
|
||||
# except BaseException as e:
|
||||
# # Handle ExceptionGroup from TaskGroup (Python 3.11+) or BaseExceptionGroup
|
||||
# if hasattr(e, 'exceptions'):
|
||||
# error_messages = [f"{type(sub_e).__name__}: {str(sub_e)}" for sub_e in e.exceptions]
|
||||
# detailed_error = "; ".join(error_messages)
|
||||
# api_logger.error(f"Write operation error (TaskGroup): {detailed_error}", exc_info=True)
|
||||
# return fail(BizCode.INTERNAL_ERROR, "写入失败", detailed_error)
|
||||
# api_logger.error(f"Write operation error: {str(e)}", exc_info=True)
|
||||
# return fail(BizCode.INTERNAL_ERROR, "写入失败", str(e))
|
||||
#
|
||||
#
|
||||
# @router.post("/writer_service_async", response_model=ApiResponse)
|
||||
# @cur_workspace_access_guard()
|
||||
# async def write_server_async(
|
||||
# user_input: Write_UserInput,
|
||||
# language_type: str = Header(default=None, alias="X-Language-Type"),
|
||||
# db: Session = Depends(get_db),
|
||||
# current_user: User = Depends(get_current_user)
|
||||
# ):
|
||||
# """
|
||||
# Async write service endpoint - enqueues write processing to Celery
|
||||
#
|
||||
# Args:
|
||||
# user_input: Write request containing message and end_user_id
|
||||
# language_type: 语言类型 ("zh" 中文, "en" 英文),通过 X-Language-Type Header 传递
|
||||
#
|
||||
# Returns:
|
||||
# Task ID for tracking async operation
|
||||
# Use GET /memory/write_result/{task_id} to check task status and get result
|
||||
# """
|
||||
# # 使用集中化的语言校验
|
||||
# language = get_language_from_header(language_type)
|
||||
#
|
||||
# config_id = user_input.config_id
|
||||
# workspace_id = current_user.current_workspace_id
|
||||
# api_logger.info(
|
||||
# f"Async write service: workspace_id={workspace_id}, config_id={config_id}, language_type={language}")
|
||||
#
|
||||
# # 获取 storage_type,如果为 None 则使用默认值
|
||||
# storage_type = workspace_service.get_workspace_storage_type(
|
||||
# db=db,
|
||||
# workspace_id=workspace_id,
|
||||
# user=current_user
|
||||
# )
|
||||
# if storage_type is None: storage_type = 'neo4j'
|
||||
# user_rag_memory_id = ''
|
||||
# if workspace_id:
|
||||
#
|
||||
# knowledge = knowledge_repository.get_knowledge_by_name(
|
||||
# db=db,
|
||||
# name="USER_RAG_MERORY",
|
||||
# workspace_id=workspace_id
|
||||
# )
|
||||
# if knowledge: user_rag_memory_id = str(knowledge.id)
|
||||
# api_logger.info(f"Async write: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}")
|
||||
# try:
|
||||
# # 获取标准化的消息列表
|
||||
# messages_list = memory_agent_service.get_messages_list(user_input)
|
||||
#
|
||||
# task = celery_app.send_task(
|
||||
# "app.core.memory.agent.write_message",
|
||||
# args=[user_input.end_user_id, messages_list, config_id, storage_type, user_rag_memory_id, language]
|
||||
# )
|
||||
# api_logger.info(f"Write task queued: {task.id}")
|
||||
#
|
||||
# return success(data={"task_id": task.id}, msg="写入任务已提交")
|
||||
# except Exception as e:
|
||||
# api_logger.error(f"Async write operation failed: {str(e)}")
|
||||
# return fail(BizCode.INTERNAL_ERROR, "写入失败", str(e))
|
||||
|
||||
|
||||
@router.post("/read_service", response_model=ApiResponse)
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
import asyncio
|
||||
import uuid
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Query
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.orm import Session
|
||||
@@ -47,64 +49,64 @@ def get_workspace_total_end_users(
|
||||
|
||||
@router.get("/end_users", response_model=ApiResponse)
|
||||
async def get_workspace_end_users(
|
||||
workspace_id: Optional[uuid.UUID] = Query(None, description="工作空间ID(可选,默认当前用户工作空间)"),
|
||||
keyword: Optional[str] = Query(None, description="搜索关键词(同时模糊匹配 other_name 和 id)"),
|
||||
page: int = Query(1, ge=1, description="页码,从1开始"),
|
||||
pagesize: int = Query(10, ge=1, description="每页数量"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
获取工作空间的宿主列表(高性能优化版本 v2)
|
||||
|
||||
优化策略:
|
||||
1. 批量查询 end_users(一次查询而非循环)
|
||||
2. 并发查询所有用户的记忆数量(Neo4j)
|
||||
3. RAG 模式使用批量查询(一次 SQL)
|
||||
4. 只返回必要字段减少数据传输
|
||||
5. 添加短期缓存减少重复查询
|
||||
6. 并发执行配置查询和记忆数量查询
|
||||
|
||||
返回格式:
|
||||
{
|
||||
"end_user": {"id": "uuid", "other_name": "名称"},
|
||||
"memory_num": {"total": 数量},
|
||||
"memory_config": {"memory_config_id": "id", "memory_config_name": "名称"}
|
||||
}
|
||||
获取工作空间的宿主列表(分页查询,支持模糊搜索)
|
||||
|
||||
返回工作空间下的宿主列表,支持分页查询和模糊搜索。
|
||||
通过 keyword 参数同时模糊匹配 other_name 和 id 字段。
|
||||
|
||||
Args:
|
||||
workspace_id: 工作空间ID(可选,默认当前用户工作空间)
|
||||
keyword: 搜索关键词(可选,同时模糊匹配 other_name 和 id)
|
||||
page: 页码(从1开始,默认1)
|
||||
pagesize: 每页数量(默认10)
|
||||
db: 数据库会话
|
||||
current_user: 当前用户
|
||||
|
||||
Returns:
|
||||
ApiResponse: 包含宿主列表和分页信息
|
||||
"""
|
||||
import asyncio
|
||||
import json
|
||||
from app.aioRedis import aio_redis_get, aio_redis_set
|
||||
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
# 尝试从缓存获取(30秒缓存)
|
||||
cache_key = f"end_users:workspace:{workspace_id}"
|
||||
try:
|
||||
cached_data = await aio_redis_get(cache_key)
|
||||
if cached_data:
|
||||
api_logger.info(f"从缓存获取宿主列表: workspace_id={workspace_id}")
|
||||
return success(data=json.loads(cached_data), msg="宿主列表获取成功")
|
||||
except Exception as e:
|
||||
api_logger.warning(f"Redis 缓存读取失败: {str(e)}")
|
||||
|
||||
# 如果未提供 workspace_id,使用当前用户的工作空间
|
||||
if workspace_id is None:
|
||||
workspace_id = current_user.current_workspace_id
|
||||
# 获取当前空间类型
|
||||
current_workspace_type = memory_dashboard_service.get_current_workspace_type(db, workspace_id, current_user)
|
||||
api_logger.info(f"用户 {current_user.username} 请求获取工作空间 {workspace_id} 的宿主列表")
|
||||
|
||||
# 获取 end_users(已优化为批量查询)
|
||||
end_users = memory_dashboard_service.get_workspace_end_users(
|
||||
api_logger.info(f"用户 {current_user.username} 请求获取工作空间 {workspace_id} 的宿主列表, 类型: {current_workspace_type}")
|
||||
|
||||
# 获取分页的 end_users
|
||||
end_users_result = memory_dashboard_service.get_workspace_end_users_paginated(
|
||||
db=db,
|
||||
workspace_id=workspace_id,
|
||||
current_user=current_user
|
||||
current_user=current_user,
|
||||
page=page,
|
||||
pagesize=pagesize,
|
||||
keyword=keyword
|
||||
)
|
||||
|
||||
end_users = end_users_result.get("items", [])
|
||||
total = end_users_result.get("total", 0)
|
||||
|
||||
if not end_users:
|
||||
api_logger.info("工作空间下没有宿主")
|
||||
# 缓存空结果,避免重复查询
|
||||
try:
|
||||
await aio_redis_set(cache_key, json.dumps([]), expire=30)
|
||||
except Exception as e:
|
||||
api_logger.warning(f"Redis 缓存写入失败: {str(e)}")
|
||||
return success(data=[], msg="宿主列表获取成功")
|
||||
|
||||
api_logger.info(f"工作空间下没有宿主或当前页无数据: total={total}, page={page}")
|
||||
return success(data={
|
||||
"items": [],
|
||||
"page": {
|
||||
"page": page,
|
||||
"pagesize": pagesize,
|
||||
"total": total,
|
||||
"hasnext": (page * pagesize) < total
|
||||
}
|
||||
}, msg="宿主列表获取成功")
|
||||
|
||||
end_user_ids = [str(user.id) for user in end_users]
|
||||
|
||||
|
||||
# 并发执行两个独立的查询任务
|
||||
async def get_memory_configs():
|
||||
"""获取记忆配置(在线程池中执行同步查询)"""
|
||||
@@ -116,7 +118,7 @@ async def get_workspace_end_users(
|
||||
except Exception as e:
|
||||
api_logger.error(f"批量获取记忆配置失败: {str(e)}")
|
||||
return {}
|
||||
|
||||
|
||||
async def get_memory_nums():
|
||||
"""获取记忆数量"""
|
||||
if current_workspace_type == "rag":
|
||||
@@ -130,26 +132,18 @@ async def get_workspace_end_users(
|
||||
except Exception as e:
|
||||
api_logger.error(f"批量获取 RAG chunk 数量失败: {str(e)}")
|
||||
return {uid: {"total": 0} for uid in end_user_ids}
|
||||
|
||||
|
||||
elif current_workspace_type == "neo4j":
|
||||
# Neo4j 模式:并发查询(带并发限制)
|
||||
# 使用信号量限制并发数,避免大量用户时压垮 Neo4j
|
||||
MAX_CONCURRENT_QUERIES = 10
|
||||
semaphore = asyncio.Semaphore(MAX_CONCURRENT_QUERIES)
|
||||
|
||||
async def get_neo4j_memory_num(end_user_id: str):
|
||||
async with semaphore:
|
||||
try:
|
||||
return await memory_storage_service.search_all(end_user_id)
|
||||
except Exception as e:
|
||||
api_logger.error(f"获取用户 {end_user_id} Neo4j 记忆数量失败: {str(e)}")
|
||||
return {"total": 0}
|
||||
|
||||
memory_nums_list = await asyncio.gather(*[get_neo4j_memory_num(uid) for uid in end_user_ids])
|
||||
return {end_user_ids[i]: memory_nums_list[i] for i in range(len(end_user_ids))}
|
||||
|
||||
# Neo4j 模式:批量查询(简化版本,只返回total)
|
||||
try:
|
||||
batch_result = await memory_storage_service.search_all_batch(end_user_ids)
|
||||
return {uid: {"total": count} for uid, count in batch_result.items()}
|
||||
except Exception as e:
|
||||
api_logger.error(f"批量获取 Neo4j 记忆数量失败: {str(e)}")
|
||||
return {uid: {"total": 0} for uid in end_user_ids}
|
||||
|
||||
return {uid: {"total": 0} for uid in end_user_ids}
|
||||
|
||||
|
||||
# 触发按需初始化:为 implicit_emotions_storage 中没有记录的用户异步生成数据
|
||||
try:
|
||||
from app.celery_app import celery_app as _celery_app
|
||||
@@ -170,13 +164,13 @@ async def get_workspace_end_users(
|
||||
get_memory_configs(),
|
||||
get_memory_nums()
|
||||
)
|
||||
|
||||
# 构建结果(优化:使用列表推导式)
|
||||
result = []
|
||||
|
||||
# 构建结果列表
|
||||
items = []
|
||||
for end_user in end_users:
|
||||
user_id = str(end_user.id)
|
||||
config_info = memory_configs_map.get(user_id, {})
|
||||
result.append({
|
||||
items.append({
|
||||
'end_user': {
|
||||
'id': user_id,
|
||||
'other_name': end_user.other_name
|
||||
@@ -187,14 +181,27 @@ async def get_workspace_end_users(
|
||||
"memory_config_name": config_info.get("memory_config_name")
|
||||
}
|
||||
})
|
||||
|
||||
# 写入缓存(30秒过期)
|
||||
|
||||
# 触发社区聚类补全任务(异步,不阻塞接口响应)
|
||||
try:
|
||||
await aio_redis_set(cache_key, json.dumps(result), expire=30)
|
||||
from app.tasks import init_community_clustering_for_users
|
||||
init_community_clustering_for_users.delay(end_user_ids=end_user_ids, workspace_id=str(workspace_id))
|
||||
api_logger.info(f"已触发社区聚类补全任务,候选用户数: {len(end_user_ids)}")
|
||||
except Exception as e:
|
||||
api_logger.warning(f"Redis 缓存写入失败: {str(e)}")
|
||||
|
||||
api_logger.info(f"成功获取 {len(end_users)} 个宿主记录")
|
||||
api_logger.warning(f"触发社区聚类补全任务失败(不影响主流程): {str(e)}")
|
||||
|
||||
# 构建分页响应
|
||||
result = {
|
||||
"items": items,
|
||||
"page": {
|
||||
"page": page,
|
||||
"pagesize": pagesize,
|
||||
"total": total,
|
||||
"hasnext": (page * pagesize) < total
|
||||
}
|
||||
}
|
||||
|
||||
api_logger.info(f"成功获取 {len(end_users)} 个宿主记录,总计 {total} 条")
|
||||
return success(data=result, msg="宿主列表获取成功")
|
||||
|
||||
|
||||
@@ -584,7 +591,7 @@ async def dashboard_data(
|
||||
"total_api_call": None
|
||||
}
|
||||
|
||||
# 1. 获取记忆总量(total_memory)
|
||||
# 1. 获取记忆总量(total_memory)—— neo4j 独有逻辑:查询 neo4j 存储节点
|
||||
try:
|
||||
total_memory_data = await memory_dashboard_service.get_workspace_total_memory_count(
|
||||
db=db,
|
||||
@@ -593,46 +600,33 @@ async def dashboard_data(
|
||||
end_user_id=end_user_id
|
||||
)
|
||||
neo4j_data["total_memory"] = total_memory_data.get("total_memory_count", 0)
|
||||
# total_app: 统计当前空间下的所有app数量
|
||||
from app.repositories import app_repository
|
||||
apps_orm = app_repository.get_apps_by_workspace_id(db, workspace_id)
|
||||
neo4j_data["total_app"] = len(apps_orm)
|
||||
api_logger.info(f"成功获取记忆总量: {neo4j_data['total_memory']}, 应用数量: {neo4j_data['total_app']}")
|
||||
api_logger.info(f"成功获取记忆总量: {neo4j_data['total_memory']}")
|
||||
except Exception as e:
|
||||
api_logger.warning(f"获取记忆总量失败: {str(e)}")
|
||||
|
||||
# 2. 获取知识库类型统计(total_knowledge)
|
||||
try:
|
||||
from app.services.memory_agent_service import MemoryAgentService
|
||||
memory_agent_service = MemoryAgentService()
|
||||
knowledge_stats = await memory_agent_service.get_knowledge_type_stats(
|
||||
end_user_id=end_user_id,
|
||||
only_active=True,
|
||||
current_workspace_id=workspace_id,
|
||||
db=db
|
||||
)
|
||||
neo4j_data["total_knowledge"] = knowledge_stats.get("total", 0)
|
||||
api_logger.info(f"成功获取知识库类型统计total: {neo4j_data['total_knowledge']}")
|
||||
except Exception as e:
|
||||
api_logger.warning(f"获取知识库类型统计失败: {str(e)}")
|
||||
# 2. 获取共享统计数据(total_app、total_knowledge、total_api_call)
|
||||
common_stats = memory_dashboard_service.get_dashboard_common_stats(db, workspace_id)
|
||||
neo4j_data.update(common_stats)
|
||||
api_logger.info(f"成功获取共享统计: app={common_stats['total_app']}, knowledge={common_stats['total_knowledge']}, api_call={common_stats['total_api_call']}")
|
||||
|
||||
# 3. 获取API调用统计(total_api_call)
|
||||
# 计算昨日对比
|
||||
try:
|
||||
# 使用 AppStatisticsService 获取真实的API调用统计
|
||||
app_stats_service = AppStatisticsService(db)
|
||||
api_stats = app_stats_service.get_workspace_api_statistics(
|
||||
changes = memory_dashboard_service.get_dashboard_yesterday_changes(
|
||||
db=db,
|
||||
workspace_id=workspace_id,
|
||||
start_date=start_date,
|
||||
end_date=end_date
|
||||
storage_type=storage_type,
|
||||
today_data=neo4j_data
|
||||
)
|
||||
# 计算总调用次数
|
||||
total_api_calls = sum(item.get("total_calls", 0) for item in api_stats)
|
||||
neo4j_data["total_api_call"] = total_api_calls
|
||||
api_logger.info(f"成功获取API调用统计: {neo4j_data['total_api_call']}")
|
||||
neo4j_data.update(changes)
|
||||
except Exception as e:
|
||||
api_logger.error(f"获取API调用统计失败: {str(e)}")
|
||||
neo4j_data["total_api_call"] = 0
|
||||
|
||||
api_logger.warning(f"计算neo4j昨日对比失败: {str(e)}")
|
||||
neo4j_data.update({
|
||||
"total_memory_change": None,
|
||||
"total_app_change": None,
|
||||
"total_knowledge_change": None,
|
||||
"total_api_call_change": None,
|
||||
})
|
||||
|
||||
result["neo4j_data"] = neo4j_data
|
||||
api_logger.info("成功获取neo4j_data")
|
||||
|
||||
@@ -645,41 +639,37 @@ async def dashboard_data(
|
||||
"total_api_call": None
|
||||
}
|
||||
|
||||
# 获取RAG相关数据
|
||||
# 1. 获取记忆总量(total_memory)—— rag 独有逻辑:查询 document 表的 chunk_num
|
||||
try:
|
||||
# total_memory: 只统计用户知识库(permission_id='Memory')的chunk数
|
||||
total_chunk = memory_dashboard_service.get_rag_user_kb_total_chunk(db, current_user)
|
||||
rag_data["total_memory"] = total_chunk
|
||||
|
||||
# total_app: 统计当前空间下的所有app数量
|
||||
from app.repositories import app_repository
|
||||
apps_orm = app_repository.get_apps_by_workspace_id(db, workspace_id)
|
||||
rag_data["total_app"] = len(apps_orm)
|
||||
|
||||
# total_knowledge: 使用 total_kb(总知识库数)
|
||||
total_kb = memory_dashboard_service.get_rag_total_kb(db, current_user)
|
||||
rag_data["total_knowledge"] = total_kb
|
||||
|
||||
# total_api_call: 使用 AppStatisticsService 获取真实的API调用统计
|
||||
try:
|
||||
app_stats_service = AppStatisticsService(db)
|
||||
api_stats = app_stats_service.get_workspace_api_statistics(
|
||||
workspace_id=workspace_id,
|
||||
start_date=start_date,
|
||||
end_date=end_date
|
||||
)
|
||||
# 计算总调用次数
|
||||
total_api_calls = sum(item.get("total_calls", 0) for item in api_stats)
|
||||
rag_data["total_api_call"] = total_api_calls
|
||||
api_logger.info(f"成功获取RAG模式API调用统计: {rag_data['total_api_call']}")
|
||||
except Exception as e:
|
||||
api_logger.warning(f"获取RAG模式API调用统计失败,使用默认值: {str(e)}")
|
||||
rag_data["total_api_call"] = 0
|
||||
|
||||
api_logger.info(f"成功获取RAG相关数据: memory={total_chunk}, app={len(apps_orm)}, knowledge={total_kb}, api_calls={rag_data['total_api_call']}")
|
||||
api_logger.info(f"成功获取RAG记忆总量: {total_chunk}")
|
||||
except Exception as e:
|
||||
api_logger.warning(f"获取RAG相关数据失败: {str(e)}")
|
||||
api_logger.warning(f"获取RAG记忆总量失败: {str(e)}")
|
||||
|
||||
# 2. 获取共享统计数据(total_app、total_knowledge、total_api_call)
|
||||
common_stats = memory_dashboard_service.get_dashboard_common_stats(db, workspace_id)
|
||||
rag_data.update(common_stats)
|
||||
api_logger.info(f"成功获取共享统计: app={common_stats['total_app']}, knowledge={common_stats['total_knowledge']}, api_call={common_stats['total_api_call']}")
|
||||
|
||||
# 计算昨日对比
|
||||
try:
|
||||
changes = memory_dashboard_service.get_dashboard_yesterday_changes(
|
||||
db=db,
|
||||
workspace_id=workspace_id,
|
||||
storage_type=storage_type,
|
||||
today_data=rag_data
|
||||
)
|
||||
rag_data.update(changes)
|
||||
except Exception as e:
|
||||
api_logger.warning(f"计算RAG昨日对比失败: {str(e)}")
|
||||
rag_data.update({
|
||||
"total_memory_change": None,
|
||||
"total_app_change": None,
|
||||
"total_knowledge_change": None,
|
||||
"total_api_call_change": None,
|
||||
})
|
||||
|
||||
result["rag_data"] = rag_data
|
||||
api_logger.info("成功获取rag_data")
|
||||
|
||||
|
||||
@@ -31,6 +31,7 @@ from app.schemas.memory_storage_schema import (
|
||||
ForgettingCurveRequest,
|
||||
ForgettingCurveResponse,
|
||||
ForgettingCurvePoint,
|
||||
PendingNodesResponse,
|
||||
)
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.services.memory_forget_service import MemoryForgetService
|
||||
@@ -308,6 +309,100 @@ async def get_forgetting_stats(
|
||||
return fail(BizCode.INTERNAL_ERROR, "获取遗忘引擎统计失败", str(e))
|
||||
|
||||
|
||||
@router.get("/pending-nodes", response_model=ApiResponse)
|
||||
async def get_pending_nodes(
|
||||
end_user_id: str,
|
||||
page: int = 1,
|
||||
pagesize: int = 10,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
获取待遗忘节点列表(独立分页接口)
|
||||
|
||||
查询满足遗忘条件的节点(激活值低于阈值且最后访问时间超过最小天数)。
|
||||
此接口独立分页,与 /stats 接口分离。
|
||||
|
||||
Args:
|
||||
end_user_id: 组ID(即 end_user_id,必填)
|
||||
page: 页码(从1开始,默认1)
|
||||
pagesize: 每页数量(默认10)
|
||||
current_user: 当前用户
|
||||
db: 数据库会话
|
||||
|
||||
Returns:
|
||||
ApiResponse: 包含待遗忘节点列表和分页信息的响应
|
||||
|
||||
Examples:
|
||||
- 第1页,每页10条:GET /memory/forget-memory/pending-nodes?end_user_id=xxx&page=1&pagesize=10
|
||||
- 第2页,每页20条:GET /memory/forget-memory/pending-nodes?end_user_id=xxx&page=2&pagesize=20
|
||||
|
||||
Notes:
|
||||
- page 从1开始,pagesize 必须大于0
|
||||
- 返回格式:{"items": [...], "page": {"page": 1, "pagesize": 10, "total": 100, "hasnext": true}}
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
# 检查用户是否已选择工作空间
|
||||
if workspace_id is None:
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试获取待遗忘节点但未选择工作空间")
|
||||
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
||||
|
||||
# 验证 end_user_id 必填
|
||||
if not end_user_id:
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试获取待遗忘节点但未提供 end_user_id")
|
||||
return fail(BizCode.INVALID_PARAMETER, "end_user_id 不能为空", "end_user_id is required")
|
||||
|
||||
# 通过 end_user_id 获取关联的 config_id
|
||||
try:
|
||||
from app.services.memory_agent_service import get_end_user_connected_config
|
||||
|
||||
connected_config = get_end_user_connected_config(end_user_id, db)
|
||||
config_id = connected_config.get("memory_config_id")
|
||||
config_id = resolve_config_id(config_id, db)
|
||||
|
||||
if config_id is None:
|
||||
api_logger.warning(f"终端用户 {end_user_id} 未关联记忆配置")
|
||||
return fail(BizCode.INVALID_PARAMETER, f"终端用户 {end_user_id} 未关联记忆配置", "memory_config_id is None")
|
||||
|
||||
api_logger.debug(f"通过 end_user_id={end_user_id} 获取到 config_id={config_id}")
|
||||
except ValueError as e:
|
||||
api_logger.warning(f"获取终端用户配置失败: {str(e)}")
|
||||
return fail(BizCode.INVALID_PARAMETER, str(e), "ValueError")
|
||||
except Exception as e:
|
||||
api_logger.error(f"获取终端用户配置时发生错误: {str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "获取终端用户配置失败", str(e))
|
||||
|
||||
# 验证分页参数
|
||||
if page < 1:
|
||||
return fail(BizCode.INVALID_PARAMETER, "page 必须大于等于1", "page < 1")
|
||||
if pagesize < 1:
|
||||
return fail(BizCode.INVALID_PARAMETER, "pagesize 必须大于等于1", "pagesize < 1")
|
||||
|
||||
api_logger.info(
|
||||
f"用户 {current_user.username} 在工作空间 {workspace_id} 请求获取待遗忘节点: "
|
||||
f"end_user_id={end_user_id}, page={page}, pagesize={pagesize}"
|
||||
)
|
||||
|
||||
try:
|
||||
# 调用服务层获取待遗忘节点列表
|
||||
result = await forget_service.get_pending_nodes(
|
||||
db=db,
|
||||
end_user_id=end_user_id,
|
||||
config_id=config_id,
|
||||
page=page,
|
||||
pagesize=pagesize
|
||||
)
|
||||
|
||||
# 构建响应
|
||||
response_data = PendingNodesResponse(**result)
|
||||
|
||||
return success(data=response_data.model_dump(), msg="查询成功")
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"获取待遗忘节点列表失败: {str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "获取待遗忘节点列表失败", str(e))
|
||||
|
||||
|
||||
@router.post("/forgetting_curve", response_model=ApiResponse)
|
||||
async def get_forgetting_curve(
|
||||
request: ForgettingCurveRequest,
|
||||
|
||||
@@ -1,3 +1,19 @@
|
||||
"""
|
||||
Memory Reflection Controller
|
||||
|
||||
This module provides REST API endpoints for managing memory reflection configurations
|
||||
and operations. It handles reflection engine setup, configuration management, and
|
||||
execution of self-reflection processes across memory systems.
|
||||
|
||||
Key Features:
|
||||
- Reflection configuration management (save, retrieve, update)
|
||||
- Workspace-wide reflection execution across multiple applications
|
||||
- Individual configuration-based reflection runs
|
||||
- Multi-language support for reflection outputs
|
||||
- Integration with Neo4j memory storage and LLM models
|
||||
- Comprehensive error handling and logging
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
import uuid
|
||||
@@ -28,9 +44,13 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from app.utils.config_utils import resolve_config_id
|
||||
|
||||
# Load environment variables for configuration
|
||||
load_dotenv()
|
||||
|
||||
# Initialize API logger for request tracking and debugging
|
||||
api_logger = get_api_logger()
|
||||
|
||||
# Configure router with prefix and tags for API organization
|
||||
router = APIRouter(
|
||||
prefix="/memory",
|
||||
tags=["Memory"],
|
||||
@@ -43,7 +63,38 @@ async def save_reflection_config(
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
"""Save reflection configuration to data_comfig table"""
|
||||
"""
|
||||
Save reflection configuration to memory config table
|
||||
|
||||
Persists reflection engine configuration settings to the data_config table,
|
||||
including reflection parameters, model settings, and evaluation criteria.
|
||||
Validates configuration parameters and ensures data consistency.
|
||||
|
||||
Args:
|
||||
request: Memory reflection configuration data including:
|
||||
- config_id: Configuration identifier to update
|
||||
- reflection_enabled: Whether reflection is enabled
|
||||
- reflection_period_in_hours: Reflection execution interval
|
||||
- reflexion_range: Scope of reflection (partial/all)
|
||||
- baseline: Reflection strategy (time/fact/hybrid)
|
||||
- reflection_model_id: LLM model for reflection operations
|
||||
- memory_verify: Enable memory verification checks
|
||||
- quality_assessment: Enable quality assessment evaluation
|
||||
current_user: Authenticated user saving the configuration
|
||||
db: Database session for data operations
|
||||
|
||||
Returns:
|
||||
dict: Success response with saved reflection configuration data
|
||||
|
||||
Raises:
|
||||
HTTPException 400: If config_id is missing or parameters are invalid
|
||||
HTTPException 500: If configuration save operation fails
|
||||
|
||||
Database Operations:
|
||||
- Updates memory_config table with reflection settings
|
||||
- Commits transaction and refreshes entity
|
||||
- Maintains configuration consistency
|
||||
"""
|
||||
try:
|
||||
config_id = request.config_id
|
||||
config_id = resolve_config_id(config_id, db)
|
||||
@@ -54,6 +105,7 @@ async def save_reflection_config(
|
||||
)
|
||||
api_logger.info(f"用户 {current_user.username} 保存反思配置,config_id: {config_id}")
|
||||
|
||||
# Update reflection configuration in database
|
||||
memory_config = MemoryConfigRepository.update_reflection_config(
|
||||
db,
|
||||
config_id=config_id,
|
||||
@@ -66,6 +118,7 @@ async def save_reflection_config(
|
||||
quality_assessment=request.quality_assessment
|
||||
)
|
||||
|
||||
# Commit transaction and refresh entity
|
||||
db.commit()
|
||||
db.refresh(memory_config)
|
||||
|
||||
@@ -102,13 +155,55 @@ async def start_workspace_reflection(
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
"""启动工作空间中所有匹配应用的反思功能"""
|
||||
"""
|
||||
Start reflection functionality for all matching applications in workspace
|
||||
|
||||
Initiates reflection processes across all applications within the user's current
|
||||
workspace that have valid memory configurations. Processes each application's
|
||||
configurations and associated end users, executing reflection operations
|
||||
with proper error isolation and transaction management.
|
||||
|
||||
This endpoint serves as a workspace-wide reflection orchestrator, ensuring
|
||||
that reflection failures for individual users don't affect other operations.
|
||||
|
||||
Args:
|
||||
current_user: Authenticated user initiating workspace reflection
|
||||
db: Database session for configuration queries
|
||||
|
||||
Returns:
|
||||
dict: Success response with reflection results for all processed applications:
|
||||
- app_id: Application identifier
|
||||
- config_id: Memory configuration identifier
|
||||
- end_user_id: End user identifier
|
||||
- reflection_result: Individual reflection operation result
|
||||
|
||||
Processing Logic:
|
||||
1. Retrieve all applications in the current workspace
|
||||
2. Filter applications with valid memory configurations
|
||||
3. For each configuration, find matching releases
|
||||
4. Execute reflection for each end user with isolated transactions
|
||||
5. Aggregate results with error handling per user
|
||||
|
||||
Error Handling:
|
||||
- Individual user reflection failures are isolated
|
||||
- Failed operations are logged and included in results
|
||||
- Database transactions are isolated per user to prevent cascading failures
|
||||
- Comprehensive error reporting for debugging
|
||||
|
||||
Raises:
|
||||
HTTPException 500: If workspace reflection initialization fails
|
||||
|
||||
Performance Notes:
|
||||
- Uses independent database sessions for each user operation
|
||||
- Prevents transaction failures from affecting other users
|
||||
- Comprehensive logging for operation tracking
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
try:
|
||||
api_logger.info(f"用户 {current_user.username} 启动workspace反思,workspace_id: {workspace_id}")
|
||||
|
||||
# 使用独立的数据库会话来获取工作空间应用详情,避免事务失败
|
||||
# Use independent database session to get workspace app details, avoiding transaction failures
|
||||
from app.db import get_db_context
|
||||
with get_db_context() as query_db:
|
||||
service = WorkspaceAppService(query_db)
|
||||
@@ -116,8 +211,9 @@ async def start_workspace_reflection(
|
||||
|
||||
reflection_results = []
|
||||
|
||||
# Process each application in the workspace
|
||||
for data in result['apps_detailed_info']:
|
||||
# 跳过没有配置的应用
|
||||
# Skip applications without configurations
|
||||
if not data['memory_configs']:
|
||||
api_logger.debug(f"应用 {data['id']} 没有memory_configs,跳过")
|
||||
continue
|
||||
@@ -126,22 +222,22 @@ async def start_workspace_reflection(
|
||||
memory_configs = data['memory_configs']
|
||||
end_users = data['end_users']
|
||||
|
||||
# 为每个配置和用户组合执行反思
|
||||
# Execute reflection for each configuration and user combination
|
||||
for config in memory_configs:
|
||||
config_id_str = str(config['config_id'])
|
||||
|
||||
# 找到匹配此配置的所有release
|
||||
# Find all releases matching this configuration
|
||||
matching_releases = [r for r in releases if str(r['config']) == config_id_str]
|
||||
|
||||
if not matching_releases:
|
||||
api_logger.debug(f"配置 {config_id_str} 没有匹配的release")
|
||||
continue
|
||||
|
||||
# 为每个用户执行反思 - 使用独立的数据库会话
|
||||
# Execute reflection for each user - using independent database sessions
|
||||
for user in end_users:
|
||||
api_logger.info(f"为用户 {user['id']} 启动反思,config_id: {config_id_str}")
|
||||
|
||||
# 为每个用户创建独立的数据库会话,避免事务失败影响其他用户
|
||||
# Create independent database session for each user to avoid transaction failure impact
|
||||
with get_db_context() as user_db:
|
||||
try:
|
||||
reflection_service = MemoryReflectionService(user_db)
|
||||
@@ -184,14 +280,51 @@ async def start_reflection_configs(
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
"""通过config_id查询memory_config表中的反思配置信息"""
|
||||
"""
|
||||
Query reflection configuration information by config_id
|
||||
|
||||
Retrieves detailed reflection configuration settings from the memory_config
|
||||
table for a specific configuration ID. Provides comprehensive reflection
|
||||
parameters including model settings, evaluation criteria, and operational flags.
|
||||
|
||||
Args:
|
||||
config_id: Configuration identifier (UUID or integer) to query
|
||||
current_user: Authenticated user making the request
|
||||
db: Database session for data operations
|
||||
|
||||
Returns:
|
||||
dict: Success response with detailed reflection configuration:
|
||||
- config_id: Resolved configuration identifier
|
||||
- reflection_enabled: Whether reflection is enabled for this config
|
||||
- reflection_period_in_hours: Reflection execution interval
|
||||
- reflexion_range: Scope of reflection operations (partial/all)
|
||||
- baseline: Reflection strategy (time/fact/hybrid)
|
||||
- reflection_model_id: LLM model identifier for reflection
|
||||
- memory_verify: Memory verification flag
|
||||
- quality_assessment: Quality assessment flag
|
||||
|
||||
Database Operations:
|
||||
- Queries memory_config table by resolved config_id
|
||||
- Retrieves all reflection-related configuration fields
|
||||
- Resolves configuration ID for consistent formatting
|
||||
|
||||
Raises:
|
||||
HTTPException 404: If configuration with specified ID is not found
|
||||
HTTPException 500: If configuration query operation fails
|
||||
|
||||
ID Resolution:
|
||||
- Supports both UUID and integer config_id formats
|
||||
- Automatically resolves to appropriate internal format
|
||||
- Maintains consistency across different ID representations
|
||||
"""
|
||||
config_id = resolve_config_id(config_id, db)
|
||||
try:
|
||||
config_id=resolve_config_id(config_id,db)
|
||||
api_logger.info(f"用户 {current_user.username} 查询反思配置,config_id: {config_id}")
|
||||
result = MemoryConfigRepository.query_reflection_config_by_id(db, config_id)
|
||||
memory_config_id = resolve_config_id(result.config_id, db)
|
||||
# 构建返回数据
|
||||
|
||||
# Build response data with comprehensive configuration details
|
||||
reflection_config = {
|
||||
"config_id": memory_config_id,
|
||||
"reflection_enabled": result.enable_self_reflexion,
|
||||
@@ -204,10 +337,12 @@ async def start_reflection_configs(
|
||||
}
|
||||
api_logger.info(f"成功查询反思配置,config_id: {config_id}")
|
||||
return success(data=reflection_config, msg="反思配置查询成功")
|
||||
|
||||
|
||||
api_logger.info(f"Successfully queried reflection config, config_id: {config_id}")
|
||||
return success(data=reflection_config, msg="Reflection configuration query successful")
|
||||
|
||||
except HTTPException:
|
||||
# 重新抛出HTTP异常
|
||||
# Re-raise HTTP exceptions without modification
|
||||
raise
|
||||
except Exception as e:
|
||||
api_logger.error(f"查询反思配置失败: {str(e)}")
|
||||
@@ -223,13 +358,66 @@ async def reflection_run(
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
"""Activate the reflection function for all matching applications in the workspace"""
|
||||
# 使用集中化的语言校验
|
||||
"""
|
||||
Execute reflection engine with specified configuration
|
||||
|
||||
Runs the reflection engine using configuration parameters from the database.
|
||||
Validates model availability, sets up the reflection engine with proper
|
||||
configuration, and executes the reflection process with multi-language support.
|
||||
|
||||
This endpoint provides a test run capability for reflection configurations,
|
||||
allowing users to validate their reflection settings and see results before
|
||||
deploying to production environments.
|
||||
|
||||
Args:
|
||||
config_id: Configuration identifier (UUID or integer) for reflection settings
|
||||
language_type: Language preference header for output localization (optional)
|
||||
current_user: Authenticated user executing the reflection
|
||||
db: Database session for configuration queries
|
||||
|
||||
Returns:
|
||||
dict: Success response with reflection execution results including:
|
||||
- baseline: Reflection strategy used
|
||||
- source_data: Input data processed
|
||||
- memory_verifies: Memory verification results (if enabled)
|
||||
- quality_assessments: Quality assessment results (if enabled)
|
||||
- reflexion_data: Generated reflection insights and solutions
|
||||
|
||||
Configuration Validation:
|
||||
- Verifies configuration exists in database
|
||||
- Validates LLM model availability
|
||||
- Falls back to default model if specified model is unavailable
|
||||
- Ensures all required parameters are properly set
|
||||
|
||||
Reflection Engine Setup:
|
||||
- Creates ReflectionConfig with database parameters
|
||||
- Initializes Neo4j connector for memory access
|
||||
- Sets up ReflectionEngine with validated model
|
||||
- Configures language preferences for output
|
||||
|
||||
Error Handling:
|
||||
- Model validation with fallback to default
|
||||
- Configuration validation and error reporting
|
||||
- Comprehensive logging for debugging
|
||||
- Graceful handling of missing configurations
|
||||
|
||||
Raises:
|
||||
HTTPException 404: If configuration is not found
|
||||
HTTPException 500: If reflection execution fails
|
||||
|
||||
Performance Notes:
|
||||
- Direct database query for configuration retrieval
|
||||
- Model validation to prevent runtime failures
|
||||
- Efficient reflection engine initialization
|
||||
- Language-aware output processing
|
||||
"""
|
||||
# Use centralized language validation for consistent localization
|
||||
language = get_language_from_header(language_type)
|
||||
|
||||
api_logger.info(f"用户 {current_user.username} 查询反思配置,config_id: {config_id}")
|
||||
config_id = resolve_config_id(config_id, db)
|
||||
# 使用MemoryConfigRepository查询反思配置
|
||||
|
||||
# Query reflection configuration using MemoryConfigRepository
|
||||
result = MemoryConfigRepository.query_reflection_config_by_id(db, config_id)
|
||||
if not result:
|
||||
raise HTTPException(
|
||||
@@ -239,7 +427,7 @@ async def reflection_run(
|
||||
|
||||
api_logger.info(f"成功查询反思配置,config_id: {config_id}")
|
||||
|
||||
# 验证模型ID是否存在
|
||||
# Validate model ID existence
|
||||
model_id = result.reflection_model_id
|
||||
if model_id:
|
||||
try:
|
||||
@@ -250,6 +438,7 @@ async def reflection_run(
|
||||
# 可以设置为None,让反思引擎使用默认模型
|
||||
model_id = None
|
||||
|
||||
# Create reflection configuration with database parameters
|
||||
config = ReflectionConfig(
|
||||
enabled=result.enable_self_reflexion,
|
||||
iteration_period=result.iteration_period,
|
||||
@@ -262,11 +451,13 @@ async def reflection_run(
|
||||
model_id=model_id,
|
||||
language_type=language_type
|
||||
)
|
||||
|
||||
# Initialize Neo4j connector and reflection engine
|
||||
connector = Neo4jConnector()
|
||||
engine = ReflectionEngine(
|
||||
config=config,
|
||||
neo4j_connector=connector,
|
||||
llm_client=model_id # 传入验证后的 model_id
|
||||
llm_client=model_id # Pass validated model_id
|
||||
)
|
||||
|
||||
result=await (engine.reflection_run())
|
||||
|
||||
@@ -1,3 +1,18 @@
|
||||
"""
|
||||
Memory Short Term Controller
|
||||
|
||||
This module provides REST API endpoints for managing short-term and long-term memory
|
||||
data retrieval and analysis. It handles memory system statistics, data aggregation,
|
||||
and provides comprehensive memory insights for end users.
|
||||
|
||||
Key Features:
|
||||
- Short-term memory data retrieval and statistics
|
||||
- Long-term memory data aggregation
|
||||
- Entity count integration
|
||||
- Multi-language response support
|
||||
- Memory system analytics and reporting
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from dotenv import load_dotenv
|
||||
@@ -13,9 +28,13 @@ from app.models.user_model import User
|
||||
from app.services.memory_short_service import LongService, ShortService
|
||||
from app.services.memory_storage_service import search_entity
|
||||
|
||||
# Load environment variables for configuration
|
||||
load_dotenv()
|
||||
|
||||
# Initialize API logger for request tracking and debugging
|
||||
api_logger = get_api_logger()
|
||||
|
||||
# Configure router with prefix and tags for API organization
|
||||
router = APIRouter(
|
||||
prefix="/memory/short",
|
||||
tags=["Memory"],
|
||||
@@ -27,24 +46,73 @@ async def short_term_configs(
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
# 使用集中化的语言校验
|
||||
"""
|
||||
Retrieve comprehensive short-term and long-term memory statistics
|
||||
|
||||
Provides a comprehensive overview of memory system data for a specific end user,
|
||||
including short-term memory entries, long-term memory aggregations, entity counts,
|
||||
and retrieval statistics. Supports multi-language responses based on request headers.
|
||||
|
||||
This endpoint serves as a central dashboard for memory system analytics, combining
|
||||
data from multiple memory subsystems to provide a holistic view of user memory state.
|
||||
|
||||
Args:
|
||||
end_user_id: Unique identifier for the end user whose memory data to retrieve
|
||||
language_type: Language preference header for response localization (optional)
|
||||
current_user: Authenticated user making the request (injected by dependency)
|
||||
db: Database session for data operations (injected by dependency)
|
||||
|
||||
Returns:
|
||||
dict: Success response containing comprehensive memory statistics:
|
||||
- short_term: List of short-term memory entries with detailed data
|
||||
- long_term: List of long-term memory aggregations and summaries
|
||||
- entity: Count of entities associated with the end user
|
||||
- retrieval_number: Total count of short-term memory retrievals
|
||||
- long_term_number: Total count of long-term memory entries
|
||||
|
||||
Response Structure:
|
||||
{
|
||||
"code": 200,
|
||||
"msg": "Short-term memory system data retrieved successfully",
|
||||
"data": {
|
||||
"short_term": [...], # Short-term memory entries
|
||||
"long_term": [...], # Long-term memory data
|
||||
"entity": 42, # Entity count
|
||||
"retrieval_number": 156, # Short-term retrieval count
|
||||
"long_term_number": 23 # Long-term memory count
|
||||
}
|
||||
}
|
||||
|
||||
Raises:
|
||||
HTTPException: If end_user_id is invalid or data retrieval fails
|
||||
|
||||
Performance Notes:
|
||||
- Combines multiple service calls for comprehensive data
|
||||
- Entity search is performed asynchronously for better performance
|
||||
- Response time depends on memory data volume for the specified user
|
||||
"""
|
||||
# Use centralized language validation for consistent localization
|
||||
language = get_language_from_header(language_type)
|
||||
|
||||
# 获取短期记忆数据
|
||||
short_term=ShortService(end_user_id, db)
|
||||
short_result=short_term.get_short_databasets()
|
||||
short_count=short_term.get_short_count()
|
||||
# Retrieve short-term memory data and statistics
|
||||
short_term = ShortService(end_user_id, db)
|
||||
short_result = short_term.get_short_databasets() # Get short-term memory entries
|
||||
short_count = short_term.get_short_count() # Get short-term retrieval count
|
||||
|
||||
long_term=LongService(end_user_id, db)
|
||||
long_result=long_term.get_long_databasets()
|
||||
# Retrieve long-term memory data and aggregations
|
||||
long_term = LongService(end_user_id, db)
|
||||
long_result = long_term.get_long_databasets() # Get long-term memory entries
|
||||
|
||||
# Get entity count for the specified end user
|
||||
entity_result = await search_entity(end_user_id)
|
||||
|
||||
# Compile comprehensive memory statistics response
|
||||
result = {
|
||||
'short_term': short_result,
|
||||
'long_term': long_result,
|
||||
'entity': entity_result.get('num', 0),
|
||||
"retrieval_number":short_count,
|
||||
"long_term_number":len(long_result)
|
||||
'short_term': short_result, # Short-term memory entries
|
||||
'long_term': long_result, # Long-term memory data
|
||||
'entity': entity_result.get('num', 0), # Entity count (default to 0 if not found)
|
||||
"retrieval_number": short_count, # Short-term retrieval statistics
|
||||
"long_term_number": len(long_result) # Long-term memory entry count
|
||||
}
|
||||
|
||||
return success(data=result, msg="短期记忆系统数据获取成功")
|
||||
@@ -26,7 +26,7 @@ from app.services.memory_storage_service import (
|
||||
analytics_hot_memory_tags,
|
||||
analytics_recent_activity_stats,
|
||||
kb_type_distribution,
|
||||
search_all,
|
||||
search_all_batch,
|
||||
search_chunk,
|
||||
search_detials,
|
||||
search_dialogue,
|
||||
@@ -54,8 +54,8 @@ router = APIRouter(
|
||||
|
||||
@router.get("/info", response_model=ApiResponse)
|
||||
async def get_storage_info(
|
||||
storage_id: str,
|
||||
current_user: User = Depends(get_current_user)
|
||||
storage_id: str,
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Example wrapper endpoint - retrieves storage information
|
||||
@@ -75,24 +75,19 @@ async def get_storage_info(
|
||||
return fail(BizCode.INTERNAL_ERROR, "存储信息获取失败", str(e))
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@router.post("/create_config", response_model=ApiResponse) # 创建配置文件,其他参数默认
|
||||
@router.post("/create_config", response_model=ApiResponse) # 创建配置文件,其他参数默认
|
||||
def create_config(
|
||||
payload: ConfigParamsCreate,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
x_language_type: Optional[str] = Header(None, alias="X-Language-Type"),
|
||||
payload: ConfigParamsCreate,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
x_language_type: Optional[str] = Header(None, alias="X-Language-Type"),
|
||||
) -> dict:
|
||||
workspace_id = current_user.current_workspace_id
|
||||
# 检查用户是否已选择工作空间
|
||||
if workspace_id is None:
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试创建配置但未选择工作空间")
|
||||
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
||||
|
||||
|
||||
api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求创建配置: {payload.config_name}")
|
||||
try:
|
||||
# 将 workspace_id 注入到 payload 中(保持为 UUID 类型)
|
||||
@@ -107,9 +102,11 @@ def create_config(
|
||||
api_logger.warning(f"重复的配置名称 '{config_name}' 在工作空间 {workspace_id}")
|
||||
lang = get_language_from_header(x_language_type)
|
||||
if lang == "en":
|
||||
msg = fail(BizCode.BAD_REQUEST, "Config name already exists", f"A config named \"{config_name}\" already exists in the current workspace. Please use a different name.")
|
||||
msg = fail(BizCode.BAD_REQUEST, "Config name already exists",
|
||||
f"A config named \"{config_name}\" already exists in the current workspace. Please use a different name.")
|
||||
else:
|
||||
msg = fail(BizCode.BAD_REQUEST, "配置名称已存在", f"当前工作空间下已存在名为「{config_name}」的记忆配置,请使用其他名称")
|
||||
msg = fail(BizCode.BAD_REQUEST, "配置名称已存在",
|
||||
f"当前工作空间下已存在名为「{config_name}」的记忆配置,请使用其他名称")
|
||||
return JSONResponse(status_code=400, content=msg)
|
||||
api_logger.error(f"Create config failed: {err_str}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "创建配置失败", err_str)
|
||||
@@ -119,9 +116,11 @@ def create_config(
|
||||
api_logger.warning(f"重复的配置名称 '{payload.config_name}' 在工作空间 {workspace_id}")
|
||||
lang = get_language_from_header(x_language_type)
|
||||
if lang == "en":
|
||||
msg = fail(BizCode.BAD_REQUEST, "Config name already exists", f"A config named \"{payload.config_name}\" already exists in the current workspace. Please use a different name.")
|
||||
msg = fail(BizCode.BAD_REQUEST, "Config name already exists",
|
||||
f"A config named \"{payload.config_name}\" already exists in the current workspace. Please use a different name.")
|
||||
else:
|
||||
msg = fail(BizCode.BAD_REQUEST, "配置名称已存在", f"当前工作空间下已存在名为「{payload.config_name}」的记忆配置,请使用其他名称")
|
||||
msg = fail(BizCode.BAD_REQUEST, "配置名称已存在",
|
||||
f"当前工作空间下已存在名为「{payload.config_name}」的记忆配置,请使用其他名称")
|
||||
return JSONResponse(status_code=400, content=msg)
|
||||
api_logger.error(f"Create config failed: {str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "创建配置失败", str(e))
|
||||
@@ -129,10 +128,10 @@ def create_config(
|
||||
|
||||
@router.delete("/delete_config", response_model=ApiResponse) # 删除数据库中的内容(按配置名称)
|
||||
def delete_config(
|
||||
config_id: UUID|int,
|
||||
force: bool = Query(False, description="是否强制删除(即使有终端用户正在使用)"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
config_id: UUID | int,
|
||||
force: bool = Query(False, description="是否强制删除(即使有终端用户正在使用)"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
"""删除记忆配置(带终端用户保护)
|
||||
|
||||
@@ -145,24 +144,24 @@ def delete_config(
|
||||
force: 设置为 true 可强制删除(即使有终端用户正在使用)
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
config_id=resolve_config_id(config_id, db)
|
||||
config_id = resolve_config_id(config_id, db)
|
||||
# 检查用户是否已选择工作空间
|
||||
if workspace_id is None:
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试删除配置但未选择工作空间")
|
||||
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
||||
|
||||
|
||||
api_logger.info(
|
||||
f"用户 {current_user.username} 在工作空间 {workspace_id} 请求删除配置: "
|
||||
f"config_id={config_id}, force={force}"
|
||||
)
|
||||
|
||||
|
||||
try:
|
||||
# 使用带保护的删除服务
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
|
||||
|
||||
config_service = MemoryConfigService(db)
|
||||
result = config_service.delete_config(config_id=config_id, force=force)
|
||||
|
||||
|
||||
if result["status"] == "error":
|
||||
api_logger.warning(
|
||||
f"记忆配置删除被拒绝: config_id={config_id}, reason={result['message']}"
|
||||
@@ -172,7 +171,7 @@ def delete_config(
|
||||
msg=result["message"],
|
||||
data={"config_id": str(config_id), "is_default": result.get("is_default", False)}
|
||||
)
|
||||
|
||||
|
||||
if result["status"] == "warning":
|
||||
api_logger.warning(
|
||||
f"记忆配置正在使用,无法删除: config_id={config_id}, "
|
||||
@@ -186,7 +185,7 @@ def delete_config(
|
||||
"force_required": result["force_required"]
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
api_logger.info(
|
||||
f"记忆配置删除成功: config_id={config_id}, "
|
||||
f"affected_users={result['affected_users']}"
|
||||
@@ -195,7 +194,7 @@ def delete_config(
|
||||
msg=result["message"],
|
||||
data={"affected_users": result["affected_users"]}
|
||||
)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"Delete config failed: {str(e)}", exc_info=True)
|
||||
return fail(BizCode.INTERNAL_ERROR, "删除配置失败", str(e))
|
||||
@@ -203,9 +202,9 @@ def delete_config(
|
||||
|
||||
@router.post("/update_config", response_model=ApiResponse) # 更新配置文件中name和desc
|
||||
def update_config(
|
||||
payload: ConfigUpdate,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
payload: ConfigUpdate,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
workspace_id = current_user.current_workspace_id
|
||||
payload.config_id = resolve_config_id(payload.config_id, db)
|
||||
@@ -213,12 +212,13 @@ def update_config(
|
||||
if workspace_id is None:
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试更新配置但未选择工作空间")
|
||||
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
||||
|
||||
|
||||
# 校验至少有一个字段需要更新
|
||||
if payload.config_name is None and payload.config_desc is None and payload.scene_id is None:
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试更新配置但未提供任何更新字段")
|
||||
return fail(BizCode.INVALID_PARAMETER, "请至少提供一个需要更新的字段", "config_name, config_desc, scene_id 均为空")
|
||||
|
||||
return fail(BizCode.INVALID_PARAMETER, "请至少提供一个需要更新的字段",
|
||||
"config_name, config_desc, scene_id 均为空")
|
||||
|
||||
api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求更新配置: {payload.config_id}")
|
||||
try:
|
||||
svc = DataConfigService(db)
|
||||
@@ -231,9 +231,9 @@ def update_config(
|
||||
|
||||
@router.post("/update_config_extracted", response_model=ApiResponse) # 更新数据库中的部分内容 所有业务字段均可选
|
||||
def update_config_extracted(
|
||||
payload: ConfigUpdateExtracted,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
payload: ConfigUpdateExtracted,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
workspace_id = current_user.current_workspace_id
|
||||
payload.config_id = resolve_config_id(payload.config_id, db)
|
||||
@@ -241,7 +241,7 @@ def update_config_extracted(
|
||||
if workspace_id is None:
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试更新提取配置但未选择工作空间")
|
||||
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
||||
|
||||
|
||||
api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求更新提取配置: {payload.config_id}")
|
||||
try:
|
||||
svc = DataConfigService(db)
|
||||
@@ -256,11 +256,11 @@ def update_config_extracted(
|
||||
# 遗忘引擎配置接口已迁移到 memory_forget_controller.py
|
||||
# 使用新接口: /api/memory/forget/read_config 和 /api/memory/forget/update_config
|
||||
|
||||
@router.get("/read_config_extracted", response_model=ApiResponse) # 通过查询参数读取某条配置(固定路径) 没有意义的话就删除
|
||||
@router.get("/read_config_extracted", response_model=ApiResponse) # 通过查询参数读取某条配置(固定路径) 没有意义的话就删除
|
||||
def read_config_extracted(
|
||||
config_id: UUID | int,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
config_id: UUID | int,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
workspace_id = current_user.current_workspace_id
|
||||
config_id = resolve_config_id(config_id, db)
|
||||
@@ -268,7 +268,7 @@ def read_config_extracted(
|
||||
if workspace_id is None:
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试读取提取配置但未选择工作空间")
|
||||
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
||||
|
||||
|
||||
api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求读取提取配置: {config_id}")
|
||||
try:
|
||||
svc = DataConfigService(db)
|
||||
@@ -278,18 +278,19 @@ def read_config_extracted(
|
||||
api_logger.error(f"Read config extracted failed: {str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "查询配置失败", str(e))
|
||||
|
||||
@router.get("/read_all_config", response_model=ApiResponse) # 读取所有配置文件列表
|
||||
|
||||
@router.get("/read_all_config", response_model=ApiResponse) # 读取所有配置文件列表
|
||||
def read_all_config(
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
|
||||
# 检查用户是否已选择工作空间
|
||||
if workspace_id is None:
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试查询配置但未选择工作空间")
|
||||
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
||||
|
||||
|
||||
api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求读取所有配置")
|
||||
try:
|
||||
svc = DataConfigService(db)
|
||||
@@ -303,14 +304,14 @@ def read_all_config(
|
||||
|
||||
@router.post("/pilot_run", response_model=None)
|
||||
async def pilot_run(
|
||||
payload: ConfigPilotRun,
|
||||
language_type: str = Header(default=None, alias="X-Language-Type"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
payload: ConfigPilotRun,
|
||||
language_type: str = Header(default=None, alias="X-Language-Type"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> StreamingResponse:
|
||||
# 使用集中化的语言校验
|
||||
language = get_language_from_header(language_type)
|
||||
|
||||
|
||||
api_logger.info(
|
||||
f"Pilot run requested: config_id={payload.config_id}, "
|
||||
f"dialogue_text_length={len(payload.dialogue_text)}, "
|
||||
@@ -333,9 +334,9 @@ async def pilot_run(
|
||||
|
||||
@router.get("/search/kb_type_distribution", response_model=ApiResponse)
|
||||
async def get_kb_type_distribution(
|
||||
end_user_id: Optional[str] = None,
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> dict:
|
||||
end_user_id: Optional[str] = None,
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> dict:
|
||||
api_logger.info(f"KB type distribution requested for end_user_id: {end_user_id}")
|
||||
try:
|
||||
result = await kb_type_distribution(end_user_id)
|
||||
@@ -344,12 +345,12 @@ async def get_kb_type_distribution(
|
||||
api_logger.error(f"KB type distribution failed: {str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "知识库类型分布查询失败", str(e))
|
||||
|
||||
|
||||
|
||||
@router.get("/search/dialogue", response_model=ApiResponse)
|
||||
async def search_dialogues_num(
|
||||
end_user_id: Optional[str] = None,
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> dict:
|
||||
end_user_id: Optional[str] = None,
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> dict:
|
||||
api_logger.info(f"Search dialogue requested for end_user_id: {end_user_id}")
|
||||
try:
|
||||
result = await search_dialogue(end_user_id)
|
||||
@@ -361,9 +362,9 @@ async def search_dialogues_num(
|
||||
|
||||
@router.get("/search/chunk", response_model=ApiResponse)
|
||||
async def search_chunks_num(
|
||||
end_user_id: Optional[str] = None,
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> dict:
|
||||
end_user_id: Optional[str] = None,
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> dict:
|
||||
api_logger.info(f"Search chunk requested for end_user_id: {end_user_id}")
|
||||
try:
|
||||
result = await search_chunk(end_user_id)
|
||||
@@ -375,9 +376,9 @@ async def search_chunks_num(
|
||||
|
||||
@router.get("/search/statement", response_model=ApiResponse)
|
||||
async def search_statements_num(
|
||||
end_user_id: Optional[str] = None,
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> dict:
|
||||
end_user_id: Optional[str] = None,
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> dict:
|
||||
api_logger.info(f"Search statement requested for end_user_id: {end_user_id}")
|
||||
try:
|
||||
result = await search_statement(end_user_id)
|
||||
@@ -389,9 +390,9 @@ async def search_statements_num(
|
||||
|
||||
@router.get("/search/entity", response_model=ApiResponse)
|
||||
async def search_entities_num(
|
||||
end_user_id: Optional[str] = None,
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> dict:
|
||||
end_user_id: Optional[str] = None,
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> dict:
|
||||
api_logger.info(f"Search entity requested for end_user_id: {end_user_id}")
|
||||
try:
|
||||
result = await search_entity(end_user_id)
|
||||
@@ -403,12 +404,15 @@ async def search_entities_num(
|
||||
|
||||
@router.get("/search", response_model=ApiResponse)
|
||||
async def search_all_num(
|
||||
end_user_id: Optional[str] = None,
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> dict:
|
||||
end_user_id: Optional[str] = None,
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> dict:
|
||||
api_logger.info(f"Search all requested for end_user_id: {end_user_id}")
|
||||
try:
|
||||
result = await search_all(end_user_id)
|
||||
if not end_user_id:
|
||||
return success(data={"total": 0}, msg="查询成功")
|
||||
batch_result = await search_all_batch([end_user_id])
|
||||
result = {"total": batch_result.get(end_user_id, 0)}
|
||||
return success(data=result, msg="查询成功")
|
||||
except Exception as e:
|
||||
api_logger.error(f"Search all failed: {str(e)}")
|
||||
@@ -417,9 +421,9 @@ async def search_all_num(
|
||||
|
||||
@router.get("/search/detials", response_model=ApiResponse)
|
||||
async def search_entities_detials(
|
||||
end_user_id: Optional[str] = None,
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> dict:
|
||||
end_user_id: Optional[str] = None,
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> dict:
|
||||
api_logger.info(f"Search details requested for end_user_id: {end_user_id}")
|
||||
try:
|
||||
result = await search_detials(end_user_id)
|
||||
@@ -431,9 +435,9 @@ async def search_entities_detials(
|
||||
|
||||
@router.get("/search/edges", response_model=ApiResponse)
|
||||
async def search_entity_edges(
|
||||
end_user_id: Optional[str] = None,
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> dict:
|
||||
end_user_id: Optional[str] = None,
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> dict:
|
||||
api_logger.info(f"Search edges requested for end_user_id: {end_user_id}")
|
||||
try:
|
||||
result = await search_edges(end_user_id)
|
||||
@@ -443,14 +447,12 @@ async def search_entity_edges(
|
||||
return fail(BizCode.INTERNAL_ERROR, "边查询失败", str(e))
|
||||
|
||||
|
||||
|
||||
|
||||
@router.get("/analytics/hot_memory_tags", response_model=ApiResponse)
|
||||
async def get_hot_memory_tags_api(
|
||||
limit: int = 10,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> dict:
|
||||
limit: int = 10,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> dict:
|
||||
"""
|
||||
获取热门记忆标签(带Redis缓存)
|
||||
|
||||
@@ -461,18 +463,18 @@ async def get_hot_memory_tags_api(
|
||||
- 缓存未命中:~600-800ms(取决于LLM速度)
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
|
||||
# 构建缓存键
|
||||
cache_key = f"hot_memory_tags:{workspace_id}:{limit}"
|
||||
|
||||
|
||||
api_logger.info(f"Hot memory tags requested for workspace: {workspace_id}, limit: {limit}")
|
||||
|
||||
|
||||
try:
|
||||
# 尝试从Redis缓存获取
|
||||
import json
|
||||
|
||||
from app.aioRedis import aio_redis_get, aio_redis_set
|
||||
|
||||
|
||||
cached_result = await aio_redis_get(cache_key)
|
||||
if cached_result:
|
||||
api_logger.info(f"Cache hit for key: {cache_key}")
|
||||
@@ -481,11 +483,11 @@ async def get_hot_memory_tags_api(
|
||||
return success(data=data, msg="查询成功(缓存)")
|
||||
except json.JSONDecodeError:
|
||||
api_logger.warning(f"Failed to parse cached data, will refresh")
|
||||
|
||||
|
||||
# 缓存未命中,执行查询
|
||||
api_logger.info(f"Cache miss for key: {cache_key}, executing query")
|
||||
result = await analytics_hot_memory_tags(db, current_user, limit)
|
||||
|
||||
|
||||
# 写入缓存(过期时间:5分钟)
|
||||
# 注意:result是列表,需要转换为JSON字符串
|
||||
try:
|
||||
@@ -495,9 +497,9 @@ async def get_hot_memory_tags_api(
|
||||
except Exception as cache_error:
|
||||
# 缓存写入失败不影响主流程
|
||||
api_logger.warning(f"Failed to cache result: {str(cache_error)}")
|
||||
|
||||
|
||||
return success(data=result, msg="查询成功")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"Hot memory tags failed: {str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "热门标签查询失败", str(e))
|
||||
@@ -505,8 +507,8 @@ async def get_hot_memory_tags_api(
|
||||
|
||||
@router.delete("/analytics/hot_memory_tags/cache", response_model=ApiResponse)
|
||||
async def clear_hot_memory_tags_cache(
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> dict:
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> dict:
|
||||
"""
|
||||
清除热门标签缓存
|
||||
|
||||
@@ -516,12 +518,12 @@ async def clear_hot_memory_tags_cache(
|
||||
- 数据更新后立即生效
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
|
||||
api_logger.info(f"Clear hot memory tags cache requested for workspace: {workspace_id}")
|
||||
|
||||
|
||||
try:
|
||||
from app.aioRedis import aio_redis_delete
|
||||
|
||||
|
||||
# 清除所有limit的缓存(常见的limit值)
|
||||
cleared_count = 0
|
||||
for limit in [5, 10, 15, 20, 30, 50]:
|
||||
@@ -530,12 +532,12 @@ async def clear_hot_memory_tags_cache(
|
||||
if result:
|
||||
cleared_count += 1
|
||||
api_logger.info(f"Cleared cache for key: {cache_key}")
|
||||
|
||||
|
||||
return success(
|
||||
data={"cleared_count": cleared_count},
|
||||
data={"cleared_count": cleared_count},
|
||||
msg=f"成功清除 {cleared_count} 个缓存"
|
||||
)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"Clear cache failed: {str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "清除缓存失败", str(e))
|
||||
@@ -543,7 +545,7 @@ async def clear_hot_memory_tags_cache(
|
||||
|
||||
@router.get("/analytics/recent_activity_stats", response_model=ApiResponse)
|
||||
async def get_recent_activity_stats_api(
|
||||
current_user: User = Depends(get_current_user),
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> dict:
|
||||
workspace_id = str(current_user.current_workspace_id) if current_user.current_workspace_id else None
|
||||
api_logger.info(f"Recent activity stats requested: workspace_id={workspace_id}")
|
||||
@@ -553,4 +555,3 @@ async def get_recent_activity_stats_api(
|
||||
except Exception as e:
|
||||
api_logger.error(f"Recent activity stats failed: {str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "最近活动统计失败", str(e))
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@ from app.core.response_utils import success
|
||||
from app.db import get_db
|
||||
from app.dependencies import get_current_user
|
||||
from app.models import User
|
||||
from app.schemas import conversation_schema
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.services.conversation_service import ConversationService
|
||||
|
||||
@@ -32,35 +33,47 @@ def get_memory_count(
|
||||
@router.get("/{end_user_id}/conversations", response_model=ApiResponse)
|
||||
def get_conversations(
|
||||
end_user_id: uuid.UUID,
|
||||
page: int = 1,
|
||||
pagesize: int = 20,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Retrieve all conversations for the current user in a specific group.
|
||||
Retrieve conversations for the current user in a specific group with pagination.
|
||||
|
||||
Args:
|
||||
end_user_id (UUID): The group identifier.
|
||||
page (int): Page number (1-based). Defaults to 1.
|
||||
pagesize (int): Number of items per page. Defaults to 20.
|
||||
current_user (User, optional): The authenticated user.
|
||||
db (Session, optional): SQLAlchemy session.
|
||||
|
||||
Returns:
|
||||
ApiResponse: Contains a list of conversation IDs.
|
||||
|
||||
Notes:
|
||||
- Initializes the ConversationService with the current DB session.
|
||||
- Returns only conversation IDs for lightweight response.
|
||||
- Logs can be added to trace requests in production.
|
||||
ApiResponse: Contains a paginated list of conversations.
|
||||
"""
|
||||
page = max(1, page)
|
||||
page_size = max(1, min(pagesize, 100)) # Limit page size between 1 and 100
|
||||
conversation_service = ConversationService(db)
|
||||
conversations = conversation_service.get_user_conversations(
|
||||
end_user_id
|
||||
conversations, total = conversation_service.get_user_conversations(
|
||||
end_user_id,
|
||||
page=page,
|
||||
page_size=page_size
|
||||
)
|
||||
return success(data=[
|
||||
{
|
||||
"id": conversation.id,
|
||||
"title": conversation.title
|
||||
} for conversation in conversations
|
||||
], msg="get conversations success")
|
||||
return success(data={
|
||||
"items": [
|
||||
{
|
||||
"id": conversation.id,
|
||||
"title": conversation.title
|
||||
} for conversation in conversations
|
||||
],
|
||||
"total": total,
|
||||
"page": {
|
||||
"page": page,
|
||||
"pagesize": page_size,
|
||||
"total": total,
|
||||
"hasnext": (page * page_size) < total
|
||||
},
|
||||
}, msg="get conversations success")
|
||||
|
||||
|
||||
@router.get("/{end_user_id}/messages", response_model=ApiResponse)
|
||||
@@ -90,11 +103,7 @@ def get_messages(
|
||||
conversation_id,
|
||||
)
|
||||
messages = [
|
||||
{
|
||||
"role": message.role,
|
||||
"content": message.content,
|
||||
"created_at": int(message.created_at.timestamp() * 1000),
|
||||
}
|
||||
conversation_schema.Message.model_validate(message)
|
||||
for message in messages_obj
|
||||
]
|
||||
return success(data=messages, msg="get conversation history success")
|
||||
|
||||
@@ -42,6 +42,7 @@ def get_model_strategies():
|
||||
@router.get("", response_model=ApiResponse)
|
||||
def get_model_list(
|
||||
type: Optional[list[str]] = Query(None, description="模型类型筛选(支持多个,如 ?type=LLM 或 ?type=LLM,EMBEDDING)"),
|
||||
capability: Optional[list[str]] = Query(None, description="能力筛选(支持多个,如 ?capability=chat 或 ?capability=chat, embedding)"),
|
||||
provider: Optional[model_schema.ModelProvider] = Query(None, description="提供商筛选(基于API Key)"),
|
||||
is_active: Optional[bool] = Query(None, description="激活状态筛选"),
|
||||
is_public: Optional[bool] = Query(None, description="公开状态筛选"),
|
||||
@@ -74,10 +75,21 @@ def get_model_list(
|
||||
unique_flat_type = list(dict.fromkeys(flat_type))
|
||||
type_list = [ModelType(t.lower()) for t in unique_flat_type]
|
||||
|
||||
capability_list = []
|
||||
if capability is not None:
|
||||
flat_capability = []
|
||||
for item in capability:
|
||||
split_items = [c.strip() for c in item.split(', ') if c.strip()]
|
||||
flat_capability.extend(split_items)
|
||||
|
||||
unique_flat_capability = list(dict.fromkeys(flat_capability))
|
||||
capability_list = unique_flat_capability
|
||||
|
||||
api_logger.error(f"获取模型type_list: {type_list}")
|
||||
query = model_schema.ModelConfigQuery(
|
||||
type=type_list,
|
||||
provider=provider,
|
||||
capability=capability_list,
|
||||
is_active=is_active,
|
||||
is_public=is_public,
|
||||
search=search,
|
||||
|
||||
@@ -163,6 +163,7 @@ def _get_ontology_service(
|
||||
api_key=api_key_config.api_key,
|
||||
base_url=api_key_config.api_base,
|
||||
is_omni=api_key_config.is_omni,
|
||||
support_thinking="thinking" in (api_key_config.capability or []),
|
||||
max_retries=3,
|
||||
timeout=60.0
|
||||
)
|
||||
|
||||
@@ -124,10 +124,11 @@ async def get_prompt_opt(
|
||||
skill=data.skill
|
||||
):
|
||||
# chunk 是 prompt 的增量内容
|
||||
yield f"event:message\ndata: {json.dumps(chunk)}\n\n"
|
||||
yield f"event:message\ndata: {json.dumps(chunk, ensure_ascii=False)}\n\n"
|
||||
except Exception as e:
|
||||
yield f"event:error\ndata: {json.dumps(
|
||||
{"error": str(e)}
|
||||
{"error": str(e)},
|
||||
ensure_ascii=False
|
||||
)}\n\n"
|
||||
yield "event:end\ndata: {}\n\n"
|
||||
|
||||
|
||||
@@ -13,7 +13,6 @@ from app.core.logging_config import get_business_logger
|
||||
from app.core.response_utils import success, fail
|
||||
from app.db import get_db, get_db_read
|
||||
from app.dependencies import get_share_user_id, ShareTokenData
|
||||
from app.models.app_model import App
|
||||
from app.models.app_model import AppType
|
||||
from app.repositories import knowledge_repository
|
||||
from app.repositories.end_user_repository import EndUserRepository
|
||||
@@ -22,11 +21,13 @@ from app.schemas import release_share_schema, conversation_schema
|
||||
from app.schemas.response_schema import PageData, PageMeta
|
||||
from app.services import workspace_service
|
||||
from app.services.app_chat_service import AppChatService, get_app_chat_service
|
||||
from app.services.app_service import AppService
|
||||
from app.services.auth_service import create_access_token
|
||||
from app.services.conversation_service import ConversationService
|
||||
from app.services.release_share_service import ReleaseShareService
|
||||
from app.services.shared_chat_service import SharedChatService
|
||||
from app.services.workflow_service import WorkflowService
|
||||
from app.models.file_metadata_model import FileMetadata
|
||||
from app.utils.app_config_utils import workflow_config_4_app_release, \
|
||||
agent_config_4_app_release, multi_agent_config_4_app_release
|
||||
|
||||
@@ -215,8 +216,11 @@ def list_conversations(
|
||||
service = SharedChatService(db)
|
||||
share, release = service.get_release_by_share_token(share_data.share_token, password)
|
||||
end_user_repo = EndUserRepository(db)
|
||||
app_service = AppService(db)
|
||||
app = app_service._get_app_or_404(share.app_id)
|
||||
new_end_user = end_user_repo.get_or_create_end_user(
|
||||
app_id=share.app_id,
|
||||
workspace_id=app.workspace_id,
|
||||
other_id=other_id
|
||||
)
|
||||
logger.debug(new_end_user.id)
|
||||
@@ -256,8 +260,41 @@ def get_conversation(
|
||||
conv_service = ConversationService(db)
|
||||
messages = conv_service.get_messages(conversation_id)
|
||||
|
||||
# 构建响应
|
||||
conv_dict = conversation_schema.Conversation.model_validate(conversation).model_dump()
|
||||
file_ids = []
|
||||
message_file_id_map = {}
|
||||
|
||||
# 第一次遍历:解析 audio_url,收集所有有效的 file_id
|
||||
for idx, m in enumerate(messages):
|
||||
if m.role == "assistant" and m.meta_data:
|
||||
audio_url = m.meta_data.get("audio_url")
|
||||
if not audio_url:
|
||||
continue
|
||||
try:
|
||||
file_id = uuid.UUID(audio_url.rstrip("/").split("/")[-1])
|
||||
except (ValueError, IndexError):
|
||||
# audio_url 无法解析为 UUID,标记为 unknown
|
||||
m.meta_data["audio_status"] = "unknown"
|
||||
continue
|
||||
|
||||
file_ids.append(file_id)
|
||||
message_file_id_map[idx] = file_id
|
||||
|
||||
# 批量查询所有相关的 FileMetadata
|
||||
file_status_map = {}
|
||||
if file_ids:
|
||||
file_metas = (
|
||||
db.query(FileMetadata)
|
||||
.filter(FileMetadata.id.in_(set(file_ids)))
|
||||
.all()
|
||||
)
|
||||
file_status_map = {fm.id: fm.status for fm in file_metas}
|
||||
|
||||
# 第二次遍历:将查询结果映射回消息
|
||||
for idx, file_id in message_file_id_map.items():
|
||||
m = messages[idx]
|
||||
m.meta_data["audio_status"] = file_status_map.get(file_id, "unknown")
|
||||
|
||||
conv_dict = conversation_schema.Conversation.model_validate(conversation).model_dump(mode="json")
|
||||
conv_dict["messages"] = [
|
||||
conversation_schema.Message.model_validate(m) for m in messages
|
||||
]
|
||||
@@ -308,25 +345,39 @@ async def chat(
|
||||
|
||||
# Store end_user_id in database with original user_id
|
||||
end_user_repo = EndUserRepository(db)
|
||||
app_service = AppService(db)
|
||||
app = app_service._get_app_or_404(share.app_id)
|
||||
workspace_id = app.workspace_id
|
||||
new_end_user = end_user_repo.get_or_create_end_user(
|
||||
app_id=share.app_id,
|
||||
workspace_id=workspace_id,
|
||||
other_id=other_id,
|
||||
original_user_id=user_id # Save original user_id to other_id
|
||||
original_user_id=user_id
|
||||
)
|
||||
|
||||
# Only extract and set memory_config_id when the end user doesn't have one yet
|
||||
if not new_end_user.memory_config_id:
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
memory_config_service = MemoryConfigService(db)
|
||||
memory_config_id, _ = memory_config_service.extract_memory_config_id(release.type, release.config or {})
|
||||
if memory_config_id:
|
||||
new_end_user.memory_config_id = memory_config_id
|
||||
db.commit()
|
||||
db.refresh(new_end_user)
|
||||
end_user_id = str(new_end_user.id)
|
||||
|
||||
appid = share.app_id
|
||||
# appid = share.app_id
|
||||
"""获取存储类型和工作空间的ID"""
|
||||
|
||||
# 直接通过 SQLAlchemy 查询 app(仅查询未删除的应用)
|
||||
app = db.query(App).filter(
|
||||
App.id == appid,
|
||||
App.is_active.is_(True)
|
||||
).first()
|
||||
if not app:
|
||||
raise BusinessException("应用不存在", BizCode.APP_NOT_FOUND)
|
||||
# app = db.query(App).filter(
|
||||
# App.id == appid,
|
||||
# App.is_active.is_(True)
|
||||
# ).first()
|
||||
# if not app:
|
||||
# raise BusinessException("应用不存在", BizCode.APP_NOT_FOUND)
|
||||
|
||||
workspace_id = app.workspace_id
|
||||
# workspace_id = app.workspace_id
|
||||
|
||||
# 直接从 workspace 获取 storage_type(公开分享场景无需权限检查)
|
||||
storage_type = workspace_service.get_workspace_storage_type_without_auth(
|
||||
@@ -402,31 +453,10 @@ async def chat(
|
||||
# 流式返回
|
||||
agent_config = agent_config_4_app_release(release)
|
||||
|
||||
if payload.stream:
|
||||
# async def event_generator():
|
||||
# async for event in service.chat_stream(
|
||||
# share_token=share_token,
|
||||
# message=payload.message,
|
||||
# conversation_id=conversation.id, # 使用已创建的会话 ID
|
||||
# user_id=str(new_end_user.id), # 转换为字符串
|
||||
# variables=payload.variables,
|
||||
# password=password,
|
||||
# web_search=payload.web_search,
|
||||
# memory=payload.memory,
|
||||
# storage_type=storage_type,
|
||||
# user_rag_memory_id=user_rag_memory_id
|
||||
# ):
|
||||
# yield event
|
||||
if not (agent_config.model_parameters.get("deep_thinking", False) and payload.thinking):
|
||||
agent_config.model_parameters["deep_thinking"] = False
|
||||
|
||||
# return StreamingResponse(
|
||||
# event_generator(),
|
||||
# media_type="text/event-stream",
|
||||
# headers={
|
||||
# "Cache-Control": "no-cache",
|
||||
# "Connection": "keep-alive",
|
||||
# "X-Accel-Buffering": "no"
|
||||
# }
|
||||
# )
|
||||
if payload.stream:
|
||||
async def event_generator():
|
||||
async for event in app_chat_service.agnet_chat_stream(
|
||||
message=payload.message,
|
||||
@@ -452,20 +482,6 @@ async def chat(
|
||||
"X-Accel-Buffering": "no"
|
||||
}
|
||||
)
|
||||
# 非流式返回
|
||||
# result = await service.chat(
|
||||
# share_token=share_token,
|
||||
# message=payload.message,
|
||||
# conversation_id=conversation.id, # 使用已创建的会话 ID
|
||||
# user_id=str(new_end_user.id), # 转换为字符串
|
||||
# variables=payload.variables,
|
||||
# password=password,
|
||||
# web_search=payload.web_search,
|
||||
# memory=payload.memory,
|
||||
# storage_type=storage_type,
|
||||
# user_rag_memory_id=user_rag_memory_id
|
||||
# )
|
||||
# return success(data=conversation_schema.ChatResponse(**result))
|
||||
result = await app_chat_service.agnet_chat(
|
||||
message=payload.message,
|
||||
conversation_id=conversation.id, # 使用已创建的会话 ID
|
||||
@@ -524,48 +540,6 @@ async def chat(
|
||||
)
|
||||
|
||||
return success(data=conversation_schema.ChatResponse(**result).model_dump(mode="json"))
|
||||
# 多 Agent 流式返回
|
||||
# if payload.stream:
|
||||
# async def event_generator():
|
||||
# async for event in service.multi_agent_chat_stream(
|
||||
# share_token=share_token,
|
||||
# message=payload.message,
|
||||
# conversation_id=conversation.id, # 使用已创建的会话 ID
|
||||
# user_id=str(new_end_user.id), # 转换为字符串
|
||||
# variables=payload.variables,
|
||||
# password=password,
|
||||
# web_search=payload.web_search,
|
||||
# memory=payload.memory,
|
||||
# storage_type=storage_type,
|
||||
# user_rag_memory_id=user_rag_memory_id
|
||||
# ):
|
||||
# yield event
|
||||
|
||||
# return StreamingResponse(
|
||||
# event_generator(),
|
||||
# media_type="text/event-stream",
|
||||
# headers={
|
||||
# "Cache-Control": "no-cache",
|
||||
# "Connection": "keep-alive",
|
||||
# "X-Accel-Buffering": "no"
|
||||
# }
|
||||
# )
|
||||
|
||||
# # 多 Agent 非流式返回
|
||||
# result = await service.multi_agent_chat(
|
||||
# share_token=share_token,
|
||||
# message=payload.message,
|
||||
# conversation_id=conversation.id, # 使用已创建的会话 ID
|
||||
# user_id=str(new_end_user.id), # 转换为字符串
|
||||
# variables=payload.variables,
|
||||
# password=password,
|
||||
# web_search=payload.web_search,
|
||||
# memory=payload.memory,
|
||||
# storage_type=storage_type,
|
||||
# user_rag_memory_id=user_rag_memory_id
|
||||
# )
|
||||
|
||||
# return success(data=conversation_schema.ChatResponse(**result))
|
||||
elif app_type == AppType.WORKFLOW:
|
||||
config = workflow_config_4_app_release(release)
|
||||
if not config.id:
|
||||
@@ -610,11 +584,11 @@ async def chat(
|
||||
|
||||
# 多 Agent 非流式返回
|
||||
result = await app_chat_service.workflow_chat(
|
||||
|
||||
message=payload.message,
|
||||
conversation_id=conversation.id, # 使用已创建的会话 ID
|
||||
user_id=end_user_id, # 转换为字符串
|
||||
variables=payload.variables,
|
||||
files=payload.files,
|
||||
config=config,
|
||||
web_search=payload.web_search,
|
||||
memory=payload.memory,
|
||||
@@ -654,17 +628,23 @@ async def config_query(
|
||||
workflow_service = WorkflowService(db)
|
||||
content = {
|
||||
"app_type": release.app.type,
|
||||
"variables": workflow_service.get_start_node_variables(release.config)
|
||||
"variables": workflow_service.get_start_node_variables(release.config),
|
||||
"memory": workflow_service.is_memory_enable(release.config),
|
||||
"features": release.config.get("features")
|
||||
}
|
||||
elif release.app.type == AppType.AGENT:
|
||||
content = {
|
||||
"app_type": release.app.type,
|
||||
"variables": release.config.get("variables")
|
||||
"variables": release.config.get("variables"),
|
||||
"memory": release.config.get("memory", {}).get("enabled"),
|
||||
"features": release.config.get("features"),
|
||||
"model_parameters": release.config.get("model_parameters")
|
||||
}
|
||||
elif release.app.type == AppType.MULTI_AGENT:
|
||||
content = {
|
||||
"app_type": release.app.type,
|
||||
"variables": []
|
||||
"variables": [],
|
||||
"features": release.config.get("features")
|
||||
}
|
||||
else:
|
||||
return fail(msg="Unsupported app type", code=BizCode.APP_TYPE_NOT_SUPPORTED)
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
认证方式: API Key
|
||||
"""
|
||||
from fastapi import APIRouter
|
||||
from . import app_api_controller, rag_api_knowledge_controller, rag_api_document_controller, rag_api_file_controller, rag_api_chunk_controller, memory_api_controller
|
||||
from . import app_api_controller, rag_api_knowledge_controller, rag_api_document_controller, rag_api_file_controller, rag_api_chunk_controller, memory_api_controller, end_user_api_controller
|
||||
|
||||
# 创建 V1 API 路由器
|
||||
service_router = APIRouter()
|
||||
@@ -16,5 +16,6 @@ service_router.include_router(rag_api_document_controller.router)
|
||||
service_router.include_router(rag_api_file_controller.router)
|
||||
service_router.include_router(rag_api_chunk_controller.router)
|
||||
service_router.include_router(memory_api_controller.router)
|
||||
service_router.include_router(end_user_api_controller.router)
|
||||
|
||||
__all__ = ["service_router"]
|
||||
|
||||
@@ -14,6 +14,7 @@ from app.core.response_utils import success
|
||||
from app.db import get_db
|
||||
from app.models.app_model import App
|
||||
from app.models.app_model import AppType
|
||||
from app.models.app_release_model import AppRelease
|
||||
from app.repositories import knowledge_repository
|
||||
from app.repositories.end_user_repository import EndUserRepository
|
||||
from app.schemas import AppChatRequest, conversation_schema
|
||||
@@ -61,18 +62,18 @@ async def list_apps():
|
||||
# return success(data={"received": True}, msg="消息已接收")
|
||||
|
||||
|
||||
def _checkAppConfig(app: App):
|
||||
if app.type == AppType.AGENT:
|
||||
if not app.current_release.config:
|
||||
def _checkAppConfig(release: AppRelease):
|
||||
if release.type == AppType.AGENT:
|
||||
if not release.config:
|
||||
raise BusinessException("Agent 应用未配置模型", BizCode.AGENT_CONFIG_MISSING)
|
||||
elif app.type == AppType.MULTI_AGENT:
|
||||
if not app.current_release.config:
|
||||
elif release.type == AppType.MULTI_AGENT:
|
||||
if not release.config:
|
||||
raise BusinessException("Multi-Agent 应用未配置模型", BizCode.AGENT_CONFIG_MISSING)
|
||||
elif app.type == AppType.WORKFLOW:
|
||||
if not app.current_release.config:
|
||||
elif release.type == AppType.WORKFLOW:
|
||||
if not release.config:
|
||||
raise BusinessException("工作流应用未配置模型", BizCode.AGENT_CONFIG_MISSING)
|
||||
else:
|
||||
raise BusinessException("不支持的应用类型", BizCode.AGENT_CONFIG_MISSING)
|
||||
raise BusinessException("不支持的应用类型", BizCode.APP_TYPE_NOT_SUPPORTED)
|
||||
|
||||
|
||||
@router.post("/chat")
|
||||
@@ -86,17 +87,29 @@ async def chat(
|
||||
app_service: Annotated[AppService, Depends(get_app_service)] = None,
|
||||
message: str = Body(..., description="聊天消息内容"),
|
||||
):
|
||||
"""
|
||||
Agent/Workflow 聊天接口
|
||||
|
||||
- 不传 version:使用当前生效版本(current_release,回滚后为回滚目标版本)
|
||||
- 传 version=release_id:使用指定版本uuid的历史快照,例如 {"version": "{{release_id}}"}
|
||||
"""
|
||||
body = await request.json()
|
||||
payload = AppChatRequest(**body)
|
||||
|
||||
app = app_service.get_app(api_key_auth.resource_id, api_key_auth.workspace_id)
|
||||
|
||||
# 版本切换:指定 release_id 时查找对应历史快照,否则使用当前激活版本
|
||||
if payload.version is not None:
|
||||
active_release = app_service.get_release_by_id(app.id, payload.version)
|
||||
else:
|
||||
active_release = app.current_release
|
||||
other_id = payload.user_id
|
||||
workspace_id = app.workspace_id
|
||||
workspace_id = api_key_auth.workspace_id
|
||||
end_user_repo = EndUserRepository(db)
|
||||
new_end_user = end_user_repo.get_or_create_end_user(
|
||||
app_id=app.id,
|
||||
workspace_id=workspace_id,
|
||||
other_id=other_id,
|
||||
original_user_id=other_id # Save original user_id to other_id
|
||||
)
|
||||
end_user_id = str(new_end_user.id)
|
||||
web_search = True
|
||||
@@ -127,7 +140,7 @@ async def chat(
|
||||
storage_type = 'neo4j'
|
||||
app_type = app.type
|
||||
# check app config
|
||||
_checkAppConfig(app)
|
||||
_checkAppConfig(active_release)
|
||||
|
||||
# 获取或创建会话(提前验证)
|
||||
conversation = conversation_service.create_or_get_conversation(
|
||||
@@ -142,8 +155,13 @@ async def chat(
|
||||
|
||||
# print("="*50)
|
||||
# print(app.current_release.default_model_config_id)
|
||||
agent_config = agent_config_4_app_release(app.current_release)
|
||||
agent_config = agent_config_4_app_release(active_release)
|
||||
# print(agent_config.default_model_config_id)
|
||||
|
||||
# thinking 开关:仅当 agent 配置了 deep_thinking 且请求 thinking=True 时才启用
|
||||
if not (agent_config.model_parameters.get("deep_thinking", False) and payload.thinking):
|
||||
agent_config.model_parameters["deep_thinking"] = False
|
||||
|
||||
# 流式返回
|
||||
if payload.stream:
|
||||
async def event_generator():
|
||||
@@ -189,7 +207,7 @@ async def chat(
|
||||
return success(data=conversation_schema.ChatResponse(**result).model_dump(mode="json"))
|
||||
elif app_type == AppType.MULTI_AGENT:
|
||||
# 多 Agent 流式返回
|
||||
config = multi_agent_config_4_app_release(app.current_release)
|
||||
config = multi_agent_config_4_app_release(active_release)
|
||||
if payload.stream:
|
||||
async def event_generator():
|
||||
async for event in app_chat_service.multi_agent_chat_stream(
|
||||
@@ -232,7 +250,7 @@ async def chat(
|
||||
return success(data=conversation_schema.ChatResponse(**result).model_dump(mode="json"))
|
||||
elif app_type == AppType.WORKFLOW:
|
||||
# 多 Agent 流式返回
|
||||
config = workflow_config_4_app_release(app.current_release)
|
||||
config = workflow_config_4_app_release(active_release)
|
||||
if payload.stream:
|
||||
async def event_generator():
|
||||
async for event in app_chat_service.workflow_chat_stream(
|
||||
@@ -248,7 +266,7 @@ async def chat(
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
app_id=app.id,
|
||||
workspace_id=workspace_id,
|
||||
release_id=app.current_release.id,
|
||||
release_id=active_release.id,
|
||||
public=True
|
||||
):
|
||||
event_type = event.get("event", "message")
|
||||
@@ -280,9 +298,10 @@ async def chat(
|
||||
memory=memory,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
files=payload.files,
|
||||
app_id=app.id,
|
||||
workspace_id=workspace_id,
|
||||
release_id=app.current_release.id
|
||||
release_id=active_release.id
|
||||
)
|
||||
logger.debug(
|
||||
"工作流试运行返回结果",
|
||||
@@ -296,6 +315,4 @@ async def chat(
|
||||
msg="工作流任务执行成功"
|
||||
)
|
||||
else:
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.error_codes import BizCode
|
||||
raise BusinessException(f"不支持的应用类型: {app_type}", BizCode.APP_TYPE_NOT_SUPPORTED)
|
||||
|
||||
92
api/app/controllers/service/end_user_api_controller.py
Normal file
92
api/app/controllers/service/end_user_api_controller.py
Normal file
@@ -0,0 +1,92 @@
|
||||
"""End User 服务接口 - 基于 API Key 认证"""
|
||||
|
||||
import uuid
|
||||
|
||||
from fastapi import APIRouter, Body, Depends, Request
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.api_key_auth import require_api_key
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.core.response_utils import success
|
||||
from app.db import get_db
|
||||
from app.repositories.end_user_repository import EndUserRepository
|
||||
from app.schemas.api_key_schema import ApiKeyAuth
|
||||
from app.schemas.memory_api_schema import CreateEndUserRequest, CreateEndUserResponse
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
|
||||
router = APIRouter(prefix="/end_user", tags=["V1 - End User API"])
|
||||
logger = get_business_logger()
|
||||
|
||||
|
||||
@router.post("/create")
|
||||
@require_api_key(scopes=["memory"])
|
||||
async def create_end_user(
|
||||
request: Request,
|
||||
api_key_auth: ApiKeyAuth = None,
|
||||
db: Session = Depends(get_db),
|
||||
message: str = Body(..., description="Request body"),
|
||||
):
|
||||
"""
|
||||
Create or retrieve an end user for the workspace.
|
||||
|
||||
Creates a new end user and connects it to a memory configuration.
|
||||
If an end user with the same other_id already exists in the workspace,
|
||||
returns the existing one.
|
||||
|
||||
Optionally accepts a memory_config_id to connect the end user to a specific
|
||||
memory configuration. If not provided, falls back to the workspace default config.
|
||||
"""
|
||||
body = await request.json()
|
||||
payload = CreateEndUserRequest(**body)
|
||||
workspace_id = api_key_auth.workspace_id
|
||||
|
||||
logger.info("Create end user request - other_id: %s, workspace_id: %s", payload.other_id, workspace_id)
|
||||
|
||||
# Resolve memory_config_id: explicit > workspace default
|
||||
memory_config_id = None
|
||||
config_service = MemoryConfigService(db)
|
||||
|
||||
if payload.memory_config_id:
|
||||
try:
|
||||
memory_config_id = uuid.UUID(payload.memory_config_id)
|
||||
except ValueError:
|
||||
raise BusinessException(
|
||||
f"Invalid memory_config_id format: {payload.memory_config_id}",
|
||||
BizCode.INVALID_PARAMETER
|
||||
)
|
||||
config = config_service.get_config_with_fallback(memory_config_id, workspace_id)
|
||||
if not config:
|
||||
raise BusinessException(
|
||||
f"Memory config not found: {payload.memory_config_id}",
|
||||
BizCode.MEMORY_CONFIG_NOT_FOUND
|
||||
)
|
||||
memory_config_id = config.config_id
|
||||
else:
|
||||
default_config = config_service.get_workspace_default_config(workspace_id)
|
||||
if default_config:
|
||||
memory_config_id = default_config.config_id
|
||||
logger.info(f"Using workspace default memory config: {memory_config_id}")
|
||||
else:
|
||||
logger.warning(f"No default memory config found for workspace: {workspace_id}")
|
||||
|
||||
end_user_repo = EndUserRepository(db)
|
||||
end_user = end_user_repo.get_or_create_end_user_with_config(
|
||||
app_id=api_key_auth.resource_id,
|
||||
workspace_id=workspace_id,
|
||||
other_id=payload.other_id,
|
||||
memory_config_id=memory_config_id,
|
||||
)
|
||||
|
||||
logger.info(f"End user ready: {end_user.id}")
|
||||
|
||||
result = {
|
||||
"id": str(end_user.id),
|
||||
"other_id": end_user.other_id or "",
|
||||
"other_name": end_user.other_name or "",
|
||||
"workspace_id": str(end_user.workspace_id),
|
||||
"memory_config_id": str(end_user.memory_config_id) if end_user.memory_config_id else None,
|
||||
}
|
||||
|
||||
return success(data=CreateEndUserResponse(**result).model_dump(), msg="End user created successfully")
|
||||
@@ -6,6 +6,9 @@ from app.core.response_utils import success
|
||||
from app.db import get_db
|
||||
from app.schemas.api_key_schema import ApiKeyAuth
|
||||
from app.schemas.memory_api_schema import (
|
||||
CreateEndUserRequest,
|
||||
CreateEndUserResponse,
|
||||
ListConfigsResponse,
|
||||
MemoryReadRequest,
|
||||
MemoryReadResponse,
|
||||
MemoryWriteRequest,
|
||||
@@ -31,14 +34,15 @@ async def write_memory_api_service(
|
||||
request: Request,
|
||||
api_key_auth: ApiKeyAuth = None,
|
||||
db: Session = Depends(get_db),
|
||||
payload: MemoryWriteRequest = Body(..., embed=False),
|
||||
|
||||
message: str = Body(..., description="Message content"),
|
||||
):
|
||||
"""
|
||||
Write memory to storage.
|
||||
|
||||
Stores memory content for the specified end user using the Memory API Service.
|
||||
"""
|
||||
body = await request.json()
|
||||
payload = MemoryWriteRequest(**body)
|
||||
logger.info(f"Memory write request - end_user_id: {payload.end_user_id}, workspace_id: {api_key_auth.workspace_id}")
|
||||
|
||||
memory_api_service = MemoryAPIService(db)
|
||||
@@ -62,13 +66,15 @@ async def read_memory_api_service(
|
||||
request: Request,
|
||||
api_key_auth: ApiKeyAuth = None,
|
||||
db: Session = Depends(get_db),
|
||||
payload: MemoryReadRequest = Body(..., embed=False),
|
||||
message: str = Body(..., description="Query message"),
|
||||
):
|
||||
"""
|
||||
Read memory from storage.
|
||||
|
||||
Queries and retrieves memories for the specified end user with context-aware responses.
|
||||
"""
|
||||
body = await request.json()
|
||||
payload = MemoryReadRequest(**body)
|
||||
logger.info(f"Memory read request - end_user_id: {payload.end_user_id}")
|
||||
|
||||
memory_api_service = MemoryAPIService(db)
|
||||
@@ -85,3 +91,55 @@ async def read_memory_api_service(
|
||||
|
||||
logger.info(f"Memory read successful for end_user: {payload.end_user_id}")
|
||||
return success(data=MemoryReadResponse(**result).model_dump(), msg="Memory read successfully")
|
||||
|
||||
|
||||
@router.get("/configs")
|
||||
@require_api_key(scopes=["memory"])
|
||||
async def list_memory_configs(
|
||||
request: Request,
|
||||
api_key_auth: ApiKeyAuth = None,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
List all memory configs for the workspace.
|
||||
|
||||
Returns all available memory configurations associated with the authorized workspace.
|
||||
"""
|
||||
logger.info(f"List configs request - workspace_id: {api_key_auth.workspace_id}")
|
||||
|
||||
memory_api_service = MemoryAPIService(db)
|
||||
|
||||
result = memory_api_service.list_memory_configs(
|
||||
workspace_id=api_key_auth.workspace_id,
|
||||
)
|
||||
|
||||
logger.info(f"Listed {result['total']} configs for workspace: {api_key_auth.workspace_id}")
|
||||
return success(data=ListConfigsResponse(**result).model_dump(), msg="Configs listed successfully")
|
||||
|
||||
|
||||
@router.post("/end_users")
|
||||
@require_api_key(scopes=["memory"])
|
||||
async def create_end_user(
|
||||
request: Request,
|
||||
api_key_auth: ApiKeyAuth = None,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
Create an end user.
|
||||
|
||||
Creates a new end user for the authorized workspace.
|
||||
If an end user with the same other_id already exists, returns the existing one.
|
||||
"""
|
||||
body = await request.json()
|
||||
payload = CreateEndUserRequest(**body)
|
||||
logger.info(f"Create end user request - other_id: {payload.other_id}, workspace_id: {api_key_auth.workspace_id}")
|
||||
|
||||
memory_api_service = MemoryAPIService(db)
|
||||
|
||||
result = memory_api_service.create_end_user(
|
||||
workspace_id=api_key_auth.workspace_id,
|
||||
other_id=payload.other_id,
|
||||
)
|
||||
|
||||
logger.info(f"End user ready: {result['id']}")
|
||||
return success(data=CreateEndUserResponse(**result).model_dump(), msg="End user created successfully")
|
||||
|
||||
@@ -3,8 +3,11 @@ from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.error_codes import BizCode
|
||||
from app.schemas.tool_schema import (
|
||||
ToolCreateRequest, ToolUpdateRequest, ToolExecuteRequest, ParseSchemaRequest, CustomToolTestRequest
|
||||
ToolCreateRequest, ToolUpdateRequest, ToolExecuteRequest, ParseSchemaRequest,
|
||||
CustomToolTestRequest, ToolActiveUpdate
|
||||
)
|
||||
|
||||
from app.core.response_utils import success
|
||||
@@ -73,6 +76,8 @@ async def get_tool_methods(
|
||||
if methods is None:
|
||||
raise HTTPException(status_code=404, detail="工具不存在")
|
||||
return success(data=methods, msg="获取工具方法成功")
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@@ -118,6 +123,8 @@ async def create_tool(
|
||||
raise HTTPException(status_code=400, detail=e.message)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@@ -146,6 +153,8 @@ async def update_tool(
|
||||
return success(msg="工具更新成功")
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@@ -156,7 +165,7 @@ async def delete_tool(
|
||||
current_user: User = Depends(get_current_user),
|
||||
service: ToolService = Depends(get_tool_service)
|
||||
):
|
||||
"""删除工具"""
|
||||
"""删除工具(逻辑删除,is_active=False)"""
|
||||
try:
|
||||
success_flag = service.delete_tool(tool_id, current_user.tenant_id)
|
||||
if not success_flag:
|
||||
@@ -168,6 +177,32 @@ async def delete_tool(
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.patch("/{tool_id}/active", response_model=ApiResponse)
|
||||
async def set_tool_active(
|
||||
tool_id: str,
|
||||
request: ToolActiveUpdate,
|
||||
current_user: User = Depends(get_current_user),
|
||||
service: ToolService = Depends(get_tool_service)
|
||||
):
|
||||
"""设置工具可用状态(启用/禁用)
|
||||
|
||||
- is_active=true: 启用工具
|
||||
- is_active=false: 禁用工具(等同于删除,但可恢复)
|
||||
"""
|
||||
try:
|
||||
success_flag = service.set_tool_active(tool_id, current_user.tenant_id, request.is_active)
|
||||
if not success_flag:
|
||||
raise HTTPException(status_code=404, detail="工具不存在")
|
||||
action = "启用" if request.is_active else "禁用"
|
||||
return success(msg=f"工具已{action}")
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/execution/execute", response_model=ApiResponse)
|
||||
async def execute_tool(
|
||||
request: ToolExecuteRequest,
|
||||
@@ -196,6 +231,8 @@ async def execute_tool(
|
||||
},
|
||||
msg="工具执行完成"
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@@ -225,8 +262,10 @@ async def sync_mcp_tools(
|
||||
try:
|
||||
result = await service.sync_mcp_tools(tool_id, current_user.tenant_id)
|
||||
if not result.get("success", False):
|
||||
raise HTTPException(status_code=400, detail=result.get("message", "同步失败"))
|
||||
raise BusinessException(result.get("message", "工具列表同步失败"), BizCode.BAD_REQUEST)
|
||||
return success(data=result, msg="MCP工具列表同步完成")
|
||||
except BusinessException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@@ -249,8 +288,10 @@ async def test_tool_connection(
|
||||
# 普通连接测试
|
||||
result = await service.test_connection(tool_id, current_user.tenant_id)
|
||||
if result["success"] is False:
|
||||
raise HTTPException(status_code=400, detail=result["message"])
|
||||
raise BusinessException(result["message"], BizCode.SERVICE_UNAVAILABLE)
|
||||
return success(data=result, msg="连接测试完成")
|
||||
except BusinessException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from fastapi import APIRouter, Depends
|
||||
from sqlalchemy.orm import Session
|
||||
import uuid
|
||||
from typing import Callable
|
||||
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.exceptions import BusinessException
|
||||
@@ -19,6 +20,7 @@ from app.services import user_service
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.core.response_utils import success
|
||||
from app.core.security import verify_password
|
||||
from app.i18n.dependencies import get_translator
|
||||
|
||||
# 获取API专用日志器
|
||||
api_logger = get_api_logger()
|
||||
@@ -33,7 +35,8 @@ router = APIRouter(
|
||||
def create_superuser(
|
||||
user: user_schema.UserCreate,
|
||||
db: Session = Depends(get_db),
|
||||
current_superuser: User = Depends(get_current_superuser)
|
||||
current_superuser: User = Depends(get_current_superuser),
|
||||
t: Callable = Depends(get_translator)
|
||||
):
|
||||
"""创建超级管理员(仅超级管理员可访问)"""
|
||||
api_logger.info(f"超级管理员创建请求: {user.username}, email: {user.email}")
|
||||
@@ -42,7 +45,7 @@ def create_superuser(
|
||||
api_logger.info(f"超级管理员创建成功: {result.username} (ID: {result.id})")
|
||||
|
||||
result_schema = user_schema.User.model_validate(result)
|
||||
return success(data=result_schema, msg="超级管理员创建成功")
|
||||
return success(data=result_schema, msg=t("users.create.superuser_success"))
|
||||
|
||||
|
||||
@router.delete("/{user_id}", response_model=ApiResponse)
|
||||
@@ -50,6 +53,7 @@ def delete_user(
|
||||
user_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
t: Callable = Depends(get_translator)
|
||||
):
|
||||
"""停用用户(软删除)"""
|
||||
api_logger.info(f"用户停用请求: user_id={user_id}, 操作者: {current_user.username}")
|
||||
@@ -57,13 +61,14 @@ def delete_user(
|
||||
db=db, user_id_to_deactivate=user_id, current_user=current_user
|
||||
)
|
||||
api_logger.info(f"用户停用成功: {result.username} (ID: {result.id})")
|
||||
return success(msg="用户停用成功")
|
||||
return success(msg=t("users.delete.deactivate_success"))
|
||||
|
||||
@router.post("/{user_id}/activate", response_model=ApiResponse)
|
||||
def activate_user(
|
||||
user_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
t: Callable = Depends(get_translator)
|
||||
):
|
||||
"""激活用户"""
|
||||
api_logger.info(f"用户激活请求: user_id={user_id}, 操作者: {current_user.username}")
|
||||
@@ -74,13 +79,14 @@ def activate_user(
|
||||
api_logger.info(f"用户激活成功: {result.username} (ID: {result.id})")
|
||||
|
||||
result_schema = user_schema.User.model_validate(result)
|
||||
return success(data=result_schema, msg="用户激活成功")
|
||||
return success(data=result_schema, msg=t("users.activate.success"))
|
||||
|
||||
|
||||
@router.get("", response_model=ApiResponse)
|
||||
def get_current_user_info(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
t: Callable = Depends(get_translator)
|
||||
):
|
||||
"""获取当前用户信息"""
|
||||
api_logger.info(f"当前用户信息请求: {current_user.username}")
|
||||
@@ -105,7 +111,19 @@ def get_current_user_info(
|
||||
break
|
||||
|
||||
api_logger.info(f"当前用户信息获取成功: {result.username}, 角色: {result_schema.role}, 工作空间: {result_schema.current_workspace_name}")
|
||||
return success(data=result_schema, msg="用户信息获取成功")
|
||||
|
||||
# 设置权限:如果用户来自 SSO Source,则使用该 Source 的 permissions;否则返回 "all" 表示拥有所有权限
|
||||
if current_user.external_source:
|
||||
from premium.sso.models import SSOSource
|
||||
source = db.query(SSOSource).filter(SSOSource.source_code == current_user.external_source).first()
|
||||
if source and source.permissions:
|
||||
result_schema.permissions = source.permissions
|
||||
else:
|
||||
result_schema.permissions = []
|
||||
else:
|
||||
result_schema.permissions = ["all"]
|
||||
|
||||
return success(data=result_schema, msg=t("users.info.get_success"))
|
||||
|
||||
|
||||
@router.get("/superusers", response_model=ApiResponse)
|
||||
@@ -113,6 +131,7 @@ def get_tenant_superusers(
|
||||
include_inactive: bool = False,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_superuser),
|
||||
t: Callable = Depends(get_translator)
|
||||
):
|
||||
"""获取当前租户下的超管账号列表(仅超级管理员可访问)"""
|
||||
api_logger.info(f"获取租户超管列表请求: {current_user.username}")
|
||||
@@ -125,8 +144,7 @@ def get_tenant_superusers(
|
||||
api_logger.info(f"租户超管列表获取成功: count={len(superusers)}")
|
||||
|
||||
superusers_schema = [user_schema.User.model_validate(u) for u in superusers]
|
||||
return success(data=superusers_schema, msg="租户超管列表获取成功")
|
||||
|
||||
return success(data=superusers_schema, msg=t("users.list.superusers_success"))
|
||||
|
||||
|
||||
@router.get("/{user_id}", response_model=ApiResponse)
|
||||
@@ -134,6 +152,7 @@ def get_user_info_by_id(
|
||||
user_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
t: Callable = Depends(get_translator)
|
||||
):
|
||||
"""根据用户ID获取用户信息"""
|
||||
api_logger.info(f"获取用户信息请求: user_id={user_id}, 操作者: {current_user.username}")
|
||||
@@ -144,7 +163,7 @@ def get_user_info_by_id(
|
||||
api_logger.info(f"用户信息获取成功: {result.username}")
|
||||
|
||||
result_schema = user_schema.User.model_validate(result)
|
||||
return success(data=result_schema, msg="用户信息获取成功")
|
||||
return success(data=result_schema, msg=t("users.info.get_success"))
|
||||
|
||||
|
||||
@router.put("/change-password", response_model=ApiResponse)
|
||||
@@ -152,6 +171,7 @@ async def change_password(
|
||||
request: ChangePasswordRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
t: Callable = Depends(get_translator)
|
||||
):
|
||||
"""修改当前用户密码"""
|
||||
api_logger.info(f"用户密码修改请求: {current_user.username}")
|
||||
@@ -164,7 +184,7 @@ async def change_password(
|
||||
current_user=current_user
|
||||
)
|
||||
api_logger.info(f"用户密码修改成功: {current_user.username}")
|
||||
return success(msg="密码修改成功")
|
||||
return success(msg=t("auth.password.change_success"))
|
||||
|
||||
|
||||
@router.put("/admin/change-password", response_model=ApiResponse)
|
||||
@@ -172,6 +192,7 @@ async def admin_change_password(
|
||||
request: AdminChangePasswordRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_superuser),
|
||||
t: Callable = Depends(get_translator)
|
||||
):
|
||||
"""超级管理员修改指定用户的密码"""
|
||||
api_logger.info(f"管理员密码修改请求: 管理员 {current_user.username} 修改用户 {request.user_id}")
|
||||
@@ -186,16 +207,17 @@ async def admin_change_password(
|
||||
# 根据是否生成了随机密码来构造响应
|
||||
if request.new_password:
|
||||
api_logger.info(f"管理员密码修改成功: 用户 {request.user_id}")
|
||||
return success(msg="密码修改成功")
|
||||
return success(msg=t("auth.password.change_success"))
|
||||
else:
|
||||
api_logger.info(f"管理员密码重置成功: 用户 {request.user_id}, 随机密码已生成")
|
||||
return success(data=generated_password, msg="密码重置成功")
|
||||
return success(data=generated_password, msg=t("auth.password.reset_success"))
|
||||
|
||||
|
||||
@router.post("/verify_pwd", response_model=ApiResponse)
|
||||
def verify_pwd(
|
||||
request: VerifyPasswordRequest,
|
||||
current_user: User = Depends(get_current_user),
|
||||
t: Callable = Depends(get_translator)
|
||||
):
|
||||
"""验证当前用户密码"""
|
||||
api_logger.info(f"用户验证密码请求: {current_user.username}")
|
||||
@@ -203,8 +225,8 @@ def verify_pwd(
|
||||
is_valid = verify_password(request.password, current_user.hashed_password)
|
||||
api_logger.info(f"用户密码验证结果: {current_user.username}, valid={is_valid}")
|
||||
if not is_valid:
|
||||
raise BusinessException("密码验证失败", code=BizCode.VALIDATION_FAILED)
|
||||
return success(data={"valid": is_valid}, msg="验证完成")
|
||||
raise BusinessException(t("users.errors.password_verification_failed"), code=BizCode.VALIDATION_FAILED)
|
||||
return success(data={"valid": is_valid}, msg=t("common.success.retrieved"))
|
||||
|
||||
|
||||
@router.post("/send-email-code", response_model=ApiResponse)
|
||||
@@ -212,6 +234,7 @@ async def send_email_code(
|
||||
request: SendEmailCodeRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
t: Callable = Depends(get_translator)
|
||||
):
|
||||
"""发送邮箱验证码"""
|
||||
api_logger.info(f"用户请求发送邮箱验证码: {current_user.username}, email={request.email}")
|
||||
@@ -219,7 +242,7 @@ async def send_email_code(
|
||||
await user_service.send_email_code_method(db=db, email=request.email, user_id=current_user.id)
|
||||
|
||||
api_logger.info(f"邮箱验证码已发送: {current_user.username}")
|
||||
return success(msg="验证码已发送到您的邮箱,请查收")
|
||||
return success(msg=t("users.email.code_sent"))
|
||||
|
||||
|
||||
@router.put("/change-email", response_model=ApiResponse)
|
||||
@@ -227,6 +250,7 @@ async def change_email(
|
||||
request: VerifyEmailCodeRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
t: Callable = Depends(get_translator)
|
||||
):
|
||||
"""验证验证码并修改邮箱"""
|
||||
api_logger.info(f"用户修改邮箱: {current_user.username}, new_email={request.new_email}")
|
||||
@@ -239,4 +263,51 @@ async def change_email(
|
||||
)
|
||||
|
||||
api_logger.info(f"用户邮箱修改成功: {current_user.username}")
|
||||
return success(msg="邮箱修改成功")
|
||||
return success(msg=t("users.email.change_success"))
|
||||
|
||||
|
||||
|
||||
@router.get("/me/language", response_model=ApiResponse)
|
||||
def get_current_user_language(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
t: Callable = Depends(get_translator)
|
||||
):
|
||||
"""获取当前用户的语言偏好"""
|
||||
api_logger.info(f"获取用户语言偏好: {current_user.username}")
|
||||
|
||||
language = user_service.get_user_language_preference(
|
||||
db=db,
|
||||
user_id=current_user.id,
|
||||
current_user=current_user
|
||||
)
|
||||
|
||||
api_logger.info(f"用户语言偏好获取成功: {current_user.username}, language={language}")
|
||||
return success(
|
||||
data=user_schema.LanguagePreferenceResponse(language=language),
|
||||
msg=t("users.language.get_success")
|
||||
)
|
||||
|
||||
|
||||
@router.put("/me/language", response_model=ApiResponse)
|
||||
def update_current_user_language(
|
||||
request: user_schema.LanguagePreferenceRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
t: Callable = Depends(get_translator)
|
||||
):
|
||||
"""设置当前用户的语言偏好"""
|
||||
api_logger.info(f"更新用户语言偏好: {current_user.username}, language={request.language}")
|
||||
|
||||
updated_user = user_service.update_user_language_preference(
|
||||
db=db,
|
||||
user_id=current_user.id,
|
||||
language=request.language,
|
||||
current_user=current_user
|
||||
)
|
||||
|
||||
api_logger.info(f"用户语言偏好更新成功: {current_user.username}, language={request.language}")
|
||||
return success(
|
||||
data=user_schema.LanguagePreferenceResponse(language=updated_user.preferred_language),
|
||||
msg=t("users.language.update_success")
|
||||
)
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
from typing import Optional
|
||||
import datetime
|
||||
from sqlalchemy.orm import Session
|
||||
from fastapi import APIRouter, Depends,Header
|
||||
from fastapi import APIRouter, Depends, Header
|
||||
|
||||
from app.db import get_db
|
||||
from app.core.language_utils import get_language_from_header
|
||||
@@ -17,14 +17,17 @@ from app.services.user_memory_service import (
|
||||
UserMemoryService,
|
||||
analytics_memory_types,
|
||||
analytics_graph_data,
|
||||
analytics_community_graph_data,
|
||||
)
|
||||
from app.services.memory_entity_relationship_service import MemoryEntityService,MemoryEmotion,MemoryInteraction
|
||||
from app.services.memory_entity_relationship_service import MemoryEntityService, MemoryEmotion, MemoryInteraction
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.schemas.memory_storage_schema import GenerateCacheRequest
|
||||
from app.repositories.workspace_repository import WorkspaceRepository
|
||||
from app.schemas.end_user_schema import (
|
||||
EndUserProfileResponse,
|
||||
EndUserProfileUpdate,
|
||||
from app.repositories.end_user_repository import EndUserRepository
|
||||
from app.schemas.end_user_info_schema import (
|
||||
EndUserInfoResponse,
|
||||
EndUserInfoCreate,
|
||||
EndUserInfoUpdate,
|
||||
)
|
||||
from app.models.end_user_model import EndUser
|
||||
from app.dependencies import get_current_user
|
||||
@@ -44,9 +47,9 @@ router = APIRouter(
|
||||
|
||||
@router.get("/analytics/memory_insight/report", response_model=ApiResponse)
|
||||
async def get_memory_insight_report_api(
|
||||
end_user_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
end_user_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
"""
|
||||
获取缓存的记忆洞察报告
|
||||
@@ -72,10 +75,10 @@ async def get_memory_insight_report_api(
|
||||
|
||||
@router.get("/analytics/user_summary", response_model=ApiResponse)
|
||||
async def get_user_summary_api(
|
||||
end_user_id: str,
|
||||
language_type: str = Header(default=None, alias="X-Language-Type"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
end_user_id: str,
|
||||
language_type: str = Header(default=None, alias="X-Language-Type"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
"""
|
||||
获取缓存的用户摘要
|
||||
@@ -89,7 +92,7 @@ async def get_user_summary_api(
|
||||
"""
|
||||
# 使用集中化的语言校验
|
||||
language = get_language_from_header(language_type)
|
||||
|
||||
|
||||
workspace_id = current_user.current_workspace_id
|
||||
workspace_repo = WorkspaceRepository(db)
|
||||
workspace_models = workspace_repo.get_workspace_models_configs(workspace_id)
|
||||
@@ -101,7 +104,7 @@ async def get_user_summary_api(
|
||||
api_logger.info(f"用户摘要查询请求: end_user_id={end_user_id}, user={current_user.username}")
|
||||
try:
|
||||
# 调用服务层获取缓存数据
|
||||
result = await user_memory_service.get_cached_user_summary(db, end_user_id,model_id,language)
|
||||
result = await user_memory_service.get_cached_user_summary(db, end_user_id, model_id, language)
|
||||
|
||||
if result["is_cached"]:
|
||||
api_logger.info(f"成功返回缓存的用户摘要: end_user_id={end_user_id}")
|
||||
@@ -116,10 +119,10 @@ async def get_user_summary_api(
|
||||
|
||||
@router.post("/analytics/generate_cache", response_model=ApiResponse)
|
||||
async def generate_cache_api(
|
||||
request: GenerateCacheRequest,
|
||||
language_type: str = Header(default=None, alias="X-Language-Type"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
request: GenerateCacheRequest,
|
||||
language_type: str = Header(default=None, alias="X-Language-Type"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
"""
|
||||
手动触发缓存生成
|
||||
@@ -133,7 +136,7 @@ async def generate_cache_api(
|
||||
"""
|
||||
# 使用集中化的语言校验
|
||||
language = get_language_from_header(language_type)
|
||||
|
||||
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
# 检查用户是否已选择工作空间
|
||||
@@ -154,10 +157,12 @@ async def generate_cache_api(
|
||||
api_logger.info(f"开始为单个用户生成缓存: end_user_id={end_user_id}")
|
||||
|
||||
# 生成记忆洞察
|
||||
insight_result = await user_memory_service.generate_and_cache_insight(db, end_user_id, workspace_id, language=language)
|
||||
insight_result = await user_memory_service.generate_and_cache_insight(db, end_user_id, workspace_id,
|
||||
language=language)
|
||||
|
||||
# 生成用户摘要
|
||||
summary_result = await user_memory_service.generate_and_cache_summary(db, end_user_id, workspace_id, language=language)
|
||||
summary_result = await user_memory_service.generate_and_cache_summary(db, end_user_id, workspace_id,
|
||||
language=language)
|
||||
|
||||
# 构建响应
|
||||
result = {
|
||||
@@ -208,9 +213,9 @@ async def generate_cache_api(
|
||||
|
||||
@router.get("/analytics/node_statistics", response_model=ApiResponse)
|
||||
async def get_node_statistics_api(
|
||||
end_user_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
end_user_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
@@ -219,7 +224,8 @@ async def get_node_statistics_api(
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试查询节点统计但未选择工作空间")
|
||||
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
||||
|
||||
api_logger.info(f"记忆类型统计请求: end_user_id={end_user_id}, user={current_user.username}, workspace={workspace_id}")
|
||||
api_logger.info(
|
||||
f"记忆类型统计请求: end_user_id={end_user_id}, user={current_user.username}, workspace={workspace_id}")
|
||||
|
||||
try:
|
||||
# 调用新的记忆类型统计函数
|
||||
@@ -227,21 +233,23 @@ async def get_node_statistics_api(
|
||||
|
||||
# 计算总数用于日志
|
||||
total_count = sum(item["count"] for item in result)
|
||||
api_logger.info(f"成功获取记忆类型统计: end_user_id={end_user_id}, 总记忆数={total_count}, 类型数={len(result)}")
|
||||
api_logger.info(
|
||||
f"成功获取记忆类型统计: end_user_id={end_user_id}, 总记忆数={total_count}, 类型数={len(result)}")
|
||||
return success(data=result, msg="查询成功")
|
||||
except Exception as e:
|
||||
api_logger.error(f"记忆类型查询失败: end_user_id={end_user_id}, error={str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "记忆类型查询失败", str(e))
|
||||
|
||||
|
||||
@router.get("/analytics/graph_data", response_model=ApiResponse)
|
||||
async def get_graph_data_api(
|
||||
end_user_id: str,
|
||||
node_types: Optional[str] = None,
|
||||
limit: int = 100,
|
||||
depth: int = 1,
|
||||
center_node_id: Optional[str] = None,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
end_user_id: str,
|
||||
node_types: Optional[str] = None,
|
||||
limit: int = 100,
|
||||
depth: int = 1,
|
||||
center_node_id: Optional[str] = None,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
@@ -295,110 +303,165 @@ async def get_graph_data_api(
|
||||
return fail(BizCode.INTERNAL_ERROR, "图数据查询失败", str(e))
|
||||
|
||||
|
||||
@router.get("/read_end_user/profile", response_model=ApiResponse)
|
||||
async def get_end_user_profile(
|
||||
end_user_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
@router.get("/analytics/community_graph", response_model=ApiResponse)
|
||||
async def get_community_graph_data_api(
|
||||
end_user_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
workspace_id = current_user.current_workspace_id
|
||||
workspace_repo = WorkspaceRepository(db)
|
||||
workspace_models = workspace_repo.get_workspace_models_configs(workspace_id)
|
||||
|
||||
if workspace_models:
|
||||
model_id = workspace_models.get("llm", None)
|
||||
else:
|
||||
model_id = None
|
||||
# 检查用户是否已选择工作空间
|
||||
if workspace_id is None:
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试查询用户信息但未选择工作空间")
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试查询社区图谱但未选择工作空间")
|
||||
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
||||
|
||||
api_logger.info(
|
||||
f"用户信息查询请求: end_user_id={end_user_id}, user={current_user.username}, "
|
||||
f"社区图谱查询请求: end_user_id={end_user_id}, user={current_user.username}, "
|
||||
f"workspace={workspace_id}"
|
||||
)
|
||||
|
||||
try:
|
||||
# 查询终端用户
|
||||
end_user = db.query(EndUser).filter(EndUser.id == end_user_id).first()
|
||||
result = await analytics_community_graph_data(db=db, end_user_id=end_user_id)
|
||||
|
||||
if not end_user:
|
||||
api_logger.warning(f"终端用户不存在: end_user_id={end_user_id}")
|
||||
return fail(BizCode.INVALID_PARAMETER, "终端用户不存在", f"end_user_id={end_user_id}")
|
||||
# 构建响应数据
|
||||
profile_data = EndUserProfileResponse(
|
||||
id=end_user.id,
|
||||
other_name=end_user.other_name,
|
||||
position=end_user.position,
|
||||
department=end_user.department,
|
||||
contact=end_user.contact,
|
||||
phone=end_user.phone,
|
||||
hire_date=end_user.hire_date,
|
||||
updatetime_profile=end_user.updatetime_profile
|
||||
if "message" in result and result["statistics"]["total_nodes"] == 0:
|
||||
api_logger.warning(f"社区图谱查询返回空结果: {result.get('message')}")
|
||||
return success(data=result, msg=result.get("message", "查询成功"))
|
||||
|
||||
api_logger.info(
|
||||
f"成功获取社区图谱: end_user_id={end_user_id}, "
|
||||
f"nodes={result['statistics']['total_nodes']}, "
|
||||
f"edges={result['statistics']['total_edges']}"
|
||||
)
|
||||
|
||||
api_logger.info(f"成功获取用户信息: end_user_id={end_user_id}")
|
||||
return success(data=UserMemoryService.convert_profile_to_dict_with_timestamp(profile_data), msg="查询成功")
|
||||
return success(data=result, msg="查询成功")
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"用户信息查询失败: end_user_id={end_user_id}, error={str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "用户信息查询失败", str(e))
|
||||
api_logger.error(f"社区图谱查询失败: end_user_id={end_user_id}, error={str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "社区图谱查询失败", str(e))
|
||||
|
||||
#=======================终端用户信息接口=======================
|
||||
|
||||
@router.post("/updated_end_user/profile", response_model=ApiResponse)
|
||||
async def update_end_user_profile(
|
||||
profile_update: EndUserProfileUpdate,
|
||||
@router.get("/end_user_info", response_model=ApiResponse)
|
||||
async def get_end_user_info(
|
||||
end_user_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
"""
|
||||
更新终端用户的基本信息
|
||||
查询终端用户信息记录
|
||||
|
||||
该接口可以更新用户的姓名、职位、部门、联系方式、电话和入职日期等信息。
|
||||
所有字段都是可选的,只更新提供的字段。
|
||||
根据 end_user_id 查询单条终端用户信息记录。
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
end_user_id = profile_update.end_user_id
|
||||
|
||||
# 验证工作空间
|
||||
if workspace_id is None:
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试更新用户信息但未选择工作空间")
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试查询终端用户信息但未选择工作空间")
|
||||
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
||||
|
||||
api_logger.info(
|
||||
f"用户信息更新请求: end_user_id={end_user_id}, user={current_user.username}, "
|
||||
f"查询终端用户信息请求: end_user_id={end_user_id}, user={current_user.username}, "
|
||||
f"workspace={workspace_id}"
|
||||
)
|
||||
|
||||
# 调用 Service 层处理业务逻辑
|
||||
result = user_memory_service.update_end_user_profile(db, end_user_id, profile_update)
|
||||
# 校验 end_user 是否属于当前工作空间
|
||||
end_user_repo = EndUserRepository(db)
|
||||
end_user = end_user_repo.get_end_user_by_id(end_user_id)
|
||||
if end_user is None:
|
||||
return fail(BizCode.USER_NOT_FOUND, "终端用户不存在", "end_user not found")
|
||||
if str(end_user.workspace_id) != str(workspace_id):
|
||||
api_logger.warning(
|
||||
f"用户 {current_user.username} 尝试查询不属于工作空间 {workspace_id} 的终端用户 {end_user_id}"
|
||||
)
|
||||
return fail(BizCode.PERMISSION_DENIED, "该终端用户不属于当前工作空间", "end_user workspace mismatch")
|
||||
|
||||
result = user_memory_service.get_end_user_info(db, end_user_id)
|
||||
|
||||
if result["success"]:
|
||||
api_logger.info(f"成功更新用户信息: end_user_id={end_user_id}")
|
||||
api_logger.info(f"成功查询终端用户信息: end_user_id={end_user_id}")
|
||||
return success(data=result["data"], msg="查询成功")
|
||||
else:
|
||||
error_msg = result["error"]
|
||||
api_logger.error(f"查询终端用户信息失败: end_user_id={end_user_id}, error={error_msg}")
|
||||
|
||||
if error_msg == "终端用户信息记录不存在":
|
||||
return fail(BizCode.USER_NOT_FOUND, "终端用户信息记录不存在", error_msg)
|
||||
elif error_msg == "无效的终端用户ID格式":
|
||||
return fail(BizCode.INVALID_USER_ID, "无效的终端用户ID格式", error_msg)
|
||||
else:
|
||||
return fail(BizCode.INTERNAL_ERROR, "查询终端用户信息失败", error_msg)
|
||||
|
||||
|
||||
@router.post("/end_user_info/updated", response_model=ApiResponse)
|
||||
async def update_end_user_info(
|
||||
info_update: EndUserInfoUpdate,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
"""
|
||||
更新终端用户信息记录
|
||||
|
||||
根据 end_user_id 更新终端用户信息记录,支持批量更新多个别名。
|
||||
|
||||
示例请求体:
|
||||
{
|
||||
"end_user_id": "a1b2c3d4-e5f6-7890-abcd-ef1234567890",
|
||||
"other_name": "张三1",
|
||||
"aliases": ["小张", "张工"],
|
||||
"meta_data": {"position": "工程师", "department": "技术部"}
|
||||
}
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
end_user_id = info_update.end_user_id
|
||||
|
||||
if workspace_id is None:
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试更新终端用户信息但未选择工作空间")
|
||||
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
||||
|
||||
api_logger.info(
|
||||
f"更新终端用户信息请求: end_user_id={end_user_id}, user={current_user.username}, "
|
||||
f"workspace={workspace_id}"
|
||||
)
|
||||
|
||||
# 校验 end_user 是否属于当前工作空间
|
||||
end_user_repo = EndUserRepository(db)
|
||||
end_user = end_user_repo.get_end_user_by_id(end_user_id)
|
||||
if end_user is None:
|
||||
return fail(BizCode.USER_NOT_FOUND, "终端用户不存在", "end_user not found")
|
||||
if str(end_user.workspace_id) != str(workspace_id):
|
||||
api_logger.warning(
|
||||
f"用户 {current_user.username} 尝试更新不属于工作空间 {workspace_id} 的终端用户 {end_user_id}"
|
||||
)
|
||||
return fail(BizCode.PERMISSION_DENIED, "该终端用户不属于当前工作空间", "end_user workspace mismatch")
|
||||
|
||||
# 获取更新数据(排除 end_user_id)
|
||||
update_data = info_update.model_dump(exclude_unset=True, exclude={'end_user_id'})
|
||||
|
||||
result = user_memory_service.update_end_user_info(db, end_user_id, update_data)
|
||||
|
||||
if result["success"]:
|
||||
api_logger.info(f"成功更新终端用户信息: end_user_id={end_user_id}")
|
||||
return success(data=result["data"], msg="更新成功")
|
||||
else:
|
||||
error_msg = result["error"]
|
||||
api_logger.error(f"用户信息更新失败: end_user_id={end_user_id}, error={error_msg}")
|
||||
api_logger.error(f"终端用户信息更新失败: end_user_id={end_user_id}, error={error_msg}")
|
||||
|
||||
# 根据错误类型映射到合适的业务错误码
|
||||
if error_msg == "终端用户不存在":
|
||||
return fail(BizCode.USER_NOT_FOUND, "终端用户不存在", error_msg)
|
||||
elif error_msg == "无效的用户ID格式":
|
||||
return fail(BizCode.INVALID_USER_ID, "无效的用户ID格式", error_msg)
|
||||
if error_msg == "终端用户信息记录不存在":
|
||||
return fail(BizCode.USER_NOT_FOUND, "终端用户信息记录不存在", error_msg)
|
||||
elif error_msg == "无效的终端用户ID格式":
|
||||
return fail(BizCode.INVALID_USER_ID, "无效的终端用户ID格式", error_msg)
|
||||
else:
|
||||
# 只有未预期的错误才使用 INTERNAL_ERROR
|
||||
return fail(BizCode.INTERNAL_ERROR, "用户信息更新失败", error_msg)
|
||||
return fail(BizCode.INTERNAL_ERROR, "终端用户信息更新失败", error_msg)
|
||||
|
||||
@router.get("/memory_space/timeline_memories", response_model=ApiResponse)
|
||||
async def memory_space_timeline_of_shared_memories(id: str, label: str,language_type: str = Header(default=None, alias="X-Language-Type"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
async def memory_space_timeline_of_shared_memories(
|
||||
id: str, label: str,
|
||||
language_type: str = Header(default=None, alias="X-Language-Type"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
# 使用集中化的语言校验
|
||||
language = get_language_from_header(language_type)
|
||||
|
||||
workspace_id=current_user.current_workspace_id
|
||||
|
||||
workspace_id = current_user.current_workspace_id
|
||||
workspace_repo = WorkspaceRepository(db)
|
||||
workspace_models = workspace_repo.get_workspace_models_configs(workspace_id)
|
||||
|
||||
@@ -410,11 +473,13 @@ async def memory_space_timeline_of_shared_memories(id: str, label: str,language_
|
||||
timeline_memories_result = await MemoryEntity.get_timeline_memories_server(model_id, language)
|
||||
|
||||
return success(data=timeline_memories_result, msg="共同记忆时间线")
|
||||
|
||||
|
||||
@router.get("/memory_space/relationship_evolution", response_model=ApiResponse)
|
||||
async def memory_space_relationship_evolution(id: str, label: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
try:
|
||||
api_logger.info(f"关系演变查询请求: id={id}, table={label}, user={current_user.username}")
|
||||
|
||||
|
||||
@@ -14,6 +14,12 @@ from app.dependencies import (
|
||||
get_current_user,
|
||||
workspace_access_guard,
|
||||
)
|
||||
from app.i18n.dependencies import get_current_language, get_translator
|
||||
from app.i18n.serializers import (
|
||||
WorkspaceSerializer,
|
||||
WorkspaceMemberSerializer,
|
||||
WorkspaceInviteSerializer
|
||||
)
|
||||
from app.models.tenant_model import Tenants
|
||||
from app.models.user_model import User
|
||||
from app.models.workspace_model import InviteStatus
|
||||
@@ -65,7 +71,9 @@ def get_workspaces(
|
||||
include_current: bool = Query(True, description="是否包含当前工作空间"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
current_tenant: Tenants = Depends(get_current_tenant)
|
||||
current_tenant: Tenants = Depends(get_current_tenant),
|
||||
language: str = Depends(get_current_language),
|
||||
t: callable = Depends(get_translator)
|
||||
):
|
||||
"""获取当前租户下用户参与的所有工作空间
|
||||
|
||||
@@ -88,8 +96,13 @@ def get_workspaces(
|
||||
)
|
||||
|
||||
api_logger.info(f"成功获取 {len(workspaces)} 个工作空间")
|
||||
workspaces_schema = [WorkspaceResponse.model_validate(w) for w in workspaces]
|
||||
return success(data=workspaces_schema, msg="工作空间列表获取成功")
|
||||
|
||||
# 使用序列化器添加国际化字段
|
||||
serializer = WorkspaceSerializer()
|
||||
workspaces_data = [WorkspaceResponse.model_validate(w).model_dump() for w in workspaces]
|
||||
workspaces_i18n = serializer.serialize_list(workspaces_data, language)
|
||||
|
||||
return success(data=workspaces_i18n, msg=t("workspace.list_retrieved"))
|
||||
|
||||
|
||||
@router.post("", response_model=ApiResponse)
|
||||
@@ -98,6 +111,8 @@ def create_workspace(
|
||||
language_type: str = Header(default="zh", alias="X-Language-Type"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_superuser),
|
||||
language: str = Depends(get_current_language),
|
||||
t: callable = Depends(get_translator)
|
||||
):
|
||||
"""创建新的工作空间"""
|
||||
from app.core.language_utils import get_language_from_header
|
||||
@@ -118,8 +133,13 @@ def create_workspace(
|
||||
f"工作空间创建成功 - 名称: {workspace.name}, ID: {result.id}, "
|
||||
f"创建者: {current_user.username}, language={language}"
|
||||
)
|
||||
result_schema = WorkspaceResponse.model_validate(result)
|
||||
return success(data=result_schema, msg="工作空间创建成功")
|
||||
|
||||
# 使用序列化器添加国际化字段
|
||||
serializer = WorkspaceSerializer()
|
||||
result_data = WorkspaceResponse.model_validate(result).model_dump()
|
||||
result_i18n = serializer.serialize(result_data, language)
|
||||
|
||||
return success(data=result_i18n, msg=t("workspace.created"))
|
||||
|
||||
@router.put("", response_model=ApiResponse)
|
||||
@cur_workspace_access_guard()
|
||||
@@ -127,6 +147,8 @@ def update_workspace(
|
||||
workspace: WorkspaceUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
language: str = Depends(get_current_language),
|
||||
t: callable = Depends(get_translator)
|
||||
):
|
||||
"""更新工作空间"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
@@ -139,14 +161,21 @@ def update_workspace(
|
||||
user=current_user,
|
||||
)
|
||||
api_logger.info(f"工作空间更新成功 - ID: {workspace_id}, 用户: {current_user.username}")
|
||||
result_schema = WorkspaceResponse.model_validate(result)
|
||||
return success(data=result_schema, msg="工作空间更新成功")
|
||||
|
||||
# 使用序列化器添加国际化字段
|
||||
serializer = WorkspaceSerializer()
|
||||
result_data = WorkspaceResponse.model_validate(result).model_dump()
|
||||
result_i18n = serializer.serialize(result_data, language)
|
||||
|
||||
return success(data=result_i18n, msg=t("workspace.updated"))
|
||||
|
||||
@router.get("/members", response_model=ApiResponse)
|
||||
@cur_workspace_access_guard()
|
||||
def get_cur_workspace_members(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
language: str = Depends(get_current_language),
|
||||
t: callable = Depends(get_translator)
|
||||
):
|
||||
"""获取工作空间成员列表(关系序列化)"""
|
||||
api_logger.info(f"用户 {current_user.username} 请求获取工作空间 {current_user.current_workspace_id} 的成员列表")
|
||||
@@ -157,8 +186,14 @@ def get_cur_workspace_members(
|
||||
user=current_user,
|
||||
)
|
||||
api_logger.info(f"工作空间成员列表获取成功 - ID: {current_user.current_workspace_id}, 数量: {len(members)}")
|
||||
|
||||
# 转换为表格项并使用序列化器添加国际化字段
|
||||
table_items = _convert_members_to_table_items(members)
|
||||
return success(data=table_items, msg="工作空间成员列表获取成功")
|
||||
serializer = WorkspaceMemberSerializer()
|
||||
members_data = [item.model_dump() for item in table_items]
|
||||
members_i18n = serializer.serialize_list(members_data, language)
|
||||
|
||||
return success(data=members_i18n, msg=t("workspace.members.list_retrieved"))
|
||||
|
||||
|
||||
@router.put("/members", response_model=ApiResponse)
|
||||
@@ -168,6 +203,7 @@ def update_workspace_members(
|
||||
updates: List[WorkspaceMemberUpdate],
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
t: callable = Depends(get_translator)
|
||||
):
|
||||
workspace_id = current_user.current_workspace_id
|
||||
api_logger.info(f"用户 {current_user.username} 请求更新工作空间 {workspace_id} 的成员角色")
|
||||
@@ -178,7 +214,7 @@ def update_workspace_members(
|
||||
user=current_user,
|
||||
)
|
||||
api_logger.info(f"工作空间成员角色更新成功 - ID: {workspace_id}, 数量: {len(members)}")
|
||||
return success(msg="成员角色更新成功")
|
||||
return success(msg=t("workspace.members.role_updated"))
|
||||
|
||||
|
||||
@router.delete("/members/{member_id}", response_model=ApiResponse)
|
||||
@@ -187,6 +223,7 @@ def delete_workspace_member(
|
||||
member_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
t: callable = Depends(get_translator)
|
||||
):
|
||||
workspace_id = current_user.current_workspace_id
|
||||
api_logger.info(f"用户 {current_user.username} 请求删除工作空间 {workspace_id} 的成员 {member_id}")
|
||||
@@ -198,7 +235,7 @@ def delete_workspace_member(
|
||||
user=current_user,
|
||||
)
|
||||
api_logger.info(f"工作空间成员删除成功 - ID: {workspace_id}, 成员: {member_id}")
|
||||
return success(msg="成员删除成功")
|
||||
return success(msg=t("workspace.members.deleted"))
|
||||
|
||||
|
||||
# 创建空间协作邀请
|
||||
@@ -208,6 +245,8 @@ def create_workspace_invite(
|
||||
invite_data: WorkspaceInviteCreate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
language: str = Depends(get_current_language),
|
||||
t: callable = Depends(get_translator)
|
||||
):
|
||||
"""创建工作空间邀请"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
@@ -220,7 +259,12 @@ def create_workspace_invite(
|
||||
user=current_user
|
||||
)
|
||||
api_logger.info(f"工作空间邀请创建成功 - 工作空间: {workspace_id}, 邮箱: {invite_data.email}")
|
||||
return success(data=result, msg="邀请创建成功")
|
||||
|
||||
# 使用序列化器添加国际化字段
|
||||
serializer = WorkspaceInviteSerializer()
|
||||
result_i18n = serializer.serialize(result, language)
|
||||
|
||||
return success(data=result_i18n, msg=t("workspace.invites.created"))
|
||||
|
||||
|
||||
@router.get("/invites", response_model=ApiResponse)
|
||||
@@ -232,6 +276,8 @@ def get_workspace_invites(
|
||||
offset: int = Query(0, ge=0),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
language: str = Depends(get_current_language),
|
||||
t: callable = Depends(get_translator)
|
||||
):
|
||||
"""获取工作空间邀请列表"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
@@ -246,18 +292,30 @@ def get_workspace_invites(
|
||||
offset=offset
|
||||
)
|
||||
api_logger.info(f"成功获取 {len(invites)} 个邀请记录")
|
||||
return success(data=invites, msg="邀请列表获取成功")
|
||||
|
||||
# 使用序列化器添加国际化字段
|
||||
serializer = WorkspaceInviteSerializer()
|
||||
invites_i18n = serializer.serialize_list(invites, language)
|
||||
|
||||
return success(data=invites_i18n, msg=t("workspace.invites.list_retrieved"))
|
||||
|
||||
|
||||
@public_router.get("/invites/validate/{token}", response_model=ApiResponse)
|
||||
def get_workspace_invite_info(
|
||||
token: str,
|
||||
db: Session = Depends(get_db),
|
||||
language: str = Depends(get_current_language),
|
||||
t: callable = Depends(get_translator)
|
||||
):
|
||||
"""获取工作空间邀请用户信息(无需认证)"""
|
||||
result = workspace_service.validate_invite_token(db=db, token=token)
|
||||
api_logger.info(f"工作空间邀请验证成功 - 邀请: {token}")
|
||||
return success(data=result, msg="邀请验证成功")
|
||||
|
||||
# 使用序列化器添加国际化字段
|
||||
serializer = WorkspaceInviteSerializer()
|
||||
result_i18n = serializer.serialize(result, language)
|
||||
|
||||
return success(data=result_i18n, msg=t("workspace.invites.validated"))
|
||||
|
||||
|
||||
@router.delete("/invites/{invite_id}", response_model=ApiResponse)
|
||||
@@ -267,6 +325,8 @@ def revoke_workspace_invite(
|
||||
invite_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
language: str = Depends(get_current_language),
|
||||
t: callable = Depends(get_translator)
|
||||
):
|
||||
"""撤销工作空间邀请"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
@@ -279,7 +339,12 @@ def revoke_workspace_invite(
|
||||
user=current_user
|
||||
)
|
||||
api_logger.info(f"工作空间邀请撤销成功 - 邀请: {invite_id}")
|
||||
return success(data=result, msg="邀请撤销成功")
|
||||
|
||||
# 使用序列化器添加国际化字段
|
||||
serializer = WorkspaceInviteSerializer()
|
||||
result_i18n = serializer.serialize(result, language)
|
||||
|
||||
return success(data=result_i18n, msg=t("workspace.invites.revoked"))
|
||||
|
||||
# ==================== 公开邀请接口(无需认证) ====================
|
||||
|
||||
@@ -302,6 +367,7 @@ def switch_workspace(
|
||||
workspace_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
t: callable = Depends(get_translator)
|
||||
):
|
||||
"""切换工作空间"""
|
||||
api_logger.info(f"用户 {current_user.username} 请求切换工作空间为 {workspace_id}")
|
||||
@@ -312,7 +378,7 @@ def switch_workspace(
|
||||
user=current_user,
|
||||
)
|
||||
api_logger.info(f"成功切换工作空间为 {workspace_id}")
|
||||
return success(msg="工作空间切换成功")
|
||||
return success(msg=t("workspace.switched"))
|
||||
|
||||
|
||||
@router.get("/storage", response_model=ApiResponse)
|
||||
@@ -320,6 +386,7 @@ def switch_workspace(
|
||||
def get_workspace_storage_type(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
t: callable = Depends(get_translator)
|
||||
):
|
||||
"""获取当前工作空间的存储类型"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
@@ -331,7 +398,7 @@ def get_workspace_storage_type(
|
||||
user=current_user
|
||||
)
|
||||
api_logger.info(f"成功获取工作空间 {workspace_id} 的存储类型: {storage_type}")
|
||||
return success(data={"storage_type": storage_type}, msg="存储类型获取成功")
|
||||
return success(data={"storage_type": storage_type}, msg=t("workspace.storage.type_retrieved"))
|
||||
|
||||
|
||||
@router.get("/workspace_models", response_model=ApiResponse)
|
||||
@@ -339,6 +406,8 @@ def get_workspace_storage_type(
|
||||
def workspace_models_configs(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
language: str = Depends(get_current_language),
|
||||
t: callable = Depends(get_translator)
|
||||
):
|
||||
"""获取当前工作空间的模型配置(llm, embedding, rerank)"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
@@ -354,14 +423,14 @@ def workspace_models_configs(
|
||||
api_logger.warning(f"工作空间 {workspace_id} 不存在或无权访问")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="工作空间不存在或无权访问"
|
||||
detail=t("workspace.not_found")
|
||||
)
|
||||
|
||||
api_logger.info(
|
||||
f"成功获取工作空间 {workspace_id} 的模型配置: "
|
||||
f"llm={configs.get('llm')}, embedding={configs.get('embedding')}, rerank={configs.get('rerank')}"
|
||||
)
|
||||
return success(data=WorkspaceModelsConfig.model_validate(configs), msg="模型配置获取成功")
|
||||
return success(data=WorkspaceModelsConfig.model_validate(configs), msg=t("workspace.models.config_retrieved"))
|
||||
|
||||
|
||||
@router.put("/workspace_models", response_model=ApiResponse)
|
||||
@@ -370,6 +439,7 @@ def update_workspace_models_configs(
|
||||
models_update: WorkspaceModelsUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
t: callable = Depends(get_translator)
|
||||
):
|
||||
"""更新当前工作空间的模型配置(llm, embedding, rerank)"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
@@ -386,5 +456,5 @@ def update_workspace_models_configs(
|
||||
f"成功更新工作空间 {workspace_id} 的模型配置: "
|
||||
f"llm={updated_workspace.llm}, embedding={updated_workspace.embedding}, rerank={updated_workspace.rerank}"
|
||||
)
|
||||
return success(data=WorkspaceModelsConfig.model_validate(updated_workspace), msg="模型配置更新成功")
|
||||
return success(data=WorkspaceModelsConfig.model_validate(updated_workspace), msg=t("workspace.models.config_updated"))
|
||||
|
||||
|
||||
@@ -11,17 +11,14 @@ LangChain Agent 封装
|
||||
import time
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional, Sequence
|
||||
|
||||
from app.core.memory.agent.langgraph_graph.write_graph import write_long_term
|
||||
from app.db import get_db
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.core.models import RedBearLLM, RedBearModelConfig
|
||||
from app.models.models_model import ModelType, ModelProvider
|
||||
from app.services.memory_agent_service import (
|
||||
get_end_user_connected_config,
|
||||
)
|
||||
from langchain.agents import create_agent
|
||||
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
|
||||
from langchain_core.tools import BaseTool
|
||||
from langgraph.errors import GraphRecursionError
|
||||
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.core.models import RedBearLLM, RedBearModelConfig
|
||||
from app.models.models_model import ModelType
|
||||
|
||||
logger = get_business_logger()
|
||||
|
||||
@@ -41,7 +38,10 @@ class LangChainAgent:
|
||||
tools: Optional[Sequence[BaseTool]] = None,
|
||||
streaming: bool = False,
|
||||
max_iterations: Optional[int] = None, # 最大迭代次数(None 表示自动计算)
|
||||
max_tool_consecutive_calls: int = 3 # 单个工具最大连续调用次数
|
||||
max_tool_consecutive_calls: int = 3, # 单个工具最大连续调用次数
|
||||
deep_thinking: bool = False, # 是否启用深度思考模式
|
||||
thinking_budget_tokens: Optional[int] = None, # 深度思考 token 预算
|
||||
capability: Optional[List[str]] = None # 模型能力列表,用于校验是否支持深度思考
|
||||
):
|
||||
"""初始化 LangChain Agent
|
||||
|
||||
@@ -64,6 +64,7 @@ class LangChainAgent:
|
||||
self.streaming = streaming
|
||||
self.is_omni = is_omni
|
||||
self.max_tool_consecutive_calls = max_tool_consecutive_calls
|
||||
self.deep_thinking = deep_thinking and ("thinking" in (capability or []))
|
||||
|
||||
# 工具调用计数器:记录每个工具的连续调用次数
|
||||
self.tool_call_counter: Dict[str, int] = {}
|
||||
@@ -86,6 +87,13 @@ class LangChainAgent:
|
||||
f"auto_calculated={max_iterations is None}"
|
||||
)
|
||||
|
||||
# 根据 capability 校验是否真正支持深度思考
|
||||
actual_deep_thinking = self.deep_thinking
|
||||
if deep_thinking and not actual_deep_thinking:
|
||||
logger.warning(
|
||||
f"模型 {model_name} 不支持深度思考(capability 中无 'thinking'),已自动关闭 deep_thinking"
|
||||
)
|
||||
|
||||
# 创建 RedBearLLM(支持多提供商)
|
||||
model_config = RedBearModelConfig(
|
||||
model_name=model_name,
|
||||
@@ -93,10 +101,13 @@ class LangChainAgent:
|
||||
api_key=api_key,
|
||||
base_url=api_base,
|
||||
is_omni=is_omni,
|
||||
deep_thinking=actual_deep_thinking,
|
||||
thinking_budget_tokens=thinking_budget_tokens if actual_deep_thinking else None,
|
||||
support_thinking="thinking" in (capability or []),
|
||||
extra_params={
|
||||
"temperature": temperature,
|
||||
"max_tokens": max_tokens,
|
||||
"streaming": streaming # 使用参数控制流式
|
||||
"streaming": streaming
|
||||
}
|
||||
)
|
||||
|
||||
@@ -226,10 +237,9 @@ class LangChainAgent:
|
||||
Returns:
|
||||
List[BaseMessage]: 消息列表
|
||||
"""
|
||||
messages = []
|
||||
messages:list = [SystemMessage(content=self.system_prompt)]
|
||||
|
||||
# 添加系统提示词
|
||||
messages.append(SystemMessage(content=self.system_prompt))
|
||||
|
||||
# 添加历史消息
|
||||
if history:
|
||||
@@ -254,6 +264,33 @@ class LangChainAgent:
|
||||
|
||||
return messages
|
||||
|
||||
@staticmethod
|
||||
def _extract_tokens_from_message(msg) -> int:
|
||||
"""从 AIMessage 或类似对象中提取 total_tokens,兼容多种 provider 格式
|
||||
|
||||
支持的格式:
|
||||
- response_metadata.token_usage.total_tokens (OpenAI/ChatOpenAI)
|
||||
- response_metadata.usage.total_tokens (部分 provider)
|
||||
- usage_metadata.total_tokens (LangChain 新版)
|
||||
"""
|
||||
total = 0
|
||||
# 1. response_metadata
|
||||
response_meta = getattr(msg, "response_metadata", None)
|
||||
if response_meta and isinstance(response_meta, dict):
|
||||
# 尝试 token_usage 路径
|
||||
token_usage = response_meta.get("token_usage") or response_meta.get("usage", {})
|
||||
if isinstance(token_usage, dict):
|
||||
total = token_usage.get("total_tokens", 0)
|
||||
# 2. usage_metadata(LangChain 新版 AIMessage 属性)
|
||||
if not total:
|
||||
usage_meta = getattr(msg, "usage_metadata", None)
|
||||
if usage_meta:
|
||||
if isinstance(usage_meta, dict):
|
||||
total = usage_meta.get("total_tokens", 0)
|
||||
else:
|
||||
total = getattr(usage_meta, "total_tokens", 0)
|
||||
return total or 0
|
||||
|
||||
def _build_multimodal_content(self, text: str, files: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
构建多模态消息内容
|
||||
@@ -288,17 +325,23 @@ class LangChainAgent:
|
||||
|
||||
return content_parts
|
||||
|
||||
@staticmethod
|
||||
def _extract_reasoning_content(msg) -> str:
|
||||
"""从 AIMessage 中提取深度思考内容(reasoning_content)
|
||||
|
||||
所有 provider 统一通过 additional_kwargs.reasoning_content 传递:
|
||||
- DeepSeek-R1 / QwQ: 原生字段
|
||||
- Volcano (Doubao-thinking): 由 VolcanoChatOpenAI 从 delta.reasoning_content 注入
|
||||
"""
|
||||
additional = getattr(msg, "additional_kwargs", None) or {}
|
||||
return additional.get("reasoning_content") or additional.get("reasoning", "")
|
||||
|
||||
async def chat(
|
||||
self,
|
||||
message: str,
|
||||
history: Optional[List[Dict[str, str]]] = None,
|
||||
context: Optional[str] = None,
|
||||
end_user_id: Optional[str] = None,
|
||||
config_id: Optional[str] = None, # 添加这个参数
|
||||
storage_type: Optional[str] = None,
|
||||
user_rag_memory_id: Optional[str] = None,
|
||||
memory_flag: Optional[bool] = True,
|
||||
files: Optional[List[Dict[str, Any]]] = None # 新增:多模态文件
|
||||
files: Optional[List[Dict[str, Any]]] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""执行对话
|
||||
|
||||
@@ -306,32 +349,12 @@ class LangChainAgent:
|
||||
message: 用户消息
|
||||
history: 历史消息列表 [{"role": "user/assistant", "content": "..."}]
|
||||
context: 上下文信息(如知识库检索结果)
|
||||
files: 多模态文件
|
||||
|
||||
Returns:
|
||||
Dict: 包含 content 和元数据的字典
|
||||
"""
|
||||
message_chat = message
|
||||
start_time = time.time()
|
||||
actual_config_id = config_id
|
||||
# If config_id is None, try to get from end_user's connected config
|
||||
if actual_config_id is None and end_user_id:
|
||||
try:
|
||||
from app.services.memory_agent_service import (
|
||||
get_end_user_connected_config,
|
||||
)
|
||||
db = next(get_db())
|
||||
try:
|
||||
connected_config = get_end_user_connected_config(end_user_id, db)
|
||||
actual_config_id = connected_config.get("memory_config_id")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get connected config for end_user {end_user_id}: {e}")
|
||||
finally:
|
||||
db.close()
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get db session: {e}")
|
||||
actual_end_user_id = end_user_id if end_user_id is not None else "unknown"
|
||||
logger.info(f'写入类型{storage_type, str(end_user_id), message, str(user_rag_memory_id)}')
|
||||
print(f'写入类型{storage_type, str(end_user_id), message, str(user_rag_memory_id)}')
|
||||
try:
|
||||
# 准备消息列表(支持多模态)
|
||||
messages = self._prepare_messages(message, history, context, files)
|
||||
@@ -355,7 +378,7 @@ class LangChainAgent:
|
||||
{"messages": messages},
|
||||
config={"recursion_limit": self.max_iterations}
|
||||
)
|
||||
except RecursionError as e:
|
||||
except (RecursionError, GraphRecursionError) as e:
|
||||
logger.warning(
|
||||
f"Agent 达到最大迭代次数限制 ({self.max_iterations}),可能存在工具调用循环",
|
||||
extra={"error": str(e)}
|
||||
@@ -378,6 +401,7 @@ class LangChainAgent:
|
||||
|
||||
logger.debug(f"输出消息数量: {len(output_messages)}")
|
||||
total_tokens = 0
|
||||
reasoning_content = ""
|
||||
for msg in reversed(output_messages):
|
||||
if isinstance(msg, AIMessage):
|
||||
logger.debug(f"找到 AI 消息,content 类型: {type(msg.content)}")
|
||||
@@ -412,16 +436,13 @@ class LangChainAgent:
|
||||
else:
|
||||
content = str(msg.content)
|
||||
logger.debug(f"转换为字符串: {content[:100]}...")
|
||||
response_meta = msg.response_metadata if hasattr(msg, 'response_metadata') else None
|
||||
total_tokens = response_meta.get("token_usage", {}).get("total_tokens", 0) if response_meta else 0
|
||||
total_tokens = self._extract_tokens_from_message(msg)
|
||||
reasoning_content = self._extract_reasoning_content(msg) if self.deep_thinking else ""
|
||||
break
|
||||
|
||||
logger.info(f"最终提取的内容长度: {len(content)}")
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
if memory_flag:
|
||||
await write_long_term(storage_type, end_user_id, message_chat, content, user_rag_memory_id,
|
||||
actual_config_id)
|
||||
response = {
|
||||
"content": content,
|
||||
"model": self.model_name,
|
||||
@@ -432,6 +453,8 @@ class LangChainAgent:
|
||||
"total_tokens": total_tokens
|
||||
}
|
||||
}
|
||||
if reasoning_content:
|
||||
response["reasoning_content"] = reasoning_content
|
||||
|
||||
logger.debug(
|
||||
"Agent 调用完成",
|
||||
@@ -452,22 +475,20 @@ class LangChainAgent:
|
||||
message: str,
|
||||
history: Optional[List[Dict[str, str]]] = None,
|
||||
context: Optional[str] = None,
|
||||
end_user_id: Optional[str] = None,
|
||||
config_id: Optional[str] = None,
|
||||
storage_type: Optional[str] = None,
|
||||
user_rag_memory_id: Optional[str] = None,
|
||||
memory_flag: Optional[bool] = True,
|
||||
files: Optional[List[Dict[str, Any]]] = None # 新增:多模态文件
|
||||
) -> AsyncGenerator[str, None]:
|
||||
files: Optional[List[Dict[str, Any]]] = None
|
||||
) -> AsyncGenerator[str | int | dict[str, str], None]:
|
||||
"""执行流式对话
|
||||
|
||||
Args:
|
||||
message: 用户消息
|
||||
history: 历史消息列表
|
||||
context: 上下文信息
|
||||
files: 多模态文件
|
||||
|
||||
Yields:
|
||||
str: 消息内容块
|
||||
int: token 统计
|
||||
Dict: 深度思考内容 {"type": "reasoning", "content": "..."}
|
||||
"""
|
||||
logger.info("=" * 80)
|
||||
logger.info(" chat_stream 方法开始执行")
|
||||
@@ -475,23 +496,6 @@ class LangChainAgent:
|
||||
logger.info(f" Has tools: {bool(self.tools)}")
|
||||
logger.info(f" Tool count: {len(self.tools) if self.tools else 0}")
|
||||
logger.info("=" * 80)
|
||||
message_chat = message
|
||||
actual_config_id = config_id
|
||||
# If config_id is None, try to get from end_user's connected config
|
||||
if actual_config_id is None and end_user_id:
|
||||
try:
|
||||
db = next(get_db())
|
||||
try:
|
||||
connected_config = get_end_user_connected_config(end_user_id, db)
|
||||
actual_config_id = connected_config.get("memory_config_id")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get connected config for end_user {end_user_id}: {e}")
|
||||
finally:
|
||||
db.close()
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get db session: {e}")
|
||||
|
||||
# 注意:不在这里写入用户消息,等 AI 回复后一起写入
|
||||
try:
|
||||
# 准备消息列表(支持多模态)
|
||||
messages = self._prepare_messages(message, history, context, files)
|
||||
@@ -501,17 +505,19 @@ class LangChainAgent:
|
||||
)
|
||||
|
||||
chunk_count = 0
|
||||
yielded_content = False
|
||||
|
||||
# 统一使用 agent 的 astream_events 实现流式输出
|
||||
logger.debug("使用 Agent astream_events 实现流式输出")
|
||||
full_content = ''
|
||||
full_reasoning = ''
|
||||
try:
|
||||
last_event = {}
|
||||
async for event in self.agent.astream_events(
|
||||
{"messages": messages},
|
||||
version="v2",
|
||||
config={"recursion_limit": self.max_iterations}
|
||||
):
|
||||
last_event = event
|
||||
chunk_count += 1
|
||||
kind = event.get("event")
|
||||
|
||||
@@ -520,12 +526,18 @@ class LangChainAgent:
|
||||
# LLM 流式输出
|
||||
chunk = event.get("data", {}).get("chunk")
|
||||
if chunk and hasattr(chunk, "content"):
|
||||
# 提取深度思考内容(仅在启用深度思考时)
|
||||
if self.deep_thinking:
|
||||
reasoning_chunk = self._extract_reasoning_content(chunk)
|
||||
if reasoning_chunk:
|
||||
full_reasoning += reasoning_chunk
|
||||
yield {"type": "reasoning", "content": reasoning_chunk}
|
||||
|
||||
# 处理多模态响应:content 可能是字符串或列表
|
||||
chunk_content = chunk.content
|
||||
if isinstance(chunk_content, str) and chunk_content:
|
||||
full_content += chunk_content
|
||||
yield chunk_content
|
||||
yielded_content = True
|
||||
elif isinstance(chunk_content, list):
|
||||
# 多模态响应:提取文本部分
|
||||
for item in chunk_content:
|
||||
@@ -536,29 +548,32 @@ class LangChainAgent:
|
||||
if text:
|
||||
full_content += text
|
||||
yield text
|
||||
yielded_content = True
|
||||
# OpenAI 格式: {"type": "text", "text": "..."}
|
||||
elif item.get("type") == "text":
|
||||
text = item.get("text", "")
|
||||
if text:
|
||||
full_content += text
|
||||
yield text
|
||||
yielded_content = True
|
||||
elif isinstance(item, str):
|
||||
full_content += item
|
||||
yield item
|
||||
yielded_content = True
|
||||
|
||||
elif kind == "on_llm_stream":
|
||||
# 另一种 LLM 流式事件
|
||||
chunk = event.get("data", {}).get("chunk")
|
||||
if chunk:
|
||||
if hasattr(chunk, "content"):
|
||||
# 提取深度思考内容(仅在启用深度思考时)
|
||||
if self.deep_thinking:
|
||||
reasoning_chunk = self._extract_reasoning_content(chunk)
|
||||
if reasoning_chunk:
|
||||
full_reasoning += reasoning_chunk
|
||||
yield {"type": "reasoning", "content": reasoning_chunk}
|
||||
|
||||
chunk_content = chunk.content
|
||||
if isinstance(chunk_content, str) and chunk_content:
|
||||
full_content += chunk_content
|
||||
yield chunk_content
|
||||
yielded_content = True
|
||||
elif isinstance(chunk_content, list):
|
||||
# 多模态响应:提取文本部分
|
||||
for item in chunk_content:
|
||||
@@ -569,22 +584,18 @@ class LangChainAgent:
|
||||
if text:
|
||||
full_content += text
|
||||
yield text
|
||||
yielded_content = True
|
||||
# OpenAI 格式: {"type": "text", "text": "..."}
|
||||
elif item.get("type") == "text":
|
||||
text = item.get("text", "")
|
||||
if text:
|
||||
full_content += text
|
||||
yield text
|
||||
yielded_content = True
|
||||
elif isinstance(item, str):
|
||||
full_content += item
|
||||
yield item
|
||||
yielded_content = True
|
||||
elif isinstance(chunk, str):
|
||||
full_content += chunk
|
||||
yield chunk
|
||||
yielded_content = True
|
||||
|
||||
# 记录工具调用(可选)
|
||||
elif kind == "on_tool_start":
|
||||
@@ -594,17 +605,20 @@ class LangChainAgent:
|
||||
|
||||
logger.debug(f"Agent 流式完成,共 {chunk_count} 个事件")
|
||||
# 统计token消耗
|
||||
output_messages = event.get("data", {}).get("output", {}).get("messages", [])
|
||||
output_messages = last_event.get("data", {}).get("output", {}).get("messages", [])
|
||||
for msg in reversed(output_messages):
|
||||
if isinstance(msg, AIMessage):
|
||||
response_meta = msg.response_metadata if hasattr(msg, 'response_metadata') else None
|
||||
total_tokens = response_meta.get("token_usage", {}).get("total_tokens",
|
||||
0) if response_meta else 0
|
||||
yield total_tokens
|
||||
stream_total_tokens = self._extract_tokens_from_message(msg)
|
||||
logger.info(f"流式 token 统计: total_tokens={stream_total_tokens}")
|
||||
yield stream_total_tokens
|
||||
break
|
||||
if memory_flag:
|
||||
await write_long_term(storage_type, end_user_id, message_chat, full_content, user_rag_memory_id,
|
||||
actual_config_id)
|
||||
|
||||
except GraphRecursionError:
|
||||
logger.warning(
|
||||
f"Agent 达到最大迭代次数限制 ({self.max_iterations}),模型可能不支持正确的工具调用停止判断"
|
||||
)
|
||||
if not full_content:
|
||||
yield "抱歉,我在处理您的请求时遇到了问题(已达最大处理步骤限制)。请尝试简化问题或更换模型后重试。"
|
||||
except Exception as e:
|
||||
logger.error(f"Agent astream_events 失败: {str(e)}", exc_info=True)
|
||||
raise
|
||||
|
||||
@@ -97,6 +97,7 @@ class Settings:
|
||||
|
||||
# File Upload
|
||||
MAX_FILE_SIZE: int = int(os.getenv("MAX_FILE_SIZE", "52428800"))
|
||||
MAX_FILE_COUNT: int = int(os.getenv("MAX_FILE_COUNT", "20"))
|
||||
FILE_PATH: str = os.getenv("FILE_PATH", "/files")
|
||||
FILE_URL_EXPIRES: int = int(os.getenv("FILE_URL_EXPIRES", "3600"))
|
||||
|
||||
@@ -162,6 +163,44 @@ class Settings:
|
||||
# This controls the language used for memory summary titles and other generated content
|
||||
DEFAULT_LANGUAGE: str = os.getenv("DEFAULT_LANGUAGE", "zh")
|
||||
|
||||
# ========================================================================
|
||||
# Internationalization (i18n) Configuration
|
||||
# ========================================================================
|
||||
# Default language for API responses
|
||||
I18N_DEFAULT_LANGUAGE: str = os.getenv("I18N_DEFAULT_LANGUAGE", "zh")
|
||||
|
||||
# Supported languages (comma-separated)
|
||||
I18N_SUPPORTED_LANGUAGES: list[str] = [
|
||||
lang.strip()
|
||||
for lang in os.getenv("I18N_SUPPORTED_LANGUAGES", "zh,en").split(",")
|
||||
if lang.strip()
|
||||
]
|
||||
|
||||
# Core locales directory (community edition)
|
||||
# Use absolute path to work from any working directory
|
||||
I18N_CORE_LOCALES_DIR: str = os.getenv(
|
||||
"I18N_CORE_LOCALES_DIR",
|
||||
os.path.join(os.path.dirname(os.path.dirname(__file__)), "locales")
|
||||
)
|
||||
|
||||
# Premium locales directory (enterprise edition, optional)
|
||||
I18N_PREMIUM_LOCALES_DIR: Optional[str] = os.getenv("I18N_PREMIUM_LOCALES_DIR", None)
|
||||
|
||||
# Enable translation cache
|
||||
I18N_ENABLE_TRANSLATION_CACHE: bool = os.getenv("I18N_ENABLE_TRANSLATION_CACHE", "true").lower() == "true"
|
||||
|
||||
# LRU cache size for hot translations
|
||||
I18N_LRU_CACHE_SIZE: int = int(os.getenv("I18N_LRU_CACHE_SIZE", "1000"))
|
||||
|
||||
# Enable hot reload of translation files
|
||||
I18N_ENABLE_HOT_RELOAD: bool = os.getenv("I18N_ENABLE_HOT_RELOAD", "false").lower() == "true"
|
||||
|
||||
# Fallback language when translation is missing
|
||||
I18N_FALLBACK_LANGUAGE: str = os.getenv("I18N_FALLBACK_LANGUAGE", "zh")
|
||||
|
||||
# Log missing translations
|
||||
I18N_LOG_MISSING_TRANSLATIONS: bool = os.getenv("I18N_LOG_MISSING_TRANSLATIONS", "true").lower() == "true"
|
||||
|
||||
# Logging settings
|
||||
LOG_LEVEL: str = os.getenv("LOG_LEVEL", "INFO")
|
||||
LOG_FORMAT: str = os.getenv("LOG_FORMAT", "%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
||||
@@ -192,8 +231,8 @@ class Settings:
|
||||
# Celery configuration (internal)
|
||||
# NOTE: 变量名不以 CELERY_ 开头,避免被 Celery CLI 的前缀匹配机制劫持
|
||||
# 详见 docs/celery-env-bug-report.md
|
||||
# 默认使用 Redis DB 3 (broker) 和 DB 4 (backend),与业务缓存 (DB 1/2) 隔离
|
||||
# 多人共用同一 Redis 时,每位开发者应在 .env 中配置不同的 DB 编号避免任务互相干扰
|
||||
# 默认使用 Redis 作为 broker 和 backend,与业务缓存隔离
|
||||
# 如需使用 RabbitMQ,在 .env 中设置 CELERY_BROKER_URL=amqp://user:pass@host:5672/vhost
|
||||
REDIS_DB_CELERY_BROKER: int = int(os.getenv("REDIS_DB_CELERY_BROKER", "3"))
|
||||
REDIS_DB_CELERY_BACKEND: int = int(os.getenv("REDIS_DB_CELERY_BACKEND", "4"))
|
||||
|
||||
|
||||
@@ -19,6 +19,7 @@ class BizCode(IntEnum):
|
||||
TENANT_NOT_FOUND = 3002
|
||||
WORKSPACE_NO_ACCESS = 3003
|
||||
WORKSPACE_INVITE_NOT_FOUND = 3004
|
||||
WORKSPACE_ACCESS_DENIED = 3005
|
||||
# API Key 管理(3xxx)
|
||||
API_KEY_NOT_FOUND = 3007
|
||||
API_KEY_DUPLICATE_NAME = 3008
|
||||
@@ -40,6 +41,7 @@ class BizCode(IntEnum):
|
||||
FILE_NOT_FOUND = 4006
|
||||
APP_NOT_FOUND = 4007
|
||||
RELEASE_NOT_FOUND = 4008
|
||||
USER_NO_ACCESS = 4009
|
||||
|
||||
# 冲突/状态(5xxx)
|
||||
DUPLICATE_NAME = 5001
|
||||
@@ -113,8 +115,11 @@ HTTP_MAPPING = {
|
||||
BizCode.FORBIDDEN: 403,
|
||||
BizCode.TENANT_NOT_FOUND: 400,
|
||||
BizCode.WORKSPACE_NO_ACCESS: 403,
|
||||
BizCode.WORKSPACE_INVITE_NOT_FOUND: 400,
|
||||
BizCode.WORKSPACE_ACCESS_DENIED: 403,
|
||||
BizCode.NOT_FOUND: 400,
|
||||
BizCode.USER_NOT_FOUND: 200,
|
||||
BizCode.USER_NO_ACCESS: 401,
|
||||
BizCode.WORKSPACE_NOT_FOUND: 400,
|
||||
BizCode.MODEL_NOT_FOUND: 400,
|
||||
BizCode.KNOWLEDGE_NOT_FOUND: 400,
|
||||
|
||||
@@ -529,8 +529,9 @@ def log_time(step_name: str, duration: float, log_file: str = "logs/time.log") -
|
||||
# Fallback to console only if file write fails
|
||||
print(f"Warning: Could not write to timing log: {e}")
|
||||
|
||||
# Always print to console (backward compatible behavior)
|
||||
print(f"✓ {step_name}: {duration:.2f}s")
|
||||
# Always log at INFO level (avoids Celery treating stdout as WARNING)
|
||||
_timing_logger = logging.getLogger(__name__)
|
||||
_timing_logger.info(f"✓ {step_name}: {duration:.2f}s")
|
||||
|
||||
|
||||
def get_agent_logger(name: str = "agent_service",
|
||||
|
||||
@@ -1,16 +1,45 @@
|
||||
from app.core.memory.agent.utils.llm_tools import ReadState, WriteState
|
||||
from app.schemas.memory_agent_schema import AgentMemoryDataset
|
||||
|
||||
|
||||
def content_input_node(state: ReadState) -> ReadState:
|
||||
"""开始节点 - 提取内容并保持状态信息"""
|
||||
"""
|
||||
Start node - Extract content and maintain state information
|
||||
|
||||
Extracts the content from the first message in the state and returns it
|
||||
as the data field while preserving all other state information.
|
||||
|
||||
Args:
|
||||
state: ReadState containing messages and other state data
|
||||
|
||||
Returns:
|
||||
ReadState: Updated state with extracted content in data field
|
||||
"""
|
||||
|
||||
content = state['messages'][0].content if state.get('messages') else ''
|
||||
# 返回内容并保持所有状态信息
|
||||
# Return content and maintain all state information
|
||||
for pronoun in AgentMemoryDataset.PRONOUN:
|
||||
content = content.replace(pronoun, AgentMemoryDataset.NAME)
|
||||
|
||||
return {"data": content}
|
||||
|
||||
|
||||
def content_input_write(state: WriteState) -> WriteState:
|
||||
"""开始节点 - 提取内容并保持状态信息"""
|
||||
"""
|
||||
Start node - Extract content and maintain state information for write operations
|
||||
|
||||
Extracts the content from the first message in the state for write operations.
|
||||
|
||||
Args:
|
||||
state: WriteState containing messages and other state data
|
||||
|
||||
Returns:
|
||||
WriteState: Updated state with extracted content in data field
|
||||
"""
|
||||
|
||||
content = state['messages'][0].content if state.get('messages') else ''
|
||||
# 返回内容并保持所有状态信息
|
||||
return {"data": content}
|
||||
# Return content and maintain all state information
|
||||
for pronoun in AgentMemoryDataset.PRONOUN:
|
||||
content = content.replace(pronoun, AgentMemoryDataset.NAME)
|
||||
|
||||
return {"data": content}
|
||||
|
||||
@@ -0,0 +1,408 @@
|
||||
"""
|
||||
Perceptual Memory Retrieval Node & Service
|
||||
|
||||
Provides PerceptualSearchService for searching perceptual memories (vision, audio,
|
||||
text, conversation) from Neo4j using keyword fulltext + embedding semantic search
|
||||
with BM25+embedding fusion reranking.
|
||||
|
||||
Also provides the perceptual_retrieve_node for use as a LangGraph node.
|
||||
"""
|
||||
import asyncio
|
||||
import math
|
||||
from typing import List, Dict, Any, Optional
|
||||
|
||||
from app.core.logging_config import get_agent_logger
|
||||
from app.core.memory.agent.utils.llm_tools import ReadState
|
||||
from app.core.memory.utils.data.text_utils import escape_lucene_query
|
||||
from app.repositories.neo4j.graph_search import (
|
||||
search_perceptual,
|
||||
search_perceptual_by_embedding,
|
||||
)
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
|
||||
logger = get_agent_logger(__name__)
|
||||
|
||||
|
||||
class PerceptualSearchService:
|
||||
"""
|
||||
感知记忆检索服务。
|
||||
|
||||
封装关键词全文检索 + 向量语义检索 + BM25/embedding 融合排序的完整流程。
|
||||
调用方只需提供 query / keywords、end_user_id、memory_config,即可获得
|
||||
格式化并排序后的感知记忆列表和拼接文本。
|
||||
|
||||
Usage:
|
||||
service = PerceptualSearchService(end_user_id=..., memory_config=...)
|
||||
results = await service.search(query="...", keywords=[...], limit=10)
|
||||
# results = {"memories": [...], "content": "...", "keyword_raw": N, "embedding_raw": M}
|
||||
"""
|
||||
|
||||
DEFAULT_ALPHA = 0.6
|
||||
DEFAULT_CONTENT_SCORE_THRESHOLD = 0.5
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
end_user_id: str,
|
||||
memory_config: Any,
|
||||
alpha: float = DEFAULT_ALPHA,
|
||||
content_score_threshold: float = DEFAULT_CONTENT_SCORE_THRESHOLD,
|
||||
):
|
||||
self.end_user_id = end_user_id
|
||||
self.memory_config = memory_config
|
||||
self.alpha = alpha
|
||||
self.content_score_threshold = content_score_threshold
|
||||
|
||||
async def search(
|
||||
self,
|
||||
query: str,
|
||||
keywords: Optional[List[str]] = None,
|
||||
limit: int = 10,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
执行感知记忆检索(关键词 + 向量并行),融合排序后返回结果。
|
||||
|
||||
对 embedding 命中但 keyword 未命中的结果,补查全文索引获取 BM25 分数,
|
||||
确保所有结果都同时具备 BM25 和 embedding 两个维度的评分。
|
||||
|
||||
Args:
|
||||
query: 原始用户查询(用于向量检索和 BM25 补查)
|
||||
keywords: 关键词列表(用于全文检索),为 None 时使用 [query]
|
||||
limit: 最大返回数量
|
||||
|
||||
Returns:
|
||||
{
|
||||
"memories": [格式化后的记忆 dict, ...],
|
||||
"content": "拼接的纯文本摘要",
|
||||
"keyword_raw": int,
|
||||
"embedding_raw": int,
|
||||
}
|
||||
"""
|
||||
if keywords is None:
|
||||
keywords = [query] if query else []
|
||||
|
||||
connector = Neo4jConnector()
|
||||
try:
|
||||
kw_task = self._keyword_search(connector, keywords, limit)
|
||||
emb_task = self._embedding_search(connector, query, limit)
|
||||
|
||||
kw_results, emb_results = await asyncio.gather(
|
||||
kw_task, emb_task, return_exceptions=True
|
||||
)
|
||||
if isinstance(kw_results, Exception):
|
||||
logger.warning(f"[PerceptualSearch] keyword search error: {kw_results}")
|
||||
kw_results = []
|
||||
if isinstance(emb_results, Exception):
|
||||
logger.warning(f"[PerceptualSearch] embedding search error: {emb_results}")
|
||||
emb_results = []
|
||||
|
||||
# 补查 BM25:找出 embedding 命中但 keyword 未命中的 id,
|
||||
# 用原始 query 对这些节点补查全文索引拿 BM25 score
|
||||
kw_ids = {r.get("id") for r in kw_results if r.get("id")}
|
||||
emb_only_ids = {r.get("id") for r in emb_results if r.get("id") and r.get("id") not in kw_ids}
|
||||
|
||||
if emb_only_ids and query:
|
||||
backfill = await self._bm25_backfill(connector, query, emb_only_ids, limit)
|
||||
# 把补查到的 BM25 score 注入到 embedding 结果中
|
||||
backfill_map = {r["id"]: r.get("score", 0) for r in backfill}
|
||||
for r in emb_results:
|
||||
rid = r.get("id", "")
|
||||
if rid in backfill_map:
|
||||
r["bm25_backfill_score"] = backfill_map[rid]
|
||||
logger.info(
|
||||
f"[PerceptualSearch] BM25 backfill: {len(emb_only_ids)} embedding-only ids, "
|
||||
f"{len(backfill_map)} got BM25 scores"
|
||||
)
|
||||
|
||||
reranked = self._rerank(kw_results, emb_results, limit)
|
||||
|
||||
memories = []
|
||||
content_parts = []
|
||||
for record in reranked:
|
||||
fmt = self._format_result(record)
|
||||
fmt["score"] = round(record.get("content_score", 0), 4)
|
||||
memories.append(fmt)
|
||||
content_parts.append(self._build_content_text(fmt))
|
||||
|
||||
logger.info(
|
||||
f"[PerceptualSearch] {len(memories)} results after rerank "
|
||||
f"(keyword_raw={len(kw_results)}, embedding_raw={len(emb_results)})"
|
||||
)
|
||||
return {
|
||||
"memories": memories,
|
||||
"content": "\n\n".join(content_parts),
|
||||
"keyword_raw": len(kw_results),
|
||||
"embedding_raw": len(emb_results),
|
||||
}
|
||||
finally:
|
||||
await connector.close()
|
||||
|
||||
async def _bm25_backfill(
|
||||
self,
|
||||
connector: Neo4jConnector,
|
||||
query: str,
|
||||
target_ids: set,
|
||||
limit: int,
|
||||
) -> List[dict]:
|
||||
"""
|
||||
对指定 id 集合补查全文索引 BM25 score。
|
||||
|
||||
用原始 query 查全文索引,只保留 id 在 target_ids 中的结果。
|
||||
"""
|
||||
escaped = escape_lucene_query(query)
|
||||
if not escaped.strip():
|
||||
return []
|
||||
try:
|
||||
r = await search_perceptual(
|
||||
connector=connector, query=escaped,
|
||||
end_user_id=self.end_user_id,
|
||||
limit=limit * 5, # 多查一些以提高命中率
|
||||
)
|
||||
all_hits = r.get("perceptuals", [])
|
||||
return [h for h in all_hits if h.get("id") in target_ids]
|
||||
except Exception as e:
|
||||
logger.warning(f"[PerceptualSearch] BM25 backfill failed: {e}")
|
||||
return []
|
||||
|
||||
async def _keyword_search(
|
||||
self,
|
||||
connector: Neo4jConnector,
|
||||
keywords: List[str],
|
||||
limit: int,
|
||||
) -> List[dict]:
|
||||
"""并发对每个关键词做全文检索,去重后按 score 降序返回 top N 原始结果。"""
|
||||
seen_ids: set = set()
|
||||
all_results: List[dict] = []
|
||||
|
||||
async def _one(kw: str):
|
||||
escaped = escape_lucene_query(kw)
|
||||
if not escaped.strip():
|
||||
return []
|
||||
r = await search_perceptual(
|
||||
connector=connector, query=escaped,
|
||||
end_user_id=self.end_user_id, limit=limit,
|
||||
)
|
||||
return r.get("perceptuals", [])
|
||||
|
||||
tasks = [_one(kw) for kw in keywords[:10]]
|
||||
batch = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
for result in batch:
|
||||
if isinstance(result, Exception):
|
||||
logger.warning(f"[PerceptualSearch] keyword sub-query error: {result}")
|
||||
continue
|
||||
for rec in result:
|
||||
rid = rec.get("id", "")
|
||||
if rid and rid not in seen_ids:
|
||||
seen_ids.add(rid)
|
||||
all_results.append(rec)
|
||||
|
||||
all_results.sort(key=lambda x: float(x.get("score", 0)), reverse=True)
|
||||
return all_results[:limit]
|
||||
|
||||
async def _embedding_search(
|
||||
self,
|
||||
connector: Neo4jConnector,
|
||||
query_text: str,
|
||||
limit: int,
|
||||
) -> List[dict]:
|
||||
"""向量语义检索,返回原始结果(不做阈值过滤)。"""
|
||||
try:
|
||||
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
from app.db import get_db_context
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
|
||||
with get_db_context() as db:
|
||||
cfg = MemoryConfigService(db).get_embedder_config(
|
||||
str(self.memory_config.embedding_model_id)
|
||||
)
|
||||
client = OpenAIEmbedderClient(RedBearModelConfig(**cfg))
|
||||
|
||||
r = await search_perceptual_by_embedding(
|
||||
connector=connector, embedder_client=client,
|
||||
query_text=query_text, end_user_id=self.end_user_id,
|
||||
limit=limit,
|
||||
)
|
||||
return r.get("perceptuals", [])
|
||||
except Exception as e:
|
||||
logger.warning(f"[PerceptualSearch] embedding search failed: {e}")
|
||||
return []
|
||||
|
||||
def _rerank(
|
||||
self,
|
||||
keyword_results: List[dict],
|
||||
embedding_results: List[dict],
|
||||
limit: int,
|
||||
) -> List[dict]:
|
||||
"""BM25 + embedding 融合排序。
|
||||
|
||||
对 embedding 结果中带有 bm25_backfill_score 的条目,
|
||||
将其与 keyword 结果合并后统一归一化,确保 BM25 分数在同一尺度上。
|
||||
"""
|
||||
# 把补查的 BM25 score 合并到 keyword_results 中统一归一化
|
||||
emb_backfill_items = []
|
||||
for item in embedding_results:
|
||||
backfill_score = item.get("bm25_backfill_score")
|
||||
if backfill_score is not None and item.get("id"):
|
||||
emb_backfill_items.append({"id": item["id"], "score": backfill_score})
|
||||
|
||||
# 合并后统一归一化 BM25 scores
|
||||
all_bm25_items = keyword_results + emb_backfill_items
|
||||
all_bm25_items = self._normalize_scores(all_bm25_items)
|
||||
|
||||
# 建立 id -> normalized BM25 score 的映射
|
||||
bm25_norm_map: Dict[str, float] = {}
|
||||
for item in all_bm25_items:
|
||||
item_id = item.get("id", "")
|
||||
if item_id:
|
||||
bm25_norm_map[item_id] = float(item.get("normalized_score", 0))
|
||||
|
||||
# 归一化 embedding scores
|
||||
embedding_results = self._normalize_scores(embedding_results)
|
||||
|
||||
# 合并
|
||||
combined: Dict[str, dict] = {}
|
||||
for item in keyword_results:
|
||||
item_id = item.get("id", "")
|
||||
if not item_id:
|
||||
continue
|
||||
combined[item_id] = item.copy()
|
||||
combined[item_id]["bm25_score"] = bm25_norm_map.get(item_id, 0)
|
||||
combined[item_id]["embedding_score"] = 0.0
|
||||
|
||||
for item in embedding_results:
|
||||
item_id = item.get("id", "")
|
||||
if not item_id:
|
||||
continue
|
||||
if item_id in combined:
|
||||
combined[item_id]["embedding_score"] = item.get("normalized_score", 0)
|
||||
else:
|
||||
combined[item_id] = item.copy()
|
||||
combined[item_id]["bm25_score"] = bm25_norm_map.get(item_id, 0)
|
||||
combined[item_id]["embedding_score"] = item.get("normalized_score", 0)
|
||||
|
||||
for item in combined.values():
|
||||
bm25 = float(item.get("bm25_score", 0) or 0)
|
||||
emb = float(item.get("embedding_score", 0) or 0)
|
||||
item["content_score"] = self.alpha * bm25 + (1 - self.alpha) * emb
|
||||
|
||||
results = list(combined.values())
|
||||
before = len(results)
|
||||
results = [r for r in results if r["content_score"] >= self.content_score_threshold]
|
||||
results.sort(key=lambda x: x["content_score"], reverse=True)
|
||||
results = results[:limit]
|
||||
|
||||
logger.info(
|
||||
f"[PerceptualSearch] rerank: merged={before}, after_threshold={len(results)} "
|
||||
f"(alpha={self.alpha}, threshold={self.content_score_threshold})"
|
||||
)
|
||||
return results
|
||||
|
||||
@staticmethod
|
||||
def _normalize_scores(items: List[dict], field: str = "score") -> List[dict]:
|
||||
"""Z-score + sigmoid 归一化。"""
|
||||
if not items:
|
||||
return items
|
||||
scores = [float(it.get(field, 0) or 0) for it in items]
|
||||
if len(scores) <= 1:
|
||||
for it in items:
|
||||
it[f"normalized_{field}"] = 1.0
|
||||
return items
|
||||
mean = sum(scores) / len(scores)
|
||||
var = sum((s - mean) ** 2 for s in scores) / len(scores)
|
||||
std = math.sqrt(var)
|
||||
if std == 0:
|
||||
for it in items:
|
||||
it[f"normalized_{field}"] = 1.0
|
||||
else:
|
||||
for it, s in zip(items, scores):
|
||||
z = (s - mean) / std
|
||||
it[f"normalized_{field}"] = 1 / (1 + math.exp(-z))
|
||||
return items
|
||||
|
||||
@staticmethod
|
||||
def _format_result(record: dict) -> dict:
|
||||
return {
|
||||
"id": record.get("id", ""),
|
||||
"perceptual_type": record.get("perceptual_type", ""),
|
||||
"file_name": record.get("file_name", ""),
|
||||
"file_path": record.get("file_path", ""),
|
||||
"summary": record.get("summary", ""),
|
||||
"topic": record.get("topic", ""),
|
||||
"domain": record.get("domain", ""),
|
||||
"keywords": record.get("keywords", []),
|
||||
"created_at": str(record.get("created_at", "")),
|
||||
"file_type": record.get("file_type", ""),
|
||||
"score": record.get("score", 0),
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _build_content_text(formatted: dict) -> str:
|
||||
parts = []
|
||||
if formatted["summary"]:
|
||||
parts.append(formatted["summary"])
|
||||
if formatted["topic"]:
|
||||
parts.append(f"[主题: {formatted['topic']}]")
|
||||
if formatted["keywords"]:
|
||||
kw_list = formatted["keywords"]
|
||||
if isinstance(kw_list, list):
|
||||
parts.append(f"[关键词: {', '.join(kw_list)}]")
|
||||
if formatted["file_name"]:
|
||||
parts.append(f"[文件: {formatted['file_name']}]")
|
||||
return " ".join(parts)
|
||||
|
||||
|
||||
def _extract_keywords_from_problems(problem_extension: dict) -> List[str]:
|
||||
"""Extract search keywords from problem extension results."""
|
||||
keywords = []
|
||||
context = problem_extension.get("context", {})
|
||||
if isinstance(context, dict):
|
||||
for original_q, extended_qs in context.items():
|
||||
keywords.append(original_q)
|
||||
if isinstance(extended_qs, list):
|
||||
keywords.extend(extended_qs)
|
||||
return keywords
|
||||
|
||||
|
||||
async def perceptual_retrieve_node(state: ReadState) -> ReadState:
|
||||
"""
|
||||
LangGraph node: perceptual memory retrieval.
|
||||
|
||||
Uses PerceptualSearchService to run keyword + embedding search with
|
||||
BM25 fusion reranking, then writes results to state['perceptual_data'].
|
||||
"""
|
||||
end_user_id = state.get("end_user_id", "")
|
||||
problem_extension = state.get("problem_extension", {})
|
||||
original_query = state.get("data", "")
|
||||
memory_config = state.get("memory_config", None)
|
||||
|
||||
logger.info(f"Perceptual_Retrieve: start, end_user_id={end_user_id}")
|
||||
|
||||
keywords = _extract_keywords_from_problems(problem_extension)
|
||||
if not keywords:
|
||||
keywords = [original_query] if original_query else []
|
||||
|
||||
logger.info(f"Perceptual_Retrieve: {len(keywords)} keywords extracted")
|
||||
|
||||
service = PerceptualSearchService(
|
||||
end_user_id=end_user_id,
|
||||
memory_config=memory_config,
|
||||
)
|
||||
search_result = await service.search(
|
||||
query=original_query,
|
||||
keywords=keywords,
|
||||
limit=10,
|
||||
)
|
||||
|
||||
result = {
|
||||
"memories": search_result["memories"],
|
||||
"content": search_result["content"],
|
||||
"_intermediate": {
|
||||
"type": "perceptual_retrieve",
|
||||
"title": "感知记忆检索",
|
||||
"data": search_result["memories"],
|
||||
"query": original_query,
|
||||
"result_count": len(search_result["memories"]),
|
||||
},
|
||||
}
|
||||
return {"perceptual_data": result}
|
||||
@@ -19,19 +19,39 @@ logger = get_agent_logger(__name__)
|
||||
|
||||
|
||||
class ProblemNodeService(LLMServiceMixin):
|
||||
"""问题处理节点服务类"""
|
||||
"""
|
||||
Problem processing node service class
|
||||
|
||||
Handles problem decomposition and extension operations using LLM services.
|
||||
Inherits from LLMServiceMixin to provide structured LLM calling capabilities.
|
||||
|
||||
Attributes:
|
||||
template_service: Service for rendering Jinja2 templates
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.template_service = TemplateService(template_root)
|
||||
|
||||
|
||||
# 创建全局服务实例
|
||||
# Create global service instance
|
||||
problem_service = ProblemNodeService()
|
||||
|
||||
|
||||
async def Split_The_Problem(state: ReadState) -> ReadState:
|
||||
"""问题分解节点"""
|
||||
"""
|
||||
Problem decomposition node
|
||||
|
||||
Breaks down complex user queries into smaller, more manageable sub-problems.
|
||||
Uses LLM to analyze the input and generate structured problem decomposition
|
||||
with question types and reasoning.
|
||||
|
||||
Args:
|
||||
state: ReadState containing user input and configuration
|
||||
|
||||
Returns:
|
||||
ReadState: Updated state with problem decomposition results
|
||||
"""
|
||||
# 从状态中获取数据
|
||||
content = state.get('data', '')
|
||||
end_user_id = state.get('end_user_id', '')
|
||||
@@ -64,7 +84,7 @@ async def Split_The_Problem(state: ReadState) -> ReadState:
|
||||
# 添加更详细的日志记录
|
||||
logger.info(f"Split_The_Problem: 开始处理问题分解,内容长度: {len(content)}")
|
||||
|
||||
# 验证结构化响应
|
||||
# Validate structured response
|
||||
if not structured or not hasattr(structured, 'root'):
|
||||
logger.warning("Split_The_Problem: 结构化响应为空或格式不正确")
|
||||
split_result = json.dumps([], ensure_ascii=False)
|
||||
@@ -106,7 +126,7 @@ async def Split_The_Problem(state: ReadState) -> ReadState:
|
||||
exc_info=True
|
||||
)
|
||||
|
||||
# 提供更详细的错误信息
|
||||
# Provide more detailed error information
|
||||
error_details = {
|
||||
"error_type": type(e).__name__,
|
||||
"error_message": str(e),
|
||||
@@ -116,7 +136,7 @@ async def Split_The_Problem(state: ReadState) -> ReadState:
|
||||
|
||||
logger.error(f"Split_The_Problem error details: {error_details}")
|
||||
|
||||
# 创建默认的空结果
|
||||
# Create default empty result
|
||||
result = {
|
||||
"context": json.dumps([], ensure_ascii=False),
|
||||
"original": content,
|
||||
@@ -130,13 +150,25 @@ async def Split_The_Problem(state: ReadState) -> ReadState:
|
||||
}
|
||||
}
|
||||
|
||||
# 返回更新后的状态,包含spit_context字段
|
||||
# Return updated state including spit_context field
|
||||
return {"spit_data": result}
|
||||
|
||||
|
||||
async def Problem_Extension(state: ReadState) -> ReadState:
|
||||
"""问题扩展节点"""
|
||||
# 获取原始数据和分解结果
|
||||
"""
|
||||
Problem extension node
|
||||
|
||||
Extends the decomposed problems from Split_The_Problem node by generating
|
||||
additional related questions and organizing them by original question.
|
||||
Uses LLM to create comprehensive question extensions for better memory retrieval.
|
||||
|
||||
Args:
|
||||
state: ReadState containing decomposed problems and configuration
|
||||
|
||||
Returns:
|
||||
ReadState: Updated state with extended problem results
|
||||
"""
|
||||
# Get original data and decomposition results
|
||||
start = time.time()
|
||||
content = state.get('data', '')
|
||||
data = state.get('spit_data', '')['context']
|
||||
@@ -182,7 +214,7 @@ async def Problem_Extension(state: ReadState) -> ReadState:
|
||||
|
||||
logger.info(f"Problem_Extension: 开始处理问题扩展,问题数量: {len(databasets)}")
|
||||
|
||||
# 验证结构化响应
|
||||
# Validate structured response
|
||||
if not response_content or not hasattr(response_content, 'root'):
|
||||
logger.warning("Problem_Extension: 结构化响应为空或格式不正确")
|
||||
aggregated_dict = {}
|
||||
@@ -216,7 +248,7 @@ async def Problem_Extension(state: ReadState) -> ReadState:
|
||||
exc_info=True
|
||||
)
|
||||
|
||||
# 提供更详细的错误信息
|
||||
# Provide more detailed error information
|
||||
error_details = {
|
||||
"error_type": type(e).__name__,
|
||||
"error_message": str(e),
|
||||
@@ -231,7 +263,6 @@ async def Problem_Extension(state: ReadState) -> ReadState:
|
||||
logger.info(f"Problem extension result: {aggregated_dict}")
|
||||
|
||||
# Emit intermediate output for frontend
|
||||
print(time.time() - start)
|
||||
result = {
|
||||
"context": aggregated_dict,
|
||||
"original": data,
|
||||
|
||||
@@ -29,6 +29,18 @@ logger = get_agent_logger(__name__)
|
||||
|
||||
|
||||
async def rag_config(state):
|
||||
"""
|
||||
Configure RAG (Retrieval-Augmented Generation) settings
|
||||
|
||||
Creates configuration for knowledge base retrieval including similarity thresholds,
|
||||
weights, and reranker settings.
|
||||
|
||||
Args:
|
||||
state: Current state containing user_rag_memory_id
|
||||
|
||||
Returns:
|
||||
dict: RAG configuration dictionary
|
||||
"""
|
||||
user_rag_memory_id = state.get('user_rag_memory_id', '')
|
||||
kb_config = {
|
||||
"knowledge_bases": [
|
||||
@@ -48,6 +60,19 @@ async def rag_config(state):
|
||||
|
||||
|
||||
async def rag_knowledge(state, question):
|
||||
"""
|
||||
Retrieve knowledge using RAG approach
|
||||
|
||||
Performs knowledge retrieval from configured knowledge bases using the
|
||||
provided question and returns formatted results.
|
||||
|
||||
Args:
|
||||
state: Current state containing configuration
|
||||
question: Question to search for
|
||||
|
||||
Returns:
|
||||
tuple: (retrieval_knowledge, clean_content, cleaned_query, raw_results)
|
||||
"""
|
||||
kb_config = await rag_config(state)
|
||||
end_user_id = state.get('end_user_id', '')
|
||||
user_rag_memory_id = state.get("user_rag_memory_id", '')
|
||||
@@ -68,12 +93,24 @@ async def rag_knowledge(state, question):
|
||||
|
||||
|
||||
async def llm_infomation(state: ReadState) -> ReadState:
|
||||
"""
|
||||
Get LLM configuration information from state
|
||||
|
||||
Retrieves model configuration details including model ID and tenant ID
|
||||
from the memory configuration in the current state.
|
||||
|
||||
Args:
|
||||
state: ReadState containing memory configuration
|
||||
|
||||
Returns:
|
||||
ReadState: Model configuration as Pydantic model
|
||||
"""
|
||||
memory_config = state.get('memory_config', None)
|
||||
model_id = memory_config.llm_model_id
|
||||
tenant_id = memory_config.tenant_id
|
||||
|
||||
# 使用现有的 memory_config 而不是重新查询数据库
|
||||
# 或者使用线程安全的数据库访问
|
||||
# Use existing memory_config instead of re-querying database
|
||||
# or use thread-safe database access
|
||||
with get_db_context() as db:
|
||||
result_orm = ModelConfigService.get_model_by_id(db=db, model_id=model_id, tenant_id=tenant_id)
|
||||
result_pydantic = model_schema.ModelConfig.model_validate(result_orm)
|
||||
@@ -82,16 +119,20 @@ async def llm_infomation(state: ReadState) -> ReadState:
|
||||
|
||||
async def clean_databases(data) -> str:
|
||||
"""
|
||||
简化的数据库搜索结果清理函数
|
||||
Simplified database search result cleaning function
|
||||
|
||||
Processes and cleans search results from various sources including
|
||||
reranked results and time-based search results. Extracts text content
|
||||
from structured data and returns as formatted string.
|
||||
|
||||
Args:
|
||||
data: 搜索结果数据
|
||||
data: Search result data (can be string, dict, or other types)
|
||||
|
||||
Returns:
|
||||
清理后的内容字符串
|
||||
str: Cleaned content string
|
||||
"""
|
||||
try:
|
||||
# 解析JSON字符串
|
||||
# Parse JSON string
|
||||
if isinstance(data, str):
|
||||
try:
|
||||
data = json.loads(data)
|
||||
@@ -101,24 +142,24 @@ async def clean_databases(data) -> str:
|
||||
if not isinstance(data, dict):
|
||||
return str(data)
|
||||
|
||||
# 获取结果数据
|
||||
# Get result data
|
||||
# with open("搜索结果.json","w",encoding='utf-8') as f:
|
||||
# f.write(json.dumps(data, indent=4, ensure_ascii=False))
|
||||
results = data.get('results', data)
|
||||
if not isinstance(results, dict):
|
||||
return str(results)
|
||||
|
||||
# 收集所有内容
|
||||
# Collect all content
|
||||
content_list = []
|
||||
|
||||
# 处理重排序结果
|
||||
# Process reranked results
|
||||
reranked = results.get('reranked_results', {})
|
||||
if reranked:
|
||||
for category in ['summaries', 'statements', 'chunks', 'entities']:
|
||||
for category in ['summaries', 'communities', 'statements', 'chunks', 'entities']:
|
||||
items = reranked.get(category, [])
|
||||
if isinstance(items, list):
|
||||
content_list.extend(items)
|
||||
# 处理时间搜索结果
|
||||
# Process time search results
|
||||
time_search = results.get('time_search', {})
|
||||
if time_search:
|
||||
if isinstance(time_search, dict):
|
||||
@@ -128,11 +169,18 @@ async def clean_databases(data) -> str:
|
||||
elif isinstance(time_search, list):
|
||||
content_list.extend(time_search)
|
||||
|
||||
# 提取文本内容
|
||||
# Extract text content,对 community 按 name 去重(多次 tool 调用会产生重复)
|
||||
text_parts = []
|
||||
seen_community_names = set()
|
||||
for item in content_list:
|
||||
if isinstance(item, dict):
|
||||
text = item.get('statement') or item.get('content', '')
|
||||
# community 节点用 name 去重
|
||||
if 'member_count' in item or 'core_entities' in item:
|
||||
community_name = item.get('name') or item.get('id', '')
|
||||
if community_name in seen_community_names:
|
||||
continue
|
||||
seen_community_names.add(community_name)
|
||||
text = item.get('statement') or item.get('content') or item.get('summary', '')
|
||||
if text:
|
||||
text_parts.append(text)
|
||||
elif isinstance(item, str):
|
||||
@@ -146,10 +194,19 @@ async def clean_databases(data) -> str:
|
||||
|
||||
|
||||
async def retrieve_nodes(state: ReadState) -> ReadState:
|
||||
'''
|
||||
|
||||
模型信息
|
||||
'''
|
||||
"""
|
||||
Retrieve information using simplified search approach
|
||||
|
||||
Processes extended problems from previous nodes and performs retrieval
|
||||
using either RAG or hybrid search based on storage type. Handles concurrent
|
||||
processing of multiple questions and deduplicates results.
|
||||
|
||||
Args:
|
||||
state: ReadState containing problem extensions and configuration
|
||||
|
||||
Returns:
|
||||
ReadState: Updated state with retrieval results and intermediate outputs
|
||||
"""
|
||||
|
||||
problem_extension = state.get('problem_extension', '')['context']
|
||||
storage_type = state.get('storage_type', '')
|
||||
@@ -163,7 +220,7 @@ async def retrieve_nodes(state: ReadState) -> ReadState:
|
||||
problem_list.append(data)
|
||||
logger.info(f"Retrieve: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}")
|
||||
|
||||
# 创建异步任务处理单个问题
|
||||
# Create async task to process individual questions
|
||||
async def process_question_nodes(idx, question):
|
||||
try:
|
||||
# Prepare search parameters based on storage type
|
||||
@@ -209,7 +266,7 @@ async def retrieve_nodes(state: ReadState) -> ReadState:
|
||||
}
|
||||
}
|
||||
|
||||
# 并发处理所有问题
|
||||
# Process all questions concurrently
|
||||
tasks = [process_question_nodes(idx, question) for idx, question in enumerate(problem_list)]
|
||||
databases_anser = await asyncio.gather(*tasks)
|
||||
databases_data = {
|
||||
@@ -257,7 +314,20 @@ async def retrieve_nodes(state: ReadState) -> ReadState:
|
||||
|
||||
|
||||
async def retrieve(state: ReadState) -> ReadState:
|
||||
# 从state中获取end_user_id
|
||||
"""
|
||||
Advanced retrieve function using LangChain agents and tools
|
||||
|
||||
Uses LangChain agents with specialized retrieval tools (time-based and hybrid)
|
||||
to perform sophisticated information retrieval. Supports both RAG and traditional
|
||||
memory storage approaches with concurrent processing and result deduplication.
|
||||
|
||||
Args:
|
||||
state: ReadState containing problem extensions and configuration
|
||||
|
||||
Returns:
|
||||
ReadState: Updated state with retrieval results and intermediate outputs
|
||||
"""
|
||||
# Get end_user_id from state
|
||||
import time
|
||||
start = time.time()
|
||||
problem_extension = state.get('problem_extension', '')['context']
|
||||
@@ -291,7 +361,11 @@ async def retrieve(state: ReadState) -> ReadState:
|
||||
)
|
||||
|
||||
time_retrieval_tool = create_time_retrieval_tool(end_user_id)
|
||||
search_params = {"end_user_id": end_user_id, "return_raw_results": True}
|
||||
search_params = {
|
||||
"end_user_id": end_user_id,
|
||||
"return_raw_results": True,
|
||||
"include": ["summaries", "statements", "chunks", "entities", "communities"],
|
||||
}
|
||||
hybrid_retrieval = create_hybrid_retrieval_tool_sync(memory_config, **search_params)
|
||||
agent = create_agent(
|
||||
llm,
|
||||
@@ -299,21 +373,21 @@ async def retrieve(state: ReadState) -> ReadState:
|
||||
system_prompt=f"我是检索专家,可以根据适合的工具进行检索。当前使用的end_user_id是: {end_user_id}"
|
||||
)
|
||||
|
||||
# 创建异步任务处理单个问题
|
||||
# Create async task to process individual questions
|
||||
import asyncio
|
||||
|
||||
# 在模块级别定义信号量,限制最大并发数
|
||||
SEMAPHORE = asyncio.Semaphore(5) # 限制最多5个并发数据库操作
|
||||
# Define semaphore at module level to limit maximum concurrency
|
||||
SEMAPHORE = asyncio.Semaphore(5) # Limit to maximum 5 concurrent database operations
|
||||
|
||||
async def process_question(idx, question):
|
||||
async with SEMAPHORE: # 限制并发
|
||||
async with SEMAPHORE: # Limit concurrency
|
||||
try:
|
||||
if storage_type == "rag" and user_rag_memory_id:
|
||||
retrieval_knowledge, clean_content, cleaned_query, raw_results = await rag_knowledge(state,
|
||||
question)
|
||||
else:
|
||||
cleaned_query = question
|
||||
# 使用 asyncio 在线程池中运行同步的 agent.invoke
|
||||
# Use asyncio to run synchronous agent.invoke in thread pool
|
||||
import asyncio
|
||||
response = await asyncio.get_event_loop().run_in_executor(
|
||||
None,
|
||||
@@ -327,8 +401,32 @@ async def retrieve(state: ReadState) -> ReadState:
|
||||
raw_results = tool_results['content']
|
||||
clean_content = await clean_databases(raw_results)
|
||||
|
||||
# 社区展开:从 tool 返回结果中提取命中的 community,
|
||||
# 沿 BELONGS_TO_COMMUNITY 关系拉取关联 Statement 追加到 clean_content
|
||||
_expanded_stmts_to_write = []
|
||||
try:
|
||||
results_dict = raw_results.get('results', {}) if isinstance(raw_results, dict) else {}
|
||||
reranked = results_dict.get('reranked_results', {})
|
||||
community_hits = reranked.get('communities', [])
|
||||
if not community_hits:
|
||||
community_hits = results_dict.get('communities', [])
|
||||
if community_hits:
|
||||
from app.core.memory.agent.services.search_service import expand_communities_to_statements
|
||||
_expanded_stmts_to_write, new_texts = await expand_communities_to_statements(
|
||||
community_results=community_hits,
|
||||
end_user_id=end_user_id,
|
||||
existing_content=clean_content,
|
||||
)
|
||||
if new_texts:
|
||||
clean_content = clean_content + '\n' + '\n'.join(new_texts)
|
||||
except Exception as parse_err:
|
||||
logger.warning(f"[Retrieve] 解析社区命中结果失败,跳过展开: {parse_err}")
|
||||
|
||||
try:
|
||||
raw_results = raw_results['results']
|
||||
# 写回展开结果,接口返回中可见(已在 helper 中清洗过字段)
|
||||
if _expanded_stmts_to_write and isinstance(raw_results, dict):
|
||||
raw_results.setdefault('reranked_results', {})['expanded_statements'] = _expanded_stmts_to_write
|
||||
except Exception:
|
||||
raw_results = []
|
||||
|
||||
@@ -362,7 +460,7 @@ async def retrieve(state: ReadState) -> ReadState:
|
||||
}
|
||||
}
|
||||
|
||||
# 并发处理所有问题
|
||||
# Process all questions concurrently
|
||||
import asyncio
|
||||
tasks = [process_question(idx, question) for idx, question in enumerate(problem_list)]
|
||||
databases_anser = await asyncio.gather(*tasks)
|
||||
|
||||
@@ -1,7 +1,11 @@
|
||||
import asyncio
|
||||
import os
|
||||
import time
|
||||
|
||||
from app.core.logging_config import get_agent_logger, log_time
|
||||
from app.core.memory.agent.langgraph_graph.nodes.perceptual_retrieve_node import (
|
||||
PerceptualSearchService,
|
||||
)
|
||||
from app.core.memory.agent.models.summary_models import (
|
||||
RetrieveSummaryResponse,
|
||||
SummaryResponse,
|
||||
@@ -23,18 +27,39 @@ logger = get_agent_logger(__name__)
|
||||
|
||||
|
||||
class SummaryNodeService(LLMServiceMixin):
|
||||
"""总结节点服务类"""
|
||||
"""
|
||||
Summary node service class
|
||||
|
||||
Handles summary generation operations using LLM services. Inherits from
|
||||
LLMServiceMixin to provide structured LLM calling capabilities for
|
||||
generating summaries from retrieved information.
|
||||
|
||||
Attributes:
|
||||
template_service: Service for rendering Jinja2 templates
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.template_service = TemplateService(template_root)
|
||||
|
||||
|
||||
# 创建全局服务实例
|
||||
# Create global service instance
|
||||
summary_service = SummaryNodeService()
|
||||
|
||||
|
||||
async def rag_config(state):
|
||||
"""
|
||||
Configure RAG (Retrieval-Augmented Generation) settings for summary operations
|
||||
|
||||
Creates configuration for knowledge base retrieval including similarity thresholds,
|
||||
weights, and reranker settings specifically for summary generation.
|
||||
|
||||
Args:
|
||||
state: Current state containing user_rag_memory_id
|
||||
|
||||
Returns:
|
||||
dict: RAG configuration dictionary with knowledge base settings
|
||||
"""
|
||||
user_rag_memory_id = state.get('user_rag_memory_id', '')
|
||||
kb_config = {
|
||||
"knowledge_bases": [
|
||||
@@ -54,6 +79,23 @@ async def rag_config(state):
|
||||
|
||||
|
||||
async def rag_knowledge(state, question):
|
||||
"""
|
||||
Retrieve knowledge using RAG approach for summary generation
|
||||
|
||||
Performs knowledge retrieval from configured knowledge bases using the
|
||||
provided question and returns formatted results for summary processing.
|
||||
|
||||
Args:
|
||||
state: Current state containing configuration
|
||||
question: Question to search for in knowledge base
|
||||
|
||||
Returns:
|
||||
tuple: (retrieval_knowledge, clean_content, cleaned_query, raw_results)
|
||||
- retrieval_knowledge: List of retrieved knowledge chunks
|
||||
- clean_content: Formatted content string
|
||||
- cleaned_query: Processed query string
|
||||
- raw_results: Raw retrieval results
|
||||
"""
|
||||
kb_config = await rag_config(state)
|
||||
end_user_id = state.get('end_user_id', '')
|
||||
user_rag_memory_id = state.get("user_rag_memory_id", '')
|
||||
@@ -74,6 +116,18 @@ async def rag_knowledge(state, question):
|
||||
|
||||
|
||||
async def summary_history(state: ReadState) -> ReadState:
|
||||
"""
|
||||
Retrieve conversation history for summary context
|
||||
|
||||
Gets the conversation history for the current user to provide context
|
||||
for summary generation operations.
|
||||
|
||||
Args:
|
||||
state: ReadState containing end_user_id
|
||||
|
||||
Returns:
|
||||
ReadState: Conversation history data
|
||||
"""
|
||||
end_user_id = state.get("end_user_id", '')
|
||||
history = await SessionService(store).get_history(end_user_id, end_user_id, end_user_id)
|
||||
return history
|
||||
@@ -82,11 +136,26 @@ async def summary_history(state: ReadState) -> ReadState:
|
||||
async def summary_llm(state: ReadState, history, retrieve_info, template_name, operation_name, response_model,
|
||||
search_mode) -> str:
|
||||
"""
|
||||
增强的summary_llm函数,包含更好的错误处理和数据验证
|
||||
Enhanced summary_llm function with better error handling and data validation
|
||||
|
||||
Generates summaries using LLM with structured output. Includes fallback mechanisms
|
||||
for handling LLM failures and provides robust error recovery.
|
||||
|
||||
Args:
|
||||
state: ReadState containing current context
|
||||
history: Conversation history for context
|
||||
retrieve_info: Retrieved information to summarize
|
||||
template_name: Jinja2 template name for prompt generation
|
||||
operation_name: Type of operation (summary, input_summary, retrieve_summary)
|
||||
response_model: Pydantic model for structured output
|
||||
search_mode: Search mode flag ("0" for simple, "1" for complex)
|
||||
|
||||
Returns:
|
||||
str: Generated summary text or fallback message
|
||||
"""
|
||||
data = state.get("data", '')
|
||||
|
||||
# 构建系统提示词
|
||||
# Build system prompt
|
||||
if str(search_mode) == "0":
|
||||
system_prompt = await summary_service.template_service.render_template(
|
||||
template_name=template_name,
|
||||
@@ -103,7 +172,7 @@ async def summary_llm(state: ReadState, history, retrieve_info, template_name, o
|
||||
retrieve_info=retrieve_info
|
||||
)
|
||||
try:
|
||||
# 使用优化的LLM服务进行结构化输出
|
||||
# Use optimized LLM service for structured output
|
||||
with get_db_context() as db_session:
|
||||
structured = await summary_service.call_llm_structured(
|
||||
state=state,
|
||||
@@ -112,23 +181,23 @@ async def summary_llm(state: ReadState, history, retrieve_info, template_name, o
|
||||
response_model=response_model,
|
||||
fallback_value=None
|
||||
)
|
||||
# 验证结构化响应
|
||||
# Validate structured response
|
||||
if structured is None:
|
||||
logger.warning("LLM返回None,使用默认回答")
|
||||
return "信息不足,无法回答"
|
||||
|
||||
# 根据操作类型提取答案
|
||||
# Extract answer based on operation type
|
||||
if operation_name == "summary":
|
||||
aimessages = getattr(structured, 'query_answer', None) or "信息不足,无法回答"
|
||||
else:
|
||||
# 处理RetrieveSummaryResponse
|
||||
# Handle RetrieveSummaryResponse
|
||||
if hasattr(structured, 'data') and structured.data:
|
||||
aimessages = getattr(structured.data, 'query_answer', None) or "信息不足,无法回答"
|
||||
else:
|
||||
logger.warning("结构化响应缺少data字段")
|
||||
aimessages = "信息不足,无法回答"
|
||||
|
||||
# 验证答案不为空
|
||||
# Validate answer is not empty
|
||||
if not aimessages or aimessages.strip() == "":
|
||||
aimessages = "信息不足,无法回答"
|
||||
|
||||
@@ -137,7 +206,7 @@ async def summary_llm(state: ReadState, history, retrieve_info, template_name, o
|
||||
except Exception as e:
|
||||
logger.error(f"结构化输出失败: {e}", exc_info=True)
|
||||
|
||||
# 尝试非结构化输出作为fallback
|
||||
# Try unstructured output as fallback
|
||||
try:
|
||||
logger.info("尝试非结构化输出作为fallback")
|
||||
response = await summary_service.call_llm_simple(
|
||||
@@ -148,9 +217,9 @@ async def summary_llm(state: ReadState, history, retrieve_info, template_name, o
|
||||
)
|
||||
|
||||
if response and response.strip():
|
||||
# 简单清理响应
|
||||
# Simple response cleaning
|
||||
cleaned_response = response.strip()
|
||||
# 移除可能的JSON标记
|
||||
# Remove possible JSON markers
|
||||
if cleaned_response.startswith('```'):
|
||||
lines = cleaned_response.split('\n')
|
||||
cleaned_response = '\n'.join(lines[1:-1])
|
||||
@@ -165,6 +234,19 @@ async def summary_llm(state: ReadState, history, retrieve_info, template_name, o
|
||||
|
||||
|
||||
async def summary_redis_save(state: ReadState, aimessages) -> ReadState:
|
||||
"""
|
||||
Save summary results to Redis session storage
|
||||
|
||||
Stores the generated summary and user query in Redis for session management
|
||||
and conversation history tracking.
|
||||
|
||||
Args:
|
||||
state: ReadState containing user and query information
|
||||
aimessages: Generated summary message to save
|
||||
|
||||
Returns:
|
||||
ReadState: Updated state after saving to Redis
|
||||
"""
|
||||
data = state.get("data", '')
|
||||
end_user_id = state.get("end_user_id", '')
|
||||
await SessionService(store).save_session(
|
||||
@@ -179,6 +261,20 @@ async def summary_redis_save(state: ReadState, aimessages) -> ReadState:
|
||||
|
||||
|
||||
async def summary_prompt(state: ReadState, aimessages, raw_results) -> ReadState:
|
||||
"""
|
||||
Format summary results for different output types
|
||||
|
||||
Creates structured output formats for both input summary and retrieval summary
|
||||
operations, including metadata and intermediate results for frontend display.
|
||||
|
||||
Args:
|
||||
state: ReadState containing storage and user information
|
||||
aimessages: Generated summary message
|
||||
raw_results: Raw search/retrieval results
|
||||
|
||||
Returns:
|
||||
tuple: (input_summary, retrieve_summary) formatted result dictionaries
|
||||
"""
|
||||
storage_type = state.get("storage_type", '')
|
||||
user_rag_memory_id = state.get("user_rag_memory_id", '')
|
||||
data = state.get("data", '')
|
||||
@@ -217,6 +313,19 @@ async def summary_prompt(state: ReadState, aimessages, raw_results) -> ReadState
|
||||
|
||||
|
||||
async def Input_Summary(state: ReadState) -> ReadState:
|
||||
"""
|
||||
Generate quick input summary from retrieved information
|
||||
|
||||
Performs fast retrieval and generates a quick summary response for user queries.
|
||||
This function prioritizes speed by only searching summary nodes and provides
|
||||
immediate feedback to users.
|
||||
|
||||
Args:
|
||||
state: ReadState containing user query, storage configuration, and context
|
||||
|
||||
Returns:
|
||||
ReadState: Dictionary containing summary results with status and metadata
|
||||
"""
|
||||
start = time.time()
|
||||
storage_type = state.get("storage_type", '')
|
||||
memory_config = state.get('memory_config', None)
|
||||
@@ -229,13 +338,56 @@ async def Input_Summary(state: ReadState) -> ReadState:
|
||||
"end_user_id": end_user_id,
|
||||
"question": data,
|
||||
"return_raw_results": True,
|
||||
"include": ["summaries"] # Only search summary nodes for faster performance
|
||||
"include": ["summaries", "communities"] # MemorySummary 和 Community 同为高维度概括节点
|
||||
}
|
||||
|
||||
try:
|
||||
if storage_type != "rag":
|
||||
retrieve_info, question, raw_results = await SearchService().execute_hybrid_search(**search_params,
|
||||
memory_config=memory_config)
|
||||
|
||||
async def _perceptual_search():
|
||||
service = PerceptualSearchService(
|
||||
end_user_id=end_user_id,
|
||||
memory_config=memory_config,
|
||||
)
|
||||
return await service.search(query=data, limit=5)
|
||||
|
||||
hybrid_task = SearchService().execute_hybrid_search(
|
||||
**search_params,
|
||||
memory_config=memory_config,
|
||||
expand_communities=False,
|
||||
)
|
||||
perceptual_task = _perceptual_search()
|
||||
|
||||
gather_results = await asyncio.gather(
|
||||
hybrid_task, perceptual_task, return_exceptions=True
|
||||
)
|
||||
hybrid_result = gather_results[0]
|
||||
perceptual_results = gather_results[1]
|
||||
|
||||
# 处理 hybrid search 异常
|
||||
if isinstance(hybrid_result, Exception):
|
||||
raise hybrid_result
|
||||
retrieve_info, question, raw_results = hybrid_result
|
||||
|
||||
# 处理感知记忆结果
|
||||
if isinstance(perceptual_results, Exception):
|
||||
logger.warning(f"[Input_Summary] perceptual search failed: {perceptual_results}")
|
||||
perceptual_results = []
|
||||
|
||||
# 拼接感知记忆内容到 retrieve_info
|
||||
if perceptual_results and isinstance(perceptual_results, dict):
|
||||
perceptual_content = perceptual_results.get("content", "")
|
||||
if perceptual_content:
|
||||
retrieve_info = f"{retrieve_info}\n\n<history-files>\n{perceptual_content}"
|
||||
count = len(perceptual_results.get("memories", []))
|
||||
logger.info(f"[Input_Summary] appended {count} perceptual memories (reranked)")
|
||||
|
||||
# 调试:打印 community 检索结果数量
|
||||
if raw_results and isinstance(raw_results, dict):
|
||||
reranked = raw_results.get('reranked_results', {})
|
||||
community_hits = reranked.get('communities', [])
|
||||
logger.debug(f"[Input_Summary] community 命中数: {len(community_hits)}, "
|
||||
f"summary 命中数: {len(reranked.get('summaries', []))}")
|
||||
else:
|
||||
retrieval_knowledge, retrieve_info, question, raw_results = await rag_knowledge(state, data)
|
||||
except Exception as e:
|
||||
@@ -257,15 +409,25 @@ async def Input_Summary(state: ReadState) -> ReadState:
|
||||
"error": str(e)
|
||||
}
|
||||
end = time.time()
|
||||
try:
|
||||
duration = end - start
|
||||
except Exception:
|
||||
duration = 0.0
|
||||
duration = end - start
|
||||
log_time('检索', duration)
|
||||
return {"summary": summary}
|
||||
|
||||
|
||||
async def Retrieve_Summary(state: ReadState) -> ReadState:
|
||||
"""
|
||||
Generate comprehensive summary from retrieved expansion issues
|
||||
|
||||
Processes retrieved expansion issues and generates a detailed summary using LLM.
|
||||
This function handles complex retrieval results and provides comprehensive answers
|
||||
based on expanded query results.
|
||||
|
||||
Args:
|
||||
state: ReadState containing retrieve data with expansion issues
|
||||
|
||||
Returns:
|
||||
ReadState: Dictionary containing comprehensive summary results
|
||||
"""
|
||||
retrieve = state.get("retrieve", '')
|
||||
history = await summary_history(state)
|
||||
import json
|
||||
@@ -285,8 +447,20 @@ async def Retrieve_Summary(state: ReadState) -> ReadState:
|
||||
retrieve_info_str = list(set(retrieve_info_str))
|
||||
retrieve_info_str = '\n'.join(retrieve_info_str)
|
||||
|
||||
aimessages = await summary_llm(state, history, retrieve_info_str,
|
||||
'direct_summary_prompt.jinja2', 'retrieve_summary', RetrieveSummaryResponse, "1")
|
||||
# Merge perceptual memory content
|
||||
perceptual_data = state.get("perceptual_data", {})
|
||||
perceptual_content = perceptual_data.get("content", "") if isinstance(perceptual_data, dict) else ""
|
||||
if perceptual_content:
|
||||
retrieve_info_str = f"{retrieve_info_str}\n\n<history-file-input>\n{perceptual_content}</history-file-input>"
|
||||
|
||||
aimessages = await summary_llm(
|
||||
state,
|
||||
history,
|
||||
retrieve_info_str,
|
||||
'direct_summary_prompt.jinja2',
|
||||
'retrieve_summary', RetrieveSummaryResponse,
|
||||
"1"
|
||||
)
|
||||
if '信息不足,无法回答' not in str(aimessages) or str(aimessages) != "":
|
||||
await summary_redis_save(state, aimessages)
|
||||
if aimessages == '':
|
||||
@@ -299,13 +473,26 @@ async def Retrieve_Summary(state: ReadState) -> ReadState:
|
||||
duration = 0.0
|
||||
log_time('Retrieval summary', duration)
|
||||
|
||||
# 修复协程调用 - 先await,然后访问返回值
|
||||
# Fixed coroutine call - await first, then access return value
|
||||
summary_result = await summary_prompt(state, aimessages, retrieve_info_str)
|
||||
summary = summary_result[1]
|
||||
return {"summary": summary}
|
||||
|
||||
|
||||
async def Summary(state: ReadState) -> ReadState:
|
||||
"""
|
||||
Generate final comprehensive summary from verified data
|
||||
|
||||
Creates the final summary using verified expansion issues and conversation history.
|
||||
This function processes verified data to generate the most comprehensive and
|
||||
accurate response to user queries.
|
||||
|
||||
Args:
|
||||
state: ReadState containing verified data and query information
|
||||
|
||||
Returns:
|
||||
ReadState: Dictionary containing final summary results
|
||||
"""
|
||||
start = time.time()
|
||||
query = state.get("data", '')
|
||||
verify = state.get("verify", '')
|
||||
@@ -318,6 +505,12 @@ async def Summary(state: ReadState) -> ReadState:
|
||||
retrieve_info_str += i + '\n'
|
||||
history = await summary_history(state)
|
||||
|
||||
# Merge perceptual memory content
|
||||
perceptual_data = state.get("perceptual_data", {})
|
||||
perceptual_content = perceptual_data.get("content", "") if isinstance(perceptual_data, dict) else ""
|
||||
if perceptual_content:
|
||||
retrieve_info_str = f"{retrieve_info_str}\n\n<history-file-input>\n{perceptual_content}</history-file-input>"
|
||||
|
||||
data = {
|
||||
"query": query,
|
||||
"history": history,
|
||||
@@ -336,13 +529,26 @@ async def Summary(state: ReadState) -> ReadState:
|
||||
duration = 0.0
|
||||
log_time('Retrieval summary', duration)
|
||||
|
||||
# 修复协程调用 - 先await,然后访问返回值
|
||||
# Fixed coroutine call - await first, then access return value
|
||||
summary_result = await summary_prompt(state, aimessages, retrieve_info_str)
|
||||
summary = summary_result[1]
|
||||
return {"summary": summary}
|
||||
|
||||
|
||||
async def Summary_fails(state: ReadState) -> ReadState:
|
||||
"""
|
||||
Generate fallback summary when normal summary process fails
|
||||
|
||||
Provides a fallback summary generation mechanism when the standard summary
|
||||
process encounters errors or fails to produce satisfactory results. Uses
|
||||
a specialized failure template to handle edge cases.
|
||||
|
||||
Args:
|
||||
state: ReadState containing verified data and failure context
|
||||
|
||||
Returns:
|
||||
ReadState: Dictionary containing fallback summary results
|
||||
"""
|
||||
storage_type = state.get("storage_type", '')
|
||||
user_rag_memory_id = state.get("user_rag_memory_id", '')
|
||||
history = await summary_history(state)
|
||||
@@ -355,6 +561,13 @@ async def Summary_fails(state: ReadState) -> ReadState:
|
||||
if key == 'answer_small':
|
||||
for i in value:
|
||||
retrieve_info_str += i + '\n'
|
||||
|
||||
# Merge perceptual memory content
|
||||
perceptual_data = state.get("perceptual_data", {})
|
||||
perceptual_content = perceptual_data.get("content", "") if isinstance(perceptual_data, dict) else ""
|
||||
if perceptual_content:
|
||||
retrieve_info_str = f"{retrieve_info_str}\n\n<history-file-input>\n{perceptual_content}</history-file-input>"
|
||||
|
||||
data = {
|
||||
"query": query,
|
||||
"history": history,
|
||||
|
||||
@@ -18,24 +18,46 @@ logger = get_agent_logger(__name__)
|
||||
|
||||
|
||||
class VerificationNodeService(LLMServiceMixin):
|
||||
"""验证节点服务类"""
|
||||
"""
|
||||
Verification node service class
|
||||
|
||||
Handles data verification operations using LLM services. Inherits from
|
||||
LLMServiceMixin to provide structured LLM calling capabilities for
|
||||
verifying and validating retrieved information.
|
||||
|
||||
Attributes:
|
||||
template_service: Service for rendering Jinja2 templates
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.template_service = TemplateService(template_root)
|
||||
|
||||
|
||||
# 创建全局服务实例
|
||||
# Create global service instance
|
||||
verification_service = VerificationNodeService()
|
||||
|
||||
|
||||
async def Verify_prompt(state: ReadState, messages_deal: VerificationResult):
|
||||
"""处理验证结果并生成输出格式"""
|
||||
"""
|
||||
Process verification results and generate output format
|
||||
|
||||
Transforms VerificationResult objects into structured output format suitable
|
||||
for frontend consumption. Handles conversion of VerificationItem objects to
|
||||
dictionary format and adds metadata for tracking.
|
||||
|
||||
Args:
|
||||
state: ReadState containing storage and user configuration
|
||||
messages_deal: VerificationResult containing verification outcomes
|
||||
|
||||
Returns:
|
||||
dict: Formatted verification result with status and metadata
|
||||
"""
|
||||
storage_type = state.get('storage_type', '')
|
||||
user_rag_memory_id = state.get('user_rag_memory_id', '')
|
||||
data = state.get('data', '')
|
||||
|
||||
# 将 VerificationItem 对象转换为字典列表
|
||||
# Convert VerificationItem objects to dictionary list
|
||||
verified_data = []
|
||||
if messages_deal.expansion_issue:
|
||||
for item in messages_deal.expansion_issue:
|
||||
@@ -89,7 +111,7 @@ async def Verify(state: ReadState):
|
||||
|
||||
logger.info("Verify: 开始渲染模板")
|
||||
|
||||
# 生成 JSON schema 以指导 LLM 输出正确格式
|
||||
# Generate JSON schema to guide LLM output format
|
||||
json_schema = VerificationResult.model_json_schema()
|
||||
|
||||
system_prompt = await verification_service.template_service.render_template(
|
||||
@@ -104,8 +126,8 @@ async def Verify(state: ReadState):
|
||||
# 使用优化的LLM服务,添加超时保护
|
||||
logger.info("Verify: 开始调用 LLM")
|
||||
try:
|
||||
# 添加 asyncio.wait_for 超时包裹,防止无限等待
|
||||
# 超时时间设置为 150 秒(比 LLM 配置的 120 秒稍长)
|
||||
# Add asyncio.wait_for timeout wrapper to prevent infinite waiting
|
||||
# Timeout set to 150 seconds (slightly longer than LLM config's 120 seconds)
|
||||
|
||||
with get_db_context() as db_session:
|
||||
structured = await asyncio.wait_for(
|
||||
@@ -122,7 +144,7 @@ async def Verify(state: ReadState):
|
||||
"reason": "验证失败或超时"
|
||||
}
|
||||
),
|
||||
timeout=150.0 # 150秒超时
|
||||
timeout=150.0 # 150 second timeout
|
||||
)
|
||||
logger.info(f"Verify: LLM 调用完成,result={structured}")
|
||||
except asyncio.TimeoutError:
|
||||
|
||||
@@ -15,7 +15,10 @@ from app.core.memory.agent.langgraph_graph.nodes.problem_nodes import (
|
||||
Problem_Extension,
|
||||
)
|
||||
from app.core.memory.agent.langgraph_graph.nodes.retrieve_nodes import (
|
||||
retrieve,
|
||||
retrieve_nodes,
|
||||
)
|
||||
from app.core.memory.agent.langgraph_graph.nodes.perceptual_retrieve_node import (
|
||||
perceptual_retrieve_node,
|
||||
)
|
||||
from app.core.memory.agent.langgraph_graph.nodes.summary_nodes import (
|
||||
Input_Summary,
|
||||
@@ -33,147 +36,54 @@ from app.core.memory.agent.langgraph_graph.routing.routers import (
|
||||
|
||||
@asynccontextmanager
|
||||
async def make_read_graph():
|
||||
"""创建并返回 LangGraph 工作流"""
|
||||
"""
|
||||
Create and return a LangGraph workflow for memory reading operations
|
||||
|
||||
Builds a state graph workflow that handles memory retrieval, problem analysis,
|
||||
verification, and summarization. The workflow includes nodes for content input,
|
||||
problem splitting, retrieval, verification, and various summary operations.
|
||||
|
||||
Yields:
|
||||
StateGraph: Compiled LangGraph workflow for memory reading
|
||||
|
||||
Raises:
|
||||
Exception: If workflow creation fails
|
||||
"""
|
||||
try:
|
||||
# Build workflow graph
|
||||
workflow = StateGraph(ReadState)
|
||||
workflow = StateGraph(ReadState)
|
||||
workflow.add_node("content_input", content_input_node)
|
||||
workflow.add_node("Split_The_Problem", Split_The_Problem)
|
||||
workflow.add_node("Problem_Extension", Problem_Extension)
|
||||
workflow.add_node("Input_Summary", Input_Summary)
|
||||
# workflow.add_node("Retrieve", retrieve_nodes)
|
||||
workflow.add_node("Retrieve", retrieve)
|
||||
workflow.add_node("Retrieve", retrieve_nodes)
|
||||
# workflow.add_node("Retrieve", retrieve)
|
||||
workflow.add_node("Perceptual_Retrieve", perceptual_retrieve_node)
|
||||
workflow.add_node("Verify", Verify)
|
||||
workflow.add_node("Retrieve_Summary", Retrieve_Summary)
|
||||
workflow.add_node("Summary", Summary)
|
||||
workflow.add_node("Summary_fails", Summary_fails)
|
||||
|
||||
# 添加边
|
||||
# Add edges to define workflow flow
|
||||
workflow.add_edge(START, "content_input")
|
||||
workflow.add_conditional_edges("content_input", Split_continue)
|
||||
workflow.add_edge("Input_Summary", END)
|
||||
workflow.add_edge("Split_The_Problem", "Problem_Extension")
|
||||
workflow.add_edge("Problem_Extension", "Retrieve")
|
||||
# After Problem_Extension, retrieve perceptual memory first, then main Retrieve
|
||||
workflow.add_edge("Problem_Extension", "Perceptual_Retrieve")
|
||||
workflow.add_edge("Perceptual_Retrieve", "Retrieve")
|
||||
workflow.add_conditional_edges("Retrieve", Retrieve_continue)
|
||||
workflow.add_edge("Retrieve_Summary", END)
|
||||
workflow.add_conditional_edges("Verify", Verify_continue)
|
||||
workflow.add_edge("Summary_fails", END)
|
||||
workflow.add_edge("Summary", END)
|
||||
|
||||
'''-----'''
|
||||
# workflow.add_edge("Retrieve", END)
|
||||
|
||||
# 编译工作流
|
||||
# Compile workflow
|
||||
graph = workflow.compile()
|
||||
yield graph
|
||||
|
||||
except Exception as e:
|
||||
print(f"创建工作流失败: {e}")
|
||||
logger.error(f"创建工作流失败: {e}")
|
||||
raise
|
||||
finally:
|
||||
print("工作流创建完成")
|
||||
|
||||
|
||||
async def main():
|
||||
"""主函数 - 运行工作流"""
|
||||
message = "昨天有什么好看的电影"
|
||||
end_user_id = '88a459f5_text09' # 组ID
|
||||
storage_type = 'neo4j' # 存储类型
|
||||
search_switch = '1' # 搜索开关
|
||||
user_rag_memory_id = 'wwwwwwww' # 用户RAG记忆ID
|
||||
|
||||
# 获取数据库会话
|
||||
db_session = next(get_db())
|
||||
config_service = MemoryConfigService(db_session)
|
||||
memory_config = config_service.load_memory_config(
|
||||
config_id=17, # 改为整数
|
||||
service_name="MemoryAgentService"
|
||||
)
|
||||
import time
|
||||
start = time.time()
|
||||
try:
|
||||
async with make_read_graph() as graph:
|
||||
config = {"configurable": {"thread_id": end_user_id}}
|
||||
# 初始状态 - 包含所有必要字段
|
||||
initial_state = {"messages": [HumanMessage(content=message)], "search_switch": search_switch,
|
||||
"end_user_id": end_user_id
|
||||
, "storage_type": storage_type, "user_rag_memory_id": user_rag_memory_id,
|
||||
"memory_config": memory_config}
|
||||
# 获取节点更新信息
|
||||
_intermediate_outputs = []
|
||||
summary = ''
|
||||
|
||||
async for update_event in graph.astream(
|
||||
initial_state,
|
||||
stream_mode="updates",
|
||||
config=config
|
||||
):
|
||||
for node_name, node_data in update_event.items():
|
||||
print(f"处理节点: {node_name}")
|
||||
|
||||
# 处理不同Summary节点的返回结构
|
||||
if 'Summary' in node_name:
|
||||
if 'InputSummary' in node_data and 'summary_result' in node_data['InputSummary']:
|
||||
summary = node_data['InputSummary']['summary_result']
|
||||
elif 'RetrieveSummary' in node_data and 'summary_result' in node_data['RetrieveSummary']:
|
||||
summary = node_data['RetrieveSummary']['summary_result']
|
||||
elif 'summary' in node_data and 'summary_result' in node_data['summary']:
|
||||
summary = node_data['summary']['summary_result']
|
||||
elif 'SummaryFails' in node_data and 'summary_result' in node_data['SummaryFails']:
|
||||
summary = node_data['SummaryFails']['summary_result']
|
||||
|
||||
spit_data = node_data.get('spit_data', {}).get('_intermediate', None)
|
||||
if spit_data and spit_data != [] and spit_data != {}:
|
||||
_intermediate_outputs.append(spit_data)
|
||||
|
||||
# Problem_Extension 节点
|
||||
problem_extension = node_data.get('problem_extension', {}).get('_intermediate', None)
|
||||
if problem_extension and problem_extension != [] and problem_extension != {}:
|
||||
_intermediate_outputs.append(problem_extension)
|
||||
|
||||
# Retrieve 节点
|
||||
retrieve_node = node_data.get('retrieve', {}).get('_intermediate_outputs', None)
|
||||
if retrieve_node and retrieve_node != [] and retrieve_node != {}:
|
||||
_intermediate_outputs.extend(retrieve_node)
|
||||
|
||||
# Verify 节点
|
||||
verify_n = node_data.get('verify', {}).get('_intermediate', None)
|
||||
if verify_n and verify_n != [] and verify_n != {}:
|
||||
_intermediate_outputs.append(verify_n)
|
||||
|
||||
# Summary 节点
|
||||
summary_n = node_data.get('summary', {}).get('_intermediate', None)
|
||||
if summary_n and summary_n != [] and summary_n != {}:
|
||||
_intermediate_outputs.append(summary_n)
|
||||
|
||||
# # 过滤掉空值
|
||||
# _intermediate_outputs = [item for item in _intermediate_outputs if item and item != [] and item != {}]
|
||||
#
|
||||
# # 优化搜索结果
|
||||
# print("=== 开始优化搜索结果 ===")
|
||||
# optimized_outputs = merge_multiple_search_results(_intermediate_outputs)
|
||||
# result=reorder_output_results(optimized_outputs)
|
||||
# # 保存优化后的结果到文件
|
||||
# with open('_intermediate_outputs_optimized.json', 'w', encoding='utf-8') as f:
|
||||
# import json
|
||||
# f.write(json.dumps(result, indent=4, ensure_ascii=False))
|
||||
#
|
||||
print(f"=== 最终摘要 ===")
|
||||
print(summary)
|
||||
|
||||
except Exception as e:
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
finally:
|
||||
db_session.close()
|
||||
|
||||
end = time.time()
|
||||
print(100 * 'y')
|
||||
print(f"总耗时: {end - start}s")
|
||||
print(100 * 'y')
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
|
||||
asyncio.run(main())
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
|
||||
from typing import Literal
|
||||
|
||||
from app.core.logging_config import get_agent_logger
|
||||
from app.core.memory.agent.utils.llm_tools import ReadState, COUNTState
|
||||
|
||||
|
||||
logger = get_agent_logger(__name__)
|
||||
counter = COUNTState(limit=3)
|
||||
def Split_continue(state:ReadState) -> Literal["Split_The_Problem", "Input_Summary"]:
|
||||
|
||||
|
||||
def Split_continue(state: ReadState) -> Literal["Split_The_Problem", "Input_Summary"]:
|
||||
"""
|
||||
Determine routing based on search_switch value.
|
||||
|
||||
@@ -25,6 +25,7 @@ def Split_continue(state:ReadState) -> Literal["Split_The_Problem", "Input_Summa
|
||||
return 'Input_Summary'
|
||||
return 'Split_The_Problem' # 默认情况
|
||||
|
||||
|
||||
def Retrieve_continue(state) -> Literal["Verify", "Retrieve_Summary"]:
|
||||
"""
|
||||
Determine routing based on search_switch value.
|
||||
@@ -43,8 +44,10 @@ def Retrieve_continue(state) -> Literal["Verify", "Retrieve_Summary"]:
|
||||
elif search_switch == '1':
|
||||
return 'Retrieve_Summary'
|
||||
return 'Retrieve_Summary' # Default based on business logic
|
||||
|
||||
|
||||
def Verify_continue(state: ReadState) -> Literal["Summary", "Summary_fails", "content_input"]:
|
||||
status=state.get('verify', '')['status']
|
||||
status = state.get('verify', '')['status']
|
||||
# loop_count = counter.get_total()
|
||||
if "success" in status:
|
||||
# counter.reset()
|
||||
@@ -53,7 +56,7 @@ def Verify_continue(state: ReadState) -> Literal["Summary", "Summary_fails", "co
|
||||
# if loop_count < 2: # Maximum loop count is 3
|
||||
# return "content_input"
|
||||
# else:
|
||||
# counter.reset()
|
||||
# counter.reset()
|
||||
return "Summary_fails"
|
||||
else:
|
||||
# Add default return value to avoid returning None
|
||||
|
||||
@@ -2,77 +2,84 @@ import json
|
||||
import os
|
||||
|
||||
from app.core.logging_config import get_agent_logger
|
||||
from app.core.memory.agent.langgraph_graph.tools.write_tool import format_parsing, messages_parse
|
||||
from app.core.memory.agent.langgraph_graph.write_graph import make_write_graph, long_term_storage
|
||||
|
||||
from app.core.memory.agent.langgraph_graph.tools.write_tool import format_parsing, messages_parse
|
||||
from app.core.memory.agent.models.write_aggregate_model import WriteAggregateModel
|
||||
from app.core.memory.agent.utils.llm_tools import PROJECT_ROOT_
|
||||
from app.core.memory.agent.utils.redis_tool import write_store
|
||||
from app.core.memory.agent.utils.redis_tool import count_store
|
||||
from app.core.memory.agent.utils.redis_tool import write_store
|
||||
from app.core.memory.agent.utils.template_tools import TemplateService
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.db import get_db_context, get_db
|
||||
from app.db import get_db_context
|
||||
from app.repositories.memory_short_repository import LongTermMemoryRepository
|
||||
from app.schemas.memory_agent_schema import AgentMemory_Long_Term
|
||||
from app.services.memory_konwledges_server import write_rag
|
||||
from app.services.task_service import get_task_memory_write_result
|
||||
from app.tasks import write_message_task
|
||||
from app.utils.config_utils import resolve_config_id
|
||||
|
||||
logger = get_agent_logger(__name__)
|
||||
template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt')
|
||||
|
||||
async def write_rag_agent(end_user_id, user_message, ai_message, user_rag_memory_id):
|
||||
# RAG 模式:组合消息为字符串格式(保持原有逻辑)
|
||||
combined_message = f"user: {user_message}\nassistant: {ai_message}"
|
||||
await write_rag(end_user_id, combined_message, user_rag_memory_id)
|
||||
logger.info(f'RAG_Agent:{end_user_id};{user_rag_memory_id}')
|
||||
async def write(storage_type, end_user_id, user_message, ai_message, user_rag_memory_id, actual_end_user_id,
|
||||
actual_config_id, long_term_messages=[]):
|
||||
|
||||
async def write(
|
||||
storage_type,
|
||||
end_user_id,
|
||||
user_message,
|
||||
ai_message,
|
||||
user_rag_memory_id,
|
||||
actual_end_user_id,
|
||||
actual_config_id,
|
||||
long_term_messages=None
|
||||
):
|
||||
"""
|
||||
写入记忆(支持结构化消息)
|
||||
Write memory with structured message support
|
||||
|
||||
Handles memory writing operations for different storage types (Neo4j/RAG).
|
||||
Supports both individual message pairs and batch long-term message processing.
|
||||
|
||||
Args:
|
||||
storage_type: 存储类型 (neo4j/rag)
|
||||
end_user_id: 终端用户ID
|
||||
user_message: 用户消息内容
|
||||
ai_message: AI 回复内容
|
||||
user_rag_memory_id: RAG 记忆ID
|
||||
actual_end_user_id: 实际用户ID
|
||||
actual_config_id: 配置ID
|
||||
storage_type: Storage type identifier ("neo4j" or "rag")
|
||||
end_user_id: Terminal user identifier
|
||||
user_message: User message content
|
||||
ai_message: AI response content
|
||||
user_rag_memory_id: RAG memory identifier
|
||||
actual_end_user_id: Actual user identifier for storage
|
||||
actual_config_id: Configuration identifier
|
||||
long_term_messages: Optional list of structured messages for batch processing
|
||||
|
||||
逻辑说明:
|
||||
- RAG 模式:组合 user_message 和 ai_message 为字符串格式,保持原有逻辑不变
|
||||
- Neo4j 模式:使用结构化消息列表
|
||||
1. 如果 user_message 和 ai_message 都不为空:创建配对消息 [user, assistant]
|
||||
2. 如果只有 user_message:创建单条用户消息 [user](用于历史记忆场景)
|
||||
3. 每条消息会被转换为独立的 Chunk,保留 speaker 字段
|
||||
Logic explanation:
|
||||
- RAG mode: Combines user_message and ai_message into string format, maintains original logic
|
||||
- Neo4j mode: Uses structured message lists
|
||||
1. If both user_message and ai_message are not empty: Creates paired messages [user, assistant]
|
||||
2. If only user_message exists: Creates single user message [user] (for historical memory scenarios)
|
||||
3. Each message is converted to independent Chunk, preserving speaker field
|
||||
"""
|
||||
|
||||
db = next(get_db())
|
||||
try:
|
||||
if long_term_messages is None:
|
||||
long_term_messages = []
|
||||
with get_db_context() as db:
|
||||
actual_config_id = resolve_config_id(actual_config_id, db)
|
||||
# Neo4j 模式:使用结构化消息列表
|
||||
# Neo4j mode: Use structured message lists
|
||||
structured_messages = []
|
||||
|
||||
# 始终添加用户消息(如果不为空)
|
||||
# Always add user message (if not empty)
|
||||
if isinstance(user_message, str) and user_message.strip() != "":
|
||||
structured_messages.append({"role": "user", "content": user_message})
|
||||
|
||||
# 只有当 AI 回复不为空时才添加 assistant 消息
|
||||
# Only add assistant message when AI reply is not empty
|
||||
if isinstance(ai_message, str) and ai_message.strip() != "":
|
||||
structured_messages.append({"role": "assistant", "content": ai_message})
|
||||
|
||||
# 如果提供了 long_term_messages,使用它替代 structured_messages
|
||||
# If long_term_messages provided, use it to replace structured_messages
|
||||
if long_term_messages and isinstance(long_term_messages, list):
|
||||
structured_messages = long_term_messages
|
||||
elif long_term_messages and isinstance(long_term_messages, str):
|
||||
# 如果是 JSON 字符串,先解析
|
||||
# If it's a JSON string, parse it first
|
||||
try:
|
||||
structured_messages = json.loads(long_term_messages)
|
||||
except json.JSONDecodeError:
|
||||
logger.error(f"Failed to parse long_term_messages as JSON: {long_term_messages}")
|
||||
|
||||
# 如果没有消息,直接返回
|
||||
# If no messages, return directly
|
||||
if not structured_messages:
|
||||
logger.warning(f"No messages to write for user {actual_end_user_id}")
|
||||
return
|
||||
@@ -80,29 +87,42 @@ async def write(storage_type, end_user_id, user_message, ai_message, user_rag_me
|
||||
logger.info(
|
||||
f"[WRITE] Submitting Celery task - user={actual_end_user_id}, messages={len(structured_messages)}, config={actual_config_id}")
|
||||
write_id = write_message_task.delay(
|
||||
actual_end_user_id, # end_user_id: 用户ID
|
||||
structured_messages, # message: JSON 字符串格式的消息列表
|
||||
str(actual_config_id), # config_id: 配置ID字符串
|
||||
actual_end_user_id, # end_user_id: User ID
|
||||
structured_messages, # message: JSON string format message list
|
||||
str(actual_config_id), # config_id: Configuration ID string
|
||||
storage_type, # storage_type: "neo4j"
|
||||
user_rag_memory_id or "" # user_rag_memory_id: RAG记忆ID(Neo4j模式下不使用)
|
||||
user_rag_memory_id or "" # user_rag_memory_id: RAG memory ID (not used in Neo4j mode)
|
||||
)
|
||||
logger.info(f"[WRITE] Celery task submitted - task_id={write_id}")
|
||||
write_status = get_task_memory_write_result(str(write_id))
|
||||
logger.info(f'[WRITE] Task result - user={actual_end_user_id}, status={write_status}')
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
async def term_memory_save(long_term_messages,actual_config_id,end_user_id,type,scope):
|
||||
|
||||
async def term_memory_save(end_user_id, strategy_type, scope):
|
||||
"""
|
||||
Save long-term memory data to database
|
||||
|
||||
Handles the storage of long-term memory data based on different strategies
|
||||
(chunk-based or aggregate-based) and manages the transition from short-term
|
||||
to long-term memory storage.
|
||||
|
||||
Args:
|
||||
end_user_id: User identifier for memory association
|
||||
strategy_type: Memory storage strategy type (STRATEGY_CHUNK or STRATEGY_AGGREGATE)
|
||||
scope: Scope/window size for memory processing
|
||||
"""
|
||||
with get_db_context() as db_session:
|
||||
repo = LongTermMemoryRepository(db_session)
|
||||
|
||||
|
||||
from app.core.memory.agent.utils.redis_tool import write_store
|
||||
result = write_store.get_session_by_userid(end_user_id)
|
||||
if type==AgentMemory_Long_Term.STRATEGY_CHUNK or AgentMemory_Long_Term.STRATEGY_AGGREGATE:
|
||||
if not result:
|
||||
logger.warning(f"No write data found for user {end_user_id}")
|
||||
return
|
||||
if strategy_type in [AgentMemory_Long_Term.STRATEGY_CHUNK, AgentMemory_Long_Term.STRATEGY_AGGREGATE]:
|
||||
data = await format_parsing(result, "dict")
|
||||
chunk_data = data[:scope]
|
||||
if len(chunk_data)==scope:
|
||||
if len(chunk_data) == scope:
|
||||
repo.upsert(end_user_id, chunk_data)
|
||||
logger.info(f'---------写入短长期-----------')
|
||||
else:
|
||||
@@ -112,73 +132,90 @@ async def term_memory_save(long_term_messages,actual_config_id,end_user_id,type,
|
||||
logger.info(f'写入短长期:')
|
||||
|
||||
|
||||
async def window_dialogue(end_user_id, langchain_messages, memory_config, scope):
|
||||
"""
|
||||
Process dialogue based on window size and write to Neo4j
|
||||
|
||||
'''根据窗口'''
|
||||
async def window_dialogue(end_user_id,langchain_messages,memory_config,scope):
|
||||
'''
|
||||
根据窗口获取redis数据,写入neo4j:
|
||||
Args:
|
||||
end_user_id: 终端用户ID
|
||||
memory_config: 内存配置对象
|
||||
langchain_messages:原始数据LIST
|
||||
scope:窗口大小
|
||||
'''
|
||||
scope=scope
|
||||
is_end_user_id = count_store.get_sessions_count(end_user_id)
|
||||
if is_end_user_id is not False:
|
||||
is_end_user_id = count_store.get_sessions_count(end_user_id)[0]
|
||||
redis_messages = count_store.get_sessions_count(end_user_id)[1]
|
||||
if is_end_user_id and int(is_end_user_id) != int(scope):
|
||||
is_end_user_id += 1
|
||||
langchain_messages += redis_messages
|
||||
count_store.update_sessions_count(end_user_id, is_end_user_id, langchain_messages)
|
||||
elif int(is_end_user_id) == int(scope):
|
||||
Manages conversation data based on a sliding window approach. When the window
|
||||
reaches the specified scope size, it triggers long-term memory storage to Neo4j.
|
||||
|
||||
Args:
|
||||
end_user_id: Terminal user identifier
|
||||
memory_config: Memory configuration object containing settings
|
||||
langchain_messages: Original message data list
|
||||
scope: Window size determining when to trigger long-term storage
|
||||
"""
|
||||
is_end_user_has_history = count_store.get_sessions_count(end_user_id)
|
||||
if is_end_user_has_history:
|
||||
end_user_visit_count, redis_messages = is_end_user_has_history
|
||||
else:
|
||||
count_store.save_sessions_count(end_user_id, 1, langchain_messages)
|
||||
return
|
||||
end_user_visit_count += 1
|
||||
if end_user_visit_count < scope:
|
||||
redis_messages.extend(langchain_messages)
|
||||
count_store.update_sessions_count(end_user_id, end_user_visit_count, redis_messages)
|
||||
else:
|
||||
logger.info('写入长期记忆NEO4J')
|
||||
formatted_messages = (redis_messages)
|
||||
# 获取 config_id(如果 memory_config 是对象,提取 config_id;否则直接使用)
|
||||
redis_messages.extend(langchain_messages)
|
||||
# Get config_id (if memory_config is an object, extract config_id; otherwise use directly)
|
||||
if hasattr(memory_config, 'config_id'):
|
||||
config_id = memory_config.config_id
|
||||
else:
|
||||
config_id = memory_config
|
||||
|
||||
await write(AgentMemory_Long_Term.STORAGE_NEO4J, end_user_id, "", "", None, end_user_id,
|
||||
config_id, formatted_messages)
|
||||
count_store.update_sessions_count(end_user_id, 1, langchain_messages)
|
||||
else:
|
||||
count_store.save_sessions_count(end_user_id, 1, langchain_messages)
|
||||
|
||||
write_message_task.delay(
|
||||
end_user_id, # end_user_id: User ID
|
||||
redis_messages, # message: JSON string format message list
|
||||
config_id, # config_id: Configuration ID string
|
||||
AgentMemory_Long_Term.STORAGE_NEO4J, # storage_type: "neo4j"
|
||||
"" # user_rag_memory_id: RAG memory ID (not used in Neo4j mode)
|
||||
)
|
||||
count_store.update_sessions_count(end_user_id, 0, [])
|
||||
|
||||
|
||||
"""根据时间"""
|
||||
async def memory_long_term_storage(end_user_id,memory_config,time):
|
||||
'''
|
||||
根据时间获取redis数据,写入neo4j:
|
||||
Args:
|
||||
end_user_id: 终端用户ID
|
||||
memory_config: 内存配置对象
|
||||
'''
|
||||
async def memory_long_term_storage(end_user_id, memory_config, time):
|
||||
"""
|
||||
Process memory storage based on time intervals and write to Neo4j
|
||||
|
||||
Retrieves Redis data based on time intervals and writes it to Neo4j for
|
||||
long-term storage. This function handles time-based memory consolidation.
|
||||
|
||||
Args:
|
||||
end_user_id: Terminal user identifier
|
||||
memory_config: Memory configuration object containing settings
|
||||
time: Time interval for data retrieval
|
||||
"""
|
||||
long_time_data = write_store.find_user_recent_sessions(end_user_id, time)
|
||||
format_messages = (long_time_data)
|
||||
messages=[]
|
||||
memory_config=memory_config.config_id
|
||||
format_messages = long_time_data
|
||||
messages = []
|
||||
memory_config = memory_config.config_id
|
||||
for i in format_messages:
|
||||
message=json.loads(i['Query'])
|
||||
messages+= message
|
||||
if format_messages!=[]:
|
||||
message = json.loads(i['Query'])
|
||||
messages += message
|
||||
if format_messages:
|
||||
await write(AgentMemory_Long_Term.STORAGE_NEO4J, end_user_id, "", "", None, end_user_id,
|
||||
memory_config, messages)
|
||||
'''聚合判断'''
|
||||
|
||||
|
||||
async def aggregate_judgment(end_user_id: str, ori_messages: list, memory_config) -> dict:
|
||||
"""
|
||||
聚合判断函数:判断输入句子和历史消息是否描述同一事件
|
||||
Aggregation judgment function: determine if input sentence and historical messages describe the same event
|
||||
|
||||
Uses LLM-based analysis to determine whether new messages should be aggregated with existing
|
||||
historical data or stored as separate events. This helps optimize memory storage and retrieval.
|
||||
|
||||
Args:
|
||||
end_user_id: 终端用户ID
|
||||
ori_messages: 原始消息列表,格式如 [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]
|
||||
memory_config: 内存配置对象
|
||||
"""
|
||||
end_user_id: Terminal user identifier
|
||||
ori_messages: Original message list, format like [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]
|
||||
memory_config: Memory configuration object containing LLM settings
|
||||
|
||||
Returns:
|
||||
dict: Aggregation judgment result containing is_same_event flag and processed output
|
||||
"""
|
||||
history = None
|
||||
try:
|
||||
# 1. 获取历史会话数据(使用新方法)
|
||||
# 1. Get historical session data (using new method)
|
||||
result = write_store.get_all_sessions_by_end_user_id(end_user_id)
|
||||
history = await format_parsing(result)
|
||||
if not result:
|
||||
@@ -210,7 +247,7 @@ async def aggregate_judgment(end_user_id: str, ori_messages: list, memory_config
|
||||
output_value = structured.output
|
||||
if isinstance(output_value, list):
|
||||
output_value = [
|
||||
{"role": msg.role, "content": msg.content}
|
||||
{"role": msg.role, "content": msg.content}
|
||||
for msg in output_value
|
||||
]
|
||||
|
||||
@@ -223,16 +260,14 @@ async def aggregate_judgment(end_user_id: str, ori_messages: list, memory_config
|
||||
await write("neo4j", end_user_id, "", "", None, end_user_id,
|
||||
memory_config.config_id, output_value)
|
||||
return result_dict
|
||||
|
||||
|
||||
except Exception as e:
|
||||
print(f"[aggregate_judgment] 发生错误: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
logger.error(f"[aggregate_judgment] 发生错误: {e}", exc_info=True)
|
||||
|
||||
return {
|
||||
"is_same_event": False,
|
||||
"output": ori_messages,
|
||||
"messages": ori_messages,
|
||||
"history": history if 'history' in locals() else [],
|
||||
"error": str(e)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,41 +2,53 @@ import asyncio
|
||||
import json
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
|
||||
from langchain.tools import tool
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
from app.core.memory.src.search import (
|
||||
search_by_temporal,
|
||||
search_by_keyword_temporal,
|
||||
)
|
||||
|
||||
|
||||
def extract_tool_message_content(response):
|
||||
"""从agent响应中提取ToolMessage内容和工具名称"""
|
||||
"""
|
||||
Extract ToolMessage content and tool names from agent response
|
||||
|
||||
Parses agent response messages to extract tool execution results and metadata.
|
||||
Handles JSON parsing and provides structured access to tool output data.
|
||||
|
||||
Args:
|
||||
response: Agent response dictionary containing messages
|
||||
|
||||
Returns:
|
||||
dict: Dictionary containing tool_name and parsed content, or None if no tool message found
|
||||
- tool_name: Name of the executed tool
|
||||
- content: Parsed tool execution result (JSON or raw text)
|
||||
"""
|
||||
messages = response.get('messages', [])
|
||||
|
||||
for message in messages:
|
||||
if hasattr(message, 'tool_call_id') and hasattr(message, 'content'):
|
||||
# 这是一个ToolMessage
|
||||
# This is a ToolMessage
|
||||
tool_content = message.content
|
||||
tool_name = None
|
||||
|
||||
# 尝试获取工具名称
|
||||
# Try to get tool name
|
||||
if hasattr(message, 'name'):
|
||||
tool_name = message.name
|
||||
elif hasattr(message, 'tool_name'):
|
||||
tool_name = message.tool_name
|
||||
|
||||
try:
|
||||
# 解析JSON内容
|
||||
# Parse JSON content
|
||||
parsed_content = json.loads(tool_content)
|
||||
return {
|
||||
'tool_name': tool_name,
|
||||
'content': parsed_content
|
||||
}
|
||||
except json.JSONDecodeError:
|
||||
# 如果不是JSON格式,直接返回内容
|
||||
# If not JSON format, return content directly
|
||||
return {
|
||||
'tool_name': tool_name,
|
||||
'content': tool_content
|
||||
@@ -46,38 +58,61 @@ def extract_tool_message_content(response):
|
||||
|
||||
|
||||
class TimeRetrievalInput(BaseModel):
|
||||
"""时间检索工具的输入模式"""
|
||||
"""
|
||||
Input schema for time retrieval tool
|
||||
|
||||
Defines the expected input parameters for time-based retrieval operations.
|
||||
Used for validation and documentation of tool parameters.
|
||||
|
||||
Attributes:
|
||||
context: User input query content for search
|
||||
end_user_id: Group ID for filtering search results, defaults to test user
|
||||
"""
|
||||
context: str = Field(description="用户输入的查询内容")
|
||||
end_user_id: str = Field(default="88a459f5_text09", description="组ID,用于过滤搜索结果")
|
||||
|
||||
|
||||
def create_time_retrieval_tool(end_user_id: str):
|
||||
"""
|
||||
创建一个带有特定end_user_id的TimeRetrieval工具(同步版本),用于按时间范围搜索语句(Statements)
|
||||
Create a TimeRetrieval tool with specific end_user_id (synchronous version) for searching statements by time range
|
||||
|
||||
Creates a specialized time-based retrieval tool that searches for statements within
|
||||
specified time ranges. Includes field cleaning functionality to remove unnecessary
|
||||
metadata from search results.
|
||||
|
||||
Args:
|
||||
end_user_id: User identifier for scoping search results
|
||||
|
||||
Returns:
|
||||
function: Configured TimeRetrievalWithGroupId tool function
|
||||
"""
|
||||
|
||||
|
||||
def clean_temporal_result_fields(data):
|
||||
"""
|
||||
清理时间搜索结果中不需要的字段,并修改结构
|
||||
Clean unnecessary fields from temporal search results and modify structure
|
||||
|
||||
Removes metadata fields that are not needed for end-user consumption and
|
||||
restructures the response format for better usability.
|
||||
|
||||
Args:
|
||||
data: 要清理的数据
|
||||
data: Data to be cleaned (dict, list, or other types)
|
||||
|
||||
Returns:
|
||||
清理后的数据
|
||||
Cleaned data with unnecessary fields removed
|
||||
"""
|
||||
# 需要过滤的字段列表
|
||||
# List of fields to filter out
|
||||
fields_to_remove = {
|
||||
'id', 'apply_id', 'user_id', 'chunk_id', 'created_at',
|
||||
'id', 'apply_id', 'user_id', 'chunk_id', 'created_at',
|
||||
'valid_at', 'invalid_at', 'statement_ids'
|
||||
}
|
||||
|
||||
|
||||
if isinstance(data, dict):
|
||||
cleaned = {}
|
||||
for key, value in data.items():
|
||||
if key == 'statements' and isinstance(value, dict) and 'statements' in value:
|
||||
# 将 statements: {"statements": [...]} 改为 time_search: {"statements": [...]}
|
||||
# Change statements: {"statements": [...]} to time_search: {"statements": [...]}
|
||||
cleaned_value = clean_temporal_result_fields(value)
|
||||
# 进一步将内部的 statements 改为 time_search
|
||||
# Further change internal statements to time_search
|
||||
if 'statements' in cleaned_value:
|
||||
cleaned['results'] = {
|
||||
'time_search': cleaned_value['statements']
|
||||
@@ -91,26 +126,35 @@ def create_time_retrieval_tool(end_user_id: str):
|
||||
return [clean_temporal_result_fields(item) for item in data]
|
||||
else:
|
||||
return data
|
||||
|
||||
|
||||
@tool
|
||||
def TimeRetrievalWithGroupId(context: str, start_date: str = None, end_date: str = None, end_user_id_param: str = None, clean_output: bool = True) -> str:
|
||||
def TimeRetrievalWithGroupId(context: str, start_date: str = None, end_date: str = None,
|
||||
end_user_id_param: str = None, clean_output: bool = True) -> str:
|
||||
"""
|
||||
优化的时间检索工具,只结合时间范围搜索(同步版本),自动过滤不需要的元数据字段
|
||||
显式接收参数:
|
||||
- context: 查询上下文内容
|
||||
- start_date: 开始时间(可选,格式:YYYY-MM-DD)
|
||||
- end_date: 结束时间(可选,格式:YYYY-MM-DD)
|
||||
- end_user_id_param: 组ID(可选,用于覆盖默认组ID)
|
||||
- clean_output: 是否清理输出中的元数据字段
|
||||
-end_date 需要根据用户的描述获取结束的时间,输出格式用strftime("%Y-%m-%d")
|
||||
Optimized time retrieval tool, combines time range search only (synchronous version), automatically filters unnecessary metadata fields
|
||||
|
||||
Performs time-based search operations with automatic metadata filtering. Supports
|
||||
flexible date range specification and provides clean, user-friendly output.
|
||||
|
||||
Explicit parameters:
|
||||
- context: Query context content
|
||||
- start_date: Start time (optional, format: YYYY-MM-DD)
|
||||
- end_date: End time (optional, format: YYYY-MM-DD)
|
||||
- end_user_id_param: Group ID (optional, overrides default group ID)
|
||||
- clean_output: Whether to clean metadata fields from output
|
||||
- end_date needs to be obtained based on user description, output format uses strftime("%Y-%m-%d")
|
||||
|
||||
Returns:
|
||||
str: JSON formatted search results with temporal data
|
||||
"""
|
||||
|
||||
async def _async_search():
|
||||
# 使用传入的参数或默认值
|
||||
# Use passed parameters or default values
|
||||
actual_end_user_id = end_user_id_param or end_user_id
|
||||
actual_end_date = end_date or datetime.now().strftime("%Y-%m-%d")
|
||||
actual_start_date = start_date or (datetime.now() - timedelta(days=7)).strftime("%Y-%m-%d")
|
||||
|
||||
# 基本时间搜索
|
||||
# Basic time search
|
||||
results = await search_by_temporal(
|
||||
end_user_id=actual_end_user_id,
|
||||
start_date=actual_start_date,
|
||||
@@ -118,33 +162,43 @@ def create_time_retrieval_tool(end_user_id: str):
|
||||
limit=10
|
||||
)
|
||||
|
||||
# 清理结果中不需要的字段
|
||||
# Clean unnecessary fields from results
|
||||
if clean_output:
|
||||
cleaned_results = clean_temporal_result_fields(results)
|
||||
else:
|
||||
cleaned_results = results
|
||||
|
||||
return json.dumps(cleaned_results, ensure_ascii=False, indent=2)
|
||||
|
||||
|
||||
return asyncio.run(_async_search())
|
||||
|
||||
@tool
|
||||
def KeywordTimeRetrieval(context: str, days_back: int = 7, start_date: str = None, end_date: str = None, clean_output: bool = True) -> str:
|
||||
def KeywordTimeRetrieval(context: str, days_back: int = 7, start_date: str = None, end_date: str = None,
|
||||
clean_output: bool = True) -> str:
|
||||
"""
|
||||
优化的关键词时间检索工具,结合关键词和时间范围搜索(同步版本),自动过滤不需要的元数据字段
|
||||
显式接收参数:
|
||||
- context: 查询内容
|
||||
- days_back: 向前搜索的天数,默认7天
|
||||
- start_date: 开始时间(可选,格式:YYYY-MM-DD)
|
||||
- end_date: 结束时间(可选,格式:YYYY-MM-DD)
|
||||
- clean_output: 是否清理输出中的元数据字段
|
||||
- end_date 需要根据用户的描述获取结束的时间,输出格式用strftime("%Y-%m-%d")
|
||||
Optimized keyword time retrieval tool, combines keyword and time range search (synchronous version), automatically filters unnecessary metadata fields
|
||||
|
||||
Performs combined keyword and temporal search operations with automatic metadata
|
||||
filtering. Provides more targeted search results by combining content relevance
|
||||
with time-based filtering.
|
||||
|
||||
Explicit parameters:
|
||||
- context: Query content for keyword matching
|
||||
- days_back: Number of days to search backwards, default 7 days
|
||||
- start_date: Start time (optional, format: YYYY-MM-DD)
|
||||
- end_date: End time (optional, format: YYYY-MM-DD)
|
||||
- clean_output: Whether to clean metadata fields from output
|
||||
- end_date needs to be obtained based on user description, output format uses strftime("%Y-%m-%d")
|
||||
|
||||
Returns:
|
||||
str: JSON formatted search results combining keyword and temporal data
|
||||
"""
|
||||
|
||||
async def _async_search():
|
||||
actual_end_date = end_date or datetime.now().strftime("%Y-%m-%d")
|
||||
actual_start_date = start_date or (datetime.now() - timedelta(days=days_back)).strftime("%Y-%m-%d")
|
||||
|
||||
# 关键词时间搜索
|
||||
# Keyword time search
|
||||
results = await search_by_keyword_temporal(
|
||||
query_text=context,
|
||||
end_user_id=end_user_id,
|
||||
@@ -153,7 +207,7 @@ def create_time_retrieval_tool(end_user_id: str):
|
||||
limit=15
|
||||
)
|
||||
|
||||
# 清理结果中不需要的字段
|
||||
# Clean unnecessary fields from results
|
||||
if clean_output:
|
||||
cleaned_results = clean_temporal_result_fields(results)
|
||||
else:
|
||||
@@ -162,51 +216,61 @@ def create_time_retrieval_tool(end_user_id: str):
|
||||
return json.dumps(cleaned_results, ensure_ascii=False, indent=2)
|
||||
|
||||
return asyncio.run(_async_search())
|
||||
|
||||
|
||||
return TimeRetrievalWithGroupId
|
||||
|
||||
|
||||
def create_hybrid_retrieval_tool_async(memory_config, **search_params):
|
||||
"""
|
||||
创建混合检索工具,使用run_hybrid_search进行混合检索,优化输出格式并过滤不需要的字段
|
||||
Create hybrid retrieval tool using run_hybrid_search for hybrid retrieval, optimize output format and filter unnecessary fields
|
||||
|
||||
Creates an advanced hybrid search tool that combines multiple search strategies
|
||||
(keyword, vector, hybrid) with automatic result cleaning and formatting.
|
||||
|
||||
Args:
|
||||
memory_config: 内存配置对象
|
||||
**search_params: 搜索参数,包含end_user_id, limit, include等
|
||||
memory_config: Memory configuration object containing LLM and search settings
|
||||
**search_params: Search parameters including end_user_id, limit, include, etc.
|
||||
|
||||
Returns:
|
||||
function: Configured HybridSearch tool function with async capabilities
|
||||
"""
|
||||
|
||||
|
||||
def clean_result_fields(data):
|
||||
"""
|
||||
递归清理结果中不需要的字段
|
||||
Recursively clean unnecessary fields from results
|
||||
|
||||
Removes metadata fields that are not needed for end-user consumption,
|
||||
improving readability and reducing response size.
|
||||
|
||||
Args:
|
||||
data: 要清理的数据(可能是字典、列表或其他类型)
|
||||
data: Data to be cleaned (can be dict, list, or other types)
|
||||
|
||||
Returns:
|
||||
清理后的数据
|
||||
Cleaned data with unnecessary fields removed
|
||||
"""
|
||||
# 需要过滤的字段列表
|
||||
# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
|
||||
# List of fields to filter out
|
||||
# TODO: fact_summary functionality temporarily disabled, will be enabled after future development
|
||||
fields_to_remove = {
|
||||
'invalid_at', 'valid_at', 'chunk_id_from_rel', 'entity_ids',
|
||||
'expired_at', 'created_at', 'chunk_id', 'id', 'apply_id',
|
||||
'user_id', 'statement_ids', 'updated_at',"chunk_ids" ,"fact_summary"
|
||||
'invalid_at', 'valid_at', 'chunk_id_from_rel', 'entity_ids',
|
||||
'expired_at', 'created_at', 'chunk_id', 'apply_id',
|
||||
'user_id', 'statement_ids', 'updated_at', "chunk_ids", "fact_summary"
|
||||
}
|
||||
|
||||
# 注意:'id' 字段保留,community 展开时需要用 community id 查询成员 statements
|
||||
|
||||
if isinstance(data, dict):
|
||||
# 对字典进行清理
|
||||
# Clean dictionary
|
||||
cleaned = {}
|
||||
for key, value in data.items():
|
||||
if key not in fields_to_remove:
|
||||
cleaned[key] = clean_result_fields(value) # 递归清理嵌套数据
|
||||
cleaned[key] = clean_result_fields(value) # Recursively clean nested data
|
||||
return cleaned
|
||||
elif isinstance(data, list):
|
||||
# 对列表中的每个元素进行清理
|
||||
# Clean each element in list
|
||||
return [clean_result_fields(item) for item in data]
|
||||
else:
|
||||
# 其他类型直接返回
|
||||
# Return other types directly
|
||||
return data
|
||||
|
||||
|
||||
@tool
|
||||
async def HybridSearch(
|
||||
context: str,
|
||||
@@ -216,57 +280,63 @@ def create_hybrid_retrieval_tool_async(memory_config, **search_params):
|
||||
rerank_alpha: float = 0.6,
|
||||
use_forgetting_rerank: bool = False,
|
||||
use_llm_rerank: bool = False,
|
||||
clean_output: bool = True # 新增:是否清理输出字段
|
||||
clean_output: bool = True # New: whether to clean output fields
|
||||
) -> str:
|
||||
"""
|
||||
优化的混合检索工具,支持关键词、向量和混合搜索,自动过滤不需要的元数据字段
|
||||
Optimized hybrid retrieval tool, supports keyword, vector and hybrid search, automatically filters unnecessary metadata fields
|
||||
|
||||
Provides comprehensive search capabilities combining multiple search strategies
|
||||
with intelligent result ranking and automatic metadata filtering for clean output.
|
||||
|
||||
Args:
|
||||
context: 查询内容
|
||||
search_type: 搜索类型 ('keyword', 'embedding', 'hybrid')
|
||||
limit: 结果数量限制
|
||||
end_user_id: 组ID,用于过滤搜索结果
|
||||
rerank_alpha: 重排序权重参数
|
||||
use_forgetting_rerank: 是否使用遗忘重排序
|
||||
use_llm_rerank: 是否使用LLM重排序
|
||||
clean_output: 是否清理输出中的元数据字段
|
||||
context: Query content for search
|
||||
search_type: Search type ('keyword', 'embedding', 'hybrid')
|
||||
limit: Result quantity limit
|
||||
end_user_id: Group ID for filtering search results
|
||||
rerank_alpha: Reranking weight parameter for result scoring
|
||||
use_forgetting_rerank: Whether to use forgetting-based reranking
|
||||
use_llm_rerank: Whether to use LLM-based reranking
|
||||
clean_output: Whether to clean metadata fields from output
|
||||
|
||||
Returns:
|
||||
str: JSON formatted comprehensive search results
|
||||
"""
|
||||
try:
|
||||
# 导入run_hybrid_search函数
|
||||
# Import run_hybrid_search function
|
||||
from app.core.memory.src.search import run_hybrid_search
|
||||
|
||||
# 合并参数,优先使用传入的参数
|
||||
# Merge parameters, prioritize passed parameters
|
||||
final_params = {
|
||||
"query_text": context,
|
||||
"search_type": search_type,
|
||||
"end_user_id": end_user_id or search_params.get("end_user_id"),
|
||||
"limit": limit or search_params.get("limit", 10),
|
||||
"include": search_params.get("include", ["summaries", "statements", "chunks", "entities"]),
|
||||
"output_path": None, # 不保存到文件
|
||||
"include": search_params.get("include", ["summaries", "statements", "chunks", "entities", "communities"]),
|
||||
"output_path": None, # Don't save to file
|
||||
"memory_config": memory_config,
|
||||
"rerank_alpha": rerank_alpha,
|
||||
"use_forgetting_rerank": use_forgetting_rerank,
|
||||
"use_llm_rerank": use_llm_rerank
|
||||
}
|
||||
|
||||
# 执行混合检索
|
||||
# Execute hybrid retrieval
|
||||
raw_results = await run_hybrid_search(**final_params)
|
||||
|
||||
# 清理结果中不需要的字段
|
||||
# Clean unnecessary fields from results
|
||||
if clean_output:
|
||||
cleaned_results = clean_result_fields(raw_results)
|
||||
else:
|
||||
cleaned_results = raw_results
|
||||
|
||||
# 格式化返回结果
|
||||
# Format return results
|
||||
formatted_results = {
|
||||
"search_query": context,
|
||||
"search_type": search_type,
|
||||
"results": cleaned_results
|
||||
}
|
||||
|
||||
|
||||
return json.dumps(formatted_results, ensure_ascii=False, indent=2, default=str)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
error_result = {
|
||||
"error": f"混合检索失败: {str(e)}",
|
||||
@@ -275,38 +345,52 @@ def create_hybrid_retrieval_tool_async(memory_config, **search_params):
|
||||
"timestamp": datetime.now().isoformat()
|
||||
}
|
||||
return json.dumps(error_result, ensure_ascii=False, indent=2)
|
||||
|
||||
|
||||
return HybridSearch
|
||||
|
||||
|
||||
def create_hybrid_retrieval_tool_sync(memory_config, **search_params):
|
||||
"""
|
||||
创建同步版本的混合检索工具,优化输出格式并过滤不需要的字段
|
||||
Create synchronous version of hybrid retrieval tool, optimize output format and filter unnecessary fields
|
||||
|
||||
Creates a synchronous wrapper around the async hybrid search functionality,
|
||||
making it compatible with synchronous tool execution environments.
|
||||
|
||||
Args:
|
||||
memory_config: 内存配置对象
|
||||
**search_params: 搜索参数
|
||||
memory_config: Memory configuration object containing search settings
|
||||
**search_params: Search parameters for configuration
|
||||
|
||||
Returns:
|
||||
function: Configured HybridSearchSync tool function
|
||||
"""
|
||||
|
||||
@tool
|
||||
def HybridSearchSync(
|
||||
context: str,
|
||||
search_type: str = "hybrid",
|
||||
limit: int = 10,
|
||||
end_user_id: str = None,
|
||||
clean_output: bool = True
|
||||
context: str,
|
||||
search_type: str = "hybrid",
|
||||
limit: int = 10,
|
||||
end_user_id: str = None,
|
||||
clean_output: bool = True
|
||||
) -> str:
|
||||
"""
|
||||
优化的混合检索工具(同步版本),自动过滤不需要的元数据字段
|
||||
Optimized hybrid retrieval tool (synchronous version), automatically filters unnecessary metadata fields
|
||||
|
||||
Provides the same hybrid search capabilities as the async version but in a
|
||||
synchronous execution context. Automatically handles async-to-sync conversion.
|
||||
|
||||
Args:
|
||||
context: 查询内容
|
||||
search_type: 搜索类型 ('keyword', 'embedding', 'hybrid')
|
||||
limit: 结果数量限制
|
||||
end_user_id: 组ID,用于过滤搜索结果
|
||||
clean_output: 是否清理输出中的元数据字段
|
||||
context: Query content for search
|
||||
search_type: Search type ('keyword', 'embedding', 'hybrid')
|
||||
limit: Result quantity limit
|
||||
end_user_id: Group ID for filtering search results
|
||||
clean_output: Whether to clean metadata fields from output
|
||||
|
||||
Returns:
|
||||
str: JSON formatted search results
|
||||
"""
|
||||
|
||||
async def _async_search():
|
||||
# 创建异步工具并执行
|
||||
# Create async tool and execute
|
||||
async_tool = create_hybrid_retrieval_tool_async(memory_config, **search_params)
|
||||
return await async_tool.ainvoke({
|
||||
"context": context,
|
||||
@@ -315,7 +399,7 @@ def create_hybrid_retrieval_tool_sync(memory_config, **search_params):
|
||||
"end_user_id": end_user_id,
|
||||
"clean_output": clean_output
|
||||
})
|
||||
|
||||
|
||||
return asyncio.run(_async_search())
|
||||
|
||||
return HybridSearchSync
|
||||
|
||||
return HybridSearchSync
|
||||
|
||||
@@ -1,20 +1,28 @@
|
||||
import json
|
||||
|
||||
from langchain_core.messages import HumanMessage, AIMessage
|
||||
async def format_parsing(messages: list,type:str='string'):
|
||||
|
||||
|
||||
async def format_parsing(messages: list, type: str = 'string'):
|
||||
"""
|
||||
格式化解析消息列表
|
||||
Format and parse message lists into different output types
|
||||
|
||||
Processes message lists from storage and converts them into either string format
|
||||
or dictionary format based on the specified type parameter. Handles JSON parsing
|
||||
and role-based message organization.
|
||||
|
||||
Args:
|
||||
messages: 消息列表
|
||||
type: 返回类型 ('string' 或 'dict')
|
||||
messages: List of message objects from storage containing message data
|
||||
type: Return type specification ('string' for text format, 'dict' for key-value pairs)
|
||||
|
||||
Returns:
|
||||
格式化后的消息列表
|
||||
list: Formatted message list in the specified format
|
||||
- 'string': List of formatted text messages with role prefixes
|
||||
- 'dict': List of dictionaries mapping user messages to AI responses
|
||||
"""
|
||||
result = []
|
||||
user=[]
|
||||
ai=[]
|
||||
user = []
|
||||
ai = []
|
||||
|
||||
for message in messages:
|
||||
hstory_messages = message['messages']
|
||||
@@ -24,25 +32,38 @@ async def format_parsing(messages: list,type:str='string'):
|
||||
role = content['role']
|
||||
content = content['content']
|
||||
if type == "string":
|
||||
if role == 'human' or role=="user":
|
||||
if role == 'human' or role == "user":
|
||||
content = '用户:' + content
|
||||
else:
|
||||
content = 'AI:' + content
|
||||
result.append(content)
|
||||
if type == "dict" :
|
||||
if role == 'human' or role=="user":
|
||||
user.append( content)
|
||||
if type == "dict":
|
||||
if role == 'human' or role == "user":
|
||||
user.append(content)
|
||||
else:
|
||||
ai.append(content)
|
||||
if type == "dict":
|
||||
for key,values in zip(user,ai):
|
||||
result.append({key:values})
|
||||
for key, values in zip(user, ai):
|
||||
result.append({key: values})
|
||||
return result
|
||||
|
||||
|
||||
async def messages_parse(messages: list | dict):
|
||||
user=[]
|
||||
ai=[]
|
||||
database=[]
|
||||
"""
|
||||
Parse messages from storage format into user-AI conversation pairs
|
||||
|
||||
Extracts and organizes conversation data from stored message format,
|
||||
separating user and AI messages and pairing them for database storage.
|
||||
|
||||
Args:
|
||||
messages: List or dictionary containing stored message data with Query fields
|
||||
|
||||
Returns:
|
||||
list: List of dictionaries containing user-AI message pairs for database storage
|
||||
"""
|
||||
user = []
|
||||
ai = []
|
||||
database = []
|
||||
for message in messages:
|
||||
Query = message['Query']
|
||||
Query = json.loads(Query)
|
||||
@@ -54,10 +75,23 @@ async def messages_parse(messages: list | dict):
|
||||
ai.append(data['content'])
|
||||
for key, values in zip(user, ai):
|
||||
database.append({key, values})
|
||||
return database
|
||||
return database
|
||||
|
||||
|
||||
async def agent_chat_messages(user_content,ai_content):
|
||||
async def agent_chat_messages(user_content, ai_content):
|
||||
"""
|
||||
Create structured chat message format for agent conversations
|
||||
|
||||
Formats user and AI content into a standardized message structure suitable
|
||||
for agent processing and storage. Creates role-based message objects.
|
||||
|
||||
Args:
|
||||
user_content: User's message content string
|
||||
ai_content: AI's response content string
|
||||
|
||||
Returns:
|
||||
list: List of structured message dictionaries with role and content fields
|
||||
"""
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
|
||||
@@ -1,104 +1,94 @@
|
||||
import asyncio
|
||||
import json
|
||||
import sys
|
||||
import warnings
|
||||
from contextlib import asynccontextmanager
|
||||
from langgraph.constants import END, START
|
||||
from langgraph.graph import StateGraph
|
||||
|
||||
from app.db import get_db, get_db_context
|
||||
from app.core.logging_config import get_agent_logger
|
||||
from app.core.memory.agent.utils.llm_tools import WriteState
|
||||
from app.core.memory.agent.langgraph_graph.nodes.write_nodes import write_node
|
||||
from app.core.memory.agent.langgraph_graph.routing.write_router import memory_long_term_storage, window_dialogue, \
|
||||
aggregate_judgment
|
||||
from app.core.memory.agent.utils.redis_tool import write_store
|
||||
from app.db import get_db_context
|
||||
from app.schemas.memory_agent_schema import AgentMemory_Long_Term
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
|
||||
from app.services.memory_konwledges_server import write_rag
|
||||
|
||||
warnings.filterwarnings("ignore", category=RuntimeWarning)
|
||||
logger = get_agent_logger(__name__)
|
||||
|
||||
if sys.platform.startswith("win"):
|
||||
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def make_write_graph():
|
||||
async def long_term_storage(
|
||||
long_term_type: str,
|
||||
langchain_messages: list,
|
||||
memory_config_id: str,
|
||||
end_user_id: str,
|
||||
scope: int = 6
|
||||
):
|
||||
"""
|
||||
Create a write graph workflow for memory operations.
|
||||
Handle long-term memory storage with different strategies
|
||||
|
||||
Supports multiple storage strategies including chunk-based, time-based,
|
||||
and aggregate judgment approaches for long-term memory persistence.
|
||||
|
||||
Args:
|
||||
user_id: User identifier
|
||||
tools: MCP tools loaded from session
|
||||
apply_id: Application identifier
|
||||
end_user_id: Group identifier
|
||||
memory_config: MemoryConfig object containing all configuration
|
||||
long_term_type: Storage strategy type ('chunk', 'time', 'aggregate')
|
||||
langchain_messages: List of messages to store
|
||||
memory_config_id: Memory configuration identifier
|
||||
end_user_id: User group identifier
|
||||
scope: Scope parameter for chunk-based storage (default: 6)
|
||||
"""
|
||||
workflow = StateGraph(WriteState)
|
||||
workflow.add_node("save_neo4j", write_node)
|
||||
workflow.add_edge(START, "save_neo4j")
|
||||
workflow.add_edge("save_neo4j", END)
|
||||
if langchain_messages is None:
|
||||
langchain_messages = []
|
||||
|
||||
graph = workflow.compile()
|
||||
|
||||
yield graph
|
||||
|
||||
async def long_term_storage(long_term_type:str="chunk",langchain_messages:list=[],memory_config:str='',end_user_id:str='',scope:int=6):
|
||||
from app.core.memory.agent.langgraph_graph.routing.write_router import memory_long_term_storage, window_dialogue,aggregate_judgment
|
||||
from app.core.memory.agent.utils.redis_tool import write_store
|
||||
write_store.save_session_write(end_user_id, (langchain_messages))
|
||||
write_store.save_session_write(end_user_id, langchain_messages)
|
||||
# 获取数据库会话
|
||||
with get_db_context() as db_session:
|
||||
config_service = MemoryConfigService(db_session)
|
||||
memory_config = config_service.load_memory_config(
|
||||
config_id=memory_config, # 改为整数
|
||||
config_id=memory_config_id, # 改为整数
|
||||
service_name="MemoryAgentService"
|
||||
)
|
||||
if long_term_type=='chunk':
|
||||
'''方案一:对话窗口6轮对话'''
|
||||
await window_dialogue(end_user_id,langchain_messages,memory_config,scope)
|
||||
if long_term_type=='time':
|
||||
"""时间"""
|
||||
await memory_long_term_storage(end_user_id, memory_config,5)
|
||||
if long_term_type=='aggregate':
|
||||
"""方案三:聚合判断"""
|
||||
if long_term_type == AgentMemory_Long_Term.STRATEGY_CHUNK:
|
||||
# Dialogue window with 6 rounds of conversation
|
||||
await window_dialogue(end_user_id, langchain_messages, memory_config, scope)
|
||||
if long_term_type == AgentMemory_Long_Term.STRATEGY_TIME:
|
||||
# Time-based strategy
|
||||
await memory_long_term_storage(end_user_id, memory_config, AgentMemory_Long_Term.TIME_SCOPE)
|
||||
if long_term_type == AgentMemory_Long_Term.STRATEGY_AGGREGATE:
|
||||
# Aggregate judgment
|
||||
await aggregate_judgment(end_user_id, langchain_messages, memory_config)
|
||||
|
||||
|
||||
async def write_long_term(
|
||||
storage_type: str,
|
||||
end_user_id: str,
|
||||
messages: list[dict],
|
||||
user_rag_memory_id: str,
|
||||
actual_config_id: str
|
||||
):
|
||||
"""
|
||||
Write long-term memory with different storage types
|
||||
|
||||
async def write_long_term(storage_type,end_user_id,message_chat,aimessages,user_rag_memory_id,actual_config_id):
|
||||
from app.core.memory.agent.langgraph_graph.routing.write_router import write_rag_agent
|
||||
Handles both RAG-based storage and traditional memory storage approaches.
|
||||
For traditional storage, uses chunk-based strategy with paired user-AI messages.
|
||||
|
||||
Args:
|
||||
storage_type: Type of storage (RAG or traditional)
|
||||
end_user_id: User group identifier
|
||||
messages: message list
|
||||
user_rag_memory_id: RAG memory identifier
|
||||
actual_config_id: Actual configuration ID
|
||||
"""
|
||||
from app.core.memory.agent.langgraph_graph.routing.write_router import term_memory_save
|
||||
from app.core.memory.agent.langgraph_graph.tools.write_tool import agent_chat_messages
|
||||
if storage_type == AgentMemory_Long_Term.STORAGE_RAG:
|
||||
await write_rag_agent(end_user_id, message_chat, aimessages, user_rag_memory_id)
|
||||
message_content = []
|
||||
for message in messages:
|
||||
message_content.append(f'{message.get("role")}:{message.get("content")}')
|
||||
messages_string = "\n".join(message_content)
|
||||
await write_rag(end_user_id, messages_string, user_rag_memory_id)
|
||||
else:
|
||||
# AI 回复写入(用户消息和 AI 回复配对,一次性写入完整对话)
|
||||
# AI reply writing (user messages and AI replies paired, written as complete dialogue at once)
|
||||
CHUNK = AgentMemory_Long_Term.STRATEGY_CHUNK
|
||||
SCOPE = AgentMemory_Long_Term.DEFAULT_SCOPE
|
||||
long_term_messages = await agent_chat_messages(message_chat, aimessages)
|
||||
await long_term_storage(long_term_type=CHUNK, langchain_messages=long_term_messages,
|
||||
memory_config=actual_config_id, end_user_id=end_user_id, scope=SCOPE)
|
||||
await term_memory_save(long_term_messages, actual_config_id, end_user_id, CHUNK, scope=SCOPE)
|
||||
|
||||
# async def main():
|
||||
# """主函数 - 运行工作流"""
|
||||
# langchain_messages = [
|
||||
# {
|
||||
# "role": "user",
|
||||
# "content": "今天周五去爬山"
|
||||
# },
|
||||
# {
|
||||
# "role": "assistant",
|
||||
# "content": "好耶"
|
||||
# }
|
||||
#
|
||||
# ]
|
||||
# end_user_id = '837fee1b-04a2-48ee-94d7-211488908940' # 组ID
|
||||
# memory_config="08ed205c-0f05-49c3-8e0c-a580d28f5fd4"
|
||||
# await long_term_storage(long_term_type="chunk",langchain_messages=langchain_messages,memory_config=memory_config,end_user_id=end_user_id,scope=2)
|
||||
#
|
||||
#
|
||||
#
|
||||
# if __name__ == "__main__":
|
||||
# import asyncio
|
||||
# asyncio.run(main())
|
||||
await long_term_storage(long_term_type=CHUNK,
|
||||
langchain_messages=messages,
|
||||
memory_config_id=actual_config_id,
|
||||
end_user_id=end_user_id,
|
||||
scope=SCOPE)
|
||||
await term_memory_save(end_user_id, CHUNK, scope=SCOPE)
|
||||
|
||||
@@ -10,18 +10,84 @@ from app.core.logging_config import get_agent_logger
|
||||
from app.core.memory.src.search import run_hybrid_search
|
||||
from app.core.memory.utils.data.text_utils import escape_lucene_query
|
||||
|
||||
|
||||
logger = get_agent_logger(__name__)
|
||||
|
||||
# 需要从展开结果中过滤的字段(含 Neo4j DateTime,不可 JSON 序列化)
|
||||
_EXPAND_FIELDS_TO_REMOVE = {
|
||||
'invalid_at', 'valid_at', 'chunk_id_from_rel', 'entity_ids',
|
||||
'expired_at', 'created_at', 'chunk_id', 'apply_id',
|
||||
'user_id', 'statement_ids', 'updated_at', 'chunk_ids', 'fact_summary'
|
||||
}
|
||||
|
||||
|
||||
def _clean_expand_fields(obj):
|
||||
"""递归过滤展开结果中不可序列化的字段(DateTime 等)。"""
|
||||
if isinstance(obj, dict):
|
||||
return {k: _clean_expand_fields(v) for k, v in obj.items() if k not in _EXPAND_FIELDS_TO_REMOVE}
|
||||
if isinstance(obj, list):
|
||||
return [_clean_expand_fields(i) for i in obj]
|
||||
return obj
|
||||
|
||||
|
||||
async def expand_communities_to_statements(
|
||||
community_results: List[dict],
|
||||
end_user_id: str,
|
||||
existing_content: str = "",
|
||||
limit: int = 10,
|
||||
) -> Tuple[List[dict], List[str]]:
|
||||
"""
|
||||
社区展开 helper:给定命中的 community 列表,拉取关联 Statement。
|
||||
|
||||
- 对展开结果去重(过滤已在 existing_content 中出现的文本)
|
||||
- 过滤不可序列化字段
|
||||
- 返回 (cleaned_expanded_stmts, new_texts)
|
||||
- cleaned_expanded_stmts: 可直接写回 raw_results 的列表
|
||||
- new_texts: 去重后新增的 statement 文本列表,用于追加到 clean_content
|
||||
"""
|
||||
community_ids = [r.get("id") for r in community_results if r.get("id")]
|
||||
if not community_ids or not end_user_id:
|
||||
return [], []
|
||||
|
||||
from app.repositories.neo4j.graph_search import search_graph_community_expand
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
|
||||
connector = Neo4jConnector()
|
||||
try:
|
||||
result = await search_graph_community_expand(
|
||||
connector=connector,
|
||||
community_ids=community_ids,
|
||||
end_user_id=end_user_id,
|
||||
limit=limit,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"[expand_communities] 社区展开检索失败,跳过: {e}")
|
||||
return [], []
|
||||
finally:
|
||||
await connector.close()
|
||||
|
||||
expanded_stmts = result.get("expanded_statements", [])
|
||||
if not expanded_stmts:
|
||||
return [], []
|
||||
|
||||
existing_lines = set(existing_content.splitlines())
|
||||
new_texts = [
|
||||
s["statement"] for s in expanded_stmts
|
||||
if s.get("statement") and s["statement"] not in existing_lines
|
||||
]
|
||||
cleaned = _clean_expand_fields(expanded_stmts)
|
||||
logger.info(
|
||||
f"[expand_communities] 展开 {len(expanded_stmts)} 条 statements,新增 {len(new_texts)} 条,community_ids={community_ids}")
|
||||
return cleaned, new_texts
|
||||
|
||||
|
||||
class SearchService:
|
||||
"""Service for executing hybrid search and processing results."""
|
||||
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the search service."""
|
||||
logger.info("SearchService initialized")
|
||||
|
||||
def extract_content_from_result(self, result: dict) -> str:
|
||||
|
||||
def extract_content_from_result(self, result: dict, node_type: str = "") -> str:
|
||||
"""
|
||||
Extract only meaningful content from search results, dropping all metadata.
|
||||
|
||||
@@ -30,35 +96,50 @@ class SearchService:
|
||||
- Entities: extract 'name' and 'fact_summary' fields
|
||||
- Summaries: extract 'content' field
|
||||
- Chunks: extract 'content' field
|
||||
- Communities: extract 'content' field (c.summary), prefixed with community name
|
||||
|
||||
Args:
|
||||
result: Search result dictionary
|
||||
node_type: Hint for node type ("community", "summary", etc.)
|
||||
|
||||
Returns:
|
||||
Clean content string without metadata
|
||||
"""
|
||||
if not isinstance(result, dict):
|
||||
return str(result)
|
||||
|
||||
|
||||
content_parts = []
|
||||
|
||||
|
||||
# Statements: extract statement field
|
||||
if 'statement' in result and result['statement']:
|
||||
content_parts.append(result['statement'])
|
||||
|
||||
# Summaries/Chunks: extract content field
|
||||
if 'content' in result and result['content']:
|
||||
|
||||
# Community 节点:有 member_count 或 core_entities 字段,或 node_type 明确指定
|
||||
# 用 "[主题:{name}]" 前缀区分,让 LLM 知道这是主题级摘要
|
||||
is_community = (
|
||||
node_type == "community"
|
||||
or 'member_count' in result
|
||||
or 'core_entities' in result
|
||||
)
|
||||
if is_community:
|
||||
name = result.get('name', '')
|
||||
content = result.get('content', '')
|
||||
if content:
|
||||
prefix = f"[主题:{name}] " if name else ""
|
||||
content_parts.append(f"{prefix}{content}")
|
||||
elif 'content' in result and result['content']:
|
||||
# Summaries / Chunks
|
||||
content_parts.append(result['content'])
|
||||
|
||||
|
||||
# Entities: extract name and fact_summary (commented out in original)
|
||||
# if 'name' in result and result['name']:
|
||||
# content_parts.append(result['name'])
|
||||
# if result.get('fact_summary'):
|
||||
# content_parts.append(result['fact_summary'])
|
||||
|
||||
|
||||
# Return concatenated content or empty string
|
||||
return '\n'.join(content_parts) if content_parts else ""
|
||||
|
||||
|
||||
def clean_query(self, query: str) -> str:
|
||||
"""
|
||||
Clean and escape query text for Lucene.
|
||||
@@ -74,32 +155,33 @@ class SearchService:
|
||||
Cleaned and escaped query string
|
||||
"""
|
||||
q = str(query).strip()
|
||||
|
||||
|
||||
# Remove wrapping quotes
|
||||
if (q.startswith("'") and q.endswith("'")) or (
|
||||
q.startswith('"') and q.endswith('"')
|
||||
q.startswith('"') and q.endswith('"')
|
||||
):
|
||||
q = q[1:-1]
|
||||
|
||||
|
||||
# Remove newlines and carriage returns
|
||||
q = q.replace('\r', ' ').replace('\n', ' ').strip()
|
||||
|
||||
|
||||
# Apply Lucene escaping
|
||||
q = escape_lucene_query(q)
|
||||
|
||||
|
||||
return q
|
||||
|
||||
|
||||
async def execute_hybrid_search(
|
||||
self,
|
||||
end_user_id: str,
|
||||
question: str,
|
||||
limit: int = 5,
|
||||
search_type: str = "hybrid",
|
||||
include: Optional[List[str]] = None,
|
||||
rerank_alpha: float = 0.4,
|
||||
output_path: str = "search_results.json",
|
||||
return_raw_results: bool = False,
|
||||
memory_config = None
|
||||
self,
|
||||
end_user_id: str,
|
||||
question: str,
|
||||
limit: int = 5,
|
||||
search_type: str = "hybrid",
|
||||
include: Optional[List[str]] = None,
|
||||
rerank_alpha: float = 0.4,
|
||||
output_path: str = "search_results.json",
|
||||
return_raw_results: bool = False,
|
||||
memory_config=None,
|
||||
expand_communities: bool = True,
|
||||
) -> Tuple[str, str, Optional[dict]]:
|
||||
"""
|
||||
Execute hybrid search and return clean content.
|
||||
@@ -114,17 +196,19 @@ class SearchService:
|
||||
output_path: Path to save search results (default: "search_results.json")
|
||||
return_raw_results: If True, also return the raw search results as third element (default: False)
|
||||
memory_config: Memory configuration object (required)
|
||||
expand_communities: If True, expand community hits to member statements (default: True).
|
||||
Set to False for quick-summary paths that only need community-level text.
|
||||
|
||||
Returns:
|
||||
Tuple of (clean_content, cleaned_query, raw_results)
|
||||
raw_results is None if return_raw_results=False
|
||||
"""
|
||||
if include is None:
|
||||
include = ["statements", "chunks", "entities", "summaries"]
|
||||
|
||||
include = ["statements", "chunks", "entities", "summaries", "communities"]
|
||||
|
||||
# Clean query
|
||||
cleaned_query = self.clean_query(question)
|
||||
|
||||
|
||||
try:
|
||||
# Execute search
|
||||
answer = await run_hybrid_search(
|
||||
@@ -137,18 +221,18 @@ class SearchService:
|
||||
memory_config=memory_config,
|
||||
rerank_alpha=rerank_alpha
|
||||
)
|
||||
|
||||
|
||||
# Extract results based on search type and include parameter
|
||||
# Prioritize summaries as they contain synthesized contextual information
|
||||
answer_list = []
|
||||
|
||||
|
||||
# For hybrid search, use reranked_results
|
||||
if search_type == "hybrid":
|
||||
reranked_results = answer.get('reranked_results', {})
|
||||
|
||||
# Priority order: summaries first (most contextual), then statements, chunks, entities
|
||||
priority_order = ['summaries', 'statements', 'chunks', 'entities']
|
||||
|
||||
|
||||
# Priority order: summaries first (most contextual), then communities, statements, chunks, entities
|
||||
priority_order = ['summaries', 'communities', 'statements', 'chunks', 'entities']
|
||||
|
||||
for category in priority_order:
|
||||
if category in include and category in reranked_results:
|
||||
category_results = reranked_results[category]
|
||||
@@ -157,33 +241,46 @@ class SearchService:
|
||||
else:
|
||||
# For keyword or embedding search, results are directly in answer dict
|
||||
# Apply same priority order
|
||||
priority_order = ['summaries', 'statements', 'chunks', 'entities']
|
||||
|
||||
priority_order = ['summaries', 'communities', 'statements', 'chunks', 'entities']
|
||||
|
||||
for category in priority_order:
|
||||
if category in include and category in answer:
|
||||
category_results = answer[category]
|
||||
if isinstance(category_results, list):
|
||||
answer_list.extend(category_results)
|
||||
|
||||
# Extract clean content from all results
|
||||
content_list = [
|
||||
self.extract_content_from_result(ans)
|
||||
for ans in answer_list
|
||||
]
|
||||
|
||||
|
||||
# 对命中的 community 节点展开其成员 statements(路径 "0"/"1" 需要,路径 "2" 不需要)
|
||||
if expand_communities and "communities" in include:
|
||||
community_results = (
|
||||
answer.get('reranked_results', {}).get('communities', [])
|
||||
if search_type == "hybrid"
|
||||
else answer.get('communities', [])
|
||||
)
|
||||
cleaned_stmts, new_texts = await expand_communities_to_statements(
|
||||
community_results=community_results,
|
||||
end_user_id=end_user_id,
|
||||
)
|
||||
answer_list.extend(cleaned_stmts)
|
||||
|
||||
# Extract clean content from all results,按类型传入 node_type 区分 community
|
||||
content_list = []
|
||||
for ans in answer_list:
|
||||
# community 节点有 member_count 或 core_entities 字段
|
||||
ntype = "community" if ('member_count' in ans or 'core_entities' in ans) else ""
|
||||
content_list.append(self.extract_content_from_result(ans, node_type=ntype))
|
||||
|
||||
# Filter out empty strings and join with newlines
|
||||
clean_content = '\n'.join([c for c in content_list if c])
|
||||
|
||||
|
||||
# Log first 200 chars
|
||||
logger.info(f"检索接口搜索结果==>>:{clean_content[:200]}...")
|
||||
|
||||
|
||||
# Return raw results if requested
|
||||
if return_raw_results:
|
||||
return clean_content, cleaned_query, answer
|
||||
else:
|
||||
return clean_content, cleaned_query, None
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Search failed for query '{question}' in group '{end_user_id}': {e}",
|
||||
|
||||
@@ -11,7 +11,7 @@ async def get_chunked_dialogs(
|
||||
chunker_strategy: str = "RecursiveChunker",
|
||||
end_user_id: str = "group_1",
|
||||
messages: list = None,
|
||||
ref_id: str = "wyl_20251027",
|
||||
ref_id: str = "",
|
||||
config_id: str = None
|
||||
) -> List[DialogData]:
|
||||
"""Generate chunks from structured messages using the specified chunker strategy.
|
||||
@@ -40,12 +40,13 @@ async def get_chunked_dialogs(
|
||||
|
||||
role = msg['role']
|
||||
content = msg['content']
|
||||
files = msg.get("file_content", [])
|
||||
|
||||
if role not in ['user', 'assistant']:
|
||||
raise ValueError(f"Message {idx} role must be 'user' or 'assistant', got: {role}")
|
||||
|
||||
if content.strip():
|
||||
conversation_messages.append(ConversationMessage(role=role, msg=content.strip()))
|
||||
conversation_messages.append(ConversationMessage(role=role, msg=content.strip(), files=files))
|
||||
|
||||
if not conversation_messages:
|
||||
raise ValueError("Message list cannot be empty after filtering")
|
||||
@@ -84,7 +85,7 @@ async def get_chunked_dialogs(
|
||||
pruning_scene=memory_config.pruning_scene or "education",
|
||||
pruning_threshold=memory_config.pruning_threshold,
|
||||
scene_id=str(memory_config.scene_id) if memory_config.scene_id else None,
|
||||
ontology_classes=memory_config.ontology_classes,
|
||||
ontology_class_infos=memory_config.ontology_class_infos,
|
||||
)
|
||||
logger.info(f"[剪枝] 加载配置: switch={pruning_config.pruning_switch}, scene={pruning_config.pruning_scene}, threshold={pruning_config.pruning_threshold}")
|
||||
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Annotated, TypedDict
|
||||
@@ -8,10 +7,11 @@ from langgraph.graph import add_messages
|
||||
|
||||
PROJECT_ROOT_ = str(Path(__file__).resolve().parents[3])
|
||||
|
||||
|
||||
class WriteState(TypedDict):
|
||||
'''
|
||||
"""
|
||||
Langgrapg Writing TypedDict
|
||||
'''
|
||||
"""
|
||||
messages: Annotated[list[AnyMessage], add_messages]
|
||||
end_user_id: str
|
||||
errors: list[dict] # Track errors: [{"tool": "tool_name", "error": "message"}]
|
||||
@@ -20,6 +20,7 @@ class WriteState(TypedDict):
|
||||
data: str
|
||||
language: str # 语言类型 ("zh" 中文, "en" 英文)
|
||||
|
||||
|
||||
class ReadState(TypedDict):
|
||||
"""
|
||||
LangGraph 工作流状态定义
|
||||
@@ -43,18 +44,21 @@ class ReadState(TypedDict):
|
||||
config_id: str
|
||||
data: str # 新增字段用于传递内容
|
||||
spit_data: dict # 新增字段用于传递问题分解结果
|
||||
problem_extension:dict
|
||||
problem_extension: dict
|
||||
storage_type: str
|
||||
user_rag_memory_id: str
|
||||
llm_id: str
|
||||
embedding_id: str
|
||||
memory_config: object # 新增字段用于传递内存配置对象
|
||||
retrieve:dict
|
||||
retrieve: dict
|
||||
perceptual_data: dict
|
||||
RetrieveSummary: dict
|
||||
InputSummary: dict
|
||||
verify: dict
|
||||
SummaryFails: dict
|
||||
summary: dict
|
||||
|
||||
|
||||
class COUNTState:
|
||||
"""
|
||||
工作流对话检索内容计数器
|
||||
@@ -99,6 +103,7 @@ class COUNTState:
|
||||
self.total = 0
|
||||
print("[COUNTState] 已重置为 0")
|
||||
|
||||
|
||||
def deduplicate_entries(entries):
|
||||
seen = set()
|
||||
deduped = []
|
||||
@@ -109,6 +114,7 @@ def deduplicate_entries(entries):
|
||||
deduped.append(entry)
|
||||
return deduped
|
||||
|
||||
|
||||
def merge_to_key_value_pairs(data, query_key, result_key):
|
||||
grouped = defaultdict(list)
|
||||
for item in data:
|
||||
@@ -142,4 +148,4 @@ def convert_extended_question_to_question(data):
|
||||
return [convert_extended_question_to_question(item) for item in data]
|
||||
else:
|
||||
# 其他类型直接返回
|
||||
return data
|
||||
return data
|
||||
|
||||
@@ -39,6 +39,30 @@
|
||||
比如:输入历史信息内容:[{'Query': '4月27日,我和你推荐过一本书,书名是什么?', 'ANswer': '张曼玉推荐了《小王子》'}]
|
||||
拆分问题:4月27日,我和你推荐过一本书,书名是什么?,可以拆分为:4月27日,张曼玉推荐过一本书,书名是什么?
|
||||
|
||||
## 指代消歧规则(Coreference Resolution):
|
||||
在拆分问题时,必须解析并替换所有指代词和抽象称呼,使问题具体化:
|
||||
|
||||
1. **"用户"的消歧**:
|
||||
- "用户是谁?" → 分析历史记录,找出对话发起者的姓名
|
||||
- 如果历史中有"我叫X"、"我的名字是X"、或多次提到某个人物,则"用户"指的就是这个人
|
||||
- 示例:历史中有"老李的原名叫李建国",则"用户是谁?"应拆分为"李建国是谁?"或"老李(李建国)是谁?"
|
||||
|
||||
2. **"我"的消歧**:
|
||||
- "我喜欢什么?" → 从历史中找出对话发起者的姓名,替换为"X喜欢什么?"
|
||||
- 示例:历史中有"张曼玉推荐了《小王子》",则"我推荐的书是什么?"应拆分为"张曼玉推荐的书是什么?"
|
||||
|
||||
3. **"他/她/它"的消歧**:
|
||||
- 从上下文或历史中找出最近提到的同类实体
|
||||
- 示例:历史中有"老李的同事叫他建国哥",则"他的同事怎么称呼他?"应拆分为"老李的同事怎么称呼他?"
|
||||
|
||||
4. **"那个人/这个人"的消歧**:
|
||||
- 从历史中找出最近提到的人物
|
||||
- 示例:历史中有"李建国",则"那个人的原名是什么?"应拆分为"李建国的原名是什么?"
|
||||
|
||||
5. **优先级**:
|
||||
- 如果历史记录中反复出现某个人物(如"老李"、"李建国"、"建国哥"),则"用户"很可能指的就是这个人
|
||||
- 如果无法从历史中确定指代对象,保留原问题,但在reason中说明"无法确定指代对象"
|
||||
|
||||
|
||||
|
||||
输出要求:
|
||||
@@ -71,6 +95,34 @@
|
||||
"reason": "输出原问题的关键要素"
|
||||
}
|
||||
]
|
||||
|
||||
## 指代消歧示例(重要):
|
||||
示例1 - "用户"的消歧:
|
||||
输入历史:[{'Query': '老李的原名叫什么?', 'Answer': '李建国'}, {'Query': '老李的同事叫他什么?', 'Answer': '建国哥'}]
|
||||
输入问题:"用户是谁?"
|
||||
输出:
|
||||
[
|
||||
{
|
||||
"original_question": "用户是谁?",
|
||||
"extended_question": "李建国是谁?",
|
||||
"type": "单跳",
|
||||
"reason": "历史中反复提到'老李/李建国/建国哥','用户'指的就是对话发起者李建国"
|
||||
}
|
||||
]
|
||||
|
||||
示例2 - "我"的消歧:
|
||||
输入历史:[{'Query': '张曼玉推荐了什么书?', 'Answer': '《小王子》'}]
|
||||
输入问题:"我推荐的书是什么?"
|
||||
输出:
|
||||
[
|
||||
{
|
||||
"original_question": "我推荐的书是什么?",
|
||||
"extended_question": "张曼玉推荐的书是什么?",
|
||||
"type": "单跳",
|
||||
"reason": "历史中提到张曼玉推荐了书,'我'指的就是张曼玉"
|
||||
}
|
||||
]
|
||||
|
||||
**Output format**
|
||||
**CRITICAL JSON FORMATTING REQUIREMENTS:**
|
||||
1. Use only standard ASCII double quotes (") for JSON structure - never use Chinese quotation marks ("") or other Unicode quotes
|
||||
|
||||
@@ -27,6 +27,30 @@
|
||||
比如:输入历史信息内容:[{'Query': '4月27日,我和你推荐过一本书,书名是什么?', 'ANswer': '张曼玉推荐了《小王子》'}]
|
||||
拆分问题:4月27日,我和你推荐过一本书,书名是什么?,可以拆分为:4月27日,张曼玉推荐过一本书,书名是什么?
|
||||
|
||||
## 指代消歧规则(Coreference Resolution):
|
||||
在拆分问题时,必须解析并替换所有指代词和抽象称呼,使问题具体化:
|
||||
|
||||
1. **"用户"的消歧**:
|
||||
- "用户是谁?" → 分析历史记录,找出对话发起者的姓名
|
||||
- 如果历史中有"我叫X"、"我的名字是X"、或多次提到某个人物(如"老李"、"李建国"),则"用户"指的就是这个人
|
||||
- 示例:历史中反复出现"老李/李建国/建国哥",则"用户是谁?"应拆分为"李建国是谁?"或"老李(李建国)是谁?"
|
||||
|
||||
2. **"我"的消歧**:
|
||||
- "我喜欢什么?" → 从历史中找出对话发起者的姓名,替换为"X喜欢什么?"
|
||||
- 示例:历史中有"张曼玉推荐了《小王子》",则"我推荐的书是什么?"应拆分为"张曼玉推荐的书是什么?"
|
||||
|
||||
3. **"他/她/它"的消歧**:
|
||||
- 从上下文或历史中找出最近提到的同类实体
|
||||
- 示例:历史中有"老李的同事叫他建国哥",则"他的同事怎么称呼他?"应拆分为"老李的同事怎么称呼他?"
|
||||
|
||||
4. **"那个人/这个人"的消歧**:
|
||||
- 从历史中找出最近提到的人物
|
||||
- 示例:历史中有"李建国",则"那个人的原名是什么?"应拆分为"李建国的原名是什么?"
|
||||
|
||||
5. **优先级**:
|
||||
- 如果历史记录中反复出现某个人物(如"老李"、"李建国"、"建国哥"),则"用户"很可能指的就是这个人
|
||||
- 如果无法从历史中确定指代对象,保留原问题,但在reason中说明"无法确定指代对象"
|
||||
|
||||
## 指令:
|
||||
你是一个智能数据拆分助手,请根据数据特性判断输入属于哪种类型:
|
||||
单跳(Single-hop)
|
||||
@@ -151,6 +175,34 @@
|
||||
]
|
||||
- 必须通过json.loads()的格式支持的形式输出
|
||||
- 必须通过json.loads()的格式支持的形式输出,响应必须是与此确切模式匹配的有效JSON对象。不要在JSON之前或之后包含任何文本。
|
||||
|
||||
## 指代消歧示例(重要):
|
||||
示例1 - "用户"的消歧:
|
||||
输入历史:[{'Query': '老李的原名叫什么?', 'Answer': '李建国'}, {'Query': '老李的同事叫他什么?', 'Answer': '建国哥'}]
|
||||
输入问题:"用户是谁?"
|
||||
输出:
|
||||
[
|
||||
{
|
||||
"id": "Q1",
|
||||
"question": "李建国是谁?",
|
||||
"type": "单跳",
|
||||
"reason": "历史中反复提到'老李/李建国/建国哥','用户'指的就是对话发起者李建国"
|
||||
}
|
||||
]
|
||||
|
||||
示例2 - "我"的消歧:
|
||||
输入历史:[{'Query': '张曼玉推荐了什么书?', 'Answer': '《小王子》'}]
|
||||
输入问题:"我推荐的书是什么?"
|
||||
输出:
|
||||
[
|
||||
{
|
||||
"id": "Q1",
|
||||
"question": "张曼玉推荐的书是什么?",
|
||||
"type": "单跳",
|
||||
"reason": "历史中提到张曼玉推荐了书,'我'指的就是张曼玉"
|
||||
}
|
||||
]
|
||||
|
||||
- 关键的JSON格式要求
|
||||
1.JSON结构仅使用标准ASCII双引号(“)-切勿使用中文引号(“”)或其他Unicode引号
|
||||
2.如果提取的语句文本包含引号,请使用反斜杠(\“)正确转义它们
|
||||
|
||||
@@ -3,8 +3,9 @@ import uuid
|
||||
from app.core.config import settings
|
||||
from typing import List, Dict, Any, Optional, Union
|
||||
|
||||
from app.core.logging_config import get_logger
|
||||
from app.core.memory.agent.utils.redis_base import (
|
||||
serialize_messages,
|
||||
serialize_messages,
|
||||
deserialize_messages,
|
||||
fix_encoding,
|
||||
format_session_data,
|
||||
@@ -14,12 +15,12 @@ from app.core.memory.agent.utils.redis_base import (
|
||||
get_current_timestamp
|
||||
)
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class RedisWriteStore:
|
||||
"""Redis Write 类型存储类,用于管理 save_session_write 相关的数据"""
|
||||
|
||||
|
||||
def __init__(self, host='localhost', port=6379, db=0, password=None, session_id=''):
|
||||
"""
|
||||
初始化 Redis 连接
|
||||
@@ -66,10 +67,10 @@ class RedisWriteStore:
|
||||
})
|
||||
result = pipe.execute()
|
||||
|
||||
print(f"[save_session_write] 保存结果: {result[0]}, session_id: {session_id}")
|
||||
logger.debug(f"[save_session_write] 保存结果: {result[0]}, session_id: {session_id}")
|
||||
return session_id
|
||||
except Exception as e:
|
||||
print(f"[save_session_write] 保存会话失败: {e}")
|
||||
logger.error(f"[save_session_write] 保存会话失败: {e}")
|
||||
raise e
|
||||
|
||||
def get_session_by_userid(self, userid: str) -> Union[List[Dict[str, str]], bool]:
|
||||
@@ -99,7 +100,7 @@ class RedisWriteStore:
|
||||
for key, data in zip(keys, all_data):
|
||||
if not data:
|
||||
continue
|
||||
|
||||
|
||||
# 从 write 类型读取,匹配 sessionid 字段
|
||||
if data.get('sessionid') == userid:
|
||||
# 从 key 中提取 session_id: session:write:{session_id}
|
||||
@@ -108,16 +109,16 @@ class RedisWriteStore:
|
||||
"sessionid": session_id,
|
||||
"messages": fix_encoding(data.get('messages', ''))
|
||||
})
|
||||
|
||||
|
||||
if not results:
|
||||
return False
|
||||
|
||||
print(f"[get_session_by_userid] userid={userid}, 找到 {len(results)} 条数据")
|
||||
|
||||
logger.debug(f"[get_session_by_userid] userid={userid}, 找到 {len(results)} 条数据")
|
||||
return results
|
||||
except Exception as e:
|
||||
print(f"[get_session_by_userid] 查询失败: {e}")
|
||||
logger.error(f"[get_session_by_userid] 查询失败: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def get_all_sessions_by_end_user_id(self, end_user_id: str) -> Union[List[Dict[str, Any]], bool]:
|
||||
"""
|
||||
通过 end_user_id 获取所有 write 类型的会话数据
|
||||
@@ -144,7 +145,7 @@ class RedisWriteStore:
|
||||
# 只查询 write 类型的 key
|
||||
keys = self.r.keys('session:write:*')
|
||||
if not keys:
|
||||
print(f"[get_all_sessions_by_end_user_id] 没有找到任何 write 类型的会话")
|
||||
logger.debug(f"[get_all_sessions_by_end_user_id] 没有找到任何 write 类型的会话")
|
||||
return False
|
||||
|
||||
# 批量获取数据
|
||||
@@ -158,12 +159,12 @@ class RedisWriteStore:
|
||||
for key, data in zip(keys, all_data):
|
||||
if not data:
|
||||
continue
|
||||
|
||||
|
||||
# 从 write 类型读取,匹配 sessionid 字段
|
||||
if data.get('sessionid') == end_user_id:
|
||||
# 从 key 中提取 session_id: session:write:{session_id}
|
||||
session_id = key.split(':')[-1]
|
||||
|
||||
|
||||
# 构建完整的会话信息
|
||||
session_info = {
|
||||
"session_id": session_id,
|
||||
@@ -173,23 +174,21 @@ class RedisWriteStore:
|
||||
"starttime": data.get('starttime', '')
|
||||
}
|
||||
results.append(session_info)
|
||||
|
||||
|
||||
if not results:
|
||||
print(f"[get_all_sessions_by_end_user_id] end_user_id={end_user_id}, 没有找到数据")
|
||||
logger.debug(f"[get_all_sessions_by_end_user_id] end_user_id={end_user_id}, 没有找到数据")
|
||||
return False
|
||||
|
||||
|
||||
# 按时间排序(最新的在前)
|
||||
results.sort(key=lambda x: x.get('starttime', ''), reverse=True)
|
||||
|
||||
print(f"[get_all_sessions_by_end_user_id] end_user_id={end_user_id}, 找到 {len(results)} 条数据")
|
||||
|
||||
logger.debug(f"[get_all_sessions_by_end_user_id] end_user_id={end_user_id}, 找到 {len(results)} 条数据")
|
||||
return results
|
||||
except Exception as e:
|
||||
print(f"[get_all_sessions_by_end_user_id] 查询失败: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
logger.error(f"[get_all_sessions_by_end_user_id] 查询失败: {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
def find_user_recent_sessions(self, userid: str,
|
||||
def find_user_recent_sessions(self, userid: str,
|
||||
minutes: int = 5) -> List[Dict[str, str]]:
|
||||
"""
|
||||
根据 userid 从 save_session_write 写入的数据中查询最近 N 分钟内的会话数据
|
||||
@@ -203,11 +202,11 @@ class RedisWriteStore:
|
||||
"""
|
||||
import time
|
||||
start_time = time.time()
|
||||
|
||||
|
||||
# 只查询 write 类型的 key
|
||||
keys = self.r.keys('session:write:*')
|
||||
if not keys:
|
||||
print(f"[find_user_recent_sessions] 查询耗时: {time.time() - start_time:.3f}秒, 结果数: 0")
|
||||
logger.debug(f"[find_user_recent_sessions] 查询耗时: {time.time() - start_time:.3f}秒, 结果数: 0")
|
||||
return []
|
||||
|
||||
# 批量获取数据
|
||||
@@ -221,7 +220,7 @@ class RedisWriteStore:
|
||||
for data in all_data:
|
||||
if not data:
|
||||
continue
|
||||
|
||||
|
||||
# 从 write 类型读取,匹配 sessionid 字段
|
||||
if data.get('sessionid') == userid and data.get('starttime'):
|
||||
# write 类型没有 aimessages,所以 Answer 为空
|
||||
@@ -230,15 +229,14 @@ class RedisWriteStore:
|
||||
"Answer": "",
|
||||
"starttime": data.get('starttime', '')
|
||||
})
|
||||
|
||||
|
||||
# 根据时间范围过滤
|
||||
filtered_items = filter_by_time_range(matched_items, minutes)
|
||||
# 排序并移除时间字段
|
||||
result_items = sort_and_limit_results(filtered_items, limit=None)
|
||||
print(result_items)
|
||||
result_items = sort_and_limit_results(filtered_items)
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
print(f"[find_user_recent_sessions] userid={userid}, minutes={minutes}, "
|
||||
logger.debug(f"[find_user_recent_sessions] userid={userid}, minutes={minutes}, "
|
||||
f"查询耗时: {elapsed_time:.3f}秒, 结果数: {len(result_items)}")
|
||||
|
||||
return result_items
|
||||
@@ -258,7 +256,7 @@ class RedisWriteStore:
|
||||
|
||||
class RedisCountStore:
|
||||
"""Redis Count 类型存储类,用于管理访问次数统计相关的数据"""
|
||||
|
||||
|
||||
def __init__(self, host='localhost', port=6379, db=0, password=None, session_id=''):
|
||||
"""
|
||||
初始化 Redis 连接
|
||||
@@ -278,7 +276,7 @@ class RedisCountStore:
|
||||
decode_responses=True,
|
||||
encoding='utf-8'
|
||||
)
|
||||
self.uudi = session_id
|
||||
self.uuid = session_id
|
||||
|
||||
def save_sessions_count(self, end_user_id: str, count: int, messages: Any) -> str:
|
||||
"""
|
||||
@@ -295,26 +293,26 @@ class RedisCountStore:
|
||||
session_id = str(uuid.uuid4())
|
||||
key = generate_session_key(session_id, key_type="count")
|
||||
index_key = f'session:count:index:{end_user_id}' # 索引键
|
||||
|
||||
|
||||
pipe = self.r.pipeline()
|
||||
pipe.hset(key, mapping={
|
||||
"id": self.uudi,
|
||||
"id": self.uuid,
|
||||
"end_user_id": end_user_id,
|
||||
"count": int(count),
|
||||
"messages": serialize_messages(messages),
|
||||
"starttime": get_current_timestamp()
|
||||
})
|
||||
pipe.expire(key, 30 * 24 * 60 * 60) # 30天过期
|
||||
|
||||
|
||||
# 创建索引:end_user_id -> session_id 映射
|
||||
pipe.set(index_key, session_id, ex=30 * 24 * 60 * 60)
|
||||
|
||||
|
||||
result = pipe.execute()
|
||||
|
||||
print(f"[save_sessions_count] 保存结果: {result}, session_id: {session_id}")
|
||||
|
||||
logger.debug(f"[save_sessions_count] 保存结果: {result}, session_id: {session_id}")
|
||||
return session_id
|
||||
|
||||
def get_sessions_count(self, end_user_id: str) -> Union[List[Any], bool]:
|
||||
def get_sessions_count(self, end_user_id: str) -> tuple[int, list[dict]] | bool:
|
||||
"""
|
||||
通过 end_user_id 查询访问次数统计
|
||||
|
||||
@@ -327,7 +325,7 @@ class RedisCountStore:
|
||||
try:
|
||||
# 使用索引键快速查找
|
||||
index_key = f'session:count:index:{end_user_id}'
|
||||
|
||||
|
||||
# 检查索引键类型,避免 WRONGTYPE 错误
|
||||
try:
|
||||
key_type = self.r.type(index_key)
|
||||
@@ -335,35 +333,40 @@ class RedisCountStore:
|
||||
self.r.delete(index_key)
|
||||
return False
|
||||
except Exception as type_error:
|
||||
print(f"[get_sessions_count] 检查键类型失败: {type_error}")
|
||||
|
||||
logger.error(f"[get_sessions_count] 检查键类型失败: {type_error}")
|
||||
|
||||
session_id = self.r.get(index_key)
|
||||
|
||||
|
||||
if not session_id:
|
||||
return False
|
||||
|
||||
|
||||
# 直接获取数据
|
||||
key = generate_session_key(session_id, key_type="count")
|
||||
data = self.r.hgetall(key)
|
||||
|
||||
|
||||
if not data:
|
||||
# 索引存在但数据不存在,清理索引
|
||||
self.r.delete(index_key)
|
||||
return False
|
||||
|
||||
|
||||
count = data.get('count')
|
||||
messages_str = data.get('messages')
|
||||
|
||||
|
||||
if count is not None:
|
||||
messages = deserialize_messages(messages_str)
|
||||
return [int(count), messages]
|
||||
|
||||
messages: list[dict] = deserialize_messages(messages_str)
|
||||
return int(count), messages
|
||||
|
||||
return False
|
||||
except Exception as e:
|
||||
print(f"[get_sessions_count] 查询失败: {e}")
|
||||
logger.error(f"[get_sessions_count] 查询失败: {e}")
|
||||
return False
|
||||
def update_sessions_count(self, end_user_id: str, new_count: int,
|
||||
messages: Any) -> bool:
|
||||
|
||||
def update_sessions_count(
|
||||
self,
|
||||
end_user_id: str,
|
||||
new_count: int,
|
||||
messages: Any
|
||||
) -> bool:
|
||||
"""
|
||||
通过 end_user_id 修改访问次数统计(优化版:使用索引)
|
||||
|
||||
@@ -378,39 +381,39 @@ class RedisCountStore:
|
||||
try:
|
||||
# 使用索引键快速查找
|
||||
index_key = f'session:count:index:{end_user_id}'
|
||||
|
||||
|
||||
# 检查索引键类型,避免 WRONGTYPE 错误
|
||||
try:
|
||||
key_type = self.r.type(index_key)
|
||||
if key_type != 'string' and key_type != 'none':
|
||||
# 索引键类型错误,删除并返回 False
|
||||
print(f"[update_sessions_count] 索引键类型错误: {key_type},删除索引")
|
||||
logger.warning(f"[update_sessions_count] 索引键类型错误: {key_type},删除索引")
|
||||
self.r.delete(index_key)
|
||||
print(f"[update_sessions_count] 未找到记录: end_user_id={end_user_id}")
|
||||
logger.debug(f"[update_sessions_count] 未找到记录: end_user_id={end_user_id}")
|
||||
return False
|
||||
except Exception as type_error:
|
||||
print(f"[update_sessions_count] 检查键类型失败: {type_error}")
|
||||
|
||||
logger.error(f"[update_sessions_count] 检查键类型失败: {type_error}")
|
||||
|
||||
session_id = self.r.get(index_key)
|
||||
|
||||
|
||||
if not session_id:
|
||||
print(f"[update_sessions_count] 未找到记录: end_user_id={end_user_id}")
|
||||
logger.debug(f"[update_sessions_count] 未找到记录: end_user_id={end_user_id}")
|
||||
return False
|
||||
|
||||
|
||||
# 直接更新数据
|
||||
key = generate_session_key(session_id, key_type="count")
|
||||
messages_str = serialize_messages(messages)
|
||||
|
||||
|
||||
pipe = self.r.pipeline()
|
||||
pipe.hset(key, 'count', int(new_count))
|
||||
pipe.hset(key, 'count', str(new_count))
|
||||
pipe.hset(key, 'messages', messages_str)
|
||||
result = pipe.execute()
|
||||
|
||||
print(f"[update_sessions_count] 更新成功: end_user_id={end_user_id}, new_count={new_count}, key={key}")
|
||||
|
||||
logger.debug(f"[update_sessions_count] 更新成功: end_user_id={end_user_id}, new_count={new_count}, key={key}")
|
||||
return True
|
||||
|
||||
|
||||
except Exception as e:
|
||||
print(f"[update_sessions_count] 更新失败: {e}")
|
||||
logger.debug(f"[update_sessions_count] 更新失败: {e}")
|
||||
return False
|
||||
|
||||
def delete_all_count_sessions(self) -> int:
|
||||
@@ -428,7 +431,7 @@ class RedisCountStore:
|
||||
|
||||
class RedisSessionStore:
|
||||
"""Redis 会话存储类,用于管理会话数据"""
|
||||
|
||||
|
||||
def __init__(self, host='localhost', port=6379, db=0, password=None, session_id=''):
|
||||
"""
|
||||
初始化 Redis 连接
|
||||
@@ -451,9 +454,9 @@ class RedisSessionStore:
|
||||
self.uudi = session_id
|
||||
|
||||
# ==================== 写入操作 ====================
|
||||
|
||||
def save_session(self, userid: str, messages: str, aimessages: str,
|
||||
apply_id: str, end_user_id: str) -> str:
|
||||
|
||||
def save_session(self, userid: str, messages: str, aimessages: str,
|
||||
apply_id: str, end_user_id: str) -> str:
|
||||
"""
|
||||
写入一条会话数据,返回 session_id
|
||||
|
||||
@@ -483,14 +486,14 @@ class RedisSessionStore:
|
||||
})
|
||||
result = pipe.execute()
|
||||
|
||||
print(f"[save_session] 保存结果: {result[0]}, session_id: {session_id}")
|
||||
logger.debug(f"[save_session] 保存结果: {result[0]}, session_id: {session_id}")
|
||||
return session_id
|
||||
except Exception as e:
|
||||
print(f"[save_session] 保存会话失败: {e}")
|
||||
logger.error(f"[save_session] 保存会话失败: {e}")
|
||||
raise e
|
||||
|
||||
# ==================== 读取操作 ====================
|
||||
|
||||
|
||||
def get_session(self, session_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
读取一条会话数据
|
||||
@@ -520,8 +523,8 @@ class RedisSessionStore:
|
||||
sessions[sid] = self.get_session(sid)
|
||||
return sessions
|
||||
|
||||
def find_user_apply_group(self, sessionid: str, apply_id: str,
|
||||
end_user_id: str) -> List[Dict[str, str]]:
|
||||
def find_user_apply_group(self, sessionid: str, apply_id: str,
|
||||
end_user_id: str) -> List[Dict[str, str]]:
|
||||
"""
|
||||
根据 sessionid、apply_id 和 end_user_id 查询会话数据,返回最新的6条
|
||||
|
||||
@@ -535,10 +538,10 @@ class RedisSessionStore:
|
||||
"""
|
||||
import time
|
||||
start_time = time.time()
|
||||
|
||||
|
||||
keys = self.r.keys('session:*')
|
||||
if not keys:
|
||||
print(f"[find_user_apply_group] 查询耗时: {time.time() - start_time:.3f}秒, 结果数: 0")
|
||||
logger.debug(f"[find_user_apply_group] 查询耗时: {time.time() - start_time:.3f}秒, 结果数: 0")
|
||||
return []
|
||||
|
||||
# 批量获取数据
|
||||
@@ -556,21 +559,21 @@ class RedisSessionStore:
|
||||
continue
|
||||
|
||||
if (data.get('apply_id') == apply_id and
|
||||
data.get('end_user_id') == end_user_id):
|
||||
data.get('end_user_id') == end_user_id):
|
||||
# 支持模糊匹配或完全匹配 sessionid
|
||||
if sessionid in data.get('sessionid', '') or data.get('sessionid') == sessionid:
|
||||
matched_items.append(format_session_data(data, include_time=True))
|
||||
|
||||
|
||||
# 排序、限制数量并移除时间字段
|
||||
result_items = sort_and_limit_results(matched_items, limit=6)
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
print(f"[find_user_apply_group] 查询耗时: {elapsed_time:.3f}秒, 结果数: {len(result_items)}")
|
||||
logger.debug(f"[find_user_apply_group] 查询耗时: {elapsed_time:.3f}秒, 结果数: {len(result_items)}")
|
||||
|
||||
return result_items
|
||||
|
||||
# ==================== 更新操作 ====================
|
||||
|
||||
|
||||
def update_session(self, session_id: str, field: str, value: Any) -> bool:
|
||||
"""
|
||||
更新单个字段
|
||||
@@ -591,7 +594,7 @@ class RedisSessionStore:
|
||||
return bool(results[0])
|
||||
|
||||
# ==================== 删除操作 ====================
|
||||
|
||||
|
||||
def delete_session(self, session_id: str) -> int:
|
||||
"""
|
||||
删除单条会话
|
||||
@@ -632,7 +635,7 @@ class RedisSessionStore:
|
||||
|
||||
keys = self.r.keys('session:*')
|
||||
if not keys:
|
||||
print("[delete_duplicate_sessions] 没有会话数据")
|
||||
logger.debug("[delete_duplicate_sessions] 没有会话数据")
|
||||
return 0
|
||||
|
||||
# 批量获取所有数据
|
||||
@@ -678,7 +681,7 @@ class RedisSessionStore:
|
||||
deleted_count += len(batch)
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
print(f"[delete_duplicate_sessions] 删除重复会话数量: {deleted_count}, 耗时: {elapsed_time:.3f}秒")
|
||||
logger.debug(f"[delete_duplicate_sessions] 删除重复会话数量: {deleted_count}, 耗时: {elapsed_time:.3f}秒")
|
||||
return deleted_count
|
||||
|
||||
|
||||
|
||||
@@ -6,14 +6,18 @@ pipeline. Only MemoryConfig is needed - clients are constructed internally.
|
||||
"""
|
||||
import asyncio
|
||||
import time
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import List, Optional
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from app.core.logging_config import get_agent_logger
|
||||
from app.core.memory.agent.utils.get_dialogs import get_chunked_dialogs
|
||||
from app.core.memory.storage_services.extraction_engine.deduplication.deduped_and_disamb import _USER_PLACEHOLDER_NAMES
|
||||
from app.core.memory.storage_services.extraction_engine.extraction_orchestrator import ExtractionOrchestrator
|
||||
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.memory_summary import memory_summary_generation
|
||||
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.memory_summary import \
|
||||
memory_summary_generation
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.core.memory.utils.log.logging_utils import log_time
|
||||
from app.db import get_db_context
|
||||
@@ -23,18 +27,17 @@ from app.repositories.neo4j.graph_saver import save_dialog_and_statements_to_neo
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.schemas.memory_config_schema import MemoryConfig
|
||||
|
||||
|
||||
load_dotenv()
|
||||
|
||||
logger = get_agent_logger(__name__)
|
||||
|
||||
|
||||
async def write(
|
||||
end_user_id: str,
|
||||
memory_config: MemoryConfig,
|
||||
messages: list,
|
||||
ref_id: str = "wyl20251027",
|
||||
language: str = "zh",
|
||||
end_user_id: str,
|
||||
memory_config: MemoryConfig,
|
||||
messages: list,
|
||||
ref_id: str = "",
|
||||
language: str = "zh",
|
||||
) -> None:
|
||||
"""
|
||||
Execute the complete knowledge extraction pipeline.
|
||||
@@ -43,9 +46,11 @@ async def write(
|
||||
end_user_id: Group identifier
|
||||
memory_config: MemoryConfig object containing all configuration
|
||||
messages: Structured message list [{"role": "user", "content": "..."}, ...]
|
||||
ref_id: Reference ID, defaults to "wyl20251027"
|
||||
ref_id: Reference ID, defaults to ""
|
||||
language: 语言类型 ("zh" 中文, "en" 英文),默认中文
|
||||
"""
|
||||
if not ref_id:
|
||||
ref_id = uuid.uuid4().hex
|
||||
# Extract config values
|
||||
embedding_model_id = str(memory_config.embedding_model_id)
|
||||
chunker_strategy = memory_config.chunker_strategy
|
||||
@@ -99,14 +104,14 @@ async def write(
|
||||
if memory_config.scene_id:
|
||||
try:
|
||||
from app.core.memory.ontology_services.ontology_type_loader import load_ontology_types_for_scene
|
||||
|
||||
|
||||
with get_db_context() as db:
|
||||
ontology_types = load_ontology_types_for_scene(
|
||||
scene_id=memory_config.scene_id,
|
||||
workspace_id=memory_config.workspace_id,
|
||||
db=db
|
||||
)
|
||||
|
||||
|
||||
if ontology_types:
|
||||
logger.info(
|
||||
f"Loaded {len(ontology_types.types)} ontology types for scene_id: {memory_config.scene_id}"
|
||||
@@ -135,9 +140,11 @@ async def write(
|
||||
all_chunk_nodes,
|
||||
all_statement_nodes,
|
||||
all_entity_nodes,
|
||||
all_perceptual_nodes,
|
||||
all_statement_chunk_edges,
|
||||
all_statement_entity_edges,
|
||||
all_entity_entity_edges,
|
||||
all_perceptual_edges,
|
||||
all_dedup_details,
|
||||
) = await orchestrator.run(chunked_dialogs, is_pilot_run=False)
|
||||
|
||||
@@ -145,11 +152,24 @@ async def write(
|
||||
|
||||
# Step 3: Save all data to Neo4j database
|
||||
step_start = time.time()
|
||||
from app.repositories.neo4j.create_indexes import create_fulltext_indexes
|
||||
|
||||
# Neo4j 写入前:清洗用户/AI助手实体之间的别名交叉污染
|
||||
# 从 Neo4j 查询已有的 AI 助手别名,与本轮实体中的 AI 助手别名合并,
|
||||
# 确保用户实体的 aliases 不包含 AI 助手的名字
|
||||
try:
|
||||
await create_fulltext_indexes()
|
||||
from app.core.memory.storage_services.extraction_engine.deduplication.deduped_and_disamb import (
|
||||
clean_cross_role_aliases,
|
||||
fetch_neo4j_assistant_aliases,
|
||||
)
|
||||
neo4j_assistant_aliases = set()
|
||||
if all_entity_nodes:
|
||||
_eu_id = all_entity_nodes[0].end_user_id
|
||||
if _eu_id:
|
||||
neo4j_assistant_aliases = await fetch_neo4j_assistant_aliases(neo4j_connector, _eu_id)
|
||||
clean_cross_role_aliases(all_entity_nodes, external_assistant_aliases=neo4j_assistant_aliases)
|
||||
logger.info(f"Neo4j 写入前别名清洗完成,AI助手别名排除集大小: {len(neo4j_assistant_aliases)}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating indexes: {e}", exc_info=True)
|
||||
logger.warning(f"Neo4j 写入前别名清洗失败(不影响主流程): {e}")
|
||||
|
||||
# 添加死锁重试机制
|
||||
max_retries = 3
|
||||
@@ -162,13 +182,63 @@ async def write(
|
||||
chunk_nodes=all_chunk_nodes,
|
||||
statement_nodes=all_statement_nodes,
|
||||
entity_nodes=all_entity_nodes,
|
||||
perceptual_nodes=all_perceptual_nodes,
|
||||
statement_chunk_edges=all_statement_chunk_edges,
|
||||
statement_entity_edges=all_statement_entity_edges,
|
||||
entity_edges=all_entity_entity_edges,
|
||||
connector=neo4j_connector
|
||||
perceptual_edges=all_perceptual_edges,
|
||||
connector=neo4j_connector,
|
||||
)
|
||||
if success:
|
||||
logger.info("Successfully saved all data to Neo4j")
|
||||
|
||||
if all_entity_nodes:
|
||||
end_user_id = all_entity_nodes[0].end_user_id
|
||||
|
||||
# Neo4j 写入完成后,用 PgSQL 权威 aliases 覆盖 Neo4j 用户实体
|
||||
try:
|
||||
from app.repositories.end_user_info_repository import EndUserInfoRepository
|
||||
if end_user_id:
|
||||
with get_db_context() as db_session:
|
||||
info = EndUserInfoRepository(db_session).get_by_end_user_id(uuid.UUID(end_user_id))
|
||||
pg_aliases = info.aliases if info and info.aliases else []
|
||||
if info is not None:
|
||||
# 将 Python 侧占位名集合作为参数传入,避免 Cypher 硬编码
|
||||
placeholder_names = list(_USER_PLACEHOLDER_NAMES)
|
||||
await neo4j_connector.execute_query(
|
||||
"""
|
||||
MATCH (e:ExtractedEntity)
|
||||
WHERE e.end_user_id = $end_user_id AND toLower(e.name) IN $placeholder_names
|
||||
SET e.aliases = $aliases
|
||||
""",
|
||||
end_user_id=end_user_id, aliases=pg_aliases,
|
||||
placeholder_names=placeholder_names,
|
||||
)
|
||||
logger.info(f"[AliasSync] Neo4j 用户实体 aliases 已用 PgSQL 权威源覆盖: {pg_aliases}")
|
||||
except Exception as sync_err:
|
||||
logger.warning(f"[AliasSync] PgSQL→Neo4j aliases 同步失败(不影响主流程): {sync_err}")
|
||||
|
||||
# 使用 Celery 异步任务触发聚类(不阻塞主流程)
|
||||
try:
|
||||
from app.tasks import run_incremental_clustering
|
||||
|
||||
new_entity_ids = [e.id for e in all_entity_nodes]
|
||||
task = run_incremental_clustering.apply_async(
|
||||
kwargs={
|
||||
"end_user_id": end_user_id,
|
||||
"new_entity_ids": new_entity_ids,
|
||||
"llm_model_id": str(memory_config.llm_model_id) if memory_config.llm_model_id else None,
|
||||
"embedding_model_id": str(memory_config.embedding_model_id) if memory_config.embedding_model_id else None,
|
||||
},
|
||||
priority=3,
|
||||
)
|
||||
logger.info(
|
||||
f"[Clustering] 增量聚类任务已提交到 Celery - "
|
||||
f"task_id={task.id}, end_user_id={end_user_id}, entity_count={len(new_entity_ids)}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"[Clustering] 提交聚类任务失败(不影响主流程): {e}", exc_info=True)
|
||||
|
||||
break
|
||||
else:
|
||||
logger.warning("Failed to save some data to Neo4j")
|
||||
@@ -202,9 +272,8 @@ async def write(
|
||||
summaries = await memory_summary_generation(
|
||||
chunked_dialogs, llm_client=llm_client, embedder_client=embedder_client, language=language
|
||||
)
|
||||
|
||||
ms_connector = Neo4jConnector()
|
||||
try:
|
||||
ms_connector = Neo4jConnector()
|
||||
await add_memory_summary_nodes(summaries, ms_connector)
|
||||
await add_memory_summary_statement_edges(summaries, ms_connector)
|
||||
finally:
|
||||
@@ -244,5 +313,21 @@ async def write(
|
||||
except Exception as cache_err:
|
||||
logger.warning(f"[WRITE] 写入活动统计缓存失败(不影响主流程): {cache_err}", exc_info=True)
|
||||
|
||||
# Close LLM/Embedder underlying httpx clients to prevent
|
||||
# 'RuntimeError: Event loop is closed' during garbage collection
|
||||
for client_obj in (llm_client, embedder_client):
|
||||
try:
|
||||
underlying = getattr(client_obj, 'client', None) or getattr(client_obj, 'model', None)
|
||||
if underlying is None:
|
||||
continue
|
||||
# Unwrap RedBearLLM / RedBearEmbeddings to get the LangChain model
|
||||
inner = getattr(underlying, '_model', underlying)
|
||||
# LangChain OpenAI models expose async_client (httpx.AsyncClient)
|
||||
http_client = getattr(inner, 'async_client', None)
|
||||
if http_client is not None and hasattr(http_client, 'aclose'):
|
||||
await http_client.aclose()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
logger.info("=== Pipeline Complete ===")
|
||||
logger.info(f"Total execution time: {total_time:.2f} seconds")
|
||||
logger.info(f"Total execution time: {total_time:.2f} seconds")
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
from typing import Any, List
|
||||
import re
|
||||
import os
|
||||
import asyncio
|
||||
import json
|
||||
import numpy as np
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, List
|
||||
|
||||
import numpy as np
|
||||
|
||||
# Fix tokenizer parallelism warning
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
@@ -246,6 +246,7 @@ class ChunkerClient:
|
||||
"total_sub_chunks": len(sub_chunks),
|
||||
"chunker_strategy": self.chunker_config.chunker_strategy,
|
||||
},
|
||||
files=msg.files
|
||||
)
|
||||
dialogue.chunks.append(chunk)
|
||||
else:
|
||||
@@ -258,6 +259,7 @@ class ChunkerClient:
|
||||
"message_role": msg.role,
|
||||
"chunker_strategy": self.chunker_config.chunker_strategy,
|
||||
},
|
||||
files=msg.files
|
||||
)
|
||||
dialogue.chunks.append(chunk)
|
||||
|
||||
|
||||
@@ -56,7 +56,7 @@ class LLMClient(ABC):
|
||||
self.max_retries = self.config.max_retries
|
||||
self.timeout = self.config.timeout
|
||||
|
||||
logger.info(
|
||||
logger.debug(
|
||||
f"初始化 LLM 客户端: provider={self.provider}, "
|
||||
f"model={self.model_name}, max_retries={self.max_retries}"
|
||||
)
|
||||
|
||||
@@ -65,7 +65,7 @@ class OpenAIClient(LLMClient):
|
||||
type=type_
|
||||
)
|
||||
|
||||
logger.info(f"OpenAI 客户端初始化完成: type={type_}")
|
||||
logger.debug(f"OpenAI 客户端初始化完成: type={type_}")
|
||||
|
||||
async def chat(self, messages: List[Dict[str, str]], **kwargs) -> Any:
|
||||
"""
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
OpenAI Embedder 客户端实现
|
||||
|
||||
基于 LangChain 和 RedBearEmbeddings 的 OpenAI 嵌入模型客户端实现。
|
||||
自动支持火山引擎的多模态 Embedding。
|
||||
"""
|
||||
|
||||
from typing import List
|
||||
@@ -13,6 +14,7 @@ from app.core.memory.llm_tools.embedder_client import (
|
||||
)
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
from app.core.models.embedding import RedBearEmbeddings
|
||||
from app.models.models_model import ModelProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -25,6 +27,7 @@ class OpenAIEmbedderClient(EmbedderClient):
|
||||
- 批量文本嵌入
|
||||
- 自动重试机制
|
||||
- 错误处理
|
||||
- 火山引擎多模态 Embedding(自动识别)
|
||||
"""
|
||||
|
||||
def __init__(self, model_config: RedBearModelConfig):
|
||||
@@ -36,7 +39,7 @@ class OpenAIEmbedderClient(EmbedderClient):
|
||||
"""
|
||||
super().__init__(model_config)
|
||||
|
||||
# 初始化 RedBearEmbeddings 模型
|
||||
# 初始化 RedBearEmbeddings(自动支持火山引擎多模态)
|
||||
self.model = RedBearEmbeddings(
|
||||
RedBearModelConfig(
|
||||
model_name=self.model_name,
|
||||
@@ -47,8 +50,9 @@ class OpenAIEmbedderClient(EmbedderClient):
|
||||
timeout=self.timeout,
|
||||
)
|
||||
)
|
||||
self.is_multimodal = self.model.is_multimodal_supported()
|
||||
|
||||
logger.info("OpenAI Embedder 客户端初始化完成")
|
||||
logger.info(f"OpenAI Embedder 客户端初始化完成 (provider={self.provider}, multimodal={self.is_multimodal})")
|
||||
|
||||
async def response(
|
||||
self,
|
||||
@@ -77,7 +81,14 @@ class OpenAIEmbedderClient(EmbedderClient):
|
||||
return []
|
||||
|
||||
# 生成嵌入向量
|
||||
embeddings = await self.model.aembed_documents(texts)
|
||||
if self.is_multimodal:
|
||||
# 火山引擎多模态 Embedding
|
||||
embeddings = await self.model.aembed_multimodal(
|
||||
[{"type": "text", "text": text} for text in texts]
|
||||
)
|
||||
else:
|
||||
# 普通 Embedding
|
||||
embeddings = await self.model.aembed_documents(texts)
|
||||
|
||||
logger.debug(f"成功生成 {len(embeddings)} 个嵌入向量")
|
||||
return embeddings
|
||||
|
||||
@@ -58,6 +58,14 @@ from app.core.memory.models.triplet_models import (
|
||||
TripletExtractionResponse,
|
||||
)
|
||||
|
||||
# User metadata models
|
||||
from app.core.memory.models.metadata_models import (
|
||||
UserMetadata,
|
||||
UserMetadataBehavioralHints,
|
||||
UserMetadataProfile,
|
||||
MetadataExtractionResponse,
|
||||
)
|
||||
|
||||
# Ontology scenario models (LLM extracted from scenarios)
|
||||
from app.core.memory.models.ontology_scenario_models import (
|
||||
OntologyClass,
|
||||
@@ -124,6 +132,10 @@ __all__ = [
|
||||
"Entity",
|
||||
"Triplet",
|
||||
"TripletExtractionResponse",
|
||||
"UserMetadata",
|
||||
"UserMetadataBehavioralHints",
|
||||
"UserMetadataProfile",
|
||||
"MetadataExtractionResponse",
|
||||
# Ontology models
|
||||
"OntologyClass",
|
||||
"OntologyExtractionResponse",
|
||||
|
||||
@@ -6,6 +6,7 @@ of the memory system including LLM, chunking, pruning, and search.
|
||||
Classes:
|
||||
LLMConfig: Configuration for LLM client
|
||||
ChunkerConfig: Configuration for dialogue chunking
|
||||
OntologyClassInfo: Single ontology class with name and description
|
||||
PruningConfig: Configuration for semantic pruning
|
||||
TemporalSearchParams: Parameters for temporal search queries
|
||||
"""
|
||||
@@ -50,30 +51,41 @@ class ChunkerConfig(BaseModel):
|
||||
min_characters_per_chunk: Optional[int] = Field(24, ge=0, description="The minimum number of characters in each chunk.")
|
||||
|
||||
|
||||
class OntologyClassInfo(BaseModel):
|
||||
"""本体类型的名称与语义描述,用于剪枝提示词注入。
|
||||
|
||||
Attributes:
|
||||
class_name: 本体类型名称(如"患者"、"课程")
|
||||
class_description: 本体类型语义描述,告知 LLM 该类型在当前场景下的含义
|
||||
"""
|
||||
class_name: str = Field(..., description="本体类型名称")
|
||||
class_description: str = Field(default="", description="本体类型语义描述")
|
||||
|
||||
|
||||
class PruningConfig(BaseModel):
|
||||
"""Configuration for semantic pruning of dialogue content.
|
||||
|
||||
Attributes:
|
||||
pruning_switch: Enable or disable semantic pruning
|
||||
pruning_scene: Scene name for pruning, either a built-in key
|
||||
('education', 'online_service', 'outbound') or a custom scene_name
|
||||
from ontology_scene table
|
||||
pruning_scene: Scene name for pruning from ontology_scene table
|
||||
pruning_threshold: Pruning ratio (0-0.9, max 0.9 to avoid complete removal)
|
||||
scene_id: Optional ontology scene UUID, used to load custom ontology classes
|
||||
ontology_classes: List of class_name strings from ontology_class table,
|
||||
injected into the prompt when pruning_scene is not a built-in scene
|
||||
scene_id: Optional ontology scene UUID
|
||||
ontology_class_infos: Full ontology class info (name + description) from
|
||||
ontology_class table, injected into the pruning prompt to drive
|
||||
scene-aware preservation decisions
|
||||
"""
|
||||
pruning_switch: bool = Field(False, description="Enable semantic pruning when True.")
|
||||
pruning_scene: str = Field(
|
||||
"education",
|
||||
description="Scene for pruning: built-in key or custom scene_name from ontology_scene.",
|
||||
description="Scene name from ontology_scene table.",
|
||||
)
|
||||
pruning_threshold: float = Field(
|
||||
0.5, ge=0.0, le=0.9,
|
||||
description="Pruning ratio within 0-0.9 (max 0.9 to avoid termination).")
|
||||
scene_id: Optional[str] = Field(None, description="Ontology scene UUID (optional).")
|
||||
ontology_classes: Optional[List[str]] = Field(
|
||||
None, description="Class names from ontology_class table for custom scenes."
|
||||
ontology_class_infos: List[OntologyClassInfo] = Field(
|
||||
default_factory=list,
|
||||
description="Full ontology class info (name + description) injected into pruning prompt."
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -44,21 +44,21 @@ def parse_historical_datetime(v):
|
||||
"""
|
||||
if v is None:
|
||||
return v
|
||||
|
||||
|
||||
# 处理 Neo4j DateTime 对象
|
||||
if hasattr(v, 'to_native'):
|
||||
return v.to_native()
|
||||
|
||||
|
||||
# 处理 Python datetime 对象
|
||||
if isinstance(v, datetime):
|
||||
return v
|
||||
|
||||
|
||||
if isinstance(v, str):
|
||||
# 匹配 ISO 8601 格式:YYYY-MM-DD 或 YYYY-MM-DDTHH:MM:SS[.ffffff][Z|±HH:MM]
|
||||
# 支持1-4位年份
|
||||
pattern = r'^(\d{1,4})-(\d{2})-(\d{2})(?:T(\d{2}):(\d{2}):(\d{2})(?:\.(\d+))?(?:Z|([+-]\d{2}:\d{2}))?)?'
|
||||
match = re.match(pattern, v)
|
||||
|
||||
|
||||
if match:
|
||||
try:
|
||||
year = int(match.group(1))
|
||||
@@ -68,31 +68,31 @@ def parse_historical_datetime(v):
|
||||
minute = int(match.group(5)) if match.group(5) else 0
|
||||
second = int(match.group(6)) if match.group(6) else 0
|
||||
microsecond = 0
|
||||
|
||||
|
||||
# 处理微秒
|
||||
if match.group(7):
|
||||
# 补齐或截断到6位
|
||||
us_str = match.group(7).ljust(6, '0')[:6]
|
||||
microsecond = int(us_str)
|
||||
|
||||
|
||||
# 处理时区
|
||||
tzinfo = None
|
||||
if 'Z' in v or match.group(8):
|
||||
tzinfo = timezone.utc
|
||||
|
||||
|
||||
# 创建 datetime 对象
|
||||
return datetime(year, month, day, hour, minute, second, microsecond, tzinfo=tzinfo)
|
||||
|
||||
|
||||
except (ValueError, OverflowError):
|
||||
# 日期值无效(如月份13、日期32等)
|
||||
return None
|
||||
|
||||
|
||||
# 如果不匹配模式,尝试使用 fromisoformat(用于标准格式)
|
||||
try:
|
||||
return datetime.fromisoformat(v.replace('Z', '+00:00'))
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
return v
|
||||
|
||||
|
||||
@@ -114,7 +114,7 @@ class Edge(BaseModel):
|
||||
end_user_id: str = Field(..., description="The end user ID of the edge.")
|
||||
run_id: str = Field(default_factory=lambda: uuid4().hex, description="Unique identifier for this pipeline run.")
|
||||
created_at: datetime = Field(..., description="The valid time of the edge from system perspective.")
|
||||
expired_at: Optional[datetime] = Field(None, description="The expired time of the edge from system perspective.")
|
||||
expired_at: Optional[datetime] = Field(default=None, description="The expired time of the edge from system perspective.")
|
||||
|
||||
|
||||
class ChunkEdge(Edge):
|
||||
@@ -167,7 +167,7 @@ class EntityEntityEdge(Edge):
|
||||
source_statement_id: str = Field(..., description="Statement where this relationship was extracted")
|
||||
valid_at: Optional[datetime] = Field(None, description="Temporal validity start")
|
||||
invalid_at: Optional[datetime] = Field(None, description="Temporal validity end")
|
||||
|
||||
|
||||
@field_validator('valid_at', 'invalid_at', mode='before')
|
||||
@classmethod
|
||||
def validate_datetime(cls, v):
|
||||
@@ -175,6 +175,12 @@ class EntityEntityEdge(Edge):
|
||||
return parse_historical_datetime(v)
|
||||
|
||||
|
||||
class PerceptualEdge(Edge):
|
||||
"""Edge connecting perceptual nodes to their source chunks
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class Node(BaseModel):
|
||||
"""Base class for all graph nodes in the knowledge graph.
|
||||
|
||||
@@ -206,7 +212,8 @@ class DialogueNode(Node):
|
||||
ref_id: str = Field(..., description="Reference identifier of the dialog")
|
||||
content: str = Field(..., description="Dialogue content")
|
||||
dialog_embedding: Optional[List[float]] = Field(None, description="Dialog embedding vector")
|
||||
config_id: Optional[int | str] = Field(None, description="Configuration ID used to process this dialogue (integer or string)")
|
||||
config_id: Optional[int | str] = Field(None,
|
||||
description="Configuration ID used to process this dialogue (integer or string)")
|
||||
|
||||
|
||||
class StatementNode(Node):
|
||||
@@ -241,17 +248,17 @@ class StatementNode(Node):
|
||||
chunk_id: str = Field(..., description="ID of the parent chunk")
|
||||
stmt_type: str = Field(..., description="Type of the statement")
|
||||
statement: str = Field(..., description="The statement text content")
|
||||
|
||||
|
||||
# Speaker identification
|
||||
speaker: Optional[str] = Field(
|
||||
None,
|
||||
description="Speaker identifier: 'user' for user messages, 'assistant' for AI responses"
|
||||
)
|
||||
|
||||
|
||||
# Emotion fields (ordered as requested, emotion_intensity first for display)
|
||||
emotion_intensity: Optional[float] = Field(
|
||||
None,
|
||||
ge=0.0,
|
||||
None,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description="Emotion intensity: 0.0-1.0 (displayed on node)"
|
||||
)
|
||||
@@ -264,25 +271,26 @@ class StatementNode(Node):
|
||||
description="Emotion subject: self/other/object"
|
||||
)
|
||||
emotion_type: Optional[str] = Field(
|
||||
None,
|
||||
None,
|
||||
description="Emotion type: joy/sadness/anger/fear/surprise/neutral"
|
||||
)
|
||||
emotion_keywords: Optional[List[str]] = Field(
|
||||
default_factory=list,
|
||||
description="Emotion keywords list, max 3 items"
|
||||
)
|
||||
|
||||
|
||||
# Temporal fields
|
||||
temporal_info: TemporalInfo = Field(..., description="Temporal information")
|
||||
valid_at: Optional[datetime] = Field(None, description="Temporal validity start")
|
||||
invalid_at: Optional[datetime] = Field(None, description="Temporal validity end")
|
||||
|
||||
|
||||
# Embedding and other fields
|
||||
statement_embedding: Optional[List[float]] = Field(None, description="Statement embedding vector")
|
||||
chunk_embedding: Optional[List[float]] = Field(None, description="Chunk embedding vector")
|
||||
connect_strength: str = Field(..., description="Strong VS Weak classification of this statement")
|
||||
config_id: Optional[int | str] = Field(None, description="Configuration ID used to process this statement (integer or string)")
|
||||
|
||||
config_id: Optional[int | str] = Field(None,
|
||||
description="Configuration ID used to process this statement (integer or string)")
|
||||
|
||||
# ACT-R Memory Activation Properties
|
||||
importance_score: float = Field(
|
||||
default=0.5,
|
||||
@@ -309,13 +317,13 @@ class StatementNode(Node):
|
||||
ge=0,
|
||||
description="Total number of times this node has been accessed"
|
||||
)
|
||||
|
||||
|
||||
@field_validator('valid_at', 'invalid_at', mode='before')
|
||||
@classmethod
|
||||
def validate_datetime(cls, v):
|
||||
"""使用通用的历史日期解析函数"""
|
||||
return parse_historical_datetime(v)
|
||||
|
||||
|
||||
@field_validator('emotion_type', mode='before')
|
||||
@classmethod
|
||||
def validate_emotion_type(cls, v):
|
||||
@@ -326,7 +334,7 @@ class StatementNode(Node):
|
||||
if v not in valid_types:
|
||||
raise ValueError(f"emotion_type must be one of {valid_types}, got {v}")
|
||||
return v
|
||||
|
||||
|
||||
@field_validator('emotion_subject', mode='before')
|
||||
@classmethod
|
||||
def validate_emotion_subject(cls, v):
|
||||
@@ -337,7 +345,7 @@ class StatementNode(Node):
|
||||
if v not in valid_subjects:
|
||||
raise ValueError(f"emotion_subject must be one of {valid_subjects}, got {v}")
|
||||
return v
|
||||
|
||||
|
||||
@field_validator('emotion_keywords', mode='before')
|
||||
@classmethod
|
||||
def validate_emotion_keywords(cls, v):
|
||||
@@ -356,12 +364,14 @@ class ChunkNode(Node):
|
||||
Attributes:
|
||||
dialog_id: ID of the parent dialog
|
||||
content: The text content of the chunk
|
||||
speaker: Speaker identifier ('user' or 'assistant')
|
||||
chunk_embedding: Optional embedding vector for the chunk
|
||||
sequence_number: Order of this chunk within the dialog
|
||||
metadata: Additional chunk metadata as key-value pairs
|
||||
"""
|
||||
dialog_id: str = Field(..., description="ID of the parent dialog")
|
||||
content: str = Field(..., description="The text content of the chunk")
|
||||
speaker: Optional[str] = Field(None, description="Speaker identifier: 'user' for user messages, 'assistant' for AI responses")
|
||||
chunk_embedding: Optional[List[float]] = Field(None, description="Chunk embedding vector")
|
||||
sequence_number: int = Field(..., description="Order of this chunk within the dialog")
|
||||
metadata: dict = Field(default_factory=dict, description="Additional chunk metadata")
|
||||
@@ -405,19 +415,20 @@ class ExtractedEntityNode(Node):
|
||||
entity_type: str = Field(..., description="Type of the entity")
|
||||
description: str = Field(..., description="Entity description")
|
||||
example: str = Field(
|
||||
default="",
|
||||
default="",
|
||||
description="A concise example (around 20 characters) to help understand the entity"
|
||||
)
|
||||
aliases: List[str] = Field(
|
||||
default_factory=list,
|
||||
default_factory=list,
|
||||
description="Entity aliases - alternative names for this entity"
|
||||
)
|
||||
name_embedding: Optional[List[float]] = Field(default_factory=list, description="Name embedding vector")
|
||||
# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
|
||||
# fact_summary: str = Field(default="", description="Summary of the fact about this entity")
|
||||
connect_strength: str = Field(..., description="Strong VS Weak about this entity")
|
||||
config_id: Optional[int | str] = Field(None, description="Configuration ID used to process this entity (integer or string)")
|
||||
|
||||
config_id: Optional[int | str] = Field(None,
|
||||
description="Configuration ID used to process this entity (integer or string)")
|
||||
|
||||
# ACT-R Memory Activation Properties
|
||||
importance_score: float = Field(
|
||||
default=0.5,
|
||||
@@ -444,16 +455,16 @@ class ExtractedEntityNode(Node):
|
||||
ge=0,
|
||||
description="Total number of times this node has been accessed"
|
||||
)
|
||||
|
||||
|
||||
# Explicit Memory Classification
|
||||
is_explicit_memory: bool = Field(
|
||||
default=False,
|
||||
description="Whether this entity represents explicit/semantic memory (knowledge, concepts, definitions, theories, principles)"
|
||||
)
|
||||
|
||||
|
||||
@field_validator('aliases', mode='before')
|
||||
@classmethod
|
||||
def validate_aliases_field(cls, v): # 字段验证器 自动清理和验证 aliases 字段
|
||||
def validate_aliases_field(cls, v): # 字段验证器 自动清理和验证 aliases 字段
|
||||
"""Validate and clean aliases field using utility function.
|
||||
|
||||
This validator ensures that the aliases field is always a valid list of strings.
|
||||
@@ -507,8 +518,9 @@ class MemorySummaryNode(Node):
|
||||
memory_type: Optional[str] = Field(None, description="Type/category of the episodic memory")
|
||||
summary_embedding: Optional[List[float]] = Field(None, description="Embedding vector for the summary")
|
||||
metadata: dict = Field(default_factory=dict, description="Additional metadata for the summary")
|
||||
config_id: Optional[int | str] = Field(None, description="Configuration ID used to process this summary (integer or string)")
|
||||
|
||||
config_id: Optional[int | str] = Field(None,
|
||||
description="Configuration ID used to process this summary (integer or string)")
|
||||
|
||||
# ACT-R Forgetting Engine Properties
|
||||
original_statement_id: Optional[str] = Field(
|
||||
None,
|
||||
@@ -522,7 +534,7 @@ class MemorySummaryNode(Node):
|
||||
None,
|
||||
description="Timestamp when the nodes were merged"
|
||||
)
|
||||
|
||||
|
||||
# ACT-R Memory Activation Properties
|
||||
importance_score: float = Field(
|
||||
default=0.5,
|
||||
@@ -549,3 +561,18 @@ class MemorySummaryNode(Node):
|
||||
ge=0,
|
||||
description="Total number of times this node has been accessed (reset to 1 on creation)"
|
||||
)
|
||||
|
||||
|
||||
class PerceptualNode(Node):
|
||||
"""Node representing a multimodal message in the knowledge graph.
|
||||
"""
|
||||
perceptual_type: int
|
||||
file_path: str
|
||||
file_name: str
|
||||
file_ext: str
|
||||
summary: str
|
||||
keywords: list[str]
|
||||
topic: str
|
||||
domain: str
|
||||
file_type: str
|
||||
summary_embedding: list[float] | None
|
||||
|
||||
@@ -30,6 +30,7 @@ class ConversationMessage(BaseModel):
|
||||
"""
|
||||
role: str = Field(..., description="The role of the speaker (e.g., 'user', 'assistant').")
|
||||
msg: str = Field(..., description="The text content of the message.")
|
||||
files: list[tuple] = Field(default_factory=list, description="The file content of the message", exclude=True)
|
||||
|
||||
|
||||
class TemporalValidityRange(BaseModel):
|
||||
@@ -130,7 +131,8 @@ class Chunk(BaseModel):
|
||||
content: str = Field(..., description="The content of the chunk as a string.")
|
||||
speaker: Optional[str] = Field(None, description="The speaker/role for this chunk (user/assistant).")
|
||||
statements: List[Statement] = Field(default_factory=list, description="A list of statements in the chunk.")
|
||||
chunk_embedding: Optional[List[float]] = Field(None, description="The embedding vector of the chunk.")
|
||||
files: list[tuple] = Field(default_factory=list, description="List of files in the chunk.")
|
||||
chunk_embedding: Optional[List[float]] = Field(default=None, description="The embedding vector of the chunk.")
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict, description="Additional metadata for the chunk.")
|
||||
|
||||
@classmethod
|
||||
|
||||
57
api/app/core/memory/models/metadata_models.py
Normal file
57
api/app/core/memory/models/metadata_models.py
Normal file
@@ -0,0 +1,57 @@
|
||||
"""Models for user metadata extraction.
|
||||
|
||||
Independent from triplet_models.py - these models are used by the
|
||||
standalone metadata extraction pipeline (post-dedup async Celery task).
|
||||
"""
|
||||
|
||||
from typing import List
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
|
||||
class UserMetadataProfile(BaseModel):
|
||||
"""用户画像信息"""
|
||||
|
||||
model_config = ConfigDict(extra="ignore")
|
||||
role: str = Field(default="", description="用户职业或角色")
|
||||
domain: str = Field(default="", description="用户所在领域")
|
||||
expertise: List[str] = Field(
|
||||
default_factory=list, description="用户擅长的技能或工具"
|
||||
)
|
||||
interests: List[str] = Field(
|
||||
default_factory=list, description="用户关注的话题或领域标签"
|
||||
)
|
||||
|
||||
|
||||
class UserMetadataBehavioralHints(BaseModel):
|
||||
"""行为偏好"""
|
||||
|
||||
model_config = ConfigDict(extra="ignore")
|
||||
learning_stage: str = Field(default="", description="学习阶段")
|
||||
preferred_depth: str = Field(default="", description="偏好深度")
|
||||
tone_preference: str = Field(default="", description="语气偏好")
|
||||
|
||||
|
||||
class UserMetadata(BaseModel):
|
||||
"""用户元数据顶层结构"""
|
||||
|
||||
model_config = ConfigDict(extra="ignore")
|
||||
profile: UserMetadataProfile = Field(default_factory=UserMetadataProfile)
|
||||
behavioral_hints: UserMetadataBehavioralHints = Field(
|
||||
default_factory=UserMetadataBehavioralHints
|
||||
)
|
||||
knowledge_tags: List[str] = Field(default_factory=list, description="知识标签")
|
||||
|
||||
|
||||
class MetadataExtractionResponse(BaseModel):
|
||||
"""元数据提取 LLM 响应结构"""
|
||||
|
||||
model_config = ConfigDict(extra="ignore")
|
||||
user_metadata: UserMetadata = Field(default_factory=UserMetadata)
|
||||
aliases_to_add: List[str] = Field(
|
||||
default_factory=list,
|
||||
description="本次新发现的用户别名(用户自我介绍或他人对用户的称呼)",
|
||||
)
|
||||
aliases_to_remove: List[str] = Field(
|
||||
default_factory=list, description="用户明确否认的别名(如'我不叫XX了')"
|
||||
)
|
||||
@@ -1,4 +1,3 @@
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import math
|
||||
@@ -6,7 +5,6 @@ import os
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||
from uuid import UUID
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.schemas.memory_config_schema import MemoryConfig
|
||||
@@ -23,7 +21,7 @@ from app.core.memory.utils.config.config_utils import (
|
||||
)
|
||||
from app.core.memory.utils.data.text_utils import extract_plain_query
|
||||
from app.core.memory.utils.data.time_utils import normalize_date_safe
|
||||
from app.core.memory.utils.llm.llm_utils import get_reranker_client
|
||||
# from app.core.memory.utils.llm.llm_utils import get_reranker_client
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
from app.db import get_db_context
|
||||
from app.repositories.neo4j.graph_search import (
|
||||
@@ -43,6 +41,7 @@ load_dotenv()
|
||||
|
||||
logger = get_memory_logger(__name__)
|
||||
|
||||
|
||||
def _parse_datetime(value: Any) -> Optional[datetime]:
|
||||
"""Parse ISO `created_at` strings of the form 'YYYY-MM-DDTHH:MM:SS.ssssss'."""
|
||||
if value is None:
|
||||
@@ -75,7 +74,7 @@ def normalize_scores(results: List[Dict[str, Any]], score_field: str = "score")
|
||||
if score_field == "activation_value" and score is None:
|
||||
scores.append(None) # 保持 None,稍后特殊处理
|
||||
continue
|
||||
|
||||
|
||||
if score is not None and isinstance(score, (int, float)):
|
||||
scores.append(float(score))
|
||||
else:
|
||||
@@ -83,10 +82,10 @@ def normalize_scores(results: List[Dict[str, Any]], score_field: str = "score")
|
||||
|
||||
if not scores:
|
||||
return results
|
||||
|
||||
|
||||
# 过滤掉 None 值,只对有效分数进行归一化
|
||||
valid_scores = [s for s in scores if s is not None]
|
||||
|
||||
|
||||
if not valid_scores:
|
||||
# 所有分数都是 None,不进行归一化
|
||||
for item in results:
|
||||
@@ -94,7 +93,7 @@ def normalize_scores(results: List[Dict[str, Any]], score_field: str = "score")
|
||||
item[f"normalized_{score_field}"] = None
|
||||
return results
|
||||
|
||||
if len(valid_scores) == 1: # Single valid score, set to 1.0
|
||||
if len(valid_scores) == 1: # Single valid score, set to 1.0
|
||||
for item, score in zip(results, scores):
|
||||
if score_field in item or score_field == "activation_value":
|
||||
if score is None:
|
||||
@@ -132,7 +131,6 @@ def normalize_scores(results: List[Dict[str, Any]], score_field: str = "score")
|
||||
return results
|
||||
|
||||
|
||||
|
||||
def _deduplicate_results(items: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Remove duplicate items from search results based on content.
|
||||
@@ -150,52 +148,53 @@ def _deduplicate_results(items: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
seen_ids = set()
|
||||
seen_content = set()
|
||||
deduplicated = []
|
||||
|
||||
|
||||
for item in items:
|
||||
# Try multiple ID fields to identify unique items
|
||||
item_id = item.get("id") or item.get("uuid") or item.get("chunk_id")
|
||||
|
||||
|
||||
# Extract content from various possible fields
|
||||
content = (
|
||||
item.get("text") or
|
||||
item.get("content") or
|
||||
item.get("statement") or
|
||||
item.get("name") or
|
||||
""
|
||||
item.get("text") or
|
||||
item.get("content") or
|
||||
item.get("statement") or
|
||||
item.get("name") or
|
||||
""
|
||||
)
|
||||
|
||||
|
||||
# Normalize content for comparison (strip whitespace and lowercase)
|
||||
normalized_content = str(content).strip().lower() if content else ""
|
||||
|
||||
|
||||
# Check if we've seen this ID or content before
|
||||
is_duplicate = False
|
||||
|
||||
|
||||
if item_id and item_id in seen_ids:
|
||||
is_duplicate = True
|
||||
elif normalized_content and normalized_content in seen_content:
|
||||
# Only check content duplication if content is not empty
|
||||
is_duplicate = True
|
||||
|
||||
|
||||
if not is_duplicate:
|
||||
# Mark as seen
|
||||
if item_id:
|
||||
seen_ids.add(item_id)
|
||||
if normalized_content: # Only track non-empty content
|
||||
seen_content.add(normalized_content)
|
||||
|
||||
|
||||
deduplicated.append(item)
|
||||
|
||||
|
||||
return deduplicated
|
||||
|
||||
|
||||
def rerank_with_activation(
|
||||
keyword_results: Dict[str, List[Dict[str, Any]]],
|
||||
embedding_results: Dict[str, List[Dict[str, Any]]],
|
||||
alpha: float = 0.6,
|
||||
limit: int = 10,
|
||||
forgetting_config: ForgettingEngineConfig | None = None,
|
||||
activation_boost_factor: float = 0.8,
|
||||
now: datetime | None = None,
|
||||
keyword_results: Dict[str, List[Dict[str, Any]]],
|
||||
embedding_results: Dict[str, List[Dict[str, Any]]],
|
||||
alpha: float = 0.6,
|
||||
limit: int = 10,
|
||||
forgetting_config: ForgettingEngineConfig | None = None,
|
||||
activation_boost_factor: float = 0.8,
|
||||
now: datetime | None = None,
|
||||
content_score_threshold: float = 0.5,
|
||||
) -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""
|
||||
两阶段排序:先按内容相关性筛选,再按激活值排序。
|
||||
@@ -222,6 +221,8 @@ def rerank_with_activation(
|
||||
forgetting_config: 遗忘引擎配置(当前未使用)
|
||||
activation_boost_factor: 激活度对记忆强度的影响系数 (默认: 0.8)
|
||||
now: 当前时间(用于遗忘计算)
|
||||
content_score_threshold: 内容相关性最低阈值(基于归一化后的 content_score),
|
||||
低于此阈值的结果会被过滤。默认 0.5。
|
||||
|
||||
返回:
|
||||
带评分元数据的重排序结果,按 final_score 排序
|
||||
@@ -229,26 +230,26 @@ def rerank_with_activation(
|
||||
# 验证权重范围
|
||||
if not (0 <= alpha <= 1):
|
||||
raise ValueError(f"alpha 必须在 [0, 1] 范围内,当前值: {alpha}")
|
||||
|
||||
|
||||
# 初始化遗忘引擎(如果需要)
|
||||
engine = None
|
||||
if forgetting_config:
|
||||
engine = ForgettingEngine(forgetting_config)
|
||||
now_dt = now or datetime.now()
|
||||
|
||||
|
||||
reranked: Dict[str, List[Dict[str, Any]]] = {}
|
||||
|
||||
for category in ["statements", "chunks", "entities", "summaries"]:
|
||||
|
||||
for category in ["statements", "chunks", "entities", "summaries", "communities"]:
|
||||
keyword_items = keyword_results.get(category, [])
|
||||
embedding_items = embedding_results.get(category, [])
|
||||
|
||||
|
||||
# 步骤 1: 归一化分数
|
||||
keyword_items = normalize_scores(keyword_items, "score")
|
||||
embedding_items = normalize_scores(embedding_items, "score")
|
||||
|
||||
|
||||
# 步骤 2: 按 ID 合并结果(去重)
|
||||
combined_items: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
|
||||
# 添加关键词结果
|
||||
for item in keyword_items:
|
||||
item_id = item.get("id") or item.get("uuid") or item.get("chunk_id")
|
||||
@@ -257,7 +258,7 @@ def rerank_with_activation(
|
||||
combined_items[item_id] = item.copy()
|
||||
combined_items[item_id]["bm25_score"] = item.get("normalized_score", 0)
|
||||
combined_items[item_id]["embedding_score"] = 0 # 默认值
|
||||
|
||||
|
||||
# 添加或更新向量嵌入结果
|
||||
for item in embedding_items:
|
||||
item_id = item.get("id") or item.get("uuid") or item.get("chunk_id")
|
||||
@@ -271,62 +272,64 @@ def rerank_with_activation(
|
||||
combined_items[item_id] = item.copy()
|
||||
combined_items[item_id]["bm25_score"] = 0 # 默认值
|
||||
combined_items[item_id]["embedding_score"] = item.get("normalized_score", 0)
|
||||
|
||||
|
||||
# 步骤 3: 归一化激活度分数
|
||||
# 为所有项准备激活度值列表
|
||||
items_list = list(combined_items.values())
|
||||
items_list = normalize_scores(items_list, "activation_value")
|
||||
|
||||
|
||||
# 更新 combined_items 中的归一化激活度分数
|
||||
for item in items_list:
|
||||
item_id = item.get("id") or item.get("uuid") or item.get("chunk_id")
|
||||
if item_id and item_id in combined_items:
|
||||
combined_items[item_id]["normalized_activation_value"] = item.get("normalized_activation_value", 0)
|
||||
|
||||
combined_items[item_id]["normalized_activation_value"] = item.get("normalized_activation_value")
|
||||
|
||||
# 步骤 4: 计算基础分数和最终分数
|
||||
for item_id, item in combined_items.items():
|
||||
bm25_norm = float(item.get("bm25_score", 0) or 0)
|
||||
emb_norm = float(item.get("embedding_score", 0) or 0)
|
||||
act_norm = float(item.get("normalized_activation_value", 0) or 0)
|
||||
|
||||
# normalized_activation_value 为 None 表示该节点无激活值,保留 None 语义
|
||||
raw_act_norm = item.get("normalized_activation_value")
|
||||
act_norm = float(raw_act_norm) if raw_act_norm is not None else None
|
||||
|
||||
# 第一阶段:只考虑内容相关性(BM25 + Embedding)
|
||||
# alpha 控制 BM25 权重,(1-alpha) 控制 Embedding 权重
|
||||
content_score = alpha * bm25_norm + (1 - alpha) * emb_norm
|
||||
base_score = content_score # 第一阶段用内容分数
|
||||
|
||||
# 存储激活度分数供第二阶段使用
|
||||
item["activation_score"] = act_norm
|
||||
|
||||
# 存储激活度分数供第二阶段使用(None 表示无激活值,不参与激活值排序)
|
||||
item["activation_score"] = act_norm # 可能为 None
|
||||
item["content_score"] = content_score
|
||||
item["base_score"] = base_score
|
||||
|
||||
|
||||
# 步骤 5: 应用遗忘曲线(可选)
|
||||
if engine:
|
||||
# 计算受激活度影响的记忆强度
|
||||
importance = float(item.get("importance_score", 0.5) or 0.5)
|
||||
|
||||
|
||||
# 获取 activation_value
|
||||
activation_val = item.get("activation_value")
|
||||
|
||||
|
||||
# 只对有激活值的节点应用遗忘曲线
|
||||
if activation_val is not None and isinstance(activation_val, (int, float)):
|
||||
activation_val = float(activation_val)
|
||||
|
||||
|
||||
# 计算记忆强度:importance_score × (1 + activation_value × boost_factor)
|
||||
memory_strength = importance * (1 + activation_val * activation_boost_factor)
|
||||
|
||||
|
||||
# 计算经过的时间(天数)
|
||||
dt = _parse_datetime(item.get("created_at"))
|
||||
if dt is None:
|
||||
time_elapsed_days = 0.0
|
||||
else:
|
||||
time_elapsed_days = max(0.0, (now_dt - dt).total_seconds() / 86400.0)
|
||||
|
||||
|
||||
# 获取遗忘权重
|
||||
forgetting_weight = engine.calculate_weight(
|
||||
time_elapsed=time_elapsed_days,
|
||||
memory_strength=memory_strength
|
||||
)
|
||||
|
||||
|
||||
# 应用到基础分数
|
||||
item["forgetting_weight"] = forgetting_weight
|
||||
item["final_score"] = base_score * forgetting_weight
|
||||
@@ -336,7 +339,7 @@ def rerank_with_activation(
|
||||
else:
|
||||
# 不使用遗忘曲线
|
||||
item["final_score"] = base_score
|
||||
|
||||
|
||||
# 步骤 6: 两阶段排序和限制
|
||||
# 第一阶段:按内容相关性(base_score)排序,取 Top-K
|
||||
first_stage_limit = limit * 3 # 可配置,取3倍候选
|
||||
@@ -345,11 +348,11 @@ def rerank_with_activation(
|
||||
key=lambda x: float(x.get("base_score", 0) or 0), # 按内容分数排序
|
||||
reverse=True
|
||||
)[:first_stage_limit]
|
||||
|
||||
|
||||
# 第二阶段:分离有激活值和无激活值的节点
|
||||
items_with_activation = []
|
||||
items_without_activation = []
|
||||
|
||||
|
||||
for item in first_stage_sorted:
|
||||
activation_score = item.get("activation_score")
|
||||
# 检查是否有有效的激活值(不是 None)
|
||||
@@ -357,14 +360,14 @@ def rerank_with_activation(
|
||||
items_with_activation.append(item)
|
||||
else:
|
||||
items_without_activation.append(item)
|
||||
|
||||
|
||||
# 优先按激活值排序有激活值的节点
|
||||
sorted_with_activation = sorted(
|
||||
items_with_activation,
|
||||
key=lambda x: float(x.get("activation_score", 0) or 0),
|
||||
reverse=True
|
||||
)
|
||||
|
||||
|
||||
# 如果有激活值的节点不足 limit,用无激活值的节点补充
|
||||
if len(sorted_with_activation) < limit:
|
||||
needed = limit - len(sorted_with_activation)
|
||||
@@ -372,7 +375,7 @@ def rerank_with_activation(
|
||||
sorted_items = sorted_with_activation + items_without_activation[:needed]
|
||||
else:
|
||||
sorted_items = sorted_with_activation[:limit]
|
||||
|
||||
|
||||
# 两阶段排序完成,更新 final_score 以反映实际排序依据
|
||||
# Stage 1: 按 content_score 筛选候选(已完成)
|
||||
# Stage 2: 按 activation_score 排序(已完成)
|
||||
@@ -388,16 +391,29 @@ def rerank_with_activation(
|
||||
else:
|
||||
# 无激活值:使用内容相关性分数
|
||||
item["final_score"] = item.get("base_score", 0)
|
||||
|
||||
# 最终去重确保没有重复项
|
||||
|
||||
if content_score_threshold > 0:
|
||||
before_count = len(sorted_items)
|
||||
sorted_items = [
|
||||
item for item in sorted_items
|
||||
if float(item.get("content_score", 0) or 0) >= content_score_threshold
|
||||
]
|
||||
filtered_count = before_count - len(sorted_items)
|
||||
if filtered_count > 0:
|
||||
logger.info(
|
||||
f"[rerank] {category}: filtered {filtered_count}/{before_count} "
|
||||
f"items below content_score_threshold={content_score_threshold}"
|
||||
)
|
||||
|
||||
sorted_items = _deduplicate_results(sorted_items)
|
||||
|
||||
|
||||
reranked[category] = sorted_items
|
||||
|
||||
|
||||
return reranked
|
||||
|
||||
|
||||
def log_search_query(query_text: str, search_type: str, end_user_id: str | None, limit: int, include: List[str], log_file: str = None):
|
||||
def log_search_query(query_text: str, search_type: str, end_user_id: str | None, limit: int, include: List[str],
|
||||
log_file: str = None):
|
||||
"""Log search query information using the logger.
|
||||
|
||||
Args:
|
||||
@@ -410,7 +426,7 @@ def log_search_query(query_text: str, search_type: str, end_user_id: str | None,
|
||||
"""
|
||||
# Ensure the query text is plain and clean before logging
|
||||
cleaned_query = extract_plain_query(query_text)
|
||||
|
||||
|
||||
# Log using the standard logger
|
||||
logger.info(
|
||||
f"Search query: query='{cleaned_query}', type={search_type}, "
|
||||
@@ -437,8 +453,8 @@ def _remove_keys_recursive(obj: Any, keys_to_remove: List[str]) -> Any:
|
||||
|
||||
|
||||
def apply_reranker_placeholder(
|
||||
results: Dict[str, List[Dict[str, Any]]],
|
||||
query_text: str,
|
||||
results: Dict[str, List[Dict[str, Any]]],
|
||||
query_text: str,
|
||||
) -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""
|
||||
Placeholder for a cross-encoder reranker.
|
||||
@@ -481,7 +497,7 @@ def apply_reranker_placeholder(
|
||||
# ) -> Dict[str, List[Dict[str, Any]]]:
|
||||
# """
|
||||
# Apply LLM-based reranking to search results.
|
||||
|
||||
|
||||
# Args:
|
||||
# results: Search results organized by category
|
||||
# query_text: Original search query
|
||||
@@ -489,7 +505,7 @@ def apply_reranker_placeholder(
|
||||
# llm_weight: Weight for LLM score (0.0-1.0, higher favors LLM)
|
||||
# top_k: Maximum number of items to rerank per category
|
||||
# batch_size: Number of items to process concurrently
|
||||
|
||||
|
||||
# Returns:
|
||||
# Reranked results with final_score and reranker_model fields
|
||||
# """
|
||||
@@ -499,18 +515,18 @@ def apply_reranker_placeholder(
|
||||
# # except Exception as e:
|
||||
# # logger.debug(f"Failed to load reranker config: {e}")
|
||||
# # rc = {}
|
||||
|
||||
|
||||
# # Check if reranking is enabled
|
||||
# enabled = rc.get("enabled", False)
|
||||
# if not enabled:
|
||||
# logger.debug("LLM reranking is disabled in configuration")
|
||||
# return results
|
||||
|
||||
|
||||
# # Load configuration parameters with defaults
|
||||
# llm_weight = llm_weight if llm_weight is not None else rc.get("llm_weight", 0.5)
|
||||
# top_k = top_k if top_k is not None else rc.get("top_k", 20)
|
||||
# batch_size = batch_size if batch_size is not None else rc.get("batch_size", 5)
|
||||
|
||||
|
||||
# # Initialize reranker client if not provided
|
||||
# if reranker_client is None:
|
||||
# try:
|
||||
@@ -518,10 +534,10 @@ def apply_reranker_placeholder(
|
||||
# except Exception as e:
|
||||
# logger.warning(f"Failed to initialize reranker client: {e}, skipping LLM reranking")
|
||||
# return results
|
||||
|
||||
|
||||
# # Get model name for metadata
|
||||
# model_name = getattr(reranker_client, 'model_name', 'unknown')
|
||||
|
||||
|
||||
# # Process each category
|
||||
# reranked_results = {}
|
||||
# for category in ["statements", "chunks", "entities", "summaries"]:
|
||||
@@ -529,38 +545,38 @@ def apply_reranker_placeholder(
|
||||
# if not items:
|
||||
# reranked_results[category] = []
|
||||
# continue
|
||||
|
||||
|
||||
# # Select top K items by combined_score for reranking
|
||||
# sorted_items = sorted(
|
||||
# items,
|
||||
# key=lambda x: float(x.get("combined_score", x.get("score", 0.0)) or 0.0),
|
||||
# reverse=True
|
||||
# )
|
||||
|
||||
|
||||
# top_items = sorted_items[:top_k]
|
||||
# remaining_items = sorted_items[top_k:]
|
||||
|
||||
|
||||
# # Extract text content from each item
|
||||
# def extract_text(item: Dict[str, Any]) -> str:
|
||||
# """Extract text content from a result item."""
|
||||
# # Try different text fields based on category
|
||||
# text = item.get("text") or item.get("content") or item.get("statement") or item.get("name") or ""
|
||||
# return str(text).strip()
|
||||
|
||||
|
||||
# # Batch items for concurrent processing
|
||||
# batches = []
|
||||
# for i in range(0, len(top_items), batch_size):
|
||||
# batch = top_items[i:i + batch_size]
|
||||
# batches.append(batch)
|
||||
|
||||
|
||||
# # Process batches concurrently
|
||||
# async def process_batch(batch: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
# """Process a batch of items with LLM relevance scoring."""
|
||||
# scored_batch = []
|
||||
|
||||
|
||||
# for item in batch:
|
||||
# item_text = extract_text(item)
|
||||
|
||||
|
||||
# # Skip items with no text
|
||||
# if not item_text:
|
||||
# item_copy = item.copy()
|
||||
@@ -570,7 +586,7 @@ def apply_reranker_placeholder(
|
||||
# item_copy["reranker_model"] = model_name
|
||||
# scored_batch.append(item_copy)
|
||||
# continue
|
||||
|
||||
|
||||
# # Create relevance scoring prompt
|
||||
# prompt = f"""Given the search query and a result item, rate the relevance of the item to the query on a scale from 0.0 to 1.0.
|
||||
|
||||
@@ -583,15 +599,15 @@ def apply_reranker_placeholder(
|
||||
# - 1.0 means perfectly relevant
|
||||
|
||||
# Relevance score:"""
|
||||
|
||||
|
||||
# # Send request to LLM
|
||||
# try:
|
||||
# messages = [{"role": "user", "content": prompt}]
|
||||
# response = await reranker_client.chat(messages)
|
||||
|
||||
|
||||
# # Parse LLM response to extract relevance score
|
||||
# response_text = str(response.content if hasattr(response, 'content') else response).strip()
|
||||
|
||||
|
||||
# # Try to extract a float from the response
|
||||
# try:
|
||||
# # Remove any non-numeric characters except decimal point
|
||||
@@ -606,11 +622,11 @@ def apply_reranker_placeholder(
|
||||
# except (ValueError, AttributeError) as e:
|
||||
# logger.warning(f"Invalid LLM score format: {response_text}, using combined_score. Error: {e}")
|
||||
# llm_score = None
|
||||
|
||||
|
||||
# # Calculate final score
|
||||
# item_copy = item.copy()
|
||||
# combined_score = float(item.get("combined_score", item.get("score", 0.0)) or 0.0)
|
||||
|
||||
|
||||
# if llm_score is not None:
|
||||
# final_score = (1 - llm_weight) * combined_score + llm_weight * llm_score
|
||||
# item_copy["llm_relevance_score"] = llm_score
|
||||
@@ -618,7 +634,7 @@ def apply_reranker_placeholder(
|
||||
# # Use combined_score as fallback
|
||||
# final_score = combined_score
|
||||
# item_copy["llm_relevance_score"] = combined_score
|
||||
|
||||
|
||||
# item_copy["final_score"] = final_score
|
||||
# item_copy["reranker_model"] = model_name
|
||||
# scored_batch.append(item_copy)
|
||||
@@ -630,14 +646,14 @@ def apply_reranker_placeholder(
|
||||
# item_copy["llm_relevance_score"] = combined_score
|
||||
# item_copy["reranker_model"] = model_name
|
||||
# scored_batch.append(item_copy)
|
||||
|
||||
|
||||
# return scored_batch
|
||||
|
||||
|
||||
# # Process all batches concurrently
|
||||
# try:
|
||||
# batch_tasks = [process_batch(batch) for batch in batches]
|
||||
# batch_results = await asyncio.gather(*batch_tasks, return_exceptions=True)
|
||||
|
||||
|
||||
# # Merge batch results
|
||||
# scored_items = []
|
||||
# for result in batch_results:
|
||||
@@ -645,7 +661,7 @@ def apply_reranker_placeholder(
|
||||
# logger.warning(f"Batch processing failed: {result}")
|
||||
# continue
|
||||
# scored_items.extend(result)
|
||||
|
||||
|
||||
# # Add remaining items (not in top K) with their combined_score as final_score
|
||||
# for item in remaining_items:
|
||||
# item_copy = item.copy()
|
||||
@@ -653,11 +669,11 @@ def apply_reranker_placeholder(
|
||||
# item_copy["final_score"] = combined_score
|
||||
# item_copy["reranker_model"] = model_name
|
||||
# scored_items.append(item_copy)
|
||||
|
||||
|
||||
# # Sort all items by final_score in descending order
|
||||
# scored_items.sort(key=lambda x: float(x.get("final_score", 0.0) or 0.0), reverse=True)
|
||||
# reranked_results[category] = scored_items
|
||||
|
||||
|
||||
# except Exception as e:
|
||||
# logger.error(f"Error in LLM reranking for category {category}: {e}, returning original results")
|
||||
# # Return original items with combined_score as final_score
|
||||
@@ -666,22 +682,22 @@ def apply_reranker_placeholder(
|
||||
# item["final_score"] = combined_score
|
||||
# item["reranker_model"] = model_name
|
||||
# reranked_results[category] = items
|
||||
|
||||
|
||||
# return reranked_results
|
||||
|
||||
|
||||
async def run_hybrid_search(
|
||||
query_text: str,
|
||||
search_type: str,
|
||||
end_user_id: str | None,
|
||||
limit: int,
|
||||
include: List[str],
|
||||
output_path: str | None,
|
||||
memory_config: "MemoryConfig",
|
||||
rerank_alpha: float = 0.6,
|
||||
activation_boost_factor: float = 0.8,
|
||||
use_forgetting_rerank: bool = False,
|
||||
use_llm_rerank: bool = False,
|
||||
query_text: str,
|
||||
search_type: str,
|
||||
end_user_id: str | None,
|
||||
limit: int,
|
||||
include: List[str],
|
||||
output_path: str | None,
|
||||
memory_config: "MemoryConfig",
|
||||
rerank_alpha: float = 0.6,
|
||||
activation_boost_factor: float = 0.8,
|
||||
use_forgetting_rerank: bool = False,
|
||||
use_llm_rerank: bool = False,
|
||||
):
|
||||
"""
|
||||
|
||||
@@ -697,7 +713,7 @@ async def run_hybrid_search(
|
||||
|
||||
# Clean and normalize the incoming query before use/logging
|
||||
query_text = extract_plain_query(query_text)
|
||||
|
||||
|
||||
# Validate query is not empty after cleaning
|
||||
if not query_text or not query_text.strip():
|
||||
logger.warning("Empty query after cleaning, returning empty results")
|
||||
@@ -714,7 +730,7 @@ async def run_hybrid_search(
|
||||
"error": "Empty query"
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
# Log the search query
|
||||
log_search_query(query_text, search_type, end_user_id, limit, include)
|
||||
|
||||
@@ -724,15 +740,16 @@ async def run_hybrid_search(
|
||||
try:
|
||||
keyword_task = None
|
||||
embedding_task = None
|
||||
keyword_results: Dict[str, List] = {}
|
||||
embedding_results: Dict[str, List] = {}
|
||||
|
||||
if search_type in ["keyword", "hybrid"]:
|
||||
# Keyword-based search
|
||||
logger.info("[PERF] Starting keyword search...")
|
||||
keyword_start = time.time()
|
||||
keyword_task = asyncio.create_task(
|
||||
search_graph(
|
||||
connector=connector,
|
||||
q=query_text,
|
||||
query=query_text,
|
||||
end_user_id=end_user_id,
|
||||
limit=limit,
|
||||
include=include
|
||||
@@ -742,43 +759,48 @@ async def run_hybrid_search(
|
||||
if search_type in ["embedding", "hybrid"]:
|
||||
# Embedding-based search
|
||||
logger.info("[PERF] Starting embedding search...")
|
||||
embedding_start = time.time()
|
||||
|
||||
|
||||
# 从数据库读取嵌入器配置(按 ID)并构建 RedBearModelConfig
|
||||
config_load_start = time.time()
|
||||
with get_db_context() as db:
|
||||
config_service = MemoryConfigService(db)
|
||||
embedder_config_dict = config_service.get_embedder_config(str(memory_config.embedding_model_id))
|
||||
rb_config = RedBearModelConfig(
|
||||
model_name=embedder_config_dict["model_name"],
|
||||
provider=embedder_config_dict["provider"],
|
||||
api_key=embedder_config_dict["api_key"],
|
||||
base_url=embedder_config_dict["base_url"],
|
||||
type="llm"
|
||||
)
|
||||
config_load_time = time.time() - config_load_start
|
||||
logger.info(f"[PERF] Config loading took {config_load_time:.4f}s")
|
||||
|
||||
# Init embedder
|
||||
embedder_init_start = time.time()
|
||||
embedder = OpenAIEmbedderClient(model_config=rb_config)
|
||||
embedder_init_time = time.time() - embedder_init_start
|
||||
logger.info(f"[PERF] Embedder init took {embedder_init_time:.4f}s")
|
||||
|
||||
embedding_task = asyncio.create_task(
|
||||
search_graph_by_embedding(
|
||||
connector=connector,
|
||||
embedder_client=embedder,
|
||||
query_text=query_text,
|
||||
end_user_id=end_user_id,
|
||||
limit=limit,
|
||||
include=include,
|
||||
try:
|
||||
with get_db_context() as db:
|
||||
config_service = MemoryConfigService(db)
|
||||
embedder_config_dict = config_service.get_embedder_config(str(memory_config.embedding_model_id))
|
||||
rb_config = RedBearModelConfig(
|
||||
model_name=embedder_config_dict["model_name"],
|
||||
provider=embedder_config_dict["provider"],
|
||||
api_key=embedder_config_dict["api_key"],
|
||||
base_url=embedder_config_dict["base_url"]
|
||||
)
|
||||
)
|
||||
config_load_time = time.time() - config_load_start
|
||||
logger.info(f"[PERF] Config loading took {config_load_time:.4f}s")
|
||||
|
||||
# Init embedder
|
||||
embedder_init_start = time.time()
|
||||
embedder = OpenAIEmbedderClient(model_config=rb_config)
|
||||
embedder_init_time = time.time() - embedder_init_start
|
||||
logger.info(f"[PERF] Embedder init took {embedder_init_time:.4f}s")
|
||||
|
||||
embedding_task = asyncio.create_task(
|
||||
search_graph_by_embedding(
|
||||
connector=connector,
|
||||
embedder_client=embedder,
|
||||
query_text=query_text,
|
||||
end_user_id=end_user_id,
|
||||
limit=limit,
|
||||
include=include,
|
||||
)
|
||||
)
|
||||
except Exception as emb_init_err:
|
||||
logger.warning(
|
||||
f"[PERF] Embedding search skipped due to init error "
|
||||
f"(embedding_model_id={memory_config.embedding_model_id}): {emb_init_err}"
|
||||
)
|
||||
embedding_task = None
|
||||
|
||||
if keyword_task:
|
||||
keyword_results = await keyword_task
|
||||
keyword_latency = time.time() - keyword_start
|
||||
keyword_latency = time.time() - search_start_time
|
||||
latency_metrics["keyword_search_latency"] = round(keyword_latency, 4)
|
||||
logger.info(f"[PERF] Keyword search completed in {keyword_latency:.4f}s")
|
||||
if search_type == "keyword":
|
||||
@@ -788,7 +810,7 @@ async def run_hybrid_search(
|
||||
|
||||
if embedding_task:
|
||||
embedding_results = await embedding_task
|
||||
embedding_latency = time.time() - embedding_start
|
||||
embedding_latency = time.time() - search_start_time
|
||||
latency_metrics["embedding_search_latency"] = round(embedding_latency, 4)
|
||||
logger.info(f"[PERF] Embedding search completed in {embedding_latency:.4f}s")
|
||||
if search_type == "embedding":
|
||||
@@ -800,7 +822,8 @@ async def run_hybrid_search(
|
||||
if search_type == "hybrid":
|
||||
results["combined_summary"] = {
|
||||
"total_keyword_results": sum(len(v) if isinstance(v, list) else 0 for v in keyword_results.values()),
|
||||
"total_embedding_results": sum(len(v) if isinstance(v, list) else 0 for v in embedding_results.values()),
|
||||
"total_embedding_results": sum(
|
||||
len(v) if isinstance(v, list) else 0 for v in embedding_results.values()),
|
||||
"search_query": query_text,
|
||||
"search_timestamp": datetime.now().isoformat()
|
||||
}
|
||||
@@ -808,7 +831,7 @@ async def run_hybrid_search(
|
||||
# Apply two-stage reranking with ACTR activation calculation
|
||||
rerank_start = time.time()
|
||||
logger.info("[PERF] Using two-stage reranking with ACTR activation")
|
||||
|
||||
|
||||
# 加载遗忘引擎配置
|
||||
config_start = time.time()
|
||||
try:
|
||||
@@ -819,7 +842,7 @@ async def run_hybrid_search(
|
||||
forgetting_cfg = ForgettingEngineConfig()
|
||||
config_time = time.time() - config_start
|
||||
logger.info(f"[PERF] Forgetting config loading took {config_time:.4f}s")
|
||||
|
||||
|
||||
# 统一使用激活度重排序(两阶段:检索 + ACTR计算)
|
||||
rerank_compute_start = time.time()
|
||||
reranked_results = rerank_with_activation(
|
||||
@@ -832,14 +855,14 @@ async def run_hybrid_search(
|
||||
)
|
||||
rerank_compute_time = time.time() - rerank_compute_start
|
||||
logger.info(f"[PERF] Rerank computation took {rerank_compute_time:.4f}s")
|
||||
|
||||
|
||||
rerank_latency = time.time() - rerank_start
|
||||
latency_metrics["reranking_latency"] = round(rerank_latency, 4)
|
||||
logger.info(f"[PERF] Total reranking completed in {rerank_latency:.4f}s")
|
||||
|
||||
|
||||
# Optional: apply reranker placeholder if enabled via config
|
||||
reranked_results = apply_reranker_placeholder(reranked_results, query_text)
|
||||
|
||||
|
||||
# Apply LLM reranking if enabled
|
||||
llm_rerank_applied = False
|
||||
# if use_llm_rerank:
|
||||
@@ -852,11 +875,12 @@ async def run_hybrid_search(
|
||||
# logger.info("LLM reranking applied successfully")
|
||||
# except Exception as e:
|
||||
# logger.warning(f"LLM reranking failed: {e}, using previous scores")
|
||||
|
||||
|
||||
results["reranked_results"] = reranked_results
|
||||
results["combined_summary"] = {
|
||||
"total_keyword_results": sum(len(v) if isinstance(v, list) else 0 for v in keyword_results.values()),
|
||||
"total_embedding_results": sum(len(v) if isinstance(v, list) else 0 for v in embedding_results.values()),
|
||||
"total_embedding_results": sum(
|
||||
len(v) if isinstance(v, list) else 0 for v in embedding_results.values()),
|
||||
"total_reranked_results": sum(len(v) if isinstance(v, list) else 0 for v in reranked_results.values()),
|
||||
"search_query": query_text,
|
||||
"search_timestamp": datetime.now().isoformat(),
|
||||
@@ -869,17 +893,17 @@ async def run_hybrid_search(
|
||||
# Calculate total latency
|
||||
total_latency = time.time() - search_start_time
|
||||
latency_metrics["total_latency"] = round(total_latency, 4)
|
||||
|
||||
|
||||
# Add latency metrics to results
|
||||
if "combined_summary" in results:
|
||||
results["combined_summary"]["latency_metrics"] = latency_metrics
|
||||
else:
|
||||
results["latency_metrics"] = latency_metrics
|
||||
|
||||
logger.info(f"[PERF] ===== SEARCH PERFORMANCE SUMMARY =====")
|
||||
|
||||
logger.info("[PERF] ===== SEARCH PERFORMANCE SUMMARY =====")
|
||||
logger.info(f"[PERF] Total search completed in {total_latency:.4f}s")
|
||||
logger.info(f"[PERF] Latency breakdown: {json.dumps(latency_metrics, indent=2)}")
|
||||
logger.info(f"[PERF] =========================================")
|
||||
logger.info("[PERF] =========================================")
|
||||
|
||||
# Sanitize results: drop large/unused fields
|
||||
_remove_keys_recursive(results, ["name_embedding"]) # drop entity name embeddings from outputs
|
||||
@@ -898,8 +922,10 @@ async def run_hybrid_search(
|
||||
# Log search completion with result count
|
||||
if search_type == "hybrid":
|
||||
result_counts = {
|
||||
"keyword": {key: len(value) if isinstance(value, list) else 0 for key, value in keyword_results.items()},
|
||||
"embedding": {key: len(value) if isinstance(value, list) else 0 for key, value in embedding_results.items()}
|
||||
"keyword": {key: len(value) if isinstance(value, list) else 0 for key, value in
|
||||
keyword_results.items()},
|
||||
"embedding": {key: len(value) if isinstance(value, list) else 0 for key, value in
|
||||
embedding_results.items()}
|
||||
}
|
||||
else:
|
||||
result_counts = {key: len(value) if isinstance(value, list) else 0 for key, value in results.items()}
|
||||
@@ -917,12 +943,12 @@ async def run_hybrid_search(
|
||||
|
||||
|
||||
async def search_by_temporal(
|
||||
end_user_id: Optional[str] = "test",
|
||||
start_date: Optional[str] = None,
|
||||
end_date: Optional[str] = None,
|
||||
valid_date: Optional[str] = None,
|
||||
invalid_date: Optional[str] = None,
|
||||
limit: int = 1,
|
||||
end_user_id: Optional[str] = "test",
|
||||
start_date: Optional[str] = None,
|
||||
end_date: Optional[str] = None,
|
||||
valid_date: Optional[str] = None,
|
||||
invalid_date: Optional[str] = None,
|
||||
limit: int = 1,
|
||||
):
|
||||
"""
|
||||
Temporal search across Statements.
|
||||
@@ -958,13 +984,13 @@ async def search_by_temporal(
|
||||
|
||||
|
||||
async def search_by_keyword_temporal(
|
||||
query_text: str,
|
||||
end_user_id: Optional[str] = "test",
|
||||
start_date: Optional[str] = None,
|
||||
end_date: Optional[str] = None,
|
||||
valid_date: Optional[str] = None,
|
||||
invalid_date: Optional[str] = None,
|
||||
limit: int = 1,
|
||||
query_text: str,
|
||||
end_user_id: Optional[str] = "test",
|
||||
start_date: Optional[str] = None,
|
||||
end_date: Optional[str] = None,
|
||||
valid_date: Optional[str] = None,
|
||||
invalid_date: Optional[str] = None,
|
||||
limit: int = 1,
|
||||
):
|
||||
"""
|
||||
Temporal keyword search across Statements.
|
||||
@@ -1001,9 +1027,9 @@ async def search_by_keyword_temporal(
|
||||
|
||||
|
||||
async def search_chunk_by_chunk_id(
|
||||
chunk_id: str,
|
||||
end_user_id: Optional[str] = "test",
|
||||
limit: int = 1,
|
||||
chunk_id: str,
|
||||
end_user_id: Optional[str] = "test",
|
||||
limit: int = 1,
|
||||
):
|
||||
"""
|
||||
Search for Chunks by chunk_id.
|
||||
@@ -1016,4 +1042,3 @@ async def search_chunk_by_chunk_id(
|
||||
limit=limit
|
||||
)
|
||||
return {"chunks": chunks}
|
||||
|
||||
|
||||
@@ -0,0 +1,3 @@
|
||||
from app.core.memory.storage_services.clustering_engine.label_propagation import LabelPropagationEngine
|
||||
|
||||
__all__ = ["LabelPropagationEngine"]
|
||||
@@ -0,0 +1,683 @@
|
||||
"""标签传播聚类引擎
|
||||
|
||||
基于 ZEP 论文的动态标签传播算法,对 Neo4j 中的 ExtractedEntity 节点进行社区聚类。
|
||||
|
||||
支持两种模式:
|
||||
- 全量初始化(full_clustering):首次运行,对所有实体做完整 LPA 迭代
|
||||
- 增量更新(incremental_update):新实体到达时,只处理新实体及其邻居
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
from math import sqrt
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from app.repositories.neo4j.community_repository import CommunityRepository
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 全量迭代最大轮数,防止不收敛
|
||||
MAX_ITERATIONS = 10
|
||||
|
||||
# 社区核心实体取 top-N 数量
|
||||
CORE_ENTITY_LIMIT = 10
|
||||
|
||||
|
||||
def _cosine_similarity(v1: Optional[List[float]], v2: Optional[List[float]]) -> float:
|
||||
"""计算两个向量的余弦相似度,任一为空则返回 0。"""
|
||||
if not v1 or not v2 or len(v1) != len(v2):
|
||||
return 0.0
|
||||
dot = sum(a * b for a, b in zip(v1, v2))
|
||||
norm1 = sqrt(sum(a * a for a in v1))
|
||||
norm2 = sqrt(sum(b * b for b in v2))
|
||||
if norm1 == 0 or norm2 == 0:
|
||||
return 0.0
|
||||
return dot / (norm1 * norm2)
|
||||
|
||||
|
||||
def _weighted_vote(
|
||||
neighbors: List[Dict],
|
||||
self_embedding: Optional[List[float]],
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
加权多数投票,选出得票最高的社区。
|
||||
|
||||
权重 = 语义相似度(name_embedding 余弦)* activation_value 加成
|
||||
没有 community_id 的邻居不参与投票。
|
||||
"""
|
||||
votes: Dict[str, float] = {}
|
||||
for nb in neighbors:
|
||||
cid = nb.get("community_id")
|
||||
if not cid:
|
||||
continue
|
||||
sem = _cosine_similarity(self_embedding, nb.get("name_embedding"))
|
||||
act = nb.get("activation_value") or 0.5
|
||||
# 语义相似度权重 0.6,激活值权重 0.4
|
||||
weight = 0.6 * sem + 0.4 * act
|
||||
votes[cid] = votes.get(cid, 0.0) + weight
|
||||
|
||||
if not votes:
|
||||
return None
|
||||
return max(votes, key=votes.__getitem__)
|
||||
|
||||
|
||||
class LabelPropagationEngine:
|
||||
"""标签传播聚类引擎"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
connector: Neo4jConnector,
|
||||
llm_model_id: Optional[str] = None,
|
||||
embedding_model_id: Optional[str] = None,
|
||||
):
|
||||
self.connector = connector
|
||||
self.repo = CommunityRepository(connector)
|
||||
self.llm_model_id = llm_model_id
|
||||
self.embedding_model_id = embedding_model_id
|
||||
# 缓存客户端实例,避免重复初始化
|
||||
self._llm_client = None
|
||||
self._embedder_client = None
|
||||
|
||||
# ──────────────────────────────────────────────────────────────────────────
|
||||
# 公开接口
|
||||
# ──────────────────────────────────────────────────────────────────────────
|
||||
|
||||
async def run(
|
||||
self,
|
||||
end_user_id: str,
|
||||
new_entity_ids: Optional[List[str]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
统一入口:自动判断全量还是增量。
|
||||
|
||||
- 若该用户尚无 Community 节点 → 全量初始化
|
||||
- 否则 → 增量更新(仅处理 new_entity_ids)
|
||||
"""
|
||||
has_communities = await self.repo.has_communities(end_user_id)
|
||||
if not has_communities:
|
||||
logger.info(f"[Clustering] 用户 {end_user_id} 首次聚类,执行全量初始化")
|
||||
await self.full_clustering(end_user_id)
|
||||
else:
|
||||
if new_entity_ids:
|
||||
logger.info(
|
||||
f"[Clustering] 增量更新,新实体数: {len(new_entity_ids)}"
|
||||
)
|
||||
await self.incremental_update(new_entity_ids, end_user_id)
|
||||
|
||||
async def full_clustering(self, end_user_id: str) -> None:
|
||||
"""
|
||||
全量标签传播初始化(分批处理,控制内存峰值)。
|
||||
|
||||
策略:
|
||||
- 每次只加载 BATCH_SIZE 个实体及其邻居进内存
|
||||
- labels 字典跨批次共享(只存 id→community_id,内存极小)
|
||||
- 每批独立跑 MAX_ITERATIONS 轮 LPA,批次间通过 labels 传递社区信息
|
||||
- 所有批次完成后统一 flush 和 merge
|
||||
"""
|
||||
BATCH_SIZE = 888 # 每批实体数,可按需调整
|
||||
|
||||
# 轻量查询:只获取总数和 ID 列表,不加载 embedding 等大字段
|
||||
total_count = await self.repo.get_entity_count(end_user_id)
|
||||
if not total_count:
|
||||
logger.info(f"[Clustering] 用户 {end_user_id} 无实体,跳过全量聚类")
|
||||
return
|
||||
|
||||
all_entity_ids = await self.repo.get_all_entity_ids(end_user_id)
|
||||
logger.info(f"[Clustering] 用户 {end_user_id} 共 {total_count} 个实体,"
|
||||
f"分批大小 {BATCH_SIZE},共 {(total_count + BATCH_SIZE - 1) // BATCH_SIZE} 批")
|
||||
|
||||
# labels 跨批次共享:只存 id→community_id,内存极小
|
||||
labels: Dict[str, str] = {eid: eid for eid in all_entity_ids}
|
||||
del all_entity_ids # 释放 ID 列表,后续按批次加载完整数据
|
||||
|
||||
for batch_start in range(0, total_count, BATCH_SIZE):
|
||||
batch_entities = await self.repo.get_entities_page(
|
||||
end_user_id, skip=batch_start, limit=BATCH_SIZE
|
||||
)
|
||||
if not batch_entities:
|
||||
break
|
||||
|
||||
batch_ids = [e["id"] for e in batch_entities]
|
||||
batch_embeddings: Dict[str, Optional[List[float]]] = {
|
||||
e["id"]: e.get("name_embedding") for e in batch_entities
|
||||
}
|
||||
|
||||
logger.info(
|
||||
f"[Clustering] 批次 {batch_start // BATCH_SIZE + 1}:"
|
||||
f"加载 {len(batch_entities)} 个实体的邻居图..."
|
||||
)
|
||||
neighbors_cache = await self.repo.get_entity_neighbors_for_ids(
|
||||
batch_ids, end_user_id
|
||||
)
|
||||
logger.info(f"[Clustering] 邻居预加载完成,覆盖实体数: {len(neighbors_cache)}")
|
||||
|
||||
for iteration in range(MAX_ITERATIONS):
|
||||
changed = 0
|
||||
for entity in batch_entities:
|
||||
eid = entity["id"]
|
||||
neighbors = neighbors_cache.get(eid, [])
|
||||
|
||||
# 注入跨批次的最新标签(邻居可能在其他批次,labels 里有其最新值)
|
||||
enriched = []
|
||||
for nb in neighbors:
|
||||
nb_copy = dict(nb)
|
||||
nb_copy["community_id"] = labels.get(nb["id"], nb.get("community_id"))
|
||||
enriched.append(nb_copy)
|
||||
|
||||
new_label = _weighted_vote(enriched, batch_embeddings.get(eid))
|
||||
if new_label and new_label != labels[eid]:
|
||||
labels[eid] = new_label
|
||||
changed += 1
|
||||
|
||||
logger.info(
|
||||
f"[Clustering] 批次 {batch_start // BATCH_SIZE + 1} "
|
||||
f"迭代 {iteration + 1}/{MAX_ITERATIONS},标签变化数: {changed}"
|
||||
)
|
||||
if changed == 0:
|
||||
logger.info("[Clustering] 标签已收敛,提前结束本批迭代")
|
||||
break
|
||||
|
||||
# 释放本批次的大对象
|
||||
del neighbors_cache, batch_embeddings, batch_entities
|
||||
|
||||
# 所有批次完成,统一写入 Neo4j
|
||||
await self._flush_labels(labels, end_user_id)
|
||||
pre_merge_count = len(set(labels.values()))
|
||||
logger.info(
|
||||
f"[Clustering] 全量迭代完成,共 {pre_merge_count} 个社区,"
|
||||
f"{len(labels)} 个实体,开始后处理合并"
|
||||
)
|
||||
|
||||
all_community_ids = list(set(labels.values()))
|
||||
await self._evaluate_merge(all_community_ids, end_user_id)
|
||||
|
||||
logger.info(
|
||||
f"[Clustering] 全量聚类完成,合并前 {pre_merge_count} 个社区,"
|
||||
f"{len(labels)} 个实体"
|
||||
)
|
||||
|
||||
# 查询存活社区并生成元数据
|
||||
surviving_communities = await self.repo.get_all_entities(end_user_id)
|
||||
surviving_community_ids = list({
|
||||
e.get("community_id") for e in surviving_communities
|
||||
if e.get("community_id")
|
||||
})
|
||||
logger.info(f"[Clustering] 合并后实际存活社区数: {len(surviving_community_ids)}")
|
||||
await self._generate_community_metadata(surviving_community_ids, end_user_id)
|
||||
|
||||
async def incremental_update(
|
||||
self, new_entity_ids: List[str], end_user_id: str
|
||||
) -> None:
|
||||
"""
|
||||
增量更新:只处理新实体及其邻居,不重跑全图。
|
||||
|
||||
1. 对每个新实体查询邻居
|
||||
2. 加权多数投票决定社区归属
|
||||
3. 若邻居无社区 → 创建新社区
|
||||
4. 若邻居分属多个社区 → 评估是否合并
|
||||
"""
|
||||
# 收集所有需要生成元数据的社区ID
|
||||
communities_to_update = set()
|
||||
|
||||
for entity_id in new_entity_ids:
|
||||
cid = await self._process_single_entity(entity_id, end_user_id)
|
||||
if cid:
|
||||
communities_to_update.add(cid)
|
||||
|
||||
# 批量生成所有社区的元数据
|
||||
if communities_to_update:
|
||||
await self._generate_community_metadata(list(communities_to_update), end_user_id, force=True)
|
||||
|
||||
# ──────────────────────────────────────────────────────────────────────────
|
||||
# 内部方法
|
||||
# ──────────────────────────────────────────────────────────────────────────
|
||||
|
||||
async def _process_single_entity(
|
||||
self, entity_id: str, end_user_id: str
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
处理单个新实体的社区分配。
|
||||
|
||||
该函数会为新实体分配社区,可能的情况包括:
|
||||
1. 孤立实体(无邻居):创建新的单成员社区
|
||||
2. 邻居都没有社区:创建新社区并将实体和邻居都加入
|
||||
3. 邻居有社区:通过加权投票选择最合适的社区加入
|
||||
|
||||
Returns:
|
||||
Optional[str]: 分配到的社区ID。当前实现总是返回一个有效的社区ID,
|
||||
但返回类型保留为Optional以支持未来可能的扩展场景
|
||||
(例如:实体无法分配到任何社区的情况)。
|
||||
调用方应检查返回值的真假性(truthiness)。
|
||||
"""
|
||||
neighbors = await self.repo.get_entity_neighbors(entity_id, end_user_id)
|
||||
|
||||
# 查询自身 embedding(从邻居查询结果中无法获取,需单独查)
|
||||
self_embedding = await self._get_entity_embedding(entity_id, end_user_id)
|
||||
|
||||
if not neighbors:
|
||||
# 孤立实体:创建单成员社区
|
||||
new_cid = self._new_community_id()
|
||||
await self.repo.upsert_community(new_cid, end_user_id, member_count=1)
|
||||
await self.repo.assign_entity_to_community(entity_id, new_cid, end_user_id)
|
||||
logger.debug(f"[Clustering] 孤立实体 {entity_id} → 新社区 {new_cid}")
|
||||
return new_cid
|
||||
|
||||
# 统计邻居社区分布
|
||||
community_ids_in_neighbors = set(
|
||||
nb["community_id"] for nb in neighbors if nb.get("community_id")
|
||||
)
|
||||
|
||||
target_cid = _weighted_vote(neighbors, self_embedding)
|
||||
|
||||
if target_cid is None:
|
||||
# 邻居都没有社区,连同新实体一起创建新社区
|
||||
new_cid = self._new_community_id()
|
||||
await self.repo.upsert_community(new_cid, end_user_id)
|
||||
await self.repo.assign_entity_to_community(entity_id, new_cid, end_user_id)
|
||||
for nb in neighbors:
|
||||
await self.repo.assign_entity_to_community(
|
||||
nb["id"], new_cid, end_user_id
|
||||
)
|
||||
await self.repo.refresh_member_count(new_cid, end_user_id)
|
||||
logger.debug(
|
||||
f"[Clustering] 新实体 {entity_id} 与 {len(neighbors)} 个无社区邻居 → 新社区 {new_cid}"
|
||||
)
|
||||
return new_cid
|
||||
else:
|
||||
# 加入得票最多的社区
|
||||
await self.repo.assign_entity_to_community(entity_id, target_cid, end_user_id)
|
||||
await self.repo.refresh_member_count(target_cid, end_user_id)
|
||||
logger.debug(f"[Clustering] 新实体 {entity_id} → 社区 {target_cid}")
|
||||
|
||||
# 若邻居分属多个社区,评估合并
|
||||
if len(community_ids_in_neighbors) > 1:
|
||||
await self._evaluate_merge(
|
||||
list(community_ids_in_neighbors), end_user_id
|
||||
)
|
||||
# 返回目标社区ID,稍后批量生成元数据
|
||||
return target_cid
|
||||
|
||||
async def _evaluate_merge(
|
||||
self, community_ids: List[str], end_user_id: str
|
||||
) -> None:
|
||||
"""
|
||||
评估多个社区是否应合并。
|
||||
|
||||
策略:计算各社区成员 embedding 的平均向量,若两两余弦相似度 > 0.75 则合并。
|
||||
合并时保留成员数最多的社区,其余成员迁移过来。
|
||||
|
||||
全量场景(社区数 > 20)使用批量查询,避免 N 次数据库往返。
|
||||
"""
|
||||
MERGE_THRESHOLD = 0.85
|
||||
BATCH_THRESHOLD = 20 # 超过此数量走批量查询
|
||||
|
||||
community_embeddings: Dict[str, Optional[List[float]]] = {}
|
||||
community_sizes: Dict[str, int] = {}
|
||||
|
||||
if len(community_ids) > BATCH_THRESHOLD:
|
||||
# 批量查询:一次拉取所有社区成员
|
||||
all_members = await self.repo.get_all_community_members_batch(
|
||||
community_ids, end_user_id
|
||||
)
|
||||
for cid in community_ids:
|
||||
members = all_members.get(cid, [])
|
||||
community_sizes[cid] = len(members)
|
||||
valid_embeddings = [
|
||||
m["name_embedding"] for m in members if m.get("name_embedding")
|
||||
]
|
||||
if valid_embeddings:
|
||||
dim = len(valid_embeddings[0])
|
||||
community_embeddings[cid] = [
|
||||
sum(e[i] for e in valid_embeddings) / len(valid_embeddings)
|
||||
for i in range(dim)
|
||||
]
|
||||
else:
|
||||
community_embeddings[cid] = None
|
||||
else:
|
||||
# 增量场景:逐个查询
|
||||
for cid in community_ids:
|
||||
members = await self.repo.get_community_members(cid, end_user_id)
|
||||
community_sizes[cid] = len(members)
|
||||
valid_embeddings = [
|
||||
m["name_embedding"] for m in members if m.get("name_embedding")
|
||||
]
|
||||
if valid_embeddings:
|
||||
dim = len(valid_embeddings[0])
|
||||
community_embeddings[cid] = [
|
||||
sum(e[i] for e in valid_embeddings) / len(valid_embeddings)
|
||||
for i in range(dim)
|
||||
]
|
||||
else:
|
||||
community_embeddings[cid] = None
|
||||
|
||||
# 找出应合并的社区对
|
||||
to_merge: List[tuple] = []
|
||||
cids = list(community_ids)
|
||||
for i in range(len(cids)):
|
||||
for j in range(i + 1, len(cids)):
|
||||
sim = _cosine_similarity(
|
||||
community_embeddings[cids[i]],
|
||||
community_embeddings[cids[j]],
|
||||
)
|
||||
if sim > MERGE_THRESHOLD:
|
||||
to_merge.append((cids[i], cids[j]))
|
||||
|
||||
logger.info(f"[Clustering] 发现 {len(to_merge)} 对可合并社区")
|
||||
|
||||
# 执行合并:逐对处理,每次合并后重新计算合并社区的平均向量
|
||||
# 避免 union-find 链式传递导致语义不相关的社区被间接合并
|
||||
# (A≈B、B≈C 不代表 A≈C,不能因传递性把 A/B/C 全部合并)
|
||||
merged_into: Dict[str, str] = {} # dissolve → keep 的最终映射
|
||||
|
||||
def get_root(x: str) -> str:
|
||||
"""路径压缩,找到 x 当前所属的根社区。"""
|
||||
while x in merged_into:
|
||||
merged_into[x] = merged_into.get(merged_into[x], merged_into[x])
|
||||
x = merged_into[x]
|
||||
return x
|
||||
|
||||
for c1, c2 in to_merge:
|
||||
root1, root2 = get_root(c1), get_root(c2)
|
||||
if root1 == root2:
|
||||
continue
|
||||
|
||||
# 用合并后的最新平均向量重新验证相似度
|
||||
# 防止链式传递:A≈B 合并后 B 的向量已更新,C 必须和新 B 相似才能合并
|
||||
current_sim = _cosine_similarity(
|
||||
community_embeddings.get(root1),
|
||||
community_embeddings.get(root2),
|
||||
)
|
||||
if current_sim <= MERGE_THRESHOLD:
|
||||
# 合并后向量已漂移,不再满足阈值,跳过
|
||||
logger.debug(
|
||||
f"[Clustering] 跳过合并 {root1} ↔ {root2},"
|
||||
f"当前相似度 {current_sim:.3f} ≤ {MERGE_THRESHOLD}"
|
||||
)
|
||||
continue
|
||||
|
||||
keep = root1 if community_sizes.get(root1, 0) >= community_sizes.get(root2, 0) else root2
|
||||
dissolve = root2 if keep == root1 else root1
|
||||
merged_into[dissolve] = keep
|
||||
|
||||
members = await self.repo.get_community_members(dissolve, end_user_id)
|
||||
for m in members:
|
||||
await self.repo.assign_entity_to_community(m["id"], keep, end_user_id)
|
||||
|
||||
# 合并后重新计算 keep 的平均向量(加权平均)
|
||||
keep_emb = community_embeddings.get(keep)
|
||||
dissolve_emb = community_embeddings.get(dissolve)
|
||||
keep_size = community_sizes.get(keep, 0)
|
||||
dissolve_size = community_sizes.get(dissolve, 0)
|
||||
total_size = keep_size + dissolve_size
|
||||
if keep_emb and dissolve_emb and total_size > 0:
|
||||
dim = len(keep_emb)
|
||||
community_embeddings[keep] = [
|
||||
(keep_emb[i] * keep_size + dissolve_emb[i] * dissolve_size) / total_size
|
||||
for i in range(dim)
|
||||
]
|
||||
community_embeddings[dissolve] = None
|
||||
|
||||
community_sizes[keep] = total_size
|
||||
community_sizes[dissolve] = 0
|
||||
await self.repo.refresh_member_count(keep, end_user_id)
|
||||
logger.info(
|
||||
f"[Clustering] 社区合并: {dissolve} → {keep},"
|
||||
f"相似度={current_sim:.3f},迁移 {len(members)} 个成员"
|
||||
)
|
||||
|
||||
async def _flush_labels(
|
||||
self, labels: Dict[str, str], end_user_id: str
|
||||
) -> None:
|
||||
"""将内存中的标签批量写入 Neo4j。"""
|
||||
# 先创建所有唯一社区节点
|
||||
unique_communities = set(labels.values())
|
||||
for cid in unique_communities:
|
||||
await self.repo.upsert_community(cid, end_user_id)
|
||||
|
||||
# 再批量分配实体
|
||||
for entity_id, community_id in labels.items():
|
||||
await self.repo.assign_entity_to_community(
|
||||
entity_id, community_id, end_user_id
|
||||
)
|
||||
|
||||
# 刷新成员数
|
||||
for cid in unique_communities:
|
||||
await self.repo.refresh_member_count(cid, end_user_id)
|
||||
|
||||
async def _get_entity_embedding(
|
||||
self, entity_id: str, end_user_id: str
|
||||
) -> Optional[List[float]]:
|
||||
"""查询单个实体的 name_embedding。"""
|
||||
try:
|
||||
result = await self.connector.execute_query(
|
||||
"MATCH (e:ExtractedEntity {id: $eid, end_user_id: $uid}) "
|
||||
"RETURN e.name_embedding AS name_embedding",
|
||||
eid=entity_id,
|
||||
uid=end_user_id,
|
||||
)
|
||||
return result[0]["name_embedding"] if result else None
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _build_entity_lines(members: List[Dict]) -> List[str]:
|
||||
"""将实体列表格式化为 prompt 行,包含 name、aliases、description、example。"""
|
||||
lines = []
|
||||
for m in members:
|
||||
m_name = m.get("name", "")
|
||||
aliases = m.get("aliases") or []
|
||||
description = m.get("description") or ""
|
||||
example = m.get("example") or ""
|
||||
aliases_str = f"(别名:{'、'.join(aliases)})" if aliases else ""
|
||||
desc_str = f":{description}" if description else ""
|
||||
example_str = f"(示例:{example})" if example else ""
|
||||
lines.append(f"- {m_name}{aliases_str}{desc_str}{example_str}")
|
||||
return lines
|
||||
|
||||
async def _generate_community_metadata(
|
||||
self, community_ids: List[str], end_user_id: str, force: bool = False
|
||||
) -> None:
|
||||
"""
|
||||
为一个或多个社区生成并写入元数据(优化版:批量 LLM 调用)。
|
||||
|
||||
流程:
|
||||
1. 批量准备所有社区的 prompt
|
||||
2. 并发调用 LLM 生成所有社区的 name / summary
|
||||
3. 批量 embed 所有 summary
|
||||
4. 批量写入数据库
|
||||
|
||||
Args:
|
||||
force: 为 True 时跳过完整性检查,强制重新生成(用于增量更新成员变化后)
|
||||
"""
|
||||
async def _prepare_one(cid: str) -> Optional[Dict]:
|
||||
"""准备单个社区的数据和 prompt"""
|
||||
try:
|
||||
if not force:
|
||||
check_embedding = bool(self.embedding_model_id)
|
||||
if await self.repo.is_community_complete(cid, end_user_id, check_embedding=check_embedding):
|
||||
return None
|
||||
|
||||
members = await self.repo.get_community_members(cid, end_user_id)
|
||||
if not members:
|
||||
logger.warning(f"[Clustering] 社区 {cid} 无成员,跳过元数据生成")
|
||||
return None
|
||||
|
||||
sorted_members = sorted(
|
||||
members,
|
||||
key=lambda m: m.get("activation_value") or 0,
|
||||
reverse=True,
|
||||
)
|
||||
core_entities = [m["name"] for m in sorted_members[:CORE_ENTITY_LIMIT] if m.get("name")]
|
||||
all_names = [m["name"] for m in members if m.get("name")]
|
||||
|
||||
# 默认值
|
||||
name = "、".join(core_entities[:3]) if core_entities else cid[:8]
|
||||
summary = f"包含实体:{', '.join(all_names)}"
|
||||
|
||||
# 准备 LLM prompt(如果配置了 LLM)
|
||||
prompt = None
|
||||
if self.llm_model_id:
|
||||
entity_list_str = "\n".join(self._build_entity_lines(members))
|
||||
relationships = await self.repo.get_community_relationships(cid, end_user_id)
|
||||
rel_lines = [
|
||||
f"- {r['subject']} → {r['predicate']} → {r['object']}"
|
||||
for r in relationships
|
||||
if r.get("subject") and r.get("predicate") and r.get("object")
|
||||
]
|
||||
rel_section = (
|
||||
f"\n实体间关系:\n" + "\n".join(rel_lines)
|
||||
if rel_lines else ""
|
||||
)
|
||||
prompt = (
|
||||
f"以下是一组语义相关的实体:\n{entity_list_str}{rel_section}\n\n"
|
||||
f"请为这组实体所代表的主题:\n"
|
||||
f"1. 起一个简洁的中文名称(不超过10个字)\n"
|
||||
f"2. 写一句话摘要(不超过80个字)\n\n"
|
||||
f"严格按以下格式输出,不要有其他内容:\n"
|
||||
f"名称:<名称>\n摘要:<摘要>"
|
||||
)
|
||||
|
||||
return {
|
||||
"community_id": cid,
|
||||
"end_user_id": end_user_id,
|
||||
"name": name,
|
||||
"summary": summary,
|
||||
"core_entities": core_entities,
|
||||
"prompt": prompt,
|
||||
"summary_embedding": None,
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"[Clustering] 社区 {cid} 元数据准备失败: {e}", exc_info=True)
|
||||
return None
|
||||
|
||||
# --- 阶段1:并发准备所有社区数据 ---
|
||||
results = await asyncio.gather(
|
||||
*[_prepare_one(cid) for cid in community_ids],
|
||||
return_exceptions=True,
|
||||
)
|
||||
metadata_list = []
|
||||
for cid, res in zip(community_ids, results):
|
||||
if isinstance(res, Exception):
|
||||
logger.error(f"[Clustering] 社区 {cid} 元数据准备失败: {res}", exc_info=res)
|
||||
elif res is not None:
|
||||
metadata_list.append(res)
|
||||
|
||||
if not metadata_list:
|
||||
logger.warning(f"[Clustering] 无有效元数据可写入,community_ids={community_ids}")
|
||||
return
|
||||
|
||||
# --- 阶段2:批量调用 LLM 生成 name 和 summary ---
|
||||
if self.llm_model_id:
|
||||
llm_client = self._get_llm_client()
|
||||
if not llm_client:
|
||||
logger.warning(
|
||||
f"[Clustering] LLM 已配置(model_id={self.llm_model_id})但客户端初始化失败,"
|
||||
f"将跳过社区元数据的 LLM 富化。请检查 model_id 是否正确或数据库连接是否正常。"
|
||||
)
|
||||
if llm_client:
|
||||
prompts_to_process = [(i, m) for i, m in enumerate(metadata_list) if m.get("prompt")]
|
||||
|
||||
if prompts_to_process:
|
||||
logger.info(f"[Clustering] 批量调用 LLM 生成 {len(prompts_to_process)} 个社区元数据")
|
||||
|
||||
async def _call_llm(idx: int, meta: Dict) -> tuple:
|
||||
"""单个 LLM 调用"""
|
||||
try:
|
||||
response = await llm_client.chat([{"role": "user", "content": meta["prompt"]}])
|
||||
text = response.content if hasattr(response, "content") else str(response)
|
||||
return (idx, text, None)
|
||||
except Exception as e:
|
||||
logger.warning(f"[Clustering] 社区 {meta['community_id']} LLM 生成失败: {e}")
|
||||
return (idx, None, e)
|
||||
|
||||
# 并发调用所有 LLM 请求
|
||||
llm_results = await asyncio.gather(
|
||||
*[_call_llm(idx, meta) for idx, meta in prompts_to_process],
|
||||
return_exceptions=True
|
||||
)
|
||||
|
||||
# 解析 LLM 响应
|
||||
for result in llm_results:
|
||||
if isinstance(result, Exception):
|
||||
continue
|
||||
idx, text, error = result
|
||||
if error or not text:
|
||||
continue
|
||||
|
||||
meta = metadata_list[idx]
|
||||
for line in text.strip().splitlines():
|
||||
if line.startswith("名称:"):
|
||||
meta["name"] = line[3:].strip()
|
||||
elif line.startswith("摘要:"):
|
||||
meta["summary"] = line[3:].strip()
|
||||
|
||||
logger.info(f"[Clustering] LLM 批量生成完成")
|
||||
|
||||
# --- 阶段3:批量生成 summary_embedding ---
|
||||
if self.embedding_model_id:
|
||||
embedder = self._get_embedder_client()
|
||||
if not embedder:
|
||||
logger.warning(
|
||||
f"[Clustering] Embedding 已配置(model_id={self.embedding_model_id})但客户端初始化失败,"
|
||||
f"将跳过社区摘要的向量化。请检查 model_id 是否正确或数据库连接是否正常。"
|
||||
)
|
||||
if embedder:
|
||||
try:
|
||||
summaries = [m["summary"] for m in metadata_list]
|
||||
logger.info(f"[Clustering] 批量生成 {len(summaries)} 个 summary embedding")
|
||||
embeddings = await embedder.response(summaries)
|
||||
for i, meta in enumerate(metadata_list):
|
||||
meta["summary_embedding"] = embeddings[i] if i < len(embeddings) else None
|
||||
logger.info(f"[Clustering] Embedding 批量生成完成")
|
||||
except Exception as e:
|
||||
logger.error(f"[Clustering] 批量生成 summary_embedding 失败: {e}", exc_info=True)
|
||||
|
||||
# --- 阶段4:批量写入数据库 ---
|
||||
# 移除 prompt 字段(不需要存储)
|
||||
for m in metadata_list:
|
||||
m.pop("prompt", None)
|
||||
|
||||
if len(metadata_list) == 1:
|
||||
m = metadata_list[0]
|
||||
result = await self.repo.update_community_metadata(
|
||||
community_id=m["community_id"],
|
||||
end_user_id=m["end_user_id"],
|
||||
name=m["name"],
|
||||
summary=m["summary"],
|
||||
core_entities=m["core_entities"],
|
||||
summary_embedding=m["summary_embedding"],
|
||||
)
|
||||
if not result:
|
||||
logger.error(f"[Clustering] 社区 {m['community_id']} 元数据写入失败")
|
||||
else:
|
||||
ok = await self.repo.batch_update_community_metadata(metadata_list)
|
||||
if not ok:
|
||||
logger.error(f"[Clustering] 批量写入 {len(metadata_list)} 个社区元数据失败")
|
||||
else:
|
||||
logger.info(f"[Clustering] 批量写入 {len(metadata_list)} 个社区元数据成功")
|
||||
|
||||
def _get_llm_client(self):
|
||||
"""获取或创建 LLM 客户端(单例模式)"""
|
||||
if self._llm_client is None and self.llm_model_id:
|
||||
from app.db import get_db_context
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
with get_db_context() as db:
|
||||
self._llm_client = MemoryClientFactory(db).get_llm_client(self.llm_model_id)
|
||||
logger.info(f"[Clustering] LLM 客户端初始化完成(单例): model_id={self.llm_model_id}")
|
||||
return self._llm_client
|
||||
|
||||
def _get_embedder_client(self):
|
||||
"""获取或创建 Embedder 客户端(单例模式)"""
|
||||
if self._embedder_client is None and self.embedding_model_id:
|
||||
from app.db import get_db_context
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
with get_db_context() as db:
|
||||
self._embedder_client = MemoryClientFactory(db).get_embedder_client(self.embedding_model_id)
|
||||
logger.info(f"[Clustering] Embedder 客户端初始化完成(单例): model_id={self.embedding_model_id}")
|
||||
return self._embedder_client
|
||||
|
||||
@staticmethod
|
||||
def _new_community_id() -> str:
|
||||
return str(uuid.uuid4())
|
||||
@@ -9,6 +9,7 @@
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import hashlib
|
||||
import json
|
||||
@@ -20,13 +21,26 @@ from pydantic import BaseModel, Field
|
||||
|
||||
from app.core.memory.models.message_models import DialogData, ConversationMessage, ConversationContext
|
||||
from app.core.memory.models.config_models import PruningConfig
|
||||
from app.core.memory.utils.config.config_utils import get_pruning_config
|
||||
from app.core.memory.utils.prompt.prompt_utils import prompt_env, log_prompt_rendering, log_template_rendering
|
||||
from app.core.memory.storage_services.extraction_engine.data_preprocessing.scene_config import (
|
||||
SceneConfigRegistry,
|
||||
ScenePatterns
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def message_has_files(message: "ConversationMessage") -> bool:
|
||||
"""检查消息是否包含文件。
|
||||
|
||||
Args:
|
||||
message: 待检查的消息对象
|
||||
|
||||
Returns:
|
||||
bool: 如果消息包含文件则返回 True,否则返回 False
|
||||
"""
|
||||
return message.files and len(message.files) > 0
|
||||
|
||||
|
||||
class DialogExtractionResponse(BaseModel):
|
||||
"""对话级一次性抽取的结构化返回,用于加速剪枝。
|
||||
@@ -34,6 +48,8 @@ class DialogExtractionResponse(BaseModel):
|
||||
- is_related:对话与场景的相关性判定。
|
||||
- times / ids / amounts / contacts / addresses / keywords:重要信息片段,用来在不相关对话中保留关键消息。
|
||||
- preserve_keywords:情绪/兴趣/爱好/个人观点相关词,包含这些词的消息必须强制保留。
|
||||
- scene_unrelated_snippets:与当前场景无关且无语义关联的消息片段(原文截取),
|
||||
用于高阈值阶段精准删除跨场景内容。
|
||||
"""
|
||||
is_related: bool = Field(...)
|
||||
times: List[str] = Field(default_factory=list)
|
||||
@@ -43,6 +59,7 @@ class DialogExtractionResponse(BaseModel):
|
||||
addresses: List[str] = Field(default_factory=list)
|
||||
keywords: List[str] = Field(default_factory=list)
|
||||
preserve_keywords: List[str] = Field(default_factory=list, description="情绪/兴趣/爱好/个人观点相关词,包含这些词的消息强制保留")
|
||||
scene_unrelated_snippets: List[str] = Field(default_factory=list,description="与当前场景无关且无语义关联的消息原文片段,高阈值阶段用于精准删除跨场景内容")
|
||||
|
||||
|
||||
class MessageImportanceResponse(BaseModel):
|
||||
@@ -91,12 +108,14 @@ class SemanticPruner:
|
||||
# 加载统一填充词库
|
||||
self.scene_config: ScenePatterns = SceneConfigRegistry.get_config(self.config.pruning_scene)
|
||||
|
||||
# 本体类型列表(用于注入提示词,所有场景均支持)
|
||||
self._ontology_classes = getattr(self.config, "ontology_classes", None) or []
|
||||
# 本体类型列表:直接使用 ontology_class_infos(name + description)
|
||||
self._ontology_class_infos = getattr(self.config, "ontology_class_infos", None) or []
|
||||
# _ontology_classes 仅用于日志统计
|
||||
self._ontology_classes = [info.class_name for info in self._ontology_class_infos]
|
||||
|
||||
self._log(f"[剪枝-初始化] 场景={self.config.pruning_scene}")
|
||||
if self._ontology_classes:
|
||||
self._log(f"[剪枝-初始化] 注入本体类型: {self._ontology_classes}")
|
||||
if self._ontology_class_infos:
|
||||
self._log(f"[剪枝-初始化] 注入本体类型({len(self._ontology_class_infos)}个): {self._ontology_classes}")
|
||||
else:
|
||||
self._log(f"[剪枝-初始化] 未找到本体类型,将使用通用提示词")
|
||||
|
||||
@@ -121,7 +140,8 @@ class SemanticPruner:
|
||||
1. 空消息
|
||||
2. 场景特定填充词库精确匹配
|
||||
3. 常见寒暄精确匹配
|
||||
4. 纯表情/标点
|
||||
4. 组合寒暄模式(前缀 + 后缀组合,如"好的谢谢"、"同学你好"、"明白了")
|
||||
5. 纯表情/标点
|
||||
"""
|
||||
t = message.msg.strip()
|
||||
if not t:
|
||||
@@ -143,6 +163,55 @@ class SemanticPruner:
|
||||
if t in common_greetings:
|
||||
return True
|
||||
|
||||
# 组合寒暄模式:短消息(≤15字)且完全由寒暄成分构成
|
||||
# 策略:将消息拆分后,每个片段都能在填充词库或常见寒暄中找到,则整体为填充
|
||||
if len(t) <= 15:
|
||||
# 确认+称呼/感谢组合,如"好的谢谢"、"明白了"、"知道了谢谢"
|
||||
_confirm_prefixes = {"好的", "好", "嗯", "嗯嗯", "哦", "明白", "明白了", "知道了", "了解", "收到", "没问题"}
|
||||
_thanks_suffixes = {"谢谢", "谢谢你", "谢谢您", "多谢", "感谢", "谢了"}
|
||||
_greeting_suffixes = {"你好", "您好", "老师好", "同学好", "大家好"}
|
||||
_greeting_prefixes = {"同学", "老师", "您好", "你好"}
|
||||
_close_patterns = {
|
||||
"没有了", "没事了", "没问题了", "好了", "行了", "可以了",
|
||||
"不用了", "不需要了", "就这样", "就这样吧", "那就这样",
|
||||
}
|
||||
_polite_responses = {
|
||||
"不客气", "不用谢", "没关系", "没事", "应该的", "这是我应该做的",
|
||||
}
|
||||
|
||||
# 规则1:确认词 + 感谢词(如"好的谢谢"、"嗯谢谢")
|
||||
for cp in _confirm_prefixes:
|
||||
for ts in _thanks_suffixes:
|
||||
if t == cp + ts or t == cp + "," + ts or t == cp + "," + ts:
|
||||
return True
|
||||
|
||||
# 规则2:称呼前缀 + 问候(如"同学你好"、"老师好")
|
||||
for gp in _greeting_prefixes:
|
||||
for gs in _greeting_suffixes:
|
||||
if t == gp + gs or t.startswith(gp) and t.endswith("好"):
|
||||
return True
|
||||
|
||||
# 规则3:结束语 + 感谢(如"没有了,谢谢老师"、"没有了谢谢")
|
||||
for cp in _close_patterns:
|
||||
if t.startswith(cp):
|
||||
remainder = t[len(cp):].lstrip(",,、 ")
|
||||
if not remainder or any(remainder.startswith(ts) for ts in _thanks_suffixes):
|
||||
return True
|
||||
|
||||
# 规则4:礼貌回应(如"不客气,祝你考试顺利"——前缀是礼貌词,后半是祝福套话)
|
||||
for pr in _polite_responses:
|
||||
if t.startswith(pr):
|
||||
remainder = t[len(pr):].lstrip(",,、 ")
|
||||
# 后半是祝福/套话(不含实质信息)
|
||||
if not remainder or re.match(r"^(祝|希望|期待|加油|顺利|好好|保重)", remainder):
|
||||
return True
|
||||
|
||||
# 规则5:纯确认词加"了"后缀(如"明白了"、"知道了"、"好了")
|
||||
_confirm_base = {"明白", "知道", "了解", "收到", "好", "行", "可以", "没问题"}
|
||||
for cb in _confirm_base:
|
||||
if t == cb + "了" or t == cb + "了。" or t == cb + "了!":
|
||||
return True
|
||||
|
||||
# 检查是否为纯表情符号(方括号包裹)
|
||||
if re.fullmatch(r"(\[[^\]]+\])+", t):
|
||||
return True
|
||||
@@ -331,13 +400,13 @@ class SemanticPruner:
|
||||
|
||||
rendered = self.template.render(
|
||||
pruning_scene=self.config.pruning_scene,
|
||||
ontology_classes=self._ontology_classes,
|
||||
ontology_class_infos=self._ontology_class_infos,
|
||||
dialog_text=dialog_text,
|
||||
language=self.language
|
||||
)
|
||||
log_template_rendering("extracat_Pruning.jinja2", {
|
||||
"pruning_scene": self.config.pruning_scene,
|
||||
"ontology_classes_count": len(self._ontology_classes),
|
||||
"ontology_class_infos_count": len(self._ontology_class_infos),
|
||||
"language": self.language
|
||||
})
|
||||
log_prompt_rendering("pruning-extract", rendered)
|
||||
@@ -377,6 +446,193 @@ class SemanticPruner:
|
||||
)
|
||||
return fallback_response
|
||||
|
||||
def _get_pruning_mode(self) -> str:
|
||||
"""根据 pruning_threshold 返回当前剪枝阶段。
|
||||
|
||||
- 低阈值 [0.0, 0.3):conservative 只删填充,保留所有实质内容
|
||||
- 中阈值 [0.3, 0.6):semantic 保留场景相关 + 有语义关联的内容,删除无关联内容
|
||||
- 高阈值 [0.6, 0.9]:strict 只保留场景相关内容,跨场景内容可被删除
|
||||
"""
|
||||
t = float(self.config.pruning_threshold)
|
||||
if t < 0.3:
|
||||
return "conservative"
|
||||
elif t < 0.6:
|
||||
return "semantic"
|
||||
else:
|
||||
return "strict"
|
||||
|
||||
def _apply_related_dialog_pruning(
|
||||
self,
|
||||
msgs: List[ConversationMessage],
|
||||
extraction: "DialogExtractionResponse",
|
||||
dialog_label: str,
|
||||
pruning_mode: str,
|
||||
) -> List[ConversationMessage]:
|
||||
"""相关对话统一剪枝入口,消除 prune_dialog / prune_dataset 中的重复逻辑。
|
||||
|
||||
- conservative:只删填充
|
||||
- semantic / strict:场景感知剪枝
|
||||
"""
|
||||
if pruning_mode == "conservative":
|
||||
preserve_tokens = self._build_preserve_tokens(extraction)
|
||||
return self._prune_fillers_only(msgs, preserve_tokens, dialog_label)
|
||||
else:
|
||||
return self._prune_with_scene_filter(msgs, extraction, dialog_label, pruning_mode)
|
||||
|
||||
def _prune_fillers_only(
|
||||
self,
|
||||
msgs: List[ConversationMessage],
|
||||
preserve_tokens: List[str],
|
||||
dialog_label: str,
|
||||
) -> List[ConversationMessage]:
|
||||
"""相关对话专用:只删填充消息,LLM 保护消息和实质内容一律保留。
|
||||
|
||||
不受 pruning_threshold 约束,删多少算多少(填充有多少删多少)。
|
||||
至少保留 1 条消息。
|
||||
注意:填充检测优先于 preserve_tokens 保护——填充消息本身无信息价值,
|
||||
即使 LLM 误将其关键词放入 preserve_tokens 也应删除。
|
||||
"""
|
||||
to_delete_ids: set = set()
|
||||
for m in msgs:
|
||||
# 最高优先级保护:带有文件的消息一律保留,不参与任何剪枝判断
|
||||
if message_has_files(m):
|
||||
self._log(f" [保护] 带文件的消息(不参与剪枝):'{m.msg[:40]}',文件数={len(m.files)}")
|
||||
continue
|
||||
|
||||
# 填充检测优先:先判断是否为填充,再看 LLM 保护
|
||||
if self._is_filler_message(m):
|
||||
to_delete_ids.add(id(m))
|
||||
self._log(f" [填充] '{m.msg[:40]}' → 删除")
|
||||
continue
|
||||
if self._msg_matches_tokens(m, preserve_tokens):
|
||||
self._log(f" [保护] '{m.msg[:40]}' → LLM保护,跳过")
|
||||
|
||||
kept = [m for m in msgs if id(m) not in to_delete_ids]
|
||||
if not kept and msgs:
|
||||
kept = [msgs[0]]
|
||||
|
||||
deleted = len(msgs) - len(kept)
|
||||
self._log(
|
||||
f"[剪枝-相关] {dialog_label} 总消息={len(msgs)} "
|
||||
f"填充删除={deleted} 保留={len(kept)}"
|
||||
)
|
||||
return kept
|
||||
|
||||
def _prune_with_scene_filter(
|
||||
self,
|
||||
msgs: List[ConversationMessage],
|
||||
extraction: "DialogExtractionResponse",
|
||||
dialog_label: str,
|
||||
mode: str,
|
||||
) -> List[ConversationMessage]:
|
||||
"""场景感知剪枝,供 semantic / strict 两个阈值档位调用。
|
||||
|
||||
本函数体现剪枝系统的三层递进逻辑:
|
||||
|
||||
第一层(conservative,阈值 < 0.3):
|
||||
不进入本函数,由 _prune_fillers_only 处理。
|
||||
保留标准:只问"有没有信息量",填充消息(嗯/好的/哈哈等)删除,其余一律保留。
|
||||
|
||||
第二层(semantic,阈值 [0.3, 0.6)):
|
||||
保留标准:内容价值优先,场景相关性是参考而非唯一标准。
|
||||
- 填充消息 → 删除(最高优先级)
|
||||
- 场景相关消息 → 保留
|
||||
- 场景无关消息 → 有两次豁免机会:
|
||||
1. 命中 scene_preserve_tokens(LLM 标记的关键词/时间/金额等)→ 保留
|
||||
2. 含情感词(感觉/压力/开心等)→ 保留(情感内容有记忆价值)
|
||||
3. 两次豁免均未命中 → 删除
|
||||
|
||||
第三层(strict,阈值 [0.6, 0.9]):
|
||||
保留标准:场景相关性优先,无任何豁免。
|
||||
- 填充消息 → 删除(最高优先级)
|
||||
- 场景相关消息 → 保留
|
||||
- 场景无关消息 → 直接删除,preserve_keywords 和情感词在此模式下均不生效
|
||||
|
||||
至少保留 1 条消息(兜底取第一条)。
|
||||
"""
|
||||
# strict 模式收窄保护范围:只保护结构化关键信息(时间/编号/金额/联系方式/地址),
|
||||
# 不保护 keywords / preserve_keywords,让场景过滤能删掉更多内容。
|
||||
# semantic 模式完整保护:包含 LLM 抽取的所有重要片段(含 keywords 和 preserve_keywords)。
|
||||
if mode == "strict":
|
||||
scene_preserve_tokens = (
|
||||
extraction.times + extraction.ids + extraction.amounts +
|
||||
extraction.contacts + extraction.addresses
|
||||
)
|
||||
else:
|
||||
scene_preserve_tokens = self._build_preserve_tokens(extraction)
|
||||
|
||||
unrelated_snippets = extraction.scene_unrelated_snippets or []
|
||||
|
||||
to_delete_ids: set = set()
|
||||
for m in msgs:
|
||||
msg_text = m.msg.strip()
|
||||
|
||||
# 最高优先级保护:带有文件的消息一律保留,不参与任何剪枝判断
|
||||
if message_has_files(m):
|
||||
self._log(f" [保护] 带文件的消息(不参与剪枝):'{msg_text[:40]}',文件数={len(m.files)}")
|
||||
continue
|
||||
|
||||
# 第一优先级:填充消息无论模式直接删除,不参与后续场景判断
|
||||
if self._is_filler_message(m):
|
||||
to_delete_ids.add(id(m))
|
||||
self._log(f" [填充] '{msg_text[:40]}' → 删除")
|
||||
continue
|
||||
|
||||
# 双向包含匹配:处理 LLM 返回片段与原始消息文本长度不完全一致的情况
|
||||
is_scene_unrelated = any(
|
||||
snip and (snip in msg_text or msg_text in snip)
|
||||
for snip in unrelated_snippets
|
||||
)
|
||||
|
||||
if is_scene_unrelated:
|
||||
if mode == "strict":
|
||||
# strict:场景无关直接删除,不做任何豁免
|
||||
# 场景相关性是唯一裁决标准,preserve_keywords 在此模式下不生效
|
||||
to_delete_ids.add(id(m))
|
||||
self._log(f" [场景无关-严格] '{msg_text[:40]}' → 删除")
|
||||
elif mode == "semantic":
|
||||
# semantic:场景无关但有内容价值 → 保留
|
||||
# 豁免第一层:命中 scene_preserve_tokens(关键词/结构化信息保护)
|
||||
if self._msg_matches_tokens(m, scene_preserve_tokens):
|
||||
self._log(f" [保护] '{msg_text[:40]}' → 场景关键词保护,保留")
|
||||
else:
|
||||
# 豁免第二层:含情感词,认为有情境记忆价值,即使场景无关也保留
|
||||
has_contextual_emotion = any(
|
||||
word in msg_text
|
||||
for word in ["感觉", "觉得", "心情", "开心", "难过", "高兴", "沮丧",
|
||||
"喜欢", "讨厌", "爱", "恨", "担心", "害怕", "兴奋",
|
||||
"压力", "累", "疲惫", "烦", "焦虑", "委屈", "感动"]
|
||||
)
|
||||
if not has_contextual_emotion:
|
||||
to_delete_ids.add(id(m))
|
||||
self._log(f" [场景无关-语义] '{msg_text[:40]}' → 删除(无情感关联)")
|
||||
else:
|
||||
self._log(f" [场景关联-保留] '{msg_text[:40]}' → 有情感关联,保留")
|
||||
else:
|
||||
# 不在 scene_unrelated_snippets 中 → 场景相关,直接保留
|
||||
if self._msg_matches_tokens(m, scene_preserve_tokens):
|
||||
self._log(f" [保护] '{msg_text[:40]}' → LLM保护,跳过")
|
||||
# else: 普通场景相关消息,保留,不输出日志
|
||||
|
||||
kept = [m for m in msgs if id(m) not in to_delete_ids]
|
||||
if not kept and msgs:
|
||||
kept = [msgs[0]]
|
||||
|
||||
deleted = len(msgs) - len(kept)
|
||||
self._log(
|
||||
f"[剪枝-{mode}] {dialog_label} 总消息={len(msgs)} "
|
||||
f"删除={deleted} 保留={len(kept)}"
|
||||
)
|
||||
return kept
|
||||
|
||||
def _build_preserve_tokens(self, extraction: "DialogExtractionResponse") -> List[str]:
|
||||
"""统一构建 preserve_tokens,合并 LLM 抽取的所有重要片段。"""
|
||||
return (
|
||||
extraction.times + extraction.ids + extraction.amounts +
|
||||
extraction.contacts + extraction.addresses + extraction.keywords +
|
||||
extraction.preserve_keywords
|
||||
)
|
||||
|
||||
def _msg_matches_tokens(self, message: ConversationMessage, tokens: List[str]) -> bool:
|
||||
"""判断消息是否包含任意抽取到的重要片段。"""
|
||||
if not tokens:
|
||||
@@ -397,16 +653,18 @@ class SemanticPruner:
|
||||
|
||||
proportion = float(self.config.pruning_threshold)
|
||||
extraction = await self._extract_dialog_important(dialog.content)
|
||||
pruning_mode = self._get_pruning_mode()
|
||||
self._log(f"[剪枝-模式] 阈值={proportion} → 模式={pruning_mode}")
|
||||
|
||||
if extraction.is_related:
|
||||
# 相关对话不剪枝
|
||||
kept = self._apply_related_dialog_pruning(
|
||||
dialog.context.msgs, extraction, f"对话ID={dialog.id}", pruning_mode
|
||||
)
|
||||
dialog.context = ConversationContext(msgs=kept)
|
||||
return dialog
|
||||
|
||||
# 在不相关对话中,LLM 已通过 preserve_tokens 标记需要保护的内容
|
||||
preserve_tokens = (
|
||||
extraction.times + extraction.ids + extraction.amounts +
|
||||
extraction.contacts + extraction.addresses + extraction.keywords +
|
||||
extraction.preserve_keywords
|
||||
)
|
||||
preserve_tokens = self._build_preserve_tokens(extraction)
|
||||
msgs = dialog.context.msgs
|
||||
|
||||
# 分类:填充 / 其他可删(LLM保护消息通过不加入任何桶来隐式保护)
|
||||
@@ -473,7 +731,7 @@ class SemanticPruner:
|
||||
# 阈值保护:最高0.9
|
||||
proportion = float(self.config.pruning_threshold)
|
||||
if proportion > 0.9:
|
||||
print(f"[剪枝-数据集] 阈值{proportion}超过上限0.9,已自动调整为0.9")
|
||||
logger.warning(f"[剪枝-数据集] 阈值{proportion}超过上限0.9,已自动调整为0.9")
|
||||
proportion = 0.9
|
||||
if proportion < 0.0:
|
||||
proportion = 0.0
|
||||
@@ -481,11 +739,30 @@ class SemanticPruner:
|
||||
self._log(
|
||||
f"[剪枝-数据集] 对话总数={len(dialogs)} 场景={self.config.pruning_scene} 删除比例={proportion} 开关={self.config.pruning_switch} 模式=消息级独立判断"
|
||||
)
|
||||
|
||||
|
||||
pruning_mode = self._get_pruning_mode()
|
||||
self._log(f"[剪枝-数据集] 阈值={proportion} → 剪枝阶段={pruning_mode}")
|
||||
|
||||
result: List[DialogData] = []
|
||||
total_original_msgs = 0
|
||||
total_deleted_msgs = 0
|
||||
|
||||
# 统计对象:直接收集结构化数据,无需事后正则解析
|
||||
stats = {
|
||||
"scene": self.config.pruning_scene,
|
||||
"dialog_total": len(dialogs),
|
||||
"deletion_ratio": proportion,
|
||||
"enabled": self.config.pruning_switch,
|
||||
"pruning_mode": pruning_mode,
|
||||
"related_count": 0,
|
||||
"unrelated_count": 0,
|
||||
"related_indices": [],
|
||||
"unrelated_indices": [],
|
||||
"total_deleted_messages": 0,
|
||||
"remaining_dialogs": 0,
|
||||
"dialogs": [],
|
||||
}
|
||||
|
||||
# 并发执行所有对话的 LLM 抽取(获取 preserve_keywords 等保护信息)
|
||||
semaphore = asyncio.Semaphore(self.max_concurrent)
|
||||
|
||||
@@ -505,12 +782,31 @@ class SemanticPruner:
|
||||
original_count = len(msgs)
|
||||
total_original_msgs += original_count
|
||||
|
||||
# 相关对话:根据阶段决定处理力度
|
||||
if extraction.is_related:
|
||||
stats["related_count"] += 1
|
||||
stats["related_indices"].append(d_idx + 1)
|
||||
kept = self._apply_related_dialog_pruning(
|
||||
msgs, extraction, f"对话 {d_idx+1}", pruning_mode
|
||||
)
|
||||
deleted_count = original_count - len(kept)
|
||||
total_deleted_msgs += deleted_count
|
||||
dd.context.msgs = kept
|
||||
result.append(dd)
|
||||
stats["dialogs"].append({
|
||||
"index": d_idx + 1,
|
||||
"is_related": True,
|
||||
"total_messages": original_count,
|
||||
"deleted": deleted_count,
|
||||
"kept": len(kept),
|
||||
})
|
||||
continue
|
||||
|
||||
stats["unrelated_count"] += 1
|
||||
stats["unrelated_indices"].append(d_idx + 1)
|
||||
|
||||
# 从 LLM 抽取结果中获取所有需要保留的 token
|
||||
preserve_tokens = (
|
||||
extraction.times + extraction.ids + extraction.amounts +
|
||||
extraction.contacts + extraction.addresses + extraction.keywords +
|
||||
extraction.preserve_keywords # 情绪/兴趣/爱好关键词
|
||||
)
|
||||
preserve_tokens = self._build_preserve_tokens(extraction)
|
||||
|
||||
# 判断是否需要详细日志
|
||||
should_log_details = self._detailed_prune_logging and original_count <= self._max_debug_msgs_per_dialog
|
||||
@@ -527,6 +823,12 @@ class SemanticPruner:
|
||||
|
||||
for idx, m in enumerate(msgs):
|
||||
msg_text = m.msg.strip()
|
||||
|
||||
# 最高优先级保护:带有文件的消息一律保留,不参与分类
|
||||
if message_has_files(m):
|
||||
self._log(f" [保护] 带文件的消息(不参与分类,直接保留):索引{idx}, '{msg_text[:40]}', 文件数={len(m.files)}")
|
||||
llm_protected_msgs.append((idx, m)) # 放入保护列表
|
||||
continue
|
||||
|
||||
if self._msg_matches_tokens(m, preserve_tokens):
|
||||
llm_protected_msgs.append((idx, m))
|
||||
@@ -543,16 +845,16 @@ class SemanticPruner:
|
||||
|
||||
# important_msgs 仅用于日志统计
|
||||
important_msgs = llm_protected_msgs
|
||||
|
||||
|
||||
# 计算删除配额
|
||||
delete_target = int(original_count * proportion)
|
||||
if proportion > 0 and original_count > 0 and delete_target == 0:
|
||||
delete_target = 1
|
||||
|
||||
|
||||
# 确保至少保留1条消息
|
||||
max_deletable = max(0, original_count - 1)
|
||||
delete_target = min(delete_target, max_deletable)
|
||||
|
||||
|
||||
# 删除策略:优先删填充消息,再按出现顺序删其余可删消息
|
||||
to_delete_indices = set()
|
||||
deleted_details = []
|
||||
@@ -570,58 +872,73 @@ class SemanticPruner:
|
||||
break
|
||||
to_delete_indices.add(idx)
|
||||
deleted_details.append(f"[{idx}] 可删: '{msg.msg[:50]}'")
|
||||
|
||||
|
||||
# 执行删除
|
||||
kept_msgs = []
|
||||
for idx, m in enumerate(msgs):
|
||||
if idx not in to_delete_indices:
|
||||
kept_msgs.append(m)
|
||||
|
||||
|
||||
# 确保至少保留1条
|
||||
if not kept_msgs and msgs:
|
||||
kept_msgs = [msgs[0]]
|
||||
|
||||
|
||||
dd.context.msgs = kept_msgs
|
||||
deleted_count = original_count - len(kept_msgs)
|
||||
total_deleted_msgs += deleted_count
|
||||
|
||||
|
||||
# 输出删除详情
|
||||
if deleted_details:
|
||||
self._log(f"[剪枝-删除详情] 对话 {d_idx+1} 删除了以下消息:")
|
||||
for detail in deleted_details:
|
||||
self._log(f" {detail}")
|
||||
|
||||
|
||||
# ========== 问答对统计(已注释) ==========
|
||||
# qa_info = f",问答对={len(qa_pairs)}" if qa_pairs else ""
|
||||
# ========================================
|
||||
|
||||
|
||||
self._log(
|
||||
f"[剪枝-对话] 对话 {d_idx+1} 总消息={original_count} "
|
||||
f"(保护={len(important_msgs)} 填充={len(filler_msgs)} 可删={len(deletable_msgs)}) "
|
||||
f"删除={deleted_count} 保留={len(kept_msgs)}"
|
||||
)
|
||||
|
||||
result.append(dd)
|
||||
|
||||
self._log(f"[剪枝-数据集] 剩余对话数={len(result)}")
|
||||
|
||||
# 保存日志
|
||||
stats["dialogs"].append({
|
||||
"index": d_idx + 1,
|
||||
"is_related": False,
|
||||
"total_messages": original_count,
|
||||
"protected": len(important_msgs),
|
||||
"fillers": len(filler_msgs),
|
||||
"deletable": len(deletable_msgs),
|
||||
"deleted": deleted_count,
|
||||
"kept": len(kept_msgs),
|
||||
})
|
||||
|
||||
result.append(dd)
|
||||
|
||||
# 补全统计对象
|
||||
stats["total_deleted_messages"] = total_deleted_msgs
|
||||
stats["remaining_dialogs"] = len(result)
|
||||
|
||||
self._log(f"[剪枝-数据集] 剩余对话数={len(result)}")
|
||||
self._log(f"[剪枝-数据集] 相关对话数={stats['related_count']} 不相关对话数={stats['unrelated_count']}")
|
||||
self._log(f"[剪枝-数据集] 总删除 {total_deleted_msgs} 条")
|
||||
|
||||
# 直接序列化统计对象,无需正则解析
|
||||
try:
|
||||
from app.core.config import settings
|
||||
settings.ensure_memory_output_dir()
|
||||
log_output_path = settings.get_memory_output_path("pruned_terminal.json")
|
||||
sanitized_logs = [self._sanitize_log_line(l) for l in self.run_logs]
|
||||
payload = self._parse_logs_to_structured(sanitized_logs)
|
||||
with open(log_output_path, "w", encoding="utf-8") as f:
|
||||
json.dump(payload, f, ensure_ascii=False, indent=2)
|
||||
json.dump(stats, f, ensure_ascii=False, indent=2)
|
||||
except Exception as e:
|
||||
self._log(f"[剪枝-数据集] 保存终端输出日志失败:{e}")
|
||||
|
||||
# Safety: avoid empty dataset
|
||||
if not result:
|
||||
print("警告: 语义剪枝后数据集为空,已回退为未剪枝数据以避免流程中断")
|
||||
logger.warning("语义剪枝后数据集为空,已回退为未剪枝数据以避免流程中断")
|
||||
return dialogs
|
||||
|
||||
|
||||
return result
|
||||
|
||||
def _log(self, msg: str) -> None:
|
||||
@@ -629,118 +946,7 @@ class SemanticPruner:
|
||||
try:
|
||||
self.run_logs.append(msg)
|
||||
except Exception:
|
||||
# 任何异常都不影响打印
|
||||
pass
|
||||
print(msg)
|
||||
logger.debug(msg)
|
||||
|
||||
def _sanitize_log_line(self, line: str) -> str:
|
||||
"""移除行首的方括号标签前缀,例如 [剪枝-数据集] 或 [剪枝-对话]。"""
|
||||
try:
|
||||
return re.sub(r"^\[[^\]]+\]\s*", "", line)
|
||||
except Exception:
|
||||
return line
|
||||
|
||||
def _parse_logs_to_structured(self, logs: List[str]) -> dict:
|
||||
"""将已去前缀的日志列表解析为结构化 JSON,便于数据对接。"""
|
||||
summary = {
|
||||
"scene": self.config.pruning_scene,
|
||||
"dialog_total": None,
|
||||
"deletion_ratio": None,
|
||||
"enabled": None,
|
||||
"related_count": None,
|
||||
"unrelated_count": None,
|
||||
"related_indices": [],
|
||||
"unrelated_indices": [],
|
||||
"total_deleted_messages": None,
|
||||
"remaining_dialogs": None,
|
||||
}
|
||||
dialogs = []
|
||||
|
||||
# 解析函数
|
||||
def parse_int(value: str) -> Optional[int]:
|
||||
try:
|
||||
return int(value)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def parse_float(value: str) -> Optional[float]:
|
||||
try:
|
||||
return float(value)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def parse_indices(s: str) -> List[int]:
|
||||
s = s.strip()
|
||||
if not s:
|
||||
return []
|
||||
parts = [p.strip() for p in s.split(",") if p.strip()]
|
||||
out: List[int] = []
|
||||
for p in parts:
|
||||
try:
|
||||
out.append(int(p))
|
||||
except Exception:
|
||||
pass
|
||||
return out
|
||||
|
||||
# 正则
|
||||
re_header = re.compile(r"对话总数=(\d+)\s+场景=([^\s]+)\s+删除比例=([0-9.]+)\s+开关=(True|False)")
|
||||
re_counts = re.compile(r"相关对话数=(\d+)\s+不相关对话数=(\d+)")
|
||||
re_indices = re.compile(r"相关对话:第\[(.*?)\]段;不相关对话:第\[(.*?)\]段")
|
||||
re_dialog = re.compile(r"对话\s+(\d+)\s+总消息=(\d+)\s+分配删除=(\d+)\s+实删=(\d+)\s+保留=(\d+)")
|
||||
re_total_del = re.compile(r"总删除\s+(\d+)\s+条")
|
||||
re_remaining = re.compile(r"剩余对话数=(\d+)")
|
||||
|
||||
for line in logs:
|
||||
# 第一行:总览
|
||||
m = re_header.search(line)
|
||||
if m:
|
||||
summary["dialog_total"] = parse_int(m.group(1))
|
||||
# 顶层 scene 依配置,这里不覆盖,但也可校验 m.group(2)
|
||||
summary["deletion_ratio"] = parse_float(m.group(3))
|
||||
summary["enabled"] = True if m.group(4) == "True" else False
|
||||
continue
|
||||
|
||||
# 第二行:相关/不相关数量
|
||||
m = re_counts.search(line)
|
||||
if m:
|
||||
summary["related_count"] = parse_int(m.group(1))
|
||||
summary["unrelated_count"] = parse_int(m.group(2))
|
||||
continue
|
||||
|
||||
# 第三行:相关/不相关索引
|
||||
m = re_indices.search(line)
|
||||
if m:
|
||||
summary["related_indices"] = parse_indices(m.group(1))
|
||||
summary["unrelated_indices"] = parse_indices(m.group(2))
|
||||
continue
|
||||
|
||||
# 对话级统计
|
||||
m = re_dialog.search(line)
|
||||
if m:
|
||||
dialogs.append({
|
||||
"index": parse_int(m.group(1)),
|
||||
"total_messages": parse_int(m.group(2)),
|
||||
"quota_delete": parse_int(m.group(3)),
|
||||
"actual_deleted": parse_int(m.group(4)),
|
||||
"kept": parse_int(m.group(5)),
|
||||
})
|
||||
continue
|
||||
|
||||
# 全局删除总数
|
||||
m = re_total_del.search(line)
|
||||
if m:
|
||||
summary["total_deleted_messages"] = parse_int(m.group(1))
|
||||
continue
|
||||
|
||||
# 剩余对话数
|
||||
m = re_remaining.search(line)
|
||||
if m:
|
||||
summary["remaining_dialogs"] = parse_int(m.group(1))
|
||||
continue
|
||||
|
||||
return {
|
||||
"scene": summary["scene"],
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"summary": {k: v for k, v in summary.items() if k != "scene"},
|
||||
"dialogs": dialogs,
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
import asyncio
|
||||
import difflib # 提供字符串相似度计算工具
|
||||
import importlib
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from datetime import datetime
|
||||
@@ -16,6 +17,8 @@ from app.core.memory.models.graph_models import (
|
||||
)
|
||||
from app.core.memory.models.variate_config import DedupConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# 模块级类型统一工具函数
|
||||
def _unify_entity_type(canonical: ExtractedEntityNode, losing: ExtractedEntityNode, suggested_type: str = None) -> None:
|
||||
@@ -79,51 +82,38 @@ def _merge_attribute(canonical: ExtractedEntityNode, ent: ExtractedEntityNode):
|
||||
canonical.connect_strength = next(iter(pair))
|
||||
|
||||
# 别名合并(去重保序,使用标准化工具)
|
||||
# 用户实体的 aliases 由 PgSQL end_user_info 作为唯一权威源,去重合并时不修改
|
||||
try:
|
||||
canonical_name = (getattr(canonical, "name", "") or "").strip()
|
||||
incoming_name = (getattr(ent, "name", "") or "").strip()
|
||||
|
||||
# 收集所有需要合并的别名
|
||||
all_aliases = []
|
||||
|
||||
# 1. 添加canonical现有的别名
|
||||
existing = getattr(canonical, "aliases", []) or []
|
||||
all_aliases.extend(existing)
|
||||
|
||||
# 2. 添加incoming实体的名称(如果不同于canonical的名称)
|
||||
if incoming_name and incoming_name != canonical_name:
|
||||
all_aliases.append(incoming_name)
|
||||
|
||||
# 3. 添加incoming实体的所有别名
|
||||
incoming = getattr(ent, "aliases", []) or []
|
||||
all_aliases.extend(incoming)
|
||||
|
||||
# 4. 标准化并去重(优先使用alias_utils工具函数)
|
||||
try:
|
||||
from app.core.memory.utils.alias_utils import normalize_aliases
|
||||
canonical.aliases = normalize_aliases(canonical_name, all_aliases)
|
||||
except Exception:
|
||||
# 如果导入失败,使用增强的去重逻辑
|
||||
seen_normalized = set()
|
||||
unique_aliases = []
|
||||
if canonical_name.lower() not in _USER_PLACEHOLDER_NAMES:
|
||||
incoming_name = (getattr(ent, "name", "") or "").strip()
|
||||
|
||||
for alias in all_aliases:
|
||||
if not alias:
|
||||
continue
|
||||
|
||||
alias_stripped = str(alias).strip()
|
||||
if not alias_stripped or alias_stripped == canonical_name:
|
||||
continue
|
||||
|
||||
# 标准化:转小写用于去重判断
|
||||
alias_normalized = alias_stripped.lower()
|
||||
|
||||
if alias_normalized not in seen_normalized:
|
||||
seen_normalized.add(alias_normalized)
|
||||
unique_aliases.append(alias_stripped)
|
||||
# 收集所有需要合并的别名,过滤掉用户占位名避免污染非用户实体
|
||||
all_aliases = list(getattr(canonical, "aliases", []) or [])
|
||||
if incoming_name and incoming_name != canonical_name and incoming_name.lower() not in _USER_PLACEHOLDER_NAMES:
|
||||
all_aliases.append(incoming_name)
|
||||
all_aliases.extend(
|
||||
a for a in (getattr(ent, "aliases", []) or [])
|
||||
if a and a.strip().lower() not in _USER_PLACEHOLDER_NAMES
|
||||
)
|
||||
|
||||
# 排序并赋值
|
||||
canonical.aliases = sorted(unique_aliases)
|
||||
try:
|
||||
from app.core.memory.utils.alias_utils import normalize_aliases
|
||||
canonical.aliases = normalize_aliases(canonical_name, all_aliases)
|
||||
except Exception:
|
||||
seen_normalized = set()
|
||||
unique_aliases = []
|
||||
for alias in all_aliases:
|
||||
if not alias:
|
||||
continue
|
||||
alias_stripped = str(alias).strip()
|
||||
if not alias_stripped or alias_stripped == canonical_name:
|
||||
continue
|
||||
alias_normalized = alias_stripped.lower()
|
||||
if alias_normalized not in seen_normalized:
|
||||
seen_normalized.add(alias_normalized)
|
||||
unique_aliases.append(alias_stripped)
|
||||
canonical.aliases = sorted(unique_aliases)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@@ -198,11 +188,167 @@ def _merge_attribute(canonical: ExtractedEntityNode, ent: ExtractedEntityNode):
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 用户和AI助手的占位名称集合(用于名称标准化)
|
||||
_USER_PLACEHOLDER_NAMES = {"用户", "我", "user", "i"}
|
||||
_ASSISTANT_PLACEHOLDER_NAMES = {"ai助手", "助手", "人工智能助手", "智能助手", "智能体", "ai assistant", "assistant"}
|
||||
|
||||
# 标准化后的规范名称和类型
|
||||
_CANONICAL_USER_NAME = "用户"
|
||||
_CANONICAL_USER_TYPE = "用户"
|
||||
_CANONICAL_ASSISTANT_NAME = "AI助手"
|
||||
_CANONICAL_ASSISTANT_TYPE = "Agent"
|
||||
|
||||
# 用户和AI助手的所有可能名称(用于判断实体是否为特殊角色实体)
|
||||
_ALL_USER_NAMES = _USER_PLACEHOLDER_NAMES
|
||||
_ALL_ASSISTANT_NAMES = _ASSISTANT_PLACEHOLDER_NAMES
|
||||
|
||||
|
||||
def _is_user_entity(ent: ExtractedEntityNode) -> bool:
|
||||
"""判断实体是否为用户实体(name 或 entity_type 匹配)"""
|
||||
name = (getattr(ent, "name", "") or "").strip().lower()
|
||||
etype = (getattr(ent, "entity_type", "") or "").strip()
|
||||
return name in _ALL_USER_NAMES or etype == _CANONICAL_USER_TYPE
|
||||
|
||||
|
||||
def _is_assistant_entity(ent: ExtractedEntityNode) -> bool:
|
||||
"""判断实体是否为AI助手实体(name 或 entity_type 匹配)"""
|
||||
name = (getattr(ent, "name", "") or "").strip().lower()
|
||||
etype = (getattr(ent, "entity_type", "") or "").strip()
|
||||
return name in _ALL_ASSISTANT_NAMES or etype == _CANONICAL_ASSISTANT_TYPE
|
||||
|
||||
|
||||
def _would_merge_cross_role(a: ExtractedEntityNode, b: ExtractedEntityNode) -> bool:
|
||||
"""判断两个实体的合并是否会跨越用户/AI助手角色边界。
|
||||
|
||||
用户实体和AI助手实体永远不应该被合并在一起。
|
||||
如果一方是用户实体、另一方是AI助手实体,返回 True(阻止合并)。
|
||||
"""
|
||||
return (
|
||||
(_is_user_entity(a) and _is_assistant_entity(b))
|
||||
or (_is_assistant_entity(a) and _is_user_entity(b))
|
||||
)
|
||||
|
||||
|
||||
def _normalize_special_entity_names(
|
||||
entity_nodes: List[ExtractedEntityNode],
|
||||
) -> None:
|
||||
"""标准化用户和AI助手实体的名称和类型。
|
||||
|
||||
多轮对话中,LLM 对同一角色可能使用不同的名称变体(如"用户"/"我"/"User",
|
||||
"AI助手"/"助手"/"Assistant"),导致精确匹配无法合并。
|
||||
此函数在去重前将这些变体统一为规范名称,并强制绑定 entity_type,确保:
|
||||
- name="用户" 的实体 entity_type 一定为 "用户"
|
||||
- name="AI助手" 的实体 entity_type 一定为 "Agent"
|
||||
|
||||
Args:
|
||||
entity_nodes: 实体节点列表(原地修改)
|
||||
"""
|
||||
for ent in entity_nodes:
|
||||
name = (getattr(ent, "name", "") or "").strip()
|
||||
name_lower = name.lower()
|
||||
|
||||
if name_lower in _USER_PLACEHOLDER_NAMES:
|
||||
ent.name = _CANONICAL_USER_NAME
|
||||
ent.entity_type = _CANONICAL_USER_TYPE
|
||||
elif name_lower in _ASSISTANT_PLACEHOLDER_NAMES:
|
||||
ent.name = _CANONICAL_ASSISTANT_NAME
|
||||
ent.entity_type = _CANONICAL_ASSISTANT_TYPE
|
||||
|
||||
# 第二步:清洗用户/AI助手之间的别名交叉污染(复用 clean_cross_role_aliases)
|
||||
clean_cross_role_aliases(entity_nodes)
|
||||
|
||||
|
||||
async def fetch_neo4j_assistant_aliases(neo4j_connector, end_user_id: str) -> set:
|
||||
"""从 Neo4j 查询 AI 助手实体的所有别名(小写归一化)。
|
||||
|
||||
这是助手别名查询的唯一入口,供 write_tools 和 extraction_orchestrator 共用,
|
||||
避免多处维护相同的 Cypher 和名称列表。
|
||||
|
||||
Args:
|
||||
neo4j_connector: Neo4j 连接器实例(需提供 execute_query 方法)
|
||||
end_user_id: 终端用户 ID
|
||||
|
||||
Returns:
|
||||
小写归一化后的助手别名集合
|
||||
"""
|
||||
# 查询名称列表:规范名称 + 常见变体(与 _normalize_special_entity_names 标准化后一致)
|
||||
query_names = [_CANONICAL_ASSISTANT_NAME, *_ASSISTANT_PLACEHOLDER_NAMES]
|
||||
# 去重保序
|
||||
query_names = list(dict.fromkeys(query_names))
|
||||
|
||||
cypher = """
|
||||
MATCH (e:ExtractedEntity)
|
||||
WHERE e.end_user_id = $end_user_id AND e.name IN $names
|
||||
RETURN e.aliases AS aliases
|
||||
"""
|
||||
try:
|
||||
result = await neo4j_connector.execute_query(
|
||||
cypher, end_user_id=end_user_id, names=query_names
|
||||
)
|
||||
assistant_aliases: set = set()
|
||||
for record in (result or []):
|
||||
for alias in (record.get("aliases") or []):
|
||||
assistant_aliases.add(alias.strip().lower())
|
||||
if assistant_aliases:
|
||||
logger.debug(f"Neo4j 中 AI 助手别名: {assistant_aliases}")
|
||||
return assistant_aliases
|
||||
except Exception as e:
|
||||
logger.warning(f"查询 Neo4j AI 助手别名失败: {e}")
|
||||
return set()
|
||||
|
||||
|
||||
def clean_cross_role_aliases(
|
||||
entity_nodes: List[ExtractedEntityNode],
|
||||
external_assistant_aliases: set = None,
|
||||
) -> None:
|
||||
"""清洗用户实体和AI助手实体之间的别名交叉污染。
|
||||
|
||||
在 Neo4j 写入前调用,确保:
|
||||
- 用户实体的 aliases 不包含 AI 助手的别名
|
||||
- AI 助手实体的 aliases 不包含用户的别名
|
||||
|
||||
Args:
|
||||
entity_nodes: 实体节点列表(原地修改)
|
||||
external_assistant_aliases: 外部传入的 AI 助手别名集合(如从 Neo4j 查询),
|
||||
与本轮实体中的 AI 助手别名合并使用
|
||||
"""
|
||||
# 收集本轮 AI 助手实体的所有别名
|
||||
assistant_aliases = set(external_assistant_aliases or set())
|
||||
user_aliases = set()
|
||||
|
||||
for ent in entity_nodes:
|
||||
if _is_assistant_entity(ent):
|
||||
for alias in (getattr(ent, "aliases", []) or []):
|
||||
assistant_aliases.add(alias.strip().lower())
|
||||
elif _is_user_entity(ent):
|
||||
for alias in (getattr(ent, "aliases", []) or []):
|
||||
user_aliases.add(alias.strip().lower())
|
||||
|
||||
# 从用户实体的 aliases 中移除 AI 助手别名
|
||||
if assistant_aliases:
|
||||
for ent in entity_nodes:
|
||||
if _is_user_entity(ent):
|
||||
original = getattr(ent, "aliases", []) or []
|
||||
cleaned = [a for a in original if a.strip().lower() not in assistant_aliases]
|
||||
if len(cleaned) < len(original):
|
||||
ent.aliases = cleaned
|
||||
|
||||
# 从 AI 助手实体的 aliases 中移除用户别名
|
||||
if user_aliases:
|
||||
for ent in entity_nodes:
|
||||
if _is_assistant_entity(ent):
|
||||
original = getattr(ent, "aliases", []) or []
|
||||
cleaned = [a for a in original if a.strip().lower() not in user_aliases]
|
||||
if len(cleaned) < len(original):
|
||||
ent.aliases = cleaned
|
||||
|
||||
|
||||
def accurate_match(
|
||||
entity_nodes: List[ExtractedEntityNode]
|
||||
) -> Tuple[List[ExtractedEntityNode], Dict[str, str], Dict[str, Dict]]:
|
||||
"""
|
||||
精确匹配:按 (end_user_id, name, entity_type) 合并实体并建立重定向与合并记录。
|
||||
同时检测某实体的 name 是否命中另一实体的 aliases,若命中则直接合并。
|
||||
返回: (deduped_entities, id_redirect, exact_merge_map)
|
||||
"""
|
||||
exact_merge_map: Dict[str, Dict] = {}
|
||||
@@ -240,6 +386,52 @@ def accurate_match(
|
||||
pass
|
||||
|
||||
deduped_entities = list(canonical_map.values())
|
||||
|
||||
# 2) 第二轮:检测某实体的 name 是否命中另一实体的 aliases(alias-to-name 精确合并)
|
||||
# 场景:LLM 把 aliases 中的词(如"齐齐")又单独抽取为独立实体,需在此阶段合并掉
|
||||
# 优化:先构建 (end_user_id, alias_lower) -> canonical 的反向索引,查找 O(1)
|
||||
alias_index: Dict[tuple, ExtractedEntityNode] = {}
|
||||
for canonical in deduped_entities:
|
||||
uid = getattr(canonical, "end_user_id", None)
|
||||
for alias in (getattr(canonical, "aliases", []) or []):
|
||||
alias_lower = alias.strip().lower()
|
||||
if alias_lower:
|
||||
alias_index[(uid, alias_lower)] = canonical
|
||||
|
||||
i = 0
|
||||
while i < len(deduped_entities):
|
||||
ent = deduped_entities[i]
|
||||
ent_name = (getattr(ent, "name", "") or "").strip().lower()
|
||||
ent_uid = getattr(ent, "end_user_id", None)
|
||||
canonical = alias_index.get((ent_uid, ent_name))
|
||||
# 确保不是自身
|
||||
if canonical is not None and canonical.id != ent.id:
|
||||
# 保护:禁止跨角色合并(用户实体和AI助手实体不能互相合并)
|
||||
if _would_merge_cross_role(canonical, ent):
|
||||
i += 1
|
||||
continue
|
||||
_merge_attribute(canonical, ent)
|
||||
id_redirect[ent.id] = canonical.id
|
||||
for k, v in list(id_redirect.items()):
|
||||
if v == ent.id:
|
||||
id_redirect[k] = canonical.id
|
||||
try:
|
||||
k = f"{canonical.end_user_id}|{(canonical.name or '').strip()}|{(canonical.entity_type or '').strip()}"
|
||||
if k not in exact_merge_map:
|
||||
exact_merge_map[k] = {
|
||||
"canonical_id": canonical.id,
|
||||
"end_user_id": canonical.end_user_id,
|
||||
"name": canonical.name,
|
||||
"entity_type": canonical.entity_type,
|
||||
"merged_ids": set(),
|
||||
}
|
||||
exact_merge_map[k]["merged_ids"].add(ent.id)
|
||||
except Exception:
|
||||
pass
|
||||
deduped_entities.pop(i)
|
||||
else:
|
||||
i += 1
|
||||
|
||||
return deduped_entities, id_redirect, exact_merge_map
|
||||
|
||||
def fuzzy_match(
|
||||
@@ -528,66 +720,37 @@ def fuzzy_match(
|
||||
|
||||
|
||||
def _merge_entities_with_aliases(canonical: ExtractedEntityNode, losing: ExtractedEntityNode):
|
||||
""" 模糊匹配中的实体合并。
|
||||
"""模糊匹配中的实体合并(别名部分)。
|
||||
|
||||
合并策略:
|
||||
1. 保留canonical的主名称不变
|
||||
2. 将losing的主名称添加为alias(如果不同)
|
||||
3. 合并两个实体的所有aliases
|
||||
4. 自动去重(case-insensitive)并排序
|
||||
|
||||
Args:
|
||||
canonical: 规范实体(保留)
|
||||
losing: 被合并实体(删除)
|
||||
|
||||
Note:
|
||||
使用alias_utils.normalize_aliases进行标准化去重
|
||||
用户实体的 aliases 由 PgSQL end_user_info 作为唯一权威源,跳过合并。
|
||||
"""
|
||||
# 获取规范实体的名称
|
||||
canonical_name = (getattr(canonical, "name", "") or "").strip()
|
||||
if canonical_name.lower() in _USER_PLACEHOLDER_NAMES:
|
||||
return
|
||||
|
||||
losing_name = (getattr(losing, "name", "") or "").strip()
|
||||
|
||||
# 收集所有需要合并的别名
|
||||
all_aliases = []
|
||||
|
||||
# 1. 添加canonical现有的别名
|
||||
current_aliases = getattr(canonical, "aliases", []) or []
|
||||
all_aliases.extend(current_aliases)
|
||||
|
||||
# 2. 添加losing实体的名称(如果不同于canonical的名称)
|
||||
all_aliases = list(getattr(canonical, "aliases", []) or [])
|
||||
if losing_name and losing_name != canonical_name:
|
||||
all_aliases.append(losing_name)
|
||||
all_aliases.extend(getattr(losing, "aliases", []) or [])
|
||||
|
||||
# 3. 添加losing实体的所有别名
|
||||
losing_aliases = getattr(losing, "aliases", []) or []
|
||||
all_aliases.extend(losing_aliases)
|
||||
|
||||
# 4. 标准化并去重(使用标准化后的字符串进行去重)
|
||||
try:
|
||||
from app.core.memory.utils.alias_utils import normalize_aliases
|
||||
canonical.aliases = normalize_aliases(canonical_name, all_aliases)
|
||||
except Exception:
|
||||
# 如果导入失败,使用增强的去重逻辑
|
||||
# 使用标准化后的字符串作为key进行去重
|
||||
seen_normalized = set()
|
||||
unique_aliases = []
|
||||
|
||||
for alias in all_aliases:
|
||||
if not alias:
|
||||
continue
|
||||
|
||||
alias_stripped = str(alias).strip()
|
||||
if not alias_stripped or alias_stripped == canonical_name:
|
||||
continue
|
||||
|
||||
# 标准化:转小写用于去重判断
|
||||
alias_normalized = alias_stripped.lower()
|
||||
|
||||
if alias_normalized not in seen_normalized:
|
||||
seen_normalized.add(alias_normalized)
|
||||
unique_aliases.append(alias_stripped)
|
||||
|
||||
# 排序并赋值
|
||||
canonical.aliases = sorted(unique_aliases)
|
||||
|
||||
# ========== 主循环:遍历所有实体对进行模糊匹配 ==========
|
||||
@@ -661,6 +824,11 @@ def fuzzy_match(
|
||||
# 条件A(快速通道):alias_match_merge = True
|
||||
# 条件B(标准通道):s_name ≥ tn AND s_type ≥ type_threshold AND overall ≥ tover
|
||||
if alias_match_merge or (s_name >= tn and s_type >= type_threshold and overall >= tover):
|
||||
# 保护:禁止跨角色合并(用户实体和AI助手实体不能互相合并)
|
||||
if _would_merge_cross_role(a, b):
|
||||
j += 1
|
||||
continue
|
||||
|
||||
# ========== 第六步:执行实体合并 ==========
|
||||
|
||||
# 6.1 合并别名
|
||||
@@ -770,6 +938,12 @@ async def LLM_decision( # 决策中包含去重和消歧的功能
|
||||
b = entity_by_id.get(losing_id)
|
||||
if not a or not b: # 若不存在 a 或 b,可能已在精确或模糊阶段合并,在之前阶段合并之后,不会再处理但是处于审计的目的会记录
|
||||
continue
|
||||
# 保护:禁止跨角色合并(用户实体和AI助手实体不能互相合并)
|
||||
if _would_merge_cross_role(a, b):
|
||||
llm_records.append(
|
||||
f"[LLM阻断] 跨角色合并被阻止: {a.id} ({a.name}) 与 {b.id} ({b.name})"
|
||||
)
|
||||
continue
|
||||
_merge_attribute(a, b)
|
||||
# ID 重定向
|
||||
try:
|
||||
@@ -891,6 +1065,9 @@ async def deduplicate_entities_and_edges(
|
||||
返回:去重后的实体、语句→实体边、实体↔实体边。
|
||||
"""
|
||||
local_llm_records: List[str] = [] # 作为“审计日志”的本地收集器 初始化,保留为了之后对于LLM决策追溯
|
||||
# 0) 标准化用户和AI助手实体名称(确保多轮对话中的变体名称统一)
|
||||
_normalize_special_entity_names(entity_nodes)
|
||||
|
||||
# 1) 精确匹配
|
||||
deduped_entities, id_redirect, exact_merge_map = accurate_match(entity_nodes)
|
||||
|
||||
|
||||
@@ -15,6 +15,7 @@ from app.core.memory.models.message_models import DialogData
|
||||
from app.core.memory.models.variate_config import ExtractionPipelineConfig
|
||||
from app.core.memory.storage_services.extraction_engine.deduplication.deduped_and_disamb import (
|
||||
deduplicate_entities_and_edges,
|
||||
clean_cross_role_aliases,
|
||||
)
|
||||
from app.core.memory.storage_services.extraction_engine.deduplication.second_layer_dedup import (
|
||||
second_layer_dedup_and_merge_with_neo4j,
|
||||
@@ -25,17 +26,17 @@ from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
|
||||
|
||||
async def dedup_layers_and_merge_and_return(
|
||||
dialogue_nodes: List[DialogueNode],
|
||||
chunk_nodes: List[ChunkNode],
|
||||
statement_nodes: List[StatementNode],
|
||||
entity_nodes: List[ExtractedEntityNode],
|
||||
statement_chunk_edges: List[StatementChunkEdge],
|
||||
statement_entity_edges: List[StatementEntityEdge],
|
||||
entity_entity_edges: List[EntityEntityEdge],
|
||||
dialog_data_list: List[DialogData],
|
||||
pipeline_config: ExtractionPipelineConfig,
|
||||
connector: Optional[Neo4jConnector] = None,
|
||||
llm_client = None,
|
||||
dialogue_nodes: List[DialogueNode],
|
||||
chunk_nodes: List[ChunkNode],
|
||||
statement_nodes: List[StatementNode],
|
||||
entity_nodes: List[ExtractedEntityNode],
|
||||
statement_chunk_edges: List[StatementChunkEdge],
|
||||
statement_entity_edges: List[StatementEntityEdge],
|
||||
entity_entity_edges: List[EntityEntityEdge],
|
||||
dialog_data_list: List[DialogData],
|
||||
pipeline_config: ExtractionPipelineConfig,
|
||||
connector: Optional[Neo4jConnector] = None,
|
||||
llm_client=None,
|
||||
) -> Tuple[
|
||||
List[DialogueNode],
|
||||
List[ChunkNode],
|
||||
@@ -44,7 +45,7 @@ async def dedup_layers_and_merge_and_return(
|
||||
List[StatementChunkEdge],
|
||||
List[StatementEntityEdge],
|
||||
List[EntityEntityEdge],
|
||||
dict, # 新增:返回去重详情
|
||||
dict
|
||||
]:
|
||||
"""
|
||||
执行两层实体去重与融合:
|
||||
@@ -100,6 +101,10 @@ async def dedup_layers_and_merge_and_return(
|
||||
except Exception as e:
|
||||
print(f"Second-layer dedup failed: {e}")
|
||||
|
||||
# 第二层去重后,清洗用户/AI助手之间的别名交叉污染
|
||||
# 第二层从 Neo4j 合并了旧实体,可能带入历史脏数据
|
||||
clean_cross_role_aliases(fused_entity_nodes)
|
||||
|
||||
return (
|
||||
dialogue_nodes,
|
||||
chunk_nodes,
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -5,8 +5,11 @@
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
|
||||
from app.core.memory.models.message_models import DialogData
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
@@ -48,9 +51,9 @@ class EmbeddingGenerator:
|
||||
return await self.embedder_client.response(texts)
|
||||
|
||||
# 分批并行处理
|
||||
print(f"文本数量 {len(texts)} 超过批次大小 {batch_size},分批并行处理")
|
||||
logger.info(f"文本数量 {len(texts)} 超过批次大小 {batch_size},分批并行处理")
|
||||
batches = [texts[i:i+batch_size] for i in range(0, len(texts), batch_size)]
|
||||
print(f"分成 {len(batches)} 批,每批最多 {batch_size} 个文本")
|
||||
logger.info(f"分成 {len(batches)} 批,每批最多 {batch_size} 个文本")
|
||||
|
||||
# 并行发送所有批次
|
||||
batch_results = await asyncio.gather(*[
|
||||
@@ -62,7 +65,7 @@ class EmbeddingGenerator:
|
||||
for batch_result in batch_results:
|
||||
embeddings.extend(batch_result)
|
||||
|
||||
print(f"分批并行处理完成,共生成 {len(embeddings)} 个嵌入向量")
|
||||
logger.info(f"分批并行处理完成,共生成 {len(embeddings)} 个嵌入向量")
|
||||
return embeddings
|
||||
|
||||
async def generate_statement_embeddings(
|
||||
@@ -77,7 +80,7 @@ class EmbeddingGenerator:
|
||||
Returns:
|
||||
每个对话的陈述句嵌入向量映射列表
|
||||
"""
|
||||
print("\n=== 生成陈述句嵌入向量 ===")
|
||||
logger.debug("=== 生成陈述句嵌入向量 ===")
|
||||
|
||||
# 收集所有陈述句
|
||||
all_statements = []
|
||||
@@ -102,7 +105,7 @@ class EmbeddingGenerator:
|
||||
stmt_id = chunked_dialogs[d_idx].chunks[c_idx].statements[s_idx].id
|
||||
stmt_embedding_maps[d_idx][stmt_id] = embedding
|
||||
|
||||
print(f"为 {len(all_statements)} 个陈述句生成了嵌入向量")
|
||||
logger.info(f"为 {len(all_statements)} 个陈述句生成了嵌入向量")
|
||||
return stmt_embedding_maps
|
||||
|
||||
async def generate_chunk_embeddings(
|
||||
@@ -117,7 +120,7 @@ class EmbeddingGenerator:
|
||||
Returns:
|
||||
每个对话的分块嵌入向量映射列表
|
||||
"""
|
||||
print("\n=== 生成分块嵌入向量 ===")
|
||||
logger.debug("=== 生成分块嵌入向量 ===")
|
||||
|
||||
# 收集所有分块
|
||||
all_chunks = []
|
||||
@@ -138,7 +141,7 @@ class EmbeddingGenerator:
|
||||
chunk_id = chunked_dialogs[d_idx].chunks[c_idx].id
|
||||
chunk_embedding_maps[d_idx][chunk_id] = embedding
|
||||
|
||||
print(f"为 {len(all_chunks)} 个分块生成了嵌入向量")
|
||||
logger.info(f"为 {len(all_chunks)} 个分块生成了嵌入向量")
|
||||
return chunk_embedding_maps
|
||||
|
||||
async def generate_dialog_embeddings(
|
||||
@@ -172,7 +175,7 @@ class EmbeddingGenerator:
|
||||
Returns:
|
||||
(陈述句嵌入映射列表, 分块嵌入映射列表, 对话嵌入列表)
|
||||
"""
|
||||
print("\n=== 生成所有嵌入向量 ===")
|
||||
logger.debug("=== 生成所有嵌入向量 ===")
|
||||
|
||||
# 并发生成陈述句和分块嵌入向量
|
||||
stmt_embedding_maps, chunk_embedding_maps = await asyncio.gather(
|
||||
@@ -183,9 +186,7 @@ class EmbeddingGenerator:
|
||||
# 对话嵌入向量(当前跳过)
|
||||
dialog_embeddings = await self.generate_dialog_embeddings(chunked_dialogs)
|
||||
|
||||
print(
|
||||
f"生成完成:{len(chunked_dialogs)} 个对话的嵌入向量"
|
||||
)
|
||||
logger.info(f"生成完成:{len(chunked_dialogs)} 个对话的嵌入向量")
|
||||
|
||||
return stmt_embedding_maps, chunk_embedding_maps, dialog_embeddings
|
||||
|
||||
@@ -201,7 +202,7 @@ class EmbeddingGenerator:
|
||||
Returns:
|
||||
更新后的三元组映射列表(实体包含嵌入向量)
|
||||
"""
|
||||
print("\n=== 生成实体嵌入向量 ===")
|
||||
logger.debug("=== 生成实体嵌入向量 ===")
|
||||
|
||||
entity_texts: List[str] = []
|
||||
entity_refs: List[Any] = []
|
||||
@@ -219,7 +220,7 @@ class EmbeddingGenerator:
|
||||
entity_refs.append(ent)
|
||||
|
||||
if not entity_texts:
|
||||
print("没有找到需要生成嵌入向量的实体")
|
||||
logger.debug("没有找到需要生成嵌入向量的实体")
|
||||
return triplet_maps
|
||||
|
||||
# 批量生成嵌入向量
|
||||
@@ -227,13 +228,13 @@ class EmbeddingGenerator:
|
||||
|
||||
# 打印前几个嵌入向量的维度
|
||||
for i in range(min(5, len(embeddings))):
|
||||
print(f"实体 '{entity_texts[i]}' 嵌入向量维度: {len(embeddings[i])}")
|
||||
logger.debug(f"实体 '{entity_texts[i]}' 嵌入向量维度: {len(embeddings[i])}")
|
||||
|
||||
# 将嵌入向量赋值给实体
|
||||
for ent, emb in zip(entity_refs, embeddings):
|
||||
setattr(ent, "name_embedding", emb)
|
||||
|
||||
print(f"为 {len(entity_refs)} 个实体生成了嵌入向量")
|
||||
logger.info(f"为 {len(entity_refs)} 个实体生成了嵌入向量")
|
||||
return triplet_maps
|
||||
|
||||
|
||||
@@ -296,7 +297,7 @@ async def embedding_generation_all(
|
||||
Returns:
|
||||
(陈述句嵌入映射列表, 分块嵌入映射列表, 对话嵌入列表, 更新后的三元组映射列表)
|
||||
"""
|
||||
print("\n=== 综合嵌入向量生成(陈述句/分块/对话 + 实体)===")
|
||||
logger.debug("=== 综合嵌入向量生成(陈述句/分块/对话 + 实体)===")
|
||||
|
||||
generator = EmbeddingGenerator(embedding_id)
|
||||
|
||||
|
||||
@@ -188,7 +188,6 @@ async def _process_chunk_summary(
|
||||
response_model=MemorySummaryResponse,
|
||||
)
|
||||
summary_text = structured.summary.strip()
|
||||
|
||||
# Generate title and type for the summary
|
||||
title = None
|
||||
episodic_type = None
|
||||
|
||||
@@ -0,0 +1,175 @@
|
||||
"""
|
||||
Metadata extractor module.
|
||||
|
||||
Collects user-related statements from post-dedup graph data and
|
||||
extracts user metadata via an independent LLM call.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import List, Optional
|
||||
|
||||
from app.core.memory.models.graph_models import (
|
||||
ExtractedEntityNode,
|
||||
StatementEntityEdge,
|
||||
StatementNode,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Reuse the same user-entity detection logic from dedup module
|
||||
_USER_NAMES = {"用户", "我", "user", "i"}
|
||||
_CANONICAL_USER_TYPE = "用户"
|
||||
|
||||
|
||||
def _is_user_entity(ent: ExtractedEntityNode) -> bool:
|
||||
"""判断实体是否为用户实体"""
|
||||
name = (getattr(ent, "name", "") or "").strip().lower()
|
||||
etype = (getattr(ent, "entity_type", "") or "").strip()
|
||||
return name in _USER_NAMES or etype == _CANONICAL_USER_TYPE
|
||||
|
||||
|
||||
class MetadataExtractor:
|
||||
"""Extracts user metadata from post-dedup graph data via independent LLM call."""
|
||||
|
||||
def __init__(self, llm_client, language: Optional[str] = None):
|
||||
self.llm_client = llm_client
|
||||
self.language = language
|
||||
|
||||
@staticmethod
|
||||
def detect_language(statements: List[str]) -> str:
|
||||
"""根据 statement 文本内容检测语言。
|
||||
如果文本中包含中文字符则返回 "zh",否则返回 "en"。
|
||||
"""
|
||||
import re
|
||||
|
||||
combined = " ".join(statements)
|
||||
if re.search(r"[\u4e00-\u9fff]", combined):
|
||||
return "zh"
|
||||
return "en"
|
||||
|
||||
def collect_user_related_statements(
|
||||
self,
|
||||
entity_nodes: List[ExtractedEntityNode],
|
||||
statement_nodes: List[StatementNode],
|
||||
statement_entity_edges: List[StatementEntityEdge],
|
||||
) -> List[str]:
|
||||
"""
|
||||
从去重后的数据中筛选与用户直接相关且由用户发言的 statement 文本。
|
||||
|
||||
筛选逻辑:
|
||||
1. 用户实体 → StatementEntityEdge → statement(直接关联)
|
||||
2. 只保留 speaker="user" 的 statement(过滤 assistant 回复的噪声)
|
||||
|
||||
Returns:
|
||||
用户发言的 statement 文本列表
|
||||
"""
|
||||
# Find user entity IDs
|
||||
user_entity_ids = set()
|
||||
for ent in entity_nodes:
|
||||
if _is_user_entity(ent):
|
||||
user_entity_ids.add(ent.id)
|
||||
|
||||
if not user_entity_ids:
|
||||
logger.debug("未找到用户实体节点,跳过 statement 收集")
|
||||
return []
|
||||
|
||||
# 用户实体 → StatementEntityEdge → statement
|
||||
target_stmt_ids = set()
|
||||
for edge in statement_entity_edges:
|
||||
if edge.target in user_entity_ids:
|
||||
target_stmt_ids.add(edge.source)
|
||||
|
||||
# Collect: only speaker="user" statements, preserving order
|
||||
result = []
|
||||
seen = set()
|
||||
total_associated = 0
|
||||
skipped_non_user = 0
|
||||
for stmt_node in statement_nodes:
|
||||
if stmt_node.id in target_stmt_ids and stmt_node.id not in seen:
|
||||
total_associated += 1
|
||||
speaker = getattr(stmt_node, "speaker", None) or "unknown"
|
||||
if speaker == "user":
|
||||
text = (stmt_node.statement or "").strip()
|
||||
if text:
|
||||
result.append(text)
|
||||
else:
|
||||
skipped_non_user += 1
|
||||
seen.add(stmt_node.id)
|
||||
|
||||
logger.info(
|
||||
f"收集到 {len(result)} 条用户发言 statement "
|
||||
f"(直接关联: {total_associated}, speaker=user: {len(result)}, "
|
||||
f"跳过非user: {skipped_non_user})"
|
||||
)
|
||||
if result:
|
||||
for i, text in enumerate(result):
|
||||
logger.info(f" [user statement {i + 1}] {text}")
|
||||
if total_associated > 0 and len(result) == 0:
|
||||
logger.warning(
|
||||
f"有 {total_associated} 条直接关联 statement 但全部被 speaker 过滤,"
|
||||
f"可能本次写入不包含 user 消息"
|
||||
)
|
||||
return result
|
||||
|
||||
async def extract_metadata(
|
||||
self,
|
||||
statements: List[str],
|
||||
existing_metadata: Optional[dict] = None,
|
||||
existing_aliases: Optional[List[str]] = None,
|
||||
) -> Optional[tuple]:
|
||||
"""
|
||||
对筛选后的 statement 列表调用 LLM 提取元数据和用户别名。
|
||||
|
||||
Args:
|
||||
statements: 用户发言的 statement 文本列表
|
||||
existing_metadata: 数据库已有的元数据(可选)
|
||||
existing_aliases: 数据库已有的用户别名列表(可选)
|
||||
|
||||
Returns:
|
||||
(UserMetadata, List[str], List[str]) tuple: (metadata, aliases_to_add, aliases_to_remove) on success, None on failure
|
||||
"""
|
||||
if not statements:
|
||||
return None
|
||||
|
||||
try:
|
||||
from app.core.memory.utils.prompt.prompt_utils import prompt_env
|
||||
|
||||
if self.language:
|
||||
detected_language = self.language
|
||||
logger.info(f"元数据提取使用显式指定语言: {detected_language}")
|
||||
else:
|
||||
detected_language = self.detect_language(statements)
|
||||
logger.info(f"元数据提取语言自动检测结果: {detected_language}")
|
||||
|
||||
template = prompt_env.get_template("extract_user_metadata.jinja2")
|
||||
prompt = template.render(
|
||||
statements=statements,
|
||||
language=detected_language,
|
||||
existing_metadata=existing_metadata,
|
||||
existing_aliases=existing_aliases,
|
||||
json_schema="",
|
||||
)
|
||||
|
||||
from app.core.memory.models.metadata_models import (
|
||||
MetadataExtractionResponse,
|
||||
)
|
||||
|
||||
response = await self.llm_client.response_structured(
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
response_model=MetadataExtractionResponse,
|
||||
)
|
||||
|
||||
if response:
|
||||
metadata = response.user_metadata if response.user_metadata else None
|
||||
to_add = response.aliases_to_add if response.aliases_to_add else []
|
||||
to_remove = (
|
||||
response.aliases_to_remove if response.aliases_to_remove else []
|
||||
)
|
||||
return metadata, to_add, to_remove
|
||||
|
||||
logger.warning("LLM 返回的响应为空")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"元数据提取 LLM 调用失败: {e}", exc_info=True)
|
||||
return None
|
||||
@@ -1,6 +1,5 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
@@ -82,6 +81,7 @@ class StatementExtractor:
|
||||
logger.warning(f"Chunk {getattr(chunk, 'id', 'unknown')} has no speaker field or is empty")
|
||||
return None
|
||||
|
||||
|
||||
async def _extract_statements(self, chunk, end_user_id: Optional[str] = None, dialogue_content: str = None) -> List[Statement]:
|
||||
"""Process a single chunk and return extracted statements
|
||||
|
||||
@@ -94,7 +94,8 @@ class StatementExtractor:
|
||||
List of ExtractedStatement objects extracted from the chunk
|
||||
"""
|
||||
chunk_content = chunk.content
|
||||
|
||||
chunk_speaker = self._get_speaker_from_chunk(chunk)
|
||||
|
||||
if not chunk_content or len(chunk_content.strip()) < 5:
|
||||
logger.warning(f"Chunk {chunk.id} content too short or empty, skipping")
|
||||
return []
|
||||
@@ -149,8 +150,6 @@ class StatementExtractor:
|
||||
relevence_info = RelevenceInfo[relevence_str] if relevence_str in RelevenceInfo.__members__ else RelevenceInfo.RELEVANT
|
||||
except (KeyError, ValueError):
|
||||
relevence_info = RelevenceInfo.RELEVANT
|
||||
|
||||
chunk_speaker = self._get_speaker_from_chunk(chunk)
|
||||
|
||||
chunk_statement = Statement(
|
||||
statement=extracted_stmt.statement,
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
import os
|
||||
import asyncio
|
||||
from typing import List, Dict, Optional
|
||||
|
||||
from app.core.logging_config import get_memory_logger
|
||||
from app.core.memory.llm_tools.openai_client import OpenAIClient
|
||||
from app.core.memory.utils.prompt.prompt_utils import render_triplet_extraction_prompt
|
||||
from app.core.memory.utils.data.ontology import PREDICATE_DEFINITIONS, Predicate # 引入枚举 Predicate 白名单过滤
|
||||
from app.core.memory.utils.data.ontology import PREDICATE_DEFINITIONS, Predicate # 引入枚举 Predicate 白名单过滤
|
||||
from app.core.memory.models.triplet_models import TripletExtractionResponse
|
||||
from app.core.memory.models.message_models import DialogData, Statement
|
||||
from app.core.memory.models.ontology_extraction_models import OntologyTypeList
|
||||
@@ -14,15 +13,15 @@ from app.core.memory.utils.log.logging_utils import prompt_logger
|
||||
logger = get_memory_logger(__name__)
|
||||
|
||||
|
||||
|
||||
class TripletExtractor:
|
||||
"""Extracts knowledge triplets and entities from statements using LLM"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
llm_client: OpenAIClient,
|
||||
ontology_types: Optional[OntologyTypeList] = None,
|
||||
language: str = "zh"):
|
||||
self,
|
||||
llm_client: OpenAIClient,
|
||||
ontology_types: Optional[OntologyTypeList] = None,
|
||||
language: str = "zh"
|
||||
):
|
||||
"""Initialize the TripletExtractor with an LLM client
|
||||
|
||||
Args:
|
||||
@@ -61,11 +60,13 @@ class TripletExtractor:
|
||||
predicate_instructions=PREDICATE_DEFINITIONS,
|
||||
language=self._get_language(),
|
||||
ontology_types=self.ontology_types,
|
||||
speaker=getattr(statement, 'speaker', None),
|
||||
)
|
||||
|
||||
# Create messages for LLM
|
||||
messages = [
|
||||
{"role": "system", "content": "You are an expert at extracting knowledge triplets and entities from text. Follow the provided instructions carefully and return valid JSON."},
|
||||
{"role": "system",
|
||||
"content": "You are an expert at extracting knowledge triplets and entities from text. Follow the provided instructions carefully and return valid JSON."},
|
||||
{"role": "user", "content": prompt_content}
|
||||
]
|
||||
|
||||
@@ -116,7 +117,8 @@ class TripletExtractor:
|
||||
logger.error(f"Error processing statement: {e}", exc_info=True)
|
||||
return TripletExtractionResponse(triplets=[], entities=[])
|
||||
|
||||
async def extract_triplets_from_statements(self, dialog_data: DialogData, limit_chunks: int = None) -> Dict[str, TripletExtractionResponse]:
|
||||
async def extract_triplets_from_statements(self, dialog_data: DialogData, limit_chunks: int = None) -> Dict[
|
||||
str, TripletExtractionResponse]:
|
||||
"""Extract triplets and entities from statements
|
||||
|
||||
Args:
|
||||
|
||||
@@ -42,22 +42,21 @@ class AccessHistoryManager:
|
||||
- access_count: 访问次数
|
||||
|
||||
特性:
|
||||
- 原子性更新:使用Neo4j事务确保所有字段同时更新或回滚
|
||||
- 并发安全:使用乐观锁机制防止并发冲突
|
||||
- 原子性更新:使用 APOC 原子操作确保并发安全
|
||||
- 批次内合并:同一批次中对同一节点的多次访问合并为一次更新
|
||||
- 一致性保证:提供一致性检查和自动修复功能
|
||||
- 智能修剪:自动修剪过长的访问历史
|
||||
|
||||
Attributes:
|
||||
connector: Neo4j连接器实例
|
||||
actr_calculator: ACT-R激活值计算器实例
|
||||
max_retries: 并发冲突时的最大重试次数
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
connector: Neo4jConnector,
|
||||
actr_calculator: ACTRCalculator,
|
||||
max_retries: int = 3
|
||||
max_retries: int = 5
|
||||
):
|
||||
"""
|
||||
初始化访问历史管理器
|
||||
@@ -65,47 +64,35 @@ class AccessHistoryManager:
|
||||
Args:
|
||||
connector: Neo4j连接器实例
|
||||
actr_calculator: ACT-R激活值计算器实例
|
||||
max_retries: 并发冲突时的最大重试次数(默认3次)
|
||||
max_retries: 已废弃,保留参数兼容性(APOC 原子操作无需重试)
|
||||
"""
|
||||
self.connector = connector
|
||||
self.actr_calculator = actr_calculator
|
||||
self.max_retries = max_retries
|
||||
|
||||
|
||||
async def record_access(
|
||||
self,
|
||||
node_id: str,
|
||||
node_label: str,
|
||||
end_user_id: Optional[str] = None,
|
||||
current_time: Optional[datetime] = None
|
||||
current_time: Optional[datetime] = None,
|
||||
access_times: int = 1
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
记录节点访问并原子性更新所有相关字段
|
||||
|
||||
这是核心方法,实现了:
|
||||
1. 首次访问:初始化access_history,计算初始激活值
|
||||
2. 后续访问:追加访问历史,重新计算激活值
|
||||
3. 历史修剪:当历史过长时自动修剪
|
||||
4. 原子性:所有字段在单个事务中更新
|
||||
5. 并发安全:使用乐观锁重试机制
|
||||
|
||||
Args:
|
||||
node_id: 节点ID
|
||||
node_label: 节点标签(Statement, ExtractedEntity, MemorySummary)
|
||||
end_user_id: 组ID(可选,用于过滤)
|
||||
current_time: 当前时间(可选,默认使用系统时间)
|
||||
access_times: 本次访问次数(默认1,批量合并时可能大于1)
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 更新后的节点数据,包含:
|
||||
- id: 节点ID
|
||||
- activation_value: 更新后的激活值
|
||||
- access_history: 更新后的访问历史
|
||||
- last_access_time: 最后访问时间
|
||||
- access_count: 访问次数
|
||||
- importance_score: 重要性分数
|
||||
Dict[str, Any]: 更新后的节点数据
|
||||
|
||||
Raises:
|
||||
ValueError: 如果节点不存在或节点标签无效
|
||||
RuntimeError: 如果重试次数耗尽仍然失败
|
||||
RuntimeError: 如果更新失败
|
||||
"""
|
||||
if current_time is None:
|
||||
current_time = datetime.now()
|
||||
@@ -119,55 +106,48 @@ class AccessHistoryManager:
|
||||
f"Invalid node_label: {node_label}. Must be one of {valid_labels}"
|
||||
)
|
||||
|
||||
# 使用乐观锁重试机制处理并发冲突
|
||||
for attempt in range(self.max_retries):
|
||||
try:
|
||||
# 步骤1:读取当前节点状态
|
||||
node_data = await self._fetch_node(node_id, node_label, end_user_id)
|
||||
|
||||
if not node_data:
|
||||
raise ValueError(
|
||||
f"Node not found: {node_label} with id={node_id}"
|
||||
)
|
||||
|
||||
# 步骤2:计算新的访问历史和激活值
|
||||
update_data = await self._calculate_update(
|
||||
node_data=node_data,
|
||||
current_time=current_time,
|
||||
current_time_iso=current_time_iso
|
||||
try:
|
||||
# 步骤1:读取当前节点状态
|
||||
node_data = await self._fetch_node(node_id, node_label, end_user_id)
|
||||
|
||||
if not node_data:
|
||||
raise ValueError(
|
||||
f"Node not found: {node_label} with id={node_id}"
|
||||
)
|
||||
|
||||
# 步骤3:原子性更新节点(使用事务)
|
||||
updated_node = await self._atomic_update(
|
||||
node_id=node_id,
|
||||
node_label=node_label,
|
||||
update_data=update_data,
|
||||
end_user_id=end_user_id
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"成功记录访问: {node_label}[{node_id}], "
|
||||
f"activation={update_data['activation_value']:.4f}, "
|
||||
f"access_count={update_data['access_count']}"
|
||||
)
|
||||
|
||||
return updated_node
|
||||
|
||||
except Exception as e:
|
||||
if attempt < self.max_retries - 1:
|
||||
logger.warning(
|
||||
f"访问记录失败(尝试 {attempt + 1}/{self.max_retries}): {str(e)}"
|
||||
)
|
||||
continue
|
||||
else:
|
||||
logger.error(
|
||||
f"访问记录失败,重试次数耗尽: {node_label}[{node_id}], "
|
||||
f"错误: {str(e)}"
|
||||
)
|
||||
raise RuntimeError(
|
||||
f"Failed to record access after {self.max_retries} attempts: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
# 步骤2:计算新的访问历史和激活值
|
||||
update_data = await self._calculate_update(
|
||||
node_data=node_data,
|
||||
current_time=current_time,
|
||||
current_time_iso=current_time_iso,
|
||||
access_times=access_times
|
||||
)
|
||||
|
||||
# 步骤3:使用 APOC 原子操作更新节点(无需重试)
|
||||
updated_node = await self._atomic_update(
|
||||
node_id=node_id,
|
||||
node_label=node_label,
|
||||
update_data=update_data,
|
||||
end_user_id=end_user_id
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"成功记录访问: {node_label}[{node_id}], "
|
||||
f"activation={update_data['activation_value']:.4f}, "
|
||||
f"access_count={update_data['access_count']}"
|
||||
f"{f', 合并访问次数={access_times}' if access_times > 1 else ''}"
|
||||
)
|
||||
|
||||
return updated_node
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"访问记录失败: {node_label}[{node_id}], 错误: {str(e)}"
|
||||
)
|
||||
raise RuntimeError(
|
||||
f"Failed to record access: {str(e)}"
|
||||
) from e
|
||||
|
||||
async def record_batch_access(
|
||||
self,
|
||||
node_ids: List[str],
|
||||
@@ -178,11 +158,10 @@ class AccessHistoryManager:
|
||||
"""
|
||||
批量记录多个节点的访问
|
||||
|
||||
为提高性能,批量更新多个节点的访问历史。
|
||||
每个节点独立更新,失败的节点不影响其他节点。
|
||||
对同一个节点的多次访问会先在内存中合并,只发起一次更新。
|
||||
|
||||
Args:
|
||||
node_ids: 节点ID列表
|
||||
node_ids: 节点ID列表(可包含重复ID)
|
||||
node_label: 节点标签(所有节点必须是同一类型)
|
||||
end_user_id: 组ID(可选)
|
||||
current_time: 当前时间(可选)
|
||||
@@ -196,25 +175,38 @@ class AccessHistoryManager:
|
||||
if current_time is None:
|
||||
current_time = datetime.now()
|
||||
|
||||
# PERFORMANCE FIX: Process all nodes in parallel instead of sequentially
|
||||
tasks = []
|
||||
# 合并同一节点的访问次数,避免对同一节点并发写入
|
||||
access_count_map: Dict[str, int] = {}
|
||||
for node_id in node_ids:
|
||||
access_count_map[node_id] = access_count_map.get(node_id, 0) + 1
|
||||
|
||||
merged_count = len(node_ids) - len(access_count_map)
|
||||
if merged_count > 0:
|
||||
logger.info(
|
||||
f"批量访问合并: 原始={len(node_ids)}, "
|
||||
f"去重后={len(access_count_map)}, 合并={merged_count}"
|
||||
)
|
||||
|
||||
# 对去重后的节点并行发起更新
|
||||
tasks = []
|
||||
for node_id, access_times in access_count_map.items():
|
||||
task = self.record_access(
|
||||
node_id=node_id,
|
||||
node_label=node_label,
|
||||
end_user_id=end_user_id,
|
||||
current_time=current_time
|
||||
current_time=current_time,
|
||||
access_times=access_times
|
||||
)
|
||||
tasks.append(task)
|
||||
tasks.append((node_id, task))
|
||||
|
||||
# Execute all tasks in parallel
|
||||
task_results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
task_results = await asyncio.gather(
|
||||
*[t for _, t in tasks], return_exceptions=True
|
||||
)
|
||||
|
||||
# Collect successful results and count failures
|
||||
results = []
|
||||
failed_count = 0
|
||||
|
||||
for node_id, result in zip(node_ids, task_results):
|
||||
for (node_id, _), result in zip(tasks, task_results):
|
||||
if isinstance(result, Exception):
|
||||
failed_count += 1
|
||||
logger.warning(
|
||||
@@ -225,12 +217,12 @@ class AccessHistoryManager:
|
||||
|
||||
batch_duration = time.time() - batch_start
|
||||
logger.info(
|
||||
f"[PERF] 批量访问记录完成: 成功 {len(results)}/{len(node_ids)}, "
|
||||
f"[PERF] 批量访问记录完成: 成功 {len(results)}/{len(access_count_map)}, "
|
||||
f"失败 {failed_count}, 耗时 {batch_duration:.4f}s"
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
async def check_consistency(
|
||||
self,
|
||||
node_id: str,
|
||||
@@ -239,22 +231,6 @@ class AccessHistoryManager:
|
||||
) -> Tuple[ConsistencyCheckResult, Optional[str]]:
|
||||
"""
|
||||
检查节点数据的一致性
|
||||
|
||||
验证以下一致性规则:
|
||||
1. access_history[-1] == last_access_time
|
||||
2. len(access_history) == access_count
|
||||
3. 如果有访问历史,必须有激活值
|
||||
4. 激活值必须在有效范围内 [offset, 1.0]
|
||||
|
||||
Args:
|
||||
node_id: 节点ID
|
||||
node_label: 节点标签
|
||||
end_user_id: 组ID(可选)
|
||||
|
||||
Returns:
|
||||
Tuple[ConsistencyCheckResult, Optional[str]]:
|
||||
- 一致性检查结果枚举
|
||||
- 错误描述(如果不一致)
|
||||
"""
|
||||
node_data = await self._fetch_node(node_id, node_label, end_user_id)
|
||||
|
||||
@@ -266,7 +242,6 @@ class AccessHistoryManager:
|
||||
access_count = node_data.get('access_count', 0)
|
||||
activation_value = node_data.get('activation_value')
|
||||
|
||||
# 检查1:access_history[-1] == last_access_time
|
||||
if access_history and last_access_time:
|
||||
if access_history[-1] != last_access_time:
|
||||
return (
|
||||
@@ -275,7 +250,6 @@ class AccessHistoryManager:
|
||||
f"last_access_time={last_access_time}"
|
||||
)
|
||||
|
||||
# 检查2:len(access_history) == access_count
|
||||
if len(access_history) != access_count:
|
||||
return (
|
||||
ConsistencyCheckResult.INCONSISTENT_HISTORY_COUNT,
|
||||
@@ -283,14 +257,12 @@ class AccessHistoryManager:
|
||||
f"access_count={access_count}"
|
||||
)
|
||||
|
||||
# 检查3:有访问历史必须有激活值
|
||||
if access_history and activation_value is None:
|
||||
return (
|
||||
ConsistencyCheckResult.MISSING_ACTIVATION,
|
||||
"Node has access_history but activation_value is None"
|
||||
)
|
||||
|
||||
# 检查4:激活值范围
|
||||
if activation_value is not None:
|
||||
offset = self.actr_calculator.offset
|
||||
if not (offset <= activation_value <= 1.0):
|
||||
@@ -301,30 +273,14 @@ class AccessHistoryManager:
|
||||
)
|
||||
|
||||
return ConsistencyCheckResult.CONSISTENT, None
|
||||
|
||||
|
||||
async def check_batch_consistency(
|
||||
self,
|
||||
node_label: str,
|
||||
end_user_id: Optional[str] = None,
|
||||
limit: int = 1000
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
批量检查多个节点的一致性
|
||||
|
||||
Args:
|
||||
node_label: 节点标签
|
||||
end_user_id: 组ID(可选)
|
||||
limit: 检查的最大节点数
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 一致性检查报告,包含:
|
||||
- total_checked: 检查的节点总数
|
||||
- consistent_count: 一致的节点数
|
||||
- inconsistent_count: 不一致的节点数
|
||||
- inconsistencies: 不一致节点的详细信息列表
|
||||
- consistency_rate: 一致性率(0-1)
|
||||
"""
|
||||
# 查询所有相关节点
|
||||
"""批量检查多个节点的一致性"""
|
||||
query = f"""
|
||||
MATCH (n:{node_label})
|
||||
WHERE n.access_history IS NOT NULL
|
||||
@@ -343,7 +299,6 @@ class AccessHistoryManager:
|
||||
results = await self.connector.execute_query(query, **params)
|
||||
node_ids = [r['id'] for r in results]
|
||||
|
||||
# 检查每个节点
|
||||
inconsistencies = []
|
||||
consistent_count = 0
|
||||
|
||||
@@ -382,32 +337,15 @@ class AccessHistoryManager:
|
||||
)
|
||||
|
||||
return report
|
||||
|
||||
|
||||
async def repair_inconsistency(
|
||||
self,
|
||||
node_id: str,
|
||||
node_label: str,
|
||||
end_user_id: Optional[str] = None
|
||||
) -> bool:
|
||||
"""
|
||||
自动修复节点的数据不一致问题
|
||||
|
||||
修复策略:
|
||||
1. 如果access_history[-1] != last_access_time:使用access_history[-1]
|
||||
2. 如果len(access_history) != access_count:使用len(access_history)
|
||||
3. 如果有历史但无激活值:重新计算激活值
|
||||
4. 如果激活值超出范围:重新计算激活值
|
||||
|
||||
Args:
|
||||
node_id: 节点ID
|
||||
node_label: 节点标签
|
||||
end_user_id: 组ID(可选)
|
||||
|
||||
Returns:
|
||||
bool: 修复成功返回True,否则返回False
|
||||
"""
|
||||
"""自动修复节点的数据不一致问题"""
|
||||
try:
|
||||
# 检查一致性
|
||||
result, message = await self.check_consistency(
|
||||
node_id=node_id,
|
||||
node_label=node_label,
|
||||
@@ -418,7 +356,6 @@ class AccessHistoryManager:
|
||||
logger.info(f"节点数据一致,无需修复: {node_label}[{node_id}]")
|
||||
return True
|
||||
|
||||
# 获取节点数据
|
||||
node_data = await self._fetch_node(node_id, node_label, end_user_id)
|
||||
if not node_data:
|
||||
logger.error(f"节点不存在,无法修复: {node_label}[{node_id}]")
|
||||
@@ -427,17 +364,13 @@ class AccessHistoryManager:
|
||||
access_history = node_data.get('access_history') or []
|
||||
importance_score = node_data.get('importance_score', 0.5)
|
||||
|
||||
# 准备修复数据
|
||||
repair_data = {}
|
||||
|
||||
# 修复last_access_time
|
||||
if access_history:
|
||||
repair_data['last_access_time'] = access_history[-1]
|
||||
|
||||
# 修复access_count
|
||||
repair_data['access_count'] = len(access_history)
|
||||
|
||||
# 修复activation_value
|
||||
if access_history:
|
||||
current_time = datetime.now()
|
||||
last_access_dt = datetime.fromisoformat(access_history[-1])
|
||||
@@ -453,7 +386,6 @@ class AccessHistoryManager:
|
||||
)
|
||||
repair_data['activation_value'] = activation_value
|
||||
|
||||
# 执行修复
|
||||
query = f"""
|
||||
MATCH (n:{node_label} {{id: $node_id}})
|
||||
"""
|
||||
@@ -484,26 +416,16 @@ class AccessHistoryManager:
|
||||
f"修复节点失败: {node_label}[{node_id}], 错误: {str(e)}"
|
||||
)
|
||||
return False
|
||||
|
||||
|
||||
# ==================== 私有辅助方法 ====================
|
||||
|
||||
|
||||
async def _fetch_node(
|
||||
self,
|
||||
node_id: str,
|
||||
node_label: str,
|
||||
end_user_id: Optional[str] = None
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
获取节点数据
|
||||
|
||||
Args:
|
||||
node_id: 节点ID
|
||||
node_label: 节点标签
|
||||
end_user_id: 组ID(可选)
|
||||
|
||||
Returns:
|
||||
Optional[Dict[str, Any]]: 节点数据,如果不存在返回None
|
||||
"""
|
||||
"""获取节点数据"""
|
||||
query = f"""
|
||||
MATCH (n:{node_label} {{id: $node_id}})
|
||||
"""
|
||||
@@ -527,12 +449,13 @@ class AccessHistoryManager:
|
||||
if results:
|
||||
return results[0]
|
||||
return None
|
||||
|
||||
|
||||
async def _calculate_update(
|
||||
self,
|
||||
node_data: Dict[str, Any],
|
||||
current_time: datetime,
|
||||
current_time_iso: str
|
||||
current_time_iso: str,
|
||||
access_times: int = 1
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
计算更新数据
|
||||
@@ -541,45 +464,40 @@ class AccessHistoryManager:
|
||||
node_data: 当前节点数据
|
||||
current_time: 当前时间(datetime对象)
|
||||
current_time_iso: 当前时间(ISO格式字符串)
|
||||
access_times: 本次访问次数(合并后可能大于1)
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 更新数据,包含所有需要更新的字段
|
||||
Dict[str, Any]: 更新数据
|
||||
"""
|
||||
access_history = node_data.get('access_history') or []
|
||||
# Handle None importance_score - default to 0.5
|
||||
importance_score = node_data.get('importance_score')
|
||||
if importance_score is None:
|
||||
importance_score = 0.5
|
||||
|
||||
# 追加新的访问时间
|
||||
new_access_history = access_history + [current_time_iso]
|
||||
# 本次新增的时间戳
|
||||
new_timestamps = [current_time_iso] * access_times
|
||||
|
||||
# 修剪访问历史(如果过长)
|
||||
access_history_dt = [
|
||||
datetime.fromisoformat(ts) for ts in new_access_history
|
||||
]
|
||||
# 仅用本次新增的访问记录计算激活值
|
||||
new_history_dt = [current_time] * access_times
|
||||
trimmed_history_dt = self.actr_calculator.trim_access_history(
|
||||
access_history=access_history_dt,
|
||||
access_history=new_history_dt,
|
||||
current_time=current_time
|
||||
)
|
||||
trimmed_history = [ts.isoformat() for ts in trimmed_history_dt]
|
||||
|
||||
# 计算新的激活值
|
||||
activation_value = self.actr_calculator.calculate_memory_activation(
|
||||
access_history=trimmed_history_dt,
|
||||
current_time=current_time,
|
||||
last_access_time=current_time, # 最后访问时间就是当前时间
|
||||
last_access_time=current_time,
|
||||
importance_score=importance_score
|
||||
)
|
||||
|
||||
# 返回所有需要更新的字段
|
||||
return {
|
||||
'activation_value': activation_value,
|
||||
'access_history': trimmed_history,
|
||||
'new_timestamps': new_timestamps,
|
||||
'access_count_delta': access_times,
|
||||
'access_count': len(trimmed_history_dt),
|
||||
'last_access_time': current_time_iso,
|
||||
'access_count': len(trimmed_history)
|
||||
}
|
||||
|
||||
|
||||
async def _atomic_update(
|
||||
self,
|
||||
node_id: str,
|
||||
@@ -588,10 +506,10 @@ class AccessHistoryManager:
|
||||
end_user_id: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
原子性更新节点(使用乐观锁)
|
||||
原子性更新节点(使用 APOC 原子操作)
|
||||
|
||||
使用Neo4j事务和版本号确保所有字段同时更新或回滚。
|
||||
实现乐观锁机制防止并发冲突。
|
||||
使用 apoc.atomic.add 和 apoc.atomic.insert 保证并发安全,
|
||||
无需 version 字段和乐观锁,数据库层面保证原子性。
|
||||
|
||||
Args:
|
||||
node_id: 节点ID
|
||||
@@ -603,126 +521,68 @@ class AccessHistoryManager:
|
||||
Dict[str, Any]: 更新后的节点数据
|
||||
|
||||
Raises:
|
||||
RuntimeError: 如果更新失败或发生版本冲突
|
||||
RuntimeError: 如果更新失败
|
||||
"""
|
||||
# 定义事务函数
|
||||
async def update_transaction(tx, node_id, node_label, update_data, end_user_id):
|
||||
# 步骤1:读取当前节点并获取版本号
|
||||
read_query = f"""
|
||||
MATCH (n:{node_label} {{id: $node_id}})
|
||||
"""
|
||||
if end_user_id:
|
||||
read_query += " WHERE n.end_user_id = $end_user_id"
|
||||
read_query += """
|
||||
RETURN n.id as id,
|
||||
n.version as version,
|
||||
n.activation_value as activation_value,
|
||||
n.access_history as access_history,
|
||||
n.last_access_time as last_access_time,
|
||||
n.access_count as access_count,
|
||||
n.importance_score as importance_score
|
||||
"""
|
||||
content_field_map = {
|
||||
'Statement': 'n.statement as statement',
|
||||
'MemorySummary': 'n.content as content',
|
||||
'ExtractedEntity': 'null as content_placeholder',
|
||||
'Community': 'n.summary as summary'
|
||||
}
|
||||
|
||||
if node_label not in content_field_map:
|
||||
raise ValueError(
|
||||
f"Unsupported node_label: {node_label}. "
|
||||
f"Supported labels are: {list(content_field_map.keys())}"
|
||||
)
|
||||
|
||||
content_field = content_field_map[node_label]
|
||||
|
||||
where_clause = ""
|
||||
if end_user_id:
|
||||
where_clause = " AND n.end_user_id = $end_user_id"
|
||||
|
||||
query = f"""
|
||||
MATCH (n:{node_label} {{id: $node_id}})
|
||||
WHERE true{where_clause}
|
||||
CALL apoc.atomic.add(n, 'access_count', $access_count_delta, 5) YIELD oldValue AS old_count
|
||||
WITH n
|
||||
CALL (n) {{
|
||||
UNWIND $new_timestamps AS ts
|
||||
CALL apoc.atomic.insert(n, 'access_history', size(n.access_history), ts, 5) YIELD oldValue
|
||||
RETURN count(*) AS inserted
|
||||
}}
|
||||
SET n.activation_value = $activation_value,
|
||||
n.last_access_time = $last_access_time
|
||||
RETURN n.id as id,
|
||||
n.activation_value as activation_value,
|
||||
n.access_history as access_history,
|
||||
n.last_access_time as last_access_time,
|
||||
n.access_count as access_count,
|
||||
n.importance_score as importance_score,
|
||||
{content_field}
|
||||
"""
|
||||
|
||||
params = {
|
||||
'node_id': node_id,
|
||||
'access_count_delta': update_data['access_count_delta'],
|
||||
'new_timestamps': update_data['new_timestamps'],
|
||||
'activation_value': update_data['activation_value'],
|
||||
'last_access_time': update_data['last_access_time'],
|
||||
}
|
||||
if end_user_id:
|
||||
params['end_user_id'] = end_user_id
|
||||
|
||||
try:
|
||||
results = await self.connector.execute_query(query, **params)
|
||||
|
||||
read_params = {'node_id': node_id}
|
||||
if end_user_id:
|
||||
read_params['end_user_id'] = end_user_id
|
||||
|
||||
read_result = await tx.run(read_query, **read_params)
|
||||
current_node = await read_result.single()
|
||||
|
||||
if not current_node:
|
||||
if not results:
|
||||
raise RuntimeError(f"Node not found: {node_label}[{node_id}]")
|
||||
|
||||
# 获取当前版本号(如果不存在则为0)
|
||||
current_version = current_node.get('version', 0) or 0
|
||||
new_version = current_version + 1
|
||||
|
||||
# 步骤2:使用乐观锁更新节点
|
||||
# 根据节点类型构建完整的查询语句
|
||||
content_field_map = {
|
||||
'Statement': 'n.statement as statement',
|
||||
'MemorySummary': 'n.content as content',
|
||||
'ExtractedEntity': 'null as content_placeholder' # 占位符,后续会被过滤
|
||||
}
|
||||
|
||||
# 显式检查节点类型,不支持的类型抛出错误
|
||||
if node_label not in content_field_map:
|
||||
raise ValueError(
|
||||
f"Unsupported node_label: {node_label}. "
|
||||
f"Supported labels are: {list(content_field_map.keys())}"
|
||||
)
|
||||
|
||||
content_field = content_field_map[node_label]
|
||||
|
||||
# 构建 WHERE 子句
|
||||
where_conditions = []
|
||||
if end_user_id:
|
||||
where_conditions.append("n.end_user_id = $end_user_id")
|
||||
|
||||
# 添加版本检查
|
||||
if current_version > 0:
|
||||
where_conditions.append("n.version = $current_version")
|
||||
else:
|
||||
where_conditions.append("(n.version IS NULL OR n.version = 0)")
|
||||
|
||||
where_clause = " AND ".join(where_conditions) if where_conditions else "true"
|
||||
|
||||
# 构建完整的更新查询
|
||||
update_query = f"""
|
||||
MATCH (n:{node_label} {{id: $node_id}})
|
||||
WHERE {where_clause}
|
||||
SET n.activation_value = $activation_value,
|
||||
n.access_history = $access_history,
|
||||
n.last_access_time = $last_access_time,
|
||||
n.access_count = $access_count,
|
||||
n.version = $new_version
|
||||
RETURN n.id as id,
|
||||
n.activation_value as activation_value,
|
||||
n.access_history as access_history,
|
||||
n.last_access_time as last_access_time,
|
||||
n.access_count as access_count,
|
||||
n.importance_score as importance_score,
|
||||
n.version as version,
|
||||
{content_field}
|
||||
"""
|
||||
|
||||
update_params = {
|
||||
'node_id': node_id,
|
||||
'current_version': current_version,
|
||||
'new_version': new_version,
|
||||
'activation_value': update_data['activation_value'],
|
||||
'access_history': update_data['access_history'],
|
||||
'last_access_time': update_data['last_access_time'],
|
||||
'access_count': update_data['access_count']
|
||||
}
|
||||
if end_user_id:
|
||||
update_params['end_user_id'] = end_user_id
|
||||
|
||||
update_result = await tx.run(update_query, **update_params)
|
||||
updated_node = await update_result.single()
|
||||
|
||||
if not updated_node:
|
||||
raise RuntimeError(
|
||||
f"Version conflict detected for {node_label}[{node_id}]. "
|
||||
f"Expected version {current_version}, but node was modified by another transaction."
|
||||
)
|
||||
|
||||
# 转换为字典并移除占位符字段
|
||||
result_dict = dict(updated_node)
|
||||
result_dict = dict(results[0])
|
||||
result_dict.pop('content_placeholder', None)
|
||||
|
||||
return result_dict
|
||||
|
||||
# 执行事务
|
||||
try:
|
||||
result = await self.connector.execute_write_transaction(
|
||||
update_transaction,
|
||||
node_id=node_id,
|
||||
node_label=node_label,
|
||||
update_data=update_data,
|
||||
end_user_id=end_user_id
|
||||
)
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"原子性更新失败: {node_label}[{node_id}], 错误: {str(e)}"
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
"""
|
||||
自我反思引擎实现
|
||||
Self-Reflection Engine Implementation
|
||||
|
||||
该模块实现了记忆系统的自我反思功能,包括:
|
||||
1. 基于时间的反思 - 根据时间周期触发反思
|
||||
2. 基于事实的反思 - 检测记忆冲突并解决
|
||||
3. 综合反思 - 整合多种反思策略
|
||||
4. 反思结果应用 - 更新记忆库
|
||||
This module implements the self-reflection functionality of the memory system, including:
|
||||
1. Time-based reflection - Triggers reflection based on time cycles
|
||||
2. Fact-based reflection - Detects and resolves memory conflicts
|
||||
3. Comprehensive reflection - Integrates multiple reflection strategies
|
||||
4. Reflection result application - Updates memory database
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
@@ -38,7 +38,7 @@ from app.schemas.memory_storage_schema import (
|
||||
)
|
||||
from pydantic import BaseModel
|
||||
|
||||
# 配置日志
|
||||
# Configure logging
|
||||
_root_logger = logging.getLogger()
|
||||
if not _root_logger.handlers:
|
||||
logging.basicConfig(
|
||||
@@ -49,35 +49,62 @@ else:
|
||||
_root_logger.setLevel(logging.INFO)
|
||||
|
||||
class TranslationResponse(BaseModel):
|
||||
"""翻译响应模型"""
|
||||
"""Translation response model for language conversion"""
|
||||
data: str
|
||||
|
||||
class ReflectionRange(str, Enum):
|
||||
"""反思范围枚举"""
|
||||
PARTIAL = "partial" # 从检索结果中反思
|
||||
ALL = "all" # 从整个数据库中反思
|
||||
"""
|
||||
Reflection range enumeration
|
||||
|
||||
Defines the scope of data to be included in reflection operations.
|
||||
"""
|
||||
PARTIAL = "partial" # Reflect from retrieval results
|
||||
ALL = "all" # Reflect from entire database
|
||||
|
||||
|
||||
class ReflectionBaseline(str, Enum):
|
||||
"""反思基线枚举"""
|
||||
TIME = "TIME" # 基于时间的反思
|
||||
FACT = "FACT" # 基于事实的反思
|
||||
HYBRID = "HYBRID" # 混合反思
|
||||
"""
|
||||
Reflection baseline enumeration
|
||||
|
||||
Defines the strategy or approach used for reflection operations.
|
||||
"""
|
||||
TIME = "TIME" # Time-based reflection
|
||||
FACT = "FACT" # Fact-based reflection
|
||||
HYBRID = "HYBRID" # Hybrid reflection combining multiple strategies
|
||||
|
||||
|
||||
class ReflectionConfig(BaseModel):
|
||||
"""反思引擎配置"""
|
||||
"""
|
||||
Reflection engine configuration
|
||||
|
||||
Defines all configuration parameters for the reflection engine including
|
||||
operation modes, model settings, and evaluation criteria.
|
||||
|
||||
Attributes:
|
||||
enabled: Whether reflection engine is enabled
|
||||
iteration_period: Reflection cycle period (e.g., "3" hours)
|
||||
reflexion_range: Scope of reflection (PARTIAL or ALL)
|
||||
baseline: Reflection strategy (TIME, FACT, or HYBRID)
|
||||
model_id: LLM model identifier for reflection operations
|
||||
end_user_id: User identifier for scoped operations
|
||||
output_example: Example output format for guidance
|
||||
memory_verify: Enable memory verification checks
|
||||
quality_assessment: Enable quality assessment evaluation
|
||||
violation_handling_strategy: Strategy for handling violations
|
||||
language_type: Language type for output ("zh" or "en")
|
||||
"""
|
||||
enabled: bool = False
|
||||
iteration_period: str = "3" # 反思周期
|
||||
iteration_period: str = "3" # Reflection cycle period
|
||||
reflexion_range: ReflectionRange = ReflectionRange.PARTIAL
|
||||
baseline: ReflectionBaseline = ReflectionBaseline.TIME
|
||||
model_id: Optional[str] = None # 模型ID
|
||||
model_id: Optional[str] = None # Model ID
|
||||
end_user_id: Optional[str] = None
|
||||
output_example: Optional[str] = None # 输出示例
|
||||
output_example: Optional[str] = None # Output example
|
||||
|
||||
# 评估相关字段
|
||||
memory_verify: bool = True # 记忆验证
|
||||
quality_assessment: bool = True # 质量评估
|
||||
violation_handling_strategy: str = "warn" # 违规处理策略
|
||||
# Evaluation related fields
|
||||
memory_verify: bool = True # Memory verification
|
||||
quality_assessment: bool = True # Quality assessment
|
||||
violation_handling_strategy: str = "warn" # Violation handling strategy
|
||||
language_type: str = "zh"
|
||||
|
||||
class Config:
|
||||
@@ -85,7 +112,21 @@ class ReflectionConfig(BaseModel):
|
||||
|
||||
|
||||
class ReflectionResult(BaseModel):
|
||||
"""反思结果"""
|
||||
"""
|
||||
Reflection operation result
|
||||
|
||||
Contains comprehensive information about the outcome of a reflection operation
|
||||
including success status, metrics, and execution details.
|
||||
|
||||
Attributes:
|
||||
success: Whether the reflection operation succeeded
|
||||
message: Descriptive message about the operation result
|
||||
conflicts_found: Number of conflicts detected during reflection
|
||||
conflicts_resolved: Number of conflicts successfully resolved
|
||||
memories_updated: Number of memory entries updated in database
|
||||
execution_time: Total time taken for the reflection operation
|
||||
details: Additional details about the operation (optional)
|
||||
"""
|
||||
success: bool
|
||||
message: str
|
||||
conflicts_found: int = 0
|
||||
@@ -97,9 +138,22 @@ class ReflectionResult(BaseModel):
|
||||
|
||||
class ReflectionEngine:
|
||||
"""
|
||||
自我反思引擎
|
||||
|
||||
负责执行记忆系统的自我反思,包括冲突检测、冲突解决和记忆更新。
|
||||
Self-Reflection Engine
|
||||
|
||||
Responsible for executing memory system self-reflection operations including
|
||||
conflict detection, conflict resolution, and memory updates. Supports multiple
|
||||
reflection strategies and provides comprehensive result tracking.
|
||||
|
||||
The engine can operate in different modes:
|
||||
- Time-based: Reflects on memories within specific time periods
|
||||
- Fact-based: Detects and resolves factual conflicts in memories
|
||||
- Hybrid: Combines multiple reflection strategies
|
||||
|
||||
Attributes:
|
||||
config: Reflection engine configuration
|
||||
neo4j_connector: Neo4j database connector
|
||||
llm_client: Language model client for analysis
|
||||
Various function handlers for data processing and prompt rendering
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -115,18 +169,21 @@ class ReflectionEngine:
|
||||
update_query: Optional[str] = None
|
||||
):
|
||||
"""
|
||||
初始化反思引擎
|
||||
Initialize reflection engine
|
||||
|
||||
Sets up the reflection engine with configuration and optional dependencies.
|
||||
Uses lazy initialization to avoid circular imports and optimize startup time.
|
||||
|
||||
Args:
|
||||
config: 反思引擎配置
|
||||
neo4j_connector: Neo4j 连接器(可选)
|
||||
llm_client: LLM 客户端(可选)
|
||||
get_data_func: 获取数据的函数(可选)
|
||||
render_evaluate_prompt_func: 渲染评估提示词的函数(可选)
|
||||
render_reflexion_prompt_func: 渲染反思提示词的函数(可选)
|
||||
conflict_schema: 冲突结果 Schema(可选)
|
||||
reflexion_schema: 反思结果 Schema(可选)
|
||||
update_query: 更新查询语句(可选)
|
||||
config: Reflection engine configuration object
|
||||
neo4j_connector: Neo4j connector instance (optional, will be created if not provided)
|
||||
llm_client: LLM client instance (optional, will be created if not provided)
|
||||
get_data_func: Function for retrieving data (optional, uses default if not provided)
|
||||
render_evaluate_prompt_func: Function for rendering evaluation prompts (optional)
|
||||
render_reflexion_prompt_func: Function for rendering reflection prompts (optional)
|
||||
conflict_schema: Schema for conflict result validation (optional)
|
||||
reflexion_schema: Schema for reflection result validation (optional)
|
||||
update_query: Query string for database updates (optional)
|
||||
"""
|
||||
self.config = config
|
||||
self.neo4j_connector = neo4j_connector
|
||||
@@ -137,14 +194,20 @@ class ReflectionEngine:
|
||||
self.conflict_schema = conflict_schema
|
||||
self.reflexion_schema = reflexion_schema
|
||||
self.update_query = update_query
|
||||
self._semaphore = asyncio.Semaphore(5) # 默认并发数为5
|
||||
self._semaphore = asyncio.Semaphore(5) # Default concurrency limit of 5
|
||||
|
||||
|
||||
# 延迟导入以避免循环依赖
|
||||
# Lazy import to avoid circular dependencies
|
||||
self._lazy_init_done = False
|
||||
|
||||
def _lazy_init(self):
|
||||
"""延迟初始化,避免循环导入"""
|
||||
"""
|
||||
Lazy initialization to avoid circular imports
|
||||
|
||||
Initializes dependencies only when needed, preventing circular import issues
|
||||
and optimizing startup performance. Sets up default implementations for
|
||||
any components not provided during construction.
|
||||
"""
|
||||
if self._lazy_init_done:
|
||||
return
|
||||
|
||||
@@ -158,7 +221,7 @@ class ReflectionEngine:
|
||||
factory = MemoryClientFactory(db)
|
||||
self.llm_client = factory.get_llm_client(self.config.model_id)
|
||||
elif isinstance(self.llm_client, str):
|
||||
# 如果 llm_client 是字符串(model_id),则用它初始化客户端
|
||||
# If llm_client is a string (model_id), use it to initialize the client
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.db import get_db_context
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
@@ -172,10 +235,10 @@ class ReflectionEngine:
|
||||
model_config = config_service.get_model_config(model_id)
|
||||
|
||||
extra_params={
|
||||
"temperature": 0.2, # 降低温度提高响应速度和一致性
|
||||
"max_tokens": 600, # 限制最大token数
|
||||
"top_p": 0.8, # 优化采样参数
|
||||
"stream": False, # 确保非流式输出以获得最快响应
|
||||
"temperature": 0.2, # Lower temperature for faster response and consistency
|
||||
"max_tokens": 600, # Limit maximum token count
|
||||
"top_p": 0.8, # Optimize sampling parameters
|
||||
"stream": False, # Ensure non-streaming output for fastest response
|
||||
}
|
||||
|
||||
self.llm_client = OpenAIClient(RedBearModelConfig(
|
||||
@@ -191,7 +254,7 @@ class ReflectionEngine:
|
||||
if self.get_data_func is None:
|
||||
self.get_data_func = get_data
|
||||
|
||||
# 导入get_data_statement函数
|
||||
# Import get_data_statement function
|
||||
if not hasattr(self, 'get_data_statement'):
|
||||
self.get_data_statement = get_data_statement
|
||||
|
||||
@@ -223,13 +286,20 @@ class ReflectionEngine:
|
||||
|
||||
async def execute_reflection(self, host_id) -> ReflectionResult:
|
||||
"""
|
||||
执行完整的反思流程
|
||||
Execute complete reflection workflow
|
||||
|
||||
Performs the full reflection process including data retrieval, conflict detection,
|
||||
conflict resolution, and memory updates. This is the main entry point for
|
||||
reflection operations.
|
||||
|
||||
Args:
|
||||
host_id: 主机ID
|
||||
host_id: Host identifier for scoping reflection operations
|
||||
|
||||
Returns:
|
||||
ReflectionResult: 反思结果
|
||||
ReflectionResult: Comprehensive result of the reflection operation including
|
||||
success status, conflict metrics, and execution time
|
||||
"""
|
||||
# 延迟初始化
|
||||
# Lazy initialization
|
||||
self._lazy_init()
|
||||
|
||||
if not self.config.enabled:
|
||||
@@ -243,7 +313,7 @@ class ReflectionEngine:
|
||||
|
||||
print(self.config.baseline, self.config.memory_verify, self.config.quality_assessment)
|
||||
try:
|
||||
# 1. 获取反思数据
|
||||
# 1. Get reflection data
|
||||
reflexion_data, statement_databasets = await self._get_reflexion_data(host_id)
|
||||
if not reflexion_data:
|
||||
return ReflectionResult(
|
||||
@@ -252,7 +322,7 @@ class ReflectionEngine:
|
||||
execution_time=asyncio.get_event_loop().time() - start_time
|
||||
)
|
||||
|
||||
# 2. 检测冲突(基于事实的反思)
|
||||
# 2. Detect conflicts (fact-based reflection)
|
||||
conflict_data = await self._detect_conflicts(reflexion_data, statement_databasets)
|
||||
conflict_list=[]
|
||||
for i in conflict_data:
|
||||
@@ -261,7 +331,7 @@ class ReflectionEngine:
|
||||
|
||||
|
||||
conflicts_found=0
|
||||
# 3. 解决冲突
|
||||
# 3. Resolve conflicts
|
||||
solved_data = await self._resolve_conflicts(conflict_list, statement_databasets)
|
||||
|
||||
if not solved_data:
|
||||
@@ -276,7 +346,7 @@ class ReflectionEngine:
|
||||
logging.info(f"解决了 {conflicts_resolved} 个冲突")
|
||||
|
||||
|
||||
# 4. 应用反思结果(更新记忆库)
|
||||
# 4. Apply reflection results (update memory database)
|
||||
memories_updated=await self._apply_reflection_results(solved_data)
|
||||
|
||||
execution_time = asyncio.get_event_loop().time() - start_time
|
||||
@@ -302,7 +372,19 @@ class ReflectionEngine:
|
||||
)
|
||||
|
||||
async def Translate(self, text):
|
||||
# 翻译中文为英文
|
||||
"""
|
||||
Translate Chinese text to English
|
||||
|
||||
Uses the configured LLM to translate Chinese text to English with structured output.
|
||||
Provides consistent translation format for reflection results.
|
||||
|
||||
Args:
|
||||
text: Chinese text to be translated
|
||||
|
||||
Returns:
|
||||
str: Translated English text
|
||||
"""
|
||||
# Translate Chinese to English
|
||||
translation_messages = [
|
||||
{
|
||||
"role": "user",
|
||||
@@ -316,6 +398,19 @@ class ReflectionEngine:
|
||||
)
|
||||
return response.data
|
||||
async def extract_translation(self,data):
|
||||
"""
|
||||
Extract and translate reflection data to English
|
||||
|
||||
Processes reflection data structure and translates all Chinese content to English.
|
||||
Handles nested data structures including memory verifications, quality assessments,
|
||||
and reflection data while preserving the original structure.
|
||||
|
||||
Args:
|
||||
data: Dictionary containing reflection data with Chinese content
|
||||
|
||||
Returns:
|
||||
dict: Translated data structure with English content
|
||||
"""
|
||||
end_datas={}
|
||||
end_datas['source_data']=await self.Translate(data['source_data'])
|
||||
quality_assessments = []
|
||||
@@ -350,6 +445,18 @@ class ReflectionEngine:
|
||||
return end_datas
|
||||
|
||||
async def reflection_run(self):
|
||||
"""
|
||||
Execute reflection workflow with comprehensive data processing
|
||||
|
||||
Performs a complete reflection operation including conflict detection, resolution,
|
||||
and result formatting. Supports both Chinese and English output based on
|
||||
configuration settings.
|
||||
|
||||
Returns:
|
||||
dict: Comprehensive reflection results including source data, memory verifications,
|
||||
quality assessments, and reflection data. Results are translated to English
|
||||
if language_type is set to 'en'.
|
||||
"""
|
||||
self._lazy_init()
|
||||
start_time = time.time()
|
||||
memory_verifies_flag = self.config.memory_verify
|
||||
@@ -367,7 +474,7 @@ class ReflectionEngine:
|
||||
result_data['source_data'] = "我是 2023 年春天去北京工作的,后来基本一直都在北京上班,也没怎么换过城市。不过后来公司调整,2024 年上半年我被调到上海待了差不多半年,那段时间每天都是在上海办公室打卡。当时入职资料用的还是我之前的身份信息,身份证号是 11010119950308123X,银行卡是 6222023847595898,这些一直没变。对了,其实我 从 2023 年开始就一直在北京生活,从来没有长期离开过北京,上海那段更多算是远程配合"
|
||||
# 2. 检测冲突(基于事实的反思)
|
||||
conflict_data = await self._detect_conflicts(databasets, source_data)
|
||||
# 遍历数据提取字段
|
||||
# Traverse data to extract fields
|
||||
quality_assessments = []
|
||||
memory_verifies = []
|
||||
for item in conflict_data:
|
||||
@@ -375,9 +482,9 @@ class ReflectionEngine:
|
||||
memory_verifies.append(item['memory_verify'])
|
||||
result_data['memory_verifies'] = memory_verifies
|
||||
result_data['quality_assessments'] = quality_assessments
|
||||
conflicts_found = 0 # 初始化为整数0而不是空字符串
|
||||
conflicts_found = 0 # Initialize as integer 0 instead of empty string
|
||||
REMOVE_KEYS = {"created_at", "expired_at","relationship","predicate","statement_id","id","statement_id","relationship_statement_id"}
|
||||
# Clearn conflict_data,And memory_verify和quality_assessment
|
||||
# Clean conflict_data, and memory_verify and quality_assessment
|
||||
cleaned_conflict_data = []
|
||||
for item in conflict_data:
|
||||
cleaned_item = {
|
||||
@@ -389,7 +496,7 @@ class ReflectionEngine:
|
||||
for item in conflict_data:
|
||||
cleaned_data = []
|
||||
for row in item.get("data", []):
|
||||
# 删除 created_at / expired_at
|
||||
# Remove created_at / expired_at
|
||||
cleaned_row = {
|
||||
k: v
|
||||
for k, v in row.items()
|
||||
@@ -402,7 +509,7 @@ class ReflectionEngine:
|
||||
}
|
||||
cleaned_conflict_data_.append(cleaned_item)
|
||||
print(cleaned_conflict_data_)
|
||||
# 3. 解决冲突
|
||||
# 3. Resolve conflicts
|
||||
solved_data = await self._resolve_conflicts(cleaned_conflict_data_, source_data)
|
||||
if not solved_data:
|
||||
return ReflectionResult(
|
||||
@@ -413,7 +520,7 @@ class ReflectionEngine:
|
||||
)
|
||||
reflexion_data = []
|
||||
|
||||
# 遍历数据提取reflexion字段
|
||||
# Traverse data to extract reflexion fields
|
||||
for item in solved_data:
|
||||
if 'results' in item:
|
||||
for result in item['results']:
|
||||
@@ -431,15 +538,24 @@ class ReflectionEngine:
|
||||
|
||||
|
||||
async def extract_fields_from_json(self):
|
||||
"""从example.json中提取source_data和databasets字段"""
|
||||
"""
|
||||
Extract source_data and databasets fields from example.json
|
||||
|
||||
Reads reflection example data from the example.json file and extracts
|
||||
the source data and database statements for testing and demonstration purposes.
|
||||
|
||||
Returns:
|
||||
tuple: (source_data, databasets) extracted from the example file
|
||||
Returns empty lists if file reading fails
|
||||
"""
|
||||
|
||||
prompt_dir = os.path.join(os.path.dirname(__file__), "example")
|
||||
try:
|
||||
# 读取JSON文件
|
||||
# Read JSON file
|
||||
with open(prompt_dir + '/example.json', 'r', encoding='utf-8') as f:
|
||||
data = json.loads(f.read())
|
||||
|
||||
# 提取memory_verify下的字段
|
||||
# Extract fields under memory_verify
|
||||
memory_verify = data.get("memory_verify", {})
|
||||
source_data = memory_verify.get("source_data", [])
|
||||
databasets = memory_verify.get("databasets", [])
|
||||
@@ -451,15 +567,17 @@ class ReflectionEngine:
|
||||
|
||||
async def _get_reflexion_data(self, host_id: uuid.UUID) -> List[Any]:
|
||||
"""
|
||||
获取反思数据
|
||||
|
||||
根据配置的反思范围获取需要反思的记忆数据。
|
||||
Get reflection data from database
|
||||
|
||||
Retrieves memory data for reflection based on the configured reflection range.
|
||||
Supports both partial (from retrieval results) and full (entire database) modes.
|
||||
|
||||
Args:
|
||||
host_id: 主机ID
|
||||
host_id: Host UUID identifier for scoping data retrieval
|
||||
|
||||
Returns:
|
||||
List[Any]: 反思数据列表
|
||||
tuple: (reflexion_data, statement_data) containing memory data for reflection
|
||||
Returns empty lists if query fails
|
||||
"""
|
||||
|
||||
print("=== 获取反思数据 ===")
|
||||
@@ -484,26 +602,29 @@ class ReflectionEngine:
|
||||
|
||||
async def _detect_conflicts(self, data: List[Any], statement_databasets: List[Any]) -> List[Any]:
|
||||
"""
|
||||
检测冲突(基于事实的反思)
|
||||
|
||||
使用 LLM 分析记忆数据,检测其中的冲突。
|
||||
Detect conflicts (fact-based reflection)
|
||||
|
||||
Uses LLM to analyze memory data and detect conflicts within the memories.
|
||||
Performs comprehensive conflict detection including memory verification and
|
||||
quality assessment based on configuration settings.
|
||||
|
||||
Args:
|
||||
data: 待检测的记忆数据
|
||||
data: Memory data to be analyzed for conflicts
|
||||
statement_databasets: Statement database records for context
|
||||
|
||||
Returns:
|
||||
List[Any]: 冲突记忆列表
|
||||
List[Any]: List of detected conflicts with detailed analysis
|
||||
"""
|
||||
if not data:
|
||||
return []
|
||||
|
||||
# 数据预处理:如果数据量太少,直接返回无冲突
|
||||
# Data preprocessing: if data is too small, return no conflicts directly
|
||||
if len(data) < 2:
|
||||
logging.info("数据量不足,无需检测冲突")
|
||||
return []
|
||||
|
||||
# 使用转换后的数据
|
||||
# print("转换后的数据:", data[:2] if len(data) > 2 else data) # 只打印前2条避免日志过长
|
||||
# Use converted data
|
||||
# print("Converted data:", data[:2] if len(data) > 2 else data) # Only print first 2 to avoid long logs
|
||||
memory_verify = self.config.memory_verify
|
||||
|
||||
logging.info("====== 冲突检测开始 ======")
|
||||
@@ -512,7 +633,7 @@ class ReflectionEngine:
|
||||
language_type=self.config.language_type
|
||||
|
||||
try:
|
||||
# 渲染冲突检测提示词
|
||||
# Render conflict detection prompt
|
||||
rendered_prompt = await self.render_evaluate_prompt_func(
|
||||
data,
|
||||
self.conflict_schema,
|
||||
@@ -526,7 +647,7 @@ class ReflectionEngine:
|
||||
messages = [{"role": "user", "content": rendered_prompt}]
|
||||
logging.info(f"提示词长度: {len(rendered_prompt)}")
|
||||
|
||||
# 调用 LLM 进行冲突检测
|
||||
# Call LLM for conflict detection
|
||||
response = await self.llm_client.response_structured(
|
||||
messages,
|
||||
self.conflict_schema
|
||||
@@ -539,7 +660,7 @@ class ReflectionEngine:
|
||||
logging.error("LLM 冲突检测输出解析失败")
|
||||
return []
|
||||
|
||||
# 标准化返回格式
|
||||
# Standardize return format
|
||||
if isinstance(response, BaseModel):
|
||||
return [response.model_dump()]
|
||||
elif hasattr(response, 'dict'):
|
||||
@@ -553,15 +674,17 @@ class ReflectionEngine:
|
||||
|
||||
async def _resolve_conflicts(self, conflicts: List[Any], statement_databasets: List[Any]) -> List[Any]:
|
||||
"""
|
||||
解决冲突
|
||||
|
||||
使用 LLM 对检测到的冲突进行反思和解决。
|
||||
Resolve detected conflicts
|
||||
|
||||
Uses LLM to perform reflection and resolution on detected conflicts.
|
||||
Processes conflicts in parallel for efficiency while respecting concurrency limits.
|
||||
|
||||
Args:
|
||||
conflicts: 冲突列表
|
||||
conflicts: List of conflicts to be resolved
|
||||
statement_databasets: Statement database records for context
|
||||
|
||||
Returns:
|
||||
List[Any]: 解决方案列表
|
||||
List[Any]: List of resolution solutions with reflection results
|
||||
"""
|
||||
if not conflicts:
|
||||
return []
|
||||
@@ -570,12 +693,12 @@ class ReflectionEngine:
|
||||
baseline = self.config.baseline
|
||||
memory_verify = self.config.memory_verify
|
||||
|
||||
# 并行处理每个冲突
|
||||
# Process each conflict in parallel
|
||||
async def _resolve_one(conflict: Any) -> Optional[Dict[str, Any]]:
|
||||
"""解决单个冲突"""
|
||||
"""Resolve a single conflict"""
|
||||
async with self._semaphore:
|
||||
try:
|
||||
# 渲染反思提示词
|
||||
# Render reflection prompt
|
||||
rendered_prompt = await self.render_reflexion_prompt_func(
|
||||
[conflict],
|
||||
self.reflexion_schema,
|
||||
@@ -587,7 +710,7 @@ class ReflectionEngine:
|
||||
|
||||
messages = [{"role": "user", "content": rendered_prompt}]
|
||||
|
||||
# 调用 LLM 进行反思
|
||||
# Call LLM for reflection
|
||||
response = await self.llm_client.response_structured(
|
||||
messages,
|
||||
self.reflexion_schema
|
||||
@@ -596,7 +719,7 @@ class ReflectionEngine:
|
||||
if not response:
|
||||
return None
|
||||
|
||||
# 标准化返回格式
|
||||
# Standardize return format
|
||||
if isinstance(response, BaseModel):
|
||||
return response.model_dump()
|
||||
elif hasattr(response, 'dict'):
|
||||
@@ -610,11 +733,11 @@ class ReflectionEngine:
|
||||
logging.warning(f"解决单个冲突失败: {e}")
|
||||
return None
|
||||
|
||||
# 并发执行所有冲突解决任务
|
||||
# Execute all conflict resolution tasks concurrently
|
||||
tasks = [_resolve_one(conflict) for conflict in conflicts]
|
||||
results = await asyncio.gather(*tasks, return_exceptions=False)
|
||||
|
||||
# 过滤掉失败的结果
|
||||
# Filter out failed results
|
||||
solved = [r for r in results if r is not None]
|
||||
|
||||
logging.info(f"成功解决 {len(solved)}/{len(conflicts)} 个冲突")
|
||||
@@ -626,15 +749,16 @@ class ReflectionEngine:
|
||||
solved_data: List[Dict[str, Any]]
|
||||
) -> int:
|
||||
"""
|
||||
应用反思结果(更新记忆库)
|
||||
|
||||
将解决冲突后的记忆更新到 Neo4j 数据库中。
|
||||
Apply reflection results (update memory database)
|
||||
|
||||
Updates the Neo4j database with resolved conflicts and reflection results.
|
||||
Processes the solved data and applies changes to the memory storage system.
|
||||
|
||||
Args:
|
||||
solved_data: 解决方案列表
|
||||
solved_data: List of resolved conflict solutions with reflection data
|
||||
|
||||
Returns:
|
||||
int: 成功更新的记忆数量
|
||||
int: Number of successfully updated memory entries
|
||||
"""
|
||||
changes = extract_and_process_changes(solved_data)
|
||||
success_count = await neo4j_data(changes)
|
||||
@@ -642,80 +766,86 @@ class ReflectionEngine:
|
||||
|
||||
|
||||
|
||||
# 基于时间的反思方法
|
||||
# Time-based reflection methods
|
||||
async def time_based_reflection(
|
||||
self,
|
||||
host_id: uuid.UUID,
|
||||
time_period: Optional[str] = None
|
||||
) -> ReflectionResult:
|
||||
"""
|
||||
基于时间的反思
|
||||
|
||||
根据时间周期触发反思,检查在指定时间段内的记忆。
|
||||
Time-based reflection
|
||||
|
||||
Triggers reflection based on time cycles, checking memories within
|
||||
specified time periods. Uses the configured iteration period if
|
||||
no specific time period is provided.
|
||||
|
||||
Args:
|
||||
host_id: 主机ID
|
||||
time_period: 时间周期(如"三小时"),如果不提供则使用配置中的值
|
||||
host_id: Host UUID identifier for scoping reflection
|
||||
time_period: Time period (e.g., "three hours"), uses config value if not provided
|
||||
|
||||
Returns:
|
||||
ReflectionResult: 反思结果
|
||||
ReflectionResult: Comprehensive reflection operation result
|
||||
"""
|
||||
period = time_period or self.config.iteration_period
|
||||
logging.info(f"执行基于时间的反思,周期: {period}")
|
||||
|
||||
# 使用标准反思流程
|
||||
# Use standard reflection workflow
|
||||
return await self.execute_reflection(host_id)
|
||||
|
||||
# 基于事实的反思方法
|
||||
# Fact-based reflection methods
|
||||
async def fact_based_reflection(
|
||||
self,
|
||||
host_id: uuid.UUID
|
||||
) -> ReflectionResult:
|
||||
"""
|
||||
基于事实的反思
|
||||
|
||||
检测记忆中的事实冲突并解决。
|
||||
Fact-based reflection
|
||||
|
||||
Detects and resolves factual conflicts within memories. Analyzes
|
||||
memory data for inconsistencies and contradictions that need resolution.
|
||||
|
||||
Args:
|
||||
host_id: 主机ID
|
||||
host_id: Host UUID identifier for scoping reflection
|
||||
|
||||
Returns:
|
||||
ReflectionResult: 反思结果
|
||||
ReflectionResult: Comprehensive reflection operation result
|
||||
"""
|
||||
logging.info("执行基于事实的反思")
|
||||
|
||||
# 使用标准反思流程
|
||||
# Use standard reflection workflow
|
||||
return await self.execute_reflection(host_id)
|
||||
|
||||
# 综合反思方法
|
||||
# Comprehensive reflection methods
|
||||
async def comprehensive_reflection(
|
||||
self,
|
||||
host_id: uuid.UUID
|
||||
) -> ReflectionResult:
|
||||
"""
|
||||
综合反思
|
||||
|
||||
整合基于时间和基于事实的反思策略。
|
||||
Comprehensive reflection
|
||||
|
||||
Integrates time-based and fact-based reflection strategies based on
|
||||
the configured baseline. Supports hybrid approaches that combine
|
||||
multiple reflection methodologies.
|
||||
|
||||
Args:
|
||||
host_id: 主机ID
|
||||
host_id: Host UUID identifier for scoping reflection
|
||||
|
||||
Returns:
|
||||
ReflectionResult: 反思结果
|
||||
ReflectionResult: Comprehensive reflection operation result combining
|
||||
multiple strategies if using hybrid baseline
|
||||
"""
|
||||
logging.info("执行综合反思")
|
||||
|
||||
# 根据配置的基线选择反思策略
|
||||
# Choose reflection strategy based on configured baseline
|
||||
if self.config.baseline == ReflectionBaseline.TIME:
|
||||
return await self.time_based_reflection(host_id)
|
||||
elif self.config.baseline == ReflectionBaseline.FACT:
|
||||
return await self.fact_based_reflection(host_id)
|
||||
elif self.config.baseline == ReflectionBaseline.HYBRID:
|
||||
# 混合策略:先执行基于时间的反思,再执行基于事实的反思
|
||||
# Hybrid strategy: execute time-based reflection first, then fact-based reflection
|
||||
time_result = await self.time_based_reflection(host_id)
|
||||
fact_result = await self.fact_based_reflection(host_id)
|
||||
|
||||
# 合并结果
|
||||
# Merge results
|
||||
return ReflectionResult(
|
||||
success=time_result.success and fact_result.success,
|
||||
message=f"时间反思: {time_result.message}; 事实反思: {fact_result.message}",
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
使用Neo4j的全文索引进行高效的文本匹配。
|
||||
"""
|
||||
|
||||
from typing import List, Dict, Any, Optional
|
||||
from typing import List, Optional
|
||||
from app.core.logging_config import get_memory_logger
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.core.memory.storage_services.search.search_strategy import SearchStrategy, SearchResult
|
||||
@@ -74,7 +74,7 @@ class KeywordSearchStrategy(SearchStrategy):
|
||||
# 调用底层的关键词搜索函数
|
||||
results_dict = await search_graph(
|
||||
connector=self.connector,
|
||||
q=query_text,
|
||||
query=query_text,
|
||||
end_user_id=end_user_id,
|
||||
limit=limit,
|
||||
include=include_list
|
||||
|
||||
@@ -2,9 +2,17 @@ import json
|
||||
|
||||
|
||||
def escape_lucene_query(query: str) -> str:
|
||||
"""Escape Lucene special characters in a free-text query.
|
||||
|
||||
This prevents ParseException when using Neo4j full-text procedures.
|
||||
"""
|
||||
Escape special characters in Lucene queries
|
||||
|
||||
Prevents ParseException when using Neo4j full-text search procedures.
|
||||
Escapes all Lucene reserved special characters and operators.
|
||||
|
||||
Args:
|
||||
query: Original query string
|
||||
|
||||
Returns:
|
||||
str: Escaped query string safe for Lucene search
|
||||
"""
|
||||
if query is None:
|
||||
return ""
|
||||
@@ -14,7 +22,9 @@ def escape_lucene_query(query: str) -> str:
|
||||
s = s.replace("\r", " ").replace("\n", " ").strip()
|
||||
|
||||
# Lucene reserved tokens/special characters
|
||||
specials = ['&&', '||', '\\', '+', '-', '!', '(', ')', '{', '}', '[', ']', '^', '"', '~', '*', '?', ':']
|
||||
# NOTE: '/' is the regex delimiter in Lucene — must be escaped to prevent
|
||||
# TokenMgrError when the query contains unmatched slashes.
|
||||
specials = ['&&', '||', '\\', '+', '-', '!', '(', ')', '{', '}', '[', ']', '^', '"', '~', '*', '?', ':', '/']
|
||||
# Replace longer tokens first to avoid partial double-escaping
|
||||
for token in sorted(specials, key=len, reverse=True):
|
||||
s = s.replace(token, f"\\{token}")
|
||||
@@ -22,11 +32,21 @@ def escape_lucene_query(query: str) -> str:
|
||||
return s
|
||||
|
||||
def extract_plain_query(query_input: str) -> str:
|
||||
"""Extract clean, plain-text query from various input forms.
|
||||
|
||||
"""
|
||||
Extract clean plain-text query from various input forms
|
||||
|
||||
Handles the following cases:
|
||||
- Strips surrounding quotes and whitespace
|
||||
- If input looks like JSON, prefers the 'original' field
|
||||
- Fallbacks to the raw string when parsing fails
|
||||
- Falls back to raw string when parsing fails
|
||||
- Handles dictionary-type input
|
||||
- Best-effort unescape common escape characters
|
||||
|
||||
Args:
|
||||
query_input: Query input in various forms (string, dict, etc.)
|
||||
|
||||
Returns:
|
||||
str: Extracted plain-text query string
|
||||
"""
|
||||
if query_input is None:
|
||||
return ""
|
||||
|
||||
@@ -4,7 +4,13 @@ from datetime import datetime
|
||||
|
||||
def validate_date_format(date_str: str) -> bool:
|
||||
"""
|
||||
Validate if the date string is in the format YYYY-MM-DD.
|
||||
Validate if date string conforms to YYYY-MM-DD format
|
||||
|
||||
Args:
|
||||
date_str: Date string to validate
|
||||
|
||||
Returns:
|
||||
bool: True if format is correct, False otherwise
|
||||
"""
|
||||
pattern = r"^\d{4}-\d{1,2}-\d{1,2}$"
|
||||
return bool(re.match(pattern, date_str))
|
||||
@@ -41,7 +47,20 @@ def normalize_date(date_str: str) -> str:
|
||||
|
||||
|
||||
def preprocess_date_string(date_str: str) -> str:
|
||||
"""预处理日期字符串,处理特殊格式"""
|
||||
"""
|
||||
预处理日期字符串,处理特殊格式
|
||||
|
||||
处理以下特殊格式:
|
||||
- 年份后直接跟月份没有分隔符的格式(如 "20259/28")
|
||||
- 无分隔符的纯数字格式(如 "20251028", "251028")
|
||||
- 混合分隔符,统一为 "-"
|
||||
|
||||
Args:
|
||||
date_str: 原始日期字符串
|
||||
|
||||
Returns:
|
||||
str: 预处理后的日期字符串,格式为 "YYYY-MM-DD" 或 "YYYY-MM"
|
||||
"""
|
||||
|
||||
# 处理类似 "20259/28" 的格式(年份后直接跟月份没有分隔)
|
||||
match = re.match(r'^(\d{4,5})[/\.\-_]?(\d{1,2})[/\.\-_]?(\d{1,2})$', date_str)
|
||||
@@ -78,7 +97,23 @@ def preprocess_date_string(date_str: str) -> str:
|
||||
|
||||
|
||||
def fallback_parse(date_str: str) -> str:
|
||||
"""备选解析方案"""
|
||||
"""
|
||||
备选日期解析方案
|
||||
|
||||
当智能解析失败时,尝试使用预定义的日期格式进行解析。
|
||||
支持多种常见的日期格式,包括:
|
||||
- YYYY-MM-DD, YYYY/MM/DD, YYYY.MM.DD
|
||||
- YYYYMMDD, YYMMDD
|
||||
- MM-DD-YYYY, MM/DD/YYYY, MM.DD.YYYY
|
||||
- DD-MM-YYYY, DD/MM/YYYY, DD.MM.YYYY
|
||||
- YYYY-MM, YYYY/MM, YYYY.MM
|
||||
|
||||
Args:
|
||||
date_str: 待解析的日期字符串
|
||||
|
||||
Returns:
|
||||
str: 标准化后的日期字符串(YYYY-MM-DD格式),解析失败时返回原字符串
|
||||
"""
|
||||
|
||||
# 尝试常见的日期格式[citation:4][citation:5]
|
||||
formats_to_try = [
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import os
|
||||
from jinja2 import Environment, FileSystemLoader
|
||||
|
||||
from app.core.memory.models.ontology_extraction_models import OntologyTypeList
|
||||
from app.core.memory.utils.log.logging_utils import log_prompt_rendering, log_template_rendering
|
||||
|
||||
# Setup Jinja2 environment
|
||||
@@ -205,6 +205,7 @@ async def render_triplet_extraction_prompt(
|
||||
predicate_instructions: dict = None,
|
||||
language: str = "zh",
|
||||
ontology_types: "OntologyTypeList | None" = None,
|
||||
speaker: str = None,
|
||||
) -> str:
|
||||
"""
|
||||
Renders the triplet extraction prompt using the extract_triplet.jinja2 template.
|
||||
@@ -216,6 +217,7 @@ async def render_triplet_extraction_prompt(
|
||||
predicate_instructions: Optional predicate instructions
|
||||
language: The language to use for entity descriptions ("zh" for Chinese, "en" for English)
|
||||
ontology_types: Optional OntologyTypeList containing predefined ontology types for entity classification
|
||||
speaker: Speaker role ("user" or "assistant") for the current statement
|
||||
|
||||
Returns:
|
||||
Rendered prompt content as string
|
||||
@@ -223,7 +225,7 @@ async def render_triplet_extraction_prompt(
|
||||
template = prompt_env.get_template("extract_triplet.jinja2")
|
||||
|
||||
# 准备本体类型数据
|
||||
ontology_type_section = ""
|
||||
ontology_type_section = None
|
||||
ontology_type_names = []
|
||||
type_hierarchy_hints = []
|
||||
if ontology_types and ontology_types.types:
|
||||
@@ -240,6 +242,7 @@ async def render_triplet_extraction_prompt(
|
||||
ontology_types=ontology_type_section,
|
||||
ontology_type_names=ontology_type_names,
|
||||
type_hierarchy_hints=type_hierarchy_hints,
|
||||
speaker=speaker,
|
||||
)
|
||||
# 记录渲染结果到提示日志(与示例日志结构一致)
|
||||
log_prompt_rendering('triplet extraction', rendered_prompt)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
{#
|
||||
对话级抽取与相关性判定模板(用于剪枝加速)
|
||||
输入:pruning_scene, ontology_classes, dialog_text, language
|
||||
输入:pruning_scene, ontology_class_infos, dialog_text, language
|
||||
- ontology_class_infos: List[{class_name: str, class_description: str}]
|
||||
输出:严格 JSON(不要包含任何多余文本),字段:
|
||||
- is_related: bool,是否与所选场景相关
|
||||
- times: [string],从对话中抽取的时间相关文本(日期、时间、时间段、有效期等)
|
||||
@@ -18,20 +19,16 @@
|
||||
#}
|
||||
|
||||
{# ── 确定场景说明 ── #}
|
||||
{% if ontology_classes and ontology_classes | length > 0 %}
|
||||
{% if ontology_class_infos and ontology_class_infos | length > 0 %}
|
||||
{% if language == 'en' %}
|
||||
{% set custom_types_str = ontology_classes | join(', ') %}
|
||||
{% set instruction = 'Scene "' ~ pruning_scene ~ '": The dialogue is related to this scene if it involves any of the following entity types: ' ~ custom_types_str ~ '.' %}
|
||||
{% set instruction = 'Scene "' ~ pruning_scene ~ '": The dialogue is relevant if it involves any of the following entity types.' %}
|
||||
{% else %}
|
||||
{% set custom_types_str = ontology_classes | join('、') %}
|
||||
{% set instruction = '场景「' ~ pruning_scene ~ '」:对话涉及以下任意实体类型时视为相关:' ~ custom_types_str ~ '。' %}
|
||||
{% set instruction = '场景「' ~ pruning_scene ~ '」:对话涉及以下任意实体类型时视为相关。' %}
|
||||
{% endif %}
|
||||
{% else %}
|
||||
{% if language == 'en' %}
|
||||
{% set custom_types_str = '' %}
|
||||
{% set instruction = 'Scene "' ~ pruning_scene ~ '": Determine whether the dialogue content is relevant to this scene based on overall context.' %}
|
||||
{% else %}
|
||||
{% set custom_types_str = '' %}
|
||||
{% set instruction = '场景「' ~ pruning_scene ~ '」:根据对话整体内容判断是否与该场景相关。' %}
|
||||
{% endif %}
|
||||
{% endif %}
|
||||
@@ -42,8 +39,17 @@
|
||||
2. 从对话中抽取所有需要保留的重要信息片段。
|
||||
|
||||
场景说明:{{ instruction }}
|
||||
{% if custom_types_str %}
|
||||
重要提示:只要对话中出现与上述实体类型({{ custom_types_str }})相关的内容,即判定为相关(is_related=true)。
|
||||
|
||||
{% if ontology_class_infos and ontology_class_infos | length > 0 %}
|
||||
【本场景实体类型定义】
|
||||
以下实体类型定义了本场景中哪些内容是重要的。
|
||||
凡是与以下任意类型相关的内容,都必须保留,并将关键词/短语提取到 keywords 字段:
|
||||
|
||||
{% for info in ontology_class_infos %}
|
||||
- {{ info.class_name }}:{{ info.class_description }}
|
||||
{% endfor %}
|
||||
|
||||
重要提示:只要对话中出现与上述任意实体类型相关的内容,即判定为相关(is_related=true)。
|
||||
{% endif %}
|
||||
|
||||
---
|
||||
@@ -51,13 +57,40 @@
|
||||
以下类型的内容无论是否与场景直接相关,都必须保留,请将其关键词/短语抽取到对应字段:
|
||||
- 时间信息:日期、时间点、时间段、有效期 → times 字段
|
||||
- 编号信息:学号、工号、订单号、申请号、账号、ID → ids 字段
|
||||
- 金额信息:价格、费用、金额(含货币符号或单位) → amounts 字段
|
||||
- 金额信息:价格、费用、金额(含货币符号或单位,如"100元"、"¥200")→ amounts 字段(注意:考试分数、成绩分数不属于金额,不要放入此字段)
|
||||
- 联系方式:电话、手机号、邮箱、微信、QQ → contacts 字段
|
||||
- 地址信息:地点、地址、位置 → addresses 字段
|
||||
- 场景关键词:与场景强相关的专业术语、事件名称 → keywords 字段
|
||||
- 场景关键词:与**当前场景**强相关的专业术语、事件名称 → keywords 字段(注意:只放与当前场景直接相关的词,跨场景的内容不要放入此字段)
|
||||
- **情绪与情感**:喜悦、悲伤、愤怒、焦虑、开心、难过、委屈、兴奋、害怕、担心、压力、感动等情绪表达 → preserve_keywords 字段
|
||||
- **兴趣与爱好**:喜欢、热爱、爱好、擅长、享受、沉迷、着迷、讨厌某事物等个人偏好表达 → preserve_keywords 字段
|
||||
- **个人观点与态度**:对某事物的明确看法、评价、立场 → preserve_keywords 字段
|
||||
- **个人情感态度**:对人际关系、情感状态的明确表达(如"我跟室友闹矛盾了"、"我都快抑郁了")→ preserve_keywords 字段
|
||||
- 注意:学业目标(如"我想考研")、成绩(如"87分")、学科偏好(如"喜欢数学")属于学业信息,不属于情绪/情感,不要放入 preserve_keywords 字段
|
||||
|
||||
【场景无关内容标记】
|
||||
请从对话中识别出与当前场景({{ pruning_scene }})**既不相关、也无语义关联**的消息片段,将其原文(或关键片段)提取到 scene_unrelated_snippets 字段。
|
||||
判断标准:
|
||||
- 与场景实体类型完全无关
|
||||
- 与场景话题没有因果/时间/情境上的关联(例如:不是"因为上课所以累"这种关联)
|
||||
- 纯粹是另一个话题的内容(如在教育场景中讨论购物、娱乐等)
|
||||
注意:有情绪/感受表达的消息即使话题不同,也可能有语义关联,请谨慎标记。
|
||||
|
||||
**重要:scene_unrelated_snippets 必须认真填写,不能为空数组。**
|
||||
如果对话中存在与场景无关的内容,必须将其原文片段提取出来。
|
||||
|
||||
示例(场景=在线教育):
|
||||
- "我最近心情很差,跟室友闹矛盾了" → 与教育场景无关,加入 scene_unrelated_snippets
|
||||
- "她总是很晚回来吵到我睡觉" → 与教育场景无关,加入 scene_unrelated_snippets
|
||||
- "对,我都快抑郁了" → 与教育场景无关,加入 scene_unrelated_snippets
|
||||
- "期末考试12月25日" → 与教育场景相关,不加入 scene_unrelated_snippets
|
||||
- "我上次高数作业87分" → 与教育场景相关,不加入 scene_unrelated_snippets
|
||||
- "我的目标是考研" → 与教育场景相关,不加入 scene_unrelated_snippets
|
||||
|
||||
示例(场景=情感陪伴):
|
||||
- "我最近心情很差,跟室友闹矛盾了" → 与情感陪伴场景相关(情绪+关系),不加入 scene_unrelated_snippets
|
||||
- "对,我都快抑郁了" → 与情感陪伴场景相关(情绪),不加入 scene_unrelated_snippets
|
||||
- "期末考试12月25日,3号教学楼201室" → 与情感陪伴场景无关(教育信息),加入 scene_unrelated_snippets
|
||||
- "我上次高数作业87分,这次能考好吗" → 与情感陪伴场景无关(学业信息),加入 scene_unrelated_snippets
|
||||
- "我的目标是考研,想读应用数学" → 与情感陪伴场景无关(学业目标),加入 scene_unrelated_snippets
|
||||
|
||||
【可以删除的内容】
|
||||
以下类型的内容属于低价值信息,可以在剪枝时删除:
|
||||
@@ -88,7 +121,8 @@
|
||||
"contacts": [<string>...],
|
||||
"addresses": [<string>...],
|
||||
"keywords": [<string>...],
|
||||
"preserve_keywords": [<string>...]
|
||||
"preserve_keywords": [<string>...],
|
||||
"scene_unrelated_snippets": [<string>...]
|
||||
}
|
||||
{% else %}
|
||||
You are a dialogue content analysis assistant. Please analyze the full dialogue below in one pass and complete two tasks:
|
||||
@@ -96,8 +130,17 @@ You are a dialogue content analysis assistant. Please analyze the full dialogue
|
||||
2. Extract all important information fragments that must be preserved.
|
||||
|
||||
Scenario Description: {{ instruction }}
|
||||
{% if custom_types_str %}
|
||||
Important: If the dialogue contains content related to any of the entity types above ({{ custom_types_str }}), mark it as relevant (is_related=true).
|
||||
|
||||
{% if ontology_class_infos and ontology_class_infos | length > 0 %}
|
||||
[Scene Entity Type Definitions]
|
||||
The following entity types define what content is important in this scene.
|
||||
Content related to ANY of these types must be preserved and extracted into the keywords field:
|
||||
|
||||
{% for info in ontology_class_infos %}
|
||||
- {{ info.class_name }}: {{ info.class_description }}
|
||||
{% endfor %}
|
||||
|
||||
Important: If the dialogue contains content related to any of the entity types above, mark it as relevant (is_related=true).
|
||||
{% endif %}
|
||||
|
||||
---
|
||||
@@ -105,13 +148,22 @@ Important: If the dialogue contains content related to any of the entity types a
|
||||
The following types of content must always be preserved regardless of scene relevance. Extract their keywords/phrases into the corresponding fields:
|
||||
- Time information: dates, time points, durations, expiry dates → times field
|
||||
- ID information: student IDs, employee IDs, order numbers, application numbers, account IDs → ids field
|
||||
- Amount information: prices, fees, amounts (with currency symbols or units) → amounts field
|
||||
- Amount information: prices, fees, amounts (with currency symbols or units, e.g., "$100", "¥200") → amounts field (Note: exam scores and grades are NOT amounts, do not put them here)
|
||||
- Contact information: phone numbers, emails, WeChat, QQ → contacts field
|
||||
- Address information: locations, addresses, places → addresses field
|
||||
- Scene keywords: professional terms and event names strongly related to the scene → keywords field
|
||||
- Scene keywords: professional terms and event names strongly related to **the current scene** → keywords field (Note: only put terms directly related to the current scene; cross-scene content should not be placed here)
|
||||
- **Emotions and feelings**: joy, sadness, anger, anxiety, happiness, sadness, excitement, fear, worry, stress, being moved, etc. → preserve_keywords field
|
||||
- **Interests and hobbies**: likes, loves, hobbies, good at, enjoys, obsessed with, hates something, personal preferences → preserve_keywords field
|
||||
- **Personal opinions and attitudes**: clear views, evaluations, or stances on something → preserve_keywords field
|
||||
- **Personal emotional attitudes**: clear expressions about interpersonal relationships or emotional states (e.g., "I had a fight with my roommate", "I'm almost depressed") → preserve_keywords field
|
||||
- Note: Academic goals (e.g., "I want to pursue a master's degree"), grades (e.g., "87 points"), and subject preferences (e.g., "I like math") are academic information, NOT emotions/feelings — do not put them in preserve_keywords
|
||||
|
||||
[Scene-Unrelated Content Marking]
|
||||
Please identify message snippets in the dialogue that are **neither relevant to nor semantically associated with** the current scene ({{ pruning_scene }}), and extract their original text (or key fragments) into the scene_unrelated_snippets field.
|
||||
Criteria:
|
||||
- Completely unrelated to the scene's entity types
|
||||
- No causal/temporal/contextual association with the scene topic (e.g., "feeling tired because of class" IS associated)
|
||||
- Purely belongs to a different topic (e.g., discussing shopping or entertainment in an education scene)
|
||||
Note: Messages with emotional/feeling expressions may still have semantic association even if the topic differs — mark carefully.
|
||||
|
||||
[CAN BE DELETED]
|
||||
The following types of content are low-value and can be removed during pruning:
|
||||
@@ -141,6 +193,7 @@ Output strict JSON only (fixed keys, order doesn't matter):
|
||||
"contacts": [<string>...],
|
||||
"addresses": [<string>...],
|
||||
"keywords": [<string>...],
|
||||
"preserve_keywords": [<string>...]
|
||||
"preserve_keywords": [<string>...],
|
||||
"scene_unrelated_snippets": [<string>...]
|
||||
}
|
||||
{% endif %}
|
||||
|
||||
@@ -43,8 +43,9 @@ Each statement must be labeled as per the criteria mentioned below.
|
||||
|
||||
对话上下文和共指消解:
|
||||
- 将每个陈述句归属于说出它的参与者。
|
||||
- 如果参与者列表为说话者提供了名称(例如,"李雪(用户)"),请在提取的陈述句中使用具体名称("李雪"),而不是通用角色("用户")。
|
||||
- 将所有代词解析为对话上下文中的具体人物或实体。
|
||||
- **对于用户的发言:必须使用"用户"作为主语**,禁止将"用户"或"我"替换为用户的真实姓名或别名。例如,用户说"我叫张三"应提取为"用户叫张三",而不是"张三叫张三"。
|
||||
- 对于 AI 助手的发言:使用"助手"或"AI助手"作为主语。
|
||||
- 将所有代词解析为对话上下文中的具体人物或实体,但"我"必须解析为"用户"。
|
||||
- 识别并将抽象引用解析为其具体名称(如果提到)。
|
||||
- 将缩写和首字母缩略词扩展为其完整形式。
|
||||
{% else %}
|
||||
@@ -68,8 +69,9 @@ Context Resolution Requirements:
|
||||
|
||||
Conversational Context & Co-reference Resolution:
|
||||
- Attribute every statement to the participant who uttered it.
|
||||
- If the participant list provides a name for a speaker (e.g., "李雪 (用户)"), use the specific name ("李雪") in the extracted statement, not the generic role ("用户").
|
||||
- Resolve all pronouns to the specific person or entity from the conversation's context.
|
||||
- **For user's statements: always use "用户" (User) as the subject**. Do NOT replace "用户" or "I" with the user's real name or alias. For example, if the user says "I'm John", extract as "用户 is John", not "John is John".
|
||||
- For AI assistant's statements: use "助手" or "AI助手" as the subject.
|
||||
- Resolve all pronouns to the specific person or entity from the conversation's context, but "I"/"我" must always resolve to "用户".
|
||||
- Identify and resolve abstract references to their specific names if mentioned.
|
||||
- Expand abbreviations and acronyms to their full form.
|
||||
{% endif %}
|
||||
@@ -139,13 +141,13 @@ AI: "水彩画很有趣!水彩颜料通常由颜料与阿拉伯树胶等粘合
|
||||
示例输出: {
|
||||
"statements": [
|
||||
{
|
||||
"statement": "Sarah Chen 最近一直在尝试水彩画。",
|
||||
"statement": "用户最近一直在尝试水彩画。",
|
||||
"statement_type": "FACT",
|
||||
"temporal_type": "DYNAMIC",
|
||||
"relevance": "RELEVANT"
|
||||
},
|
||||
{
|
||||
"statement": "Sarah Chen 画了一些花朵。",
|
||||
"statement": "用户画了一些花朵。",
|
||||
"statement_type": "FACT",
|
||||
"temporal_type": "DYNAMIC",
|
||||
"relevance": "RELEVANT"
|
||||
@@ -157,13 +159,13 @@ AI: "水彩画很有趣!水彩颜料通常由颜料与阿拉伯树胶等粘合
|
||||
"relevance": "IRRELEVANT"
|
||||
},
|
||||
{
|
||||
"statement": "Sarah Chen 认为她的水彩画中的色彩组合可以改进。",
|
||||
"statement": "用户认为她的水彩画中的色彩组合可以改进。",
|
||||
"statement_type": "OPINION",
|
||||
"temporal_type": "STATIC",
|
||||
"relevance": "RELEVANT"
|
||||
},
|
||||
{
|
||||
"statement": "Sarah Chen 真的很喜欢玫瑰和百合。",
|
||||
"statement": "用户真的很喜欢玫瑰和百合。",
|
||||
"statement_type": "FACT",
|
||||
"temporal_type": "STATIC",
|
||||
"relevance": "RELEVANT"
|
||||
@@ -186,13 +188,13 @@ AI: "水彩画很有趣!水彩颜料通常由颜料和阿拉伯树胶等粘合
|
||||
示例输出: {
|
||||
"statements": [
|
||||
{
|
||||
"statement": "张曼婷最近在尝试水彩画。",
|
||||
"statement": "用户最近在尝试水彩画。",
|
||||
"statement_type": "FACT",
|
||||
"temporal_type": "DYNAMIC",
|
||||
"relevance": "RELEVANT"
|
||||
},
|
||||
{
|
||||
"statement": "张曼婷画了一些花朵。",
|
||||
"statement": "用户画了一些花朵。",
|
||||
"statement_type": "FACT",
|
||||
"temporal_type": "DYNAMIC",
|
||||
"relevance": "RELEVANT"
|
||||
@@ -204,13 +206,13 @@ AI: "水彩画很有趣!水彩颜料通常由颜料和阿拉伯树胶等粘合
|
||||
"relevance": "IRRELEVANT"
|
||||
},
|
||||
{
|
||||
"statement": "张曼婷觉得水彩画的色彩搭配还有提升的空间。",
|
||||
"statement": "用户觉得水彩画的色彩搭配还有提升的空间。",
|
||||
"statement_type": "OPINION",
|
||||
"temporal_type": "STATIC",
|
||||
"relevance": "RELEVANT"
|
||||
},
|
||||
{
|
||||
"statement": "张曼婷很喜欢玫瑰和百合。",
|
||||
"statement": "用户很喜欢玫瑰和百合。",
|
||||
"statement_type": "FACT",
|
||||
"temporal_type": "STATIC",
|
||||
"relevance": "RELEVANT"
|
||||
@@ -233,13 +235,13 @@ User: "I think the color combinations could use some improvement, but I really l
|
||||
Example Output: {
|
||||
"statements": [
|
||||
{
|
||||
"statement": "Sarah Chen has been trying watercolor painting recently.",
|
||||
"statement": "用户 has been trying watercolor painting recently.",
|
||||
"statement_type": "FACT",
|
||||
"temporal_type": "DYNAMIC",
|
||||
"relevance": "RELEVANT"
|
||||
},
|
||||
{
|
||||
"statement": "Sarah Chen painted some flowers.",
|
||||
"statement": "用户 painted some flowers.",
|
||||
"statement_type": "FACT",
|
||||
"temporal_type": "DYNAMIC",
|
||||
"relevance": "RELEVANT"
|
||||
@@ -251,13 +253,13 @@ Example Output: {
|
||||
"relevance": "IRRELEVANT"
|
||||
},
|
||||
{
|
||||
"statement": "Sarah Chen thinks the color combinations in her watercolor paintings could use some improvement.",
|
||||
"statement": "用户 thinks the color combinations in her watercolor paintings could use some improvement.",
|
||||
"statement_type": "OPINION",
|
||||
"temporal_type": "STATIC",
|
||||
"relevance": "RELEVANT"
|
||||
},
|
||||
{
|
||||
"statement": "Sarah Chen really likes roses and lilies.",
|
||||
"statement": "用户 really likes roses and lilies.",
|
||||
"statement_type": "FACT",
|
||||
"temporal_type": "STATIC",
|
||||
"relevance": "RELEVANT"
|
||||
@@ -280,13 +282,13 @@ AI: "水彩画很有趣!水彩颜料通常由颜料和阿拉伯树胶等粘合
|
||||
Example Output: {
|
||||
"statements": [
|
||||
{
|
||||
"statement": "张曼婷最近在尝试水彩画。",
|
||||
"statement": "用户最近在尝试水彩画。",
|
||||
"statement_type": "FACT",
|
||||
"temporal_type": "DYNAMIC",
|
||||
"relevance": "RELEVANT"
|
||||
},
|
||||
{
|
||||
"statement": "张曼婷画了一些花朵。",
|
||||
"statement": "用户画了一些花朵。",
|
||||
"statement_type": "FACT",
|
||||
"temporal_type": "DYNAMIC",
|
||||
"relevance": "RELEVANT"
|
||||
@@ -298,13 +300,13 @@ Example Output: {
|
||||
"relevance": "IRRELEVANT"
|
||||
},
|
||||
{
|
||||
"statement": "张曼婷觉得水彩画的色彩搭配还有提升的空间。",
|
||||
"statement": "用户觉得水彩画的色彩搭配还有提升的空间。",
|
||||
"statement_type": "OPINION",
|
||||
"temporal_type": "STATIC",
|
||||
"relevance": "RELEVANT"
|
||||
},
|
||||
{
|
||||
"statement": "张曼婷很喜欢玫瑰和百合。",
|
||||
"statement": "用户很喜欢玫瑰和百合。",
|
||||
"statement_type": "FACT",
|
||||
"temporal_type": "STATIC",
|
||||
"relevance": "RELEVANT"
|
||||
|
||||
@@ -5,6 +5,15 @@
|
||||
===Task===
|
||||
Extract entities and knowledge triplets from the given statement.
|
||||
|
||||
**⚠️ CRITICAL REQUIREMENTS:**
|
||||
1. **ALIASES ORDER IS CRITICAL**: The FIRST alias in the array will be used as the user's primary display name (other_name). You MUST put the most important/frequently used name FIRST.
|
||||
2. **ALWAYS include aliases field**: Even if empty, you MUST include "aliases": [] in EVERY entity.
|
||||
|
||||
<!-- TODO: v0.2.10 - denied_aliases 功能暂时禁用,将通过 Cypher 查询实现
|
||||
2. **DENIED_ALIASES**: When user explicitly denies a name (e.g., "我不叫X", "I'm not called X"), you MUST put X in denied_aliases field, NOT in aliases.
|
||||
3. **ALWAYS include both fields**: Even if empty, you MUST include "aliases": [] and "denied_aliases": [] in EVERY entity.
|
||||
-->
|
||||
|
||||
{% if language == "zh" %}
|
||||
**重要:请使用中文生成实体名称(name)、描述(description)和示例(example)。**
|
||||
{% else %}
|
||||
@@ -14,38 +23,43 @@ Extract entities and knowledge triplets from the given statement.
|
||||
===Inputs===
|
||||
**Chunk Content:** "{{ chunk_content }}"
|
||||
**Statement:** "{{ statement }}"
|
||||
{% if speaker %}
|
||||
**Speaker:** {{ speaker }}
|
||||
{% if speaker == "assistant" %}
|
||||
{% if language == "zh" %}
|
||||
⚠️ 当前陈述句来自 **AI助手的回复**。AI助手在回复中用来称呼用户的名字是**用户的别名**,不是 AI 助手的别名。但只能提取原文中逐字出现的名字,严禁推测或创造原文中不存在的别名变体。
|
||||
{% else %}
|
||||
⚠️ This statement is from the **AI assistant's reply**. Names the AI uses to address the user are **user's aliases**, NOT the AI assistant's aliases. But only extract names that appear VERBATIM in the text — never infer or fabricate alias variants.
|
||||
{% endif %}
|
||||
{% endif %}
|
||||
{% endif %}
|
||||
|
||||
{% if ontology_types %}
|
||||
===Ontology Type Guidance===
|
||||
|
||||
**CRITICAL RULE: You MUST ONLY use the predefined ontology type names listed below for the entity "type" field. Do NOT use any other type names, even if they seem reasonable.**
|
||||
**CRITICAL: Use ONLY predefined type names below. If no exact match, use CLOSEST type. NEVER invent new types.**
|
||||
|
||||
**If no predefined type fits an entity, use the CLOSEST matching predefined type. NEVER invent new type names.**
|
||||
**Type Priority:**
|
||||
1. [场景类型] Scene Types (domain-specific, prefer first)
|
||||
2. [通用类型] General Types (standard ontologies)
|
||||
3. [通用父类] Parent Types (hierarchy context)
|
||||
|
||||
**Type Priority (from highest to lowest):**
|
||||
1. **[场景类型] Scene Types** - Domain-specific types, ALWAYS prefer these first
|
||||
2. **[通用类型] General Types** - Common types from standard ontologies (DBpedia)
|
||||
3. **[通用父类] Parent Types** - Provide type hierarchy context
|
||||
**Rules:**
|
||||
- Type MUST exactly match predefined names
|
||||
- Do NOT modify, translate, or abbreviate type names
|
||||
- Prefer scene types over general types
|
||||
|
||||
**Type Matching Rules:**
|
||||
- Entity type MUST exactly match one of the predefined type names below
|
||||
- Do NOT use types like "Equipment", "Component", "Concept", "Action", "Condition", "Data", "Duration" unless they appear in the predefined list
|
||||
- Do NOT modify, translate, abbreviate, or create variations of type names
|
||||
- Prefer scene types (marked [场景类型]) over general types when both could apply
|
||||
- If uncertain, check the type description to find the best match
|
||||
|
||||
**Predefined Ontology Types:**
|
||||
**Predefined Types:**
|
||||
{{ ontology_types }}
|
||||
|
||||
{% if type_hierarchy_hints %}
|
||||
**Type Hierarchy Reference:**
|
||||
The following shows type inheritance relationships (Child → Parent → Grandparent):
|
||||
**Hierarchy:**
|
||||
{% for hint in type_hierarchy_hints %}
|
||||
- {{ hint }}
|
||||
{% endfor %}
|
||||
{% endif %}
|
||||
|
||||
**ALLOWED Type Names (use EXACTLY one of these, no exceptions):**
|
||||
**ALLOWED Names:**
|
||||
{{ ontology_type_names | join(', ') }}
|
||||
|
||||
{% endif %}
|
||||
@@ -62,66 +76,167 @@ The following shows type inheritance relationships (Child → Parent → Grandpa
|
||||
- **Entity descriptions must be in English**
|
||||
- **Examples must be in English**
|
||||
{% endif %}
|
||||
- **Semantic Memory Classification (is_explicit_memory):**
|
||||
* Set to `true` if the entity represents **explicit/semantic memory**:
|
||||
- **Concepts:** "Machine Learning", "Photosynthesis", "Democracy"
|
||||
- **Knowledge:** "Python Programming Language", "Theory of Relativity"
|
||||
- **Definitions:** "API (Application Programming Interface)", "REST API"
|
||||
- **Principles:** "SOLID Principles", "First Law of Thermodynamics"
|
||||
- **Theories:** "Evolution Theory", "Quantum Mechanics"
|
||||
- **Methods/Techniques:** "Agile Development", "Machine Learning Algorithm"
|
||||
- **Technical Terms:** "Neural Network", "Database"
|
||||
* Set to `false` for:
|
||||
- **People:** "John Smith", "Dr. Wang"
|
||||
- **Organizations:** "Microsoft", "Harvard University"
|
||||
- **Locations:** "Beijing", "Central Park"
|
||||
- **Events:** "2024 Conference", "Project Meeting"
|
||||
- **Specific objects:** "iPhone 15", "Building A"
|
||||
- **Example Generation (IMPORTANT for semantic memory entities):**
|
||||
* For entities where `is_explicit_memory=true`, generate a **concise example (around 20 characters)** to help understand the concept
|
||||
* The example should be:
|
||||
- **Specific and concrete**: Use real-world scenarios or applications
|
||||
- **Brief**: Around 20 characters (can be slightly longer if needed for clarity)
|
||||
- **Semantic Memory (is_explicit_memory):**
|
||||
* `true` for: Concepts, Knowledge, Definitions, Theories, Methods (e.g., "Machine Learning", "REST API")
|
||||
* `false` for: People, Organizations, Locations, Events, Specific objects
|
||||
* For `is_explicit_memory=true`, provide concise example (~20 chars{% if language == "zh" %},使用中文{% endif %})
|
||||
|
||||
**🚨🚨🚨 ALIASES & DENIED_ALIASES - MANDATORY FIELDS 🚨🚨🚨**
|
||||
|
||||
**CRITICAL RULES (违反将导致提取失败):**
|
||||
|
||||
1. **EVERY entity MUST have aliases field:**
|
||||
- `"aliases": [...]` - REQUIRED, even if empty `[]`
|
||||
|
||||
2. **ALIASES - 别名提取规则:**
|
||||
{% if language == "zh" %}
|
||||
- **使用中文**
|
||||
- 包含:昵称、全名、简称、别称、网名等
|
||||
- 顺序:**第一个别名将作为用户的主显示名称(other_name),必须把最重要/最常用的名字放在第一位**
|
||||
- 提取顺序:严格按照对话中首次出现的顺序
|
||||
- 示例:
|
||||
* "我叫张三,大家叫我小张" → aliases=["张三", "小张"](张三是第一个,将成为 other_name)
|
||||
* "大家叫我小李,我全名叫李明" → aliases=["小李", "李明"](小李先出现,将成为 other_name)
|
||||
- 空值:如果没有别名,使用 `[]`
|
||||
- **🚨🚨🚨 严禁幻觉:只提取对话原文中逐字出现的别名,绝对不能推测、衍生或创造任何未在原文中出现的名字。例如,看到"陈思远"不能自行添加"思远大人""远哥""小远"等变体。如果原文没有这些字,就不能出现在 aliases 中。**
|
||||
- **🚨 归属区分:必须严格区分名称的归属对象。默认情况下,用户提到的名字归属用户实体。只有出现明确的第二人称命名表达(如"叫你""给你取名")时,才将名字归属 AI/助手实体。**
|
||||
- **🚨 说话人视角:当 speaker 为 assistant 时,AI 助手用来称呼用户的名字是用户的别名,必须归入用户实体的 aliases,绝对不能归入 AI 助手实体。但同样只能提取原文中逐字出现的称呼,不能推测。**
|
||||
* "我叫陈思远,我给AI取名为远仔" → 用户 aliases=["陈思远"],AI助手 aliases=["远仔"]
|
||||
* "我叫vv" → 用户 aliases=["vv"](没有给AI取名的表达,名字归用户)
|
||||
* [speaker=assistant] "好的,VV" → 用户 aliases=["VV"](AI 在称呼用户,原文中出现了"VV")
|
||||
* [speaker=assistant] "我叫陈仔" → AI助手 aliases=["陈仔"](AI 在自我介绍,这是 AI 的别名)
|
||||
* ❌ 错误:将"远仔"放入用户的 aliases("远仔"是给AI取的名字,不是用户的名字)
|
||||
* ❌ 错误:用户说"我叫vv",却把"vv"放入 AI 助手的 aliases
|
||||
* ❌ 错误:AI 称呼用户为"VV",却把"VV"放入 AI 助手的 aliases
|
||||
* ❌ 错误:原文只有"陈思远",却在 aliases 中添加"思远大人""远哥""小远"等从未出现的变体(这是幻觉)
|
||||
{% else %}
|
||||
- **In English**
|
||||
- Include: nicknames, full names, abbreviations, alternative names
|
||||
- Order: **The FIRST alias will be used as the user's primary display name (other_name). Put the most important/frequently used name FIRST**
|
||||
- Extraction order: Strictly follow the order of first appearance in conversation
|
||||
- Examples:
|
||||
* "I'm John, people call me Johnny" → aliases=["John", "Johnny"] (John is first, will become other_name)
|
||||
* "People call me Mike, my full name is Michael" → aliases=["Mike", "Michael"] (Mike appears first, will become other_name)
|
||||
- Empty: If no aliases, use `[]`
|
||||
- **🚨🚨🚨 NO HALLUCINATION: Only extract aliases that appear VERBATIM in the original text. NEVER infer, derive, or fabricate names not present in the text. For example, seeing "John Smith" does NOT allow adding "Johnny", "Smithy", "Mr. Smith" unless those exact strings appear in the conversation.**
|
||||
- **🚨 Ownership distinction: By default, all names mentioned by the user belong to the user entity. Only assign a name to the AI/assistant entity when an explicit second-person naming expression (e.g., "I'll call you", "your name is") is present.**
|
||||
- **🚨 Speaker perspective: When speaker is "assistant", names the AI uses to address the user are the USER's aliases and MUST go into the user entity's aliases, NEVER into the AI assistant entity's aliases. But only extract names that appear verbatim in the text, never infer.**
|
||||
* "I'm Alex, I'll call you Buddy" → User aliases=["Alex"], AI assistant aliases=["Buddy"]
|
||||
* "I'm vv" → User aliases=["vv"] (no AI-naming expression, name belongs to user)
|
||||
* [speaker=assistant] "Sure thing, VV" → User aliases=["VV"] (AI addressing the user, "VV" appears in text)
|
||||
* [speaker=assistant] "I'm Jarvis" → AI assistant aliases=["Jarvis"] (AI self-introduction, this is AI's alias)
|
||||
* ❌ Wrong: putting "Buddy" in user's aliases ("Buddy" is a name for the AI, not the user)
|
||||
* ❌ Wrong: User says "I'm vv" but "vv" is put in AI assistant's aliases
|
||||
* ❌ Wrong: AI calls user "VV" but "VV" is put in AI assistant's aliases
|
||||
* ❌ Wrong: Text only has "John Smith" but aliases include "Johnny", "Smithy" (hallucinated variants)
|
||||
{% endif %}
|
||||
* For non-semantic entities (`is_explicit_memory=false`), the example field can be empty
|
||||
- **Aliases Extraction:**
|
||||
|
||||
|
||||
|
||||
3. **USER ENTITY SPECIAL HANDLING:**
|
||||
{% if language == "zh" %}
|
||||
* 别名使用中文
|
||||
- 用户实体的 name 字段:使用 "用户" 或 "我"
|
||||
- 用户的真实姓名:放入 aliases
|
||||
- **🚨 禁止将 "用户"、"我" 放入 aliases 中,aliases 只能包含用户的真实姓名、昵称等**
|
||||
- 示例:
|
||||
* "我叫李明" → name="用户", aliases=["李明"]
|
||||
* ❌ 错误:aliases=["用户", "李明"]("用户"不是真实姓名,禁止放入 aliases)
|
||||
* ❌ 错误:aliases=["我", "李明"]("我"不是真实姓名,禁止放入 aliases)
|
||||
{% else %}
|
||||
* Aliases should be in English
|
||||
- User entity name field: use "User" or "I"
|
||||
- User's real name: put in aliases
|
||||
- **🚨 NEVER put "User" or "I" in aliases. Aliases must only contain real names, nicknames, etc.**
|
||||
- Examples:
|
||||
* "I'm John" → name="User", aliases=["John"]
|
||||
* ❌ Wrong: aliases=["User", "John"] ("User" is not a real name, FORBIDDEN in aliases)
|
||||
* ❌ Wrong: aliases=["I", "John"] ("I" is not a real name, FORBIDDEN in aliases)
|
||||
{% endif %}
|
||||
* Include common alternative names, abbreviations and full names
|
||||
* If no aliases exist, use empty array: []
|
||||
- Exclude lengthy quotes, calendar dates, temporal ranges, and temporal expressions
|
||||
- For numeric values: extract as separate entities (instance_of: 'Numeric', name: units, numeric_value: value)
|
||||
Example: £30 → name: 'GBP', numeric_value: 30, instance_of: 'Numeric'
|
||||
|
||||
|
||||
|
||||
4. **AI/ASSISTANT ENTITY SPECIAL HANDLING:**
|
||||
{% if language == "zh" %}
|
||||
- **🚨 默认规则:如果对话中没有出现明确指向 AI/助手的命名表达,则所有名字都归属于用户实体。不要猜测或推断某个名字是给 AI 取的。**
|
||||
- 只有当用户**明确**对 AI/助手进行命名时,才创建 AI/助手实体并将对应名字放入其 aliases
|
||||
- AI/助手实体的 name 字段:使用 "AI助手"
|
||||
- 用户给 AI 取的名字:放入 AI/助手实体的 aliases
|
||||
- **🚨 禁止将用户给 AI 取的名字放入用户实体的 aliases 中**
|
||||
- **必须出现以下明确的命名表达才能判定为给 AI 取名:**「给你取名」「叫你」「称呼你为」「给AI取名」「你的名字是」「以后叫你」「你就叫」「你不叫X了」「你现在叫」等**第二人称(你)或明确指向 AI 的命名句式**
|
||||
- **🚨 "你不叫X了"/"你不叫X,你叫Y" 句式:X 和 Y 都是 AI 的名字(旧名和新名),绝对不是用户的名字。因为句子主语是"你"(AI)。**
|
||||
- **以下情况名字归属用户,不是给 AI 取名:**「我叫」「我的名字是」「叫我」「我是」「大家叫我」「我的英文名是」「我的昵称是」等**第一人称(我)的自我介绍句式**
|
||||
- **🚨 speaker=assistant 时的特殊规则:**
|
||||
* AI 用来称呼用户的名字 → 归入**用户**实体的 aliases(但必须是原文中逐字出现的称呼,不能推测)
|
||||
* AI 自称的名字(如"我叫陈仔""我是你的助手")→ 归入**AI助手**实体的 aliases
|
||||
* 判断依据:AI 说"你叫X"或用 X 称呼用户 → X 是用户别名;AI 说"我叫X"或"我是X" → X 是 AI 别名
|
||||
- 示例:
|
||||
* "我叫vv" → 用户实体: name="用户", aliases=["vv"](第一人称自我介绍,名字归用户)
|
||||
* "我的英文名叫vv" → 用户实体: name="用户", aliases=["vv"](第一人称自我介绍,名字归用户)
|
||||
* "我叫陈思远,我给AI取名为远仔" → 用户实体: name="用户", aliases=["陈思远"];AI实体: name="AI助手", aliases=["远仔"]
|
||||
* "叫你小助,我自己叫老王" → 用户实体: name="用户", aliases=["老王"];AI实体: name="AI助手", aliases=["小助"]
|
||||
* "你不叫远仔了,你现在叫陈仔" → AI实体: name="AI助手", aliases=["陈仔"]("远仔"是AI旧名,"陈仔"是AI新名,都归AI。不要把"远仔"或"陈仔"放入用户的aliases)
|
||||
* [speaker=assistant] "好的VV,今天想干点啥?" → 用户实体: name="用户", aliases=["VV"](AI 在称呼用户,原文中出现了"VV")
|
||||
* [speaker=assistant] "你叫陈思远,我叫陈仔" → 用户实体: name="用户", aliases=["陈思远"];AI实体: name="AI助手", aliases=["陈仔"]
|
||||
* ❌ 错误:用户说"我叫vv",却把"vv"放入 AI 助手的 aliases(没有任何给 AI 取名的表达)
|
||||
* ❌ 错误:AI 称呼用户为"VV",却把"VV"放入 AI 助手的 aliases
|
||||
* ❌ 错误:aliases=["陈思远", "远仔"]("远仔"是给AI取的名字,不是用户的名字)
|
||||
* ❌ 错误:原文只有"陈思远",却在 aliases 中添加"思远大人""远哥""小远"等从未出现的变体(这是幻觉)
|
||||
{% else %}
|
||||
- **🚨 Default rule: If there is NO explicit AI/assistant naming expression in the conversation, ALL names belong to the user entity. Do NOT guess or infer that a name is for the AI.**
|
||||
- Only create an AI/assistant entity when the user **explicitly** names the AI/assistant
|
||||
- AI/assistant entity name field: use "AI Assistant"
|
||||
- Names the user gives to the AI: put in the AI/assistant entity's aliases
|
||||
- **🚨 NEVER put names given to the AI into the user entity's aliases**
|
||||
- **An AI-naming expression MUST be present to assign a name to the AI:** "I'll call you", "your name is", "I name you", "let me call you", "you'll be called", "you're not called X anymore", "your new name is", etc. — **second-person ("you") or explicit AI-directed naming patterns**
|
||||
- **🚨 "You're not called X anymore" / "You're not X, you're Y" pattern: BOTH X and Y are AI's names (old and new). They are NOT user's names. The subject is "you" (the AI).**
|
||||
- **These patterns mean the name belongs to the USER, NOT the AI:** "I'm", "my name is", "call me", "I am", "people call me", "my English name is", "my nickname is", etc. — **first-person ("I"/"me") self-introduction patterns**
|
||||
- **🚨 Special rules when speaker=assistant:**
|
||||
* Names the AI uses to address the user → belong to the **user** entity's aliases (but only extract names that appear verbatim in the text, never infer)
|
||||
* Names the AI uses for itself (e.g., "I'm Jarvis", "I am your assistant") → belong to the **AI assistant** entity's aliases
|
||||
* Rule: AI says "you are X" or calls user X → X is user's alias; AI says "I'm X" or "I am X" → X is AI's alias
|
||||
- Examples:
|
||||
* "I'm vv" → User entity: name="User", aliases=["vv"] (first-person intro, name belongs to user)
|
||||
* "My English name is vv" → User entity: name="User", aliases=["vv"] (first-person intro, name belongs to user)
|
||||
* "I'm Alex, I'll call you Buddy" → User entity: name="User", aliases=["Alex"]; AI entity: name="AI Assistant", aliases=["Buddy"]
|
||||
* "Call yourself Jarvis, my name is Tony" → User entity: name="User", aliases=["Tony"]; AI entity: name="AI Assistant", aliases=["Jarvis"]
|
||||
* "You're not called Jarvis anymore, your new name is Friday" → AI entity: name="AI Assistant", aliases=["Friday"] (both "Jarvis" and "Friday" are AI names, NOT user names)
|
||||
* [speaker=assistant] "Sure thing, VV" → User entity: name="User", aliases=["VV"] (AI addressing the user, "VV" appears in text)
|
||||
* [speaker=assistant] "You're Alex, and I'm Jarvis" → User entity: name="User", aliases=["Alex"]; AI entity: name="AI Assistant", aliases=["Jarvis"]
|
||||
* ❌ Wrong: User says "I'm vv" but "vv" is put in AI assistant's aliases (no AI-naming expression exists)
|
||||
* ❌ Wrong: AI calls user "VV" but "VV" is put in AI assistant's aliases
|
||||
* ❌ Wrong: aliases=["Alex", "Buddy"] ("Buddy" is a name for the AI, not the user)
|
||||
* ❌ Wrong: Text only has "John Smith" but aliases include "Johnny", "Smithy" (hallucinated variants)
|
||||
{% endif %}
|
||||
|
||||
5. **ALIASES ORDER:**
|
||||
{% if language == "zh" %}
|
||||
- 顺序优先级:按出现顺序,先出现的在前
|
||||
{% else %}
|
||||
- Order priority: by appearance order, first mentioned comes first
|
||||
{% endif %}
|
||||
|
||||
**EXAMPLES OF CORRECT EXTRACTION:**
|
||||
{% if language == "zh" %}
|
||||
- "我叫张三" → aliases=["张三"] (张三将成为 other_name)
|
||||
- "大家叫我小明,我全名叫李明" → aliases=["小明", "李明"] (小明先出现,将成为 other_name)
|
||||
- "我是李华,网名叫华仔" → aliases=["李华", "华仔"] (李华先出现,将成为 other_name)
|
||||
{% else %}
|
||||
- "I'm John" → aliases=["John"] (John will become other_name)
|
||||
- "People call me Mike, my full name is Michael" → aliases=["Mike", "Michael"] (Mike appears first, will become other_name)
|
||||
- "I'm John Smith, username JSmith" → aliases=["John Smith", "JSmith"] (John Smith appears first, will become other_name)
|
||||
{% endif %}
|
||||
|
||||
- Exclude lengthy quotes, dates, temporal expressions
|
||||
- Numeric values: extract as entities (instance_of: 'Numeric', name: units, numeric_value: value)
|
||||
|
||||
**Triplet Extraction:**
|
||||
- Extract (subject, predicate, object) triplets where:
|
||||
- Subject: main entity performing the action or being described
|
||||
- Predicate: relationship between entities (e.g., 'is', 'works at', 'believes')
|
||||
- Object: entity, value, or concept affected by the predicate
|
||||
- Extract (subject, predicate, object) where subject/object are entities, predicate is relationship
|
||||
{% if language == "zh" %}
|
||||
- subject_name 和 object_name 必须使用中文
|
||||
- subject_name 和 object_name 使用中文
|
||||
{% else %}
|
||||
- subject_name and object_name must be in English (translate if original is in another language)
|
||||
- subject_name and object_name in English
|
||||
{% endif %}
|
||||
- Exclude all temporal expressions from every field
|
||||
- Use ONLY the predicates listed in "Predicate Instructions" (uppercase English tokens)
|
||||
- Do NOT translate predicate tokens
|
||||
- Do NOT include `statement_id` field (assigned automatically)
|
||||
|
||||
**When NOT to extract triplets:**
|
||||
- Non-propositional utterances (emotions, fillers, onomatopoeia)
|
||||
- No clear predicate from the given definitions applies
|
||||
- Standalone noun phrases or checklist items → extract as entities only
|
||||
- Do NOT invent generic predicates (e.g., "IS_DOING", "FEELS", "MENTIONS")
|
||||
|
||||
**If no valid triplet exists:** Return triplets: [], extract entities if present, otherwise both arrays empty.
|
||||
- Use ONLY predicates from "Predicate Instructions" (uppercase tokens)
|
||||
- Exclude temporal expressions, do NOT include `statement_id`
|
||||
- **When NOT to extract:** emotions, fillers, no clear predicate, standalone nouns
|
||||
- **If no valid triplet:** Return triplets: []
|
||||
{%- if predicate_instructions -%}
|
||||
|
||||
**Predicate Instructions:**
|
||||
@@ -170,8 +285,19 @@ Output:
|
||||
{"entity_idx": 0, "name": "Tripod", "type": "Equipment", "description": "Photography equipment accessory", "example": "", "aliases": ["Camera Tripod"], "is_explicit_memory": false}
|
||||
]
|
||||
}
|
||||
|
||||
**Example 4 (User vs AI alias distinction - English output):** "I'm Alex, and I'll call you Buddy"
|
||||
Output:
|
||||
{
|
||||
"triplets": [
|
||||
{"subject_name": "User", "subject_id": 0, "predicate": "NAMED", "object_name": "AI Assistant", "object_id": 1, "value": "Buddy"}
|
||||
],
|
||||
"entities": [
|
||||
{"entity_idx": 0, "name": "User", "type": "Person", "description": "The user", "example": "", "aliases": ["Alex"], "is_explicit_memory": false},
|
||||
{"entity_idx": 1, "name": "AI Assistant", "type": "Person", "description": "The user's AI assistant", "example": "", "aliases": ["Buddy"], "is_explicit_memory": false}
|
||||
]
|
||||
}
|
||||
{% else %}
|
||||
**Example 1 (English input → Chinese output):** "I plan to travel to Paris next week and visit the Louvre."
|
||||
Output:
|
||||
{
|
||||
"triplets": [
|
||||
@@ -207,26 +333,85 @@ Output:
|
||||
{"entity_idx": 0, "name": "三脚架", "type": "Equipment", "description": "摄影器材配件", "example": "", "aliases": ["相机三脚架"], "is_explicit_memory": false}
|
||||
]
|
||||
}
|
||||
|
||||
**Example 4 (别名 - Chinese):** "我的名字是乐力齐,我的小名是齐齐,同事们都叫我小乐"
|
||||
Output:
|
||||
{
|
||||
"triplets": [],
|
||||
"entities": [
|
||||
{"entity_idx": 0, "name": "用户", "type": "Person", "description": "用户本人", "example": "", "aliases": ["乐力齐", "齐齐", "小乐"], "is_explicit_memory": false}
|
||||
]
|
||||
}
|
||||
|
||||
**Example 5 (别名顺序 - Chinese):** "我叫陈思远。对了,我的网名叫「远山」"
|
||||
Output:
|
||||
{
|
||||
"triplets": [],
|
||||
"entities": [
|
||||
{"entity_idx": 0, "name": "用户", "type": "Person", "description": "用户本人", "example": "", "aliases": ["陈思远", "远山"], "is_explicit_memory": false}
|
||||
]
|
||||
}
|
||||
|
||||
**Example 6 (用户与AI别名区分 - Chinese):** "我称呼自己为陈思远,我给AI取名为远仔"
|
||||
Output:
|
||||
{
|
||||
"triplets": [
|
||||
{"subject_name": "用户", "subject_id": 0, "predicate": "NAMED", "object_name": "AI助手", "object_id": 1, "value": "远仔"}
|
||||
],
|
||||
"entities": [
|
||||
{"entity_idx": 0, "name": "用户", "type": "Person", "description": "用户本人", "example": "", "aliases": ["陈思远"], "is_explicit_memory": false},
|
||||
{"entity_idx": 1, "name": "AI助手", "type": "Person", "description": "用户的AI助手", "example": "", "aliases": ["远仔"], "is_explicit_memory": false}
|
||||
]
|
||||
}
|
||||
|
||||
**Example 7 (纯用户自我介绍,无AI命名 - Chinese):** "我叫vv"
|
||||
Output:
|
||||
{
|
||||
"triplets": [],
|
||||
"entities": [
|
||||
{"entity_idx": 0, "name": "用户", "type": "Person", "description": "用户本人", "example": "", "aliases": ["vv"], "is_explicit_memory": false}
|
||||
]
|
||||
}
|
||||
|
||||
**Example 8 (给AI改名 - Chinese):** "你不叫远仔了,你现在叫陈仔"
|
||||
Output:
|
||||
{
|
||||
"triplets": [
|
||||
{"subject_name": "用户", "subject_id": 0, "predicate": "NAMED", "object_name": "AI助手", "object_id": 1, "value": "陈仔"}
|
||||
],
|
||||
"entities": [
|
||||
{"entity_idx": 0, "name": "用户", "type": "Person", "description": "用户本人", "example": "", "aliases": [], "is_explicit_memory": false},
|
||||
{"entity_idx": 1, "name": "AI助手", "type": "Person", "description": "用户的AI助手", "example": "", "aliases": ["陈仔"], "is_explicit_memory": false}
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
{% endif %}
|
||||
===End of Examples===
|
||||
|
||||
{% if ontology_types %}
|
||||
**⚠️ REMINDER: The examples above use generic type names for illustration only. You MUST use ONLY the predefined ontology type names from the "ALLOWED Type Names" list above. For example, use "PredictiveMaintenance" instead of "Concept", use "ProductionLine" instead of "Equipment", etc. Map each entity to the closest matching predefined type.**
|
||||
**⚠️ REMINDER: Examples use generic types for illustration. You MUST use predefined types from "ALLOWED Names" above.**
|
||||
{% endif %}
|
||||
|
||||
===Output Format===
|
||||
|
||||
**JSON Requirements:**
|
||||
- Use only ASCII double quotes (") for JSON structure
|
||||
- Never use Chinese quotation marks ("") or Unicode quotes
|
||||
- Escape quotation marks in text with backslashes (\")
|
||||
- Ensure proper string closure and comma separation
|
||||
- No line breaks within JSON string values
|
||||
- Use ASCII double quotes ("), escape with \"
|
||||
- No Chinese quotes (""), no line breaks in strings
|
||||
{% if language == "zh" %}
|
||||
- **语言要求:实体名称(name)、描述(description)、示例(example)、subject_name、object_name 必须使用中文**
|
||||
- **语言:name、description、example、subject_name、object_name 使用中文**
|
||||
{% else %}
|
||||
- **Language Requirement: Entity names, descriptions, examples, subject_name, object_name must be in English**
|
||||
- **If the original text is in Chinese, translate all names to English**
|
||||
- **Language: names, descriptions, examples in English (translate if needed)**
|
||||
{% endif %}
|
||||
- **⚠️ ALIASES ORDER: preserve temporal order of appearance**
|
||||
- **🚨 MANDATORY FIELD: EVERY entity MUST include "aliases" field, even if empty array []**
|
||||
|
||||
**Output JSON structure:**
|
||||
```json
|
||||
{
|
||||
"triplets": [...],
|
||||
"entities": [...]
|
||||
}
|
||||
```
|
||||
|
||||
{{ json_schema }}
|
||||
|
||||
@@ -0,0 +1,135 @@
|
||||
===Task===
|
||||
Extract user metadata from the following conversation statements spoken by the user.
|
||||
|
||||
{% if language == "zh" %}
|
||||
**"三度原则"判断标准:**
|
||||
- 复用度:该信息是否会被多个功能模块使用?
|
||||
- 约束度:该信息是否会影响系统行为?
|
||||
- 时效性:该信息是长期稳定的还是临时的?仅提取长期稳定信息。
|
||||
|
||||
**提取规则:**
|
||||
- **只提取关于"用户本人"的画像信息**,忽略用户提到的第三方人物(如朋友、同事、家人)的信息
|
||||
- 仅提取文本中明确提到的信息,不要推测
|
||||
- 如果文本中没有可提取的用户画像信息,返回空的 user_metadata 对象
|
||||
- **输出语言必须与输入文本的语言一致**(输入中文则输出中文值,输入英文则输出英文值)
|
||||
|
||||
{% if existing_metadata %}
|
||||
**重要:合并已有元数据**
|
||||
下方提供了数据库中已有的用户元数据。请结合用户最新发言,输出**合并后的完整元数据**:
|
||||
- 如果用户明确否定了已有信息(如"我不再教高中物理了"),在输出中**移除**该信息
|
||||
- 如果用户提到了新信息,**添加**到对应字段中
|
||||
- 如果已有信息未被用户否定,**保留**在输出中
|
||||
- 标量字段(如 role、domain):如果用户提到了新值,用新值替换;否则保留已有值
|
||||
- 最终输出应该是完整的、合并后的元数据,不是增量
|
||||
{% endif %}
|
||||
|
||||
**字段说明:**
|
||||
- profile.role:用户的职业或角色,如 教师、医生、后端工程师
|
||||
- profile.domain:用户所在领域,如 教育、医疗、软件开发
|
||||
- profile.expertise:用户擅长的技能或工具(通用,不限于编程),如 Python、心理咨询、高中物理
|
||||
- profile.interests:用户主动表达兴趣的话题或领域标签
|
||||
- behavioral_hints.learning_stage:学习阶段(初学者/中级/高级)
|
||||
- behavioral_hints.preferred_depth:偏好深度(概览/技术细节/深入探讨)
|
||||
- behavioral_hints.tone_preference:语气偏好(轻松随意/专业简洁/学术严谨)
|
||||
- knowledge_tags:用户涉及的知识领域标签
|
||||
|
||||
**用户别名变更(增量模式):**
|
||||
- **aliases_to_add**:本次新发现的用户别名,包括:
|
||||
* 用户主动自我介绍:如"我叫张三"、"我的名字是XX"、"我的网名是XX"
|
||||
* 他人对用户的称呼:如"同事叫我陈哥"、"大家叫我小张"、"领导叫我老陈"
|
||||
* 只提取原文中逐字出现的名字,严禁推测或创造
|
||||
* 禁止提取:用户给 AI 取的名字、第三方人物自身的名字、"用户"/"我" 等占位词
|
||||
* 如果没有新别名,返回空数组 `[]`
|
||||
- **aliases_to_remove**:用户明确否认的别名,包括:
|
||||
* 用户说"我不叫XX了"、"别叫我XX"、"我改名了,不叫XX" → 将 XX 放入此数组
|
||||
* **严格限制**:只将用户原文中**逐字提到**的被否认名字放入,不要推断关联的其他别名
|
||||
* 例如:用户说"我不叫陈小刀了" → 只移除"陈小刀",不要移除"陈哥"、"老陈"等未被提及的别名
|
||||
* 如果没有要移除的别名,返回空数组 `[]`
|
||||
{% if existing_aliases %}
|
||||
- 已有别名:{{ existing_aliases | tojson }}(仅供参考,不需要在输出中重复)
|
||||
{% endif %}
|
||||
{% else %}
|
||||
**"Three-Degree Principle" criteria:**
|
||||
- Reusability: Will this information be used by multiple functional modules?
|
||||
- Constraint: Will this information affect system behavior?
|
||||
- Timeliness: Is this information long-term stable or temporary? Only extract long-term stable information.
|
||||
|
||||
**Extraction rules:**
|
||||
- **Only extract profile information about the user themselves**, ignore information about third parties (friends, colleagues, family) mentioned by the user
|
||||
- Only extract information explicitly mentioned in the text, do not speculate
|
||||
- If no user profile information can be extracted, return an empty user_metadata object
|
||||
- **Output language must match the input text language**
|
||||
|
||||
{% if existing_metadata %}
|
||||
**Important: Merge with existing metadata**
|
||||
Existing user metadata from the database is provided below. Combine with the user's latest statements to output the **complete merged metadata**:
|
||||
- If the user explicitly negates existing info (e.g. "I no longer teach high school physics"), **remove** it from output
|
||||
- If the user mentions new info, **add** it to the corresponding field
|
||||
- If existing info is not negated by the user, **keep** it in the output
|
||||
- Scalar fields (e.g. role, domain): replace with new value if user mentions one; otherwise keep existing
|
||||
- The final output should be the complete, merged metadata — not an incremental update
|
||||
{% endif %}
|
||||
|
||||
**Field descriptions:**
|
||||
- profile.role: User's occupation or role, e.g. teacher, doctor, software engineer
|
||||
- profile.domain: User's domain, e.g. education, healthcare, software development
|
||||
- profile.expertise: User's skills or tools (general, not limited to programming)
|
||||
- profile.interests: Topics or domain tags the user actively expressed interest in
|
||||
- behavioral_hints.learning_stage: Learning stage (beginner/intermediate/advanced)
|
||||
- behavioral_hints.preferred_depth: Preferred depth (overview/detailed/deep dive)
|
||||
- behavioral_hints.tone_preference: Tone preference (casual/professional/academic)
|
||||
- knowledge_tags: Knowledge domain tags related to the user
|
||||
|
||||
**User alias changes (incremental mode):**
|
||||
- **aliases_to_add**: Newly discovered user aliases from this conversation, including:
|
||||
* User self-introductions: e.g. "I'm John", "My name is XX", "My username is XX"
|
||||
* How others address the user: e.g. "My colleagues call me Johnny", "People call me Mike"
|
||||
* Only extract names that appear VERBATIM in the text — never infer or fabricate
|
||||
* Do NOT extract: names the user gives to the AI, third-party people's own names, placeholder words like "User"/"I"
|
||||
* If no new aliases, return empty array `[]`
|
||||
- **aliases_to_remove**: Aliases the user explicitly denies, including:
|
||||
* User says "Don't call me XX anymore", "I'm not called XX", "I changed my name from XX" → put XX in this array
|
||||
* **Strict rule**: Only include the exact name the user **verbatim mentions** as denied. Do NOT infer or remove related aliases
|
||||
* Example: User says "I'm not called John anymore" → only remove "John", do NOT remove "Johnny", "J" or other related aliases not mentioned
|
||||
* If no aliases to remove, return empty array `[]`
|
||||
{% if existing_aliases %}
|
||||
- Existing aliases: {{ existing_aliases | tojson }} (for reference only, do not repeat in output)
|
||||
{% endif %}
|
||||
{% endif %}
|
||||
|
||||
===User Statements===
|
||||
{% for stmt in statements %}
|
||||
- {{ stmt }}
|
||||
{% endfor %}
|
||||
|
||||
{% if existing_metadata %}
|
||||
===Existing User Metadata===
|
||||
```json
|
||||
{{ existing_metadata | tojson }}
|
||||
```
|
||||
{% endif %}
|
||||
|
||||
===Output Format===
|
||||
Return a JSON object with the following structure:
|
||||
```json
|
||||
{
|
||||
"user_metadata": {
|
||||
"profile": {
|
||||
"role": "",
|
||||
"domain": "",
|
||||
"expertise": [],
|
||||
"interests": []
|
||||
},
|
||||
"behavioral_hints": {
|
||||
"learning_stage": "",
|
||||
"preferred_depth": "",
|
||||
"tone_preference": ""
|
||||
},
|
||||
"knowledge_tags": []
|
||||
},
|
||||
"aliases_to_add": [],
|
||||
"aliases_to_remove": []
|
||||
}
|
||||
```
|
||||
|
||||
{{ json_schema }}
|
||||
@@ -2,15 +2,15 @@ import os
|
||||
from jinja2 import Environment, FileSystemLoader
|
||||
from typing import List, Dict, Any
|
||||
|
||||
|
||||
# Setup Jinja2 environment
|
||||
prompt_dir = os.path.join(os.path.dirname(__file__), "prompts")
|
||||
prompt_env = Environment(loader=FileSystemLoader(prompt_dir))
|
||||
|
||||
|
||||
async def render_evaluate_prompt(evaluate_data: List[Any], schema: Any,
|
||||
baseline: str = "TIME",
|
||||
memory_verify: bool = False,quality_assessment:bool = False,
|
||||
statement_databasets: List[str] = [],language_type:str = "zh") -> str:
|
||||
memory_verify: bool = False, quality_assessment: bool = False,
|
||||
statement_databasets=None, language_type: str = "zh") -> str:
|
||||
"""
|
||||
Renders the evaluate prompt using the evaluate_optimized.jinja2 template.
|
||||
|
||||
@@ -23,6 +23,8 @@ async def render_evaluate_prompt(evaluate_data: List[Any], schema: Any,
|
||||
Returns:
|
||||
Rendered prompt content as string
|
||||
"""
|
||||
if statement_databasets is None:
|
||||
statement_databasets = []
|
||||
template = prompt_env.get_template("evaluate.jinja2")
|
||||
|
||||
# Convert Pydantic model to JSON schema if needed
|
||||
@@ -46,7 +48,7 @@ async def render_evaluate_prompt(evaluate_data: List[Any], schema: Any,
|
||||
|
||||
|
||||
async def render_reflexion_prompt(data: Dict[str, Any], schema: Any, baseline: str, memory_verify: bool = False,
|
||||
statement_databasets: List[str] = [],language_type:str = "zh") -> str:
|
||||
statement_databasets=None, language_type: str = "zh") -> str:
|
||||
"""
|
||||
Renders the reflexion prompt using the reflexion_optimized.jinja2 template.
|
||||
|
||||
@@ -58,6 +60,8 @@ async def render_reflexion_prompt(data: Dict[str, Any], schema: Any, baseline: s
|
||||
Returns:
|
||||
Rendered prompt content as a string.
|
||||
"""
|
||||
if statement_databasets is None:
|
||||
statement_databasets = []
|
||||
template = prompt_env.get_template("reflexion.jinja2")
|
||||
|
||||
# Convert Pydantic model to JSON schema if needed
|
||||
@@ -69,7 +73,7 @@ async def render_reflexion_prompt(data: Dict[str, Any], schema: Any, baseline: s
|
||||
json_schema = schema
|
||||
|
||||
rendered_prompt = template.render(data=data, json_schema=json_schema,
|
||||
baseline=baseline,memory_verify=memory_verify,
|
||||
statement_databasets=statement_databasets,language_type=language_type)
|
||||
baseline=baseline, memory_verify=memory_verify,
|
||||
statement_databasets=statement_databasets, language_type=language_type)
|
||||
|
||||
return rendered_prompt
|
||||
|
||||
@@ -2,6 +2,7 @@ from .base import RedBearModelConfig, get_provider_llm_class, RedBearModelFacto
|
||||
from .llm import RedBearLLM
|
||||
from .embedding import RedBearEmbeddings
|
||||
from .rerank import RedBearRerank
|
||||
from .generation import RedBearImageGenerator, RedBearVideoGenerator
|
||||
|
||||
__all__ = [
|
||||
"RedBearModelConfig",
|
||||
@@ -9,5 +10,7 @@ __all__ = [
|
||||
"RedBearEmbeddings",
|
||||
"RedBearRerank",
|
||||
"RedBearModelFactory",
|
||||
"get_provider_llm_class"
|
||||
"get_provider_llm_class",
|
||||
"RedBearImageGenerator",
|
||||
"RedBearVideoGenerator"
|
||||
]
|
||||
@@ -1,23 +1,20 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Callable, Dict, List, Optional, TypeVar
|
||||
from typing import Any, Dict, Optional, TypeVar
|
||||
|
||||
from langchain_aws import ChatBedrock
|
||||
from langchain_community.chat_models import ChatTongyi
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.language_models import BaseLLM
|
||||
from langchain_ollama import OllamaLLM
|
||||
from langchain_openai import ChatOpenAI, OpenAI
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
import httpx
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.models.models_model import ModelProvider, ModelType
|
||||
from langchain_community.document_compressors import JinaRerank
|
||||
from langchain_core.callbacks import CallbackManagerForLLMRun
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.language_models import BaseLanguageModel, BaseLLM
|
||||
from langchain_core.outputs import Generation, LLMResult
|
||||
from langchain_core.retrievers import BaseRetriever
|
||||
from langchain_core.runnables import RunnableSerializable
|
||||
from pydantic import BaseModel, Field
|
||||
from app.core.models.volcano_chat import VolcanoChatOpenAI
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
@@ -29,6 +26,9 @@ class RedBearModelConfig(BaseModel):
|
||||
api_key: str
|
||||
base_url: Optional[str] = None
|
||||
is_omni: bool = False # 是否为 Omni 模型
|
||||
deep_thinking: bool = False # 是否启用深度思考模式
|
||||
thinking_budget_tokens: Optional[int] = None # 深度思考 token 预算
|
||||
support_thinking: bool = False # 模型是否支持 enable_thinking 参数(capability 含 thinking)
|
||||
# 请求超时时间(秒)- 默认120秒以支持复杂的LLM调用,可通过环境变量 LLM_TIMEOUT 配置
|
||||
timeout: float = Field(default_factory=lambda: float(os.getenv("LLM_TIMEOUT", "120.0")))
|
||||
# 最大重试次数 - 默认2次以避免过长等待,可通过环境变量 LLM_MAX_RETRIES 配置
|
||||
@@ -48,7 +48,7 @@ class RedBearModelFactory:
|
||||
# 打印供应商信息用于调试
|
||||
from app.core.logging_config import get_business_logger
|
||||
logger = get_business_logger()
|
||||
logger.debug(f"获取模型参数 - Provider: {provider}, Model: {config.model_name}, is_omni: {config.is_omni}")
|
||||
logger.debug(f"获取模型参数 - Provider: {provider}, Model: {config.model_name}, is_omni: {config.is_omni}, deep_thinking: {config.deep_thinking}")
|
||||
|
||||
# dashscope 的 omni 模型使用 OpenAI 兼容模式
|
||||
if provider == ModelProvider.DASHSCOPE and config.is_omni:
|
||||
@@ -62,7 +62,7 @@ class RedBearModelFactory:
|
||||
write=60.0,
|
||||
pool=10.0,
|
||||
)
|
||||
return {
|
||||
params: Dict[str, Any] = {
|
||||
"model": config.model_name,
|
||||
"base_url": config.base_url,
|
||||
"api_key": config.api_key,
|
||||
@@ -70,8 +70,25 @@ class RedBearModelFactory:
|
||||
"max_retries": config.max_retries,
|
||||
**config.extra_params
|
||||
}
|
||||
# 流式模式下启用 stream_usage 以获取 token 统计
|
||||
is_streaming = bool(config.extra_params.get("streaming"))
|
||||
if is_streaming:
|
||||
params["stream_usage"] = True
|
||||
# 只有支持 thinking 的模型才传 enable_thinking
|
||||
if config.support_thinking:
|
||||
model_kwargs: Dict[str, Any] = config.extra_params.get("model_kwargs", {})
|
||||
if is_streaming:
|
||||
model_kwargs["enable_thinking"] = config.deep_thinking
|
||||
if config.deep_thinking:
|
||||
model_kwargs["incremental_output"] = True
|
||||
if config.thinking_budget_tokens:
|
||||
model_kwargs["thinking_budget"] = config.thinking_budget_tokens
|
||||
else:
|
||||
model_kwargs["enable_thinking"] = False
|
||||
params["model_kwargs"] = model_kwargs
|
||||
return params
|
||||
|
||||
if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK, ModelProvider.OLLAMA]:
|
||||
if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK, ModelProvider.OLLAMA, ModelProvider.VOLCANO]:
|
||||
# 使用 httpx.Timeout 对象来设置详细的超时配置
|
||||
# 这样可以分别控制连接超时和读取超时
|
||||
import httpx
|
||||
@@ -82,7 +99,7 @@ class RedBearModelFactory:
|
||||
write=60.0, # 写入超时:60秒
|
||||
pool=10.0, # 连接池超时:10秒
|
||||
)
|
||||
return {
|
||||
params: Dict[str, Any] = {
|
||||
"model": config.model_name,
|
||||
"base_url": config.base_url,
|
||||
"api_key": config.api_key,
|
||||
@@ -90,16 +107,50 @@ class RedBearModelFactory:
|
||||
"max_retries": config.max_retries,
|
||||
**config.extra_params
|
||||
}
|
||||
# 流式模式下启用 stream_usage 以获取 token 统计
|
||||
if config.extra_params.get("streaming"):
|
||||
params["stream_usage"] = True
|
||||
# 深度思考模式
|
||||
is_streaming = bool(config.extra_params.get("streaming"))
|
||||
if config.support_thinking:
|
||||
if is_streaming and not config.is_omni:
|
||||
if provider == ModelProvider.VOLCANO:
|
||||
# 火山引擎深度思考仅流式调用支持,非流式时不传 thinking 参数
|
||||
thinking_config: Dict[str, Any] = {
|
||||
"type": "enabled" if config.deep_thinking else "disabled"
|
||||
}
|
||||
if config.deep_thinking and config.thinking_budget_tokens:
|
||||
thinking_config["budget_tokens"] = config.thinking_budget_tokens
|
||||
params["extra_body"] = {"thinking": thinking_config}
|
||||
else:
|
||||
# 始终显式传递 enable_thinking,不支持该参数的模型(如 DeepSeek-R1)会直接忽略
|
||||
model_kwargs: Dict[str, Any] = config.extra_params.get("model_kwargs", {})
|
||||
model_kwargs["enable_thinking"] = config.deep_thinking
|
||||
if config.deep_thinking and config.thinking_budget_tokens:
|
||||
model_kwargs["thinking_budget"] = config.thinking_budget_tokens
|
||||
params["model_kwargs"] = model_kwargs
|
||||
return params
|
||||
elif provider == ModelProvider.DASHSCOPE:
|
||||
# DashScope (通义千问) 使用自己的参数格式
|
||||
# 注意: DashScopeEmbeddings 不支持 timeout 和 base_url 参数
|
||||
# 只支持: model, dashscope_api_key, max_retries, client
|
||||
return {
|
||||
params = {
|
||||
"model": config.model_name,
|
||||
"dashscope_api_key": config.api_key,
|
||||
"max_retries": config.max_retries,
|
||||
**config.extra_params
|
||||
}
|
||||
# 只有支持 thinking 的模型才传 enable_thinking
|
||||
if config.support_thinking:
|
||||
is_streaming = bool(config.extra_params.get("streaming"))
|
||||
model_kwargs: Dict[str, Any] = config.extra_params.get("model_kwargs", {})
|
||||
if is_streaming:
|
||||
model_kwargs["enable_thinking"] = config.deep_thinking
|
||||
if config.deep_thinking:
|
||||
model_kwargs["incremental_output"] = True
|
||||
if config.thinking_budget_tokens:
|
||||
model_kwargs["thinking_budget"] = config.thinking_budget_tokens
|
||||
else:
|
||||
model_kwargs["enable_thinking"] = False
|
||||
params["model_kwargs"] = model_kwargs
|
||||
return params
|
||||
elif provider == ModelProvider.BEDROCK:
|
||||
# Bedrock 使用 AWS 凭证
|
||||
# api_key 格式: "access_key_id:secret_access_key" 或只是 access_key_id
|
||||
@@ -138,6 +189,13 @@ class RedBearModelFactory:
|
||||
elif "region_name" not in params:
|
||||
params["region_name"] = "us-east-1" # 默认区域
|
||||
|
||||
# 深度思考模式:Claude 3.7 Sonnet 等支持思考的模型
|
||||
# 通过 additional_model_request_fields 传递 thinking 块,关闭时不传(Bedrock 无 disabled 选项)
|
||||
if config.deep_thinking:
|
||||
budget = config.thinking_budget_tokens or 10000
|
||||
params["additional_model_request_fields"] = {
|
||||
"thinking": {"type": "enabled", "budget_tokens": budget}
|
||||
}
|
||||
return params
|
||||
else:
|
||||
raise BusinessException(f"不支持的提供商: {provider}", code=BizCode.PROVIDER_NOT_SUPPORTED)
|
||||
@@ -149,10 +207,15 @@ class RedBearModelFactory:
|
||||
if provider in [ModelProvider.XINFERENCE, ModelProvider.GPUSTACK]:
|
||||
return {
|
||||
"model": config.model_name,
|
||||
# "base_url": config.base_url,
|
||||
"jina_api_key": config.api_key,
|
||||
**config.extra_params
|
||||
}
|
||||
elif provider == ModelProvider.DASHSCOPE:
|
||||
return {
|
||||
"model": config.model_name,
|
||||
"dashscope_api_key": config.api_key,
|
||||
**config.extra_params
|
||||
}
|
||||
else:
|
||||
raise BusinessException(f"不支持的提供商: {provider}", code=BizCode.PROVIDER_NOT_SUPPORTED)
|
||||
|
||||
@@ -163,25 +226,21 @@ def get_provider_llm_class(config: RedBearModelConfig, type: ModelType = ModelTy
|
||||
|
||||
# dashscope 的 omni 模型使用 OpenAI 兼容模式
|
||||
if provider == ModelProvider.DASHSCOPE and config.is_omni:
|
||||
from langchain_openai import ChatOpenAI
|
||||
return ChatOpenAI
|
||||
|
||||
if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK] :
|
||||
if provider == ModelProvider.VOLCANO:
|
||||
return VolcanoChatOpenAI
|
||||
if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK]:
|
||||
if type == ModelType.LLM:
|
||||
from langchain_openai import OpenAI
|
||||
return OpenAI
|
||||
elif type == ModelType.CHAT:
|
||||
from langchain_openai import ChatOpenAI
|
||||
return ChatOpenAI
|
||||
else:
|
||||
raise BusinessException(f"不支持的模型提供商及类型: {provider}-{type}", code=BizCode.PROVIDER_NOT_SUPPORTED)
|
||||
elif provider == ModelProvider.DASHSCOPE:
|
||||
from langchain_community.chat_models import ChatTongyi
|
||||
return ChatTongyi
|
||||
elif provider == ModelProvider.OLLAMA:
|
||||
from langchain_ollama import OllamaLLM
|
||||
return OllamaLLM
|
||||
elif provider == ModelProvider.BEDROCK:
|
||||
from langchain_aws import ChatBedrock, ChatBedrockConverse
|
||||
|
||||
return ChatBedrock
|
||||
else:
|
||||
raise BusinessException(f"不支持的模型提供商: {provider}", code=BizCode.PROVIDER_NOT_SUPPORTED)
|
||||
@@ -212,6 +271,9 @@ def get_provider_rerank_class(provider: str):
|
||||
if provider in [ModelProvider.XINFERENCE, ModelProvider.GPUSTACK]:
|
||||
from langchain_community.document_compressors import JinaRerank
|
||||
return JinaRerank
|
||||
elif provider == ModelProvider.DASHSCOPE:
|
||||
from langchain_community.document_compressors.dashscope_rerank import DashScopeRerank
|
||||
return DashScopeRerank
|
||||
# elif provider == ModelProvider.OLLAMA:
|
||||
# from langchain_ollama import OllamaEmbeddings
|
||||
# return OllamaEmbeddings
|
||||
|
||||
@@ -1,23 +1,217 @@
|
||||
|
||||
from typing import Any, Dict, List, Optional, TypeVar, Callable
|
||||
from typing import Any, Dict, List, Union
|
||||
from langchain_core.embeddings import Embeddings
|
||||
|
||||
from app.core.models.base import RedBearModelConfig,get_provider_embedding_class,RedBearModelFactory
|
||||
from app.core.models.base import RedBearModelConfig, get_provider_embedding_class, RedBearModelFactory
|
||||
from app.models.models_model import ModelProvider
|
||||
|
||||
|
||||
class RedBearEmbeddings(Embeddings):
|
||||
"""Embedding → 完全符合 LangChain Embeddings"""
|
||||
"""统一的 Embedding 类,自动支持多模态(根据 provider 判断)"""
|
||||
|
||||
def __init__(self, config: RedBearModelConfig):
|
||||
self._model = self._create_model(config)
|
||||
self._config = config
|
||||
self._is_volcano = config.provider.lower() == ModelProvider.VOLCANO
|
||||
|
||||
if self._is_volcano:
|
||||
# 火山引擎使用 Ark SDK
|
||||
self._client = self._create_volcano_client(config)
|
||||
self._model = None
|
||||
else:
|
||||
# 其他 provider 使用 LangChain
|
||||
self._model = self._create_model(config)
|
||||
self._client = None
|
||||
|
||||
def _create_model(self, config: RedBearModelConfig) -> Embeddings:
|
||||
"""根据配置创建模型"""
|
||||
@staticmethod
|
||||
def _create_model(config: RedBearModelConfig) -> Embeddings:
|
||||
"""根据配置创建 LangChain 模型"""
|
||||
embedding_class = get_provider_embedding_class(config.provider)
|
||||
model_params = RedBearModelFactory.get_model_params(config)
|
||||
return embedding_class(**model_params)
|
||||
provider = config.provider.lower()
|
||||
# Embedding models only need connection params, never LLM-specific ones
|
||||
# (e.g. enable_thinking, model_kwargs) — build params directly.
|
||||
if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK]:
|
||||
import httpx
|
||||
params = {
|
||||
"model": config.model_name,
|
||||
"base_url": config.base_url,
|
||||
"api_key": config.api_key,
|
||||
"timeout": httpx.Timeout(timeout=config.timeout, connect=60.0),
|
||||
"max_retries": config.max_retries
|
||||
}
|
||||
elif provider == ModelProvider.DASHSCOPE:
|
||||
params = {
|
||||
"model": config.model_name,
|
||||
"dashscope_api_key": config.api_key,
|
||||
"max_retries": config.max_retries,
|
||||
}
|
||||
elif provider == ModelProvider.OLLAMA:
|
||||
params = {
|
||||
"model": config.model_name,
|
||||
"base_url": config.base_url,
|
||||
}
|
||||
elif provider == ModelProvider.BEDROCK:
|
||||
params = RedBearModelFactory.get_model_params(config)
|
||||
else:
|
||||
params = RedBearModelFactory.get_model_params(config)
|
||||
return embedding_class(**params)
|
||||
|
||||
def _create_volcano_client(self, config: RedBearModelConfig):
|
||||
"""创建火山引擎客户端"""
|
||||
from volcenginesdkarkruntime import Ark
|
||||
return Ark(api_key=config.api_key, base_url=config.base_url)
|
||||
|
||||
# ==================== LangChain 标准接口 ====================
|
||||
|
||||
def embed_documents(self, texts: list[str]) -> list[list[float]]:
|
||||
return self._model.embed_documents(texts)
|
||||
"""批量文本向量化(LangChain 标准接口)"""
|
||||
if self._is_volcano:
|
||||
# 火山引擎多模态 Embedding
|
||||
contents = [{"type": "text", "text": text} for text in texts]
|
||||
response = self._client.multimodal_embeddings.create(
|
||||
model=self._config.model_name,
|
||||
input=contents,
|
||||
encoding_format="float"
|
||||
)
|
||||
return [response.data.embedding]
|
||||
else:
|
||||
# 其他 provider
|
||||
return self._model.embed_documents(texts)
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
return self._model.embed_query(text)
|
||||
"""单个文本向量化(LangChain 标准接口)"""
|
||||
if self._is_volcano:
|
||||
# 火山引擎多模态 Embedding
|
||||
result = self.embed_documents([text])
|
||||
return result[0] if result else []
|
||||
else:
|
||||
# 其他 provider
|
||||
return self._model.embed_query(text)
|
||||
|
||||
# ==================== 多模态扩展方法 ====================
|
||||
|
||||
def embed_multimodal(
|
||||
self,
|
||||
contents: List[Dict[str, Any]],
|
||||
**kwargs
|
||||
) -> List[List[float]]:
|
||||
"""
|
||||
多模态向量化(仅火山引擎支持)
|
||||
|
||||
Args:
|
||||
contents: 内容列表,格式:
|
||||
- 文本: {"type": "text", "text": "..."}
|
||||
- 图片: {"type": "image_url", "image_url": {"url": "..."}}
|
||||
- 视频: {"type": "video_url", "video_url": {"url": "..."}}
|
||||
**kwargs: 其他参数
|
||||
|
||||
Returns:
|
||||
向量列表
|
||||
"""
|
||||
if not self._is_volcano:
|
||||
raise NotImplementedError(
|
||||
f"多模态 Embedding 仅支持火山引擎,当前 provider: {self._config.provider}"
|
||||
)
|
||||
|
||||
response = self._client.multimodal_embeddings.create(
|
||||
model=self._config.model_name,
|
||||
input=contents,
|
||||
**kwargs
|
||||
)
|
||||
return [response.data.embedding]
|
||||
|
||||
async def aembed_multimodal(
|
||||
self,
|
||||
contents: List[Dict[str, Any]],
|
||||
**kwargs
|
||||
) -> List[List[float]]:
|
||||
"""异步多模态向量化"""
|
||||
# 火山引擎 SDK 暂不支持异步,使用同步方法
|
||||
return self.embed_multimodal(contents, **kwargs)
|
||||
|
||||
def embed_text(self, text: str, **kwargs) -> List[float]:
|
||||
"""文本向量化(便捷方法)"""
|
||||
if self._is_volcano:
|
||||
result = self.embed_multimodal(
|
||||
[{"type": "text", "text": text}],
|
||||
**kwargs
|
||||
)
|
||||
return result[0] if result else []
|
||||
else:
|
||||
return self.embed_query(text)
|
||||
|
||||
def embed_image(self, image_url: str, **kwargs) -> List[float]:
|
||||
"""图片向量化(仅火山引擎支持)"""
|
||||
if not self._is_volcano:
|
||||
raise NotImplementedError(
|
||||
f"图片向量化仅支持火山引擎,当前 provider: {self._config.provider}"
|
||||
)
|
||||
|
||||
result = self.embed_multimodal(
|
||||
[{"type": "image_url", "image_url": {"url": image_url}}],
|
||||
**kwargs
|
||||
)
|
||||
return result[0] if result else []
|
||||
|
||||
def embed_video(self, video_url: str, **kwargs) -> List[float]:
|
||||
"""视频向量化(仅火山引擎支持)"""
|
||||
if not self._is_volcano:
|
||||
raise NotImplementedError(
|
||||
f"视频向量化仅支持火山引擎,当前 provider: {self._config.provider}"
|
||||
)
|
||||
|
||||
result = self.embed_multimodal(
|
||||
[{"type": "video_url", "video_url": {"url": video_url}}],
|
||||
**kwargs
|
||||
)
|
||||
return result[0] if result else []
|
||||
|
||||
def embed_batch(
|
||||
self,
|
||||
items: List[Union[str, Dict[str, Any]]],
|
||||
**kwargs
|
||||
) -> List[List[float]]:
|
||||
"""
|
||||
批量向量化(支持混合类型)
|
||||
|
||||
Args:
|
||||
items: 可以是字符串列表或内容字典列表
|
||||
**kwargs: 其他参数
|
||||
|
||||
Returns:
|
||||
向量列表
|
||||
"""
|
||||
# 如果全是字符串,使用标准方法
|
||||
if all(isinstance(item, str) for item in items):
|
||||
return self.embed_documents(items)
|
||||
|
||||
# 如果包含字典,需要多模态支持
|
||||
if not self._is_volcano:
|
||||
raise NotImplementedError(
|
||||
f"混合类型批量向量化仅支持火山引擎,当前 provider: {self._config.provider}"
|
||||
)
|
||||
|
||||
# 标准化输入格式
|
||||
contents = []
|
||||
for item in items:
|
||||
if isinstance(item, str):
|
||||
contents.append({"type": "text", "text": item})
|
||||
elif isinstance(item, dict):
|
||||
contents.append(item)
|
||||
else:
|
||||
raise ValueError(f"不支持的输入类型: {type(item)}")
|
||||
|
||||
return self.embed_multimodal(contents, **kwargs)
|
||||
|
||||
# ==================== 工具方法 ====================
|
||||
|
||||
def is_multimodal_supported(self) -> bool:
|
||||
"""检查是否支持多模态"""
|
||||
return self._is_volcano
|
||||
|
||||
def get_provider(self) -> str:
|
||||
"""获取 provider"""
|
||||
return self._config.provider
|
||||
|
||||
|
||||
# 保留 RedBearMultimodalEmbeddings 作为别名,向后兼容
|
||||
RedBearMultimodalEmbeddings = RedBearEmbeddings
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user