Compare commits
329 Commits
feature/to
...
v0.2.9
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ef626951bc | ||
|
|
4d6038c3cc | ||
|
|
d4450658a8 | ||
|
|
3ceb2efeaf | ||
|
|
e134b96333 | ||
|
|
4df41966fe | ||
|
|
2d6cde157e | ||
|
|
abc27c8372 | ||
|
|
dbe387f666 | ||
|
|
5e70d436a8 | ||
|
|
b7198f1abd | ||
|
|
5c87a2beeb | ||
|
|
3419bb137a | ||
|
|
a00684c67d | ||
|
|
6e7c641fd4 | ||
|
|
0c677701c0 | ||
|
|
4974f9aa98 | ||
|
|
c90b58bbcd | ||
|
|
d6a243f1be | ||
|
|
418114ef72 | ||
|
|
ceed61167f | ||
|
|
83774d7443 | ||
|
|
052c7c19b3 | ||
|
|
d42db0ca33 | ||
|
|
e15af5a2ba | ||
|
|
8b44b2cd61 | ||
|
|
9d91453200 | ||
|
|
ea8db7cd90 | ||
|
|
d60f16df1b | ||
|
|
8dd24533bf | ||
|
|
91b7f2a980 | ||
|
|
f7e89af9d2 | ||
|
|
afbd8c9b4f | ||
|
|
09b3b01d37 | ||
|
|
e3dcbed5f9 | ||
|
|
c7b51e7ad8 | ||
|
|
c0cd2373c0 | ||
|
|
6e757ae9e2 | ||
|
|
64a73c41d6 | ||
|
|
dae7431075 | ||
|
|
643bbbcf5c | ||
|
|
6702e86536 | ||
|
|
ab2bdfa088 | ||
|
|
8285250096 | ||
|
|
e59a215078 | ||
|
|
c89eccf8fe | ||
|
|
5703fc0cb4 | ||
|
|
3aed5c447a | ||
|
|
13352178ad | ||
|
|
8f216db353 | ||
|
|
9f6026492d | ||
|
|
b699b746a5 | ||
|
|
6095170169 | ||
|
|
173697e86a | ||
|
|
5c11da6a2e | ||
|
|
96214c433f | ||
|
|
167c915631 | ||
|
|
f485398768 | ||
|
|
289b1989e5 | ||
|
|
8224848ce1 | ||
|
|
c43d258455 | ||
|
|
c3e5c8b8bb | ||
|
|
930cadcaa8 | ||
|
|
57b6b34567 | ||
|
|
f878846364 | ||
|
|
7dce63dc0b | ||
|
|
03bc8ee7f5 | ||
|
|
4aefb01b0b | ||
|
|
4e9b5736b1 | ||
|
|
46fa99a8b8 | ||
|
|
17ea92357d | ||
|
|
bd70a8b812 | ||
|
|
ad5dc3c138 | ||
|
|
e37b1b01ca | ||
|
|
e659ca9fa2 | ||
|
|
758be0087f | ||
|
|
200c13b59f | ||
|
|
32f6886000 | ||
|
|
7fbf3e8873 | ||
|
|
3026702000 | ||
|
|
8677db114b | ||
|
|
2597a1f532 | ||
|
|
4298cd7d06 | ||
|
|
8197f9db35 | ||
|
|
3da6331515 | ||
|
|
539999131c | ||
|
|
d0ca5c8b27 | ||
|
|
ee6b8ffa62 | ||
|
|
14838dc064 | ||
|
|
e017870f44 | ||
|
|
9730c5ce0f | ||
|
|
f30260939a | ||
|
|
8ba0a74473 | ||
|
|
4f69224cfd | ||
|
|
6f7fee18c9 | ||
|
|
cc58c7333c | ||
|
|
c936277507 | ||
|
|
701df40270 | ||
|
|
b724dbe53a | ||
|
|
ac7c891ded | ||
|
|
3ed6f49bb0 | ||
|
|
a416a6b2bd | ||
|
|
35be03803f | ||
|
|
6427018ffb | ||
|
|
06b823ff96 | ||
|
|
0fdb489227 | ||
|
|
f6394a791e | ||
|
|
4bfd4944d0 | ||
|
|
7faf291ec3 | ||
|
|
3d291e3c23 | ||
|
|
b35bedc730 | ||
|
|
4d39cdf464 | ||
|
|
a874cc70a4 | ||
|
|
2319432182 | ||
|
|
7556468c6e | ||
|
|
91d38c0648 | ||
|
|
df3d58d388 | ||
|
|
80856e3c92 | ||
|
|
8c6f395818 | ||
|
|
2f4f7219e3 | ||
|
|
4c5183eddc | ||
|
|
dfc0ee9424 | ||
|
|
8dbb067b83 | ||
|
|
1df3fc416a | ||
|
|
6223b80cc4 | ||
|
|
68489f1b28 | ||
|
|
477853b04e | ||
|
|
863be50aaf | ||
|
|
d72d57f966 | ||
|
|
5b940e5f1a | ||
|
|
9ae1d2f0d9 | ||
|
|
318f1be107 | ||
|
|
4cab6317de | ||
|
|
81bfc9af36 | ||
|
|
189013f0f8 | ||
|
|
6f5bcd18a4 | ||
|
|
c7ef97c7a6 | ||
|
|
4d4a780ab7 | ||
|
|
9d2f3aa8f9 | ||
|
|
f2c9902a07 | ||
|
|
2525f8795c | ||
|
|
b7a03a844f | ||
|
|
c13c3846d1 | ||
|
|
30b5db1e98 | ||
|
|
f92eb9f45a | ||
|
|
a136d44e27 | ||
|
|
65b2f9e6e1 | ||
|
|
5275a274c3 | ||
|
|
4f09c4fbb3 | ||
|
|
7a3220aff5 | ||
|
|
14a32778f7 | ||
|
|
2a12cb04bf | ||
|
|
1e986c641f | ||
|
|
38c6c7f053 | ||
|
|
7c0743eb8f | ||
|
|
e981f066a3 | ||
|
|
db14d40fb3 | ||
|
|
e8d575fd0b | ||
|
|
a7285e35ad | ||
|
|
c4461c4917 | ||
|
|
2df615eca0 | ||
|
|
504e5ba61e | ||
|
|
0bae290e0c | ||
|
|
294ee49d59 | ||
|
|
26c36f70e6 | ||
|
|
c4b83b1f9c | ||
|
|
14413fd413 | ||
|
|
caab58dd2f | ||
|
|
0e899bea05 | ||
|
|
1794f8f209 | ||
|
|
85daf576e9 | ||
|
|
56fd5680cf | ||
|
|
0380c13a3b | ||
|
|
9ddc523f91 | ||
|
|
491ef27b8a | ||
|
|
edd115582f | ||
|
|
45eef12842 | ||
|
|
49364802c2 | ||
|
|
8873078006 | ||
|
|
2b9fd33bc8 | ||
|
|
e86d679ae5 | ||
|
|
def7367e33 | ||
|
|
54cff5861a | ||
|
|
dc2a73155b | ||
|
|
1856c55c04 | ||
|
|
522eb569f1 | ||
|
|
9df41456f6 | ||
|
|
04c54081c8 | ||
|
|
1c49e3c167 | ||
|
|
fb6ce839d2 | ||
|
|
c7275dccac | ||
|
|
d62b484d71 | ||
|
|
8ff1c6bd08 | ||
|
|
3dcf901043 | ||
|
|
d6dfc2cb12 | ||
|
|
8a3032ce4a | ||
|
|
391c60c812 | ||
|
|
b739b032d9 | ||
|
|
3dc863cabf | ||
|
|
611b14dfea | ||
|
|
de6e2f54d2 | ||
|
|
89d188fbf3 | ||
|
|
6bba574ca6 | ||
|
|
9cbffd6408 | ||
|
|
4d2ad5757c | ||
|
|
cd0ca9cae4 | ||
|
|
3369b702e4 | ||
|
|
cbec2c1356 | ||
|
|
5987eee0a8 | ||
|
|
6348304b7d | ||
|
|
59f8010519 | ||
|
|
9308c6efae | ||
|
|
2f78b7cf5e | ||
|
|
f86448f4bf | ||
|
|
48e2e613bb | ||
|
|
1060074740 | ||
|
|
95b7df7e38 | ||
|
|
fd1634eec4 | ||
|
|
efeead41b2 | ||
|
|
a3428c2435 | ||
|
|
31b8a3764e | ||
|
|
2ff81ba101 | ||
|
|
93deb286a3 | ||
|
|
7bd97bf6d3 | ||
|
|
2d1a1b4a1f | ||
|
|
503c890d93 | ||
|
|
1f73501786 | ||
|
|
eef13cb717 | ||
|
|
c70ac1339e | ||
|
|
24c13d408e | ||
|
|
338d7f1065 | ||
|
|
27672cfaa0 | ||
|
|
4dbb2bf2e2 | ||
|
|
37bc4beab4 | ||
|
|
6056952936 | ||
|
|
31085ed678 | ||
|
|
dce7206c44 | ||
|
|
c17a2dad2d | ||
|
|
0f092e08f4 | ||
|
|
8e7603bcc4 | ||
|
|
a079358028 | ||
|
|
fa29a39920 | ||
|
|
2146c555d2 | ||
|
|
240f1d431b | ||
|
|
9f947a3395 | ||
|
|
bf5c4628c3 | ||
|
|
911d5e0b34 | ||
|
|
bd31aa5abf | ||
|
|
0775fad5f0 | ||
|
|
726148d7ee | ||
|
|
0f1b1d7d10 | ||
|
|
11aa2e1f9e | ||
|
|
ca654cca74 | ||
|
|
bd1f649bd0 | ||
|
|
ea00747c66 | ||
|
|
3db031891e | ||
|
|
fb6ca3909a | ||
|
|
929afb1770 | ||
|
|
6235584b2e | ||
|
|
0b1ea33b41 | ||
|
|
3929f811b8 | ||
|
|
b1b53f6b1d | ||
|
|
551a2b59a5 | ||
|
|
9a765ac71e | ||
|
|
83e26732de | ||
|
|
52fdfc7744 | ||
|
|
4e544325a0 | ||
|
|
99a2f396fd | ||
|
|
0157c9d262 | ||
|
|
5ddacab162 | ||
|
|
a51e34852c | ||
|
|
36f670b2e9 | ||
|
|
cbcbc8822c | ||
|
|
69c001bf84 | ||
|
|
aa2d1e7a35 | ||
|
|
39b2f3ba0e | ||
|
|
43064ab71b | ||
|
|
4144f0b9b5 | ||
|
|
08f0be17ce | ||
|
|
2915e464bf | ||
|
|
152559ae46 | ||
|
|
1f531f1ace | ||
|
|
7ec947189c | ||
|
|
b4615bacdc | ||
|
|
e849fed5c1 | ||
|
|
0f5cae4590 | ||
|
|
1c3029f360 | ||
|
|
e2411e0bdd | ||
|
|
7af88b19cf | ||
|
|
c3f8dbd4bc | ||
|
|
c1e48fde86 | ||
|
|
f644c84fbb | ||
|
|
d0afce27c4 | ||
|
|
b84aba71e7 | ||
|
|
2e481df465 | ||
|
|
a322ec4fd5 | ||
|
|
bdbf9c0609 | ||
|
|
ef7d59e442 | ||
|
|
27b782e12a | ||
|
|
37a22fbfa9 | ||
|
|
d798d101f7 | ||
|
|
825f225f63 | ||
|
|
4d5e2958dc | ||
|
|
6105d46198 | ||
|
|
7aec157859 | ||
|
|
13abb03d87 | ||
|
|
e8947ad0bb | ||
|
|
7056865726 | ||
|
|
84c23e7c4e | ||
|
|
ba65b06582 | ||
|
|
f4f04036f3 | ||
|
|
477d404727 | ||
|
|
88598fb9fb | ||
|
|
f09de3a11c | ||
|
|
e13acdc8a9 | ||
|
|
3c99fb116c | ||
|
|
b9ebe22df1 | ||
|
|
509d1a2e24 | ||
|
|
153e68e055 | ||
|
|
77b9a6a94e | ||
|
|
d68bbab419 | ||
|
|
6d53d9178c | ||
|
|
06fe3f2f01 | ||
|
|
e2b6c713e7 | ||
|
|
0b3b241436 | ||
|
|
4c18f9e858 | ||
|
|
8fec54c085 | ||
|
|
d8e37a4d2b | ||
|
|
1da2c4fa37 |
@@ -1,6 +1,8 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
|
import threading
|
||||||
from typing import Dict, Any, Optional
|
from typing import Dict, Any, Optional
|
||||||
|
|
||||||
import redis.asyncio as redis
|
import redis.asyncio as redis
|
||||||
@@ -21,6 +23,50 @@ pool = ConnectionPool.from_url(
|
|||||||
)
|
)
|
||||||
aio_redis = redis.StrictRedis(connection_pool=pool)
|
aio_redis = redis.StrictRedis(connection_pool=pool)
|
||||||
|
|
||||||
|
_REDIS_URL = f"redis://{settings.REDIS_HOST}:{settings.REDIS_PORT}"
|
||||||
|
|
||||||
|
# Thread-local storage for connection pools.
|
||||||
|
# Each thread (and each forked process) gets its own pool to avoid
|
||||||
|
# "Future attached to a different loop" errors in Celery --pool=threads
|
||||||
|
# and stale connections after fork in --pool=prefork.
|
||||||
|
_thread_local = threading.local()
|
||||||
|
|
||||||
|
|
||||||
|
def get_thread_safe_redis() -> redis.StrictRedis:
|
||||||
|
"""Return a Redis client whose connection pool is bound to the current
|
||||||
|
thread, process **and** event loop.
|
||||||
|
|
||||||
|
The pool is recreated when:
|
||||||
|
- The PID changes (fork, Celery --pool=prefork)
|
||||||
|
- The thread has no pool yet (Celery --pool=threads)
|
||||||
|
- The previously-cached event loop has been closed (Celery tasks call
|
||||||
|
``_shutdown_loop_gracefully`` which closes the loop after each run)
|
||||||
|
"""
|
||||||
|
current_pid = os.getpid()
|
||||||
|
cached_loop = getattr(_thread_local, "loop", None)
|
||||||
|
loop_stale = cached_loop is not None and cached_loop.is_closed()
|
||||||
|
|
||||||
|
if not hasattr(_thread_local, "pool") \
|
||||||
|
or getattr(_thread_local, "pid", None) != current_pid \
|
||||||
|
or loop_stale:
|
||||||
|
_thread_local.pid = current_pid
|
||||||
|
# Python 3.10+: get_event_loop() raises RuntimeError in threads
|
||||||
|
# where no loop has been set yet (e.g. Celery --pool=threads).
|
||||||
|
try:
|
||||||
|
_thread_local.loop = asyncio.get_event_loop()
|
||||||
|
except RuntimeError:
|
||||||
|
_thread_local.loop = None
|
||||||
|
_thread_local.pool = ConnectionPool.from_url(
|
||||||
|
_REDIS_URL,
|
||||||
|
db=settings.REDIS_DB,
|
||||||
|
password=settings.REDIS_PASSWORD,
|
||||||
|
decode_responses=True,
|
||||||
|
max_connections=5,
|
||||||
|
health_check_interval=30,
|
||||||
|
)
|
||||||
|
|
||||||
|
return redis.StrictRedis(connection_pool=_thread_local.pool)
|
||||||
|
|
||||||
|
|
||||||
async def get_redis_connection():
|
async def get_redis_connection():
|
||||||
"""获取Redis连接"""
|
"""获取Redis连接"""
|
||||||
@@ -44,10 +90,8 @@ async def aio_redis_set(key: str, val: str | dict, expire: int = None):
|
|||||||
val = json.dumps(val, ensure_ascii=False)
|
val = json.dumps(val, ensure_ascii=False)
|
||||||
|
|
||||||
if expire is not None:
|
if expire is not None:
|
||||||
# 设置带过期时间的键值
|
|
||||||
await aio_redis.set(key, val, ex=expire)
|
await aio_redis.set(key, val, ex=expire)
|
||||||
else:
|
else:
|
||||||
# 设置永久键值
|
|
||||||
await aio_redis.set(key, val)
|
await aio_redis.set(key, val)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Redis set错误: {str(e)}")
|
logger.error(f"Redis set错误: {str(e)}")
|
||||||
|
|||||||
8
api/app/cache/memory/activity_stats_cache.py
vendored
8
api/app/cache/memory/activity_stats_cache.py
vendored
@@ -10,7 +10,7 @@ import logging
|
|||||||
from typing import Optional, Dict, Any
|
from typing import Optional, Dict, Any
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
from app.aioRedis import aio_redis
|
from app.aioRedis import get_thread_safe_redis
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -68,7 +68,7 @@ class ActivityStatsCache:
|
|||||||
"cached": True,
|
"cached": True,
|
||||||
}
|
}
|
||||||
value = json.dumps(payload, ensure_ascii=False)
|
value = json.dumps(payload, ensure_ascii=False)
|
||||||
await aio_redis.set(key, value, ex=expire)
|
await get_thread_safe_redis().set(key, value, ex=expire)
|
||||||
logger.info(f"设置活动统计缓存成功: {key}, 过期时间: {expire}秒")
|
logger.info(f"设置活动统计缓存成功: {key}, 过期时间: {expire}秒")
|
||||||
return True
|
return True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -90,7 +90,7 @@ class ActivityStatsCache:
|
|||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
key = cls._get_key(workspace_id)
|
key = cls._get_key(workspace_id)
|
||||||
value = await aio_redis.get(key)
|
value = await get_thread_safe_redis().get(key)
|
||||||
if value:
|
if value:
|
||||||
payload = json.loads(value)
|
payload = json.loads(value)
|
||||||
logger.info(f"命中活动统计缓存: {key}")
|
logger.info(f"命中活动统计缓存: {key}")
|
||||||
@@ -116,7 +116,7 @@ class ActivityStatsCache:
|
|||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
key = cls._get_key(workspace_id)
|
key = cls._get_key(workspace_id)
|
||||||
result = await aio_redis.delete(key)
|
result = await get_thread_safe_redis().delete(key)
|
||||||
logger.info(f"删除活动统计缓存: {key}, 结果: {result}")
|
logger.info(f"删除活动统计缓存: {key}, 结果: {result}")
|
||||||
return result > 0
|
return result > 0
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
8
api/app/cache/memory/interest_memory.py
vendored
8
api/app/cache/memory/interest_memory.py
vendored
@@ -9,7 +9,7 @@ import logging
|
|||||||
from typing import Optional, List, Dict, Any
|
from typing import Optional, List, Dict, Any
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
from app.aioRedis import aio_redis
|
from app.aioRedis import get_thread_safe_redis
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -62,7 +62,7 @@ class InterestMemoryCache:
|
|||||||
"cached": True,
|
"cached": True,
|
||||||
}
|
}
|
||||||
value = json.dumps(payload, ensure_ascii=False)
|
value = json.dumps(payload, ensure_ascii=False)
|
||||||
await aio_redis.set(key, value, ex=expire)
|
await get_thread_safe_redis().set(key, value, ex=expire)
|
||||||
logger.info(f"设置兴趣分布缓存成功: {key}, 过期时间: {expire}秒")
|
logger.info(f"设置兴趣分布缓存成功: {key}, 过期时间: {expire}秒")
|
||||||
return True
|
return True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -86,7 +86,7 @@ class InterestMemoryCache:
|
|||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
key = cls._get_key(end_user_id, language)
|
key = cls._get_key(end_user_id, language)
|
||||||
value = await aio_redis.get(key)
|
value = await get_thread_safe_redis().get(key)
|
||||||
if value:
|
if value:
|
||||||
payload = json.loads(value)
|
payload = json.loads(value)
|
||||||
logger.info(f"命中兴趣分布缓存: {key}")
|
logger.info(f"命中兴趣分布缓存: {key}")
|
||||||
@@ -114,7 +114,7 @@ class InterestMemoryCache:
|
|||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
key = cls._get_key(end_user_id, language)
|
key = cls._get_key(end_user_id, language)
|
||||||
result = await aio_redis.delete(key)
|
result = await get_thread_safe_redis().delete(key)
|
||||||
logger.info(f"删除兴趣分布缓存: {key}, 结果: {result}")
|
logger.info(f"删除兴趣分布缓存: {key}, 结果: {result}")
|
||||||
return result > 0
|
return result > 0
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
import os
|
import os
|
||||||
import platform
|
import platform
|
||||||
|
import re
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
from urllib.parse import quote
|
from urllib.parse import quote
|
||||||
|
|
||||||
@@ -11,21 +12,24 @@ from app.core.logging_config import get_logger
|
|||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _mask_url(url: str) -> str:
|
||||||
|
"""隐藏 URL 中的密码部分,适用于 redis:// 和 amqp:// 等协议"""
|
||||||
|
return re.sub(r'(://[^:]*:)[^@]+(@)', r'\1***\2', url)
|
||||||
|
|
||||||
# macOS fork() safety - must be set before any Celery initialization
|
# macOS fork() safety - must be set before any Celery initialization
|
||||||
if platform.system() == 'Darwin':
|
if platform.system() == 'Darwin':
|
||||||
os.environ.setdefault('OBJC_DISABLE_INITIALIZE_FORK_SAFETY', 'YES')
|
os.environ.setdefault('OBJC_DISABLE_INITIALIZE_FORK_SAFETY', 'YES')
|
||||||
|
|
||||||
# 创建 Celery 应用实例
|
# 创建 Celery 应用实例
|
||||||
# broker: 任务队列(使用 Redis DB,由 CELERY_BROKER_DB 指定)
|
# broker: 优先使用环境变量 CELERY_BROKER_URL(支持 amqp:// 等任意协议),
|
||||||
# backend: 结果存储(使用 Redis DB,由 CELERY_BACKEND_DB 指定)
|
# 未配置则回退到 Redis 方案
|
||||||
|
# backend: 结果存储(使用 Redis)
|
||||||
# NOTE: 不要在 .env 中设置 BROKER_URL / RESULT_BACKEND / CELERY_BROKER / CELERY_BACKEND,
|
# NOTE: 不要在 .env 中设置 BROKER_URL / RESULT_BACKEND / CELERY_BROKER / CELERY_BACKEND,
|
||||||
# 这些名称会被 Celery CLI 的 Click 框架劫持,详见 docs/celery-env-bug-report.md
|
# 这些名称会被 Celery CLI 的 Click 框架劫持,详见 docs/celery-env-bug-report.md
|
||||||
|
|
||||||
# Build canonical broker/backend URLs and force them into os.environ so that
|
_broker_url = os.getenv("CELERY_BROKER_URL") or \
|
||||||
# Celery's Settings.broker_url property (which checks CELERY_BROKER_URL first)
|
f"redis://:{quote(settings.REDIS_PASSWORD)}@{settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.REDIS_DB_CELERY_BROKER}"
|
||||||
# cannot be overridden by stray env vars.
|
|
||||||
# See: https://github.com/celery/celery/issues/4284
|
|
||||||
_broker_url = f"redis://:{quote(settings.REDIS_PASSWORD)}@{settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.REDIS_DB_CELERY_BROKER}"
|
|
||||||
_backend_url = f"redis://:{quote(settings.REDIS_PASSWORD)}@{settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.REDIS_DB_CELERY_BACKEND}"
|
_backend_url = f"redis://:{quote(settings.REDIS_PASSWORD)}@{settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.REDIS_DB_CELERY_BACKEND}"
|
||||||
os.environ["CELERY_BROKER_URL"] = _broker_url
|
os.environ["CELERY_BROKER_URL"] = _broker_url
|
||||||
os.environ["CELERY_RESULT_BACKEND"] = _backend_url
|
os.environ["CELERY_RESULT_BACKEND"] = _backend_url
|
||||||
@@ -45,8 +49,8 @@ celery_app = Celery(
|
|||||||
logger.info(
|
logger.info(
|
||||||
"Celery app initialized",
|
"Celery app initialized",
|
||||||
extra={
|
extra={
|
||||||
"broker": _broker_url.replace(quote(settings.REDIS_PASSWORD), "***"),
|
"broker": _mask_url(_broker_url),
|
||||||
"backend": _backend_url.replace(quote(settings.REDIS_PASSWORD), "***"),
|
"backend": _mask_url(_backend_url),
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
# Default queue for unrouted tasks
|
# Default queue for unrouted tasks
|
||||||
@@ -77,6 +81,7 @@ celery_app.conf.update(
|
|||||||
|
|
||||||
# Worker 设置 (per-worker settings are in docker-compose command line)
|
# Worker 设置 (per-worker settings are in docker-compose command line)
|
||||||
worker_prefetch_multiplier=1, # Don't hoard tasks, fairer distribution
|
worker_prefetch_multiplier=1, # Don't hoard tasks, fairer distribution
|
||||||
|
worker_redirect_stdouts_level='INFO', # stdout/print → INFO instead of WARNING
|
||||||
|
|
||||||
# 结果过期时间
|
# 结果过期时间
|
||||||
result_expires=3600, # 结果保存1小时
|
result_expires=3600, # 结果保存1小时
|
||||||
@@ -103,6 +108,9 @@ celery_app.conf.update(
|
|||||||
'app.core.memory.agent.long_term_storage.time': {'queue': 'memory_tasks'},
|
'app.core.memory.agent.long_term_storage.time': {'queue': 'memory_tasks'},
|
||||||
'app.core.memory.agent.long_term_storage.aggregate': {'queue': 'memory_tasks'},
|
'app.core.memory.agent.long_term_storage.aggregate': {'queue': 'memory_tasks'},
|
||||||
|
|
||||||
|
# Clustering tasks → memory_tasks queue (使用相同的 worker,避免 macOS fork 问题)
|
||||||
|
'app.tasks.run_incremental_clustering': {'queue': 'memory_tasks'},
|
||||||
|
|
||||||
# Document tasks → document_tasks queue (prefork worker)
|
# Document tasks → document_tasks queue (prefork worker)
|
||||||
'app.core.rag.tasks.parse_document': {'queue': 'document_tasks'},
|
'app.core.rag.tasks.parse_document': {'queue': 'document_tasks'},
|
||||||
'app.core.rag.tasks.build_graphrag_for_kb': {'queue': 'document_tasks'},
|
'app.core.rag.tasks.build_graphrag_for_kb': {'queue': 'document_tasks'},
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ from fastapi import APIRouter
|
|||||||
from . import (
|
from . import (
|
||||||
api_key_controller,
|
api_key_controller,
|
||||||
app_controller,
|
app_controller,
|
||||||
|
app_log_controller,
|
||||||
auth_controller,
|
auth_controller,
|
||||||
chunk_controller,
|
chunk_controller,
|
||||||
document_controller,
|
document_controller,
|
||||||
@@ -70,6 +71,7 @@ manager_router.include_router(chunk_controller.router)
|
|||||||
manager_router.include_router(test_controller.router)
|
manager_router.include_router(test_controller.router)
|
||||||
manager_router.include_router(knowledgeshare_controller.router)
|
manager_router.include_router(knowledgeshare_controller.router)
|
||||||
manager_router.include_router(app_controller.router)
|
manager_router.include_router(app_controller.router)
|
||||||
|
manager_router.include_router(app_log_controller.router)
|
||||||
manager_router.include_router(upload_controller.router)
|
manager_router.include_router(upload_controller.router)
|
||||||
manager_router.include_router(memory_agent_controller.router)
|
manager_router.include_router(memory_agent_controller.router)
|
||||||
manager_router.include_router(memory_dashboard_controller.router)
|
manager_router.include_router(memory_dashboard_controller.router)
|
||||||
|
|||||||
@@ -65,16 +65,42 @@ def list_apps(
|
|||||||
- 默认包含本工作空间的应用和分享给本工作空间的应用
|
- 默认包含本工作空间的应用和分享给本工作空间的应用
|
||||||
- 设置 include_shared=false 可以只查看本工作空间的应用
|
- 设置 include_shared=false 可以只查看本工作空间的应用
|
||||||
- 当提供 ids 参数时,按逗号分割获取指定应用,不分页
|
- 当提供 ids 参数时,按逗号分割获取指定应用,不分页
|
||||||
|
- search 参数支持:应用名称模糊搜索、API Key 精确搜索
|
||||||
"""
|
"""
|
||||||
|
from sqlalchemy import select as sa_select
|
||||||
|
from app.models.api_key_model import ApiKey
|
||||||
|
|
||||||
workspace_id = current_user.current_workspace_id
|
workspace_id = current_user.current_workspace_id
|
||||||
service = app_service.AppService(db)
|
service = app_service.AppService(db)
|
||||||
|
|
||||||
# 当 ids 存在且不为 None 时,根据 ids 获取应用
|
# 通过 search 参数搜索:支持应用名称模糊搜索和 API Key 精确搜索
|
||||||
|
if search:
|
||||||
|
search = search.strip()
|
||||||
|
# 尝试作为 API Key 精确匹配(API Key 通常较长)
|
||||||
|
if len(search) >= 10:
|
||||||
|
matched_id = db.execute(
|
||||||
|
sa_select(ApiKey.resource_id).where(
|
||||||
|
ApiKey.workspace_id == workspace_id,
|
||||||
|
ApiKey.api_key == search,
|
||||||
|
ApiKey.resource_id.isnot(None),
|
||||||
|
)
|
||||||
|
).scalar_one_or_none()
|
||||||
|
if matched_id:
|
||||||
|
# 找到 API Key,直接返回关联的应用
|
||||||
|
ids = str(matched_id)
|
||||||
|
|
||||||
|
# 当 ids 存在时,根据 ids 获取应用(不分页)
|
||||||
if ids is not None:
|
if ids is not None:
|
||||||
app_ids = [app_id.strip() for app_id in ids.split(',') if app_id.strip()]
|
app_ids = [app_id.strip() for app_id in ids.split(',') if app_id.strip()]
|
||||||
items_orm = app_service.get_apps_by_ids(db, app_ids, workspace_id)
|
if app_ids:
|
||||||
items = [service._convert_to_schema(app, workspace_id) for app in items_orm]
|
items_orm = app_service.get_apps_by_ids(db, app_ids, workspace_id)
|
||||||
return success(data=items)
|
items = [service._convert_to_schema(app, workspace_id) for app in items_orm]
|
||||||
|
# 返回标准分页格式
|
||||||
|
meta = PageMeta(page=1, pagesize=len(items), total=len(items), hasnext=False)
|
||||||
|
return success(data=PageData(page=meta, items=items))
|
||||||
|
# ids 为空时,返回空列表
|
||||||
|
meta = PageMeta(page=1, pagesize=0, total=0, hasnext=False)
|
||||||
|
return success(data=PageData(page=meta, items=[]))
|
||||||
|
|
||||||
# 正常分页查询
|
# 正常分页查询
|
||||||
items_orm, total = app_service.list_apps(
|
items_orm, total = app_service.list_apps(
|
||||||
|
|||||||
89
api/app/controllers/app_log_controller.py
Normal file
89
api/app/controllers/app_log_controller.py
Normal file
@@ -0,0 +1,89 @@
|
|||||||
|
"""应用日志(消息记录)接口"""
|
||||||
|
import uuid
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, Query
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from app.core.logging_config import get_business_logger
|
||||||
|
from app.core.response_utils import success
|
||||||
|
from app.db import get_db
|
||||||
|
from app.dependencies import get_current_user, cur_workspace_access_guard
|
||||||
|
from app.schemas.app_log_schema import AppLogConversation, AppLogConversationDetail
|
||||||
|
from app.schemas.response_schema import PageData, PageMeta
|
||||||
|
from app.services.app_service import AppService
|
||||||
|
from app.services.app_log_service import AppLogService
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/apps", tags=["App Logs"])
|
||||||
|
logger = get_business_logger()
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{app_id}/logs", summary="应用日志 - 会话列表")
|
||||||
|
@cur_workspace_access_guard()
|
||||||
|
def list_app_logs(
|
||||||
|
app_id: uuid.UUID,
|
||||||
|
page: int = Query(1, ge=1),
|
||||||
|
pagesize: int = Query(20, ge=1, le=100),
|
||||||
|
is_draft: Optional[bool] = None,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user=Depends(get_current_user),
|
||||||
|
):
|
||||||
|
"""查看应用下所有会话记录(分页)
|
||||||
|
|
||||||
|
- 支持按 is_draft 筛选(草稿会话 / 发布会话)
|
||||||
|
- 按最新更新时间倒序排列
|
||||||
|
- 所有人(包括共享者和被共享者)都只能查看自己的会话记录
|
||||||
|
"""
|
||||||
|
workspace_id = current_user.current_workspace_id
|
||||||
|
|
||||||
|
# 验证应用访问权限
|
||||||
|
app_service = AppService(db)
|
||||||
|
app_service.get_app(app_id, workspace_id)
|
||||||
|
|
||||||
|
# 使用 Service 层查询
|
||||||
|
log_service = AppLogService(db)
|
||||||
|
conversations, total = log_service.list_conversations(
|
||||||
|
app_id=app_id,
|
||||||
|
workspace_id=workspace_id,
|
||||||
|
page=page,
|
||||||
|
pagesize=pagesize,
|
||||||
|
is_draft=is_draft
|
||||||
|
)
|
||||||
|
|
||||||
|
items = [AppLogConversation.model_validate(c) for c in conversations]
|
||||||
|
meta = PageMeta(page=page, pagesize=pagesize, total=total, hasnext=(page * pagesize) < total)
|
||||||
|
|
||||||
|
return success(data=PageData(page=meta, items=items))
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{app_id}/logs/{conversation_id}", summary="应用日志 - 会话消息详情")
|
||||||
|
@cur_workspace_access_guard()
|
||||||
|
def get_app_log_detail(
|
||||||
|
app_id: uuid.UUID,
|
||||||
|
conversation_id: uuid.UUID,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user=Depends(get_current_user),
|
||||||
|
):
|
||||||
|
"""查看某会话的完整消息记录
|
||||||
|
|
||||||
|
- 返回会话基本信息 + 所有消息(按时间正序)
|
||||||
|
- 消息 meta_data 包含模型名、token 用量等信息
|
||||||
|
- 所有人(包括共享者和被共享者)都只能查看自己的会话详情
|
||||||
|
"""
|
||||||
|
workspace_id = current_user.current_workspace_id
|
||||||
|
|
||||||
|
# 验证应用访问权限
|
||||||
|
app_service = AppService(db)
|
||||||
|
app_service.get_app(app_id, workspace_id)
|
||||||
|
|
||||||
|
# 使用 Service 层查询
|
||||||
|
log_service = AppLogService(db)
|
||||||
|
conversation = log_service.get_conversation_detail(
|
||||||
|
app_id=app_id,
|
||||||
|
conversation_id=conversation_id,
|
||||||
|
workspace_id=workspace_id
|
||||||
|
)
|
||||||
|
|
||||||
|
detail = AppLogConversationDetail.model_validate(conversation)
|
||||||
|
|
||||||
|
return success(data=detail)
|
||||||
@@ -14,6 +14,9 @@ Routes:
|
|||||||
import os
|
import os
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
import httpx
|
||||||
|
import mimetypes
|
||||||
|
from urllib.parse import urlparse, unquote
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, File, HTTPException, Request, UploadFile, status
|
from fastapi import APIRouter, Depends, File, HTTPException, Request, UploadFile, status
|
||||||
from fastapi.responses import FileResponse, RedirectResponse
|
from fastapi.responses import FileResponse, RedirectResponse
|
||||||
@@ -91,7 +94,7 @@ async def upload_file(
|
|||||||
|
|
||||||
if file_size > settings.MAX_FILE_SIZE:
|
if file_size > settings.MAX_FILE_SIZE:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
status_code=status.HTTP_413_CONTENT_TOO_LARGE,
|
||||||
detail=f"The file size exceeds the {settings.MAX_FILE_SIZE} byte limit"
|
detail=f"The file size exceeds the {settings.MAX_FILE_SIZE} byte limit"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -172,7 +175,6 @@ async def upload_file_with_share_token(
|
|||||||
|
|
||||||
# Get share and release info from share_token
|
# Get share and release info from share_token
|
||||||
service = ReleaseShareService(db)
|
service = ReleaseShareService(db)
|
||||||
share_info = service.get_shared_release_info(share_token=share_data.share_token)
|
|
||||||
|
|
||||||
# Get share object to access app_id
|
# Get share object to access app_id
|
||||||
share = service.repo.get_by_share_token(share_data.share_token)
|
share = service.repo.get_by_share_token(share_data.share_token)
|
||||||
@@ -291,6 +293,101 @@ async def upload_file_with_share_token(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/files/info-by-url", response_model=ApiResponse)
|
||||||
|
async def get_file_info_by_url(
|
||||||
|
url: str,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Get file information by network URL (no authentication required).
|
||||||
|
|
||||||
|
Fetches file metadata from a remote URL via HTTP HEAD request.
|
||||||
|
Falls back to GET request if HEAD is not supported.
|
||||||
|
Returns file type, name, and size.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
url: The network URL of the file.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ApiResponse with file information.
|
||||||
|
"""
|
||||||
|
api_logger.info(f"File info by URL request: url={url}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||||
|
# Try HEAD request first
|
||||||
|
response = await client.head(url, follow_redirects=True)
|
||||||
|
|
||||||
|
# If HEAD fails, try GET request (some servers don't support HEAD)
|
||||||
|
if response.status_code != 200:
|
||||||
|
api_logger.info(f"HEAD request failed with {response.status_code}, trying GET request")
|
||||||
|
response = await client.get(url, follow_redirects=True)
|
||||||
|
|
||||||
|
if response.status_code != 200:
|
||||||
|
api_logger.error(f"Failed to fetch file info: HTTP {response.status_code}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail=f"Unable to access file: HTTP {response.status_code}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get file size from Content-Length header or actual content
|
||||||
|
file_size = response.headers.get("Content-Length")
|
||||||
|
if file_size:
|
||||||
|
file_size = int(file_size)
|
||||||
|
elif hasattr(response, 'content'):
|
||||||
|
file_size = len(response.content)
|
||||||
|
else:
|
||||||
|
file_size = None
|
||||||
|
|
||||||
|
# Get content type from Content-Type header
|
||||||
|
content_type = response.headers.get("Content-Type", "application/octet-stream")
|
||||||
|
# Remove charset and other parameters from content type
|
||||||
|
content_type = content_type.split(';')[0].strip()
|
||||||
|
|
||||||
|
# Extract filename from Content-Disposition or URL
|
||||||
|
file_name = None
|
||||||
|
content_disposition = response.headers.get("Content-Disposition")
|
||||||
|
if content_disposition and "filename=" in content_disposition:
|
||||||
|
parts = content_disposition.split("filename=")
|
||||||
|
if len(parts) > 1:
|
||||||
|
file_name = parts[1].strip('"').strip("'")
|
||||||
|
|
||||||
|
if not file_name:
|
||||||
|
parsed_url = urlparse(url)
|
||||||
|
file_name = unquote(os.path.basename(parsed_url.path)) or "unknown"
|
||||||
|
|
||||||
|
# Extract file extension from filename
|
||||||
|
_, file_ext = os.path.splitext(file_name)
|
||||||
|
|
||||||
|
# If no extension found, infer from content type
|
||||||
|
if not file_ext:
|
||||||
|
ext = mimetypes.guess_extension(content_type)
|
||||||
|
if ext:
|
||||||
|
file_ext = ext
|
||||||
|
file_name = f"{file_name}{file_ext}"
|
||||||
|
|
||||||
|
api_logger.info(f"File info retrieved: name={file_name}, size={file_size}, type={content_type}")
|
||||||
|
|
||||||
|
return success(
|
||||||
|
data={
|
||||||
|
"url": url,
|
||||||
|
"file_name": file_name,
|
||||||
|
"file_ext": file_ext.lower() if file_ext else "",
|
||||||
|
"file_size": file_size,
|
||||||
|
"content_type": content_type,
|
||||||
|
},
|
||||||
|
msg="File information retrieved successfully"
|
||||||
|
)
|
||||||
|
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
api_logger.error(f"Unexpected error: {e}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail=f"Failed to retrieve file information: {str(e)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@router.get("/files/{file_id}", response_model=Any)
|
@router.get("/files/{file_id}", response_model=Any)
|
||||||
async def download_file(
|
async def download_file(
|
||||||
request: Request,
|
request: Request,
|
||||||
@@ -477,8 +574,12 @@ async def get_file_url(
|
|||||||
# For local storage, generate signed URL with expiration
|
# For local storage, generate signed URL with expiration
|
||||||
url = generate_signed_url(str(file_id), expires)
|
url = generate_signed_url(str(file_id), expires)
|
||||||
else:
|
else:
|
||||||
# For remote storage (OSS/S3), get presigned URL
|
# For remote storage (OSS/S3), get presigned URL with forced download
|
||||||
url = await storage_service.get_file_url(file_key, expires=expires)
|
url = await storage_service.get_file_url(
|
||||||
|
file_key,
|
||||||
|
expires=expires,
|
||||||
|
file_name=file_metadata.file_name,
|
||||||
|
)
|
||||||
url = _match_scheme(request, url)
|
url = _match_scheme(request, url)
|
||||||
|
|
||||||
api_logger.info(f"Generated file URL: file_id={file_id}")
|
api_logger.info(f"Generated file URL: file_id={file_id}")
|
||||||
@@ -499,6 +600,51 @@ async def get_file_url(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/files/{file_id}/public-url", response_model=ApiResponse)
|
||||||
|
async def get_permanent_file_url(
|
||||||
|
file_id: uuid.UUID,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
storage_service: FileStorageService = Depends(get_file_storage_service),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
获取文件的永久公开 URL(无过期时间)。
|
||||||
|
|
||||||
|
- 本地存储:返回 API 永久访问地址(基于 FILE_LOCAL_SERVER_URL 配置)
|
||||||
|
- 远程存储(OSS/S3):返回 bucket 公读地址(需 bucket 已配置公共读权限)
|
||||||
|
"""
|
||||||
|
file_metadata = db.query(FileMetadata).filter(FileMetadata.id == file_id).first()
|
||||||
|
if not file_metadata:
|
||||||
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="The file does not exist")
|
||||||
|
|
||||||
|
if file_metadata.status != "completed":
|
||||||
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail=f"File upload not completed, status: {file_metadata.status}")
|
||||||
|
|
||||||
|
file_key = file_metadata.file_key
|
||||||
|
storage = storage_service.storage
|
||||||
|
|
||||||
|
try:
|
||||||
|
if isinstance(storage, LocalStorage):
|
||||||
|
url = f"{settings.FILE_LOCAL_SERVER_URL}/storage/permanent/{file_id}"
|
||||||
|
else:
|
||||||
|
url = await storage.get_permanent_url(file_key)
|
||||||
|
if not url:
|
||||||
|
raise HTTPException(status_code=status.HTTP_501_NOT_IMPLEMENTED,
|
||||||
|
detail="Permanent URL not supported for current storage backend")
|
||||||
|
|
||||||
|
api_logger.info(f"Generated permanent URL: file_id={file_id}")
|
||||||
|
return success(
|
||||||
|
data={"url": url, "expires_in": None, "permanent": True, "file_name": file_metadata.file_name},
|
||||||
|
msg="Permanent file URL generated successfully"
|
||||||
|
)
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
api_logger.error(f"Failed to generate permanent URL: {e}")
|
||||||
|
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail=f"Failed to generate permanent URL: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
@router.get("/public/{file_id}", response_model=Any)
|
@router.get("/public/{file_id}", response_model=Any)
|
||||||
async def public_download_file(
|
async def public_download_file(
|
||||||
request: Request,
|
request: Request,
|
||||||
@@ -644,7 +790,7 @@ async def permanent_download_file(
|
|||||||
# For remote storage, redirect to presigned URL with long expiration
|
# For remote storage, redirect to presigned URL with long expiration
|
||||||
try:
|
try:
|
||||||
# Use a very long expiration (7 days max for most cloud providers)
|
# Use a very long expiration (7 days max for most cloud providers)
|
||||||
presigned_url = await storage_service.get_file_url(file_key, expires=604800)
|
presigned_url = await storage_service.get_file_url(file_key, expires=604800, file_name=file_metadata.file_name)
|
||||||
presigned_url = _match_scheme(request, presigned_url)
|
presigned_url = _match_scheme(request, presigned_url)
|
||||||
return RedirectResponse(url=presigned_url, status_code=status.HTTP_302_FOUND)
|
return RedirectResponse(url=presigned_url, status_code=status.HTTP_302_FOUND)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -653,3 +799,44 @@ async def permanent_download_file(
|
|||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
detail=f"Failed to retrieve file: {str(e)}"
|
detail=f"Failed to retrieve file: {str(e)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/files/{file_id}/status", response_model=ApiResponse)
|
||||||
|
async def get_file_status(
|
||||||
|
file_id: uuid.UUID,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Get file upload/processing status (no authentication required).
|
||||||
|
|
||||||
|
This endpoint is used to check if a file (e.g., TTS audio) is ready.
|
||||||
|
Returns status: pending, completed, or failed.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_id: The UUID of the file.
|
||||||
|
db: Database session.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ApiResponse with file status and metadata.
|
||||||
|
"""
|
||||||
|
api_logger.info(f"File status request: file_id={file_id}")
|
||||||
|
|
||||||
|
# Query file metadata from database
|
||||||
|
file_metadata = db.query(FileMetadata).filter(FileMetadata.id == file_id).first()
|
||||||
|
if not file_metadata:
|
||||||
|
api_logger.warning(f"File not found in database: file_id={file_id}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="The file does not exist"
|
||||||
|
)
|
||||||
|
|
||||||
|
return success(
|
||||||
|
data={
|
||||||
|
"file_id": str(file_id),
|
||||||
|
"status": file_metadata.status,
|
||||||
|
"file_name": file_metadata.file_name,
|
||||||
|
"file_size": file_metadata.file_size,
|
||||||
|
"content_type": file_metadata.content_type,
|
||||||
|
},
|
||||||
|
msg="File status retrieved successfully"
|
||||||
|
)
|
||||||
|
|||||||
@@ -91,9 +91,11 @@ async def get_mcp_servers(
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
cookies = api.get_cookies(token)
|
cookies = api.get_cookies(token)
|
||||||
|
headers=api.builder_headers(api.headers)
|
||||||
|
headers['Authorization'] = f'Bearer {token}'
|
||||||
r = api.session.put(
|
r = api.session.put(
|
||||||
url=api.mcp_base_url,
|
url=api.mcp_base_url,
|
||||||
headers=api.builder_headers(api.headers),
|
headers=headers,
|
||||||
json=body,
|
json=body,
|
||||||
cookies=cookies)
|
cookies=cookies)
|
||||||
raise_for_http_status(r)
|
raise_for_http_status(r)
|
||||||
@@ -173,6 +175,7 @@ async def get_operational_mcp_servers(
|
|||||||
|
|
||||||
url = f'{api.mcp_base_url}/operational'
|
url = f'{api.mcp_base_url}/operational'
|
||||||
headers = api.builder_headers(api.headers)
|
headers = api.builder_headers(api.headers)
|
||||||
|
headers['Authorization'] = f'Bearer {token}'
|
||||||
|
|
||||||
try:
|
try:
|
||||||
cookies = api.get_cookies(access_token=token, cookies_required=True)
|
cookies = api.get_cookies(access_token=token, cookies_required=True)
|
||||||
@@ -260,7 +263,9 @@ async def create_mcp_market_config(
|
|||||||
api.login(create_data.token)
|
api.login(create_data.token)
|
||||||
body = {'filter': {}, 'page_number': 1, 'page_size': 1, 'search': None}
|
body = {'filter': {}, 'page_number': 1, 'page_size': 1, 'search': None}
|
||||||
cookies = api.get_cookies(create_data.token)
|
cookies = api.get_cookies(create_data.token)
|
||||||
r = api.session.put(url=api.mcp_base_url, headers=api.builder_headers(api.headers), json=body, cookies=cookies)
|
headers = api.builder_headers(api.headers)
|
||||||
|
headers['Authorization'] = f'Bearer {create_data.token}'
|
||||||
|
r = api.session.put(url=api.mcp_base_url, headers=headers, json=body, cookies=cookies)
|
||||||
raise_for_http_status(r)
|
raise_for_http_status(r)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
api_logger.warning(f"Token validation failed for ModelScope MCP market: {str(e)}")
|
api_logger.warning(f"Token validation failed for ModelScope MCP market: {str(e)}")
|
||||||
@@ -290,9 +295,11 @@ async def create_mcp_market_config(
|
|||||||
'search': ""
|
'search': ""
|
||||||
}
|
}
|
||||||
cookies = api.get_cookies(token)
|
cookies = api.get_cookies(token)
|
||||||
|
headers = api.builder_headers(api.headers)
|
||||||
|
headers['Authorization'] = f'Bearer {token}'
|
||||||
r = api.session.put(
|
r = api.session.put(
|
||||||
url=api.mcp_base_url,
|
url=api.mcp_base_url,
|
||||||
headers=api.builder_headers(api.headers),
|
headers=headers,
|
||||||
json=body,
|
json=body,
|
||||||
cookies=cookies)
|
cookies=cookies)
|
||||||
raise_for_http_status(r)
|
raise_for_http_status(r)
|
||||||
@@ -393,7 +400,9 @@ async def update_mcp_market_config(
|
|||||||
api.login(update_data.token)
|
api.login(update_data.token)
|
||||||
body = {'filter': {}, 'page_number': 1, 'page_size': 1, 'search': None}
|
body = {'filter': {}, 'page_number': 1, 'page_size': 1, 'search': None}
|
||||||
cookies = api.get_cookies(update_data.token)
|
cookies = api.get_cookies(update_data.token)
|
||||||
r = api.session.put(url=api.mcp_base_url, headers=api.builder_headers(api.headers), json=body, cookies=cookies)
|
headers = api.builder_headers(api.headers)
|
||||||
|
headers['Authorization'] = f'Bearer {update_data.token}'
|
||||||
|
r = api.session.put(url=api.mcp_base_url, headers=headers, json=body, cookies=cookies)
|
||||||
raise_for_http_status(r)
|
raise_for_http_status(r)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
api_logger.warning(f"Token validation failed for ModelScope MCP market: {str(e)}")
|
api_logger.warning(f"Token validation failed for ModelScope MCP market: {str(e)}")
|
||||||
|
|||||||
@@ -118,142 +118,142 @@ async def download_log(
|
|||||||
return fail(BizCode.INTERNAL_ERROR, "启动日志流式传输失败", str(e))
|
return fail(BizCode.INTERNAL_ERROR, "启动日志流式传输失败", str(e))
|
||||||
|
|
||||||
|
|
||||||
@router.post("/writer_service", response_model=ApiResponse)
|
# @router.post("/writer_service", response_model=ApiResponse)
|
||||||
@cur_workspace_access_guard()
|
# @cur_workspace_access_guard()
|
||||||
async def write_server(
|
# async def write_server(
|
||||||
user_input: Write_UserInput,
|
# user_input: Write_UserInput,
|
||||||
language_type: str = Header(default=None, alias="X-Language-Type"),
|
# language_type: str = Header(default=None, alias="X-Language-Type"),
|
||||||
db: Session = Depends(get_db),
|
# db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user)
|
# current_user: User = Depends(get_current_user)
|
||||||
):
|
# ):
|
||||||
"""
|
# """
|
||||||
Write service endpoint - processes write operations synchronously
|
# Write service endpoint - processes write operations synchronously
|
||||||
|
#
|
||||||
Args:
|
# Args:
|
||||||
user_input: Write request containing message and end_user_id
|
# user_input: Write request containing message and end_user_id
|
||||||
language_type: 语言类型 ("zh" 中文, "en" 英文),通过 X-Language-Type Header 传递
|
# language_type: 语言类型 ("zh" 中文, "en" 英文),通过 X-Language-Type Header 传递
|
||||||
|
#
|
||||||
Returns:
|
# Returns:
|
||||||
Response with write operation status
|
# Response with write operation status
|
||||||
"""
|
# """
|
||||||
# 使用集中化的语言校验
|
# # 使用集中化的语言校验
|
||||||
language = get_language_from_header(language_type)
|
# language = get_language_from_header(language_type)
|
||||||
|
#
|
||||||
config_id = user_input.config_id
|
# config_id = user_input.config_id
|
||||||
workspace_id = current_user.current_workspace_id
|
# workspace_id = current_user.current_workspace_id
|
||||||
api_logger.info(f"Write service: workspace_id={workspace_id}, config_id={config_id}, language_type={language}")
|
# api_logger.info(f"Write service: workspace_id={workspace_id}, config_id={config_id}, language_type={language}")
|
||||||
|
#
|
||||||
# 获取 storage_type,如果为 None 则使用默认值
|
# # 获取 storage_type,如果为 None 则使用默认值
|
||||||
storage_type = workspace_service.get_workspace_storage_type(
|
# storage_type = workspace_service.get_workspace_storage_type(
|
||||||
db=db,
|
# db=db,
|
||||||
workspace_id=workspace_id,
|
# workspace_id=workspace_id,
|
||||||
user=current_user
|
# user=current_user
|
||||||
)
|
# )
|
||||||
if storage_type is None: storage_type = 'neo4j'
|
# if storage_type is None: storage_type = 'neo4j'
|
||||||
user_rag_memory_id = ''
|
# user_rag_memory_id = ''
|
||||||
|
#
|
||||||
# 如果 storage_type 是 rag,必须确保有有效的 user_rag_memory_id
|
# # 如果 storage_type 是 rag,必须确保有有效的 user_rag_memory_id
|
||||||
if storage_type == 'rag':
|
# if storage_type == 'rag':
|
||||||
if workspace_id:
|
# if workspace_id:
|
||||||
knowledge = knowledge_repository.get_knowledge_by_name(
|
# knowledge = knowledge_repository.get_knowledge_by_name(
|
||||||
db=db,
|
# db=db,
|
||||||
name="USER_RAG_MERORY",
|
# name="USER_RAG_MERORY",
|
||||||
workspace_id=workspace_id
|
# workspace_id=workspace_id
|
||||||
)
|
# )
|
||||||
if knowledge:
|
# if knowledge:
|
||||||
user_rag_memory_id = str(knowledge.id)
|
# user_rag_memory_id = str(knowledge.id)
|
||||||
else:
|
# else:
|
||||||
api_logger.warning(
|
# api_logger.warning(
|
||||||
f"未找到名为 'USER_RAG_MERORY' 的知识库,workspace_id: {workspace_id},将使用 neo4j 存储")
|
# f"未找到名为 'USER_RAG_MERORY' 的知识库,workspace_id: {workspace_id},将使用 neo4j 存储")
|
||||||
storage_type = 'neo4j'
|
# storage_type = 'neo4j'
|
||||||
else:
|
# else:
|
||||||
api_logger.warning("workspace_id 为空,无法使用 rag 存储,将使用 neo4j 存储")
|
# api_logger.warning("workspace_id 为空,无法使用 rag 存储,将使用 neo4j 存储")
|
||||||
storage_type = 'neo4j'
|
# storage_type = 'neo4j'
|
||||||
|
#
|
||||||
api_logger.info(
|
# api_logger.info(
|
||||||
f"Write service requested for group {user_input.end_user_id}, storage_type: {storage_type}, user_rag_memory_id: {user_rag_memory_id}")
|
# f"Write service requested for group {user_input.end_user_id}, storage_type: {storage_type}, user_rag_memory_id: {user_rag_memory_id}")
|
||||||
try:
|
# try:
|
||||||
messages_list = memory_agent_service.get_messages_list(user_input)
|
# messages_list = memory_agent_service.get_messages_list(user_input)
|
||||||
result = await memory_agent_service.write_memory(
|
# result = await memory_agent_service.write_memory(
|
||||||
user_input.end_user_id,
|
# user_input.end_user_id,
|
||||||
messages_list,
|
# messages_list,
|
||||||
config_id,
|
# config_id,
|
||||||
db,
|
# db,
|
||||||
storage_type,
|
# storage_type,
|
||||||
user_rag_memory_id,
|
# user_rag_memory_id,
|
||||||
language
|
# language
|
||||||
)
|
# )
|
||||||
|
#
|
||||||
return success(data=result, msg="写入成功")
|
# return success(data=result, msg="写入成功")
|
||||||
except BaseException as e:
|
# except BaseException as e:
|
||||||
# Handle ExceptionGroup from TaskGroup (Python 3.11+) or BaseExceptionGroup
|
# # Handle ExceptionGroup from TaskGroup (Python 3.11+) or BaseExceptionGroup
|
||||||
if hasattr(e, 'exceptions'):
|
# if hasattr(e, 'exceptions'):
|
||||||
error_messages = [f"{type(sub_e).__name__}: {str(sub_e)}" for sub_e in e.exceptions]
|
# error_messages = [f"{type(sub_e).__name__}: {str(sub_e)}" for sub_e in e.exceptions]
|
||||||
detailed_error = "; ".join(error_messages)
|
# detailed_error = "; ".join(error_messages)
|
||||||
api_logger.error(f"Write operation error (TaskGroup): {detailed_error}", exc_info=True)
|
# api_logger.error(f"Write operation error (TaskGroup): {detailed_error}", exc_info=True)
|
||||||
return fail(BizCode.INTERNAL_ERROR, "写入失败", detailed_error)
|
# return fail(BizCode.INTERNAL_ERROR, "写入失败", detailed_error)
|
||||||
api_logger.error(f"Write operation error: {str(e)}", exc_info=True)
|
# api_logger.error(f"Write operation error: {str(e)}", exc_info=True)
|
||||||
return fail(BizCode.INTERNAL_ERROR, "写入失败", str(e))
|
# return fail(BizCode.INTERNAL_ERROR, "写入失败", str(e))
|
||||||
|
#
|
||||||
|
#
|
||||||
@router.post("/writer_service_async", response_model=ApiResponse)
|
# @router.post("/writer_service_async", response_model=ApiResponse)
|
||||||
@cur_workspace_access_guard()
|
# @cur_workspace_access_guard()
|
||||||
async def write_server_async(
|
# async def write_server_async(
|
||||||
user_input: Write_UserInput,
|
# user_input: Write_UserInput,
|
||||||
language_type: str = Header(default=None, alias="X-Language-Type"),
|
# language_type: str = Header(default=None, alias="X-Language-Type"),
|
||||||
db: Session = Depends(get_db),
|
# db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user)
|
# current_user: User = Depends(get_current_user)
|
||||||
):
|
# ):
|
||||||
"""
|
# """
|
||||||
Async write service endpoint - enqueues write processing to Celery
|
# Async write service endpoint - enqueues write processing to Celery
|
||||||
|
#
|
||||||
Args:
|
# Args:
|
||||||
user_input: Write request containing message and end_user_id
|
# user_input: Write request containing message and end_user_id
|
||||||
language_type: 语言类型 ("zh" 中文, "en" 英文),通过 X-Language-Type Header 传递
|
# language_type: 语言类型 ("zh" 中文, "en" 英文),通过 X-Language-Type Header 传递
|
||||||
|
#
|
||||||
Returns:
|
# Returns:
|
||||||
Task ID for tracking async operation
|
# Task ID for tracking async operation
|
||||||
Use GET /memory/write_result/{task_id} to check task status and get result
|
# Use GET /memory/write_result/{task_id} to check task status and get result
|
||||||
"""
|
# """
|
||||||
# 使用集中化的语言校验
|
# # 使用集中化的语言校验
|
||||||
language = get_language_from_header(language_type)
|
# language = get_language_from_header(language_type)
|
||||||
|
#
|
||||||
config_id = user_input.config_id
|
# config_id = user_input.config_id
|
||||||
workspace_id = current_user.current_workspace_id
|
# workspace_id = current_user.current_workspace_id
|
||||||
api_logger.info(
|
# api_logger.info(
|
||||||
f"Async write service: workspace_id={workspace_id}, config_id={config_id}, language_type={language}")
|
# f"Async write service: workspace_id={workspace_id}, config_id={config_id}, language_type={language}")
|
||||||
|
#
|
||||||
# 获取 storage_type,如果为 None 则使用默认值
|
# # 获取 storage_type,如果为 None 则使用默认值
|
||||||
storage_type = workspace_service.get_workspace_storage_type(
|
# storage_type = workspace_service.get_workspace_storage_type(
|
||||||
db=db,
|
# db=db,
|
||||||
workspace_id=workspace_id,
|
# workspace_id=workspace_id,
|
||||||
user=current_user
|
# user=current_user
|
||||||
)
|
# )
|
||||||
if storage_type is None: storage_type = 'neo4j'
|
# if storage_type is None: storage_type = 'neo4j'
|
||||||
user_rag_memory_id = ''
|
# user_rag_memory_id = ''
|
||||||
if workspace_id:
|
# if workspace_id:
|
||||||
|
#
|
||||||
knowledge = knowledge_repository.get_knowledge_by_name(
|
# knowledge = knowledge_repository.get_knowledge_by_name(
|
||||||
db=db,
|
# db=db,
|
||||||
name="USER_RAG_MERORY",
|
# name="USER_RAG_MERORY",
|
||||||
workspace_id=workspace_id
|
# workspace_id=workspace_id
|
||||||
)
|
# )
|
||||||
if knowledge: user_rag_memory_id = str(knowledge.id)
|
# if knowledge: user_rag_memory_id = str(knowledge.id)
|
||||||
api_logger.info(f"Async write: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}")
|
# api_logger.info(f"Async write: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}")
|
||||||
try:
|
# try:
|
||||||
# 获取标准化的消息列表
|
# # 获取标准化的消息列表
|
||||||
messages_list = memory_agent_service.get_messages_list(user_input)
|
# messages_list = memory_agent_service.get_messages_list(user_input)
|
||||||
|
#
|
||||||
task = celery_app.send_task(
|
# task = celery_app.send_task(
|
||||||
"app.core.memory.agent.write_message",
|
# "app.core.memory.agent.write_message",
|
||||||
args=[user_input.end_user_id, messages_list, config_id, storage_type, user_rag_memory_id, language]
|
# args=[user_input.end_user_id, messages_list, config_id, storage_type, user_rag_memory_id, language]
|
||||||
)
|
# )
|
||||||
api_logger.info(f"Write task queued: {task.id}")
|
# api_logger.info(f"Write task queued: {task.id}")
|
||||||
|
#
|
||||||
return success(data={"task_id": task.id}, msg="写入任务已提交")
|
# return success(data={"task_id": task.id}, msg="写入任务已提交")
|
||||||
except Exception as e:
|
# except Exception as e:
|
||||||
api_logger.error(f"Async write operation failed: {str(e)}")
|
# api_logger.error(f"Async write operation failed: {str(e)}")
|
||||||
return fail(BizCode.INTERNAL_ERROR, "写入失败", str(e))
|
# return fail(BizCode.INTERNAL_ERROR, "写入失败", str(e))
|
||||||
|
|
||||||
|
|
||||||
@router.post("/read_service", response_model=ApiResponse)
|
@router.post("/read_service", response_model=ApiResponse)
|
||||||
|
|||||||
@@ -195,10 +195,9 @@ async def get_workspace_end_users(
|
|||||||
api_logger.warning(f"Redis 缓存写入失败: {str(e)}")
|
api_logger.warning(f"Redis 缓存写入失败: {str(e)}")
|
||||||
|
|
||||||
# 触发社区聚类补全任务(异步,不阻塞接口响应)
|
# 触发社区聚类补全任务(异步,不阻塞接口响应)
|
||||||
# 对有 ExtractedEntity 但无 Community 节点的存量用户自动补跑全量聚类
|
|
||||||
try:
|
try:
|
||||||
from app.tasks import init_community_clustering_for_users
|
from app.tasks import init_community_clustering_for_users
|
||||||
init_community_clustering_for_users.delay(end_user_ids=end_user_ids)
|
init_community_clustering_for_users.delay(end_user_ids=end_user_ids, workspace_id=str(workspace_id))
|
||||||
api_logger.info(f"已触发社区聚类补全任务,候选用户数: {len(end_user_ids)}")
|
api_logger.info(f"已触发社区聚类补全任务,候选用户数: {len(end_user_ids)}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
api_logger.warning(f"触发社区聚类补全任务失败(不影响主流程): {str(e)}")
|
api_logger.warning(f"触发社区聚类补全任务失败(不影响主流程): {str(e)}")
|
||||||
@@ -664,9 +663,12 @@ async def dashboard_data(
|
|||||||
rag_data["total_memory"] = total_chunk
|
rag_data["total_memory"] = total_chunk
|
||||||
|
|
||||||
# total_app: 统计当前空间下的所有app数量
|
# total_app: 统计当前空间下的所有app数量
|
||||||
from app.repositories import app_repository
|
# 包含自有app + 被分享给本工作空间的app
|
||||||
apps_orm = app_repository.get_apps_by_workspace_id(db, workspace_id)
|
from app.services import app_service as _app_svc
|
||||||
rag_data["total_app"] = len(apps_orm)
|
_, total_app = _app_svc.AppService(db).list_apps(
|
||||||
|
workspace_id=workspace_id, include_shared=True, pagesize=1
|
||||||
|
)
|
||||||
|
rag_data["total_app"] = total_app
|
||||||
|
|
||||||
# total_knowledge: 使用 total_kb(总知识库数)
|
# total_knowledge: 使用 total_kb(总知识库数)
|
||||||
total_kb = memory_dashboard_service.get_rag_total_kb(db, current_user)
|
total_kb = memory_dashboard_service.get_rag_total_kb(db, current_user)
|
||||||
@@ -688,7 +690,7 @@ async def dashboard_data(
|
|||||||
api_logger.warning(f"获取RAG模式API调用统计失败,使用默认值: {str(e)}")
|
api_logger.warning(f"获取RAG模式API调用统计失败,使用默认值: {str(e)}")
|
||||||
rag_data["total_api_call"] = 0
|
rag_data["total_api_call"] = 0
|
||||||
|
|
||||||
api_logger.info(f"成功获取RAG相关数据: memory={total_chunk}, app={len(apps_orm)}, knowledge={total_kb}, api_calls={rag_data['total_api_call']}")
|
api_logger.info(f"成功获取RAG相关数据: memory={total_chunk}, app={total_app}, knowledge={total_kb}, api_calls={rag_data['total_api_call']}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
api_logger.warning(f"获取RAG相关数据失败: {str(e)}")
|
api_logger.warning(f"获取RAG相关数据失败: {str(e)}")
|
||||||
|
|
||||||
|
|||||||
@@ -31,6 +31,7 @@ from app.schemas.memory_storage_schema import (
|
|||||||
ForgettingCurveRequest,
|
ForgettingCurveRequest,
|
||||||
ForgettingCurveResponse,
|
ForgettingCurveResponse,
|
||||||
ForgettingCurvePoint,
|
ForgettingCurvePoint,
|
||||||
|
PendingNodesResponse,
|
||||||
)
|
)
|
||||||
from app.schemas.response_schema import ApiResponse
|
from app.schemas.response_schema import ApiResponse
|
||||||
from app.services.memory_forget_service import MemoryForgetService
|
from app.services.memory_forget_service import MemoryForgetService
|
||||||
@@ -308,6 +309,100 @@ async def get_forgetting_stats(
|
|||||||
return fail(BizCode.INTERNAL_ERROR, "获取遗忘引擎统计失败", str(e))
|
return fail(BizCode.INTERNAL_ERROR, "获取遗忘引擎统计失败", str(e))
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/pending-nodes", response_model=ApiResponse)
|
||||||
|
async def get_pending_nodes(
|
||||||
|
end_user_id: str,
|
||||||
|
page: int = 1,
|
||||||
|
pagesize: int = 10,
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
db: Session = Depends(get_db)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
获取待遗忘节点列表(独立分页接口)
|
||||||
|
|
||||||
|
查询满足遗忘条件的节点(激活值低于阈值且最后访问时间超过最小天数)。
|
||||||
|
此接口独立分页,与 /stats 接口分离。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
end_user_id: 组ID(即 end_user_id,必填)
|
||||||
|
page: 页码(从1开始,默认1)
|
||||||
|
pagesize: 每页数量(默认10)
|
||||||
|
current_user: 当前用户
|
||||||
|
db: 数据库会话
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ApiResponse: 包含待遗忘节点列表和分页信息的响应
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
- 第1页,每页10条:GET /memory/forget-memory/pending-nodes?end_user_id=xxx&page=1&pagesize=10
|
||||||
|
- 第2页,每页20条:GET /memory/forget-memory/pending-nodes?end_user_id=xxx&page=2&pagesize=20
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
- page 从1开始,pagesize 必须大于0
|
||||||
|
- 返回格式:{"items": [...], "page": {"page": 1, "pagesize": 10, "total": 100, "hasnext": true}}
|
||||||
|
"""
|
||||||
|
workspace_id = current_user.current_workspace_id
|
||||||
|
# 检查用户是否已选择工作空间
|
||||||
|
if workspace_id is None:
|
||||||
|
api_logger.warning(f"用户 {current_user.username} 尝试获取待遗忘节点但未选择工作空间")
|
||||||
|
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
||||||
|
|
||||||
|
# 验证 end_user_id 必填
|
||||||
|
if not end_user_id:
|
||||||
|
api_logger.warning(f"用户 {current_user.username} 尝试获取待遗忘节点但未提供 end_user_id")
|
||||||
|
return fail(BizCode.INVALID_PARAMETER, "end_user_id 不能为空", "end_user_id is required")
|
||||||
|
|
||||||
|
# 通过 end_user_id 获取关联的 config_id
|
||||||
|
try:
|
||||||
|
from app.services.memory_agent_service import get_end_user_connected_config
|
||||||
|
|
||||||
|
connected_config = get_end_user_connected_config(end_user_id, db)
|
||||||
|
config_id = connected_config.get("memory_config_id")
|
||||||
|
config_id = resolve_config_id(config_id, db)
|
||||||
|
|
||||||
|
if config_id is None:
|
||||||
|
api_logger.warning(f"终端用户 {end_user_id} 未关联记忆配置")
|
||||||
|
return fail(BizCode.INVALID_PARAMETER, f"终端用户 {end_user_id} 未关联记忆配置", "memory_config_id is None")
|
||||||
|
|
||||||
|
api_logger.debug(f"通过 end_user_id={end_user_id} 获取到 config_id={config_id}")
|
||||||
|
except ValueError as e:
|
||||||
|
api_logger.warning(f"获取终端用户配置失败: {str(e)}")
|
||||||
|
return fail(BizCode.INVALID_PARAMETER, str(e), "ValueError")
|
||||||
|
except Exception as e:
|
||||||
|
api_logger.error(f"获取终端用户配置时发生错误: {str(e)}")
|
||||||
|
return fail(BizCode.INTERNAL_ERROR, "获取终端用户配置失败", str(e))
|
||||||
|
|
||||||
|
# 验证分页参数
|
||||||
|
if page < 1:
|
||||||
|
return fail(BizCode.INVALID_PARAMETER, "page 必须大于等于1", "page < 1")
|
||||||
|
if pagesize < 1:
|
||||||
|
return fail(BizCode.INVALID_PARAMETER, "pagesize 必须大于等于1", "pagesize < 1")
|
||||||
|
|
||||||
|
api_logger.info(
|
||||||
|
f"用户 {current_user.username} 在工作空间 {workspace_id} 请求获取待遗忘节点: "
|
||||||
|
f"end_user_id={end_user_id}, page={page}, pagesize={pagesize}"
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 调用服务层获取待遗忘节点列表
|
||||||
|
result = await forget_service.get_pending_nodes(
|
||||||
|
db=db,
|
||||||
|
end_user_id=end_user_id,
|
||||||
|
config_id=config_id,
|
||||||
|
page=page,
|
||||||
|
pagesize=pagesize
|
||||||
|
)
|
||||||
|
|
||||||
|
# 构建响应
|
||||||
|
response_data = PendingNodesResponse(**result)
|
||||||
|
|
||||||
|
return success(data=response_data.model_dump(), msg="查询成功")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
api_logger.error(f"获取待遗忘节点列表失败: {str(e)}")
|
||||||
|
return fail(BizCode.INTERNAL_ERROR, "获取待遗忘节点列表失败", str(e))
|
||||||
|
|
||||||
|
|
||||||
@router.post("/forgetting_curve", response_model=ApiResponse)
|
@router.post("/forgetting_curve", response_model=ApiResponse)
|
||||||
async def get_forgetting_curve(
|
async def get_forgetting_curve(
|
||||||
request: ForgettingCurveRequest,
|
request: ForgettingCurveRequest,
|
||||||
|
|||||||
@@ -54,8 +54,8 @@ router = APIRouter(
|
|||||||
|
|
||||||
@router.get("/info", response_model=ApiResponse)
|
@router.get("/info", response_model=ApiResponse)
|
||||||
async def get_storage_info(
|
async def get_storage_info(
|
||||||
storage_id: str,
|
storage_id: str,
|
||||||
current_user: User = Depends(get_current_user)
|
current_user: User = Depends(get_current_user)
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Example wrapper endpoint - retrieves storage information
|
Example wrapper endpoint - retrieves storage information
|
||||||
@@ -75,17 +75,12 @@ async def get_storage_info(
|
|||||||
return fail(BizCode.INTERNAL_ERROR, "存储信息获取失败", str(e))
|
return fail(BizCode.INTERNAL_ERROR, "存储信息获取失败", str(e))
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/create_config", response_model=ApiResponse) # 创建配置文件,其他参数默认
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/create_config", response_model=ApiResponse) # 创建配置文件,其他参数默认
|
|
||||||
def create_config(
|
def create_config(
|
||||||
payload: ConfigParamsCreate,
|
payload: ConfigParamsCreate,
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
x_language_type: Optional[str] = Header(None, alias="X-Language-Type"),
|
x_language_type: Optional[str] = Header(None, alias="X-Language-Type"),
|
||||||
) -> dict:
|
) -> dict:
|
||||||
workspace_id = current_user.current_workspace_id
|
workspace_id = current_user.current_workspace_id
|
||||||
# 检查用户是否已选择工作空间
|
# 检查用户是否已选择工作空间
|
||||||
@@ -107,9 +102,11 @@ def create_config(
|
|||||||
api_logger.warning(f"重复的配置名称 '{config_name}' 在工作空间 {workspace_id}")
|
api_logger.warning(f"重复的配置名称 '{config_name}' 在工作空间 {workspace_id}")
|
||||||
lang = get_language_from_header(x_language_type)
|
lang = get_language_from_header(x_language_type)
|
||||||
if lang == "en":
|
if lang == "en":
|
||||||
msg = fail(BizCode.BAD_REQUEST, "Config name already exists", f"A config named \"{config_name}\" already exists in the current workspace. Please use a different name.")
|
msg = fail(BizCode.BAD_REQUEST, "Config name already exists",
|
||||||
|
f"A config named \"{config_name}\" already exists in the current workspace. Please use a different name.")
|
||||||
else:
|
else:
|
||||||
msg = fail(BizCode.BAD_REQUEST, "配置名称已存在", f"当前工作空间下已存在名为「{config_name}」的记忆配置,请使用其他名称")
|
msg = fail(BizCode.BAD_REQUEST, "配置名称已存在",
|
||||||
|
f"当前工作空间下已存在名为「{config_name}」的记忆配置,请使用其他名称")
|
||||||
return JSONResponse(status_code=400, content=msg)
|
return JSONResponse(status_code=400, content=msg)
|
||||||
api_logger.error(f"Create config failed: {err_str}")
|
api_logger.error(f"Create config failed: {err_str}")
|
||||||
return fail(BizCode.INTERNAL_ERROR, "创建配置失败", err_str)
|
return fail(BizCode.INTERNAL_ERROR, "创建配置失败", err_str)
|
||||||
@@ -119,9 +116,11 @@ def create_config(
|
|||||||
api_logger.warning(f"重复的配置名称 '{payload.config_name}' 在工作空间 {workspace_id}")
|
api_logger.warning(f"重复的配置名称 '{payload.config_name}' 在工作空间 {workspace_id}")
|
||||||
lang = get_language_from_header(x_language_type)
|
lang = get_language_from_header(x_language_type)
|
||||||
if lang == "en":
|
if lang == "en":
|
||||||
msg = fail(BizCode.BAD_REQUEST, "Config name already exists", f"A config named \"{payload.config_name}\" already exists in the current workspace. Please use a different name.")
|
msg = fail(BizCode.BAD_REQUEST, "Config name already exists",
|
||||||
|
f"A config named \"{payload.config_name}\" already exists in the current workspace. Please use a different name.")
|
||||||
else:
|
else:
|
||||||
msg = fail(BizCode.BAD_REQUEST, "配置名称已存在", f"当前工作空间下已存在名为「{payload.config_name}」的记忆配置,请使用其他名称")
|
msg = fail(BizCode.BAD_REQUEST, "配置名称已存在",
|
||||||
|
f"当前工作空间下已存在名为「{payload.config_name}」的记忆配置,请使用其他名称")
|
||||||
return JSONResponse(status_code=400, content=msg)
|
return JSONResponse(status_code=400, content=msg)
|
||||||
api_logger.error(f"Create config failed: {str(e)}")
|
api_logger.error(f"Create config failed: {str(e)}")
|
||||||
return fail(BizCode.INTERNAL_ERROR, "创建配置失败", str(e))
|
return fail(BizCode.INTERNAL_ERROR, "创建配置失败", str(e))
|
||||||
@@ -129,10 +128,10 @@ def create_config(
|
|||||||
|
|
||||||
@router.delete("/delete_config", response_model=ApiResponse) # 删除数据库中的内容(按配置名称)
|
@router.delete("/delete_config", response_model=ApiResponse) # 删除数据库中的内容(按配置名称)
|
||||||
def delete_config(
|
def delete_config(
|
||||||
config_id: UUID|int,
|
config_id: UUID | int,
|
||||||
force: bool = Query(False, description="是否强制删除(即使有终端用户正在使用)"),
|
force: bool = Query(False, description="是否强制删除(即使有终端用户正在使用)"),
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""删除记忆配置(带终端用户保护)
|
"""删除记忆配置(带终端用户保护)
|
||||||
|
|
||||||
@@ -145,7 +144,7 @@ def delete_config(
|
|||||||
force: 设置为 true 可强制删除(即使有终端用户正在使用)
|
force: 设置为 true 可强制删除(即使有终端用户正在使用)
|
||||||
"""
|
"""
|
||||||
workspace_id = current_user.current_workspace_id
|
workspace_id = current_user.current_workspace_id
|
||||||
config_id=resolve_config_id(config_id, db)
|
config_id = resolve_config_id(config_id, db)
|
||||||
# 检查用户是否已选择工作空间
|
# 检查用户是否已选择工作空间
|
||||||
if workspace_id is None:
|
if workspace_id is None:
|
||||||
api_logger.warning(f"用户 {current_user.username} 尝试删除配置但未选择工作空间")
|
api_logger.warning(f"用户 {current_user.username} 尝试删除配置但未选择工作空间")
|
||||||
@@ -203,9 +202,9 @@ def delete_config(
|
|||||||
|
|
||||||
@router.post("/update_config", response_model=ApiResponse) # 更新配置文件中name和desc
|
@router.post("/update_config", response_model=ApiResponse) # 更新配置文件中name和desc
|
||||||
def update_config(
|
def update_config(
|
||||||
payload: ConfigUpdate,
|
payload: ConfigUpdate,
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
) -> dict:
|
) -> dict:
|
||||||
workspace_id = current_user.current_workspace_id
|
workspace_id = current_user.current_workspace_id
|
||||||
payload.config_id = resolve_config_id(payload.config_id, db)
|
payload.config_id = resolve_config_id(payload.config_id, db)
|
||||||
@@ -217,7 +216,8 @@ def update_config(
|
|||||||
# 校验至少有一个字段需要更新
|
# 校验至少有一个字段需要更新
|
||||||
if payload.config_name is None and payload.config_desc is None and payload.scene_id is None:
|
if payload.config_name is None and payload.config_desc is None and payload.scene_id is None:
|
||||||
api_logger.warning(f"用户 {current_user.username} 尝试更新配置但未提供任何更新字段")
|
api_logger.warning(f"用户 {current_user.username} 尝试更新配置但未提供任何更新字段")
|
||||||
return fail(BizCode.INVALID_PARAMETER, "请至少提供一个需要更新的字段", "config_name, config_desc, scene_id 均为空")
|
return fail(BizCode.INVALID_PARAMETER, "请至少提供一个需要更新的字段",
|
||||||
|
"config_name, config_desc, scene_id 均为空")
|
||||||
|
|
||||||
api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求更新配置: {payload.config_id}")
|
api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求更新配置: {payload.config_id}")
|
||||||
try:
|
try:
|
||||||
@@ -231,9 +231,9 @@ def update_config(
|
|||||||
|
|
||||||
@router.post("/update_config_extracted", response_model=ApiResponse) # 更新数据库中的部分内容 所有业务字段均可选
|
@router.post("/update_config_extracted", response_model=ApiResponse) # 更新数据库中的部分内容 所有业务字段均可选
|
||||||
def update_config_extracted(
|
def update_config_extracted(
|
||||||
payload: ConfigUpdateExtracted,
|
payload: ConfigUpdateExtracted,
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
) -> dict:
|
) -> dict:
|
||||||
workspace_id = current_user.current_workspace_id
|
workspace_id = current_user.current_workspace_id
|
||||||
payload.config_id = resolve_config_id(payload.config_id, db)
|
payload.config_id = resolve_config_id(payload.config_id, db)
|
||||||
@@ -256,11 +256,11 @@ def update_config_extracted(
|
|||||||
# 遗忘引擎配置接口已迁移到 memory_forget_controller.py
|
# 遗忘引擎配置接口已迁移到 memory_forget_controller.py
|
||||||
# 使用新接口: /api/memory/forget/read_config 和 /api/memory/forget/update_config
|
# 使用新接口: /api/memory/forget/read_config 和 /api/memory/forget/update_config
|
||||||
|
|
||||||
@router.get("/read_config_extracted", response_model=ApiResponse) # 通过查询参数读取某条配置(固定路径) 没有意义的话就删除
|
@router.get("/read_config_extracted", response_model=ApiResponse) # 通过查询参数读取某条配置(固定路径) 没有意义的话就删除
|
||||||
def read_config_extracted(
|
def read_config_extracted(
|
||||||
config_id: UUID | int,
|
config_id: UUID | int,
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
) -> dict:
|
) -> dict:
|
||||||
workspace_id = current_user.current_workspace_id
|
workspace_id = current_user.current_workspace_id
|
||||||
config_id = resolve_config_id(config_id, db)
|
config_id = resolve_config_id(config_id, db)
|
||||||
@@ -278,10 +278,11 @@ def read_config_extracted(
|
|||||||
api_logger.error(f"Read config extracted failed: {str(e)}")
|
api_logger.error(f"Read config extracted failed: {str(e)}")
|
||||||
return fail(BizCode.INTERNAL_ERROR, "查询配置失败", str(e))
|
return fail(BizCode.INTERNAL_ERROR, "查询配置失败", str(e))
|
||||||
|
|
||||||
@router.get("/read_all_config", response_model=ApiResponse) # 读取所有配置文件列表
|
|
||||||
|
@router.get("/read_all_config", response_model=ApiResponse) # 读取所有配置文件列表
|
||||||
def read_all_config(
|
def read_all_config(
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
) -> dict:
|
) -> dict:
|
||||||
workspace_id = current_user.current_workspace_id
|
workspace_id = current_user.current_workspace_id
|
||||||
|
|
||||||
@@ -303,10 +304,10 @@ def read_all_config(
|
|||||||
|
|
||||||
@router.post("/pilot_run", response_model=None)
|
@router.post("/pilot_run", response_model=None)
|
||||||
async def pilot_run(
|
async def pilot_run(
|
||||||
payload: ConfigPilotRun,
|
payload: ConfigPilotRun,
|
||||||
language_type: str = Header(default=None, alias="X-Language-Type"),
|
language_type: str = Header(default=None, alias="X-Language-Type"),
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
) -> StreamingResponse:
|
) -> StreamingResponse:
|
||||||
# 使用集中化的语言校验
|
# 使用集中化的语言校验
|
||||||
language = get_language_from_header(language_type)
|
language = get_language_from_header(language_type)
|
||||||
@@ -333,9 +334,9 @@ async def pilot_run(
|
|||||||
|
|
||||||
@router.get("/search/kb_type_distribution", response_model=ApiResponse)
|
@router.get("/search/kb_type_distribution", response_model=ApiResponse)
|
||||||
async def get_kb_type_distribution(
|
async def get_kb_type_distribution(
|
||||||
end_user_id: Optional[str] = None,
|
end_user_id: Optional[str] = None,
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
) -> dict:
|
) -> dict:
|
||||||
api_logger.info(f"KB type distribution requested for end_user_id: {end_user_id}")
|
api_logger.info(f"KB type distribution requested for end_user_id: {end_user_id}")
|
||||||
try:
|
try:
|
||||||
result = await kb_type_distribution(end_user_id)
|
result = await kb_type_distribution(end_user_id)
|
||||||
@@ -347,9 +348,9 @@ async def get_kb_type_distribution(
|
|||||||
|
|
||||||
@router.get("/search/dialogue", response_model=ApiResponse)
|
@router.get("/search/dialogue", response_model=ApiResponse)
|
||||||
async def search_dialogues_num(
|
async def search_dialogues_num(
|
||||||
end_user_id: Optional[str] = None,
|
end_user_id: Optional[str] = None,
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
) -> dict:
|
) -> dict:
|
||||||
api_logger.info(f"Search dialogue requested for end_user_id: {end_user_id}")
|
api_logger.info(f"Search dialogue requested for end_user_id: {end_user_id}")
|
||||||
try:
|
try:
|
||||||
result = await search_dialogue(end_user_id)
|
result = await search_dialogue(end_user_id)
|
||||||
@@ -361,9 +362,9 @@ async def search_dialogues_num(
|
|||||||
|
|
||||||
@router.get("/search/chunk", response_model=ApiResponse)
|
@router.get("/search/chunk", response_model=ApiResponse)
|
||||||
async def search_chunks_num(
|
async def search_chunks_num(
|
||||||
end_user_id: Optional[str] = None,
|
end_user_id: Optional[str] = None,
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
) -> dict:
|
) -> dict:
|
||||||
api_logger.info(f"Search chunk requested for end_user_id: {end_user_id}")
|
api_logger.info(f"Search chunk requested for end_user_id: {end_user_id}")
|
||||||
try:
|
try:
|
||||||
result = await search_chunk(end_user_id)
|
result = await search_chunk(end_user_id)
|
||||||
@@ -375,9 +376,9 @@ async def search_chunks_num(
|
|||||||
|
|
||||||
@router.get("/search/statement", response_model=ApiResponse)
|
@router.get("/search/statement", response_model=ApiResponse)
|
||||||
async def search_statements_num(
|
async def search_statements_num(
|
||||||
end_user_id: Optional[str] = None,
|
end_user_id: Optional[str] = None,
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
) -> dict:
|
) -> dict:
|
||||||
api_logger.info(f"Search statement requested for end_user_id: {end_user_id}")
|
api_logger.info(f"Search statement requested for end_user_id: {end_user_id}")
|
||||||
try:
|
try:
|
||||||
result = await search_statement(end_user_id)
|
result = await search_statement(end_user_id)
|
||||||
@@ -389,9 +390,9 @@ async def search_statements_num(
|
|||||||
|
|
||||||
@router.get("/search/entity", response_model=ApiResponse)
|
@router.get("/search/entity", response_model=ApiResponse)
|
||||||
async def search_entities_num(
|
async def search_entities_num(
|
||||||
end_user_id: Optional[str] = None,
|
end_user_id: Optional[str] = None,
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
) -> dict:
|
) -> dict:
|
||||||
api_logger.info(f"Search entity requested for end_user_id: {end_user_id}")
|
api_logger.info(f"Search entity requested for end_user_id: {end_user_id}")
|
||||||
try:
|
try:
|
||||||
result = await search_entity(end_user_id)
|
result = await search_entity(end_user_id)
|
||||||
@@ -403,9 +404,9 @@ async def search_entities_num(
|
|||||||
|
|
||||||
@router.get("/search", response_model=ApiResponse)
|
@router.get("/search", response_model=ApiResponse)
|
||||||
async def search_all_num(
|
async def search_all_num(
|
||||||
end_user_id: Optional[str] = None,
|
end_user_id: Optional[str] = None,
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
) -> dict:
|
) -> dict:
|
||||||
api_logger.info(f"Search all requested for end_user_id: {end_user_id}")
|
api_logger.info(f"Search all requested for end_user_id: {end_user_id}")
|
||||||
try:
|
try:
|
||||||
result = await search_all(end_user_id)
|
result = await search_all(end_user_id)
|
||||||
@@ -417,9 +418,9 @@ async def search_all_num(
|
|||||||
|
|
||||||
@router.get("/search/detials", response_model=ApiResponse)
|
@router.get("/search/detials", response_model=ApiResponse)
|
||||||
async def search_entities_detials(
|
async def search_entities_detials(
|
||||||
end_user_id: Optional[str] = None,
|
end_user_id: Optional[str] = None,
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
) -> dict:
|
) -> dict:
|
||||||
api_logger.info(f"Search details requested for end_user_id: {end_user_id}")
|
api_logger.info(f"Search details requested for end_user_id: {end_user_id}")
|
||||||
try:
|
try:
|
||||||
result = await search_detials(end_user_id)
|
result = await search_detials(end_user_id)
|
||||||
@@ -431,9 +432,9 @@ async def search_entities_detials(
|
|||||||
|
|
||||||
@router.get("/search/edges", response_model=ApiResponse)
|
@router.get("/search/edges", response_model=ApiResponse)
|
||||||
async def search_entity_edges(
|
async def search_entity_edges(
|
||||||
end_user_id: Optional[str] = None,
|
end_user_id: Optional[str] = None,
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
) -> dict:
|
) -> dict:
|
||||||
api_logger.info(f"Search edges requested for end_user_id: {end_user_id}")
|
api_logger.info(f"Search edges requested for end_user_id: {end_user_id}")
|
||||||
try:
|
try:
|
||||||
result = await search_edges(end_user_id)
|
result = await search_edges(end_user_id)
|
||||||
@@ -443,14 +444,12 @@ async def search_entity_edges(
|
|||||||
return fail(BizCode.INTERNAL_ERROR, "边查询失败", str(e))
|
return fail(BizCode.INTERNAL_ERROR, "边查询失败", str(e))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/analytics/hot_memory_tags", response_model=ApiResponse)
|
@router.get("/analytics/hot_memory_tags", response_model=ApiResponse)
|
||||||
async def get_hot_memory_tags_api(
|
async def get_hot_memory_tags_api(
|
||||||
limit: int = 10,
|
limit: int = 10,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""
|
"""
|
||||||
获取热门记忆标签(带Redis缓存)
|
获取热门记忆标签(带Redis缓存)
|
||||||
|
|
||||||
@@ -505,8 +504,8 @@ async def get_hot_memory_tags_api(
|
|||||||
|
|
||||||
@router.delete("/analytics/hot_memory_tags/cache", response_model=ApiResponse)
|
@router.delete("/analytics/hot_memory_tags/cache", response_model=ApiResponse)
|
||||||
async def clear_hot_memory_tags_cache(
|
async def clear_hot_memory_tags_cache(
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""
|
"""
|
||||||
清除热门标签缓存
|
清除热门标签缓存
|
||||||
|
|
||||||
@@ -543,7 +542,7 @@ async def clear_hot_memory_tags_cache(
|
|||||||
|
|
||||||
@router.get("/analytics/recent_activity_stats", response_model=ApiResponse)
|
@router.get("/analytics/recent_activity_stats", response_model=ApiResponse)
|
||||||
async def get_recent_activity_stats_api(
|
async def get_recent_activity_stats_api(
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
) -> dict:
|
) -> dict:
|
||||||
workspace_id = str(current_user.current_workspace_id) if current_user.current_workspace_id else None
|
workspace_id = str(current_user.current_workspace_id) if current_user.current_workspace_id else None
|
||||||
api_logger.info(f"Recent activity stats requested: workspace_id={workspace_id}")
|
api_logger.info(f"Recent activity stats requested: workspace_id={workspace_id}")
|
||||||
@@ -553,4 +552,3 @@ async def get_recent_activity_stats_api(
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
api_logger.error(f"Recent activity stats failed: {str(e)}")
|
api_logger.error(f"Recent activity stats failed: {str(e)}")
|
||||||
return fail(BizCode.INTERNAL_ERROR, "最近活动统计失败", str(e))
|
return fail(BizCode.INTERNAL_ERROR, "最近活动统计失败", str(e))
|
||||||
|
|
||||||
|
|||||||
@@ -33,35 +33,47 @@ def get_memory_count(
|
|||||||
@router.get("/{end_user_id}/conversations", response_model=ApiResponse)
|
@router.get("/{end_user_id}/conversations", response_model=ApiResponse)
|
||||||
def get_conversations(
|
def get_conversations(
|
||||||
end_user_id: uuid.UUID,
|
end_user_id: uuid.UUID,
|
||||||
|
page: int = 1,
|
||||||
|
pagesize: int = 20,
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: Session = Depends(get_db)
|
db: Session = Depends(get_db)
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Retrieve all conversations for the current user in a specific group.
|
Retrieve conversations for the current user in a specific group with pagination.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
end_user_id (UUID): The group identifier.
|
end_user_id (UUID): The group identifier.
|
||||||
|
page (int): Page number (1-based). Defaults to 1.
|
||||||
|
pagesize (int): Number of items per page. Defaults to 20.
|
||||||
current_user (User, optional): The authenticated user.
|
current_user (User, optional): The authenticated user.
|
||||||
db (Session, optional): SQLAlchemy session.
|
db (Session, optional): SQLAlchemy session.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
ApiResponse: Contains a list of conversation IDs.
|
ApiResponse: Contains a paginated list of conversations.
|
||||||
|
|
||||||
Notes:
|
|
||||||
- Initializes the ConversationService with the current DB session.
|
|
||||||
- Returns only conversation IDs for lightweight response.
|
|
||||||
- Logs can be added to trace requests in production.
|
|
||||||
"""
|
"""
|
||||||
|
page = max(1, page)
|
||||||
|
page_size = max(1, min(pagesize, 100)) # Limit page size between 1 and 100
|
||||||
conversation_service = ConversationService(db)
|
conversation_service = ConversationService(db)
|
||||||
conversations = conversation_service.get_user_conversations(
|
conversations, total = conversation_service.get_user_conversations(
|
||||||
end_user_id
|
end_user_id,
|
||||||
|
page=page,
|
||||||
|
page_size=page_size
|
||||||
)
|
)
|
||||||
return success(data=[
|
return success(data={
|
||||||
{
|
"items": [
|
||||||
"id": conversation.id,
|
{
|
||||||
"title": conversation.title
|
"id": conversation.id,
|
||||||
} for conversation in conversations
|
"title": conversation.title
|
||||||
], msg="get conversations success")
|
} for conversation in conversations
|
||||||
|
],
|
||||||
|
"total": total,
|
||||||
|
"page": {
|
||||||
|
"page": page,
|
||||||
|
"pagesize": page_size,
|
||||||
|
"total": total,
|
||||||
|
"hasnext": (page * page_size) < total
|
||||||
|
},
|
||||||
|
}, msg="get conversations success")
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{end_user_id}/messages", response_model=ApiResponse)
|
@router.get("/{end_user_id}/messages", response_model=ApiResponse)
|
||||||
|
|||||||
@@ -42,6 +42,7 @@ def get_model_strategies():
|
|||||||
@router.get("", response_model=ApiResponse)
|
@router.get("", response_model=ApiResponse)
|
||||||
def get_model_list(
|
def get_model_list(
|
||||||
type: Optional[list[str]] = Query(None, description="模型类型筛选(支持多个,如 ?type=LLM 或 ?type=LLM,EMBEDDING)"),
|
type: Optional[list[str]] = Query(None, description="模型类型筛选(支持多个,如 ?type=LLM 或 ?type=LLM,EMBEDDING)"),
|
||||||
|
capability: Optional[list[str]] = Query(None, description="能力筛选(支持多个,如 ?capability=chat 或 ?capability=chat, embedding)"),
|
||||||
provider: Optional[model_schema.ModelProvider] = Query(None, description="提供商筛选(基于API Key)"),
|
provider: Optional[model_schema.ModelProvider] = Query(None, description="提供商筛选(基于API Key)"),
|
||||||
is_active: Optional[bool] = Query(None, description="激活状态筛选"),
|
is_active: Optional[bool] = Query(None, description="激活状态筛选"),
|
||||||
is_public: Optional[bool] = Query(None, description="公开状态筛选"),
|
is_public: Optional[bool] = Query(None, description="公开状态筛选"),
|
||||||
@@ -74,10 +75,21 @@ def get_model_list(
|
|||||||
unique_flat_type = list(dict.fromkeys(flat_type))
|
unique_flat_type = list(dict.fromkeys(flat_type))
|
||||||
type_list = [ModelType(t.lower()) for t in unique_flat_type]
|
type_list = [ModelType(t.lower()) for t in unique_flat_type]
|
||||||
|
|
||||||
|
capability_list = []
|
||||||
|
if capability is not None:
|
||||||
|
flat_capability = []
|
||||||
|
for item in capability:
|
||||||
|
split_items = [c.strip() for c in item.split(', ') if c.strip()]
|
||||||
|
flat_capability.extend(split_items)
|
||||||
|
|
||||||
|
unique_flat_capability = list(dict.fromkeys(flat_capability))
|
||||||
|
capability_list = unique_flat_capability
|
||||||
|
|
||||||
api_logger.error(f"获取模型type_list: {type_list}")
|
api_logger.error(f"获取模型type_list: {type_list}")
|
||||||
query = model_schema.ModelConfigQuery(
|
query = model_schema.ModelConfigQuery(
|
||||||
type=type_list,
|
type=type_list,
|
||||||
provider=provider,
|
provider=provider,
|
||||||
|
capability=capability_list,
|
||||||
is_active=is_active,
|
is_active=is_active,
|
||||||
is_public=is_public,
|
is_public=is_public,
|
||||||
search=search,
|
search=search,
|
||||||
|
|||||||
@@ -27,6 +27,7 @@ from app.services.conversation_service import ConversationService
|
|||||||
from app.services.release_share_service import ReleaseShareService
|
from app.services.release_share_service import ReleaseShareService
|
||||||
from app.services.shared_chat_service import SharedChatService
|
from app.services.shared_chat_service import SharedChatService
|
||||||
from app.services.workflow_service import WorkflowService
|
from app.services.workflow_service import WorkflowService
|
||||||
|
from app.models.file_metadata_model import FileMetadata
|
||||||
from app.utils.app_config_utils import workflow_config_4_app_release, \
|
from app.utils.app_config_utils import workflow_config_4_app_release, \
|
||||||
agent_config_4_app_release, multi_agent_config_4_app_release
|
agent_config_4_app_release, multi_agent_config_4_app_release
|
||||||
|
|
||||||
@@ -259,8 +260,41 @@ def get_conversation(
|
|||||||
conv_service = ConversationService(db)
|
conv_service = ConversationService(db)
|
||||||
messages = conv_service.get_messages(conversation_id)
|
messages = conv_service.get_messages(conversation_id)
|
||||||
|
|
||||||
# 构建响应
|
file_ids = []
|
||||||
conv_dict = conversation_schema.Conversation.model_validate(conversation).model_dump()
|
message_file_id_map = {}
|
||||||
|
|
||||||
|
# 第一次遍历:解析 audio_url,收集所有有效的 file_id
|
||||||
|
for idx, m in enumerate(messages):
|
||||||
|
if m.role == "assistant" and m.meta_data:
|
||||||
|
audio_url = m.meta_data.get("audio_url")
|
||||||
|
if not audio_url:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
file_id = uuid.UUID(audio_url.rstrip("/").split("/")[-1])
|
||||||
|
except (ValueError, IndexError):
|
||||||
|
# audio_url 无法解析为 UUID,标记为 unknown
|
||||||
|
m.meta_data["audio_status"] = "unknown"
|
||||||
|
continue
|
||||||
|
|
||||||
|
file_ids.append(file_id)
|
||||||
|
message_file_id_map[idx] = file_id
|
||||||
|
|
||||||
|
# 批量查询所有相关的 FileMetadata
|
||||||
|
file_status_map = {}
|
||||||
|
if file_ids:
|
||||||
|
file_metas = (
|
||||||
|
db.query(FileMetadata)
|
||||||
|
.filter(FileMetadata.id.in_(set(file_ids)))
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
file_status_map = {fm.id: fm.status for fm in file_metas}
|
||||||
|
|
||||||
|
# 第二次遍历:将查询结果映射回消息
|
||||||
|
for idx, file_id in message_file_id_map.items():
|
||||||
|
m = messages[idx]
|
||||||
|
m.meta_data["audio_status"] = file_status_map.get(file_id, "unknown")
|
||||||
|
|
||||||
|
conv_dict = conversation_schema.Conversation.model_validate(conversation).model_dump(mode="json")
|
||||||
conv_dict["messages"] = [
|
conv_dict["messages"] = [
|
||||||
conversation_schema.Message.model_validate(m) for m in messages
|
conversation_schema.Message.model_validate(m) for m in messages
|
||||||
]
|
]
|
||||||
@@ -320,6 +354,16 @@ async def chat(
|
|||||||
other_id=other_id,
|
other_id=other_id,
|
||||||
original_user_id=user_id
|
original_user_id=user_id
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Only extract and set memory_config_id when the end user doesn't have one yet
|
||||||
|
if not new_end_user.memory_config_id:
|
||||||
|
from app.services.memory_config_service import MemoryConfigService
|
||||||
|
memory_config_service = MemoryConfigService(db)
|
||||||
|
memory_config_id, _ = memory_config_service.extract_memory_config_id(release.type, release.config or {})
|
||||||
|
if memory_config_id:
|
||||||
|
new_end_user.memory_config_id = memory_config_id
|
||||||
|
db.commit()
|
||||||
|
db.refresh(new_end_user)
|
||||||
end_user_id = str(new_end_user.id)
|
end_user_id = str(new_end_user.id)
|
||||||
|
|
||||||
# appid = share.app_id
|
# appid = share.app_id
|
||||||
@@ -669,6 +713,7 @@ async def config_query(
|
|||||||
content = {
|
content = {
|
||||||
"app_type": release.app.type,
|
"app_type": release.app.type,
|
||||||
"variables": release.config.get("variables"),
|
"variables": release.config.get("variables"),
|
||||||
|
"memory": release.config.get("memory", {}).get("enabled"),
|
||||||
"features": release.config.get("features")
|
"features": release.config.get("features")
|
||||||
}
|
}
|
||||||
elif release.app.type == AppType.MULTI_AGENT:
|
elif release.app.type == AppType.MULTI_AGENT:
|
||||||
|
|||||||
@@ -91,7 +91,7 @@ async def chat(
|
|||||||
|
|
||||||
app = app_service.get_app(api_key_auth.resource_id, api_key_auth.workspace_id)
|
app = app_service.get_app(api_key_auth.resource_id, api_key_auth.workspace_id)
|
||||||
other_id = payload.user_id
|
other_id = payload.user_id
|
||||||
workspace_id = app.workspace_id
|
workspace_id = api_key_auth.workspace_id
|
||||||
end_user_repo = EndUserRepository(db)
|
end_user_repo = EndUserRepository(db)
|
||||||
new_end_user = end_user_repo.get_or_create_end_user(
|
new_end_user = end_user_repo.get_or_create_end_user(
|
||||||
app_id=app.id,
|
app_id=app.id,
|
||||||
|
|||||||
@@ -76,6 +76,8 @@ async def get_tool_methods(
|
|||||||
if methods is None:
|
if methods is None:
|
||||||
raise HTTPException(status_code=404, detail="工具不存在")
|
raise HTTPException(status_code=404, detail="工具不存在")
|
||||||
return success(data=methods, msg="获取工具方法成功")
|
return success(data=methods, msg="获取工具方法成功")
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
@@ -121,6 +123,8 @@ async def create_tool(
|
|||||||
raise HTTPException(status_code=400, detail=e.message)
|
raise HTTPException(status_code=400, detail=e.message)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
raise HTTPException(status_code=400, detail=str(e))
|
raise HTTPException(status_code=400, detail=str(e))
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
@@ -149,6 +153,8 @@ async def update_tool(
|
|||||||
return success(msg="工具更新成功")
|
return success(msg="工具更新成功")
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
raise HTTPException(status_code=400, detail=str(e))
|
raise HTTPException(status_code=400, detail=str(e))
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
@@ -191,6 +197,8 @@ async def set_tool_active(
|
|||||||
return success(msg=f"工具已{action}")
|
return success(msg=f"工具已{action}")
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
raise HTTPException(status_code=400, detail=str(e))
|
raise HTTPException(status_code=400, detail=str(e))
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
@@ -223,6 +231,8 @@ async def execute_tool(
|
|||||||
},
|
},
|
||||||
msg="工具执行完成"
|
msg="工具执行完成"
|
||||||
)
|
)
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
|
|||||||
@@ -111,6 +111,18 @@ def get_current_user_info(
|
|||||||
break
|
break
|
||||||
|
|
||||||
api_logger.info(f"当前用户信息获取成功: {result.username}, 角色: {result_schema.role}, 工作空间: {result_schema.current_workspace_name}")
|
api_logger.info(f"当前用户信息获取成功: {result.username}, 角色: {result_schema.role}, 工作空间: {result_schema.current_workspace_name}")
|
||||||
|
|
||||||
|
# 设置权限:如果用户来自 SSO Source,则使用该 Source 的 permissions;否则返回 "all" 表示拥有所有权限
|
||||||
|
if current_user.external_source:
|
||||||
|
from premium.sso.models import SSOSource
|
||||||
|
source = db.query(SSOSource).filter(SSOSource.source_code == current_user.external_source).first()
|
||||||
|
if source and source.permissions:
|
||||||
|
result_schema.permissions = source.permissions
|
||||||
|
else:
|
||||||
|
result_schema.permissions = []
|
||||||
|
else:
|
||||||
|
result_schema.permissions = ["all"]
|
||||||
|
|
||||||
return success(data=result_schema, msg=t("users.info.get_success"))
|
return success(data=result_schema, msg=t("users.info.get_success"))
|
||||||
|
|
||||||
|
|
||||||
@@ -135,7 +147,6 @@ def get_tenant_superusers(
|
|||||||
return success(data=superusers_schema, msg=t("users.list.superusers_success"))
|
return success(data=superusers_schema, msg=t("users.list.superusers_success"))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{user_id}", response_model=ApiResponse)
|
@router.get("/{user_id}", response_model=ApiResponse)
|
||||||
def get_user_info_by_id(
|
def get_user_info_by_id(
|
||||||
user_id: uuid.UUID,
|
user_id: uuid.UUID,
|
||||||
|
|||||||
@@ -5,7 +5,7 @@
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
import datetime
|
import datetime
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
from fastapi import APIRouter, Depends,Header
|
from fastapi import APIRouter, Depends, Header
|
||||||
|
|
||||||
from app.db import get_db
|
from app.db import get_db
|
||||||
from app.core.language_utils import get_language_from_header
|
from app.core.language_utils import get_language_from_header
|
||||||
@@ -19,13 +19,15 @@ from app.services.user_memory_service import (
|
|||||||
analytics_graph_data,
|
analytics_graph_data,
|
||||||
analytics_community_graph_data,
|
analytics_community_graph_data,
|
||||||
)
|
)
|
||||||
from app.services.memory_entity_relationship_service import MemoryEntityService,MemoryEmotion,MemoryInteraction
|
from app.services.memory_entity_relationship_service import MemoryEntityService, MemoryEmotion, MemoryInteraction
|
||||||
from app.schemas.response_schema import ApiResponse
|
from app.schemas.response_schema import ApiResponse
|
||||||
from app.schemas.memory_storage_schema import GenerateCacheRequest
|
from app.schemas.memory_storage_schema import GenerateCacheRequest
|
||||||
from app.repositories.workspace_repository import WorkspaceRepository
|
from app.repositories.workspace_repository import WorkspaceRepository
|
||||||
from app.schemas.end_user_schema import (
|
from app.repositories.end_user_repository import EndUserRepository
|
||||||
EndUserProfileResponse,
|
from app.schemas.end_user_info_schema import (
|
||||||
EndUserProfileUpdate,
|
EndUserInfoResponse,
|
||||||
|
EndUserInfoCreate,
|
||||||
|
EndUserInfoUpdate,
|
||||||
)
|
)
|
||||||
from app.models.end_user_model import EndUser
|
from app.models.end_user_model import EndUser
|
||||||
from app.dependencies import get_current_user
|
from app.dependencies import get_current_user
|
||||||
@@ -45,9 +47,9 @@ router = APIRouter(
|
|||||||
|
|
||||||
@router.get("/analytics/memory_insight/report", response_model=ApiResponse)
|
@router.get("/analytics/memory_insight/report", response_model=ApiResponse)
|
||||||
async def get_memory_insight_report_api(
|
async def get_memory_insight_report_api(
|
||||||
end_user_id: str,
|
end_user_id: str,
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""
|
"""
|
||||||
获取缓存的记忆洞察报告
|
获取缓存的记忆洞察报告
|
||||||
@@ -73,10 +75,10 @@ async def get_memory_insight_report_api(
|
|||||||
|
|
||||||
@router.get("/analytics/user_summary", response_model=ApiResponse)
|
@router.get("/analytics/user_summary", response_model=ApiResponse)
|
||||||
async def get_user_summary_api(
|
async def get_user_summary_api(
|
||||||
end_user_id: str,
|
end_user_id: str,
|
||||||
language_type: str = Header(default=None, alias="X-Language-Type"),
|
language_type: str = Header(default=None, alias="X-Language-Type"),
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""
|
"""
|
||||||
获取缓存的用户摘要
|
获取缓存的用户摘要
|
||||||
@@ -102,7 +104,7 @@ async def get_user_summary_api(
|
|||||||
api_logger.info(f"用户摘要查询请求: end_user_id={end_user_id}, user={current_user.username}")
|
api_logger.info(f"用户摘要查询请求: end_user_id={end_user_id}, user={current_user.username}")
|
||||||
try:
|
try:
|
||||||
# 调用服务层获取缓存数据
|
# 调用服务层获取缓存数据
|
||||||
result = await user_memory_service.get_cached_user_summary(db, end_user_id,model_id,language)
|
result = await user_memory_service.get_cached_user_summary(db, end_user_id, model_id, language)
|
||||||
|
|
||||||
if result["is_cached"]:
|
if result["is_cached"]:
|
||||||
api_logger.info(f"成功返回缓存的用户摘要: end_user_id={end_user_id}")
|
api_logger.info(f"成功返回缓存的用户摘要: end_user_id={end_user_id}")
|
||||||
@@ -117,10 +119,10 @@ async def get_user_summary_api(
|
|||||||
|
|
||||||
@router.post("/analytics/generate_cache", response_model=ApiResponse)
|
@router.post("/analytics/generate_cache", response_model=ApiResponse)
|
||||||
async def generate_cache_api(
|
async def generate_cache_api(
|
||||||
request: GenerateCacheRequest,
|
request: GenerateCacheRequest,
|
||||||
language_type: str = Header(default=None, alias="X-Language-Type"),
|
language_type: str = Header(default=None, alias="X-Language-Type"),
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""
|
"""
|
||||||
手动触发缓存生成
|
手动触发缓存生成
|
||||||
@@ -155,10 +157,12 @@ async def generate_cache_api(
|
|||||||
api_logger.info(f"开始为单个用户生成缓存: end_user_id={end_user_id}")
|
api_logger.info(f"开始为单个用户生成缓存: end_user_id={end_user_id}")
|
||||||
|
|
||||||
# 生成记忆洞察
|
# 生成记忆洞察
|
||||||
insight_result = await user_memory_service.generate_and_cache_insight(db, end_user_id, workspace_id, language=language)
|
insight_result = await user_memory_service.generate_and_cache_insight(db, end_user_id, workspace_id,
|
||||||
|
language=language)
|
||||||
|
|
||||||
# 生成用户摘要
|
# 生成用户摘要
|
||||||
summary_result = await user_memory_service.generate_and_cache_summary(db, end_user_id, workspace_id, language=language)
|
summary_result = await user_memory_service.generate_and_cache_summary(db, end_user_id, workspace_id,
|
||||||
|
language=language)
|
||||||
|
|
||||||
# 构建响应
|
# 构建响应
|
||||||
result = {
|
result = {
|
||||||
@@ -209,9 +213,9 @@ async def generate_cache_api(
|
|||||||
|
|
||||||
@router.get("/analytics/node_statistics", response_model=ApiResponse)
|
@router.get("/analytics/node_statistics", response_model=ApiResponse)
|
||||||
async def get_node_statistics_api(
|
async def get_node_statistics_api(
|
||||||
end_user_id: str,
|
end_user_id: str,
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
) -> dict:
|
) -> dict:
|
||||||
workspace_id = current_user.current_workspace_id
|
workspace_id = current_user.current_workspace_id
|
||||||
|
|
||||||
@@ -220,7 +224,8 @@ async def get_node_statistics_api(
|
|||||||
api_logger.warning(f"用户 {current_user.username} 尝试查询节点统计但未选择工作空间")
|
api_logger.warning(f"用户 {current_user.username} 尝试查询节点统计但未选择工作空间")
|
||||||
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
||||||
|
|
||||||
api_logger.info(f"记忆类型统计请求: end_user_id={end_user_id}, user={current_user.username}, workspace={workspace_id}")
|
api_logger.info(
|
||||||
|
f"记忆类型统计请求: end_user_id={end_user_id}, user={current_user.username}, workspace={workspace_id}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 调用新的记忆类型统计函数
|
# 调用新的记忆类型统计函数
|
||||||
@@ -228,21 +233,23 @@ async def get_node_statistics_api(
|
|||||||
|
|
||||||
# 计算总数用于日志
|
# 计算总数用于日志
|
||||||
total_count = sum(item["count"] for item in result)
|
total_count = sum(item["count"] for item in result)
|
||||||
api_logger.info(f"成功获取记忆类型统计: end_user_id={end_user_id}, 总记忆数={total_count}, 类型数={len(result)}")
|
api_logger.info(
|
||||||
|
f"成功获取记忆类型统计: end_user_id={end_user_id}, 总记忆数={total_count}, 类型数={len(result)}")
|
||||||
return success(data=result, msg="查询成功")
|
return success(data=result, msg="查询成功")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
api_logger.error(f"记忆类型查询失败: end_user_id={end_user_id}, error={str(e)}")
|
api_logger.error(f"记忆类型查询失败: end_user_id={end_user_id}, error={str(e)}")
|
||||||
return fail(BizCode.INTERNAL_ERROR, "记忆类型查询失败", str(e))
|
return fail(BizCode.INTERNAL_ERROR, "记忆类型查询失败", str(e))
|
||||||
|
|
||||||
|
|
||||||
@router.get("/analytics/graph_data", response_model=ApiResponse)
|
@router.get("/analytics/graph_data", response_model=ApiResponse)
|
||||||
async def get_graph_data_api(
|
async def get_graph_data_api(
|
||||||
end_user_id: str,
|
end_user_id: str,
|
||||||
node_types: Optional[str] = None,
|
node_types: Optional[str] = None,
|
||||||
limit: int = 100,
|
limit: int = 100,
|
||||||
depth: int = 1,
|
depth: int = 1,
|
||||||
center_node_id: Optional[str] = None,
|
center_node_id: Optional[str] = None,
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
) -> dict:
|
) -> dict:
|
||||||
workspace_id = current_user.current_workspace_id
|
workspace_id = current_user.current_workspace_id
|
||||||
|
|
||||||
@@ -298,9 +305,9 @@ async def get_graph_data_api(
|
|||||||
|
|
||||||
@router.get("/analytics/community_graph", response_model=ApiResponse)
|
@router.get("/analytics/community_graph", response_model=ApiResponse)
|
||||||
async def get_community_graph_data_api(
|
async def get_community_graph_data_api(
|
||||||
end_user_id: str,
|
end_user_id: str,
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
) -> dict:
|
) -> dict:
|
||||||
workspace_id = current_user.current_workspace_id
|
workspace_id = current_user.current_workspace_id
|
||||||
|
|
||||||
@@ -331,111 +338,130 @@ async def get_community_graph_data_api(
|
|||||||
api_logger.error(f"社区图谱查询失败: end_user_id={end_user_id}, error={str(e)}")
|
api_logger.error(f"社区图谱查询失败: end_user_id={end_user_id}, error={str(e)}")
|
||||||
return fail(BizCode.INTERNAL_ERROR, "社区图谱查询失败", str(e))
|
return fail(BizCode.INTERNAL_ERROR, "社区图谱查询失败", str(e))
|
||||||
|
|
||||||
|
#=======================终端用户信息接口=======================
|
||||||
|
|
||||||
@router.get("/read_end_user/profile", response_model=ApiResponse)
|
@router.get("/end_user_info", response_model=ApiResponse)
|
||||||
async def get_end_user_profile(
|
async def get_end_user_info(
|
||||||
end_user_id: str,
|
end_user_id: str,
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
) -> dict:
|
) -> dict:
|
||||||
workspace_id = current_user.current_workspace_id
|
"""
|
||||||
workspace_repo = WorkspaceRepository(db)
|
查询终端用户信息记录
|
||||||
workspace_models = workspace_repo.get_workspace_models_configs(workspace_id)
|
|
||||||
|
根据 end_user_id 查询单条终端用户信息记录。
|
||||||
|
"""
|
||||||
|
workspace_id = current_user.current_workspace_id
|
||||||
|
|
||||||
if workspace_models:
|
|
||||||
model_id = workspace_models.get("llm", None)
|
|
||||||
else:
|
|
||||||
model_id = None
|
|
||||||
# 检查用户是否已选择工作空间
|
|
||||||
if workspace_id is None:
|
if workspace_id is None:
|
||||||
api_logger.warning(f"用户 {current_user.username} 尝试查询用户信息但未选择工作空间")
|
api_logger.warning(f"用户 {current_user.username} 尝试查询终端用户信息但未选择工作空间")
|
||||||
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
||||||
|
|
||||||
api_logger.info(
|
api_logger.info(
|
||||||
f"用户信息查询请求: end_user_id={end_user_id}, user={current_user.username}, "
|
f"查询终端用户信息请求: end_user_id={end_user_id}, user={current_user.username}, "
|
||||||
f"workspace={workspace_id}"
|
f"workspace={workspace_id}"
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
# 校验 end_user 是否属于当前工作空间
|
||||||
# 查询终端用户
|
end_user_repo = EndUserRepository(db)
|
||||||
end_user = db.query(EndUser).filter(EndUser.id == end_user_id).first()
|
end_user = end_user_repo.get_end_user_by_id(end_user_id)
|
||||||
|
if end_user is None:
|
||||||
if not end_user:
|
return fail(BizCode.USER_NOT_FOUND, "终端用户不存在", "end_user not found")
|
||||||
api_logger.warning(f"终端用户不存在: end_user_id={end_user_id}")
|
if str(end_user.workspace_id) != str(workspace_id):
|
||||||
return fail(BizCode.INVALID_PARAMETER, "终端用户不存在", f"end_user_id={end_user_id}")
|
api_logger.warning(
|
||||||
# 构建响应数据
|
f"用户 {current_user.username} 尝试查询不属于工作空间 {workspace_id} 的终端用户 {end_user_id}"
|
||||||
profile_data = EndUserProfileResponse(
|
|
||||||
id=end_user.id,
|
|
||||||
other_name=end_user.other_name,
|
|
||||||
position=end_user.position,
|
|
||||||
department=end_user.department,
|
|
||||||
contact=end_user.contact,
|
|
||||||
phone=end_user.phone,
|
|
||||||
hire_date=end_user.hire_date,
|
|
||||||
updatetime_profile=end_user.updatetime_profile
|
|
||||||
)
|
)
|
||||||
|
return fail(BizCode.PERMISSION_DENIED, "该终端用户不属于当前工作空间", "end_user workspace mismatch")
|
||||||
|
|
||||||
api_logger.info(f"成功获取用户信息: end_user_id={end_user_id}")
|
result = user_memory_service.get_end_user_info(db, end_user_id)
|
||||||
return success(data=UserMemoryService.convert_profile_to_dict_with_timestamp(profile_data), msg="查询成功")
|
|
||||||
|
|
||||||
except Exception as e:
|
if result["success"]:
|
||||||
api_logger.error(f"用户信息查询失败: end_user_id={end_user_id}, error={str(e)}")
|
api_logger.info(f"成功查询终端用户信息: end_user_id={end_user_id}")
|
||||||
return fail(BizCode.INTERNAL_ERROR, "用户信息查询失败", str(e))
|
return success(data=result["data"], msg="查询成功")
|
||||||
|
else:
|
||||||
|
error_msg = result["error"]
|
||||||
|
api_logger.error(f"查询终端用户信息失败: end_user_id={end_user_id}, error={error_msg}")
|
||||||
|
|
||||||
|
if error_msg == "终端用户信息记录不存在":
|
||||||
|
return fail(BizCode.USER_NOT_FOUND, "终端用户信息记录不存在", error_msg)
|
||||||
|
elif error_msg == "无效的终端用户ID格式":
|
||||||
|
return fail(BizCode.INVALID_USER_ID, "无效的终端用户ID格式", error_msg)
|
||||||
|
else:
|
||||||
|
return fail(BizCode.INTERNAL_ERROR, "查询终端用户信息失败", error_msg)
|
||||||
|
|
||||||
|
|
||||||
@router.post("/updated_end_user/profile", response_model=ApiResponse)
|
@router.post("/end_user_info/updated", response_model=ApiResponse)
|
||||||
async def update_end_user_profile(
|
async def update_end_user_info(
|
||||||
profile_update: EndUserProfileUpdate,
|
info_update: EndUserInfoUpdate,
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""
|
"""
|
||||||
更新终端用户的基本信息
|
更新终端用户信息记录
|
||||||
|
|
||||||
该接口可以更新用户的姓名、职位、部门、联系方式、电话和入职日期等信息。
|
根据 end_user_id 更新终端用户信息记录,支持批量更新多个别名。
|
||||||
所有字段都是可选的,只更新提供的字段。
|
|
||||||
|
示例请求体:
|
||||||
|
{
|
||||||
|
"end_user_id": "a1b2c3d4-e5f6-7890-abcd-ef1234567890",
|
||||||
|
"other_name": "张三1",
|
||||||
|
"aliases": ["小张", "张工"],
|
||||||
|
"meta_data": {"position": "工程师", "department": "技术部"}
|
||||||
|
}
|
||||||
"""
|
"""
|
||||||
workspace_id = current_user.current_workspace_id
|
workspace_id = current_user.current_workspace_id
|
||||||
end_user_id = profile_update.end_user_id
|
end_user_id = info_update.end_user_id
|
||||||
|
|
||||||
# 验证工作空间
|
|
||||||
if workspace_id is None:
|
if workspace_id is None:
|
||||||
api_logger.warning(f"用户 {current_user.username} 尝试更新用户信息但未选择工作空间")
|
api_logger.warning(f"用户 {current_user.username} 尝试更新终端用户信息但未选择工作空间")
|
||||||
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
||||||
|
|
||||||
api_logger.info(
|
api_logger.info(
|
||||||
f"用户信息更新请求: end_user_id={end_user_id}, user={current_user.username}, "
|
f"更新终端用户信息请求: end_user_id={end_user_id}, user={current_user.username}, "
|
||||||
f"workspace={workspace_id}"
|
f"workspace={workspace_id}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# 调用 Service 层处理业务逻辑
|
# 校验 end_user 是否属于当前工作空间
|
||||||
result = user_memory_service.update_end_user_profile(db, end_user_id, profile_update)
|
end_user_repo = EndUserRepository(db)
|
||||||
|
end_user = end_user_repo.get_end_user_by_id(end_user_id)
|
||||||
|
if end_user is None:
|
||||||
|
return fail(BizCode.USER_NOT_FOUND, "终端用户不存在", "end_user not found")
|
||||||
|
if str(end_user.workspace_id) != str(workspace_id):
|
||||||
|
api_logger.warning(
|
||||||
|
f"用户 {current_user.username} 尝试更新不属于工作空间 {workspace_id} 的终端用户 {end_user_id}"
|
||||||
|
)
|
||||||
|
return fail(BizCode.PERMISSION_DENIED, "该终端用户不属于当前工作空间", "end_user workspace mismatch")
|
||||||
|
|
||||||
|
# 获取更新数据(排除 end_user_id)
|
||||||
|
update_data = info_update.model_dump(exclude_unset=True, exclude={'end_user_id'})
|
||||||
|
|
||||||
|
result = user_memory_service.update_end_user_info(db, end_user_id, update_data)
|
||||||
|
|
||||||
if result["success"]:
|
if result["success"]:
|
||||||
api_logger.info(f"成功更新用户信息: end_user_id={end_user_id}")
|
api_logger.info(f"成功更新终端用户信息: end_user_id={end_user_id}")
|
||||||
return success(data=result["data"], msg="更新成功")
|
return success(data=result["data"], msg="更新成功")
|
||||||
else:
|
else:
|
||||||
error_msg = result["error"]
|
error_msg = result["error"]
|
||||||
api_logger.error(f"用户信息更新失败: end_user_id={end_user_id}, error={error_msg}")
|
api_logger.error(f"终端用户信息更新失败: end_user_id={end_user_id}, error={error_msg}")
|
||||||
|
|
||||||
# 根据错误类型映射到合适的业务错误码
|
if error_msg == "终端用户信息记录不存在":
|
||||||
if error_msg == "终端用户不存在":
|
return fail(BizCode.USER_NOT_FOUND, "终端用户信息记录不存在", error_msg)
|
||||||
return fail(BizCode.USER_NOT_FOUND, "终端用户不存在", error_msg)
|
elif error_msg == "无效的终端用户ID格式":
|
||||||
elif error_msg == "无效的用户ID格式":
|
return fail(BizCode.INVALID_USER_ID, "无效的终端用户ID格式", error_msg)
|
||||||
return fail(BizCode.INVALID_USER_ID, "无效的用户ID格式", error_msg)
|
|
||||||
else:
|
else:
|
||||||
# 只有未预期的错误才使用 INTERNAL_ERROR
|
return fail(BizCode.INTERNAL_ERROR, "终端用户信息更新失败", error_msg)
|
||||||
return fail(BizCode.INTERNAL_ERROR, "用户信息更新失败", error_msg)
|
|
||||||
|
|
||||||
@router.get("/memory_space/timeline_memories", response_model=ApiResponse)
|
@router.get("/memory_space/timeline_memories", response_model=ApiResponse)
|
||||||
async def memory_space_timeline_of_shared_memories(id: str, label: str,language_type: str = Header(default=None, alias="X-Language-Type"),
|
async def memory_space_timeline_of_shared_memories(
|
||||||
current_user: User = Depends(get_current_user),
|
id: str, label: str,
|
||||||
db: Session = Depends(get_db),
|
language_type: str = Header(default=None, alias="X-Language-Type"),
|
||||||
):
|
current_user: User = Depends(get_current_user),
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
):
|
||||||
# 使用集中化的语言校验
|
# 使用集中化的语言校验
|
||||||
language = get_language_from_header(language_type)
|
language = get_language_from_header(language_type)
|
||||||
|
|
||||||
workspace_id=current_user.current_workspace_id
|
workspace_id = current_user.current_workspace_id
|
||||||
workspace_repo = WorkspaceRepository(db)
|
workspace_repo = WorkspaceRepository(db)
|
||||||
workspace_models = workspace_repo.get_workspace_models_configs(workspace_id)
|
workspace_models = workspace_repo.get_workspace_models_configs(workspace_id)
|
||||||
|
|
||||||
@@ -447,11 +473,13 @@ async def memory_space_timeline_of_shared_memories(id: str, label: str,language_
|
|||||||
timeline_memories_result = await MemoryEntity.get_timeline_memories_server(model_id, language)
|
timeline_memories_result = await MemoryEntity.get_timeline_memories_server(model_id, language)
|
||||||
|
|
||||||
return success(data=timeline_memories_result, msg="共同记忆时间线")
|
return success(data=timeline_memories_result, msg="共同记忆时间线")
|
||||||
|
|
||||||
|
|
||||||
@router.get("/memory_space/relationship_evolution", response_model=ApiResponse)
|
@router.get("/memory_space/relationship_evolution", response_model=ApiResponse)
|
||||||
async def memory_space_relationship_evolution(id: str, label: str,
|
async def memory_space_relationship_evolution(id: str, label: str,
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
api_logger.info(f"关系演变查询请求: id={id}, table={label}, user={current_user.username}")
|
api_logger.info(f"关系演变查询请求: id={id}, table={label}, user={current_user.username}")
|
||||||
|
|
||||||
|
|||||||
@@ -329,7 +329,6 @@ class LangChainAgent:
|
|||||||
db.close()
|
db.close()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Failed to get db session: {e}")
|
logger.warning(f"Failed to get db session: {e}")
|
||||||
actual_end_user_id = end_user_id if end_user_id is not None else "unknown"
|
|
||||||
logger.info(f'写入类型{storage_type, str(end_user_id), message, str(user_rag_memory_id)}')
|
logger.info(f'写入类型{storage_type, str(end_user_id), message, str(user_rag_memory_id)}')
|
||||||
print(f'写入类型{storage_type, str(end_user_id), message, str(user_rag_memory_id)}')
|
print(f'写入类型{storage_type, str(end_user_id), message, str(user_rag_memory_id)}')
|
||||||
try:
|
try:
|
||||||
@@ -598,8 +597,10 @@ class LangChainAgent:
|
|||||||
for msg in reversed(output_messages):
|
for msg in reversed(output_messages):
|
||||||
if isinstance(msg, AIMessage):
|
if isinstance(msg, AIMessage):
|
||||||
response_meta = msg.response_metadata if hasattr(msg, 'response_metadata') else None
|
response_meta = msg.response_metadata if hasattr(msg, 'response_metadata') else None
|
||||||
total_tokens = response_meta.get("token_usage", {}).get("total_tokens",
|
total_tokens = response_meta.get("token_usage", {}).get(
|
||||||
0) if response_meta else 0
|
"total_tokens",
|
||||||
|
0
|
||||||
|
) if response_meta else 0
|
||||||
yield total_tokens
|
yield total_tokens
|
||||||
break
|
break
|
||||||
if memory_flag:
|
if memory_flag:
|
||||||
|
|||||||
@@ -97,6 +97,7 @@ class Settings:
|
|||||||
|
|
||||||
# File Upload
|
# File Upload
|
||||||
MAX_FILE_SIZE: int = int(os.getenv("MAX_FILE_SIZE", "52428800"))
|
MAX_FILE_SIZE: int = int(os.getenv("MAX_FILE_SIZE", "52428800"))
|
||||||
|
MAX_FILE_COUNT: int = int(os.getenv("MAX_FILE_COUNT", "20"))
|
||||||
FILE_PATH: str = os.getenv("FILE_PATH", "/files")
|
FILE_PATH: str = os.getenv("FILE_PATH", "/files")
|
||||||
FILE_URL_EXPIRES: int = int(os.getenv("FILE_URL_EXPIRES", "3600"))
|
FILE_URL_EXPIRES: int = int(os.getenv("FILE_URL_EXPIRES", "3600"))
|
||||||
|
|
||||||
@@ -230,8 +231,8 @@ class Settings:
|
|||||||
# Celery configuration (internal)
|
# Celery configuration (internal)
|
||||||
# NOTE: 变量名不以 CELERY_ 开头,避免被 Celery CLI 的前缀匹配机制劫持
|
# NOTE: 变量名不以 CELERY_ 开头,避免被 Celery CLI 的前缀匹配机制劫持
|
||||||
# 详见 docs/celery-env-bug-report.md
|
# 详见 docs/celery-env-bug-report.md
|
||||||
# 默认使用 Redis DB 3 (broker) 和 DB 4 (backend),与业务缓存 (DB 1/2) 隔离
|
# 默认使用 Redis 作为 broker 和 backend,与业务缓存隔离
|
||||||
# 多人共用同一 Redis 时,每位开发者应在 .env 中配置不同的 DB 编号避免任务互相干扰
|
# 如需使用 RabbitMQ,在 .env 中设置 CELERY_BROKER_URL=amqp://user:pass@host:5672/vhost
|
||||||
REDIS_DB_CELERY_BROKER: int = int(os.getenv("REDIS_DB_CELERY_BROKER", "3"))
|
REDIS_DB_CELERY_BROKER: int = int(os.getenv("REDIS_DB_CELERY_BROKER", "3"))
|
||||||
REDIS_DB_CELERY_BACKEND: int = int(os.getenv("REDIS_DB_CELERY_BACKEND", "4"))
|
REDIS_DB_CELERY_BACKEND: int = int(os.getenv("REDIS_DB_CELERY_BACKEND", "4"))
|
||||||
|
|
||||||
|
|||||||
@@ -529,8 +529,9 @@ def log_time(step_name: str, duration: float, log_file: str = "logs/time.log") -
|
|||||||
# Fallback to console only if file write fails
|
# Fallback to console only if file write fails
|
||||||
print(f"Warning: Could not write to timing log: {e}")
|
print(f"Warning: Could not write to timing log: {e}")
|
||||||
|
|
||||||
# Always print to console (backward compatible behavior)
|
# Always log at INFO level (avoids Celery treating stdout as WARNING)
|
||||||
print(f"✓ {step_name}: {duration:.2f}s")
|
_timing_logger = logging.getLogger(__name__)
|
||||||
|
_timing_logger.info(f"✓ {step_name}: {duration:.2f}s")
|
||||||
|
|
||||||
|
|
||||||
def get_agent_logger(name: str = "agent_service",
|
def get_agent_logger(name: str = "agent_service",
|
||||||
|
|||||||
@@ -178,7 +178,7 @@ async def window_dialogue(end_user_id, langchain_messages, memory_config, scope)
|
|||||||
count_store.update_sessions_count(end_user_id, is_end_user_id, langchain_messages)
|
count_store.update_sessions_count(end_user_id, is_end_user_id, langchain_messages)
|
||||||
elif int(is_end_user_id) == int(scope):
|
elif int(is_end_user_id) == int(scope):
|
||||||
logger.info('写入长期记忆NEO4J')
|
logger.info('写入长期记忆NEO4J')
|
||||||
formatted_messages = (redis_messages)
|
formatted_messages = redis_messages
|
||||||
# Get config_id (if memory_config is an object, extract config_id; otherwise use directly)
|
# Get config_id (if memory_config is an object, extract config_id; otherwise use directly)
|
||||||
if hasattr(memory_config, 'config_id'):
|
if hasattr(memory_config, 'config_id'):
|
||||||
config_id = memory_config.config_id
|
config_id = memory_config.config_id
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ async def get_chunked_dialogs(
|
|||||||
chunker_strategy: str = "RecursiveChunker",
|
chunker_strategy: str = "RecursiveChunker",
|
||||||
end_user_id: str = "group_1",
|
end_user_id: str = "group_1",
|
||||||
messages: list = None,
|
messages: list = None,
|
||||||
ref_id: str = "wyl_20251027",
|
ref_id: str = "",
|
||||||
config_id: str = None
|
config_id: str = None
|
||||||
) -> List[DialogData]:
|
) -> List[DialogData]:
|
||||||
"""Generate chunks from structured messages using the specified chunker strategy.
|
"""Generate chunks from structured messages using the specified chunker strategy.
|
||||||
@@ -40,12 +40,13 @@ async def get_chunked_dialogs(
|
|||||||
|
|
||||||
role = msg['role']
|
role = msg['role']
|
||||||
content = msg['content']
|
content = msg['content']
|
||||||
|
files = msg.get("file_content", [])
|
||||||
|
|
||||||
if role not in ['user', 'assistant']:
|
if role not in ['user', 'assistant']:
|
||||||
raise ValueError(f"Message {idx} role must be 'user' or 'assistant', got: {role}")
|
raise ValueError(f"Message {idx} role must be 'user' or 'assistant', got: {role}")
|
||||||
|
|
||||||
if content.strip():
|
if content.strip():
|
||||||
conversation_messages.append(ConversationMessage(role=role, msg=content.strip()))
|
conversation_messages.append(ConversationMessage(role=role, msg=content.strip(), files=files))
|
||||||
|
|
||||||
if not conversation_messages:
|
if not conversation_messages:
|
||||||
raise ValueError("Message list cannot be empty after filtering")
|
raise ValueError("Message list cannot be empty after filtering")
|
||||||
|
|||||||
@@ -39,6 +39,30 @@
|
|||||||
比如:输入历史信息内容:[{'Query': '4月27日,我和你推荐过一本书,书名是什么?', 'ANswer': '张曼玉推荐了《小王子》'}]
|
比如:输入历史信息内容:[{'Query': '4月27日,我和你推荐过一本书,书名是什么?', 'ANswer': '张曼玉推荐了《小王子》'}]
|
||||||
拆分问题:4月27日,我和你推荐过一本书,书名是什么?,可以拆分为:4月27日,张曼玉推荐过一本书,书名是什么?
|
拆分问题:4月27日,我和你推荐过一本书,书名是什么?,可以拆分为:4月27日,张曼玉推荐过一本书,书名是什么?
|
||||||
|
|
||||||
|
## 指代消歧规则(Coreference Resolution):
|
||||||
|
在拆分问题时,必须解析并替换所有指代词和抽象称呼,使问题具体化:
|
||||||
|
|
||||||
|
1. **"用户"的消歧**:
|
||||||
|
- "用户是谁?" → 分析历史记录,找出对话发起者的姓名
|
||||||
|
- 如果历史中有"我叫X"、"我的名字是X"、或多次提到某个人物,则"用户"指的就是这个人
|
||||||
|
- 示例:历史中有"老李的原名叫李建国",则"用户是谁?"应拆分为"李建国是谁?"或"老李(李建国)是谁?"
|
||||||
|
|
||||||
|
2. **"我"的消歧**:
|
||||||
|
- "我喜欢什么?" → 从历史中找出对话发起者的姓名,替换为"X喜欢什么?"
|
||||||
|
- 示例:历史中有"张曼玉推荐了《小王子》",则"我推荐的书是什么?"应拆分为"张曼玉推荐的书是什么?"
|
||||||
|
|
||||||
|
3. **"他/她/它"的消歧**:
|
||||||
|
- 从上下文或历史中找出最近提到的同类实体
|
||||||
|
- 示例:历史中有"老李的同事叫他建国哥",则"他的同事怎么称呼他?"应拆分为"老李的同事怎么称呼他?"
|
||||||
|
|
||||||
|
4. **"那个人/这个人"的消歧**:
|
||||||
|
- 从历史中找出最近提到的人物
|
||||||
|
- 示例:历史中有"李建国",则"那个人的原名是什么?"应拆分为"李建国的原名是什么?"
|
||||||
|
|
||||||
|
5. **优先级**:
|
||||||
|
- 如果历史记录中反复出现某个人物(如"老李"、"李建国"、"建国哥"),则"用户"很可能指的就是这个人
|
||||||
|
- 如果无法从历史中确定指代对象,保留原问题,但在reason中说明"无法确定指代对象"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
输出要求:
|
输出要求:
|
||||||
@@ -71,6 +95,34 @@
|
|||||||
"reason": "输出原问题的关键要素"
|
"reason": "输出原问题的关键要素"
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
|
||||||
|
## 指代消歧示例(重要):
|
||||||
|
示例1 - "用户"的消歧:
|
||||||
|
输入历史:[{'Query': '老李的原名叫什么?', 'Answer': '李建国'}, {'Query': '老李的同事叫他什么?', 'Answer': '建国哥'}]
|
||||||
|
输入问题:"用户是谁?"
|
||||||
|
输出:
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"original_question": "用户是谁?",
|
||||||
|
"extended_question": "李建国是谁?",
|
||||||
|
"type": "单跳",
|
||||||
|
"reason": "历史中反复提到'老李/李建国/建国哥','用户'指的就是对话发起者李建国"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
示例2 - "我"的消歧:
|
||||||
|
输入历史:[{'Query': '张曼玉推荐了什么书?', 'Answer': '《小王子》'}]
|
||||||
|
输入问题:"我推荐的书是什么?"
|
||||||
|
输出:
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"original_question": "我推荐的书是什么?",
|
||||||
|
"extended_question": "张曼玉推荐的书是什么?",
|
||||||
|
"type": "单跳",
|
||||||
|
"reason": "历史中提到张曼玉推荐了书,'我'指的就是张曼玉"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
**Output format**
|
**Output format**
|
||||||
**CRITICAL JSON FORMATTING REQUIREMENTS:**
|
**CRITICAL JSON FORMATTING REQUIREMENTS:**
|
||||||
1. Use only standard ASCII double quotes (") for JSON structure - never use Chinese quotation marks ("") or other Unicode quotes
|
1. Use only standard ASCII double quotes (") for JSON structure - never use Chinese quotation marks ("") or other Unicode quotes
|
||||||
|
|||||||
@@ -27,6 +27,30 @@
|
|||||||
比如:输入历史信息内容:[{'Query': '4月27日,我和你推荐过一本书,书名是什么?', 'ANswer': '张曼玉推荐了《小王子》'}]
|
比如:输入历史信息内容:[{'Query': '4月27日,我和你推荐过一本书,书名是什么?', 'ANswer': '张曼玉推荐了《小王子》'}]
|
||||||
拆分问题:4月27日,我和你推荐过一本书,书名是什么?,可以拆分为:4月27日,张曼玉推荐过一本书,书名是什么?
|
拆分问题:4月27日,我和你推荐过一本书,书名是什么?,可以拆分为:4月27日,张曼玉推荐过一本书,书名是什么?
|
||||||
|
|
||||||
|
## 指代消歧规则(Coreference Resolution):
|
||||||
|
在拆分问题时,必须解析并替换所有指代词和抽象称呼,使问题具体化:
|
||||||
|
|
||||||
|
1. **"用户"的消歧**:
|
||||||
|
- "用户是谁?" → 分析历史记录,找出对话发起者的姓名
|
||||||
|
- 如果历史中有"我叫X"、"我的名字是X"、或多次提到某个人物(如"老李"、"李建国"),则"用户"指的就是这个人
|
||||||
|
- 示例:历史中反复出现"老李/李建国/建国哥",则"用户是谁?"应拆分为"李建国是谁?"或"老李(李建国)是谁?"
|
||||||
|
|
||||||
|
2. **"我"的消歧**:
|
||||||
|
- "我喜欢什么?" → 从历史中找出对话发起者的姓名,替换为"X喜欢什么?"
|
||||||
|
- 示例:历史中有"张曼玉推荐了《小王子》",则"我推荐的书是什么?"应拆分为"张曼玉推荐的书是什么?"
|
||||||
|
|
||||||
|
3. **"他/她/它"的消歧**:
|
||||||
|
- 从上下文或历史中找出最近提到的同类实体
|
||||||
|
- 示例:历史中有"老李的同事叫他建国哥",则"他的同事怎么称呼他?"应拆分为"老李的同事怎么称呼他?"
|
||||||
|
|
||||||
|
4. **"那个人/这个人"的消歧**:
|
||||||
|
- 从历史中找出最近提到的人物
|
||||||
|
- 示例:历史中有"李建国",则"那个人的原名是什么?"应拆分为"李建国的原名是什么?"
|
||||||
|
|
||||||
|
5. **优先级**:
|
||||||
|
- 如果历史记录中反复出现某个人物(如"老李"、"李建国"、"建国哥"),则"用户"很可能指的就是这个人
|
||||||
|
- 如果无法从历史中确定指代对象,保留原问题,但在reason中说明"无法确定指代对象"
|
||||||
|
|
||||||
## 指令:
|
## 指令:
|
||||||
你是一个智能数据拆分助手,请根据数据特性判断输入属于哪种类型:
|
你是一个智能数据拆分助手,请根据数据特性判断输入属于哪种类型:
|
||||||
单跳(Single-hop)
|
单跳(Single-hop)
|
||||||
@@ -151,6 +175,34 @@
|
|||||||
]
|
]
|
||||||
- 必须通过json.loads()的格式支持的形式输出
|
- 必须通过json.loads()的格式支持的形式输出
|
||||||
- 必须通过json.loads()的格式支持的形式输出,响应必须是与此确切模式匹配的有效JSON对象。不要在JSON之前或之后包含任何文本。
|
- 必须通过json.loads()的格式支持的形式输出,响应必须是与此确切模式匹配的有效JSON对象。不要在JSON之前或之后包含任何文本。
|
||||||
|
|
||||||
|
## 指代消歧示例(重要):
|
||||||
|
示例1 - "用户"的消歧:
|
||||||
|
输入历史:[{'Query': '老李的原名叫什么?', 'Answer': '李建国'}, {'Query': '老李的同事叫他什么?', 'Answer': '建国哥'}]
|
||||||
|
输入问题:"用户是谁?"
|
||||||
|
输出:
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"id": "Q1",
|
||||||
|
"question": "李建国是谁?",
|
||||||
|
"type": "单跳",
|
||||||
|
"reason": "历史中反复提到'老李/李建国/建国哥','用户'指的就是对话发起者李建国"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
示例2 - "我"的消歧:
|
||||||
|
输入历史:[{'Query': '张曼玉推荐了什么书?', 'Answer': '《小王子》'}]
|
||||||
|
输入问题:"我推荐的书是什么?"
|
||||||
|
输出:
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"id": "Q1",
|
||||||
|
"question": "张曼玉推荐的书是什么?",
|
||||||
|
"type": "单跳",
|
||||||
|
"reason": "历史中提到张曼玉推荐了书,'我'指的就是张曼玉"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
- 关键的JSON格式要求
|
- 关键的JSON格式要求
|
||||||
1.JSON结构仅使用标准ASCII双引号(“)-切勿使用中文引号(“”)或其他Unicode引号
|
1.JSON结构仅使用标准ASCII双引号(“)-切勿使用中文引号(“”)或其他Unicode引号
|
||||||
2.如果提取的语句文本包含引号,请使用反斜杠(\“)正确转义它们
|
2.如果提取的语句文本包含引号,请使用反斜杠(\“)正确转义它们
|
||||||
|
|||||||
@@ -6,35 +6,37 @@ pipeline. Only MemoryConfig is needed - clients are constructed internally.
|
|||||||
"""
|
"""
|
||||||
import asyncio
|
import asyncio
|
||||||
import time
|
import time
|
||||||
|
import uuid
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
from app.core.logging_config import get_agent_logger
|
from app.core.logging_config import get_agent_logger
|
||||||
from app.core.memory.agent.utils.get_dialogs import get_chunked_dialogs
|
from app.core.memory.agent.utils.get_dialogs import get_chunked_dialogs
|
||||||
from app.core.memory.storage_services.extraction_engine.extraction_orchestrator import ExtractionOrchestrator
|
from app.core.memory.storage_services.extraction_engine.extraction_orchestrator import ExtractionOrchestrator
|
||||||
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.memory_summary import memory_summary_generation
|
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.memory_summary import \
|
||||||
|
memory_summary_generation
|
||||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||||
from app.core.memory.utils.log.logging_utils import log_time
|
from app.core.memory.utils.log.logging_utils import log_time
|
||||||
from app.db import get_db_context
|
from app.db import get_db_context
|
||||||
from app.repositories.neo4j.add_edges import add_memory_summary_statement_edges
|
from app.repositories.neo4j.add_edges import add_memory_summary_statement_edges
|
||||||
from app.repositories.neo4j.add_nodes import add_memory_summary_nodes
|
from app.repositories.neo4j.add_nodes import add_memory_summary_nodes
|
||||||
from app.repositories.neo4j.graph_saver import save_dialog_and_statements_to_neo4j, schedule_clustering_after_write
|
from app.repositories.neo4j.graph_saver import save_dialog_and_statements_to_neo4j
|
||||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||||
from app.schemas.memory_config_schema import MemoryConfig
|
from app.schemas.memory_config_schema import MemoryConfig
|
||||||
|
|
||||||
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
|
||||||
logger = get_agent_logger(__name__)
|
logger = get_agent_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
async def write(
|
async def write(
|
||||||
end_user_id: str,
|
end_user_id: str,
|
||||||
memory_config: MemoryConfig,
|
memory_config: MemoryConfig,
|
||||||
messages: list,
|
messages: list,
|
||||||
ref_id: str = "wyl20251027",
|
ref_id: str = "",
|
||||||
language: str = "zh",
|
language: str = "zh",
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Execute the complete knowledge extraction pipeline.
|
Execute the complete knowledge extraction pipeline.
|
||||||
@@ -43,9 +45,11 @@ async def write(
|
|||||||
end_user_id: Group identifier
|
end_user_id: Group identifier
|
||||||
memory_config: MemoryConfig object containing all configuration
|
memory_config: MemoryConfig object containing all configuration
|
||||||
messages: Structured message list [{"role": "user", "content": "..."}, ...]
|
messages: Structured message list [{"role": "user", "content": "..."}, ...]
|
||||||
ref_id: Reference ID, defaults to "wyl20251027"
|
ref_id: Reference ID, defaults to ""
|
||||||
language: 语言类型 ("zh" 中文, "en" 英文),默认中文
|
language: 语言类型 ("zh" 中文, "en" 英文),默认中文
|
||||||
"""
|
"""
|
||||||
|
if not ref_id:
|
||||||
|
ref_id = uuid.uuid4().hex
|
||||||
# Extract config values
|
# Extract config values
|
||||||
embedding_model_id = str(memory_config.embedding_model_id)
|
embedding_model_id = str(memory_config.embedding_model_id)
|
||||||
chunker_strategy = memory_config.chunker_strategy
|
chunker_strategy = memory_config.chunker_strategy
|
||||||
@@ -135,9 +139,11 @@ async def write(
|
|||||||
all_chunk_nodes,
|
all_chunk_nodes,
|
||||||
all_statement_nodes,
|
all_statement_nodes,
|
||||||
all_entity_nodes,
|
all_entity_nodes,
|
||||||
|
all_perceptual_nodes,
|
||||||
all_statement_chunk_edges,
|
all_statement_chunk_edges,
|
||||||
all_statement_entity_edges,
|
all_statement_entity_edges,
|
||||||
all_entity_entity_edges,
|
all_entity_entity_edges,
|
||||||
|
all_perceptual_edges,
|
||||||
all_dedup_details,
|
all_dedup_details,
|
||||||
) = await orchestrator.run(chunked_dialogs, is_pilot_run=False)
|
) = await orchestrator.run(chunked_dialogs, is_pilot_run=False)
|
||||||
|
|
||||||
@@ -145,11 +151,6 @@ async def write(
|
|||||||
|
|
||||||
# Step 3: Save all data to Neo4j database
|
# Step 3: Save all data to Neo4j database
|
||||||
step_start = time.time()
|
step_start = time.time()
|
||||||
from app.repositories.neo4j.create_indexes import create_fulltext_indexes
|
|
||||||
try:
|
|
||||||
await create_fulltext_indexes()
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error creating indexes: {e}", exc_info=True)
|
|
||||||
|
|
||||||
# 添加死锁重试机制
|
# 添加死锁重试机制
|
||||||
max_retries = 3
|
max_retries = 3
|
||||||
@@ -162,19 +163,43 @@ async def write(
|
|||||||
chunk_nodes=all_chunk_nodes,
|
chunk_nodes=all_chunk_nodes,
|
||||||
statement_nodes=all_statement_nodes,
|
statement_nodes=all_statement_nodes,
|
||||||
entity_nodes=all_entity_nodes,
|
entity_nodes=all_entity_nodes,
|
||||||
|
perceptual_nodes=all_perceptual_nodes,
|
||||||
statement_chunk_edges=all_statement_chunk_edges,
|
statement_chunk_edges=all_statement_chunk_edges,
|
||||||
statement_entity_edges=all_statement_entity_edges,
|
statement_entity_edges=all_statement_entity_edges,
|
||||||
entity_edges=all_entity_entity_edges,
|
entity_edges=all_entity_entity_edges,
|
||||||
|
perceptual_edges=all_perceptual_edges,
|
||||||
connector=neo4j_connector,
|
connector=neo4j_connector,
|
||||||
)
|
)
|
||||||
if success:
|
if success:
|
||||||
logger.info("Successfully saved all data to Neo4j")
|
logger.info("Successfully saved all data to Neo4j")
|
||||||
# 写入成功后,异步触发聚类(不阻塞写入响应)
|
|
||||||
schedule_clustering_after_write(
|
# 使用 Celery 异步任务触发聚类(不阻塞主流程)
|
||||||
all_entity_nodes,
|
if all_entity_nodes:
|
||||||
llm_model_id=str(memory_config.llm_model_id) if memory_config.llm_model_id else None,
|
try:
|
||||||
embedding_model_id=str(memory_config.embedding_model_id) if memory_config.embedding_model_id else None,
|
from app.tasks import run_incremental_clustering
|
||||||
)
|
|
||||||
|
end_user_id = all_entity_nodes[0].end_user_id
|
||||||
|
new_entity_ids = [e.id for e in all_entity_nodes]
|
||||||
|
|
||||||
|
# 异步提交 Celery 任务
|
||||||
|
task = run_incremental_clustering.apply_async(
|
||||||
|
kwargs={
|
||||||
|
"end_user_id": end_user_id,
|
||||||
|
"new_entity_ids": new_entity_ids,
|
||||||
|
"llm_model_id": str(memory_config.llm_model_id) if memory_config.llm_model_id else None,
|
||||||
|
"embedding_model_id": str(memory_config.embedding_model_id) if memory_config.embedding_model_id else None,
|
||||||
|
},
|
||||||
|
# 设置任务优先级(低优先级,不影响主业务)
|
||||||
|
priority=3,
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
f"[Clustering] 增量聚类任务已提交到 Celery - "
|
||||||
|
f"task_id={task.id}, end_user_id={end_user_id}, entity_count={len(new_entity_ids)}"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
# 聚类任务提交失败不影响主流程
|
||||||
|
logger.error(f"[Clustering] 提交聚类任务失败(不影响主流程): {e}", exc_info=True)
|
||||||
|
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
logger.warning("Failed to save some data to Neo4j")
|
logger.warning("Failed to save some data to Neo4j")
|
||||||
@@ -208,9 +233,8 @@ async def write(
|
|||||||
summaries = await memory_summary_generation(
|
summaries = await memory_summary_generation(
|
||||||
chunked_dialogs, llm_client=llm_client, embedder_client=embedder_client, language=language
|
chunked_dialogs, llm_client=llm_client, embedder_client=embedder_client, language=language
|
||||||
)
|
)
|
||||||
|
ms_connector = Neo4jConnector()
|
||||||
try:
|
try:
|
||||||
ms_connector = Neo4jConnector()
|
|
||||||
await add_memory_summary_nodes(summaries, ms_connector)
|
await add_memory_summary_nodes(summaries, ms_connector)
|
||||||
await add_memory_summary_statement_edges(summaries, ms_connector)
|
await add_memory_summary_statement_edges(summaries, ms_connector)
|
||||||
finally:
|
finally:
|
||||||
@@ -250,5 +274,21 @@ async def write(
|
|||||||
except Exception as cache_err:
|
except Exception as cache_err:
|
||||||
logger.warning(f"[WRITE] 写入活动统计缓存失败(不影响主流程): {cache_err}", exc_info=True)
|
logger.warning(f"[WRITE] 写入活动统计缓存失败(不影响主流程): {cache_err}", exc_info=True)
|
||||||
|
|
||||||
|
# Close LLM/Embedder underlying httpx clients to prevent
|
||||||
|
# 'RuntimeError: Event loop is closed' during garbage collection
|
||||||
|
for client_obj in (llm_client, embedder_client):
|
||||||
|
try:
|
||||||
|
underlying = getattr(client_obj, 'client', None) or getattr(client_obj, 'model', None)
|
||||||
|
if underlying is None:
|
||||||
|
continue
|
||||||
|
# Unwrap RedBearLLM / RedBearEmbeddings to get the LangChain model
|
||||||
|
inner = getattr(underlying, '_model', underlying)
|
||||||
|
# LangChain OpenAI models expose async_client (httpx.AsyncClient)
|
||||||
|
http_client = getattr(inner, 'async_client', None)
|
||||||
|
if http_client is not None and hasattr(http_client, 'aclose'):
|
||||||
|
await http_client.aclose()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
logger.info("=== Pipeline Complete ===")
|
logger.info("=== Pipeline Complete ===")
|
||||||
logger.info(f"Total execution time: {total_time:.2f} seconds")
|
logger.info(f"Total execution time: {total_time:.2f} seconds")
|
||||||
@@ -1,10 +1,10 @@
|
|||||||
from typing import Any, List
|
|
||||||
import re
|
|
||||||
import os
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import numpy as np
|
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
|
from typing import Any, List
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
# Fix tokenizer parallelism warning
|
# Fix tokenizer parallelism warning
|
||||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||||
@@ -246,6 +246,7 @@ class ChunkerClient:
|
|||||||
"total_sub_chunks": len(sub_chunks),
|
"total_sub_chunks": len(sub_chunks),
|
||||||
"chunker_strategy": self.chunker_config.chunker_strategy,
|
"chunker_strategy": self.chunker_config.chunker_strategy,
|
||||||
},
|
},
|
||||||
|
files=msg.files
|
||||||
)
|
)
|
||||||
dialogue.chunks.append(chunk)
|
dialogue.chunks.append(chunk)
|
||||||
else:
|
else:
|
||||||
@@ -258,6 +259,7 @@ class ChunkerClient:
|
|||||||
"message_role": msg.role,
|
"message_role": msg.role,
|
||||||
"chunker_strategy": self.chunker_config.chunker_strategy,
|
"chunker_strategy": self.chunker_config.chunker_strategy,
|
||||||
},
|
},
|
||||||
|
files=msg.files
|
||||||
)
|
)
|
||||||
dialogue.chunks.append(chunk)
|
dialogue.chunks.append(chunk)
|
||||||
|
|
||||||
|
|||||||
@@ -65,7 +65,7 @@ class OpenAIClient(LLMClient):
|
|||||||
type=type_
|
type=type_
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(f"OpenAI 客户端初始化完成: type={type_}")
|
logger.debug(f"OpenAI 客户端初始化完成: type={type_}")
|
||||||
|
|
||||||
async def chat(self, messages: List[Dict[str, str]], **kwargs) -> Any:
|
async def chat(self, messages: List[Dict[str, str]], **kwargs) -> Any:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -2,6 +2,7 @@
|
|||||||
OpenAI Embedder 客户端实现
|
OpenAI Embedder 客户端实现
|
||||||
|
|
||||||
基于 LangChain 和 RedBearEmbeddings 的 OpenAI 嵌入模型客户端实现。
|
基于 LangChain 和 RedBearEmbeddings 的 OpenAI 嵌入模型客户端实现。
|
||||||
|
自动支持火山引擎的多模态 Embedding。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import List
|
from typing import List
|
||||||
@@ -13,6 +14,7 @@ from app.core.memory.llm_tools.embedder_client import (
|
|||||||
)
|
)
|
||||||
from app.core.models.base import RedBearModelConfig
|
from app.core.models.base import RedBearModelConfig
|
||||||
from app.core.models.embedding import RedBearEmbeddings
|
from app.core.models.embedding import RedBearEmbeddings
|
||||||
|
from app.models.models_model import ModelProvider
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -25,6 +27,7 @@ class OpenAIEmbedderClient(EmbedderClient):
|
|||||||
- 批量文本嵌入
|
- 批量文本嵌入
|
||||||
- 自动重试机制
|
- 自动重试机制
|
||||||
- 错误处理
|
- 错误处理
|
||||||
|
- 火山引擎多模态 Embedding(自动识别)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, model_config: RedBearModelConfig):
|
def __init__(self, model_config: RedBearModelConfig):
|
||||||
@@ -36,7 +39,7 @@ class OpenAIEmbedderClient(EmbedderClient):
|
|||||||
"""
|
"""
|
||||||
super().__init__(model_config)
|
super().__init__(model_config)
|
||||||
|
|
||||||
# 初始化 RedBearEmbeddings 模型
|
# 初始化 RedBearEmbeddings(自动支持火山引擎多模态)
|
||||||
self.model = RedBearEmbeddings(
|
self.model = RedBearEmbeddings(
|
||||||
RedBearModelConfig(
|
RedBearModelConfig(
|
||||||
model_name=self.model_name,
|
model_name=self.model_name,
|
||||||
@@ -47,8 +50,9 @@ class OpenAIEmbedderClient(EmbedderClient):
|
|||||||
timeout=self.timeout,
|
timeout=self.timeout,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
self.is_multimodal = self.model.is_multimodal_supported()
|
||||||
|
|
||||||
logger.info("OpenAI Embedder 客户端初始化完成")
|
logger.info(f"OpenAI Embedder 客户端初始化完成 (provider={self.provider}, multimodal={self.is_multimodal})")
|
||||||
|
|
||||||
async def response(
|
async def response(
|
||||||
self,
|
self,
|
||||||
@@ -77,7 +81,14 @@ class OpenAIEmbedderClient(EmbedderClient):
|
|||||||
return []
|
return []
|
||||||
|
|
||||||
# 生成嵌入向量
|
# 生成嵌入向量
|
||||||
embeddings = await self.model.aembed_documents(texts)
|
if self.is_multimodal:
|
||||||
|
# 火山引擎多模态 Embedding
|
||||||
|
embeddings = await self.model.aembed_multimodal(
|
||||||
|
[{"type": "text", "text": text} for text in texts]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# 普通 Embedding
|
||||||
|
embeddings = await self.model.aembed_documents(texts)
|
||||||
|
|
||||||
logger.debug(f"成功生成 {len(embeddings)} 个嵌入向量")
|
logger.debug(f"成功生成 {len(embeddings)} 个嵌入向量")
|
||||||
return embeddings
|
return embeddings
|
||||||
|
|||||||
@@ -114,7 +114,7 @@ class Edge(BaseModel):
|
|||||||
end_user_id: str = Field(..., description="The end user ID of the edge.")
|
end_user_id: str = Field(..., description="The end user ID of the edge.")
|
||||||
run_id: str = Field(default_factory=lambda: uuid4().hex, description="Unique identifier for this pipeline run.")
|
run_id: str = Field(default_factory=lambda: uuid4().hex, description="Unique identifier for this pipeline run.")
|
||||||
created_at: datetime = Field(..., description="The valid time of the edge from system perspective.")
|
created_at: datetime = Field(..., description="The valid time of the edge from system perspective.")
|
||||||
expired_at: Optional[datetime] = Field(None, description="The expired time of the edge from system perspective.")
|
expired_at: Optional[datetime] = Field(default=None, description="The expired time of the edge from system perspective.")
|
||||||
|
|
||||||
|
|
||||||
class ChunkEdge(Edge):
|
class ChunkEdge(Edge):
|
||||||
@@ -175,6 +175,12 @@ class EntityEntityEdge(Edge):
|
|||||||
return parse_historical_datetime(v)
|
return parse_historical_datetime(v)
|
||||||
|
|
||||||
|
|
||||||
|
class PerceptualEdge(Edge):
|
||||||
|
"""Edge connecting perceptual nodes to their source chunks
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class Node(BaseModel):
|
class Node(BaseModel):
|
||||||
"""Base class for all graph nodes in the knowledge graph.
|
"""Base class for all graph nodes in the knowledge graph.
|
||||||
|
|
||||||
@@ -206,7 +212,8 @@ class DialogueNode(Node):
|
|||||||
ref_id: str = Field(..., description="Reference identifier of the dialog")
|
ref_id: str = Field(..., description="Reference identifier of the dialog")
|
||||||
content: str = Field(..., description="Dialogue content")
|
content: str = Field(..., description="Dialogue content")
|
||||||
dialog_embedding: Optional[List[float]] = Field(None, description="Dialog embedding vector")
|
dialog_embedding: Optional[List[float]] = Field(None, description="Dialog embedding vector")
|
||||||
config_id: Optional[int | str] = Field(None, description="Configuration ID used to process this dialogue (integer or string)")
|
config_id: Optional[int | str] = Field(None,
|
||||||
|
description="Configuration ID used to process this dialogue (integer or string)")
|
||||||
|
|
||||||
|
|
||||||
class StatementNode(Node):
|
class StatementNode(Node):
|
||||||
@@ -281,7 +288,8 @@ class StatementNode(Node):
|
|||||||
statement_embedding: Optional[List[float]] = Field(None, description="Statement embedding vector")
|
statement_embedding: Optional[List[float]] = Field(None, description="Statement embedding vector")
|
||||||
chunk_embedding: Optional[List[float]] = Field(None, description="Chunk embedding vector")
|
chunk_embedding: Optional[List[float]] = Field(None, description="Chunk embedding vector")
|
||||||
connect_strength: str = Field(..., description="Strong VS Weak classification of this statement")
|
connect_strength: str = Field(..., description="Strong VS Weak classification of this statement")
|
||||||
config_id: Optional[int | str] = Field(None, description="Configuration ID used to process this statement (integer or string)")
|
config_id: Optional[int | str] = Field(None,
|
||||||
|
description="Configuration ID used to process this statement (integer or string)")
|
||||||
|
|
||||||
# ACT-R Memory Activation Properties
|
# ACT-R Memory Activation Properties
|
||||||
importance_score: float = Field(
|
importance_score: float = Field(
|
||||||
@@ -416,7 +424,8 @@ class ExtractedEntityNode(Node):
|
|||||||
# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
|
# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
|
||||||
# fact_summary: str = Field(default="", description="Summary of the fact about this entity")
|
# fact_summary: str = Field(default="", description="Summary of the fact about this entity")
|
||||||
connect_strength: str = Field(..., description="Strong VS Weak about this entity")
|
connect_strength: str = Field(..., description="Strong VS Weak about this entity")
|
||||||
config_id: Optional[int | str] = Field(None, description="Configuration ID used to process this entity (integer or string)")
|
config_id: Optional[int | str] = Field(None,
|
||||||
|
description="Configuration ID used to process this entity (integer or string)")
|
||||||
|
|
||||||
# ACT-R Memory Activation Properties
|
# ACT-R Memory Activation Properties
|
||||||
importance_score: float = Field(
|
importance_score: float = Field(
|
||||||
@@ -453,7 +462,7 @@ class ExtractedEntityNode(Node):
|
|||||||
|
|
||||||
@field_validator('aliases', mode='before')
|
@field_validator('aliases', mode='before')
|
||||||
@classmethod
|
@classmethod
|
||||||
def validate_aliases_field(cls, v): # 字段验证器 自动清理和验证 aliases 字段
|
def validate_aliases_field(cls, v): # 字段验证器 自动清理和验证 aliases 字段
|
||||||
"""Validate and clean aliases field using utility function.
|
"""Validate and clean aliases field using utility function.
|
||||||
|
|
||||||
This validator ensures that the aliases field is always a valid list of strings.
|
This validator ensures that the aliases field is always a valid list of strings.
|
||||||
@@ -507,7 +516,8 @@ class MemorySummaryNode(Node):
|
|||||||
memory_type: Optional[str] = Field(None, description="Type/category of the episodic memory")
|
memory_type: Optional[str] = Field(None, description="Type/category of the episodic memory")
|
||||||
summary_embedding: Optional[List[float]] = Field(None, description="Embedding vector for the summary")
|
summary_embedding: Optional[List[float]] = Field(None, description="Embedding vector for the summary")
|
||||||
metadata: dict = Field(default_factory=dict, description="Additional metadata for the summary")
|
metadata: dict = Field(default_factory=dict, description="Additional metadata for the summary")
|
||||||
config_id: Optional[int | str] = Field(None, description="Configuration ID used to process this summary (integer or string)")
|
config_id: Optional[int | str] = Field(None,
|
||||||
|
description="Configuration ID used to process this summary (integer or string)")
|
||||||
|
|
||||||
# ACT-R Forgetting Engine Properties
|
# ACT-R Forgetting Engine Properties
|
||||||
original_statement_id: Optional[str] = Field(
|
original_statement_id: Optional[str] = Field(
|
||||||
@@ -549,3 +559,18 @@ class MemorySummaryNode(Node):
|
|||||||
ge=0,
|
ge=0,
|
||||||
description="Total number of times this node has been accessed (reset to 1 on creation)"
|
description="Total number of times this node has been accessed (reset to 1 on creation)"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class PerceptualNode(Node):
|
||||||
|
"""Node representing a multimodal message in the knowledge graph.
|
||||||
|
"""
|
||||||
|
perceptual_type: int
|
||||||
|
file_path: str
|
||||||
|
file_name: str
|
||||||
|
file_ext: str
|
||||||
|
summary: str
|
||||||
|
keywords: list[str]
|
||||||
|
topic: str
|
||||||
|
domain: str
|
||||||
|
file_type: str
|
||||||
|
summary_embedding: list[float] | None
|
||||||
|
|||||||
@@ -30,6 +30,7 @@ class ConversationMessage(BaseModel):
|
|||||||
"""
|
"""
|
||||||
role: str = Field(..., description="The role of the speaker (e.g., 'user', 'assistant').")
|
role: str = Field(..., description="The role of the speaker (e.g., 'user', 'assistant').")
|
||||||
msg: str = Field(..., description="The text content of the message.")
|
msg: str = Field(..., description="The text content of the message.")
|
||||||
|
files: list[tuple] = Field(default_factory=list, description="The file content of the message", exclude=True)
|
||||||
|
|
||||||
|
|
||||||
class TemporalValidityRange(BaseModel):
|
class TemporalValidityRange(BaseModel):
|
||||||
@@ -130,7 +131,8 @@ class Chunk(BaseModel):
|
|||||||
content: str = Field(..., description="The content of the chunk as a string.")
|
content: str = Field(..., description="The content of the chunk as a string.")
|
||||||
speaker: Optional[str] = Field(None, description="The speaker/role for this chunk (user/assistant).")
|
speaker: Optional[str] = Field(None, description="The speaker/role for this chunk (user/assistant).")
|
||||||
statements: List[Statement] = Field(default_factory=list, description="A list of statements in the chunk.")
|
statements: List[Statement] = Field(default_factory=list, description="A list of statements in the chunk.")
|
||||||
chunk_embedding: Optional[List[float]] = Field(None, description="The embedding vector of the chunk.")
|
files: list[tuple] = Field(default_factory=list, description="List of files in the chunk.")
|
||||||
|
chunk_embedding: Optional[List[float]] = Field(default=None, description="The embedding vector of the chunk.")
|
||||||
metadata: Dict[str, Any] = Field(default_factory=dict, description="Additional metadata for the chunk.")
|
metadata: Dict[str, Any] = Field(default_factory=dict, description="Additional metadata for the chunk.")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
@@ -76,6 +76,9 @@ class LabelPropagationEngine:
|
|||||||
self.repo = CommunityRepository(connector)
|
self.repo = CommunityRepository(connector)
|
||||||
self.llm_model_id = llm_model_id
|
self.llm_model_id = llm_model_id
|
||||||
self.embedding_model_id = embedding_model_id
|
self.embedding_model_id = embedding_model_id
|
||||||
|
# 缓存客户端实例,避免重复初始化
|
||||||
|
self._llm_client = None
|
||||||
|
self._embedder_client = None
|
||||||
|
|
||||||
# ──────────────────────────────────────────────────────────────────────────
|
# ──────────────────────────────────────────────────────────────────────────
|
||||||
# 公开接口
|
# 公开接口
|
||||||
@@ -215,8 +218,17 @@ class LabelPropagationEngine:
|
|||||||
3. 若邻居无社区 → 创建新社区
|
3. 若邻居无社区 → 创建新社区
|
||||||
4. 若邻居分属多个社区 → 评估是否合并
|
4. 若邻居分属多个社区 → 评估是否合并
|
||||||
"""
|
"""
|
||||||
|
# 收集所有需要生成元数据的社区ID
|
||||||
|
communities_to_update = set()
|
||||||
|
|
||||||
for entity_id in new_entity_ids:
|
for entity_id in new_entity_ids:
|
||||||
await self._process_single_entity(entity_id, end_user_id)
|
cid = await self._process_single_entity(entity_id, end_user_id)
|
||||||
|
if cid:
|
||||||
|
communities_to_update.add(cid)
|
||||||
|
|
||||||
|
# 批量生成所有社区的元数据
|
||||||
|
if communities_to_update:
|
||||||
|
await self._generate_community_metadata(list(communities_to_update), end_user_id, force=True)
|
||||||
|
|
||||||
# ──────────────────────────────────────────────────────────────────────────
|
# ──────────────────────────────────────────────────────────────────────────
|
||||||
# 内部方法
|
# 内部方法
|
||||||
@@ -224,8 +236,21 @@ class LabelPropagationEngine:
|
|||||||
|
|
||||||
async def _process_single_entity(
|
async def _process_single_entity(
|
||||||
self, entity_id: str, end_user_id: str
|
self, entity_id: str, end_user_id: str
|
||||||
) -> None:
|
) -> Optional[str]:
|
||||||
"""处理单个新实体的社区分配。"""
|
"""
|
||||||
|
处理单个新实体的社区分配。
|
||||||
|
|
||||||
|
该函数会为新实体分配社区,可能的情况包括:
|
||||||
|
1. 孤立实体(无邻居):创建新的单成员社区
|
||||||
|
2. 邻居都没有社区:创建新社区并将实体和邻居都加入
|
||||||
|
3. 邻居有社区:通过加权投票选择最合适的社区加入
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Optional[str]: 分配到的社区ID。当前实现总是返回一个有效的社区ID,
|
||||||
|
但返回类型保留为Optional以支持未来可能的扩展场景
|
||||||
|
(例如:实体无法分配到任何社区的情况)。
|
||||||
|
调用方应检查返回值的真假性(truthiness)。
|
||||||
|
"""
|
||||||
neighbors = await self.repo.get_entity_neighbors(entity_id, end_user_id)
|
neighbors = await self.repo.get_entity_neighbors(entity_id, end_user_id)
|
||||||
|
|
||||||
# 查询自身 embedding(从邻居查询结果中无法获取,需单独查)
|
# 查询自身 embedding(从邻居查询结果中无法获取,需单独查)
|
||||||
@@ -237,7 +262,7 @@ class LabelPropagationEngine:
|
|||||||
await self.repo.upsert_community(new_cid, end_user_id, member_count=1)
|
await self.repo.upsert_community(new_cid, end_user_id, member_count=1)
|
||||||
await self.repo.assign_entity_to_community(entity_id, new_cid, end_user_id)
|
await self.repo.assign_entity_to_community(entity_id, new_cid, end_user_id)
|
||||||
logger.debug(f"[Clustering] 孤立实体 {entity_id} → 新社区 {new_cid}")
|
logger.debug(f"[Clustering] 孤立实体 {entity_id} → 新社区 {new_cid}")
|
||||||
return
|
return new_cid
|
||||||
|
|
||||||
# 统计邻居社区分布
|
# 统计邻居社区分布
|
||||||
community_ids_in_neighbors = set(
|
community_ids_in_neighbors = set(
|
||||||
@@ -259,7 +284,7 @@ class LabelPropagationEngine:
|
|||||||
logger.debug(
|
logger.debug(
|
||||||
f"[Clustering] 新实体 {entity_id} 与 {len(neighbors)} 个无社区邻居 → 新社区 {new_cid}"
|
f"[Clustering] 新实体 {entity_id} 与 {len(neighbors)} 个无社区邻居 → 新社区 {new_cid}"
|
||||||
)
|
)
|
||||||
await self._generate_community_metadata([new_cid], end_user_id)
|
return new_cid
|
||||||
else:
|
else:
|
||||||
# 加入得票最多的社区
|
# 加入得票最多的社区
|
||||||
await self.repo.assign_entity_to_community(entity_id, target_cid, end_user_id)
|
await self.repo.assign_entity_to_community(entity_id, target_cid, end_user_id)
|
||||||
@@ -271,7 +296,8 @@ class LabelPropagationEngine:
|
|||||||
await self._evaluate_merge(
|
await self._evaluate_merge(
|
||||||
list(community_ids_in_neighbors), end_user_id
|
list(community_ids_in_neighbors), end_user_id
|
||||||
)
|
)
|
||||||
await self._generate_community_metadata([target_cid], end_user_id)
|
# 返回目标社区ID,稍后批量生成元数据
|
||||||
|
return target_cid
|
||||||
|
|
||||||
async def _evaluate_merge(
|
async def _evaluate_merge(
|
||||||
self, community_ids: List[str], end_user_id: str
|
self, community_ids: List[str], end_user_id: str
|
||||||
@@ -451,80 +477,84 @@ class LabelPropagationEngine:
|
|||||||
return lines
|
return lines
|
||||||
|
|
||||||
async def _generate_community_metadata(
|
async def _generate_community_metadata(
|
||||||
self, community_ids: List[str], end_user_id: str
|
self, community_ids: List[str], end_user_id: str, force: bool = False
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
为一个或多个社区生成并写入元数据。
|
为一个或多个社区生成并写入元数据(优化版:批量 LLM 调用)。
|
||||||
|
|
||||||
流程:
|
流程:
|
||||||
1. 逐个社区调 LLM 生成 name / summary(串行)
|
1. 批量准备所有社区的 prompt
|
||||||
2. 收集所有 summary,一次性批量 embed
|
2. 并发调用 LLM 生成所有社区的 name / summary
|
||||||
3. 单个社区用 update_community_metadata,多个用 batch_update_community_metadata
|
3. 批量 embed 所有 summary
|
||||||
|
4. 批量写入数据库
|
||||||
|
|
||||||
|
Args:
|
||||||
|
force: 为 True 时跳过完整性检查,强制重新生成(用于增量更新成员变化后)
|
||||||
"""
|
"""
|
||||||
if not community_ids:
|
async def _prepare_one(cid: str) -> Optional[Dict]:
|
||||||
return
|
"""准备单个社区的数据和 prompt"""
|
||||||
|
try:
|
||||||
|
if not force:
|
||||||
|
check_embedding = bool(self.embedding_model_id)
|
||||||
|
if await self.repo.is_community_complete(cid, end_user_id, check_embedding=check_embedding):
|
||||||
|
return None
|
||||||
|
|
||||||
from app.db import get_db_context
|
members = await self.repo.get_community_members(cid, end_user_id)
|
||||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
if not members:
|
||||||
|
logger.warning(f"[Clustering] 社区 {cid} 无成员,跳过元数据生成")
|
||||||
|
return None
|
||||||
|
|
||||||
# --- 阶段1:并发调 LLM 生成每个社区的 name / summary ---
|
sorted_members = sorted(
|
||||||
async def _build_one(cid: str):
|
members,
|
||||||
members = await self.repo.get_community_members(cid, end_user_id)
|
key=lambda m: m.get("activation_value") or 0,
|
||||||
if not members:
|
reverse=True,
|
||||||
|
)
|
||||||
|
core_entities = [m["name"] for m in sorted_members[:CORE_ENTITY_LIMIT] if m.get("name")]
|
||||||
|
all_names = [m["name"] for m in members if m.get("name")]
|
||||||
|
|
||||||
|
# 默认值
|
||||||
|
name = "、".join(core_entities[:3]) if core_entities else cid[:8]
|
||||||
|
summary = f"包含实体:{', '.join(all_names)}"
|
||||||
|
|
||||||
|
# 准备 LLM prompt(如果配置了 LLM)
|
||||||
|
prompt = None
|
||||||
|
if self.llm_model_id:
|
||||||
|
entity_list_str = "\n".join(self._build_entity_lines(members))
|
||||||
|
relationships = await self.repo.get_community_relationships(cid, end_user_id)
|
||||||
|
rel_lines = [
|
||||||
|
f"- {r['subject']} → {r['predicate']} → {r['object']}"
|
||||||
|
for r in relationships
|
||||||
|
if r.get("subject") and r.get("predicate") and r.get("object")
|
||||||
|
]
|
||||||
|
rel_section = (
|
||||||
|
f"\n实体间关系:\n" + "\n".join(rel_lines)
|
||||||
|
if rel_lines else ""
|
||||||
|
)
|
||||||
|
prompt = (
|
||||||
|
f"以下是一组语义相关的实体:\n{entity_list_str}{rel_section}\n\n"
|
||||||
|
f"请为这组实体所代表的主题:\n"
|
||||||
|
f"1. 起一个简洁的中文名称(不超过10个字)\n"
|
||||||
|
f"2. 写一句话摘要(不超过80个字)\n\n"
|
||||||
|
f"严格按以下格式输出,不要有其他内容:\n"
|
||||||
|
f"名称:<名称>\n摘要:<摘要>"
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"community_id": cid,
|
||||||
|
"end_user_id": end_user_id,
|
||||||
|
"name": name,
|
||||||
|
"summary": summary,
|
||||||
|
"core_entities": core_entities,
|
||||||
|
"prompt": prompt,
|
||||||
|
"summary_embedding": None,
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"[Clustering] 社区 {cid} 元数据准备失败: {e}", exc_info=True)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
sorted_members = sorted(
|
# --- 阶段1:并发准备所有社区数据 ---
|
||||||
members,
|
|
||||||
key=lambda m: m.get("activation_value") or 0,
|
|
||||||
reverse=True,
|
|
||||||
)
|
|
||||||
core_entities = [m["name"] for m in sorted_members[:CORE_ENTITY_LIMIT] if m.get("name")]
|
|
||||||
|
|
||||||
entity_list_str = "\n".join(self._build_entity_lines(members))
|
|
||||||
|
|
||||||
# 方案四:注入社区内实体间关系三元组
|
|
||||||
relationships = await self.repo.get_community_relationships(cid, end_user_id)
|
|
||||||
rel_lines = [
|
|
||||||
f"- {r['subject']} → {r['predicate']} → {r['object']}"
|
|
||||||
for r in relationships
|
|
||||||
if r.get("subject") and r.get("predicate") and r.get("object")
|
|
||||||
]
|
|
||||||
rel_section = (
|
|
||||||
f"\n实体间关系:\n" + "\n".join(rel_lines)
|
|
||||||
if rel_lines else ""
|
|
||||||
)
|
|
||||||
|
|
||||||
prompt = (
|
|
||||||
f"以下是一组语义相关的实体:\n{entity_list_str}{rel_section}\n\n"
|
|
||||||
f"请为这组实体所代表的主题:\n"
|
|
||||||
f"1. 起一个简洁的中文名称(不超过10个字)\n"
|
|
||||||
f"2. 写一句话摘要(不超过80个字)\n\n"
|
|
||||||
f"严格按以下格式输出,不要有其他内容:\n"
|
|
||||||
f"名称:<名称>\n摘要:<摘要>"
|
|
||||||
)
|
|
||||||
with get_db_context() as db:
|
|
||||||
llm_client = MemoryClientFactory(db).get_llm_client(self.llm_model_id)
|
|
||||||
response = await llm_client.chat([{"role": "user", "content": prompt}])
|
|
||||||
text = response.content if hasattr(response, "content") else str(response)
|
|
||||||
|
|
||||||
name, summary = "", ""
|
|
||||||
for line in text.strip().splitlines():
|
|
||||||
if line.startswith("名称:"):
|
|
||||||
name = line[3:].strip()
|
|
||||||
elif line.startswith("摘要:"):
|
|
||||||
summary = line[3:].strip()
|
|
||||||
|
|
||||||
return {
|
|
||||||
"community_id": cid,
|
|
||||||
"end_user_id": end_user_id,
|
|
||||||
"name": name,
|
|
||||||
"summary": summary,
|
|
||||||
"core_entities": core_entities,
|
|
||||||
"summary_embedding": None,
|
|
||||||
}
|
|
||||||
|
|
||||||
results = await asyncio.gather(
|
results = await asyncio.gather(
|
||||||
*[_build_one(cid) for cid in community_ids],
|
*[_prepare_one(cid) for cid in community_ids],
|
||||||
return_exceptions=True,
|
return_exceptions=True,
|
||||||
)
|
)
|
||||||
metadata_list = []
|
metadata_list = []
|
||||||
@@ -535,17 +565,80 @@ class LabelPropagationEngine:
|
|||||||
metadata_list.append(res)
|
metadata_list.append(res)
|
||||||
|
|
||||||
if not metadata_list:
|
if not metadata_list:
|
||||||
|
logger.warning(f"[Clustering] 无有效元数据可写入,community_ids={community_ids}")
|
||||||
return
|
return
|
||||||
|
|
||||||
# --- 阶段2:批量生成 summary_embedding ---
|
# --- 阶段2:批量调用 LLM 生成 name 和 summary ---
|
||||||
summaries = [m["summary"] for m in metadata_list]
|
if self.llm_model_id:
|
||||||
with get_db_context() as db:
|
llm_client = self._get_llm_client()
|
||||||
embedder = MemoryClientFactory(db).get_embedder_client(self.embedding_model_id)
|
if not llm_client:
|
||||||
embeddings = await embedder.response(summaries)
|
logger.warning(
|
||||||
for i, meta in enumerate(metadata_list):
|
f"[Clustering] LLM 已配置(model_id={self.llm_model_id})但客户端初始化失败,"
|
||||||
meta["summary_embedding"] = embeddings[i] if i < len(embeddings) else None
|
f"将跳过社区元数据的 LLM 富化。请检查 model_id 是否正确或数据库连接是否正常。"
|
||||||
|
)
|
||||||
|
if llm_client:
|
||||||
|
prompts_to_process = [(i, m) for i, m in enumerate(metadata_list) if m.get("prompt")]
|
||||||
|
|
||||||
|
if prompts_to_process:
|
||||||
|
logger.info(f"[Clustering] 批量调用 LLM 生成 {len(prompts_to_process)} 个社区元数据")
|
||||||
|
|
||||||
|
async def _call_llm(idx: int, meta: Dict) -> tuple:
|
||||||
|
"""单个 LLM 调用"""
|
||||||
|
try:
|
||||||
|
response = await llm_client.chat([{"role": "user", "content": meta["prompt"]}])
|
||||||
|
text = response.content if hasattr(response, "content") else str(response)
|
||||||
|
return (idx, text, None)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"[Clustering] 社区 {meta['community_id']} LLM 生成失败: {e}")
|
||||||
|
return (idx, None, e)
|
||||||
|
|
||||||
|
# 并发调用所有 LLM 请求
|
||||||
|
llm_results = await asyncio.gather(
|
||||||
|
*[_call_llm(idx, meta) for idx, meta in prompts_to_process],
|
||||||
|
return_exceptions=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# 解析 LLM 响应
|
||||||
|
for result in llm_results:
|
||||||
|
if isinstance(result, Exception):
|
||||||
|
continue
|
||||||
|
idx, text, error = result
|
||||||
|
if error or not text:
|
||||||
|
continue
|
||||||
|
|
||||||
|
meta = metadata_list[idx]
|
||||||
|
for line in text.strip().splitlines():
|
||||||
|
if line.startswith("名称:"):
|
||||||
|
meta["name"] = line[3:].strip()
|
||||||
|
elif line.startswith("摘要:"):
|
||||||
|
meta["summary"] = line[3:].strip()
|
||||||
|
|
||||||
|
logger.info(f"[Clustering] LLM 批量生成完成")
|
||||||
|
|
||||||
|
# --- 阶段3:批量生成 summary_embedding ---
|
||||||
|
if self.embedding_model_id:
|
||||||
|
embedder = self._get_embedder_client()
|
||||||
|
if not embedder:
|
||||||
|
logger.warning(
|
||||||
|
f"[Clustering] Embedding 已配置(model_id={self.embedding_model_id})但客户端初始化失败,"
|
||||||
|
f"将跳过社区摘要的向量化。请检查 model_id 是否正确或数据库连接是否正常。"
|
||||||
|
)
|
||||||
|
if embedder:
|
||||||
|
try:
|
||||||
|
summaries = [m["summary"] for m in metadata_list]
|
||||||
|
logger.info(f"[Clustering] 批量生成 {len(summaries)} 个 summary embedding")
|
||||||
|
embeddings = await embedder.response(summaries)
|
||||||
|
for i, meta in enumerate(metadata_list):
|
||||||
|
meta["summary_embedding"] = embeddings[i] if i < len(embeddings) else None
|
||||||
|
logger.info(f"[Clustering] Embedding 批量生成完成")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"[Clustering] 批量生成 summary_embedding 失败: {e}", exc_info=True)
|
||||||
|
|
||||||
|
# --- 阶段4:批量写入数据库 ---
|
||||||
|
# 移除 prompt 字段(不需要存储)
|
||||||
|
for m in metadata_list:
|
||||||
|
m.pop("prompt", None)
|
||||||
|
|
||||||
# --- 阶段3:写入(单个 or 批量)---
|
|
||||||
if len(metadata_list) == 1:
|
if len(metadata_list) == 1:
|
||||||
m = metadata_list[0]
|
m = metadata_list[0]
|
||||||
result = await self.repo.update_community_metadata(
|
result = await self.repo.update_community_metadata(
|
||||||
@@ -556,16 +649,34 @@ class LabelPropagationEngine:
|
|||||||
core_entities=m["core_entities"],
|
core_entities=m["core_entities"],
|
||||||
summary_embedding=m["summary_embedding"],
|
summary_embedding=m["summary_embedding"],
|
||||||
)
|
)
|
||||||
if result:
|
if not result:
|
||||||
logger.info(f"[Clustering] 社区 {m['community_id']} 元数据写入成功: name={m['name']}, summary={m['summary'][:30]}...")
|
logger.error(f"[Clustering] 社区 {m['community_id']} 元数据写入失败")
|
||||||
else:
|
|
||||||
logger.warning(f"[Clustering] 社区 {m['community_id']} 元数据写入返回 False")
|
|
||||||
else:
|
else:
|
||||||
ok = await self.repo.batch_update_community_metadata(metadata_list)
|
ok = await self.repo.batch_update_community_metadata(metadata_list)
|
||||||
if ok:
|
if not ok:
|
||||||
logger.info(f"[Clustering] 批量写入 {len(metadata_list)} 个社区元数据成功")
|
logger.error(f"[Clustering] 批量写入 {len(metadata_list)} 个社区元数据失败")
|
||||||
else:
|
else:
|
||||||
logger.warning(f"[Clustering] 批量写入社区元数据失败")
|
logger.info(f"[Clustering] 批量写入 {len(metadata_list)} 个社区元数据成功")
|
||||||
|
|
||||||
|
def _get_llm_client(self):
|
||||||
|
"""获取或创建 LLM 客户端(单例模式)"""
|
||||||
|
if self._llm_client is None and self.llm_model_id:
|
||||||
|
from app.db import get_db_context
|
||||||
|
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||||
|
with get_db_context() as db:
|
||||||
|
self._llm_client = MemoryClientFactory(db).get_llm_client(self.llm_model_id)
|
||||||
|
logger.info(f"[Clustering] LLM 客户端初始化完成(单例): model_id={self.llm_model_id}")
|
||||||
|
return self._llm_client
|
||||||
|
|
||||||
|
def _get_embedder_client(self):
|
||||||
|
"""获取或创建 Embedder 客户端(单例模式)"""
|
||||||
|
if self._embedder_client is None and self.embedding_model_id:
|
||||||
|
from app.db import get_db_context
|
||||||
|
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||||
|
with get_db_context() as db:
|
||||||
|
self._embedder_client = MemoryClientFactory(db).get_embedder_client(self.embedding_model_id)
|
||||||
|
logger.info(f"[Clustering] Embedder 客户端初始化完成(单例): model_id={self.embedding_model_id}")
|
||||||
|
return self._embedder_client
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _new_community_id() -> str:
|
def _new_community_id() -> str:
|
||||||
|
|||||||
@@ -9,6 +9,7 @@
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
import hashlib
|
import hashlib
|
||||||
import json
|
import json
|
||||||
@@ -26,6 +27,20 @@ from app.core.memory.storage_services.extraction_engine.data_preprocessing.scene
|
|||||||
ScenePatterns
|
ScenePatterns
|
||||||
)
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def message_has_files(message: "ConversationMessage") -> bool:
|
||||||
|
"""检查消息是否包含文件。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message: 待检查的消息对象
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: 如果消息包含文件则返回 True,否则返回 False
|
||||||
|
"""
|
||||||
|
return message.files and len(message.files) > 0
|
||||||
|
|
||||||
|
|
||||||
class DialogExtractionResponse(BaseModel):
|
class DialogExtractionResponse(BaseModel):
|
||||||
"""对话级一次性抽取的结构化返回,用于加速剪枝。
|
"""对话级一次性抽取的结构化返回,用于加速剪枝。
|
||||||
@@ -125,7 +140,7 @@ class SemanticPruner:
|
|||||||
1. 空消息
|
1. 空消息
|
||||||
2. 场景特定填充词库精确匹配
|
2. 场景特定填充词库精确匹配
|
||||||
3. 常见寒暄精确匹配
|
3. 常见寒暄精确匹配
|
||||||
4. 组合寒暄模式(前缀+后缀组合,如"好的谢谢"、"同学你好"、"明白了")
|
4. 组合寒暄模式(前缀 + 后缀组合,如"好的谢谢"、"同学你好"、"明白了")
|
||||||
5. 纯表情/标点
|
5. 纯表情/标点
|
||||||
"""
|
"""
|
||||||
t = message.msg.strip()
|
t = message.msg.strip()
|
||||||
@@ -479,6 +494,11 @@ class SemanticPruner:
|
|||||||
"""
|
"""
|
||||||
to_delete_ids: set = set()
|
to_delete_ids: set = set()
|
||||||
for m in msgs:
|
for m in msgs:
|
||||||
|
# 最高优先级保护:带有文件的消息一律保留,不参与任何剪枝判断
|
||||||
|
if message_has_files(m):
|
||||||
|
self._log(f" [保护] 带文件的消息(不参与剪枝):'{m.msg[:40]}',文件数={len(m.files)}")
|
||||||
|
continue
|
||||||
|
|
||||||
# 填充检测优先:先判断是否为填充,再看 LLM 保护
|
# 填充检测优先:先判断是否为填充,再看 LLM 保护
|
||||||
if self._is_filler_message(m):
|
if self._is_filler_message(m):
|
||||||
to_delete_ids.add(id(m))
|
to_delete_ids.add(id(m))
|
||||||
@@ -547,6 +567,11 @@ class SemanticPruner:
|
|||||||
for m in msgs:
|
for m in msgs:
|
||||||
msg_text = m.msg.strip()
|
msg_text = m.msg.strip()
|
||||||
|
|
||||||
|
# 最高优先级保护:带有文件的消息一律保留,不参与任何剪枝判断
|
||||||
|
if message_has_files(m):
|
||||||
|
self._log(f" [保护] 带文件的消息(不参与剪枝):'{msg_text[:40]}',文件数={len(m.files)}")
|
||||||
|
continue
|
||||||
|
|
||||||
# 第一优先级:填充消息无论模式直接删除,不参与后续场景判断
|
# 第一优先级:填充消息无论模式直接删除,不参与后续场景判断
|
||||||
if self._is_filler_message(m):
|
if self._is_filler_message(m):
|
||||||
to_delete_ids.add(id(m))
|
to_delete_ids.add(id(m))
|
||||||
@@ -706,7 +731,7 @@ class SemanticPruner:
|
|||||||
# 阈值保护:最高0.9
|
# 阈值保护:最高0.9
|
||||||
proportion = float(self.config.pruning_threshold)
|
proportion = float(self.config.pruning_threshold)
|
||||||
if proportion > 0.9:
|
if proportion > 0.9:
|
||||||
print(f"[剪枝-数据集] 阈值{proportion}超过上限0.9,已自动调整为0.9")
|
logger.warning(f"[剪枝-数据集] 阈值{proportion}超过上限0.9,已自动调整为0.9")
|
||||||
proportion = 0.9
|
proportion = 0.9
|
||||||
if proportion < 0.0:
|
if proportion < 0.0:
|
||||||
proportion = 0.0
|
proportion = 0.0
|
||||||
@@ -799,6 +824,12 @@ class SemanticPruner:
|
|||||||
for idx, m in enumerate(msgs):
|
for idx, m in enumerate(msgs):
|
||||||
msg_text = m.msg.strip()
|
msg_text = m.msg.strip()
|
||||||
|
|
||||||
|
# 最高优先级保护:带有文件的消息一律保留,不参与分类
|
||||||
|
if message_has_files(m):
|
||||||
|
self._log(f" [保护] 带文件的消息(不参与分类,直接保留):索引{idx}, '{msg_text[:40]}', 文件数={len(m.files)}")
|
||||||
|
llm_protected_msgs.append((idx, m)) # 放入保护列表
|
||||||
|
continue
|
||||||
|
|
||||||
if self._msg_matches_tokens(m, preserve_tokens):
|
if self._msg_matches_tokens(m, preserve_tokens):
|
||||||
llm_protected_msgs.append((idx, m))
|
llm_protected_msgs.append((idx, m))
|
||||||
if should_log_details or idx < self._max_debug_msgs_per_dialog:
|
if should_log_details or idx < self._max_debug_msgs_per_dialog:
|
||||||
@@ -905,7 +936,7 @@ class SemanticPruner:
|
|||||||
|
|
||||||
# Safety: avoid empty dataset
|
# Safety: avoid empty dataset
|
||||||
if not result:
|
if not result:
|
||||||
print("警告: 语义剪枝后数据集为空,已回退为未剪枝数据以避免流程中断")
|
logger.warning("语义剪枝后数据集为空,已回退为未剪枝数据以避免流程中断")
|
||||||
return dialogs
|
return dialogs
|
||||||
|
|
||||||
return result
|
return result
|
||||||
@@ -915,8 +946,7 @@ class SemanticPruner:
|
|||||||
try:
|
try:
|
||||||
self.run_logs.append(msg)
|
self.run_logs.append(msg)
|
||||||
except Exception:
|
except Exception:
|
||||||
# 任何异常都不影响打印
|
|
||||||
pass
|
pass
|
||||||
print(msg)
|
logger.debug(msg)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -203,6 +203,7 @@ def accurate_match(
|
|||||||
) -> Tuple[List[ExtractedEntityNode], Dict[str, str], Dict[str, Dict]]:
|
) -> Tuple[List[ExtractedEntityNode], Dict[str, str], Dict[str, Dict]]:
|
||||||
"""
|
"""
|
||||||
精确匹配:按 (end_user_id, name, entity_type) 合并实体并建立重定向与合并记录。
|
精确匹配:按 (end_user_id, name, entity_type) 合并实体并建立重定向与合并记录。
|
||||||
|
同时检测某实体的 name 是否命中另一实体的 aliases,若命中则直接合并。
|
||||||
返回: (deduped_entities, id_redirect, exact_merge_map)
|
返回: (deduped_entities, id_redirect, exact_merge_map)
|
||||||
"""
|
"""
|
||||||
exact_merge_map: Dict[str, Dict] = {}
|
exact_merge_map: Dict[str, Dict] = {}
|
||||||
@@ -240,6 +241,48 @@ def accurate_match(
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
deduped_entities = list(canonical_map.values())
|
deduped_entities = list(canonical_map.values())
|
||||||
|
|
||||||
|
# 2) 第二轮:检测某实体的 name 是否命中另一实体的 aliases(alias-to-name 精确合并)
|
||||||
|
# 场景:LLM 把 aliases 中的词(如"齐齐")又单独抽取为独立实体,需在此阶段合并掉
|
||||||
|
# 优化:先构建 (end_user_id, alias_lower) -> canonical 的反向索引,查找 O(1)
|
||||||
|
alias_index: Dict[tuple, ExtractedEntityNode] = {}
|
||||||
|
for canonical in deduped_entities:
|
||||||
|
uid = getattr(canonical, "end_user_id", None)
|
||||||
|
for alias in (getattr(canonical, "aliases", []) or []):
|
||||||
|
alias_lower = alias.strip().lower()
|
||||||
|
if alias_lower:
|
||||||
|
alias_index[(uid, alias_lower)] = canonical
|
||||||
|
|
||||||
|
i = 0
|
||||||
|
while i < len(deduped_entities):
|
||||||
|
ent = deduped_entities[i]
|
||||||
|
ent_name = (getattr(ent, "name", "") or "").strip().lower()
|
||||||
|
ent_uid = getattr(ent, "end_user_id", None)
|
||||||
|
canonical = alias_index.get((ent_uid, ent_name))
|
||||||
|
# 确保不是自身
|
||||||
|
if canonical is not None and canonical.id != ent.id:
|
||||||
|
_merge_attribute(canonical, ent)
|
||||||
|
id_redirect[ent.id] = canonical.id
|
||||||
|
for k, v in list(id_redirect.items()):
|
||||||
|
if v == ent.id:
|
||||||
|
id_redirect[k] = canonical.id
|
||||||
|
try:
|
||||||
|
k = f"{canonical.end_user_id}|{(canonical.name or '').strip()}|{(canonical.entity_type or '').strip()}"
|
||||||
|
if k not in exact_merge_map:
|
||||||
|
exact_merge_map[k] = {
|
||||||
|
"canonical_id": canonical.id,
|
||||||
|
"end_user_id": canonical.end_user_id,
|
||||||
|
"name": canonical.name,
|
||||||
|
"entity_type": canonical.entity_type,
|
||||||
|
"merged_ids": set(),
|
||||||
|
}
|
||||||
|
exact_merge_map[k]["merged_ids"].add(ent.id)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
deduped_entities.pop(i)
|
||||||
|
else:
|
||||||
|
i += 1
|
||||||
|
|
||||||
return deduped_entities, id_redirect, exact_merge_map
|
return deduped_entities, id_redirect, exact_merge_map
|
||||||
|
|
||||||
def fuzzy_match(
|
def fuzzy_match(
|
||||||
|
|||||||
@@ -25,17 +25,17 @@ from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
|||||||
|
|
||||||
|
|
||||||
async def dedup_layers_and_merge_and_return(
|
async def dedup_layers_and_merge_and_return(
|
||||||
dialogue_nodes: List[DialogueNode],
|
dialogue_nodes: List[DialogueNode],
|
||||||
chunk_nodes: List[ChunkNode],
|
chunk_nodes: List[ChunkNode],
|
||||||
statement_nodes: List[StatementNode],
|
statement_nodes: List[StatementNode],
|
||||||
entity_nodes: List[ExtractedEntityNode],
|
entity_nodes: List[ExtractedEntityNode],
|
||||||
statement_chunk_edges: List[StatementChunkEdge],
|
statement_chunk_edges: List[StatementChunkEdge],
|
||||||
statement_entity_edges: List[StatementEntityEdge],
|
statement_entity_edges: List[StatementEntityEdge],
|
||||||
entity_entity_edges: List[EntityEntityEdge],
|
entity_entity_edges: List[EntityEntityEdge],
|
||||||
dialog_data_list: List[DialogData],
|
dialog_data_list: List[DialogData],
|
||||||
pipeline_config: ExtractionPipelineConfig,
|
pipeline_config: ExtractionPipelineConfig,
|
||||||
connector: Optional[Neo4jConnector] = None,
|
connector: Optional[Neo4jConnector] = None,
|
||||||
llm_client = None,
|
llm_client=None,
|
||||||
) -> Tuple[
|
) -> Tuple[
|
||||||
List[DialogueNode],
|
List[DialogueNode],
|
||||||
List[ChunkNode],
|
List[ChunkNode],
|
||||||
@@ -44,7 +44,7 @@ async def dedup_layers_and_merge_and_return(
|
|||||||
List[StatementChunkEdge],
|
List[StatementChunkEdge],
|
||||||
List[StatementEntityEdge],
|
List[StatementEntityEdge],
|
||||||
List[EntityEntityEdge],
|
List[EntityEntityEdge],
|
||||||
dict, # 新增:返回去重详情
|
dict
|
||||||
]:
|
]:
|
||||||
"""
|
"""
|
||||||
执行两层实体去重与融合:
|
执行两层实体去重与融合:
|
||||||
|
|||||||
@@ -19,6 +19,7 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
import uuid
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple
|
from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
@@ -32,10 +33,11 @@ from app.core.memory.models.graph_models import (
|
|||||||
StatementChunkEdge,
|
StatementChunkEdge,
|
||||||
StatementEntityEdge,
|
StatementEntityEdge,
|
||||||
StatementNode,
|
StatementNode,
|
||||||
|
PerceptualEdge,
|
||||||
|
PerceptualNode
|
||||||
)
|
)
|
||||||
from app.core.memory.models.message_models import DialogData
|
from app.core.memory.models.message_models import DialogData
|
||||||
from app.core.memory.models.ontology_extraction_models import OntologyTypeList
|
from app.core.memory.models.ontology_extraction_models import OntologyTypeList
|
||||||
from app.core.memory.models.ontology_extraction_models import OntologyTypeList
|
|
||||||
from app.core.memory.models.variate_config import (
|
from app.core.memory.models.variate_config import (
|
||||||
ExtractionPipelineConfig,
|
ExtractionPipelineConfig,
|
||||||
)
|
)
|
||||||
@@ -46,7 +48,6 @@ from app.core.memory.storage_services.extraction_engine.knowledge_extraction.emb
|
|||||||
embedding_generation,
|
embedding_generation,
|
||||||
generate_entity_embeddings_from_triplets,
|
generate_entity_embeddings_from_triplets,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 导入各个提取模块
|
# 导入各个提取模块
|
||||||
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.statement_extraction import (
|
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.statement_extraction import (
|
||||||
StatementExtractor,
|
StatementExtractor,
|
||||||
@@ -62,6 +63,10 @@ from app.core.memory.storage_services.extraction_engine.pipeline_help import (
|
|||||||
export_test_input_doc,
|
export_test_input_doc,
|
||||||
)
|
)
|
||||||
from app.core.memory.utils.data.ontology import TemporalInfo
|
from app.core.memory.utils.data.ontology import TemporalInfo
|
||||||
|
from app.db import get_db_context
|
||||||
|
from app.models.end_user_info_model import EndUserInfo
|
||||||
|
from app.repositories.end_user_info_repository import EndUserInfoRepository
|
||||||
|
from app.repositories.end_user_repository import EndUserRepository
|
||||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||||
|
|
||||||
# 配置日志
|
# 配置日志
|
||||||
@@ -90,16 +95,16 @@ class ExtractionOrchestrator:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
llm_client: LLMClient,
|
llm_client: LLMClient,
|
||||||
embedder_client: OpenAIEmbedderClient,
|
embedder_client: OpenAIEmbedderClient,
|
||||||
connector: Neo4jConnector,
|
connector: Neo4jConnector,
|
||||||
config: Optional[ExtractionPipelineConfig] = None,
|
config: Optional[ExtractionPipelineConfig] = None,
|
||||||
progress_callback: Optional[Callable[[str, str, Optional[Dict[str, Any]]], Awaitable[None]]] = None,
|
progress_callback: Optional[Callable[[str, str, Optional[Dict[str, Any]]], Awaitable[None]]] = None,
|
||||||
embedding_id: Optional[str] = None,
|
embedding_id: Optional[str] = None,
|
||||||
ontology_types: Optional[OntologyTypeList] = None,
|
ontology_types: Optional[OntologyTypeList] = None,
|
||||||
enable_general_types: bool = True,
|
enable_general_types: bool = True,
|
||||||
language: str = "zh",
|
language: str = "zh",
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
初始化流水线编排器
|
初始化流水线编排器
|
||||||
@@ -157,19 +162,27 @@ class ExtractionOrchestrator:
|
|||||||
llm_client=llm_client,
|
llm_client=llm_client,
|
||||||
config=self.config.statement_extraction,
|
config=self.config.statement_extraction,
|
||||||
)
|
)
|
||||||
self.triplet_extractor = TripletExtractor(llm_client=llm_client,ontology_types=self.ontology_types, language=language)
|
self.triplet_extractor = TripletExtractor(llm_client=llm_client, ontology_types=self.ontology_types,
|
||||||
|
language=language)
|
||||||
self.temporal_extractor = TemporalExtractor(llm_client=llm_client)
|
self.temporal_extractor = TemporalExtractor(llm_client=llm_client)
|
||||||
|
|
||||||
logger.info("ExtractionOrchestrator 初始化完成")
|
logger.info("ExtractionOrchestrator 初始化完成")
|
||||||
|
|
||||||
async def run(
|
async def run(
|
||||||
self,
|
self,
|
||||||
dialog_data_list: List[DialogData],
|
dialog_data_list: List[DialogData],
|
||||||
is_pilot_run: bool = False,
|
is_pilot_run: bool = False,
|
||||||
) -> Tuple[
|
) -> tuple[
|
||||||
Tuple[List[DialogueNode], List[ChunkNode], List[StatementNode]],
|
list[DialogueNode],
|
||||||
Tuple[List[ExtractedEntityNode], List[StatementEntityEdge], List[EntityEntityEdge]],
|
list[ChunkNode],
|
||||||
Tuple[List[ExtractedEntityNode], List[StatementEntityEdge], List[EntityEntityEdge]],
|
list[StatementNode],
|
||||||
|
list[ExtractedEntityNode],
|
||||||
|
list[PerceptualNode],
|
||||||
|
list[StatementChunkEdge],
|
||||||
|
list[StatementEntityEdge],
|
||||||
|
list[EntityEntityEdge],
|
||||||
|
list[PerceptualEdge],
|
||||||
|
list[DialogData]
|
||||||
]:
|
]:
|
||||||
"""
|
"""
|
||||||
运行完整的知识提取流水线(优化版:并行执行)
|
运行完整的知识提取流水线(优化版:并行执行)
|
||||||
@@ -208,7 +221,6 @@ class ExtractionOrchestrator:
|
|||||||
for dialog in dialog_data_list:
|
for dialog in dialog_data_list:
|
||||||
for chunk in dialog.chunks:
|
for chunk in dialog.chunks:
|
||||||
all_statements_list.extend(chunk.statements)
|
all_statements_list.extend(chunk.statements)
|
||||||
len(all_statements_list)
|
|
||||||
|
|
||||||
# 步骤 2: 并行执行三元组提取、时间信息提取、情绪提取和基础嵌入生成
|
# 步骤 2: 并行执行三元组提取、时间信息提取、情绪提取和基础嵌入生成
|
||||||
logger.info("步骤 2/6: 并行执行三元组提取、时间信息提取、情绪提取和嵌入生成")
|
logger.info("步骤 2/6: 并行执行三元组提取、时间信息提取、情绪提取和嵌入生成")
|
||||||
@@ -230,10 +242,6 @@ class ExtractionOrchestrator:
|
|||||||
all_entities_list.extend(triplet_info.entities)
|
all_entities_list.extend(triplet_info.entities)
|
||||||
all_triplets_list.extend(triplet_info.triplets)
|
all_triplets_list.extend(triplet_info.triplets)
|
||||||
|
|
||||||
len(all_entities_list)
|
|
||||||
len(all_triplets_list)
|
|
||||||
sum(len(temporal_map) for temporal_map in temporal_maps)
|
|
||||||
|
|
||||||
# 步骤 3: 生成实体嵌入(依赖三元组提取结果)
|
# 步骤 3: 生成实体嵌入(依赖三元组提取结果)
|
||||||
logger.info("步骤 3/6: 生成实体嵌入")
|
logger.info("步骤 3/6: 生成实体嵌入")
|
||||||
triplet_maps = await self._generate_entity_embeddings(triplet_maps)
|
triplet_maps = await self._generate_entity_embeddings(triplet_maps)
|
||||||
@@ -260,9 +268,11 @@ class ExtractionOrchestrator:
|
|||||||
chunk_nodes,
|
chunk_nodes,
|
||||||
statement_nodes,
|
statement_nodes,
|
||||||
entity_nodes,
|
entity_nodes,
|
||||||
|
perceptual_nodes,
|
||||||
statement_chunk_edges,
|
statement_chunk_edges,
|
||||||
statement_entity_edges,
|
statement_entity_edges,
|
||||||
entity_entity_edges,
|
entity_entity_edges,
|
||||||
|
perceptual_edges
|
||||||
) = await self._create_nodes_and_edges(dialog_data_list)
|
) = await self._create_nodes_and_edges(dialog_data_list)
|
||||||
|
|
||||||
# 导出去重前的测试输入文档(试运行和正式模式都需要,用于生成结果汇总)
|
# 导出去重前的测试输入文档(试运行和正式模式都需要,用于生成结果汇总)
|
||||||
@@ -276,7 +286,17 @@ class ExtractionOrchestrator:
|
|||||||
|
|
||||||
# 注意:deduplication 消息已在创建节点和边完成后立即发送
|
# 注意:deduplication 消息已在创建节点和边完成后立即发送
|
||||||
|
|
||||||
result = await self._run_dedup_and_write_summary(
|
(
|
||||||
|
dialogue_nodes,
|
||||||
|
chunk_nodes,
|
||||||
|
statement_nodes,
|
||||||
|
entity_nodes,
|
||||||
|
statement_chunk_edges,
|
||||||
|
statement_entity_edges,
|
||||||
|
entity_entity_edges,
|
||||||
|
dialog_data_list,
|
||||||
|
dedup_details,
|
||||||
|
) = await self._run_dedup_and_write_summary(
|
||||||
dialogue_nodes,
|
dialogue_nodes,
|
||||||
chunk_nodes,
|
chunk_nodes,
|
||||||
statement_nodes,
|
statement_nodes,
|
||||||
@@ -287,17 +307,31 @@ class ExtractionOrchestrator:
|
|||||||
dialog_data_list,
|
dialog_data_list,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# 步骤 7: 同步用户别名到数据库表(仅正式模式)
|
||||||
|
if not is_pilot_run:
|
||||||
|
logger.info("步骤 7: 同步用户别名到 end_user 和 end_user_info 表")
|
||||||
|
await self._update_end_user_other_name(entity_nodes, dialog_data_list)
|
||||||
|
|
||||||
logger.info(f"知识提取流水线运行完成({mode_str})")
|
logger.info(f"知识提取流水线运行完成({mode_str})")
|
||||||
return result
|
return (
|
||||||
|
dialogue_nodes,
|
||||||
|
chunk_nodes,
|
||||||
|
statement_nodes,
|
||||||
|
entity_nodes,
|
||||||
|
perceptual_nodes,
|
||||||
|
statement_chunk_edges,
|
||||||
|
statement_entity_edges,
|
||||||
|
entity_entity_edges,
|
||||||
|
perceptual_edges,
|
||||||
|
dialog_data_list,
|
||||||
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"知识提取流水线运行失败: {e}", exc_info=True)
|
logger.error(f"知识提取流水线运行失败: {e}", exc_info=True)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
async def _extract_statements(
|
async def _extract_statements(
|
||||||
self, dialog_data_list: List[DialogData]
|
self, dialog_data_list: List[DialogData]
|
||||||
) -> List[DialogData]:
|
) -> List[DialogData]:
|
||||||
"""
|
"""
|
||||||
从对话中提取陈述句(流式输出版本:边提取边发送进度)
|
从对话中提取陈述句(流式输出版本:边提取边发送进度)
|
||||||
@@ -395,7 +429,7 @@ class ExtractionOrchestrator:
|
|||||||
return dialog_data_list
|
return dialog_data_list
|
||||||
|
|
||||||
async def _extract_triplets(
|
async def _extract_triplets(
|
||||||
self, dialog_data_list: List[DialogData]
|
self, dialog_data_list: List[DialogData]
|
||||||
) -> List[Dict[str, Any]]:
|
) -> List[Dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
从对话中提取三元组(流式输出版本:边提取边发送进度)
|
从对话中提取三元组(流式输出版本:边提取边发送进度)
|
||||||
@@ -478,7 +512,7 @@ class ExtractionOrchestrator:
|
|||||||
return triplet_maps
|
return triplet_maps
|
||||||
|
|
||||||
async def _extract_temporal(
|
async def _extract_temporal(
|
||||||
self, dialog_data_list: List[DialogData]
|
self, dialog_data_list: List[DialogData]
|
||||||
) -> List[Dict[str, Any]]:
|
) -> List[Dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
从对话中提取时间信息(流式输出版本:边提取边发送进度)
|
从对话中提取时间信息(流式输出版本:边提取边发送进度)
|
||||||
@@ -585,7 +619,7 @@ class ExtractionOrchestrator:
|
|||||||
return temporal_maps
|
return temporal_maps
|
||||||
|
|
||||||
async def _extract_emotions(
|
async def _extract_emotions(
|
||||||
self, dialog_data_list: List[DialogData]
|
self, dialog_data_list: List[DialogData]
|
||||||
) -> List[Dict[str, Any]]:
|
) -> List[Dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
从对话中提取情绪信息(仅针对用户消息,全局陈述句级并行)
|
从对话中提取情绪信息(仅针对用户消息,全局陈述句级并行)
|
||||||
@@ -706,7 +740,7 @@ class ExtractionOrchestrator:
|
|||||||
return emotion_maps
|
return emotion_maps
|
||||||
|
|
||||||
async def _parallel_extract_and_embed(
|
async def _parallel_extract_and_embed(
|
||||||
self, dialog_data_list: List[DialogData]
|
self, dialog_data_list: List[DialogData]
|
||||||
) -> Tuple[
|
) -> Tuple[
|
||||||
List[Dict[str, Any]],
|
List[Dict[str, Any]],
|
||||||
List[Dict[str, Any]],
|
List[Dict[str, Any]],
|
||||||
@@ -777,7 +811,7 @@ class ExtractionOrchestrator:
|
|||||||
)
|
)
|
||||||
|
|
||||||
async def _generate_basic_embeddings(
|
async def _generate_basic_embeddings(
|
||||||
self, dialog_data_list: List[DialogData]
|
self, dialog_data_list: List[DialogData]
|
||||||
) -> Tuple[List[Dict[str, List[float]]], List[Dict[str, List[float]]], List[List[float]]]:
|
) -> Tuple[List[Dict[str, List[float]]], List[Dict[str, List[float]]], List[List[float]]]:
|
||||||
"""
|
"""
|
||||||
生成基础嵌入向量(陈述句、分块、对话)
|
生成基础嵌入向量(陈述句、分块、对话)
|
||||||
@@ -836,7 +870,7 @@ class ExtractionOrchestrator:
|
|||||||
)
|
)
|
||||||
|
|
||||||
async def _generate_entity_embeddings(
|
async def _generate_entity_embeddings(
|
||||||
self, triplet_maps: List[Dict[str, Any]]
|
self, triplet_maps: List[Dict[str, Any]]
|
||||||
) -> List[Dict[str, Any]]:
|
) -> List[Dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
生成实体嵌入向量
|
生成实体嵌入向量
|
||||||
@@ -874,17 +908,15 @@ class ExtractionOrchestrator:
|
|||||||
logger.error(f"实体嵌入生成失败: {e}", exc_info=True)
|
logger.error(f"实体嵌入生成失败: {e}", exc_info=True)
|
||||||
return triplet_maps
|
return triplet_maps
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
async def _assign_extracted_data(
|
async def _assign_extracted_data(
|
||||||
self,
|
self,
|
||||||
dialog_data_list: List[DialogData],
|
dialog_data_list: List[DialogData],
|
||||||
temporal_maps: List[Dict[str, Any]],
|
temporal_maps: List[Dict[str, Any]],
|
||||||
triplet_maps: List[Dict[str, Any]],
|
triplet_maps: List[Dict[str, Any]],
|
||||||
emotion_maps: List[Dict[str, Any]],
|
emotion_maps: List[Dict[str, Any]],
|
||||||
statement_embedding_maps: List[Dict[str, List[float]]],
|
statement_embedding_maps: List[Dict[str, List[float]]],
|
||||||
chunk_embedding_maps: List[Dict[str, List[float]]],
|
chunk_embedding_maps: List[Dict[str, List[float]]],
|
||||||
dialog_embeddings: List[List[float]],
|
dialog_embeddings: List[List[float]],
|
||||||
) -> List[DialogData]:
|
) -> List[DialogData]:
|
||||||
"""
|
"""
|
||||||
将提取的数据赋值到语句
|
将提取的数据赋值到语句
|
||||||
@@ -906,12 +938,12 @@ class ExtractionOrchestrator:
|
|||||||
# 确保列表长度匹配
|
# 确保列表长度匹配
|
||||||
expected_length = len(dialog_data_list)
|
expected_length = len(dialog_data_list)
|
||||||
if (
|
if (
|
||||||
len(temporal_maps) != expected_length
|
len(temporal_maps) != expected_length
|
||||||
or len(triplet_maps) != expected_length
|
or len(triplet_maps) != expected_length
|
||||||
or len(emotion_maps) != expected_length
|
or len(emotion_maps) != expected_length
|
||||||
or len(statement_embedding_maps) != expected_length
|
or len(statement_embedding_maps) != expected_length
|
||||||
or len(chunk_embedding_maps) != expected_length
|
or len(chunk_embedding_maps) != expected_length
|
||||||
or len(dialog_embeddings) != expected_length
|
or len(dialog_embeddings) != expected_length
|
||||||
):
|
):
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"数据大小不匹配 - 对话: {len(dialog_data_list)}, "
|
f"数据大小不匹配 - 对话: {len(dialog_data_list)}, "
|
||||||
@@ -999,15 +1031,17 @@ class ExtractionOrchestrator:
|
|||||||
return dialog_data_list
|
return dialog_data_list
|
||||||
|
|
||||||
async def _create_nodes_and_edges(
|
async def _create_nodes_and_edges(
|
||||||
self, dialog_data_list: List[DialogData]
|
self, dialog_data_list: List[DialogData]
|
||||||
) -> Tuple[
|
) -> Tuple[
|
||||||
List[DialogueNode],
|
List[DialogueNode],
|
||||||
List[ChunkNode],
|
List[ChunkNode],
|
||||||
List[StatementNode],
|
List[StatementNode],
|
||||||
List[ExtractedEntityNode],
|
List[ExtractedEntityNode],
|
||||||
|
List[PerceptualNode],
|
||||||
List[StatementChunkEdge],
|
List[StatementChunkEdge],
|
||||||
List[StatementEntityEdge],
|
List[StatementEntityEdge],
|
||||||
List[EntityEntityEdge],
|
List[EntityEntityEdge],
|
||||||
|
List[PerceptualEdge]
|
||||||
]:
|
]:
|
||||||
"""
|
"""
|
||||||
创建图数据库节点和边
|
创建图数据库节点和边
|
||||||
@@ -1031,6 +1065,8 @@ class ExtractionOrchestrator:
|
|||||||
statement_chunk_edges = []
|
statement_chunk_edges = []
|
||||||
statement_entity_edges = []
|
statement_entity_edges = []
|
||||||
entity_entity_edges = []
|
entity_entity_edges = []
|
||||||
|
perceptual_nodes = []
|
||||||
|
perceptual_edges = []
|
||||||
|
|
||||||
# 用于去重的集合
|
# 用于去重的集合
|
||||||
entity_id_set = set()
|
entity_id_set = set()
|
||||||
@@ -1075,6 +1111,45 @@ class ExtractionOrchestrator:
|
|||||||
)
|
)
|
||||||
chunk_nodes.append(chunk_node)
|
chunk_nodes.append(chunk_node)
|
||||||
|
|
||||||
|
for p, file_type in chunk.files:
|
||||||
|
|
||||||
|
meta = p.meta_data or {}
|
||||||
|
content_meta = meta.get("content", {})
|
||||||
|
|
||||||
|
# 生成 summary embedding(如果有 embedder_client)
|
||||||
|
summary_embedding = None
|
||||||
|
if self.embedder_client and p.summary:
|
||||||
|
try:
|
||||||
|
summary_embedding = (await self.embedder_client.response([p.summary]))[0]
|
||||||
|
except Exception as emb_err:
|
||||||
|
print(f"Failed to embed perceptual summary: {emb_err}")
|
||||||
|
|
||||||
|
perceptual = PerceptualNode(
|
||||||
|
name=f"Perceptual_{p.id}",
|
||||||
|
**{
|
||||||
|
"id": str(p.id),
|
||||||
|
"end_user_id": str(p.end_user_id),
|
||||||
|
"perceptual_type": p.perceptual_type,
|
||||||
|
"file_path": p.file_path or "",
|
||||||
|
"file_name": p.file_name or "",
|
||||||
|
"file_ext": p.file_ext or "",
|
||||||
|
"summary": p.summary or "",
|
||||||
|
"keywords": content_meta.get("keywords", []),
|
||||||
|
"topic": content_meta.get("topic", ""),
|
||||||
|
"domain": content_meta.get("domain", ""),
|
||||||
|
"created_at": p.created_time.isoformat() if p.created_time else None,
|
||||||
|
"file_type": file_type,
|
||||||
|
"summary_embedding": summary_embedding,
|
||||||
|
})
|
||||||
|
perceptual_nodes.append(perceptual)
|
||||||
|
perceptual_edges.append(PerceptualEdge(
|
||||||
|
source=perceptual.id,
|
||||||
|
target=chunk.id,
|
||||||
|
end_user_id=dialog_data.end_user_id,
|
||||||
|
run_id=dialog_data.run_id,
|
||||||
|
created_at=dialog_data.created_at,
|
||||||
|
))
|
||||||
|
|
||||||
# 处理每个陈述句
|
# 处理每个陈述句
|
||||||
for statement in chunk.statements:
|
for statement in chunk.statements:
|
||||||
# 创建陈述句节点
|
# 创建陈述句节点
|
||||||
@@ -1083,15 +1158,19 @@ class ExtractionOrchestrator:
|
|||||||
name=f"Statement_{statement.id}", # 添加必需的 name 字段
|
name=f"Statement_{statement.id}", # 添加必需的 name 字段
|
||||||
chunk_id=chunk.id,
|
chunk_id=chunk.id,
|
||||||
stmt_type=getattr(statement, 'stmt_type', 'general'), # 添加必需的 stmt_type 字段
|
stmt_type=getattr(statement, 'stmt_type', 'general'), # 添加必需的 stmt_type 字段
|
||||||
temporal_info=getattr(statement, 'temporal_info', TemporalInfo.ATEMPORAL), # 添加必需的 temporal_info 字段
|
temporal_info=getattr(statement, 'temporal_info', TemporalInfo.ATEMPORAL),
|
||||||
connect_strength=statement.connect_strength if statement.connect_strength is not None else 'Strong', # 添加必需的 connect_strength 字段
|
# 添加必需的 temporal_info 字段
|
||||||
|
connect_strength=statement.connect_strength if statement.connect_strength is not None else 'Strong',
|
||||||
|
# 添加必需的 connect_strength 字段
|
||||||
end_user_id=dialog_data.end_user_id,
|
end_user_id=dialog_data.end_user_id,
|
||||||
run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id
|
run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id
|
||||||
statement=statement.statement,
|
statement=statement.statement,
|
||||||
speaker=getattr(statement, 'speaker', None), # 添加 speaker 字段
|
speaker=getattr(statement, 'speaker', None), # 添加 speaker 字段
|
||||||
statement_embedding=statement.statement_embedding,
|
statement_embedding=statement.statement_embedding,
|
||||||
valid_at=statement.temporal_validity.valid_at if hasattr(statement, 'temporal_validity') and statement.temporal_validity else None,
|
valid_at=statement.temporal_validity.valid_at if hasattr(statement,
|
||||||
invalid_at=statement.temporal_validity.invalid_at if hasattr(statement, 'temporal_validity') and statement.temporal_validity else None,
|
'temporal_validity') and statement.temporal_validity else None,
|
||||||
|
invalid_at=statement.temporal_validity.invalid_at if hasattr(statement,
|
||||||
|
'temporal_validity') and statement.temporal_validity else None,
|
||||||
created_at=dialog_data.created_at,
|
created_at=dialog_data.created_at,
|
||||||
expired_at=dialog_data.expired_at,
|
expired_at=dialog_data.expired_at,
|
||||||
config_id=dialog_data.config_id if hasattr(dialog_data, 'config_id') else None,
|
config_id=dialog_data.config_id if hasattr(dialog_data, 'config_id') else None,
|
||||||
@@ -1141,7 +1220,8 @@ class ExtractionOrchestrator:
|
|||||||
example=getattr(entity, 'example', ''), # 新增:传递示例字段
|
example=getattr(entity, 'example', ''), # 新增:传递示例字段
|
||||||
# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
|
# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
|
||||||
# fact_summary=getattr(entity, 'fact_summary', ''), # 添加必需的 fact_summary 字段
|
# fact_summary=getattr(entity, 'fact_summary', ''), # 添加必需的 fact_summary 字段
|
||||||
connect_strength=entity_connect_strength if entity_connect_strength is not None else 'Strong', # 添加必需的 connect_strength 字段
|
connect_strength=entity_connect_strength if entity_connect_strength is not None else 'Strong',
|
||||||
|
# 添加必需的 connect_strength 字段
|
||||||
aliases=getattr(entity, 'aliases', []) or [], # 传递从三元组提取阶段获取的aliases
|
aliases=getattr(entity, 'aliases', []) or [], # 传递从三元组提取阶段获取的aliases
|
||||||
name_embedding=getattr(entity, 'name_embedding', None),
|
name_embedding=getattr(entity, 'name_embedding', None),
|
||||||
is_explicit_memory=getattr(entity, 'is_explicit_memory', False), # 新增:传递语义记忆标记
|
is_explicit_memory=getattr(entity, 'is_explicit_memory', False), # 新增:传递语义记忆标记
|
||||||
@@ -1248,25 +1328,197 @@ class ExtractionOrchestrator:
|
|||||||
chunk_nodes,
|
chunk_nodes,
|
||||||
statement_nodes,
|
statement_nodes,
|
||||||
entity_nodes,
|
entity_nodes,
|
||||||
|
perceptual_nodes,
|
||||||
statement_chunk_edges,
|
statement_chunk_edges,
|
||||||
statement_entity_edges,
|
statement_entity_edges,
|
||||||
entity_entity_edges,
|
entity_entity_edges,
|
||||||
|
perceptual_edges
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def _update_end_user_other_name(
|
||||||
|
self,
|
||||||
|
entity_nodes: List[ExtractedEntityNode],
|
||||||
|
dialog_data_list: List[DialogData]
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
从 Neo4j 读取用户实体的最终 aliases,同步到 end_user 和 end_user_info 表
|
||||||
|
|
||||||
|
注意:
|
||||||
|
1. other_name 使用本次对话提取的第一个别名(保持时间顺序)
|
||||||
|
2. aliases 从 Neo4j 读取(保持完整性)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
entity_nodes: 实体节点列表
|
||||||
|
dialog_data_list: 对话数据列表
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if not dialog_data_list:
|
||||||
|
logger.warning("dialog_data_list 为空,跳过用户别名同步")
|
||||||
|
return
|
||||||
|
|
||||||
|
end_user_id = dialog_data_list[0].end_user_id
|
||||||
|
if not end_user_id:
|
||||||
|
logger.warning("end_user_id 为空,跳过用户别名同步")
|
||||||
|
return
|
||||||
|
|
||||||
|
# 1. 提取本次对话的用户别名(保持 LLM 提取的原始顺序,不排序)
|
||||||
|
current_aliases = self._extract_current_aliases(entity_nodes)
|
||||||
|
|
||||||
|
# 2. 从 Neo4j 获取完整 aliases(权威数据源)
|
||||||
|
neo4j_aliases = await self._fetch_neo4j_user_aliases(end_user_id)
|
||||||
|
|
||||||
|
if not neo4j_aliases:
|
||||||
|
# Neo4j 中没有别名,使用本次对话提取的别名
|
||||||
|
neo4j_aliases = current_aliases
|
||||||
|
if not neo4j_aliases:
|
||||||
|
logger.debug(f"aliases 为空,跳过同步: end_user_id={end_user_id}")
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info(f"本次对话提取的 aliases: {current_aliases}")
|
||||||
|
logger.info(f"Neo4j 中的完整 aliases: {neo4j_aliases}")
|
||||||
|
|
||||||
|
# 3. 同步到数据库
|
||||||
|
end_user_uuid = uuid.UUID(end_user_id)
|
||||||
|
with get_db_context() as db:
|
||||||
|
# 更新 end_user 表
|
||||||
|
end_user = EndUserRepository(db).get_by_id(end_user_uuid)
|
||||||
|
if not end_user:
|
||||||
|
logger.warning(f"未找到 end_user_id={end_user_id} 的用户记录")
|
||||||
|
return
|
||||||
|
|
||||||
|
new_name = self._resolve_other_name(end_user.other_name, current_aliases, neo4j_aliases)
|
||||||
|
if new_name is not None:
|
||||||
|
end_user.other_name = new_name
|
||||||
|
logger.info(f"更新 end_user 表 other_name → {new_name}")
|
||||||
|
else:
|
||||||
|
logger.debug(f"end_user 表 other_name 保持不变: {end_user.other_name}")
|
||||||
|
|
||||||
|
# 更新或创建 end_user_info 记录
|
||||||
|
info = EndUserInfoRepository(db).get_by_end_user_id(end_user_uuid)
|
||||||
|
if info:
|
||||||
|
new_name_info = self._resolve_other_name(info.other_name, current_aliases, neo4j_aliases)
|
||||||
|
if new_name_info is not None:
|
||||||
|
info.other_name = new_name_info
|
||||||
|
logger.info(f"更新 end_user_info 表 other_name → {new_name_info}")
|
||||||
|
if info.aliases != neo4j_aliases:
|
||||||
|
info.aliases = neo4j_aliases
|
||||||
|
logger.info(f"同步 Neo4j aliases 到 end_user_info: {neo4j_aliases}")
|
||||||
|
else:
|
||||||
|
first_alias = current_aliases[0].strip() if current_aliases else ""
|
||||||
|
# 确保 first_alias 不是占位名称
|
||||||
|
if first_alias and first_alias not in self.USER_PLACEHOLDER_NAMES:
|
||||||
|
db.add(EndUserInfo(
|
||||||
|
end_user_id=end_user_uuid,
|
||||||
|
other_name=first_alias,
|
||||||
|
aliases=neo4j_aliases,
|
||||||
|
meta_data={}
|
||||||
|
))
|
||||||
|
logger.info(f"创建 end_user_info 记录,other_name={first_alias}, aliases={neo4j_aliases}")
|
||||||
|
|
||||||
|
db.commit()
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"更新 end_user other_name 失败: {e}", exc_info=True)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# 用户实体占位名称,不允许作为 other_name 或出现在 aliases 中
|
||||||
|
USER_PLACEHOLDER_NAMES = {'用户', '我', 'User', 'I'}
|
||||||
|
|
||||||
|
def _extract_current_aliases(self, entity_nodes: List[ExtractedEntityNode]) -> List[str]:
|
||||||
|
"""从实体节点提取用户别名(保持 LLM 提取的原始顺序,不进行任何排序)
|
||||||
|
|
||||||
|
这个方法直接返回 LLM 提取的别名列表,并过滤掉占位名称("用户"、"我"、"User"、"I")。
|
||||||
|
第一个别名将被用作 other_name。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
entity_nodes: 实体节点列表
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
别名列表(保持 LLM 提取的原始顺序,已过滤占位名称)
|
||||||
|
"""
|
||||||
|
for entity in entity_nodes:
|
||||||
|
if getattr(entity, 'name', '').strip() in self.USER_PLACEHOLDER_NAMES:
|
||||||
|
aliases = getattr(entity, 'aliases', []) or []
|
||||||
|
# 过滤掉占位名称,防止 "用户"/"我"/"User"/"I" 被存入 aliases 和 other_name
|
||||||
|
filtered = [a for a in aliases if a.strip() not in self.USER_PLACEHOLDER_NAMES]
|
||||||
|
logger.debug(f"提取到用户别名(原始顺序,已过滤占位名称): {filtered}")
|
||||||
|
return filtered
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
async def _fetch_neo4j_user_aliases(self, end_user_id: str) -> List[str]:
|
||||||
|
"""从 Neo4j 查询用户实体的完整 aliases 列表(已过滤占位名称)"""
|
||||||
|
cypher = """
|
||||||
|
MATCH (e:ExtractedEntity)
|
||||||
|
WHERE e.end_user_id = $end_user_id AND e.name IN ['用户', '我', 'User', 'I']
|
||||||
|
RETURN e.aliases AS aliases
|
||||||
|
LIMIT 1
|
||||||
|
"""
|
||||||
|
result = await Neo4jConnector().execute_query(cypher, end_user_id=end_user_id)
|
||||||
|
if not result:
|
||||||
|
logger.debug(f"Neo4j 中未找到用户实体: end_user_id={end_user_id}")
|
||||||
|
return []
|
||||||
|
aliases = result[0].get('aliases') or []
|
||||||
|
if not aliases:
|
||||||
|
logger.debug(f"Neo4j 用户实体 aliases 为空: end_user_id={end_user_id}")
|
||||||
|
return []
|
||||||
|
# 过滤掉占位名称,防止历史脏数据传播
|
||||||
|
filtered = [a for a in aliases if a.strip() not in self.USER_PLACEHOLDER_NAMES]
|
||||||
|
return filtered
|
||||||
|
|
||||||
|
def _resolve_other_name(
|
||||||
|
self,
|
||||||
|
current: Optional[str],
|
||||||
|
current_aliases: List[str],
|
||||||
|
neo4j_aliases: List[str]
|
||||||
|
) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
决定 other_name 是否需要更新,返回新值;无需更新返回 None。
|
||||||
|
|
||||||
|
决策规则:
|
||||||
|
- 为空或为占位名称 → 用本次对话第一个别名
|
||||||
|
- 不在 Neo4j aliases 中 → 用 Neo4j 第一个别名(说明已被删除)
|
||||||
|
- 否则 → 保持不变(返回 None)
|
||||||
|
|
||||||
|
注意:返回值不允许是占位名称("用户"、"我"、"User"、"I")
|
||||||
|
"""
|
||||||
|
# 当前值为空或为占位名称时,需要更新
|
||||||
|
if not current or not current.strip() or current.strip() in self.USER_PLACEHOLDER_NAMES:
|
||||||
|
candidate = current_aliases[0].strip() if current_aliases else None
|
||||||
|
# 确保候选值不是占位名称
|
||||||
|
if candidate and candidate in self.USER_PLACEHOLDER_NAMES:
|
||||||
|
return None
|
||||||
|
return candidate
|
||||||
|
if current not in neo4j_aliases:
|
||||||
|
candidate = neo4j_aliases[0].strip() if neo4j_aliases else None
|
||||||
|
# 确保候选值不是占位名称
|
||||||
|
if candidate and candidate in self.USER_PLACEHOLDER_NAMES:
|
||||||
|
return None
|
||||||
|
return candidate
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
async def _run_dedup_and_write_summary(
|
async def _run_dedup_and_write_summary(
|
||||||
self,
|
self,
|
||||||
dialogue_nodes: List[DialogueNode],
|
dialogue_nodes: List[DialogueNode],
|
||||||
chunk_nodes: List[ChunkNode],
|
chunk_nodes: List[ChunkNode],
|
||||||
statement_nodes: List[StatementNode],
|
statement_nodes: List[StatementNode],
|
||||||
entity_nodes: List[ExtractedEntityNode],
|
entity_nodes: List[ExtractedEntityNode],
|
||||||
statement_chunk_edges: List[StatementChunkEdge],
|
statement_chunk_edges: List[StatementChunkEdge],
|
||||||
statement_entity_edges: List[StatementEntityEdge],
|
statement_entity_edges: List[StatementEntityEdge],
|
||||||
entity_entity_edges: List[EntityEntityEdge],
|
entity_entity_edges: List[EntityEntityEdge],
|
||||||
dialog_data_list: List[DialogData],
|
dialog_data_list: List[DialogData],
|
||||||
) -> Tuple[
|
) -> tuple[
|
||||||
Tuple[List[DialogueNode], List[ChunkNode], List[StatementNode]],
|
list[DialogueNode],
|
||||||
Tuple[List[ExtractedEntityNode], List[StatementEntityEdge], List[EntityEntityEdge]],
|
list[ChunkNode],
|
||||||
Tuple[List[ExtractedEntityNode], List[StatementEntityEdge], List[EntityEntityEdge]],
|
list[StatementNode],
|
||||||
|
list[ExtractedEntityNode],
|
||||||
|
list[StatementChunkEdge],
|
||||||
|
list[StatementEntityEdge],
|
||||||
|
list[EntityEntityEdge],
|
||||||
|
list[DialogData],
|
||||||
|
dict
|
||||||
]:
|
]:
|
||||||
"""
|
"""
|
||||||
执行两阶段去重并写入汇总
|
执行两阶段去重并写入汇总
|
||||||
@@ -1329,6 +1581,8 @@ class ExtractionOrchestrator:
|
|||||||
statement_chunk_edges,
|
statement_chunk_edges,
|
||||||
dedup_statement_entity_edges,
|
dedup_statement_entity_edges,
|
||||||
dedup_entity_entity_edges,
|
dedup_entity_entity_edges,
|
||||||
|
dialog_data_list,
|
||||||
|
dedup_details,
|
||||||
)
|
)
|
||||||
|
|
||||||
final_entity_nodes = dedup_entity_nodes
|
final_entity_nodes = dedup_entity_nodes
|
||||||
@@ -1336,7 +1590,16 @@ class ExtractionOrchestrator:
|
|||||||
final_entity_entity_edges = dedup_entity_entity_edges
|
final_entity_entity_edges = dedup_entity_entity_edges
|
||||||
else:
|
else:
|
||||||
# 正式模式:执行完整的两阶段去重
|
# 正式模式:执行完整的两阶段去重
|
||||||
result_tuple = await dedup_layers_and_merge_and_return(
|
(
|
||||||
|
dialogue_nodes,
|
||||||
|
chunk_nodes,
|
||||||
|
statement_nodes,
|
||||||
|
final_entity_nodes,
|
||||||
|
statement_chunk_edges,
|
||||||
|
final_statement_entity_edges,
|
||||||
|
final_entity_entity_edges,
|
||||||
|
dedup_details,
|
||||||
|
) = await dedup_layers_and_merge_and_return(
|
||||||
dialogue_nodes,
|
dialogue_nodes,
|
||||||
chunk_nodes,
|
chunk_nodes,
|
||||||
statement_nodes,
|
statement_nodes,
|
||||||
@@ -1350,21 +1613,21 @@ class ExtractionOrchestrator:
|
|||||||
llm_client=self.llm_client,
|
llm_client=self.llm_client,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 解包返回值
|
|
||||||
(
|
|
||||||
_,
|
|
||||||
_,
|
|
||||||
_,
|
|
||||||
final_entity_nodes,
|
|
||||||
_,
|
|
||||||
final_statement_entity_edges,
|
|
||||||
final_entity_entity_edges,
|
|
||||||
dedup_details,
|
|
||||||
) = result_tuple
|
|
||||||
|
|
||||||
# 保存去重消歧的详细记录到实例变量
|
# 保存去重消歧的详细记录到实例变量
|
||||||
self._save_dedup_details(dedup_details, entity_nodes, final_entity_nodes)
|
self._save_dedup_details(dedup_details, entity_nodes, final_entity_nodes)
|
||||||
|
|
||||||
|
result_tuple = (
|
||||||
|
dialogue_nodes,
|
||||||
|
chunk_nodes,
|
||||||
|
statement_nodes,
|
||||||
|
final_entity_nodes,
|
||||||
|
statement_chunk_edges,
|
||||||
|
final_statement_entity_edges,
|
||||||
|
final_entity_entity_edges,
|
||||||
|
dialog_data_list,
|
||||||
|
dedup_details,
|
||||||
|
)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"去重后: {len(final_entity_nodes)} 个实体节点, "
|
f"去重后: {len(final_entity_nodes)} 个实体节点, "
|
||||||
f"{len(final_statement_entity_edges)} 条陈述句-实体边, "
|
f"{len(final_statement_entity_edges)} 条陈述句-实体边, "
|
||||||
@@ -1415,7 +1678,6 @@ class ExtractionOrchestrator:
|
|||||||
len(entity_entity_edges), len(final_entity_entity_edges)
|
len(entity_entity_edges), len(final_entity_entity_edges)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# 写入提取结果汇总(试运行和正式模式都需要生成)
|
# 写入提取结果汇总(试运行和正式模式都需要生成)
|
||||||
try:
|
try:
|
||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
@@ -1436,10 +1698,10 @@ class ExtractionOrchestrator:
|
|||||||
raise
|
raise
|
||||||
|
|
||||||
def _save_dedup_details(
|
def _save_dedup_details(
|
||||||
self,
|
self,
|
||||||
dedup_details: Dict[str, Any],
|
dedup_details: Dict[str, Any],
|
||||||
original_entities: List[ExtractedEntityNode],
|
original_entities: List[ExtractedEntityNode],
|
||||||
final_entities: List[ExtractedEntityNode]
|
final_entities: List[ExtractedEntityNode]
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
保存去重消歧的详细记录到实例变量(基于内存数据结构)
|
保存去重消歧的详细记录到实例变量(基于内存数据结构)
|
||||||
@@ -1537,15 +1799,16 @@ class ExtractionOrchestrator:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug(f"解析消歧记录失败: {record}, 错误: {e}")
|
logger.debug(f"解析消歧记录失败: {record}, 错误: {e}")
|
||||||
|
|
||||||
logger.info(f"保存去重消歧记录:{len(self.dedup_merge_records)} 个合并记录,{len(self.dedup_disamb_records)} 个消歧记录")
|
logger.info(
|
||||||
|
f"保存去重消歧记录:{len(self.dedup_merge_records)} 个合并记录,{len(self.dedup_disamb_records)} 个消歧记录")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"保存去重消歧详情失败: {e}", exc_info=True)
|
logger.error(f"保存去重消歧详情失败: {e}", exc_info=True)
|
||||||
|
|
||||||
async def _analyze_entity_merges(
|
async def _analyze_entity_merges(
|
||||||
self,
|
self,
|
||||||
original_entities: List[ExtractedEntityNode],
|
original_entities: List[ExtractedEntityNode],
|
||||||
final_entities: List[ExtractedEntityNode]
|
final_entities: List[ExtractedEntityNode]
|
||||||
) -> List[Dict[str, Any]]:
|
) -> List[Dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
分析实体合并情况,直接使用内存中的合并记录(不再解析日志文件)
|
分析实体合并情况,直接使用内存中的合并记录(不再解析日志文件)
|
||||||
@@ -1585,9 +1848,9 @@ class ExtractionOrchestrator:
|
|||||||
return []
|
return []
|
||||||
|
|
||||||
async def _analyze_entity_disambiguation(
|
async def _analyze_entity_disambiguation(
|
||||||
self,
|
self,
|
||||||
original_entities: List[ExtractedEntityNode],
|
original_entities: List[ExtractedEntityNode],
|
||||||
final_entities: List[ExtractedEntityNode]
|
final_entities: List[ExtractedEntityNode]
|
||||||
) -> List[Dict[str, Any]]:
|
) -> List[Dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
分析实体消歧情况,直接使用内存中的消歧记录(不再解析日志文件)
|
分析实体消歧情况,直接使用内存中的消歧记录(不再解析日志文件)
|
||||||
@@ -1645,9 +1908,9 @@ class ExtractionOrchestrator:
|
|||||||
return type_mapping.get(entity_type, f"{entity_type}实体节点")
|
return type_mapping.get(entity_type, f"{entity_type}实体节点")
|
||||||
|
|
||||||
async def _output_relationship_creation_results(
|
async def _output_relationship_creation_results(
|
||||||
self,
|
self,
|
||||||
entity_entity_edges: List[EntityEntityEdge],
|
entity_entity_edges: List[EntityEntityEdge],
|
||||||
entity_nodes: List[ExtractedEntityNode]
|
entity_nodes: List[ExtractedEntityNode]
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
输出关系创建结果
|
输出关系创建结果
|
||||||
@@ -1681,13 +1944,13 @@ class ExtractionOrchestrator:
|
|||||||
logger.error(f"输出关系创建结果失败: {e}", exc_info=True)
|
logger.error(f"输出关系创建结果失败: {e}", exc_info=True)
|
||||||
|
|
||||||
async def _send_dedup_progress_callback(
|
async def _send_dedup_progress_callback(
|
||||||
self,
|
self,
|
||||||
original_entities: int,
|
original_entities: int,
|
||||||
final_entities: int,
|
final_entities: int,
|
||||||
original_stmt_edges: int,
|
original_stmt_edges: int,
|
||||||
final_stmt_edges: int,
|
final_stmt_edges: int,
|
||||||
original_ent_edges: int,
|
original_ent_edges: int,
|
||||||
final_ent_edges: int,
|
final_ent_edges: int,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
发送去重消歧完成的进度回调,传递具体的去重和消歧效果
|
发送去重消歧完成的进度回调,传递具体的去重和消歧效果
|
||||||
@@ -1715,7 +1978,8 @@ class ExtractionOrchestrator:
|
|||||||
"original_count": original_entities,
|
"original_count": original_entities,
|
||||||
"final_count": final_entities,
|
"final_count": final_entities,
|
||||||
"reduced_count": entities_reduced,
|
"reduced_count": entities_reduced,
|
||||||
"reduction_rate": round(entities_reduced / original_entities * 100, 1) if original_entities > 0 else 0,
|
"reduction_rate": round(entities_reduced / original_entities * 100,
|
||||||
|
1) if original_entities > 0 else 0,
|
||||||
},
|
},
|
||||||
"statement_entity_edges": {
|
"statement_entity_edges": {
|
||||||
"original_count": original_stmt_edges,
|
"original_count": original_stmt_edges,
|
||||||
@@ -1790,7 +2054,8 @@ class ExtractionOrchestrator:
|
|||||||
|
|
||||||
disamb_examples.append({
|
disamb_examples.append({
|
||||||
"entity1_name": entity_name,
|
"entity1_name": entity_name,
|
||||||
"entity1_type": disamb_type.split("vs")[0].replace("消歧阻断:", "").strip() if "vs" in disamb_type else "未知",
|
"entity1_type": disamb_type.split("vs")[0].replace("消歧阻断:",
|
||||||
|
"").strip() if "vs" in disamb_type else "未知",
|
||||||
"entity2_name": entity_name,
|
"entity2_name": entity_name,
|
||||||
"entity2_type": disamb_type.split("vs")[1].strip() if "vs" in disamb_type else "未知",
|
"entity2_type": disamb_type.split("vs")[1].strip() if "vs" in disamb_type else "未知",
|
||||||
"description": f"{entity_name},消歧区分成功"
|
"description": f"{entity_name},消歧区分成功"
|
||||||
@@ -1815,9 +2080,9 @@ class ExtractionOrchestrator:
|
|||||||
|
|
||||||
|
|
||||||
async def get_chunked_dialogs(
|
async def get_chunked_dialogs(
|
||||||
chunker_strategy: str = "RecursiveChunker",
|
chunker_strategy: str = "RecursiveChunker",
|
||||||
end_user_id: str = "group_1",
|
end_user_id: str = "group_1",
|
||||||
indices: Optional[List[int]] = None,
|
indices: Optional[List[int]] = None,
|
||||||
) -> List[DialogData]:
|
) -> List[DialogData]:
|
||||||
"""从测试数据生成分块对话
|
"""从测试数据生成分块对话
|
||||||
|
|
||||||
@@ -1924,10 +2189,10 @@ async def get_chunked_dialogs(
|
|||||||
|
|
||||||
|
|
||||||
def preprocess_data(
|
def preprocess_data(
|
||||||
input_path: Optional[str] = None,
|
input_path: Optional[str] = None,
|
||||||
output_path: Optional[str] = None,
|
output_path: Optional[str] = None,
|
||||||
skip_cleaning: bool = True,
|
skip_cleaning: bool = True,
|
||||||
indices: Optional[List[int]] = None
|
indices: Optional[List[int]] = None
|
||||||
) -> List[DialogData]:
|
) -> List[DialogData]:
|
||||||
"""数据预处理
|
"""数据预处理
|
||||||
|
|
||||||
@@ -1946,7 +2211,8 @@ def preprocess_data(
|
|||||||
)
|
)
|
||||||
preprocessor = DataPreprocessor()
|
preprocessor = DataPreprocessor()
|
||||||
try:
|
try:
|
||||||
cleaned_data = preprocessor.preprocess(input_path=input_path, output_path=output_path, skip_cleaning=skip_cleaning, indices=indices)
|
cleaned_data = preprocessor.preprocess(input_path=input_path, output_path=output_path,
|
||||||
|
skip_cleaning=skip_cleaning, indices=indices)
|
||||||
logger.debug(f"数据预处理完成!共处理了 {len(cleaned_data)} 条对话数据")
|
logger.debug(f"数据预处理完成!共处理了 {len(cleaned_data)} 条对话数据")
|
||||||
return cleaned_data
|
return cleaned_data
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -1955,9 +2221,9 @@ def preprocess_data(
|
|||||||
|
|
||||||
|
|
||||||
async def get_chunked_dialogs_from_preprocessed(
|
async def get_chunked_dialogs_from_preprocessed(
|
||||||
data: List[DialogData],
|
data: List[DialogData],
|
||||||
chunker_strategy: str = "RecursiveChunker",
|
chunker_strategy: str = "RecursiveChunker",
|
||||||
llm_client: Optional[Any] = None,
|
llm_client: Optional[Any] = None,
|
||||||
) -> List[DialogData]:
|
) -> List[DialogData]:
|
||||||
"""从预处理后的数据中生成分块
|
"""从预处理后的数据中生成分块
|
||||||
|
|
||||||
@@ -1988,15 +2254,15 @@ async def get_chunked_dialogs_from_preprocessed(
|
|||||||
|
|
||||||
|
|
||||||
async def get_chunked_dialogs_with_preprocessing(
|
async def get_chunked_dialogs_with_preprocessing(
|
||||||
chunker_strategy: str = "RecursiveChunker",
|
chunker_strategy: str = "RecursiveChunker",
|
||||||
end_user_id: str = "default",
|
end_user_id: str = "default",
|
||||||
user_id: str = "default",
|
user_id: str = "default",
|
||||||
apply_id: str = "default",
|
apply_id: str = "default",
|
||||||
indices: Optional[List[int]] = None,
|
indices: Optional[List[int]] = None,
|
||||||
input_data_path: Optional[str] = None,
|
input_data_path: Optional[str] = None,
|
||||||
llm_client: Optional[Any] = None,
|
llm_client: Optional[Any] = None,
|
||||||
skip_cleaning: bool = True,
|
skip_cleaning: bool = True,
|
||||||
pruning_config: Optional[Dict] = None,
|
pruning_config: Optional[Dict] = None,
|
||||||
) -> List[DialogData]:
|
) -> List[DialogData]:
|
||||||
"""包含数据预处理步骤的完整分块流程
|
"""包含数据预处理步骤的完整分块流程
|
||||||
|
|
||||||
@@ -2046,7 +2312,8 @@ async def get_chunked_dialogs_with_preprocessing(
|
|||||||
if pruning_config:
|
if pruning_config:
|
||||||
# 使用传入的配置
|
# 使用传入的配置
|
||||||
config = PruningConfig(**pruning_config)
|
config = PruningConfig(**pruning_config)
|
||||||
logger.debug(f"[剪枝] 使用传入配置: switch={config.pruning_switch}, scene={config.pruning_scene}, threshold={config.pruning_threshold}")
|
logger.debug(
|
||||||
|
f"[剪枝] 使用传入配置: switch={config.pruning_switch}, scene={config.pruning_scene}, threshold={config.pruning_threshold}")
|
||||||
else:
|
else:
|
||||||
# 使用默认配置(关闭剪枝)
|
# 使用默认配置(关闭剪枝)
|
||||||
config = None
|
config = None
|
||||||
|
|||||||
@@ -5,8 +5,11 @@
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import logging
|
||||||
from typing import Any, Dict, List, Tuple
|
from typing import Any, Dict, List, Tuple
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
|
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
|
||||||
from app.core.memory.models.message_models import DialogData
|
from app.core.memory.models.message_models import DialogData
|
||||||
from app.core.models.base import RedBearModelConfig
|
from app.core.models.base import RedBearModelConfig
|
||||||
@@ -48,9 +51,9 @@ class EmbeddingGenerator:
|
|||||||
return await self.embedder_client.response(texts)
|
return await self.embedder_client.response(texts)
|
||||||
|
|
||||||
# 分批并行处理
|
# 分批并行处理
|
||||||
print(f"文本数量 {len(texts)} 超过批次大小 {batch_size},分批并行处理")
|
logger.info(f"文本数量 {len(texts)} 超过批次大小 {batch_size},分批并行处理")
|
||||||
batches = [texts[i:i+batch_size] for i in range(0, len(texts), batch_size)]
|
batches = [texts[i:i+batch_size] for i in range(0, len(texts), batch_size)]
|
||||||
print(f"分成 {len(batches)} 批,每批最多 {batch_size} 个文本")
|
logger.info(f"分成 {len(batches)} 批,每批最多 {batch_size} 个文本")
|
||||||
|
|
||||||
# 并行发送所有批次
|
# 并行发送所有批次
|
||||||
batch_results = await asyncio.gather(*[
|
batch_results = await asyncio.gather(*[
|
||||||
@@ -62,7 +65,7 @@ class EmbeddingGenerator:
|
|||||||
for batch_result in batch_results:
|
for batch_result in batch_results:
|
||||||
embeddings.extend(batch_result)
|
embeddings.extend(batch_result)
|
||||||
|
|
||||||
print(f"分批并行处理完成,共生成 {len(embeddings)} 个嵌入向量")
|
logger.info(f"分批并行处理完成,共生成 {len(embeddings)} 个嵌入向量")
|
||||||
return embeddings
|
return embeddings
|
||||||
|
|
||||||
async def generate_statement_embeddings(
|
async def generate_statement_embeddings(
|
||||||
@@ -77,7 +80,7 @@ class EmbeddingGenerator:
|
|||||||
Returns:
|
Returns:
|
||||||
每个对话的陈述句嵌入向量映射列表
|
每个对话的陈述句嵌入向量映射列表
|
||||||
"""
|
"""
|
||||||
print("\n=== 生成陈述句嵌入向量 ===")
|
logger.debug("=== 生成陈述句嵌入向量 ===")
|
||||||
|
|
||||||
# 收集所有陈述句
|
# 收集所有陈述句
|
||||||
all_statements = []
|
all_statements = []
|
||||||
@@ -102,7 +105,7 @@ class EmbeddingGenerator:
|
|||||||
stmt_id = chunked_dialogs[d_idx].chunks[c_idx].statements[s_idx].id
|
stmt_id = chunked_dialogs[d_idx].chunks[c_idx].statements[s_idx].id
|
||||||
stmt_embedding_maps[d_idx][stmt_id] = embedding
|
stmt_embedding_maps[d_idx][stmt_id] = embedding
|
||||||
|
|
||||||
print(f"为 {len(all_statements)} 个陈述句生成了嵌入向量")
|
logger.info(f"为 {len(all_statements)} 个陈述句生成了嵌入向量")
|
||||||
return stmt_embedding_maps
|
return stmt_embedding_maps
|
||||||
|
|
||||||
async def generate_chunk_embeddings(
|
async def generate_chunk_embeddings(
|
||||||
@@ -117,7 +120,7 @@ class EmbeddingGenerator:
|
|||||||
Returns:
|
Returns:
|
||||||
每个对话的分块嵌入向量映射列表
|
每个对话的分块嵌入向量映射列表
|
||||||
"""
|
"""
|
||||||
print("\n=== 生成分块嵌入向量 ===")
|
logger.debug("=== 生成分块嵌入向量 ===")
|
||||||
|
|
||||||
# 收集所有分块
|
# 收集所有分块
|
||||||
all_chunks = []
|
all_chunks = []
|
||||||
@@ -138,7 +141,7 @@ class EmbeddingGenerator:
|
|||||||
chunk_id = chunked_dialogs[d_idx].chunks[c_idx].id
|
chunk_id = chunked_dialogs[d_idx].chunks[c_idx].id
|
||||||
chunk_embedding_maps[d_idx][chunk_id] = embedding
|
chunk_embedding_maps[d_idx][chunk_id] = embedding
|
||||||
|
|
||||||
print(f"为 {len(all_chunks)} 个分块生成了嵌入向量")
|
logger.info(f"为 {len(all_chunks)} 个分块生成了嵌入向量")
|
||||||
return chunk_embedding_maps
|
return chunk_embedding_maps
|
||||||
|
|
||||||
async def generate_dialog_embeddings(
|
async def generate_dialog_embeddings(
|
||||||
@@ -172,7 +175,7 @@ class EmbeddingGenerator:
|
|||||||
Returns:
|
Returns:
|
||||||
(陈述句嵌入映射列表, 分块嵌入映射列表, 对话嵌入列表)
|
(陈述句嵌入映射列表, 分块嵌入映射列表, 对话嵌入列表)
|
||||||
"""
|
"""
|
||||||
print("\n=== 生成所有嵌入向量 ===")
|
logger.debug("=== 生成所有嵌入向量 ===")
|
||||||
|
|
||||||
# 并发生成陈述句和分块嵌入向量
|
# 并发生成陈述句和分块嵌入向量
|
||||||
stmt_embedding_maps, chunk_embedding_maps = await asyncio.gather(
|
stmt_embedding_maps, chunk_embedding_maps = await asyncio.gather(
|
||||||
@@ -183,9 +186,7 @@ class EmbeddingGenerator:
|
|||||||
# 对话嵌入向量(当前跳过)
|
# 对话嵌入向量(当前跳过)
|
||||||
dialog_embeddings = await self.generate_dialog_embeddings(chunked_dialogs)
|
dialog_embeddings = await self.generate_dialog_embeddings(chunked_dialogs)
|
||||||
|
|
||||||
print(
|
logger.info(f"生成完成:{len(chunked_dialogs)} 个对话的嵌入向量")
|
||||||
f"生成完成:{len(chunked_dialogs)} 个对话的嵌入向量"
|
|
||||||
)
|
|
||||||
|
|
||||||
return stmt_embedding_maps, chunk_embedding_maps, dialog_embeddings
|
return stmt_embedding_maps, chunk_embedding_maps, dialog_embeddings
|
||||||
|
|
||||||
@@ -201,7 +202,7 @@ class EmbeddingGenerator:
|
|||||||
Returns:
|
Returns:
|
||||||
更新后的三元组映射列表(实体包含嵌入向量)
|
更新后的三元组映射列表(实体包含嵌入向量)
|
||||||
"""
|
"""
|
||||||
print("\n=== 生成实体嵌入向量 ===")
|
logger.debug("=== 生成实体嵌入向量 ===")
|
||||||
|
|
||||||
entity_texts: List[str] = []
|
entity_texts: List[str] = []
|
||||||
entity_refs: List[Any] = []
|
entity_refs: List[Any] = []
|
||||||
@@ -219,7 +220,7 @@ class EmbeddingGenerator:
|
|||||||
entity_refs.append(ent)
|
entity_refs.append(ent)
|
||||||
|
|
||||||
if not entity_texts:
|
if not entity_texts:
|
||||||
print("没有找到需要生成嵌入向量的实体")
|
logger.debug("没有找到需要生成嵌入向量的实体")
|
||||||
return triplet_maps
|
return triplet_maps
|
||||||
|
|
||||||
# 批量生成嵌入向量
|
# 批量生成嵌入向量
|
||||||
@@ -227,13 +228,13 @@ class EmbeddingGenerator:
|
|||||||
|
|
||||||
# 打印前几个嵌入向量的维度
|
# 打印前几个嵌入向量的维度
|
||||||
for i in range(min(5, len(embeddings))):
|
for i in range(min(5, len(embeddings))):
|
||||||
print(f"实体 '{entity_texts[i]}' 嵌入向量维度: {len(embeddings[i])}")
|
logger.debug(f"实体 '{entity_texts[i]}' 嵌入向量维度: {len(embeddings[i])}")
|
||||||
|
|
||||||
# 将嵌入向量赋值给实体
|
# 将嵌入向量赋值给实体
|
||||||
for ent, emb in zip(entity_refs, embeddings):
|
for ent, emb in zip(entity_refs, embeddings):
|
||||||
setattr(ent, "name_embedding", emb)
|
setattr(ent, "name_embedding", emb)
|
||||||
|
|
||||||
print(f"为 {len(entity_refs)} 个实体生成了嵌入向量")
|
logger.info(f"为 {len(entity_refs)} 个实体生成了嵌入向量")
|
||||||
return triplet_maps
|
return triplet_maps
|
||||||
|
|
||||||
|
|
||||||
@@ -296,7 +297,7 @@ async def embedding_generation_all(
|
|||||||
Returns:
|
Returns:
|
||||||
(陈述句嵌入映射列表, 分块嵌入映射列表, 对话嵌入列表, 更新后的三元组映射列表)
|
(陈述句嵌入映射列表, 分块嵌入映射列表, 对话嵌入列表, 更新后的三元组映射列表)
|
||||||
"""
|
"""
|
||||||
print("\n=== 综合嵌入向量生成(陈述句/分块/对话 + 实体)===")
|
logger.debug("=== 综合嵌入向量生成(陈述句/分块/对话 + 实体)===")
|
||||||
|
|
||||||
generator = EmbeddingGenerator(embedding_id)
|
generator = EmbeddingGenerator(embedding_id)
|
||||||
|
|
||||||
|
|||||||
@@ -188,7 +188,6 @@ async def _process_chunk_summary(
|
|||||||
response_model=MemorySummaryResponse,
|
response_model=MemorySummaryResponse,
|
||||||
)
|
)
|
||||||
summary_text = structured.summary.strip()
|
summary_text = structured.summary.strip()
|
||||||
|
|
||||||
# Generate title and type for the summary
|
# Generate title and type for the summary
|
||||||
title = None
|
title = None
|
||||||
episodic_type = None
|
episodic_type = None
|
||||||
|
|||||||
@@ -5,6 +5,15 @@
|
|||||||
===Task===
|
===Task===
|
||||||
Extract entities and knowledge triplets from the given statement.
|
Extract entities and knowledge triplets from the given statement.
|
||||||
|
|
||||||
|
**⚠️ CRITICAL REQUIREMENTS:**
|
||||||
|
1. **ALIASES ORDER IS CRITICAL**: The FIRST alias in the array will be used as the user's primary display name (other_name). You MUST put the most important/frequently used name FIRST.
|
||||||
|
2. **ALWAYS include aliases field**: Even if empty, you MUST include "aliases": [] in EVERY entity.
|
||||||
|
|
||||||
|
<!-- TODO: v0.2.10 - denied_aliases 功能暂时禁用,将通过 Cypher 查询实现
|
||||||
|
2. **DENIED_ALIASES**: When user explicitly denies a name (e.g., "我不叫X", "I'm not called X"), you MUST put X in denied_aliases field, NOT in aliases.
|
||||||
|
3. **ALWAYS include both fields**: Even if empty, you MUST include "aliases": [] and "denied_aliases": [] in EVERY entity.
|
||||||
|
-->
|
||||||
|
|
||||||
{% if language == "zh" %}
|
{% if language == "zh" %}
|
||||||
**重要:请使用中文生成实体名称(name)、描述(description)和示例(example)。**
|
**重要:请使用中文生成实体名称(name)、描述(description)和示例(example)。**
|
||||||
{% else %}
|
{% else %}
|
||||||
@@ -18,34 +27,29 @@ Extract entities and knowledge triplets from the given statement.
|
|||||||
{% if ontology_types %}
|
{% if ontology_types %}
|
||||||
===Ontology Type Guidance===
|
===Ontology Type Guidance===
|
||||||
|
|
||||||
**CRITICAL RULE: You MUST ONLY use the predefined ontology type names listed below for the entity "type" field. Do NOT use any other type names, even if they seem reasonable.**
|
**CRITICAL: Use ONLY predefined type names below. If no exact match, use CLOSEST type. NEVER invent new types.**
|
||||||
|
|
||||||
**If no predefined type fits an entity, use the CLOSEST matching predefined type. NEVER invent new type names.**
|
**Type Priority:**
|
||||||
|
1. [场景类型] Scene Types (domain-specific, prefer first)
|
||||||
|
2. [通用类型] General Types (standard ontologies)
|
||||||
|
3. [通用父类] Parent Types (hierarchy context)
|
||||||
|
|
||||||
**Type Priority (from highest to lowest):**
|
**Rules:**
|
||||||
1. **[场景类型] Scene Types** - Domain-specific types, ALWAYS prefer these first
|
- Type MUST exactly match predefined names
|
||||||
2. **[通用类型] General Types** - Common types from standard ontologies (DBpedia)
|
- Do NOT modify, translate, or abbreviate type names
|
||||||
3. **[通用父类] Parent Types** - Provide type hierarchy context
|
- Prefer scene types over general types
|
||||||
|
|
||||||
**Type Matching Rules:**
|
**Predefined Types:**
|
||||||
- Entity type MUST exactly match one of the predefined type names below
|
|
||||||
- Do NOT use types like "Equipment", "Component", "Concept", "Action", "Condition", "Data", "Duration" unless they appear in the predefined list
|
|
||||||
- Do NOT modify, translate, abbreviate, or create variations of type names
|
|
||||||
- Prefer scene types (marked [场景类型]) over general types when both could apply
|
|
||||||
- If uncertain, check the type description to find the best match
|
|
||||||
|
|
||||||
**Predefined Ontology Types:**
|
|
||||||
{{ ontology_types }}
|
{{ ontology_types }}
|
||||||
|
|
||||||
{% if type_hierarchy_hints %}
|
{% if type_hierarchy_hints %}
|
||||||
**Type Hierarchy Reference:**
|
**Hierarchy:**
|
||||||
The following shows type inheritance relationships (Child → Parent → Grandparent):
|
|
||||||
{% for hint in type_hierarchy_hints %}
|
{% for hint in type_hierarchy_hints %}
|
||||||
- {{ hint }}
|
- {{ hint }}
|
||||||
{% endfor %}
|
{% endfor %}
|
||||||
{% endif %}
|
{% endif %}
|
||||||
|
|
||||||
**ALLOWED Type Names (use EXACTLY one of these, no exceptions):**
|
**ALLOWED Names:**
|
||||||
{{ ontology_type_names | join(', ') }}
|
{{ ontology_type_names | join(', ') }}
|
||||||
|
|
||||||
{% endif %}
|
{% endif %}
|
||||||
@@ -62,66 +66,94 @@ The following shows type inheritance relationships (Child → Parent → Grandpa
|
|||||||
- **Entity descriptions must be in English**
|
- **Entity descriptions must be in English**
|
||||||
- **Examples must be in English**
|
- **Examples must be in English**
|
||||||
{% endif %}
|
{% endif %}
|
||||||
- **Semantic Memory Classification (is_explicit_memory):**
|
- **Semantic Memory (is_explicit_memory):**
|
||||||
* Set to `true` if the entity represents **explicit/semantic memory**:
|
* `true` for: Concepts, Knowledge, Definitions, Theories, Methods (e.g., "Machine Learning", "REST API")
|
||||||
- **Concepts:** "Machine Learning", "Photosynthesis", "Democracy"
|
* `false` for: People, Organizations, Locations, Events, Specific objects
|
||||||
- **Knowledge:** "Python Programming Language", "Theory of Relativity"
|
* For `is_explicit_memory=true`, provide concise example (~20 chars{% if language == "zh" %},使用中文{% endif %})
|
||||||
- **Definitions:** "API (Application Programming Interface)", "REST API"
|
|
||||||
- **Principles:** "SOLID Principles", "First Law of Thermodynamics"
|
**🚨🚨🚨 ALIASES & DENIED_ALIASES - MANDATORY FIELDS 🚨🚨🚨**
|
||||||
- **Theories:** "Evolution Theory", "Quantum Mechanics"
|
|
||||||
- **Methods/Techniques:** "Agile Development", "Machine Learning Algorithm"
|
**CRITICAL RULES (违反将导致提取失败):**
|
||||||
- **Technical Terms:** "Neural Network", "Database"
|
|
||||||
* Set to `false` for:
|
1. **EVERY entity MUST have aliases field:**
|
||||||
- **People:** "John Smith", "Dr. Wang"
|
- `"aliases": [...]` - REQUIRED, even if empty `[]`
|
||||||
- **Organizations:** "Microsoft", "Harvard University"
|
|
||||||
- **Locations:** "Beijing", "Central Park"
|
2. **ALIASES - 别名提取规则:**
|
||||||
- **Events:** "2024 Conference", "Project Meeting"
|
|
||||||
- **Specific objects:** "iPhone 15", "Building A"
|
|
||||||
- **Example Generation (IMPORTANT for semantic memory entities):**
|
|
||||||
* For entities where `is_explicit_memory=true`, generate a **concise example (around 20 characters)** to help understand the concept
|
|
||||||
* The example should be:
|
|
||||||
- **Specific and concrete**: Use real-world scenarios or applications
|
|
||||||
- **Brief**: Around 20 characters (can be slightly longer if needed for clarity)
|
|
||||||
{% if language == "zh" %}
|
{% if language == "zh" %}
|
||||||
- **使用中文**
|
- 包含:昵称、全名、简称、别称、网名等
|
||||||
|
- 顺序:**第一个别名将作为用户的主显示名称(other_name),必须把最重要/最常用的名字放在第一位**
|
||||||
|
- 提取顺序:严格按照对话中首次出现的顺序
|
||||||
|
- 示例:
|
||||||
|
* "我叫张三,大家叫我小张" → aliases=["张三", "小张"](张三是第一个,将成为 other_name)
|
||||||
|
* "大家叫我小李,我全名叫李明" → aliases=["小李", "李明"](小李先出现,将成为 other_name)
|
||||||
|
- 空值:如果没有别名,使用 `[]`
|
||||||
|
- 重要:只提取本次对话中明确提到的别名,不要推测或添加未提及的名字
|
||||||
{% else %}
|
{% else %}
|
||||||
- **In English**
|
- Include: nicknames, full names, abbreviations, alternative names
|
||||||
|
- Order: **The FIRST alias will be used as the user's primary display name (other_name). Put the most important/frequently used name FIRST**
|
||||||
|
- Extraction order: Strictly follow the order of first appearance in conversation
|
||||||
|
- Examples:
|
||||||
|
* "I'm John, people call me Johnny" → aliases=["John", "Johnny"] (John is first, will become other_name)
|
||||||
|
* "People call me Mike, my full name is Michael" → aliases=["Mike", "Michael"] (Mike appears first, will become other_name)
|
||||||
|
- Empty: If no aliases, use `[]`
|
||||||
|
- Important: Only extract aliases explicitly mentioned in current conversation, do not infer or add unmentioned names
|
||||||
{% endif %}
|
{% endif %}
|
||||||
* For non-semantic entities (`is_explicit_memory=false`), the example field can be empty
|
|
||||||
- **Aliases Extraction:**
|
|
||||||
|
|
||||||
|
3. **USER ENTITY SPECIAL HANDLING:**
|
||||||
{% if language == "zh" %}
|
{% if language == "zh" %}
|
||||||
* 别名使用中文
|
- 用户实体的 name 字段:使用 "用户" 或 "我"
|
||||||
|
- 用户的真实姓名:放入 aliases
|
||||||
|
- **🚨 禁止将 "用户"、"我" 放入 aliases 中,aliases 只能包含用户的真实姓名、昵称等**
|
||||||
|
- 示例:
|
||||||
|
* "我叫李明" → name="用户", aliases=["李明"]
|
||||||
|
* ❌ 错误:aliases=["用户", "李明"]("用户"不是真实姓名,禁止放入 aliases)
|
||||||
|
* ❌ 错误:aliases=["我", "李明"]("我"不是真实姓名,禁止放入 aliases)
|
||||||
{% else %}
|
{% else %}
|
||||||
* Aliases should be in English
|
- User entity name field: use "User" or "I"
|
||||||
|
- User's real name: put in aliases
|
||||||
|
- **🚨 NEVER put "User" or "I" in aliases. Aliases must only contain real names, nicknames, etc.**
|
||||||
|
- Examples:
|
||||||
|
* "I'm John" → name="User", aliases=["John"]
|
||||||
|
* ❌ Wrong: aliases=["User", "John"] ("User" is not a real name, FORBIDDEN in aliases)
|
||||||
|
* ❌ Wrong: aliases=["I", "John"] ("I" is not a real name, FORBIDDEN in aliases)
|
||||||
{% endif %}
|
{% endif %}
|
||||||
* Include common alternative names, abbreviations and full names
|
|
||||||
* If no aliases exist, use empty array: []
|
|
||||||
- Exclude lengthy quotes, calendar dates, temporal ranges, and temporal expressions
|
|
||||||
- For numeric values: extract as separate entities (instance_of: 'Numeric', name: units, numeric_value: value)
|
4. **ALIASES ORDER:**
|
||||||
Example: £30 → name: 'GBP', numeric_value: 30, instance_of: 'Numeric'
|
{% if language == "zh" %}
|
||||||
|
- 顺序优先级:按出现顺序,先出现的在前
|
||||||
|
{% else %}
|
||||||
|
- Order priority: by appearance order, first mentioned comes first
|
||||||
|
{% endif %}
|
||||||
|
|
||||||
|
**EXAMPLES OF CORRECT EXTRACTION:**
|
||||||
|
{% if language == "zh" %}
|
||||||
|
- "我叫张三" → aliases=["张三"] (张三将成为 other_name)
|
||||||
|
- "大家叫我小明,我全名叫李明" → aliases=["小明", "李明"] (小明先出现,将成为 other_name)
|
||||||
|
- "我是李华,网名叫华仔" → aliases=["李华", "华仔"] (李华先出现,将成为 other_name)
|
||||||
|
{% else %}
|
||||||
|
- "I'm John" → aliases=["John"] (John will become other_name)
|
||||||
|
- "People call me Mike, my full name is Michael" → aliases=["Mike", "Michael"] (Mike appears first, will become other_name)
|
||||||
|
- "I'm John Smith, username JSmith" → aliases=["John Smith", "JSmith"] (John Smith appears first, will become other_name)
|
||||||
|
{% endif %}
|
||||||
|
|
||||||
|
- Exclude lengthy quotes, dates, temporal expressions
|
||||||
|
- Numeric values: extract as entities (instance_of: 'Numeric', name: units, numeric_value: value)
|
||||||
|
|
||||||
**Triplet Extraction:**
|
**Triplet Extraction:**
|
||||||
- Extract (subject, predicate, object) triplets where:
|
- Extract (subject, predicate, object) where subject/object are entities, predicate is relationship
|
||||||
- Subject: main entity performing the action or being described
|
|
||||||
- Predicate: relationship between entities (e.g., 'is', 'works at', 'believes')
|
|
||||||
- Object: entity, value, or concept affected by the predicate
|
|
||||||
{% if language == "zh" %}
|
{% if language == "zh" %}
|
||||||
- subject_name 和 object_name 必须使用中文
|
- subject_name 和 object_name 使用中文
|
||||||
{% else %}
|
{% else %}
|
||||||
- subject_name and object_name must be in English (translate if original is in another language)
|
- subject_name and object_name in English
|
||||||
{% endif %}
|
{% endif %}
|
||||||
- Exclude all temporal expressions from every field
|
- Use ONLY predicates from "Predicate Instructions" (uppercase tokens)
|
||||||
- Use ONLY the predicates listed in "Predicate Instructions" (uppercase English tokens)
|
- Exclude temporal expressions, do NOT include `statement_id`
|
||||||
- Do NOT translate predicate tokens
|
- **When NOT to extract:** emotions, fillers, no clear predicate, standalone nouns
|
||||||
- Do NOT include `statement_id` field (assigned automatically)
|
- **If no valid triplet:** Return triplets: []
|
||||||
|
|
||||||
**When NOT to extract triplets:**
|
|
||||||
- Non-propositional utterances (emotions, fillers, onomatopoeia)
|
|
||||||
- No clear predicate from the given definitions applies
|
|
||||||
- Standalone noun phrases or checklist items → extract as entities only
|
|
||||||
- Do NOT invent generic predicates (e.g., "IS_DOING", "FEELS", "MENTIONS")
|
|
||||||
|
|
||||||
**If no valid triplet exists:** Return triplets: [], extract entities if present, otherwise both arrays empty.
|
|
||||||
{%- if predicate_instructions -%}
|
{%- if predicate_instructions -%}
|
||||||
|
|
||||||
**Predicate Instructions:**
|
**Predicate Instructions:**
|
||||||
@@ -207,26 +239,44 @@ Output:
|
|||||||
{"entity_idx": 0, "name": "三脚架", "type": "Equipment", "description": "摄影器材配件", "example": "", "aliases": ["相机三脚架"], "is_explicit_memory": false}
|
{"entity_idx": 0, "name": "三脚架", "type": "Equipment", "description": "摄影器材配件", "example": "", "aliases": ["相机三脚架"], "is_explicit_memory": false}
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
**Example 4 (别名 - Chinese):** "我的名字是乐力齐,我的小名是齐齐,同事们都叫我小乐"
|
||||||
|
Output:
|
||||||
|
{
|
||||||
|
"triplets": [],
|
||||||
|
"entities": [
|
||||||
|
{"entity_idx": 0, "name": "用户", "type": "Person", "description": "用户本人", "example": "", "aliases": ["乐力齐", "齐齐", "小乐"], "is_explicit_memory": false}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
**Example 5 (别名顺序 - Chinese):** "我叫陈思远。对了,我的网名叫「远山」"
|
||||||
|
Output:
|
||||||
|
{
|
||||||
|
"triplets": [],
|
||||||
|
"entities": [
|
||||||
|
{"entity_idx": 0, "name": "用户", "type": "Person", "description": "用户本人", "example": "", "aliases": ["陈思远", "远山"], "is_explicit_memory": false}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
{% endif %}
|
{% endif %}
|
||||||
===End of Examples===
|
===End of Examples===
|
||||||
|
|
||||||
{% if ontology_types %}
|
{% if ontology_types %}
|
||||||
**⚠️ REMINDER: The examples above use generic type names for illustration only. You MUST use ONLY the predefined ontology type names from the "ALLOWED Type Names" list above. For example, use "PredictiveMaintenance" instead of "Concept", use "ProductionLine" instead of "Equipment", etc. Map each entity to the closest matching predefined type.**
|
**⚠️ REMINDER: Examples use generic types for illustration. You MUST use predefined types from "ALLOWED Names" above.**
|
||||||
{% endif %}
|
{% endif %}
|
||||||
|
|
||||||
===Output Format===
|
===Output Format===
|
||||||
|
|
||||||
**JSON Requirements:**
|
**JSON Requirements:**
|
||||||
- Use only ASCII double quotes (") for JSON structure
|
- Use ASCII double quotes ("), escape with \"
|
||||||
- Never use Chinese quotation marks ("") or Unicode quotes
|
- No Chinese quotes (""), no line breaks in strings
|
||||||
- Escape quotation marks in text with backslashes (\")
|
|
||||||
- Ensure proper string closure and comma separation
|
|
||||||
- No line breaks within JSON string values
|
|
||||||
{% if language == "zh" %}
|
{% if language == "zh" %}
|
||||||
- **语言要求:实体名称(name)、描述(description)、示例(example)、subject_name、object_name 必须使用中文**
|
- **语言:name、description、example、subject_name、object_name 使用中文**
|
||||||
{% else %}
|
{% else %}
|
||||||
- **Language Requirement: Entity names, descriptions, examples, subject_name, object_name must be in English**
|
- **Language: names, descriptions, examples in English (translate if needed)**
|
||||||
- **If the original text is in Chinese, translate all names to English**
|
|
||||||
{% endif %}
|
{% endif %}
|
||||||
|
- **⚠️ ALIASES ORDER: preserve temporal order of appearance**
|
||||||
|
- **🚨 MANDATORY FIELD: EVERY entity MUST include "aliases" field, even if empty array []**
|
||||||
|
|
||||||
{{ json_schema }}
|
{{ json_schema }}
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ from .base import RedBearModelConfig, get_provider_llm_class, RedBearModelFacto
|
|||||||
from .llm import RedBearLLM
|
from .llm import RedBearLLM
|
||||||
from .embedding import RedBearEmbeddings
|
from .embedding import RedBearEmbeddings
|
||||||
from .rerank import RedBearRerank
|
from .rerank import RedBearRerank
|
||||||
|
from .generation import RedBearImageGenerator, RedBearVideoGenerator
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"RedBearModelConfig",
|
"RedBearModelConfig",
|
||||||
@@ -9,5 +10,7 @@ __all__ = [
|
|||||||
"RedBearEmbeddings",
|
"RedBearEmbeddings",
|
||||||
"RedBearRerank",
|
"RedBearRerank",
|
||||||
"RedBearModelFactory",
|
"RedBearModelFactory",
|
||||||
"get_provider_llm_class"
|
"get_provider_llm_class",
|
||||||
|
"RedBearImageGenerator",
|
||||||
|
"RedBearVideoGenerator"
|
||||||
]
|
]
|
||||||
@@ -67,7 +67,7 @@ class RedBearModelFactory:
|
|||||||
**config.extra_params
|
**config.extra_params
|
||||||
}
|
}
|
||||||
|
|
||||||
if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK, ModelProvider.OLLAMA]:
|
if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK, ModelProvider.OLLAMA, ModelProvider.VOLCANO]:
|
||||||
# 使用 httpx.Timeout 对象来设置详细的超时配置
|
# 使用 httpx.Timeout 对象来设置详细的超时配置
|
||||||
# 这样可以分别控制连接超时和读取超时
|
# 这样可以分别控制连接超时和读取超时
|
||||||
import httpx
|
import httpx
|
||||||
@@ -160,11 +160,13 @@ def get_provider_llm_class(config: RedBearModelConfig, type: ModelType = ModelTy
|
|||||||
# dashscope 的 omni 模型使用 OpenAI 兼容模式
|
# dashscope 的 omni 模型使用 OpenAI 兼容模式
|
||||||
if provider == ModelProvider.DASHSCOPE and config.is_omni:
|
if provider == ModelProvider.DASHSCOPE and config.is_omni:
|
||||||
return ChatOpenAI
|
return ChatOpenAI
|
||||||
if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK]:
|
if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK, ModelProvider.VOLCANO]:
|
||||||
if type == ModelType.LLM:
|
if type == ModelType.LLM:
|
||||||
return OpenAI
|
return OpenAI
|
||||||
elif type == ModelType.CHAT:
|
elif type == ModelType.CHAT:
|
||||||
return ChatOpenAI
|
return ChatOpenAI
|
||||||
|
else:
|
||||||
|
raise BusinessException(f"不支持的模型提供商及类型: {provider}-{type}", code=BizCode.PROVIDER_NOT_SUPPORTED)
|
||||||
elif provider == ModelProvider.DASHSCOPE:
|
elif provider == ModelProvider.DASHSCOPE:
|
||||||
return ChatTongyi
|
return ChatTongyi
|
||||||
elif provider == ModelProvider.OLLAMA:
|
elif provider == ModelProvider.OLLAMA:
|
||||||
|
|||||||
@@ -1,23 +1,190 @@
|
|||||||
|
|
||||||
from typing import Any, Dict, List, Optional, TypeVar, Callable
|
from typing import Any, Dict, List, Optional, Union
|
||||||
from langchain_core.embeddings import Embeddings
|
from langchain_core.embeddings import Embeddings
|
||||||
|
|
||||||
from app.core.models.base import RedBearModelConfig,get_provider_embedding_class,RedBearModelFactory
|
from app.core.models.base import RedBearModelConfig, get_provider_embedding_class, RedBearModelFactory
|
||||||
|
from app.models.models_model import ModelProvider
|
||||||
|
|
||||||
|
|
||||||
class RedBearEmbeddings(Embeddings):
|
class RedBearEmbeddings(Embeddings):
|
||||||
"""Embedding → 完全符合 LangChain Embeddings"""
|
"""统一的 Embedding 类,自动支持多模态(根据 provider 判断)"""
|
||||||
|
|
||||||
def __init__(self, config: RedBearModelConfig):
|
def __init__(self, config: RedBearModelConfig):
|
||||||
self._model = self._create_model(config)
|
|
||||||
self._config = config
|
self._config = config
|
||||||
|
self._is_volcano = config.provider.lower() == ModelProvider.VOLCANO
|
||||||
|
|
||||||
|
if self._is_volcano:
|
||||||
|
# 火山引擎使用 Ark SDK
|
||||||
|
self._client = self._create_volcano_client(config)
|
||||||
|
self._model = None
|
||||||
|
else:
|
||||||
|
# 其他 provider 使用 LangChain
|
||||||
|
self._model = self._create_model(config)
|
||||||
|
self._client = None
|
||||||
|
|
||||||
def _create_model(self, config: RedBearModelConfig) -> Embeddings:
|
def _create_model(self, config: RedBearModelConfig) -> Embeddings:
|
||||||
"""根据配置创建模型"""
|
"""根据配置创建 LangChain 模型"""
|
||||||
embedding_class = get_provider_embedding_class(config.provider)
|
embedding_class = get_provider_embedding_class(config.provider)
|
||||||
model_params = RedBearModelFactory.get_model_params(config)
|
model_params = RedBearModelFactory.get_model_params(config)
|
||||||
return embedding_class(**model_params)
|
return embedding_class(**model_params)
|
||||||
|
|
||||||
|
def _create_volcano_client(self, config: RedBearModelConfig):
|
||||||
|
"""创建火山引擎客户端"""
|
||||||
|
from volcenginesdkarkruntime import Ark
|
||||||
|
return Ark(api_key=config.api_key, base_url=config.base_url)
|
||||||
|
|
||||||
|
# ==================== LangChain 标准接口 ====================
|
||||||
|
|
||||||
def embed_documents(self, texts: list[str]) -> list[list[float]]:
|
def embed_documents(self, texts: list[str]) -> list[list[float]]:
|
||||||
return self._model.embed_documents(texts)
|
"""批量文本向量化(LangChain 标准接口)"""
|
||||||
|
if self._is_volcano:
|
||||||
|
# 火山引擎多模态 Embedding
|
||||||
|
contents = [{"type": "text", "text": text} for text in texts]
|
||||||
|
response = self._client.multimodal_embeddings.create(
|
||||||
|
model=self._config.model_name,
|
||||||
|
input=contents,
|
||||||
|
encoding_format="float"
|
||||||
|
)
|
||||||
|
return [response.data.embedding]
|
||||||
|
else:
|
||||||
|
# 其他 provider
|
||||||
|
return self._model.embed_documents(texts)
|
||||||
|
|
||||||
def embed_query(self, text: str) -> List[float]:
|
def embed_query(self, text: str) -> List[float]:
|
||||||
return self._model.embed_query(text)
|
"""单个文本向量化(LangChain 标准接口)"""
|
||||||
|
if self._is_volcano:
|
||||||
|
# 火山引擎多模态 Embedding
|
||||||
|
result = self.embed_documents([text])
|
||||||
|
return result[0] if result else []
|
||||||
|
else:
|
||||||
|
# 其他 provider
|
||||||
|
return self._model.embed_query(text)
|
||||||
|
|
||||||
|
# ==================== 多模态扩展方法 ====================
|
||||||
|
|
||||||
|
def embed_multimodal(
|
||||||
|
self,
|
||||||
|
contents: List[Dict[str, Any]],
|
||||||
|
**kwargs
|
||||||
|
) -> List[List[float]]:
|
||||||
|
"""
|
||||||
|
多模态向量化(仅火山引擎支持)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
contents: 内容列表,格式:
|
||||||
|
- 文本: {"type": "text", "text": "..."}
|
||||||
|
- 图片: {"type": "image_url", "image_url": {"url": "..."}}
|
||||||
|
- 视频: {"type": "video_url", "video_url": {"url": "..."}}
|
||||||
|
**kwargs: 其他参数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
向量列表
|
||||||
|
"""
|
||||||
|
if not self._is_volcano:
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"多模态 Embedding 仅支持火山引擎,当前 provider: {self._config.provider}"
|
||||||
|
)
|
||||||
|
|
||||||
|
response = self._client.multimodal_embeddings.create(
|
||||||
|
model=self._config.model_name,
|
||||||
|
input=contents,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
return [response.data.embedding]
|
||||||
|
|
||||||
|
async def aembed_multimodal(
|
||||||
|
self,
|
||||||
|
contents: List[Dict[str, Any]],
|
||||||
|
**kwargs
|
||||||
|
) -> List[List[float]]:
|
||||||
|
"""异步多模态向量化"""
|
||||||
|
# 火山引擎 SDK 暂不支持异步,使用同步方法
|
||||||
|
return self.embed_multimodal(contents, **kwargs)
|
||||||
|
|
||||||
|
def embed_text(self, text: str, **kwargs) -> List[float]:
|
||||||
|
"""文本向量化(便捷方法)"""
|
||||||
|
if self._is_volcano:
|
||||||
|
result = self.embed_multimodal(
|
||||||
|
[{"type": "text", "text": text}],
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
return result[0] if result else []
|
||||||
|
else:
|
||||||
|
return self.embed_query(text)
|
||||||
|
|
||||||
|
def embed_image(self, image_url: str, **kwargs) -> List[float]:
|
||||||
|
"""图片向量化(仅火山引擎支持)"""
|
||||||
|
if not self._is_volcano:
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"图片向量化仅支持火山引擎,当前 provider: {self._config.provider}"
|
||||||
|
)
|
||||||
|
|
||||||
|
result = self.embed_multimodal(
|
||||||
|
[{"type": "image_url", "image_url": {"url": image_url}}],
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
return result[0] if result else []
|
||||||
|
|
||||||
|
def embed_video(self, video_url: str, **kwargs) -> List[float]:
|
||||||
|
"""视频向量化(仅火山引擎支持)"""
|
||||||
|
if not self._is_volcano:
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"视频向量化仅支持火山引擎,当前 provider: {self._config.provider}"
|
||||||
|
)
|
||||||
|
|
||||||
|
result = self.embed_multimodal(
|
||||||
|
[{"type": "video_url", "video_url": {"url": video_url}}],
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
return result[0] if result else []
|
||||||
|
|
||||||
|
def embed_batch(
|
||||||
|
self,
|
||||||
|
items: List[Union[str, Dict[str, Any]]],
|
||||||
|
**kwargs
|
||||||
|
) -> List[List[float]]:
|
||||||
|
"""
|
||||||
|
批量向量化(支持混合类型)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
items: 可以是字符串列表或内容字典列表
|
||||||
|
**kwargs: 其他参数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
向量列表
|
||||||
|
"""
|
||||||
|
# 如果全是字符串,使用标准方法
|
||||||
|
if all(isinstance(item, str) for item in items):
|
||||||
|
return self.embed_documents(items)
|
||||||
|
|
||||||
|
# 如果包含字典,需要多模态支持
|
||||||
|
if not self._is_volcano:
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"混合类型批量向量化仅支持火山引擎,当前 provider: {self._config.provider}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 标准化输入格式
|
||||||
|
contents = []
|
||||||
|
for item in items:
|
||||||
|
if isinstance(item, str):
|
||||||
|
contents.append({"type": "text", "text": item})
|
||||||
|
elif isinstance(item, dict):
|
||||||
|
contents.append(item)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"不支持的输入类型: {type(item)}")
|
||||||
|
|
||||||
|
return self.embed_multimodal(contents, **kwargs)
|
||||||
|
|
||||||
|
# ==================== 工具方法 ====================
|
||||||
|
|
||||||
|
def is_multimodal_supported(self) -> bool:
|
||||||
|
"""检查是否支持多模态"""
|
||||||
|
return self._is_volcano
|
||||||
|
|
||||||
|
def get_provider(self) -> str:
|
||||||
|
"""获取 provider"""
|
||||||
|
return self._config.provider
|
||||||
|
|
||||||
|
|
||||||
|
# 保留 RedBearMultimodalEmbeddings 作为别名,向后兼容
|
||||||
|
RedBearMultimodalEmbeddings = RedBearEmbeddings
|
||||||
|
|||||||
344
api/app/core/models/generation.py
Normal file
344
api/app/core/models/generation.py
Normal file
@@ -0,0 +1,344 @@
|
|||||||
|
"""
|
||||||
|
图片和视频生成模型封装
|
||||||
|
|
||||||
|
支持的 Provider:
|
||||||
|
- Volcano (火山引擎): 使用 volcenginesdkarkruntime
|
||||||
|
- OpenAI: 使用 openai SDK
|
||||||
|
"""
|
||||||
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
|
from volcenginesdkarkruntime import Ark
|
||||||
|
from volcenginesdkarkruntime.types.images.images import (
|
||||||
|
SequentialImageGenerationOptions,
|
||||||
|
ContentGenerationTool,
|
||||||
|
OptimizePromptOptions
|
||||||
|
)
|
||||||
|
|
||||||
|
from app.core.models.base import RedBearModelConfig
|
||||||
|
from app.core.exceptions import BusinessException
|
||||||
|
from app.core.error_codes import BizCode
|
||||||
|
from app.models.models_model import ModelProvider
|
||||||
|
|
||||||
|
|
||||||
|
class RedBearImageGenerator:
|
||||||
|
"""图片生成模型封装"""
|
||||||
|
|
||||||
|
def __init__(self, config: RedBearModelConfig):
|
||||||
|
self._config = config
|
||||||
|
self._client = self._create_client(config)
|
||||||
|
|
||||||
|
def _create_client(self, config: RedBearModelConfig):
|
||||||
|
"""根据 provider 创建客户端"""
|
||||||
|
provider = config.provider.lower()
|
||||||
|
|
||||||
|
if provider == ModelProvider.VOLCANO:
|
||||||
|
return Ark(api_key=config.api_key, base_url=config.base_url)
|
||||||
|
# elif provider == ModelProvider.OPENAI:
|
||||||
|
# from openai import OpenAI
|
||||||
|
# return OpenAI(api_key=config.api_key, base_url=config.base_url)
|
||||||
|
else:
|
||||||
|
raise BusinessException(
|
||||||
|
f"不支持的图片生成提供商: {provider}",
|
||||||
|
code=BizCode.PROVIDER_NOT_SUPPORTED
|
||||||
|
)
|
||||||
|
|
||||||
|
def generate(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
image: Optional[Any] = None,
|
||||||
|
size: Optional[str] = "2K",
|
||||||
|
output_format: str = "png",
|
||||||
|
response_format: str = "url",
|
||||||
|
watermark: bool = False,
|
||||||
|
sequential_image_generation: Optional[str] = None,
|
||||||
|
sequential_image_generation_options: Optional[Dict] = None,
|
||||||
|
tools: Optional[list] = None,
|
||||||
|
optimize_prompt_options: Optional[Dict] = None,
|
||||||
|
stream: bool = False,
|
||||||
|
**kwargs
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
生成图片
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt: 提示词
|
||||||
|
image: 参考图片URL或URL列表(图文生图/多图融合)
|
||||||
|
size: 图片尺寸,支持 "2K", "2048x2048", "1920x1080" 等(至少3686400像素)
|
||||||
|
output_format: 输出格式,如 "png", "jpg"
|
||||||
|
response_format: 返回格式,"url" 或 "b64_json"
|
||||||
|
watermark: 是否添加水印
|
||||||
|
sequential_image_generation: 组图生成模式,"auto" 或 "disabled"
|
||||||
|
sequential_image_generation_options: 组图生成选项,如 {"max_images": 4}
|
||||||
|
tools: 工具列表,如 [{"type": "web_search"}] 用于联网搜索生图
|
||||||
|
optimize_prompt_options: 提示词优化选项,如 {"mode": "fast"}
|
||||||
|
stream: 是否使用流式生成
|
||||||
|
**kwargs: 其他参数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
生成结果
|
||||||
|
"""
|
||||||
|
provider = self._config.provider.lower()
|
||||||
|
|
||||||
|
if provider == ModelProvider.VOLCANO:
|
||||||
|
params = {
|
||||||
|
"model": self._config.model_name,
|
||||||
|
"prompt": prompt,
|
||||||
|
"size": size,
|
||||||
|
"output_format": output_format,
|
||||||
|
"response_format": response_format,
|
||||||
|
"watermark": watermark,
|
||||||
|
}
|
||||||
|
|
||||||
|
if image is not None:
|
||||||
|
params["image"] = image
|
||||||
|
|
||||||
|
if sequential_image_generation:
|
||||||
|
params["sequential_image_generation"] = sequential_image_generation
|
||||||
|
if sequential_image_generation_options:
|
||||||
|
params["sequential_image_generation_options"] = SequentialImageGenerationOptions(
|
||||||
|
**sequential_image_generation_options
|
||||||
|
)
|
||||||
|
|
||||||
|
if tools:
|
||||||
|
params["tools"] = [ContentGenerationTool(**tool) if isinstance(tool, dict) else tool for tool in tools]
|
||||||
|
|
||||||
|
if optimize_prompt_options:
|
||||||
|
params["optimize_prompt_options"] = OptimizePromptOptions(**optimize_prompt_options)
|
||||||
|
|
||||||
|
if stream:
|
||||||
|
params["stream"] = True
|
||||||
|
|
||||||
|
params.update(kwargs)
|
||||||
|
response = self._client.images.generate(**params)
|
||||||
|
|
||||||
|
# elif provider == ModelProvider.OPENAI:
|
||||||
|
# response = self._client.images.generate(
|
||||||
|
# model=self._config.model_name,
|
||||||
|
# prompt=prompt,
|
||||||
|
# size=size,
|
||||||
|
# n=n,
|
||||||
|
# **kwargs
|
||||||
|
# )
|
||||||
|
else:
|
||||||
|
raise BusinessException(
|
||||||
|
f"不支持的提供商: {provider}",
|
||||||
|
code=BizCode.PROVIDER_NOT_SUPPORTED
|
||||||
|
)
|
||||||
|
|
||||||
|
return response.model_dump() if hasattr(response, 'model_dump') else response
|
||||||
|
|
||||||
|
async def agenerate(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
image: Optional[Any] = None,
|
||||||
|
size: Optional[str] = "2K",
|
||||||
|
output_format: str = "png",
|
||||||
|
response_format: str = "url",
|
||||||
|
watermark: bool = False,
|
||||||
|
**kwargs
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""异步生成图片"""
|
||||||
|
return self.generate(prompt, image, size, output_format, response_format, watermark, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class RedBearVideoGenerator:
|
||||||
|
"""视频生成模型封装"""
|
||||||
|
|
||||||
|
def __init__(self, config: RedBearModelConfig):
|
||||||
|
self._config = config
|
||||||
|
self._client = self._create_client(config)
|
||||||
|
|
||||||
|
def _create_client(self, config: RedBearModelConfig):
|
||||||
|
"""根据 provider 创建客户端"""
|
||||||
|
provider = config.provider.lower()
|
||||||
|
|
||||||
|
if provider == ModelProvider.VOLCANO:
|
||||||
|
return Ark(api_key=config.api_key, base_url=config.base_url)
|
||||||
|
else:
|
||||||
|
raise BusinessException(
|
||||||
|
f"不支持的视频生成提供商: {provider}",
|
||||||
|
code=BizCode.PROVIDER_NOT_SUPPORTED
|
||||||
|
)
|
||||||
|
|
||||||
|
def generate(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
image_url: Optional[str] = None,
|
||||||
|
first_frame_url: Optional[str] = None,
|
||||||
|
last_frame_url: Optional[str] = None,
|
||||||
|
reference_images: Optional[list] = None,
|
||||||
|
draft_task_id: Optional[str] = None,
|
||||||
|
duration: Optional[int] = None,
|
||||||
|
frames: Optional[int] = None,
|
||||||
|
ratio: Optional[str] = None,
|
||||||
|
resolution: Optional[str] = None,
|
||||||
|
generate_audio: bool = False,
|
||||||
|
watermark: bool = False,
|
||||||
|
camera_fixed: bool = False,
|
||||||
|
seed: Optional[int] = None,
|
||||||
|
return_last_frame: bool = False,
|
||||||
|
service_tier: str = "default",
|
||||||
|
execution_expires_after: Optional[int] = None,
|
||||||
|
draft: bool = False,
|
||||||
|
**kwargs
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
生成视频
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt: 提示词
|
||||||
|
image_url: 首帧图片URL(图生视频-基于首帧)
|
||||||
|
first_frame_url: 首帧图片URL(图生视频-基于首尾帧)
|
||||||
|
last_frame_url: 尾帧图片URL(图生视频-基于首尾帧)
|
||||||
|
reference_images: 参考图片URL列表(图生视频-基于参考图)
|
||||||
|
draft_task_id: Draft任务ID(基于Draft生成正式视频)
|
||||||
|
duration: 视频时长(秒),与frames二选一
|
||||||
|
frames: 视频帧数,与duration二选一
|
||||||
|
ratio: 视频比例,如 "16:9", "9:16", "adaptive"
|
||||||
|
resolution: 视频分辨率,如 "720p", "1080p"
|
||||||
|
generate_audio: 是否生成音频
|
||||||
|
watermark: 是否添加水印
|
||||||
|
camera_fixed: 是否固定镜头
|
||||||
|
seed: 随机种子
|
||||||
|
return_last_frame: 是否返回最后一帧
|
||||||
|
service_tier: 服务层级,"default" 或 "flex"(离线推理)
|
||||||
|
execution_expires_after: 任务过期时间(秒)
|
||||||
|
draft: 是否生成样片
|
||||||
|
**kwargs: 其他参数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
生成结果(包含任务ID,需要轮询获取结果)
|
||||||
|
"""
|
||||||
|
provider = self._config.provider.lower()
|
||||||
|
|
||||||
|
if provider == ModelProvider.VOLCANO:
|
||||||
|
content = [{"type": "text", "text": prompt}]
|
||||||
|
|
||||||
|
if draft_task_id:
|
||||||
|
content = [{"type": "draft_task", "draft_task": {"id": draft_task_id}}]
|
||||||
|
else:
|
||||||
|
if image_url:
|
||||||
|
content.append({"type": "image_url", "image_url": {"url": image_url}})
|
||||||
|
|
||||||
|
if first_frame_url:
|
||||||
|
content.append({"type": "image_url", "image_url": {"url": first_frame_url}, "role": "first_frame"})
|
||||||
|
if last_frame_url:
|
||||||
|
content.append({"type": "image_url", "image_url": {"url": last_frame_url}, "role": "last_frame"})
|
||||||
|
|
||||||
|
if reference_images:
|
||||||
|
for ref_url in reference_images:
|
||||||
|
content.append({"type": "image_url", "image_url": {"url": ref_url}, "role": "reference_image"})
|
||||||
|
|
||||||
|
params = {"model": self._config.model_name, "content": content, "watermark": watermark}
|
||||||
|
|
||||||
|
if duration:
|
||||||
|
params["duration"] = duration
|
||||||
|
if frames:
|
||||||
|
params["frames"] = frames
|
||||||
|
if ratio:
|
||||||
|
params["ratio"] = ratio
|
||||||
|
if resolution:
|
||||||
|
params["resolution"] = resolution
|
||||||
|
if generate_audio:
|
||||||
|
params["generate_audio"] = generate_audio
|
||||||
|
if camera_fixed:
|
||||||
|
params["camera_fixed"] = camera_fixed
|
||||||
|
if seed is not None:
|
||||||
|
params["seed"] = seed
|
||||||
|
if return_last_frame:
|
||||||
|
params["return_last_frame"] = return_last_frame
|
||||||
|
if service_tier != "default":
|
||||||
|
params["service_tier"] = service_tier
|
||||||
|
if execution_expires_after:
|
||||||
|
params["execution_expires_after"] = execution_expires_after
|
||||||
|
if draft:
|
||||||
|
params["draft"] = draft
|
||||||
|
|
||||||
|
params.update(kwargs)
|
||||||
|
response = self._client.content_generation.tasks.create(**params)
|
||||||
|
else:
|
||||||
|
raise BusinessException(
|
||||||
|
f"不支持的提供商: {provider}",
|
||||||
|
code=BizCode.PROVIDER_NOT_SUPPORTED
|
||||||
|
)
|
||||||
|
|
||||||
|
return response.model_dump() if hasattr(response, 'model_dump') else response
|
||||||
|
|
||||||
|
async def agenerate(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
image_url: Optional[str] = None,
|
||||||
|
duration: Optional[int] = None,
|
||||||
|
**kwargs
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""异步生成视频"""
|
||||||
|
return self.generate(prompt, image_url=image_url, duration=duration, **kwargs)
|
||||||
|
|
||||||
|
def get_task_status(self, task_id: str) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
查询视频生成任务状态
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_id: 任务ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
任务状态信息
|
||||||
|
"""
|
||||||
|
provider = self._config.provider.lower()
|
||||||
|
|
||||||
|
if provider == ModelProvider.VOLCANO:
|
||||||
|
response = self._client.content_generation.tasks.get(task_id=task_id)
|
||||||
|
return response.model_dump() if hasattr(response, 'model_dump') else response
|
||||||
|
else:
|
||||||
|
raise BusinessException(
|
||||||
|
f"不支持的提供商: {provider}",
|
||||||
|
code=BizCode.PROVIDER_NOT_SUPPORTED
|
||||||
|
)
|
||||||
|
|
||||||
|
async def aget_task_status(self, task_id: str) -> Dict[str, Any]:
|
||||||
|
"""异步查询任务状态"""
|
||||||
|
return self.get_task_status(task_id)
|
||||||
|
|
||||||
|
def list_tasks(self, page_size: int = 10, status: Optional[str] = None, **kwargs) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
查询视频生成任务列表
|
||||||
|
|
||||||
|
Args:
|
||||||
|
page_size: 每页数量
|
||||||
|
status: 任务状态筛选,如 "succeeded", "failed", "pending"
|
||||||
|
**kwargs: 其他参数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
任务列表
|
||||||
|
"""
|
||||||
|
provider = self._config.provider.lower()
|
||||||
|
|
||||||
|
if provider == ModelProvider.VOLCANO:
|
||||||
|
params = {"page_size": page_size}
|
||||||
|
if status:
|
||||||
|
params["status"] = status
|
||||||
|
params.update(kwargs)
|
||||||
|
response = self._client.content_generation.tasks.list(**params)
|
||||||
|
return response.model_dump() if hasattr(response, 'model_dump') else response
|
||||||
|
else:
|
||||||
|
raise BusinessException(
|
||||||
|
f"不支持的提供商: {provider}",
|
||||||
|
code=BizCode.PROVIDER_NOT_SUPPORTED
|
||||||
|
)
|
||||||
|
|
||||||
|
def delete_task(self, task_id: str) -> None:
|
||||||
|
"""
|
||||||
|
删除或取消视频生成任务
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_id: 任务ID
|
||||||
|
"""
|
||||||
|
provider = self._config.provider.lower()
|
||||||
|
|
||||||
|
if provider == ModelProvider.VOLCANO:
|
||||||
|
self._client.content_generation.tasks.delete(task_id=task_id)
|
||||||
|
else:
|
||||||
|
raise BusinessException(
|
||||||
|
f"不支持的提供商: {provider}",
|
||||||
|
code=BizCode.PROVIDER_NOT_SUPPORTED
|
||||||
|
)
|
||||||
334
api/app/core/models/scripts/volcano_models.yaml
Normal file
334
api/app/core/models/scripts/volcano_models.yaml
Normal file
@@ -0,0 +1,334 @@
|
|||||||
|
provider: volcano
|
||||||
|
models:
|
||||||
|
# Doubao-Seed 2.0 系列
|
||||||
|
- name: doubao-seed-2-0-pro-260215
|
||||||
|
type: chat
|
||||||
|
provider: volcano
|
||||||
|
description: 旗舰级全能通用模型,面向 Agent 时代的复杂推理与长链路任务执行场景。强调多模态理解、长上下文推理、结构化生成与工具增强执行。复杂指令与多约束执行能力突出,可稳定应对多步复杂规划、复杂图文推理、视频内容理解与高难度分析等场景。侧重长链路推理能力与复杂任务稳定性,适配真实业务中的复杂场景。
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
capability:
|
||||||
|
- vision
|
||||||
|
- video
|
||||||
|
is_omni: false
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
logo: volcano
|
||||||
|
|
||||||
|
- name: doubao-seed-2-0-lite-260215
|
||||||
|
type: chat
|
||||||
|
provider: volcano
|
||||||
|
description: 面向高频企业场景兼顾性能与成本的均衡型模型,综合能力超越上一代Doubao-Seed-1.8。胜任非结构化信息处理、内容创作、搜索推荐、数据分析等生产型工作,支持长上下文、多源信息融合、多步指令执行与高保真结构化输出。在保障稳定效果的同时显著优化成本。兼顾生成质量与响应速度,适合作为通用生产级模型。
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
capability:
|
||||||
|
- vision
|
||||||
|
- video
|
||||||
|
is_omni: false
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
logo: volcano
|
||||||
|
|
||||||
|
- name: doubao-seed-2-0-mini-260215
|
||||||
|
type: chat
|
||||||
|
provider: volcano
|
||||||
|
description: 面向低时延、高并发与成本敏感场景,提供极致的模型推理速度。模型效果与Doubao-Seed-1.6相当。支持256k上下文、4档思考长度和多模态理解,适合成本和速度优先的轻量级任务。
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
capability:
|
||||||
|
- vision
|
||||||
|
- video
|
||||||
|
is_omni: false
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
logo: volcano
|
||||||
|
|
||||||
|
- name: doubao-seed-2-0-code-preview-260215
|
||||||
|
type: chat
|
||||||
|
provider: volcano
|
||||||
|
description: 面向真实编程环境优化的 Coding 模型,能稳定调用 Claude Code 等常见 IDE 中的工具。模型特别优化了前端能力,在使用常见的前端框架时能有良好表现。模型支持使用 Skills,可以配合多种自定义技能使用。Seed 2.0 的编程加强版,更适合 Agentic Coding。
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
capability:
|
||||||
|
- vision
|
||||||
|
- video
|
||||||
|
is_omni: false
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- 代码模型
|
||||||
|
logo: volcano
|
||||||
|
|
||||||
|
# Doubao-Seed 1.x 系列
|
||||||
|
- name: doubao-seed-1-8-251228
|
||||||
|
type: chat
|
||||||
|
provider: volcano
|
||||||
|
description: Doubao-Seed-1.8 面向多模态 Agent 场景定向优化。Agent 能力上,Tool Use、复杂指令遵循等能力均大幅增强。多模态理解方面,视觉基础能力显著提升,可低帧率理解超长视频,视频运动理解、复杂空间理解及文档结构化解析能力也有所优化,还原生支持智能上下文管理,用户可配置上下文策略。
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
capability:
|
||||||
|
- vision
|
||||||
|
- video
|
||||||
|
is_omni: false
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
logo: volcano
|
||||||
|
|
||||||
|
- name: doubao-seed-1-6-251015
|
||||||
|
type: chat
|
||||||
|
provider: volcano
|
||||||
|
description: Doubao-Seed-1.6全新多模态深度思考模型,同时支持minimal/low/medium/high 四种reasoning effort。 更强模型效果,服务复杂任务和有挑战场景。支持 256k 上下文窗口,输出长度支持最大 32k tokens。
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
capability:
|
||||||
|
- vision
|
||||||
|
- video
|
||||||
|
is_omni: false
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
logo: volcano
|
||||||
|
|
||||||
|
- name: doubao-seed-1-6-lite-251015
|
||||||
|
type: chat
|
||||||
|
provider: volcano
|
||||||
|
description: 更高性价比,常见任务的最佳选择,支持minimal、low、medium、high 四种reasoning_effort思考深度
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
capability:
|
||||||
|
- vision
|
||||||
|
- video
|
||||||
|
is_omni: false
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
logo: volcano
|
||||||
|
|
||||||
|
- name: doubao-seed-1-6-flash-250828
|
||||||
|
type: chat
|
||||||
|
provider: volcano
|
||||||
|
description: Doubao-Seed-1.6-flash推理速度极致的多模态深度思考模型,TPOT低至10ms; 同时支持文本和视觉理解,文本理解能力超过上一代lite,视觉理解比肩友商pro系列模型。支持 256k 上下文窗口,输出长度支持最大 16k tokens。
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
capability:
|
||||||
|
- vision
|
||||||
|
- video
|
||||||
|
is_omni: false
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
logo: volcano
|
||||||
|
|
||||||
|
- name: doubao-seed-code-preview-251028
|
||||||
|
type: chat
|
||||||
|
provider: volcano
|
||||||
|
description: 面向Agentic编程任务进行了深度优化。
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
capability:
|
||||||
|
- vision
|
||||||
|
- video
|
||||||
|
is_omni: false
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- 代码模型
|
||||||
|
logo: volcano
|
||||||
|
|
||||||
|
- name: doubao-seed-1-6-vision-250815
|
||||||
|
type: chat
|
||||||
|
provider: volcano
|
||||||
|
description: 全新Doubao-Seed-1.6系列视觉深度思考模型,视觉理解能力显著增强,并支持image_process视觉工具
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
capability:
|
||||||
|
- vision
|
||||||
|
- video
|
||||||
|
is_omni: false
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- 多模态模型
|
||||||
|
logo: volcano
|
||||||
|
|
||||||
|
# Doubao 1.5 系列
|
||||||
|
- name: doubao-1-5-vision-pro-32k-250115
|
||||||
|
type: chat
|
||||||
|
provider: volcano
|
||||||
|
description: 全新升级的多模态大模型,支持任意分辨率和极端长宽比图像识别,增强视觉推理、文档识别、细节信息理解和指令遵循能力。支持 32k 上下文窗口,输出长度支持最大 12k tokens。
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
capability:
|
||||||
|
- vision
|
||||||
|
is_omni: false
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
- 多模态模型
|
||||||
|
logo: volcano
|
||||||
|
|
||||||
|
- name: doubao-1-5-pro-32k-250115
|
||||||
|
type: chat
|
||||||
|
provider: volcano
|
||||||
|
description: 全新一代主力模型,性能全面升级,在知识、代码、推理等方面表现卓越。最大支持 128k 上下文窗口,输出长度支持最大 12k tokens。
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
capability: []
|
||||||
|
is_omni: false
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
logo: volcano
|
||||||
|
|
||||||
|
- name: doubao-1-5-lite-32k-250115
|
||||||
|
type: chat
|
||||||
|
provider: volcano
|
||||||
|
description: 全新一代轻量版模型,极致响应速度,效果与时延均达到全球一流水平。支持 32k 上下文窗口,输出长度支持最大 12k tokens。
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
capability: []
|
||||||
|
is_omni: false
|
||||||
|
tags:
|
||||||
|
- 大语言模型
|
||||||
|
logo: volcano
|
||||||
|
|
||||||
|
# Doubao-Seedance 视频生成系列
|
||||||
|
- name: doubao-seedance-1-5-pro-251215
|
||||||
|
type: video
|
||||||
|
provider: volcano
|
||||||
|
description: 豆包视频生成模型Seedance 1.5 pro 作为全球领先的视频生成模型,可生成音画高精同步的视频内容。支持多人多语言对白,全面覆盖环境音、动作音、合成音、乐器音、背景音及人声,支持首尾帧,实现影视级叙事效果,满足影视、漫剧、电商及广告领域的高阶创作需求。
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
capability:
|
||||||
|
- vision
|
||||||
|
is_omni: false
|
||||||
|
tags:
|
||||||
|
- 视频生成
|
||||||
|
logo: volcano
|
||||||
|
|
||||||
|
- name: doubao-seedance-1-0-pro-250528
|
||||||
|
type: video
|
||||||
|
provider: volcano
|
||||||
|
description: 一款支持多镜头叙事的视频生成基础模型,在各维度表现出色。它在语义理解与指令遵循能力上取得突破,能生成运动流畅、细节丰富、风格多样且具备影视级美感的 1080P 高清视频
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
capability:
|
||||||
|
- vision
|
||||||
|
is_omni: false
|
||||||
|
tags:
|
||||||
|
- 视频生成
|
||||||
|
logo: volcano
|
||||||
|
|
||||||
|
- name: doubao-seedance-1-0-pro-fast-251015
|
||||||
|
type: video
|
||||||
|
provider: volcano
|
||||||
|
description: 一款价格触底、效能封顶的全面模型,在视频生成质量、速度、价格之间取得了卓越平衡。它继承了Seedance 1.0 pro 核心优势,同时生成速度提升、价格更具竞争力,为创作者带来效率与成本双重优化的体验。
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
capability:
|
||||||
|
- vision
|
||||||
|
is_omni: false
|
||||||
|
tags:
|
||||||
|
- 视频生成
|
||||||
|
logo: volcano
|
||||||
|
|
||||||
|
- name: doubao-seedance-1-0-lite-i2v-250428
|
||||||
|
type: video
|
||||||
|
provider: volcano
|
||||||
|
description: 基于首帧图片、尾帧图片(可选)、参考图片(可选)和文本提示词(可选)相结合的方式生成视频
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
capability:
|
||||||
|
- vision
|
||||||
|
is_omni: false
|
||||||
|
tags:
|
||||||
|
- 视频生成
|
||||||
|
- 图生视频
|
||||||
|
logo: volcano
|
||||||
|
|
||||||
|
- name: doubao-seedance-1-0-lite-t2v-250428
|
||||||
|
type: video
|
||||||
|
provider: volcano
|
||||||
|
description: 基于文本提示词生成视频
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
capability: []
|
||||||
|
is_omni: false
|
||||||
|
tags:
|
||||||
|
- 视频生成
|
||||||
|
- 文生视频
|
||||||
|
logo: volcano
|
||||||
|
|
||||||
|
# Doubao-Seedream 图像生成系列
|
||||||
|
- name: doubao-seedream-5-0-260128
|
||||||
|
type: image
|
||||||
|
provider: volcano
|
||||||
|
description: 字节跳动发布的最新图像创作模型。该模型首次搭载联网检索功能,能融合实时网络信息,提升生图时效性。同时,模型的聪明度进一步升级,能够精准解析复杂指令和视觉内容。此外,模型在世界知识广度、参考一致性及专业场景生成质量上均有增强,可更好地满足企业级视觉创作需求。
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
capability:
|
||||||
|
- vision
|
||||||
|
is_omni: false
|
||||||
|
tags:
|
||||||
|
- 图像生成
|
||||||
|
logo: volcano
|
||||||
|
|
||||||
|
- name: doubao-seedream-4-5-251128
|
||||||
|
type: image
|
||||||
|
provider: volcano
|
||||||
|
description: 字节跳动最新推出的图像多模态模型,整合了文生图、图生图、组图输出等能力,融合常识和推理能力。相比前代4.0模型生成效果大幅提升,具备更好的编辑一致性和多图融合效果,能更精准的控制画面细节,小字、小人脸生成更自然,图片排版、色彩更和谐,美感提升。
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
capability:
|
||||||
|
- vision
|
||||||
|
is_omni: false
|
||||||
|
tags:
|
||||||
|
- 图像生成
|
||||||
|
logo: volcano
|
||||||
|
|
||||||
|
- name: doubao-seedream-4-0-250828
|
||||||
|
type: image
|
||||||
|
provider: volcano
|
||||||
|
description: 基于领先架构的SOTA级多模态图像创作模型,其生成美感、指令遵循、结构完整度、主体保持一致性处于世界头部水平。模型采用同一套架构实现文生图与编辑能力的统一,原生支持文本 、单图和多图输入,并能通过对提示词的深度推理,自动适配最优的图像比例尺寸与生成数量,可一次性连续输出最多 15 张内容关联的图像,支持 4K 超高清输出。
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
capability:
|
||||||
|
- vision
|
||||||
|
is_omni: false
|
||||||
|
tags:
|
||||||
|
- 图像生成
|
||||||
|
logo: volcano
|
||||||
|
|
||||||
|
- name: doubao-seedream-3-0-t2i-250415
|
||||||
|
type: image
|
||||||
|
provider: volcano
|
||||||
|
description: 一款支持原生高分辨率的中英双语图像生成基础模型,综合能力媲美GPT-4o,处于世界第一梯队。支持原生 2K 分辨率输出;响应速度更快;小字生成更准确,文本排版效果增强;指令遵循能力强,美感&结构提升,保真度和细节表现较好。
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
capability: []
|
||||||
|
is_omni: false
|
||||||
|
tags:
|
||||||
|
- 图像生成
|
||||||
|
- 文生图
|
||||||
|
logo: volcano
|
||||||
|
|
||||||
|
# Doubao 翻译系列
|
||||||
|
- name: doubao-seed-translation-250915
|
||||||
|
type: chat
|
||||||
|
provider: volcano
|
||||||
|
description: 通用多语言翻译模型,支持30余种语言互译,支持 4K 上下文窗口,输出长度支持最大 3K tokens
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
capability: []
|
||||||
|
is_omni: false
|
||||||
|
tags:
|
||||||
|
- 翻译模型
|
||||||
|
logo: volcano
|
||||||
|
|
||||||
|
# Doubao Embedding 系列
|
||||||
|
- name: doubao-embedding-vision-251215
|
||||||
|
type: embedding
|
||||||
|
provider: volcano
|
||||||
|
description: 主要面向图文多模向量检索的使用场景,支持图片输入及中、英双语文本输入,最长 128K 上下文长度。
|
||||||
|
is_deprecated: false
|
||||||
|
is_official: true
|
||||||
|
capability:
|
||||||
|
- vision
|
||||||
|
- video
|
||||||
|
is_omni: false
|
||||||
|
tags:
|
||||||
|
- 向量模型
|
||||||
|
- 多模态模型
|
||||||
|
logo: volcano
|
||||||
@@ -61,24 +61,16 @@ class ElasticSearchConfig(BaseModel):
|
|||||||
class ElasticSearchVector(BaseVector):
|
class ElasticSearchVector(BaseVector):
|
||||||
def __init__(self, index_name: str, config: ElasticSearchConfig, embedding_config: ModelApiKey, reranker_config: ModelApiKey):
|
def __init__(self, index_name: str, config: ElasticSearchConfig, embedding_config: ModelApiKey, reranker_config: ModelApiKey):
|
||||||
super().__init__(index_name.lower())
|
super().__init__(index_name.lower())
|
||||||
# self.embeddings = XinferenceEmbeddings(
|
|
||||||
# server_url=os.getenv("XINFERENCE_URL", "http://127.0.0.1"), # Default Xinference port
|
# 初始化 Embedding 模型(自动支持火山引擎多模态)
|
||||||
# model_uid="bge-m3" # replace model_uid with the model UID return from launching the model
|
|
||||||
# )
|
|
||||||
# Remove debug printing to avoid leaking sensitive information
|
|
||||||
# print("embedding:" + embedding_config.model_name + "|" + embedding_config.provider + "|" + embedding_config.api_key + "|" + embedding_config.api_base)
|
|
||||||
self.embeddings = RedBearEmbeddings(RedBearModelConfig(
|
self.embeddings = RedBearEmbeddings(RedBearModelConfig(
|
||||||
model_name=embedding_config.model_name,
|
model_name=embedding_config.model_name,
|
||||||
provider=embedding_config.provider,
|
provider=embedding_config.provider,
|
||||||
api_key=embedding_config.api_key,
|
api_key=embedding_config.api_key,
|
||||||
base_url=embedding_config.api_base
|
base_url=embedding_config.api_base
|
||||||
))
|
))
|
||||||
# self.reranker = XinferenceRerank(
|
self.is_multimodal_embedding = self.embeddings.is_multimodal_supported()
|
||||||
# server_url=os.getenv("XINFERENCE_URL", "http://127.0.0.1"),
|
|
||||||
# model_uid="bge-reranker-large"
|
|
||||||
# )
|
|
||||||
# Remove debug printing to avoid leaking sensitive information
|
|
||||||
# print("reranker:"+ reranker_config.model_name + "|" + reranker_config.provider + "|" + reranker_config.api_key + "|" + reranker_config.api_base)
|
|
||||||
self.reranker = RedBearRerank(RedBearModelConfig(
|
self.reranker = RedBearRerank(RedBearModelConfig(
|
||||||
model_name=reranker_config.model_name,
|
model_name=reranker_config.model_name,
|
||||||
provider=reranker_config.provider,
|
provider=reranker_config.provider,
|
||||||
@@ -144,7 +136,11 @@ class ElasticSearchVector(BaseVector):
|
|||||||
def add_chunks(self, chunks: list[DocumentChunk], **kwargs):
|
def add_chunks(self, chunks: list[DocumentChunk], **kwargs):
|
||||||
# 实现 Elasticsearch 保存向量
|
# 实现 Elasticsearch 保存向量
|
||||||
texts = [chunk.page_content for chunk in chunks]
|
texts = [chunk.page_content for chunk in chunks]
|
||||||
embeddings = self.embeddings.embed_documents(list(texts))
|
if self.is_multimodal_embedding:
|
||||||
|
# 火山引擎多模态 Embedding
|
||||||
|
embeddings = self.embeddings.embed_batch(texts)
|
||||||
|
else:
|
||||||
|
embeddings = self.embeddings.embed_documents(list(texts))
|
||||||
self.create(chunks, embeddings, **kwargs)
|
self.create(chunks, embeddings, **kwargs)
|
||||||
|
|
||||||
def create(self, chunks: list[DocumentChunk], embeddings: list[list[float]], **kwargs):
|
def create(self, chunks: list[DocumentChunk], embeddings: list[list[float]], **kwargs):
|
||||||
@@ -394,7 +390,11 @@ class ElasticSearchVector(BaseVector):
|
|||||||
updated count.
|
updated count.
|
||||||
"""
|
"""
|
||||||
indices = kwargs.get("indices", self._collection_name) # Default single index, multi-index available,etc "index1,index2,index3"
|
indices = kwargs.get("indices", self._collection_name) # Default single index, multi-index available,etc "index1,index2,index3"
|
||||||
chunk.vector = self.embeddings.embed_query(chunk.page_content)
|
if self.is_multimodal_embedding:
|
||||||
|
# 火山引擎多模态 Embedding
|
||||||
|
chunk.vector = self.embeddings.embed_text(chunk.page_content)
|
||||||
|
else:
|
||||||
|
chunk.vector = self.embeddings.embed_query(chunk.page_content)
|
||||||
|
|
||||||
body = {
|
body = {
|
||||||
"script": {
|
"script": {
|
||||||
@@ -454,7 +454,11 @@ class ElasticSearchVector(BaseVector):
|
|||||||
|
|
||||||
def search_by_vector(self, query: str, **kwargs: Any) -> list[DocumentChunk]:
|
def search_by_vector(self, query: str, **kwargs: Any) -> list[DocumentChunk]:
|
||||||
"""Search the nearest neighbors to a vector."""
|
"""Search the nearest neighbors to a vector."""
|
||||||
query_vector = self.embeddings.embed_query(query)
|
if self.is_multimodal_embedding:
|
||||||
|
# 火山引擎多模态 Embedding
|
||||||
|
query_vector = self.embeddings.embed_text(query)
|
||||||
|
else:
|
||||||
|
query_vector = self.embeddings.embed_query(query)
|
||||||
top_k = kwargs.get("top_k", 1024)
|
top_k = kwargs.get("top_k", 1024)
|
||||||
score_threshold = float(kwargs.get("score_threshold") or 0.3)
|
score_threshold = float(kwargs.get("score_threshold") or 0.3)
|
||||||
indices = kwargs.get("indices", self._collection_name) # Default single index, multi-index available,etc "index1,index2,index3"
|
indices = kwargs.get("indices", self._collection_name) # Default single index, multi-index available,etc "index1,index2,index3"
|
||||||
|
|||||||
@@ -109,15 +109,26 @@ class StorageBackend(ABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def get_url(self, file_key: str, expires: int = 3600) -> str:
|
async def get_url(
|
||||||
|
self,
|
||||||
|
file_key: str,
|
||||||
|
expires: int = 3600,
|
||||||
|
file_name: Optional[str] = None
|
||||||
|
) -> str:
|
||||||
|
"""Get an access URL for the file."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def get_permanent_url(self, file_key: str) -> Optional[str]:
|
||||||
"""
|
"""
|
||||||
Get an access URL for the file.
|
Get a permanent public URL for the file (no expiration).
|
||||||
|
|
||||||
|
Returns None by default; remote storage backends should override this
|
||||||
|
if the bucket is configured for public read access.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
file_key: Unique identifier for the file in the storage system.
|
file_key: Unique identifier for the file in the storage system.
|
||||||
expires: URL validity period in seconds (default: 1 hour).
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
URL for accessing the file.
|
A permanent public URL, or None if not supported.
|
||||||
"""
|
"""
|
||||||
pass
|
return None
|
||||||
|
|||||||
@@ -210,7 +210,12 @@ class LocalStorage(StorageBackend):
|
|||||||
cause=e,
|
cause=e,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def get_url(self, file_key: str, expires: int = 3600) -> str:
|
async def get_url(
|
||||||
|
self,
|
||||||
|
file_key: str,
|
||||||
|
expires: int = 3600,
|
||||||
|
file_name: Optional[str] = None
|
||||||
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Get an access URL for the file.
|
Get an access URL for the file.
|
||||||
|
|
||||||
@@ -220,6 +225,7 @@ class LocalStorage(StorageBackend):
|
|||||||
Args:
|
Args:
|
||||||
file_key: Unique identifier for the file in the storage system.
|
file_key: Unique identifier for the file in the storage system.
|
||||||
expires: URL validity period in seconds (not used for local storage).
|
expires: URL validity period in seconds (not used for local storage).
|
||||||
|
file_name: If set, adds Content-Disposition: attachment to force download.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A relative URL path for accessing the file.
|
A relative URL path for accessing the file.
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ Storage Service (OSS) using the oss2 SDK.
|
|||||||
|
|
||||||
import io
|
import io
|
||||||
import logging
|
import logging
|
||||||
|
import urllib.parse
|
||||||
from typing import AsyncIterator, Optional
|
from typing import AsyncIterator, Optional
|
||||||
|
|
||||||
import oss2
|
import oss2
|
||||||
@@ -43,6 +44,8 @@ class OSSStorage(StorageBackend):
|
|||||||
access_key_id: str,
|
access_key_id: str,
|
||||||
access_key_secret: str,
|
access_key_secret: str,
|
||||||
bucket_name: str,
|
bucket_name: str,
|
||||||
|
connect_timeout: int = 30,
|
||||||
|
multipart_threshold: int = 10 * 1024 * 1024, # 10MB
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Initialize the OSSStorage backend.
|
Initialize the OSSStorage backend.
|
||||||
@@ -52,6 +55,8 @@ class OSSStorage(StorageBackend):
|
|||||||
access_key_id: The Aliyun access key ID.
|
access_key_id: The Aliyun access key ID.
|
||||||
access_key_secret: The Aliyun access key secret.
|
access_key_secret: The Aliyun access key secret.
|
||||||
bucket_name: The name of the OSS bucket.
|
bucket_name: The name of the OSS bucket.
|
||||||
|
connect_timeout: Connection timeout in seconds (default: 30).
|
||||||
|
multipart_threshold: File size threshold for multipart upload (default: 10MB).
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
StorageConfigError: If any required configuration is missing.
|
StorageConfigError: If any required configuration is missing.
|
||||||
@@ -68,10 +73,17 @@ class OSSStorage(StorageBackend):
|
|||||||
|
|
||||||
self.endpoint = endpoint
|
self.endpoint = endpoint
|
||||||
self.bucket_name = bucket_name
|
self.bucket_name = bucket_name
|
||||||
|
self.multipart_threshold = multipart_threshold
|
||||||
|
|
||||||
try:
|
try:
|
||||||
auth = oss2.Auth(access_key_id, access_key_secret)
|
auth = oss2.Auth(access_key_id, access_key_secret)
|
||||||
self.bucket = oss2.Bucket(auth, endpoint, bucket_name)
|
# 设置超时和重试
|
||||||
|
self.bucket = oss2.Bucket(
|
||||||
|
auth,
|
||||||
|
endpoint,
|
||||||
|
bucket_name,
|
||||||
|
connect_timeout=connect_timeout
|
||||||
|
)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"OSSStorage initialized with endpoint: {endpoint}, bucket: {bucket_name}"
|
f"OSSStorage initialized with endpoint: {endpoint}, bucket: {bucket_name}"
|
||||||
)
|
)
|
||||||
@@ -107,21 +119,38 @@ class OSSStorage(StorageBackend):
|
|||||||
if content_type:
|
if content_type:
|
||||||
headers["Content-Type"] = content_type
|
headers["Content-Type"] = content_type
|
||||||
|
|
||||||
self.bucket.put_object(file_key, content, headers=headers if headers else None)
|
# 大文件使用分片上传
|
||||||
|
if len(content) > self.multipart_threshold:
|
||||||
|
logger.info(f"Using multipart upload for large file: {file_key} ({len(content)} bytes)")
|
||||||
|
upload_id = self.bucket.init_multipart_upload(file_key, headers=headers if headers else None).upload_id
|
||||||
|
parts = []
|
||||||
|
part_size = 5 * 1024 * 1024 # 5MB per part
|
||||||
|
part_num = 1
|
||||||
|
|
||||||
|
for offset in range(0, len(content), part_size):
|
||||||
|
chunk = content[offset:offset + part_size]
|
||||||
|
result = self.bucket.upload_part(file_key, upload_id, part_num, chunk)
|
||||||
|
parts.append(oss2.models.PartInfo(part_num, result.etag))
|
||||||
|
part_num += 1
|
||||||
|
|
||||||
|
self.bucket.complete_multipart_upload(file_key, upload_id, parts)
|
||||||
|
else:
|
||||||
|
self.bucket.put_object(file_key, content, headers=headers if headers else None)
|
||||||
|
|
||||||
logger.info(f"File uploaded to OSS successfully: {file_key}")
|
logger.info(f"File uploaded to OSS successfully: {file_key}")
|
||||||
return file_key
|
return file_key
|
||||||
|
|
||||||
except OssError as e:
|
except OssError as e:
|
||||||
logger.error(f"OSS error uploading file {file_key}: {e}")
|
logger.error(f"OSS error uploading file {file_key}: {e}")
|
||||||
raise StorageUploadError(
|
raise StorageUploadError(
|
||||||
message=f"Failed to upload file to OSS: {e.message}",
|
message=f"Failed to upload file to OSS: {str(e)}",
|
||||||
file_key=file_key,
|
file_key=file_key,
|
||||||
cause=e,
|
cause=e,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to upload file to OSS {file_key}: {e}")
|
logger.error(f"Failed to upload file to OSS {file_key}: {e}")
|
||||||
raise StorageUploadError(
|
raise StorageUploadError(
|
||||||
message=f"Failed to upload file to OSS: {e}",
|
message=f"Failed to upload file to OSS: {str(e)}",
|
||||||
file_key=file_key,
|
file_key=file_key,
|
||||||
cause=e,
|
cause=e,
|
||||||
)
|
)
|
||||||
@@ -134,28 +163,73 @@ class OSSStorage(StorageBackend):
|
|||||||
) -> int:
|
) -> int:
|
||||||
"""Upload from async stream to OSS. Returns total bytes written."""
|
"""Upload from async stream to OSS. Returns total bytes written."""
|
||||||
buf = io.BytesIO()
|
buf = io.BytesIO()
|
||||||
|
headers = {"Content-Type": content_type} if content_type else None
|
||||||
|
upload_id = None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
# 收集流数据
|
||||||
|
total_size = 0
|
||||||
async for chunk in stream:
|
async for chunk in stream:
|
||||||
|
if not chunk:
|
||||||
|
continue
|
||||||
buf.write(chunk)
|
buf.write(chunk)
|
||||||
|
total_size += len(chunk)
|
||||||
|
|
||||||
content = buf.getvalue()
|
content = buf.getvalue()
|
||||||
headers = {"Content-Type": content_type} if content_type else None
|
|
||||||
self.bucket.put_object(file_key, content, headers=headers)
|
if not content:
|
||||||
logger.info(f"File stream uploaded to OSS successfully: {file_key}")
|
raise StorageUploadError(
|
||||||
return len(content)
|
message="Empty stream content",
|
||||||
|
file_key=file_key,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 大文件使用分片上传
|
||||||
|
if len(content) > self.multipart_threshold:
|
||||||
|
logger.info(f"Using multipart upload for stream: {file_key} ({len(content)} bytes)")
|
||||||
|
upload_id = self.bucket.init_multipart_upload(file_key, headers=headers).upload_id
|
||||||
|
parts = []
|
||||||
|
part_size = 5 * 1024 * 1024 # 5MB
|
||||||
|
part_num = 1
|
||||||
|
|
||||||
|
for offset in range(0, len(content), part_size):
|
||||||
|
chunk = content[offset:offset + part_size]
|
||||||
|
result = self.bucket.upload_part(file_key, upload_id, part_num, chunk)
|
||||||
|
parts.append(oss2.models.PartInfo(part_num, result.etag))
|
||||||
|
part_num += 1
|
||||||
|
|
||||||
|
self.bucket.complete_multipart_upload(file_key, upload_id, parts)
|
||||||
|
else:
|
||||||
|
self.bucket.put_object(file_key, content, headers=headers)
|
||||||
|
|
||||||
|
logger.info(f"File stream uploaded to OSS successfully: {file_key} ({total_size} bytes)")
|
||||||
|
return total_size
|
||||||
|
|
||||||
except OssError as e:
|
except OssError as e:
|
||||||
|
if upload_id:
|
||||||
|
try:
|
||||||
|
self.bucket.abort_multipart_upload(file_key, upload_id)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
logger.error(f"OSS error stream uploading file {file_key}: {e}")
|
logger.error(f"OSS error stream uploading file {file_key}: {e}")
|
||||||
raise StorageUploadError(
|
raise StorageUploadError(
|
||||||
message=f"Failed to stream upload file to OSS: {e.message}",
|
message=f"Failed to stream upload file to OSS: {str(e)}",
|
||||||
file_key=file_key,
|
file_key=file_key,
|
||||||
cause=e,
|
cause=e,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
if upload_id:
|
||||||
|
try:
|
||||||
|
self.bucket.abort_multipart_upload(file_key, upload_id)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
logger.error(f"Failed to stream upload file to OSS {file_key}: {e}")
|
logger.error(f"Failed to stream upload file to OSS {file_key}: {e}")
|
||||||
raise StorageUploadError(
|
raise StorageUploadError(
|
||||||
message=f"Failed to stream upload file to OSS: {e}",
|
message=f"Failed to stream upload file to OSS: {str(e)}",
|
||||||
file_key=file_key,
|
file_key=file_key,
|
||||||
cause=e,
|
cause=e,
|
||||||
)
|
)
|
||||||
|
finally:
|
||||||
|
buf.close()
|
||||||
|
|
||||||
async def download(self, file_key: str) -> bytes:
|
async def download(self, file_key: str) -> bytes:
|
||||||
"""
|
"""
|
||||||
@@ -181,14 +255,14 @@ class OSSStorage(StorageBackend):
|
|||||||
except OssError as e:
|
except OssError as e:
|
||||||
logger.error(f"OSS error downloading file {file_key}: {e}")
|
logger.error(f"OSS error downloading file {file_key}: {e}")
|
||||||
raise StorageDownloadError(
|
raise StorageDownloadError(
|
||||||
message=f"Failed to download file from OSS: {e.message}",
|
message=f"Failed to download file from OSS: {str(e)}",
|
||||||
file_key=file_key,
|
file_key=file_key,
|
||||||
cause=e,
|
cause=e,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to download file from OSS {file_key}: {e}")
|
logger.error(f"Failed to download file from OSS {file_key}: {e}")
|
||||||
raise StorageDownloadError(
|
raise StorageDownloadError(
|
||||||
message=f"Failed to download file from OSS: {e}",
|
message=f"Failed to download file from OSS: {str(e)}",
|
||||||
file_key=file_key,
|
file_key=file_key,
|
||||||
cause=e,
|
cause=e,
|
||||||
)
|
)
|
||||||
@@ -214,14 +288,14 @@ class OSSStorage(StorageBackend):
|
|||||||
except OssError as e:
|
except OssError as e:
|
||||||
logger.error(f"OSS error deleting file {file_key}: {e}")
|
logger.error(f"OSS error deleting file {file_key}: {e}")
|
||||||
raise StorageDeleteError(
|
raise StorageDeleteError(
|
||||||
message=f"Failed to delete file from OSS: {e.message}",
|
message=f"Failed to delete file from OSS: {str(e)}",
|
||||||
file_key=file_key,
|
file_key=file_key,
|
||||||
cause=e,
|
cause=e,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to delete file from OSS {file_key}: {e}")
|
logger.error(f"Failed to delete file from OSS {file_key}: {e}")
|
||||||
raise StorageDeleteError(
|
raise StorageDeleteError(
|
||||||
message=f"Failed to delete file from OSS: {e}",
|
message=f"Failed to delete file from OSS: {str(e)}",
|
||||||
file_key=file_key,
|
file_key=file_key,
|
||||||
cause=e,
|
cause=e,
|
||||||
)
|
)
|
||||||
@@ -242,22 +316,41 @@ class OSSStorage(StorageBackend):
|
|||||||
logger.error(f"Failed to check file existence in OSS {file_key}: {e}")
|
logger.error(f"Failed to check file existence in OSS {file_key}: {e}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
async def get_url(self, file_key: str, expires: int = 3600) -> str:
|
async def get_url(
|
||||||
|
self,
|
||||||
|
file_key: str,
|
||||||
|
expires: int = 3600,
|
||||||
|
file_name: Optional[str] = None,
|
||||||
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Get a presigned URL for accessing the file.
|
Get a presigned URL for accessing the file.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
file_key: Unique identifier for the file in the storage system.
|
file_key: Unique identifier for the file in the storage system.
|
||||||
expires: URL validity period in seconds (default: 1 hour).
|
expires: URL validity period in seconds (default: 1 hour).
|
||||||
|
file_name: If set, adds Content-Disposition: attachment to force download.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A presigned URL for accessing the file.
|
A presigned URL for accessing the file.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
url = self.bucket.sign_url("GET", file_key, expires)
|
params = {}
|
||||||
|
if file_name:
|
||||||
|
filename_encoded = urllib.parse.quote(file_name.encode("utf-8"))
|
||||||
|
params["response-content-disposition"] = f"attachment; filename*=UTF-8''{filename_encoded}"
|
||||||
|
url = self.bucket.sign_url("GET", file_key, expires, params=params if params else None)
|
||||||
logger.debug(f"Generated presigned URL for {file_key}, expires in {expires}s")
|
logger.debug(f"Generated presigned URL for {file_key}, expires in {expires}s")
|
||||||
return url
|
return url
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to generate presigned URL for {file_key}: {e}")
|
logger.error(f"Failed to generate presigned URL for {file_key}: {e}")
|
||||||
# Return a basic URL format as fallback
|
|
||||||
return f"https://{self.bucket_name}.{self.endpoint.replace('https://', '').replace('http://', '')}/{file_key}"
|
return f"https://{self.bucket_name}.{self.endpoint.replace('https://', '').replace('http://', '')}/{file_key}"
|
||||||
|
|
||||||
|
async def get_permanent_url(self, file_key: str) -> str:
|
||||||
|
"""
|
||||||
|
Get a permanent public URL for the file (requires bucket public read).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A permanent URL in the format: https://{bucket}.{endpoint}/{file_key}
|
||||||
|
"""
|
||||||
|
host = self.endpoint.replace("https://", "").replace("http://", "")
|
||||||
|
return f"https://{self.bucket_name}.{host}/{file_key}"
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ using the boto3 SDK.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import io
|
import io
|
||||||
|
import urllib.parse
|
||||||
import logging
|
import logging
|
||||||
from typing import AsyncIterator, Optional
|
from typing import AsyncIterator, Optional
|
||||||
|
|
||||||
@@ -352,29 +353,44 @@ class S3Storage(StorageBackend):
|
|||||||
logger.error(f"Failed to check file existence in S3 {file_key}: {e}")
|
logger.error(f"Failed to check file existence in S3 {file_key}: {e}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
async def get_url(self, file_key: str, expires: int = 3600) -> str:
|
async def get_url(
|
||||||
|
self,
|
||||||
|
file_key: str,
|
||||||
|
expires: int = 3600,
|
||||||
|
file_name: Optional[str] = None,
|
||||||
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Get a presigned URL for accessing the file.
|
Get a presigned URL for accessing the file.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
file_key: Unique identifier for the file in the storage system.
|
file_key: Unique identifier for the file in the storage system.
|
||||||
expires: URL validity period in seconds (default: 1 hour).
|
expires: URL validity period in seconds (default: 1 hour).
|
||||||
|
file_name: If set, adds Content-Disposition: attachment to force download.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A presigned URL for accessing the file.
|
A presigned URL for accessing the file.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
|
params = {"Bucket": self.bucket_name, "Key": file_key}
|
||||||
|
if file_name:
|
||||||
|
filename_encoded = urllib.parse.quote(file_name.encode("utf-8"))
|
||||||
|
params["ResponseContentDisposition"] = f"attachment; filename*=UTF-8''{filename_encoded}"
|
||||||
url = self.client.generate_presigned_url(
|
url = self.client.generate_presigned_url(
|
||||||
"get_object",
|
"get_object",
|
||||||
Params={
|
Params=params,
|
||||||
"Bucket": self.bucket_name,
|
|
||||||
"Key": file_key,
|
|
||||||
},
|
|
||||||
ExpiresIn=expires,
|
ExpiresIn=expires,
|
||||||
)
|
)
|
||||||
logger.debug(f"Generated presigned URL for {file_key}, expires in {expires}s")
|
logger.debug(f"Generated presigned URL for {file_key}, expires in {expires}s")
|
||||||
return url
|
return url
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to generate presigned URL for {file_key}: {e}")
|
logger.error(f"Failed to generate presigned URL for {file_key}: {e}")
|
||||||
# Return a basic URL format as fallback
|
|
||||||
return f"https://{self.bucket_name}.s3.{self.region}.amazonaws.com/{file_key}"
|
return f"https://{self.bucket_name}.s3.{self.region}.amazonaws.com/{file_key}"
|
||||||
|
|
||||||
|
async def get_permanent_url(self, file_key: str) -> str:
|
||||||
|
"""
|
||||||
|
Get a permanent public URL for the file (requires bucket public read).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A permanent URL in the format: https://{bucket}.s3.{region}.amazonaws.com/{file_key}
|
||||||
|
"""
|
||||||
|
return f"https://{self.bucket_name}.s3.{self.region}.amazonaws.com/{file_key}"
|
||||||
|
|||||||
@@ -99,7 +99,7 @@ class SimpleMCPClient:
|
|||||||
# 建立 SSE 连接
|
# 建立 SSE 连接
|
||||||
response = await self._session.get(self.server_url)
|
response = await self._session.get(self.server_url)
|
||||||
|
|
||||||
if response.status != 200:
|
if response.status not in (200, 202):
|
||||||
error_text = await response.text()
|
error_text = await response.text()
|
||||||
raise MCPConnectionError(f"SSE 连接失败 {response.status}: {error_text}")
|
raise MCPConnectionError(f"SSE 连接失败 {response.status}: {error_text}")
|
||||||
|
|
||||||
@@ -190,7 +190,9 @@ class SimpleMCPClient:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
async with self._session.post(self._endpoint_url, json=request) as response:
|
async with self._session.post(self._endpoint_url, json=request) as response:
|
||||||
if response.status != 200:
|
# MCP SSE 协议:POST 请求返回 200 或 202 均为正常
|
||||||
|
# 202 Accepted 表示请求已接受,结果通过 SSE 流异步返回
|
||||||
|
if response.status not in (200, 202):
|
||||||
error_text = await response.text()
|
error_text = await response.text()
|
||||||
raise MCPConnectionError(f"请求失败 {response.status}: {error_text}")
|
raise MCPConnectionError(f"请求失败 {response.status}: {error_text}")
|
||||||
|
|
||||||
@@ -205,7 +207,7 @@ class SimpleMCPClient:
|
|||||||
raise MCPConnectionError("endpoint URL 未初始化")
|
raise MCPConnectionError("endpoint URL 未初始化")
|
||||||
|
|
||||||
async with self._session.post(self._endpoint_url, json=notification) as response:
|
async with self._session.post(self._endpoint_url, json=notification) as response:
|
||||||
if response.status != 200:
|
if response.status not in (200, 202):
|
||||||
logger.warning(f"通知发送失败: {response.status}")
|
logger.warning(f"通知发送失败: {response.status}")
|
||||||
|
|
||||||
async def _initialize_modelscope_session(self):
|
async def _initialize_modelscope_session(self):
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ from typing import Any
|
|||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from app.core.workflow.adapters.errors import ExceptionDefineition
|
from app.core.workflow.adapters.errors import ExceptionDefinition
|
||||||
from app.schemas.workflow_schema import (
|
from app.schemas.workflow_schema import (
|
||||||
EdgeDefinition,
|
EdgeDefinition,
|
||||||
NodeDefinition,
|
NodeDefinition,
|
||||||
@@ -40,8 +40,8 @@ class WorkflowParserResult(BaseModel):
|
|||||||
edges: list[EdgeDefinition] = Field(default_factory=list)
|
edges: list[EdgeDefinition] = Field(default_factory=list)
|
||||||
nodes: list[NodeDefinition] = Field(default_factory=list)
|
nodes: list[NodeDefinition] = Field(default_factory=list)
|
||||||
variables: list[VariableDefinition] = Field(default_factory=list)
|
variables: list[VariableDefinition] = Field(default_factory=list)
|
||||||
warnings: list[ExceptionDefineition] = Field(default_factory=list)
|
warnings: list[ExceptionDefinition] = Field(default_factory=list)
|
||||||
errors: list[ExceptionDefineition] = Field(default_factory=list)
|
errors: list[ExceptionDefinition] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
class WorkflowImportResult(BaseModel):
|
class WorkflowImportResult(BaseModel):
|
||||||
@@ -51,8 +51,8 @@ class WorkflowImportResult(BaseModel):
|
|||||||
edges: list[EdgeDefinition] = Field(default_factory=list)
|
edges: list[EdgeDefinition] = Field(default_factory=list)
|
||||||
nodes: list[NodeDefinition] = Field(default_factory=list)
|
nodes: list[NodeDefinition] = Field(default_factory=list)
|
||||||
variables: list[VariableDefinition] = Field(default_factory=list)
|
variables: list[VariableDefinition] = Field(default_factory=list)
|
||||||
warnings: list[ExceptionDefineition] = Field(default_factory=list)
|
warnings: list[ExceptionDefinition] = Field(default_factory=list)
|
||||||
errors: list[ExceptionDefineition] = Field(default_factory=list)
|
errors: list[ExceptionDefinition] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
class BasePlatformAdapter(ABC):
|
class BasePlatformAdapter(ABC):
|
||||||
|
|||||||
@@ -9,9 +9,9 @@ from urllib.parse import quote
|
|||||||
|
|
||||||
from app.core.workflow.adapters.base_converter import BaseConverter
|
from app.core.workflow.adapters.base_converter import BaseConverter
|
||||||
from app.core.workflow.adapters.errors import (
|
from app.core.workflow.adapters.errors import (
|
||||||
UnsupportVariableType,
|
UnsupportedVariableType,
|
||||||
UnknowModelWarning,
|
UnknownModelWarning,
|
||||||
ExceptionDefineition,
|
ExceptionDefinition,
|
||||||
ExceptionType
|
ExceptionType
|
||||||
)
|
)
|
||||||
from app.core.workflow.nodes.assigner.config import AssignmentItem
|
from app.core.workflow.nodes.assigner.config import AssignmentItem
|
||||||
@@ -54,7 +54,7 @@ from app.core.workflow.nodes.http_request.config import (
|
|||||||
HttpFormData,
|
HttpFormData,
|
||||||
HttpTimeOutConfig,
|
HttpTimeOutConfig,
|
||||||
HttpRetryConfig,
|
HttpRetryConfig,
|
||||||
HttpErrorDefaultTamplete,
|
HttpErrorDefaultTemplate,
|
||||||
HttpErrorHandleConfig
|
HttpErrorHandleConfig
|
||||||
)
|
)
|
||||||
from app.core.workflow.nodes.if_else.config import ConditionDetail, ConditionBranchConfig
|
from app.core.workflow.nodes.if_else.config import ConditionDetail, ConditionBranchConfig
|
||||||
@@ -108,7 +108,7 @@ class DifyConverter(BaseConverter):
|
|||||||
try:
|
try:
|
||||||
return config.model_validate(value)
|
return config.model_validate(value)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.errors.append(ExceptionDefineition(
|
self.errors.append(ExceptionDefinition(
|
||||||
type=ExceptionType.CONFIG,
|
type=ExceptionType.CONFIG,
|
||||||
node_id=node_id,
|
node_id=node_id,
|
||||||
node_name=node_name,
|
node_name=node_name,
|
||||||
@@ -138,7 +138,7 @@ class DifyConverter(BaseConverter):
|
|||||||
var_selector = mapping.get(var_selector, var_selector)
|
var_selector = mapping.get(var_selector, var_selector)
|
||||||
return var_selector
|
return var_selector
|
||||||
|
|
||||||
def _process_list_variable_litearl(self, variable_selector: list) -> str | None:
|
def _process_list_variable_literal(self, variable_selector: list) -> str | None:
|
||||||
if not self.process_var_selector(".".join(variable_selector)):
|
if not self.process_var_selector(".".join(variable_selector)):
|
||||||
return None
|
return None
|
||||||
return "{{" + self.process_var_selector(".".join(variable_selector)) + "}}"
|
return "{{" + self.process_var_selector(".".join(variable_selector)) + "}}"
|
||||||
@@ -269,7 +269,7 @@ class DifyConverter(BaseConverter):
|
|||||||
var_type = self.variable_type_map(var["type"])
|
var_type = self.variable_type_map(var["type"])
|
||||||
if not var_type:
|
if not var_type:
|
||||||
self.errors.append(
|
self.errors.append(
|
||||||
UnsupportVariableType(
|
UnsupportedVariableType(
|
||||||
scope=node["id"],
|
scope=node["id"],
|
||||||
name=var["variable"],
|
name=var["variable"],
|
||||||
var_type=var["type"],
|
var_type=var["type"],
|
||||||
@@ -281,7 +281,7 @@ class DifyConverter(BaseConverter):
|
|||||||
|
|
||||||
if var_type in ["file", "array[file]"]:
|
if var_type in ["file", "array[file]"]:
|
||||||
self.errors.append(
|
self.errors.append(
|
||||||
ExceptionDefineition(
|
ExceptionDefinition(
|
||||||
type=ExceptionType.VARIABLE,
|
type=ExceptionType.VARIABLE,
|
||||||
node_id=node["id"],
|
node_id=node["id"],
|
||||||
node_name=node_data["title"],
|
node_name=node_data["title"],
|
||||||
@@ -311,7 +311,7 @@ class DifyConverter(BaseConverter):
|
|||||||
def convert_question_classifier_node_config(self, node: dict) -> dict:
|
def convert_question_classifier_node_config(self, node: dict) -> dict:
|
||||||
node_data = node["data"]
|
node_data = node["data"]
|
||||||
self.warnings.append(
|
self.warnings.append(
|
||||||
UnknowModelWarning(
|
UnknownModelWarning(
|
||||||
node_id=node["id"],
|
node_id=node["id"],
|
||||||
node_name=node_data["title"],
|
node_name=node_data["title"],
|
||||||
model_name=node_data["model"].get("name")
|
model_name=node_data["model"].get("name")
|
||||||
@@ -327,7 +327,7 @@ class DifyConverter(BaseConverter):
|
|||||||
)
|
)
|
||||||
|
|
||||||
result = QuestionClassifierNodeConfig.model_construct(
|
result = QuestionClassifierNodeConfig.model_construct(
|
||||||
input_variable=self._process_list_variable_litearl(node_data.get("query_variable_selector")),
|
input_variable=self._process_list_variable_literal(node_data.get("query_variable_selector")),
|
||||||
user_supplement_prompt=self.trans_variable_format(node_data.get("instructions", "")),
|
user_supplement_prompt=self.trans_variable_format(node_data.get("instructions", "")),
|
||||||
categories=categories,
|
categories=categories,
|
||||||
).model_dump()
|
).model_dump()
|
||||||
@@ -337,13 +337,13 @@ class DifyConverter(BaseConverter):
|
|||||||
def convert_llm_node_config(self, node: dict) -> dict:
|
def convert_llm_node_config(self, node: dict) -> dict:
|
||||||
node_data = node["data"]
|
node_data = node["data"]
|
||||||
self.warnings.append(
|
self.warnings.append(
|
||||||
UnknowModelWarning(
|
UnknownModelWarning(
|
||||||
node_id=node["id"],
|
node_id=node["id"],
|
||||||
node_name=node_data["title"],
|
node_name=node_data["title"],
|
||||||
model_name=node_data["model"].get("name")
|
model_name=node_data["model"].get("name")
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
context = self._process_list_variable_litearl(node_data["context"]["variable_selector"])
|
context = self._process_list_variable_literal(node_data["context"]["variable_selector"])
|
||||||
memory = MemoryWindowSetting(
|
memory = MemoryWindowSetting(
|
||||||
enable=bool(node_data.get("memory")),
|
enable=bool(node_data.get("memory")),
|
||||||
enable_window=bool(node_data.get("memory", {}).get("window", {}).get("enabled", False)),
|
enable_window=bool(node_data.get("memory", {}).get("window", {}).get("enabled", False)),
|
||||||
@@ -367,7 +367,7 @@ class DifyConverter(BaseConverter):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
vision = node_data["vision"]["enabled"]
|
vision = node_data["vision"]["enabled"]
|
||||||
vision_input = self._process_list_variable_litearl(
|
vision_input = self._process_list_variable_literal(
|
||||||
node_data["vision"]["configs"]["variable_selector"]
|
node_data["vision"]["configs"]["variable_selector"]
|
||||||
) if vision else None
|
) if vision else None
|
||||||
result = LLMNodeConfig.model_construct(
|
result = LLMNodeConfig.model_construct(
|
||||||
@@ -433,7 +433,7 @@ class DifyConverter(BaseConverter):
|
|||||||
conditions.append(
|
conditions.append(
|
||||||
LoopConditionDetail.model_construct(
|
LoopConditionDetail.model_construct(
|
||||||
operator=self.convert_compare_operator(condition["comparison_operator"]),
|
operator=self.convert_compare_operator(condition["comparison_operator"]),
|
||||||
left=self._process_list_variable_litearl(condition["variable_selector"]),
|
left=self._process_list_variable_literal(condition["variable_selector"]),
|
||||||
right=self.trans_variable_format(
|
right=self.trans_variable_format(
|
||||||
right_value
|
right_value
|
||||||
) if isinstance(right_value, str) and self.is_variable(right_value) else self.convert_variable_type(
|
) if isinstance(right_value, str) and self.is_variable(right_value) else self.convert_variable_type(
|
||||||
@@ -453,7 +453,7 @@ class DifyConverter(BaseConverter):
|
|||||||
right_input_type = variable["value_type"]
|
right_input_type = variable["value_type"]
|
||||||
right_value_type = self.variable_type_map(variable["var_type"])
|
right_value_type = self.variable_type_map(variable["var_type"])
|
||||||
if right_input_type == ValueInputType.VARIABLE:
|
if right_input_type == ValueInputType.VARIABLE:
|
||||||
right_value = self._process_list_variable_litearl(variable.get("value", ""))
|
right_value = self._process_list_variable_literal(variable.get("value", ""))
|
||||||
else:
|
else:
|
||||||
right_value = self.convert_variable_type(right_value_type, variable.get("value", ""))
|
right_value = self.convert_variable_type(right_value_type, variable.get("value", ""))
|
||||||
loop_variables.append(
|
loop_variables.append(
|
||||||
@@ -475,10 +475,10 @@ class DifyConverter(BaseConverter):
|
|||||||
def convert_iteration_node_config(self, node: dict) -> dict:
|
def convert_iteration_node_config(self, node: dict) -> dict:
|
||||||
node_data = node["data"]
|
node_data = node["data"]
|
||||||
result = IterationNodeConfig.model_construct(
|
result = IterationNodeConfig.model_construct(
|
||||||
input=self._process_list_variable_litearl(node_data["iterator_selector"]),
|
input=self._process_list_variable_literal(node_data["iterator_selector"]),
|
||||||
parallel=node_data["is_parallel"],
|
parallel=node_data["is_parallel"],
|
||||||
parallel_count=node_data["parallel_nums"],
|
parallel_count=node_data["parallel_nums"],
|
||||||
output=self._process_list_variable_litearl(node_data["output_selector"]),
|
output=self._process_list_variable_literal(node_data["output_selector"]),
|
||||||
output_type=self.variable_type_map(node_data.get("output_type")),
|
output_type=self.variable_type_map(node_data.get("output_type")),
|
||||||
flatten=node_data["flatten_output"],
|
flatten=node_data["flatten_output"],
|
||||||
).model_dump()
|
).model_dump()
|
||||||
@@ -494,8 +494,8 @@ class DifyConverter(BaseConverter):
|
|||||||
continue
|
continue
|
||||||
assignments.append(
|
assignments.append(
|
||||||
AssignmentItem(
|
AssignmentItem(
|
||||||
variable_selector=self._process_list_variable_litearl(assignment["variable_selector"]),
|
variable_selector=self._process_list_variable_literal(assignment["variable_selector"]),
|
||||||
value=self._process_list_variable_litearl(
|
value=self._process_list_variable_literal(
|
||||||
assignment["value"]
|
assignment["value"]
|
||||||
) if assignment["input_type"] == ValueInputType.VARIABLE else assignment["value"],
|
) if assignment["input_type"] == ValueInputType.VARIABLE else assignment["value"],
|
||||||
operation=self.convert_assignment_operator(assignment["operation"])
|
operation=self.convert_assignment_operator(assignment["operation"])
|
||||||
@@ -514,7 +514,7 @@ class DifyConverter(BaseConverter):
|
|||||||
input_variables.append(
|
input_variables.append(
|
||||||
InputVariable.model_construct(
|
InputVariable.model_construct(
|
||||||
name=input_variable["variable"],
|
name=input_variable["variable"],
|
||||||
variable=self._process_list_variable_litearl(input_variable["value_selector"]),
|
variable=self._process_list_variable_literal(input_variable["value_selector"]),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -570,7 +570,7 @@ class DifyConverter(BaseConverter):
|
|||||||
else:
|
else:
|
||||||
if node_data["body"]["data"]:
|
if node_data["body"]["data"]:
|
||||||
body_content = (node_data["body"]["data"][0].get("value") or
|
body_content = (node_data["body"]["data"][0].get("value") or
|
||||||
self._process_list_variable_litearl(node_data["body"]["data"][0].get("file")))
|
self._process_list_variable_literal(node_data["body"]["data"][0].get("file")))
|
||||||
else:
|
else:
|
||||||
body_content = ""
|
body_content = ""
|
||||||
|
|
||||||
@@ -585,7 +585,7 @@ class DifyConverter(BaseConverter):
|
|||||||
self.trans_variable_format(key_value[0])
|
self.trans_variable_format(key_value[0])
|
||||||
] = self.trans_variable_format(key_value[1])
|
] = self.trans_variable_format(key_value[1])
|
||||||
else:
|
else:
|
||||||
self.warnings.append(ExceptionDefineition(
|
self.warnings.append(ExceptionDefinition(
|
||||||
type=ExceptionType.CONFIG,
|
type=ExceptionType.CONFIG,
|
||||||
node_id=node["id"],
|
node_id=node["id"],
|
||||||
node_name=node_data["title"],
|
node_name=node_data["title"],
|
||||||
@@ -603,7 +603,7 @@ class DifyConverter(BaseConverter):
|
|||||||
self.trans_variable_format(key_value[0])
|
self.trans_variable_format(key_value[0])
|
||||||
] = self.trans_variable_format(key_value[1])
|
] = self.trans_variable_format(key_value[1])
|
||||||
else:
|
else:
|
||||||
self.warnings.append(ExceptionDefineition(
|
self.warnings.append(ExceptionDefinition(
|
||||||
type=ExceptionType.CONFIG,
|
type=ExceptionType.CONFIG,
|
||||||
node_id=node["id"],
|
node_id=node["id"],
|
||||||
node_name=node_data["title"],
|
node_name=node_data["title"],
|
||||||
@@ -625,7 +625,7 @@ class DifyConverter(BaseConverter):
|
|||||||
default_header = var["value"]
|
default_header = var["value"]
|
||||||
elif var["key"] == "status_code":
|
elif var["key"] == "status_code":
|
||||||
default_status_code = var["value"]
|
default_status_code = var["value"]
|
||||||
default_value = HttpErrorDefaultTamplete(
|
default_value = HttpErrorDefaultTemplate(
|
||||||
body=default_body,
|
body=default_body,
|
||||||
headers=default_header,
|
headers=default_header,
|
||||||
status_code=default_status_code,
|
status_code=default_status_code,
|
||||||
@@ -668,7 +668,7 @@ class DifyConverter(BaseConverter):
|
|||||||
for variable in node_data["variables"]:
|
for variable in node_data["variables"]:
|
||||||
mapping.append(VariablesMappingConfig.model_construct(
|
mapping.append(VariablesMappingConfig.model_construct(
|
||||||
name=variable["variable"],
|
name=variable["variable"],
|
||||||
value=self._process_list_variable_litearl(variable["value_selector"])
|
value=self._process_list_variable_literal(variable["value_selector"])
|
||||||
))
|
))
|
||||||
result = JinjaRenderNodeConfig.model_construct(
|
result = JinjaRenderNodeConfig.model_construct(
|
||||||
template=node_data["template"],
|
template=node_data["template"],
|
||||||
@@ -679,14 +679,14 @@ class DifyConverter(BaseConverter):
|
|||||||
|
|
||||||
def convert_knowledge_node_config(self, node: dict) -> dict:
|
def convert_knowledge_node_config(self, node: dict) -> dict:
|
||||||
node_data = node["data"]
|
node_data = node["data"]
|
||||||
self.warnings.append(ExceptionDefineition(
|
self.warnings.append(ExceptionDefinition(
|
||||||
node_id=node["id"],
|
node_id=node["id"],
|
||||||
node_name=node_data["title"],
|
node_name=node_data["title"],
|
||||||
type=ExceptionType.CONFIG,
|
type=ExceptionType.CONFIG,
|
||||||
detail=f"Please reconfigure the Knowledge Retrieval node.",
|
detail=f"Please reconfigure the Knowledge Retrieval node.",
|
||||||
))
|
))
|
||||||
result = KnowledgeRetrievalNodeConfig.model_construct(
|
result = KnowledgeRetrievalNodeConfig.model_construct(
|
||||||
query=self._process_list_variable_litearl(node_data["query_variable_selector"]),
|
query=self._process_list_variable_literal(node_data["query_variable_selector"]),
|
||||||
).model_dump()
|
).model_dump()
|
||||||
|
|
||||||
self.config_validate(node["id"], node["data"]["title"], KnowledgeRetrievalNodeConfig, result)
|
self.config_validate(node["id"], node["data"]["title"], KnowledgeRetrievalNodeConfig, result)
|
||||||
@@ -695,7 +695,7 @@ class DifyConverter(BaseConverter):
|
|||||||
def convert_parameter_extractor_node_config(self, node: dict) -> dict:
|
def convert_parameter_extractor_node_config(self, node: dict) -> dict:
|
||||||
node_data = node["data"]
|
node_data = node["data"]
|
||||||
self.warnings.append(
|
self.warnings.append(
|
||||||
UnknowModelWarning(
|
UnknownModelWarning(
|
||||||
node_id=node["id"],
|
node_id=node["id"],
|
||||||
node_name=node_data["title"],
|
node_name=node_data["title"],
|
||||||
model_name=node_data["model"].get("name")
|
model_name=node_data["model"].get("name")
|
||||||
@@ -712,7 +712,7 @@ class DifyConverter(BaseConverter):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
result = ParameterExtractorNodeConfig.model_construct(
|
result = ParameterExtractorNodeConfig.model_construct(
|
||||||
text=self._process_list_variable_litearl(node_data["query"]),
|
text=self._process_list_variable_literal(node_data["query"]),
|
||||||
params=params,
|
params=params,
|
||||||
prompt=node_data.get("instruction")
|
prompt=node_data.get("instruction")
|
||||||
).model_dump()
|
).model_dump()
|
||||||
@@ -727,14 +727,14 @@ class DifyConverter(BaseConverter):
|
|||||||
group_type = {}
|
group_type = {}
|
||||||
if not advanced_settings or not advanced_settings["group_enabled"]:
|
if not advanced_settings or not advanced_settings["group_enabled"]:
|
||||||
group_variables = [
|
group_variables = [
|
||||||
self._process_list_variable_litearl(variable)
|
self._process_list_variable_literal(variable)
|
||||||
for variable in node_data["variables"]
|
for variable in node_data["variables"]
|
||||||
]
|
]
|
||||||
group_type["output"] = node_data["output_type"]
|
group_type["output"] = node_data["output_type"]
|
||||||
else:
|
else:
|
||||||
for group in advanced_settings["groups"]:
|
for group in advanced_settings["groups"]:
|
||||||
group_variables[group["group_name"]] = [
|
group_variables[group["group_name"]] = [
|
||||||
self._process_list_variable_litearl(variable)
|
self._process_list_variable_literal(variable)
|
||||||
for variable in group["variables"]
|
for variable in group["variables"]
|
||||||
]
|
]
|
||||||
group_type[group["group_name"]] = group["output_type"]
|
group_type[group["group_name"]] = group["output_type"]
|
||||||
@@ -751,7 +751,7 @@ class DifyConverter(BaseConverter):
|
|||||||
|
|
||||||
def convert_tool_node_config(self, node: dict) -> dict:
|
def convert_tool_node_config(self, node: dict) -> dict:
|
||||||
node_data = node["data"]
|
node_data = node["data"]
|
||||||
self.warnings.append(ExceptionDefineition(
|
self.warnings.append(ExceptionDefinition(
|
||||||
node_id=node["id"],
|
node_id=node["id"],
|
||||||
node_name=node_data["title"],
|
node_name=node_data["title"],
|
||||||
type=ExceptionType.CONFIG,
|
type=ExceptionType.CONFIG,
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ from app.core.workflow.adapters.base_adapter import (
|
|||||||
WorkflowParserResult
|
WorkflowParserResult
|
||||||
)
|
)
|
||||||
from app.core.workflow.adapters.dify.converter import DifyConverter
|
from app.core.workflow.adapters.dify.converter import DifyConverter
|
||||||
from app.core.workflow.adapters.errors import ExceptionDefineition, ExceptionType
|
from app.core.workflow.adapters.errors import ExceptionDefinition, ExceptionType
|
||||||
from app.core.workflow.nodes.enums import NodeType
|
from app.core.workflow.nodes.enums import NodeType
|
||||||
from app.schemas.workflow_schema import (
|
from app.schemas.workflow_schema import (
|
||||||
NodeDefinition,
|
NodeDefinition,
|
||||||
@@ -85,7 +85,7 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
|
|||||||
if not all(field in self.config for field in require_fields):
|
if not all(field in self.config for field in require_fields):
|
||||||
return False
|
return False
|
||||||
if self.config.get("app", {}).get("mode") == "workflow":
|
if self.config.get("app", {}).get("mode") == "workflow":
|
||||||
self.errors.append(ExceptionDefineition(
|
self.errors.append(ExceptionDefinition(
|
||||||
type=ExceptionType.PLATFORM,
|
type=ExceptionType.PLATFORM,
|
||||||
detail="workflow mode is not supported"
|
detail="workflow mode is not supported"
|
||||||
))
|
))
|
||||||
@@ -111,12 +111,12 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
|
|||||||
edge = self._convert_edge(edge)
|
edge = self._convert_edge(edge)
|
||||||
if edge:
|
if edge:
|
||||||
self.edges.append(edge)
|
self.edges.append(edge)
|
||||||
#
|
|
||||||
for variable in self.config.get("workflow").get("conversation_variables"):
|
for variable in self.config.get("workflow").get("conversation_variables"):
|
||||||
con_var = self._convert_variable(variable)
|
con_var = self._convert_variable(variable)
|
||||||
if variable:
|
if variable:
|
||||||
self.conv_variables.append(con_var)
|
self.conv_variables.append(con_var)
|
||||||
#
|
|
||||||
# for variables in config.get("workflow").get("environment_variables"):
|
# for variables in config.get("workflow").get("environment_variables"):
|
||||||
# variable = self._convert_variable(variables)
|
# variable = self._convert_variable(variables)
|
||||||
# conv_variables.append(variable)
|
# conv_variables.append(variable)
|
||||||
@@ -152,7 +152,7 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
|
|||||||
"y": node["position"]["y"] + position["y"]
|
"y": node["position"]["y"] + position["y"]
|
||||||
}
|
}
|
||||||
self.errors.append(
|
self.errors.append(
|
||||||
ExceptionDefineition(
|
ExceptionDefinition(
|
||||||
type=ExceptionType.NODE,
|
type=ExceptionType.NODE,
|
||||||
node_id=node_id,
|
node_id=node_id,
|
||||||
detail="parent cycle node not found"
|
detail="parent cycle node not found"
|
||||||
@@ -189,7 +189,7 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
|
|||||||
node_data = node["data"]
|
node_data = node["data"]
|
||||||
converter = self.get_node_convert(node_type)
|
converter = self.get_node_convert(node_type)
|
||||||
if node_type == NodeType.UNKNOWN:
|
if node_type == NodeType.UNKNOWN:
|
||||||
self.errors.append(ExceptionDefineition(
|
self.errors.append(ExceptionDefinition(
|
||||||
type=ExceptionType.NODE,
|
type=ExceptionType.NODE,
|
||||||
node_id=node["id"],
|
node_id=node["id"],
|
||||||
node_name=node["data"]["title"],
|
node_name=node["data"]["title"],
|
||||||
@@ -197,7 +197,7 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
|
|||||||
))
|
))
|
||||||
return converter(node)
|
return converter(node)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.errors.append(ExceptionDefineition(
|
self.errors.append(ExceptionDefinition(
|
||||||
type=ExceptionType.NODE,
|
type=ExceptionType.NODE,
|
||||||
node_id=node["id"],
|
node_id=node["id"],
|
||||||
node_name=node["data"]["title"],
|
node_name=node["data"]["title"],
|
||||||
@@ -207,7 +207,6 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
|
|||||||
|
|
||||||
def _convert_edge(self, edge: dict[str, Any]) -> EdgeDefinition | None:
|
def _convert_edge(self, edge: dict[str, Any]) -> EdgeDefinition | None:
|
||||||
try:
|
try:
|
||||||
|
|
||||||
source = edge["source"]
|
source = edge["source"]
|
||||||
target = edge["target"]
|
target = edge["target"]
|
||||||
label = None
|
label = None
|
||||||
@@ -230,7 +229,7 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
|
|||||||
label=label,
|
label=label,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.errors.append(ExceptionDefineition(
|
self.errors.append(ExceptionDefinition(
|
||||||
type=ExceptionType.EDGE,
|
type=ExceptionType.EDGE,
|
||||||
detail=f"convert edge error - {e}",
|
detail=f"convert edge error - {e}",
|
||||||
))
|
))
|
||||||
@@ -246,7 +245,7 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
|
|||||||
description=variable.get("description")
|
description=variable.get("description")
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.errors.append(ExceptionDefineition(
|
self.errors.append(ExceptionDefinition(
|
||||||
type=ExceptionType.VARIABLE,
|
type=ExceptionType.VARIABLE,
|
||||||
name=variable.get("name"),
|
name=variable.get("name"),
|
||||||
detail=f"convert variable error - {e}",
|
detail=f"convert variable error - {e}",
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ class ExceptionType(StrEnum):
|
|||||||
UNKNOWN = "unknown"
|
UNKNOWN = "unknown"
|
||||||
|
|
||||||
|
|
||||||
class ExceptionDefineition(BaseModel):
|
class ExceptionDefinition(BaseModel):
|
||||||
type: ExceptionType
|
type: ExceptionType
|
||||||
detail: str
|
detail: str
|
||||||
|
|
||||||
@@ -29,7 +29,7 @@ class ExceptionDefineition(BaseModel):
|
|||||||
name: str | None = None
|
name: str | None = None
|
||||||
|
|
||||||
|
|
||||||
class UnknowModelWarning(ExceptionDefineition):
|
class UnknownModelWarning(ExceptionDefinition):
|
||||||
type: ExceptionType = ExceptionType.NODE
|
type: ExceptionType = ExceptionType.NODE
|
||||||
|
|
||||||
def __init__(self, node_id, node_name, model_name):
|
def __init__(self, node_id, node_name, model_name):
|
||||||
@@ -40,36 +40,36 @@ class UnknowModelWarning(ExceptionDefineition):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class UnknowError(ExceptionDefineition):
|
class UnknownError(ExceptionDefinition):
|
||||||
type: ExceptionType = ExceptionType.UNKNOWN
|
type: ExceptionType = ExceptionType.UNKNOWN
|
||||||
|
|
||||||
def __init__(self, detail: str, **kwargs):
|
def __init__(self, detail: str, **kwargs):
|
||||||
super().__init__(detail=detail, **kwargs)
|
super().__init__(detail=detail, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
class UnsupportPlatform(ExceptionDefineition):
|
class UnsupportedPlatform(ExceptionDefinition):
|
||||||
type: ExceptionType = ExceptionType.PLATFORM
|
type: ExceptionType = ExceptionType.PLATFORM
|
||||||
|
|
||||||
def __init__(self, platform: str):
|
def __init__(self, platform: str):
|
||||||
super().__init__(detail=f"Unsupport platform {platform}")
|
super().__init__(detail=f"Unsupported platform {platform}")
|
||||||
|
|
||||||
|
|
||||||
class UnsupportVariableType(ExceptionDefineition):
|
class UnsupportedVariableType(ExceptionDefinition):
|
||||||
type: ExceptionType = ExceptionType.VARIABLE
|
type: ExceptionType = ExceptionType.VARIABLE
|
||||||
|
|
||||||
def __init__(self, scope, name, var_type: str, **kwargs):
|
def __init__(self, scope, name, var_type: str, **kwargs):
|
||||||
super().__init__(scope=scope, name=name, detail=f"Unsupport variable type:[{var_type}]", **kwargs)
|
super().__init__(scope=scope, name=name, detail=f"Unsupported variable type: [{var_type}]", **kwargs)
|
||||||
|
|
||||||
|
|
||||||
class InvalidConfiguration(ExceptionDefineition):
|
class InvalidConfiguration(ExceptionDefinition):
|
||||||
type: ExceptionType = ExceptionType.CONFIG
|
type: ExceptionType = ExceptionType.CONFIG
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__(detail="Invalid workflow configuration format")
|
super().__init__(detail="Invalid workflow configuration format")
|
||||||
|
|
||||||
|
|
||||||
class UnsupportNodeType(ExceptionDefineition):
|
class UnsupportedNodeType(ExceptionDefinition):
|
||||||
type: ExceptionType = ExceptionType.NODE
|
type: ExceptionType = ExceptionType.NODE
|
||||||
|
|
||||||
def __init__(self, node_id: str, node_type: str):
|
def __init__(self, node_id: str, node_type: str):
|
||||||
super().__init__(node_id=node_id, detail=f"Unsupport node Type {node_type}")
|
super().__init__(node_id=node_id, detail=f"Unsupported node type {node_type}")
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ from app.core.workflow.adapters.base_adapter import (
|
|||||||
BasePlatformAdapter,
|
BasePlatformAdapter,
|
||||||
WorkflowParserResult
|
WorkflowParserResult
|
||||||
)
|
)
|
||||||
from app.core.workflow.adapters.errors import ExceptionDefineition, ExceptionType, UnsupportNodeType
|
from app.core.workflow.adapters.errors import ExceptionDefinition, ExceptionType, UnsupportedNodeType
|
||||||
from app.core.workflow.adapters.memory_bear.memory_bear_converter import MemoryBearConverter
|
from app.core.workflow.adapters.memory_bear.memory_bear_converter import MemoryBearConverter
|
||||||
from app.core.workflow.nodes.enums import NodeType
|
from app.core.workflow.nodes.enums import NodeType
|
||||||
from app.schemas.workflow_schema import ExecutionConfig, NodeDefinition, EdgeDefinition, VariableDefinition
|
from app.schemas.workflow_schema import ExecutionConfig, NodeDefinition, EdgeDefinition, VariableDefinition
|
||||||
@@ -73,7 +73,7 @@ class MemoryBearAdapter(BasePlatformAdapter, MemoryBearConverter):
|
|||||||
try:
|
try:
|
||||||
node_type = self.map_node_type(node["type"])
|
node_type = self.map_node_type(node["type"])
|
||||||
if node_type == NodeType.UNKNOWN:
|
if node_type == NodeType.UNKNOWN:
|
||||||
self.errors.append(UnsupportNodeType(
|
self.errors.append(UnsupportedNodeType(
|
||||||
node_id=node_id,
|
node_id=node_id,
|
||||||
node_type=node["type"]
|
node_type=node["type"]
|
||||||
))
|
))
|
||||||
@@ -85,7 +85,7 @@ class MemoryBearAdapter(BasePlatformAdapter, MemoryBearConverter):
|
|||||||
|
|
||||||
return NodeDefinition(**node)
|
return NodeDefinition(**node)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.errors.append(ExceptionDefineition(
|
self.errors.append(ExceptionDefinition(
|
||||||
type=ExceptionType.NODE,
|
type=ExceptionType.NODE,
|
||||||
node_id=node_id,
|
node_id=node_id,
|
||||||
node_name=node_name,
|
node_name=node_name,
|
||||||
@@ -97,14 +97,14 @@ class MemoryBearAdapter(BasePlatformAdapter, MemoryBearConverter):
|
|||||||
def _convert_edge(self, edge: dict[str, Any], valid_node_ids: set) -> EdgeDefinition | None:
|
def _convert_edge(self, edge: dict[str, Any], valid_node_ids: set) -> EdgeDefinition | None:
|
||||||
try:
|
try:
|
||||||
if edge.get("source") not in valid_node_ids or edge.get("target") not in valid_node_ids:
|
if edge.get("source") not in valid_node_ids or edge.get("target") not in valid_node_ids:
|
||||||
self.warnings.append(ExceptionDefineition(
|
self.warnings.append(ExceptionDefinition(
|
||||||
type=ExceptionType.EDGE,
|
type=ExceptionType.EDGE,
|
||||||
detail=f"edge {edge.get('id')} skipped: source or target node not found"
|
detail=f"edge {edge.get('id')} skipped: source or target node not found"
|
||||||
))
|
))
|
||||||
return None
|
return None
|
||||||
return EdgeDefinition(**edge)
|
return EdgeDefinition(**edge)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.errors.append(ExceptionDefineition(
|
self.errors.append(ExceptionDefinition(
|
||||||
type=ExceptionType.EDGE,
|
type=ExceptionType.EDGE,
|
||||||
detail=f"convert edge error - {e}"
|
detail=f"convert edge error - {e}"
|
||||||
))
|
))
|
||||||
@@ -115,7 +115,7 @@ class MemoryBearAdapter(BasePlatformAdapter, MemoryBearConverter):
|
|||||||
try:
|
try:
|
||||||
return VariableDefinition(**variable)
|
return VariableDefinition(**variable)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.warnings.append(ExceptionDefineition(
|
self.warnings.append(ExceptionDefinition(
|
||||||
type=ExceptionType.VARIABLE,
|
type=ExceptionType.VARIABLE,
|
||||||
name=variable.get("name"),
|
name=variable.get("name"),
|
||||||
detail=f"convert variable error - {e}"
|
detail=f"convert variable error - {e}"
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
# -*- coding: UTF-8 -*-
|
# -*- coding: UTF-8 -*-
|
||||||
from app.core.workflow.adapters.base_converter import BaseConverter
|
from app.core.workflow.adapters.base_converter import BaseConverter
|
||||||
from app.core.workflow.adapters.errors import ExceptionDefineition, ExceptionType
|
from app.core.workflow.adapters.errors import ExceptionDefinition, ExceptionType
|
||||||
from app.core.workflow.nodes.base_config import BaseNodeConfig
|
from app.core.workflow.nodes.base_config import BaseNodeConfig
|
||||||
from app.core.workflow.nodes.configs import (
|
from app.core.workflow.nodes.configs import (
|
||||||
StartNodeConfig,
|
StartNodeConfig,
|
||||||
@@ -65,7 +65,7 @@ class MemoryBearConverter(BaseConverter):
|
|||||||
try:
|
try:
|
||||||
return config_cls.model_validate(value)
|
return config_cls.model_validate(value)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.errors.append(ExceptionDefineition(
|
self.errors.append(ExceptionDefinition(
|
||||||
type=ExceptionType.CONFIG,
|
type=ExceptionType.CONFIG,
|
||||||
node_id=node_id,
|
node_id=node_id,
|
||||||
node_name=node_name,
|
node_name=node_name,
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ import re
|
|||||||
import uuid
|
import uuid
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from typing import Any, Iterable
|
from typing import Any, Iterable, Callable
|
||||||
|
|
||||||
from langgraph.checkpoint.memory import InMemorySaver
|
from langgraph.checkpoint.memory import InMemorySaver
|
||||||
from langgraph.graph import START, END
|
from langgraph.graph import START, END
|
||||||
@@ -41,48 +41,31 @@ class GraphBuilder:
|
|||||||
self,
|
self,
|
||||||
workflow_config: dict[str, Any],
|
workflow_config: dict[str, Any],
|
||||||
stream: bool = False,
|
stream: bool = False,
|
||||||
subgraph: bool = False,
|
cycle: str = '',
|
||||||
variable_pool: VariablePool | None = None
|
variable_pool: VariablePool | None = None
|
||||||
):
|
):
|
||||||
self.workflow_config = workflow_config
|
self.workflow_config = workflow_config
|
||||||
|
|
||||||
self.stream = stream
|
self.stream = stream
|
||||||
self.subgraph = subgraph
|
self.cycle = cycle
|
||||||
|
|
||||||
self.start_node_id: str | None = None
|
self.start_node_id: str | None = None
|
||||||
|
|
||||||
self.node_map = {node["id"]: node for node in self.nodes}
|
self.node_map: dict[str, dict] = {}
|
||||||
self.end_node_map: dict[str, StreamOutputConfig] = {}
|
self.end_node_map: dict[str, StreamOutputConfig] = {}
|
||||||
self._find_upstream_activation_dep = lru_cache(
|
self._find_upstream_activation_dep: Callable = self._find_upstream_activation_dep
|
||||||
maxsize=len(self.nodes) * 2
|
|
||||||
)(self._find_upstream_activation_dep)
|
|
||||||
if variable_pool:
|
if variable_pool:
|
||||||
self.variable_pool = variable_pool
|
self.variable_pool = variable_pool
|
||||||
else:
|
else:
|
||||||
self.variable_pool = VariablePool()
|
self.variable_pool = VariablePool()
|
||||||
|
|
||||||
self.graph = StateGraph(WorkflowState)
|
self.graph: StateGraph | None = None
|
||||||
self.add_nodes()
|
self.nodes: list = []
|
||||||
self.reachable_nodes = WorkflowValidator.get_reachable_nodes(self.start_node_id, self.edges)
|
self.edges: list = []
|
||||||
self.end_nodes = [
|
self.reachable_nodes: set[str] | None = None
|
||||||
node
|
self.end_nodes: list[dict] = []
|
||||||
for node in self.nodes
|
|
||||||
if node.get("type") == "end" and node.get("id") in self.reachable_nodes
|
|
||||||
]
|
|
||||||
self.add_edges()
|
|
||||||
# EDGES MUST BE ADDED AFTER NODES ARE ADDED.
|
|
||||||
|
|
||||||
self._reverse_adj: dict[str, list[dict]] = defaultdict(list)
|
self._reverse_adj: dict[str, list[dict]] = defaultdict(list)
|
||||||
self._build_reverse_adj()
|
self._adj: dict[str, list[str]] = defaultdict(list)
|
||||||
self._analyze_end_node_output()
|
|
||||||
|
|
||||||
@property
|
|
||||||
def nodes(self) -> list[dict[str, Any]]:
|
|
||||||
return self.workflow_config.get("nodes", [])
|
|
||||||
|
|
||||||
@property
|
|
||||||
def edges(self) -> list[dict[str, Any]]:
|
|
||||||
return self.workflow_config.get("edges", [])
|
|
||||||
|
|
||||||
def get_node_type(self, node_id: str) -> str:
|
def get_node_type(self, node_id: str) -> str:
|
||||||
"""Retrieve the type of node given its ID.
|
"""Retrieve the type of node given its ID.
|
||||||
@@ -108,13 +91,14 @@ class GraphBuilder:
|
|||||||
result[node[0]].append(node[1])
|
result[node[0]].append(node[1])
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def _build_reverse_adj(self):
|
def _build_adj(self):
|
||||||
for edge in self.edges:
|
for edge in self.edges:
|
||||||
if edge["source"] not in self.reachable_nodes:
|
if edge["source"] not in self.reachable_nodes:
|
||||||
continue
|
continue
|
||||||
self._reverse_adj[edge.get("target")].append({
|
self._reverse_adj[edge.get("target")].append({
|
||||||
"id": edge["source"], "branch": edge.get("label")
|
"id": edge["source"], "branch": edge.get("label")
|
||||||
})
|
})
|
||||||
|
self._adj[edge.get("source")].append(edge["target"])
|
||||||
|
|
||||||
def _find_upstream_activation_dep(
|
def _find_upstream_activation_dep(
|
||||||
self,
|
self,
|
||||||
@@ -302,22 +286,13 @@ class GraphBuilder:
|
|||||||
"""
|
"""
|
||||||
for node in self.nodes:
|
for node in self.nodes:
|
||||||
node_type = node.get("type")
|
node_type = node.get("type")
|
||||||
if node_type == NodeType.NOTES:
|
|
||||||
continue
|
|
||||||
node_id = node.get("id")
|
node_id = node.get("id")
|
||||||
cycle_node = node.get("cycle")
|
if node_id not in self.reachable_nodes:
|
||||||
if cycle_node:
|
continue
|
||||||
# Nodes within a loop subgraph are constructed by CycleGraphNode
|
|
||||||
if not self.subgraph:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Record start and end node IDs
|
|
||||||
if node_type in [NodeType.START, NodeType.CYCLE_START]:
|
|
||||||
self.start_node_id = node_id
|
|
||||||
|
|
||||||
# Create node instance (start and end nodes are also created)
|
# Create node instance (start and end nodes are also created)
|
||||||
# NOTE:Loop node creation automatically removes the nodes and edges of the subgraph from the current graph
|
# NOTE:Loop node creation automatically removes the nodes and edges of the subgraph from the current graph
|
||||||
node_instance = NodeFactory.create_node(node, self.workflow_config)
|
node_instance = NodeFactory.create_node(node, self.workflow_config, self._adj[node_id])
|
||||||
|
|
||||||
if node_type in BRANCH_NODES:
|
if node_type in BRANCH_NODES:
|
||||||
|
|
||||||
@@ -390,6 +365,8 @@ class GraphBuilder:
|
|||||||
for edge in self.edges:
|
for edge in self.edges:
|
||||||
source = edge.get("source")
|
source = edge.get("source")
|
||||||
target = edge.get("target")
|
target = edge.get("target")
|
||||||
|
if source not in self.reachable_nodes or target not in self.reachable_nodes:
|
||||||
|
continue
|
||||||
condition = edge.get("condition")
|
condition = edge.get("condition")
|
||||||
edge_type = edge.get("type")
|
edge_type = edge.get("type")
|
||||||
|
|
||||||
@@ -411,11 +388,12 @@ class GraphBuilder:
|
|||||||
# Add conditional edges
|
# Add conditional edges
|
||||||
for source_node, branches in conditional_edges.items():
|
for source_node, branches in conditional_edges.items():
|
||||||
def make_router(src, branch_list):
|
def make_router(src, branch_list):
|
||||||
"""reate a router function for each source node that routes to a NOP node for later merging."""
|
"""Create a router function for each source node that routes to a NOP node for later merging."""
|
||||||
|
|
||||||
def make_branch_node(node_name, targets):
|
def make_branch_node(node_name, targets):
|
||||||
def node(s):
|
def node(s):
|
||||||
# NOTE: NOP NODE MUST NOT MODIFY STATE
|
# NOTE: NOP NODE USED FOR ROUTING ONLY.
|
||||||
|
# MUST NOT MUTATE STATE DIRECTLY; ONLY EMIT ACTIVATE SIGNALS.
|
||||||
return {
|
return {
|
||||||
"activate": {
|
"activate": {
|
||||||
node_id: s["activate"][node_name]
|
node_id: s["activate"][node_name]
|
||||||
@@ -502,14 +480,52 @@ class GraphBuilder:
|
|||||||
logger.debug(f"Added waiting edge: {sources} -> {target}")
|
logger.debug(f"Added waiting edge: {sources} -> {target}")
|
||||||
|
|
||||||
# Connect End nodes to the global END node
|
# Connect End nodes to the global END node
|
||||||
for end_node in self.end_nodes:
|
for node in self.reachable_nodes:
|
||||||
end_node_id = end_node.get("id")
|
if not self._adj[node]:
|
||||||
if end_node_id:
|
self.graph.add_edge(node, END)
|
||||||
self.graph.add_edge(end_node_id, END)
|
|
||||||
logger.debug(f"Added edge: {end_node_id} -> END")
|
|
||||||
return
|
return
|
||||||
|
|
||||||
def build(self) -> CompiledStateGraph:
|
def build(self) -> CompiledStateGraph:
|
||||||
|
nodes = self.workflow_config.get("nodes", [])
|
||||||
|
edges = self.workflow_config.get("edges", [])
|
||||||
|
|
||||||
|
for node in nodes:
|
||||||
|
if (node.get("cycle") or '') == self.cycle:
|
||||||
|
node_type = node.get("type")
|
||||||
|
if node_type in [NodeType.START, NodeType.CYCLE_START]:
|
||||||
|
self.start_node_id = node.get("id")
|
||||||
|
elif node_type == NodeType.NOTES:
|
||||||
|
continue
|
||||||
|
self.nodes.append(node)
|
||||||
|
self.node_map[node.get("id")] = node
|
||||||
|
|
||||||
|
for edge in edges:
|
||||||
|
source_in = edge.get("source") in self.node_map
|
||||||
|
target_in = edge.get("target") in self.node_map
|
||||||
|
if source_in ^ target_in:
|
||||||
|
raise ValueError(
|
||||||
|
f"Cycle node is connected to external node, "
|
||||||
|
f"source: {edge.get('source')}, target: {edge.get('target')}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if source_in and target_in:
|
||||||
|
self.edges.append(edge)
|
||||||
|
|
||||||
|
self.reachable_nodes = WorkflowValidator.get_reachable_nodes(self.start_node_id, self.edges)
|
||||||
|
self.end_nodes = [
|
||||||
|
node
|
||||||
|
for node in self.nodes
|
||||||
|
if node.get("type") == "end" and node.get("id") in self.reachable_nodes
|
||||||
|
]
|
||||||
|
self._build_adj()
|
||||||
|
self._find_upstream_activation_dep: Callable = lru_cache(
|
||||||
|
maxsize=len(self.nodes)*2
|
||||||
|
)(self._find_upstream_activation_dep)
|
||||||
|
|
||||||
|
self.graph = StateGraph(WorkflowState)
|
||||||
|
self.add_nodes()
|
||||||
|
self.add_edges()
|
||||||
|
|
||||||
|
self._analyze_end_node_output()
|
||||||
checkpointer = InMemorySaver()
|
checkpointer = InMemorySaver()
|
||||||
self.graph = self.graph.compile(checkpointer=checkpointer)
|
return self.graph.compile(checkpointer=checkpointer)
|
||||||
return self.graph
|
|
||||||
|
|||||||
@@ -2,6 +2,7 @@
|
|||||||
# Author: Eternity
|
# Author: Eternity
|
||||||
# @Email: 1533512157@qq.com
|
# @Email: 1533512157@qq.com
|
||||||
# @Time : 2026/2/10 13:33
|
# @Time : 2026/2/10 13:33
|
||||||
|
from app.core.workflow.engine.runtime_schema import ExecutionContext
|
||||||
from app.core.workflow.engine.variable_pool import VariablePool
|
from app.core.workflow.engine.variable_pool import VariablePool
|
||||||
|
|
||||||
|
|
||||||
@@ -9,6 +10,7 @@ class WorkflowResultBuilder:
|
|||||||
def build_final_output(
|
def build_final_output(
|
||||||
self,
|
self,
|
||||||
result: dict,
|
result: dict,
|
||||||
|
execution_context: ExecutionContext,
|
||||||
variable_pool: VariablePool,
|
variable_pool: VariablePool,
|
||||||
elapsed_time: float,
|
elapsed_time: float,
|
||||||
final_output: str,
|
final_output: str,
|
||||||
@@ -26,6 +28,8 @@ class WorkflowResultBuilder:
|
|||||||
- "node_outputs" (dict): Outputs of executed nodes.
|
- "node_outputs" (dict): Outputs of executed nodes.
|
||||||
- "messages" (list): Conversation messages exchanged during execution.
|
- "messages" (list): Conversation messages exchanged during execution.
|
||||||
- "error" (str, optional): Error message if any node failed.
|
- "error" (str, optional): Error message if any node failed.
|
||||||
|
execution_context (ExecutionContext): The execution context containing metadata like
|
||||||
|
execution ID, workspace ID, and user ID.)
|
||||||
variable_pool (VariablePool): Variable Pool
|
variable_pool (VariablePool): Variable Pool
|
||||||
elapsed_time (float): Total execution time in seconds.
|
elapsed_time (float): Total execution time in seconds.
|
||||||
final_output (Any): The aggregated or final output content of the workflow
|
final_output (Any): The aggregated or final output content of the workflow
|
||||||
@@ -48,18 +52,23 @@ class WorkflowResultBuilder:
|
|||||||
"""
|
"""
|
||||||
node_outputs = result.get("node_outputs", {})
|
node_outputs = result.get("node_outputs", {})
|
||||||
token_usage = self.aggregate_token_usage(node_outputs)
|
token_usage = self.aggregate_token_usage(node_outputs)
|
||||||
conversation_id = variable_pool.get_value("sys.conversation_id")
|
conversation_vars = {}
|
||||||
|
sys_vars = {}
|
||||||
|
|
||||||
|
if variable_pool:
|
||||||
|
conversation_vars = variable_pool.get_all_conversation_vars()
|
||||||
|
sys_vars = variable_pool.get_all_system_vars()
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"status": "completed" if success else "failed",
|
"status": "completed" if success else "failed",
|
||||||
"output": final_output,
|
"output": final_output,
|
||||||
"variables": {
|
"variables": {
|
||||||
"conv": variable_pool.get_all_conversation_vars(),
|
"conv": conversation_vars,
|
||||||
"sys": variable_pool.get_all_system_vars()
|
"sys": sys_vars
|
||||||
},
|
},
|
||||||
"node_outputs": node_outputs,
|
"node_outputs": node_outputs,
|
||||||
"messages": result.get("messages", []),
|
"messages": result.get("messages", []),
|
||||||
"conversation_id": conversation_id,
|
"conversation_id": execution_context.conversation_id,
|
||||||
"elapsed_time": elapsed_time,
|
"elapsed_time": elapsed_time,
|
||||||
"token_usage": token_usage,
|
"token_usage": token_usage,
|
||||||
"error": result.get("error"),
|
"error": result.get("error"),
|
||||||
|
|||||||
@@ -12,14 +12,29 @@ class ExecutionContext(BaseModel):
|
|||||||
execution_id: str
|
execution_id: str
|
||||||
workspace_id: str
|
workspace_id: str
|
||||||
user_id: str
|
user_id: str
|
||||||
|
conversation_id: str
|
||||||
|
memory_storage_type: str
|
||||||
|
user_rag_memory_id: str
|
||||||
checkpoint_config: RunnableConfig
|
checkpoint_config: RunnableConfig
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def create(cls, execution_id: str, workspace_id: str, user_id: str):
|
def create(
|
||||||
|
cls,
|
||||||
|
execution_id: str,
|
||||||
|
workspace_id: str,
|
||||||
|
user_id: str,
|
||||||
|
conversation_id: str,
|
||||||
|
memory_storage_type: str,
|
||||||
|
user_rag_memory_id: str
|
||||||
|
):
|
||||||
return cls(
|
return cls(
|
||||||
execution_id=execution_id,
|
execution_id=execution_id,
|
||||||
workspace_id=workspace_id,
|
workspace_id=workspace_id,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
|
conversation_id=conversation_id,
|
||||||
|
memory_storage_type=memory_storage_type,
|
||||||
|
user_rag_memory_id=user_rag_memory_id,
|
||||||
|
|
||||||
checkpoint_config=RunnableConfig(
|
checkpoint_config=RunnableConfig(
|
||||||
configurable={
|
configurable={
|
||||||
"thread_id": uuid.uuid4(),
|
"thread_id": uuid.uuid4(),
|
||||||
|
|||||||
@@ -33,6 +33,8 @@ class WorkflowState(dict):
|
|||||||
"workspace_id",
|
"workspace_id",
|
||||||
"user_id",
|
"user_id",
|
||||||
"activate",
|
"activate",
|
||||||
|
"memory_storage_type",
|
||||||
|
"user_rag_memory_id"
|
||||||
})
|
})
|
||||||
__optional_keys__ = frozenset({
|
__optional_keys__ = frozenset({
|
||||||
"error",
|
"error",
|
||||||
@@ -62,6 +64,9 @@ class WorkflowState(dict):
|
|||||||
# node activate status
|
# node activate status
|
||||||
activate: Annotated[dict[str, bool], merge_activate_state]
|
activate: Annotated[dict[str, bool], merge_activate_state]
|
||||||
|
|
||||||
|
memory_storage_type: str
|
||||||
|
user_rag_memory_id: str
|
||||||
|
|
||||||
|
|
||||||
class WorkflowStateManager:
|
class WorkflowStateManager:
|
||||||
def create_initial_state(
|
def create_initial_state(
|
||||||
@@ -85,7 +90,9 @@ class WorkflowStateManager:
|
|||||||
looping=0,
|
looping=0,
|
||||||
activate={
|
activate={
|
||||||
start_node_id: True
|
start_node_id: True
|
||||||
}
|
},
|
||||||
|
memory_storage_type=execution_context.memory_storage_type,
|
||||||
|
user_rag_memory_id=execution_context.user_rag_memory_id
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|||||||
@@ -3,7 +3,7 @@
|
|||||||
# @Email: 1533512157@qq.com
|
# @Email: 1533512157@qq.com
|
||||||
# @Time : 2026/2/9 15:11
|
# @Time : 2026/2/9 15:11
|
||||||
import re
|
import re
|
||||||
from queue import Queue
|
from collections import deque
|
||||||
from typing import AsyncGenerator
|
from typing import AsyncGenerator
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, PrivateAttr
|
from pydantic import BaseModel, Field, PrivateAttr
|
||||||
@@ -256,7 +256,7 @@ class StreamOutputCoordinator:
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.end_outputs: dict[str, StreamOutputConfig] = {}
|
self.end_outputs: dict[str, StreamOutputConfig] = {}
|
||||||
self.activate_end: str | None = None
|
self.activate_end: str | None = None
|
||||||
self.output_queue: Queue = Queue()
|
self.output_queue: deque[str] = deque()
|
||||||
self.processed_outputs = []
|
self.processed_outputs = []
|
||||||
|
|
||||||
def initialize_end_outputs(
|
def initialize_end_outputs(
|
||||||
@@ -266,7 +266,7 @@ class StreamOutputCoordinator:
|
|||||||
self.end_outputs = end_node_map
|
self.end_outputs = end_node_map
|
||||||
self.processed_outputs = []
|
self.processed_outputs = []
|
||||||
self.activate_end = None
|
self.activate_end = None
|
||||||
self.output_queue = Queue()
|
self.output_queue = deque()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def current_activate_end_info(self):
|
def current_activate_end_info(self):
|
||||||
@@ -296,13 +296,13 @@ class StreamOutputCoordinator:
|
|||||||
scope (str): The node ID or scope that has completed execution.
|
scope (str): The node ID or scope that has completed execution.
|
||||||
status (str | None): Optional status of the node (used for branch/control nodes).
|
status (str | None): Optional status of the node (used for branch/control nodes).
|
||||||
"""
|
"""
|
||||||
for node in self.end_outputs.keys():
|
for node in self.end_outputs:
|
||||||
self.end_outputs[node].update_activate(scope, status)
|
self.end_outputs[node].update_activate(scope, status)
|
||||||
if self.end_outputs[node].activate and node not in self.processed_outputs:
|
if self.end_outputs[node].activate and node not in self.processed_outputs:
|
||||||
self.output_queue.put(node)
|
self.output_queue.append(node)
|
||||||
self.processed_outputs.append(node)
|
self.processed_outputs.append(node)
|
||||||
if self.activate_end is None and not self.output_queue.empty():
|
if self.activate_end is None and self.output_queue:
|
||||||
self.activate_end = self.output_queue.get_nowait()
|
self.activate_end = self.output_queue.popleft()
|
||||||
|
|
||||||
async def emit_activate_chunk(
|
async def emit_activate_chunk(
|
||||||
self,
|
self,
|
||||||
@@ -414,8 +414,8 @@ class StreamOutputCoordinator:
|
|||||||
async for msg_event in self.emit_activate_chunk(variable_pool, force=True):
|
async for msg_event in self.emit_activate_chunk(variable_pool, force=True):
|
||||||
yield msg_event
|
yield msg_event
|
||||||
|
|
||||||
if not self.output_queue.empty():
|
if self.output_queue:
|
||||||
self.activate_end = self.output_queue.get_nowait()
|
self.activate_end = self.output_queue.popleft()
|
||||||
# Move to next active End node if current one is done
|
# Move to next active End node if current one is done
|
||||||
if not self.activate_end and self.end_outputs:
|
if not self.activate_end and self.end_outputs:
|
||||||
self.activate_end = list(self.end_outputs.keys())[0]
|
self.activate_end = list(self.end_outputs.keys())[0]
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ from pydantic import BaseModel
|
|||||||
|
|
||||||
from app.core.workflow.engine.runtime_schema import ExecutionContext
|
from app.core.workflow.engine.runtime_schema import ExecutionContext
|
||||||
from app.core.workflow.variable.base_variable import VariableType, DEFAULT_VALUE
|
from app.core.workflow.variable.base_variable import VariableType, DEFAULT_VALUE
|
||||||
from app.core.workflow.variable.variable_objects import T, create_variable_instance
|
from app.core.workflow.variable.variable_objects import T, create_variable_instance, ArrayVariable, FileVariable
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -373,6 +373,16 @@ class VariablePool:
|
|||||||
def copy(self, pool: 'VariablePool'):
|
def copy(self, pool: 'VariablePool'):
|
||||||
self.variables = deepcopy(pool.variables)
|
self.variables = deepcopy(pool.variables)
|
||||||
|
|
||||||
|
def is_file_variable(self, selector):
|
||||||
|
variable_struct = self.get_instance(selector, default=None, strict=False)
|
||||||
|
if variable_struct is None:
|
||||||
|
return False
|
||||||
|
if isinstance(variable_struct, FileVariable):
|
||||||
|
return True
|
||||||
|
elif isinstance(variable_struct, ArrayVariable) and variable_struct.child_type == FileVariable:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
def to_dict(self) -> dict[str, Any]:
|
def to_dict(self) -> dict[str, Any]:
|
||||||
"""导出为字典
|
"""导出为字典
|
||||||
|
|
||||||
|
|||||||
@@ -3,6 +3,7 @@
|
|||||||
# @Email: 1533512157@qq.com
|
# @Email: 1533512157@qq.com
|
||||||
# @Time : 2026/2/9 13:51
|
# @Time : 2026/2/9 13:51
|
||||||
import datetime
|
import datetime
|
||||||
|
import time
|
||||||
import logging
|
import logging
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
@@ -82,13 +83,15 @@ class WorkflowExecutor:
|
|||||||
CompiledStateGraph: The compiled and ready-to-run state graph.
|
CompiledStateGraph: The compiled and ready-to-run state graph.
|
||||||
"""
|
"""
|
||||||
logger.info(f"Starting workflow graph build: execution_id={self.execution_context.execution_id}")
|
logger.info(f"Starting workflow graph build: execution_id={self.execution_context.execution_id}")
|
||||||
|
start_time = time.time()
|
||||||
builder = GraphBuilder(
|
builder = GraphBuilder(
|
||||||
self.workflow_config,
|
self.workflow_config,
|
||||||
stream=stream,
|
stream=stream,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.graph = builder.build()
|
||||||
self.start_node_id = builder.start_node_id
|
self.start_node_id = builder.start_node_id
|
||||||
self.variable_pool = builder.variable_pool
|
self.variable_pool = builder.variable_pool
|
||||||
self.graph = builder.build()
|
|
||||||
|
|
||||||
self.stream_coordinator.initialize_end_outputs(builder.end_node_map)
|
self.stream_coordinator.initialize_end_outputs(builder.end_node_map)
|
||||||
self.event_handler = EventStreamHandler(
|
self.event_handler = EventStreamHandler(
|
||||||
@@ -96,7 +99,8 @@ class WorkflowExecutor:
|
|||||||
variable_pool=self.variable_pool,
|
variable_pool=self.variable_pool,
|
||||||
execution_id=self.execution_context.execution_id
|
execution_id=self.execution_context.execution_id
|
||||||
)
|
)
|
||||||
logger.info(f"Workflow graph build completed: execution_id={self.execution_context.execution_id}")
|
logger.info(f"Workflow graph build completed: execution_id={self.execution_context.execution_id}, "
|
||||||
|
f"cost: {time.time() - start_time:.4f}s")
|
||||||
|
|
||||||
return self.graph
|
return self.graph
|
||||||
|
|
||||||
@@ -134,94 +138,12 @@ class WorkflowExecutor:
|
|||||||
return event.get("data")
|
return event.get("data")
|
||||||
return self.result_builder.build_final_output(
|
return self.result_builder.build_final_output(
|
||||||
{"error": "Workflow execution did not end as expected"},
|
{"error": "Workflow execution did not end as expected"},
|
||||||
|
self.execution_context,
|
||||||
self.variable_pool,
|
self.variable_pool,
|
||||||
(datetime.datetime.now() - start).total_seconds(),
|
(datetime.datetime.now() - start).total_seconds(),
|
||||||
"",
|
"",
|
||||||
success=False
|
success=False
|
||||||
)
|
)
|
||||||
# logger.info(f"Starting workflow execution: execution_id={self.execution_context.execution_id}")
|
|
||||||
#
|
|
||||||
# start_time = datetime.datetime.now()
|
|
||||||
#
|
|
||||||
# # Execute the workflow
|
|
||||||
# try:
|
|
||||||
# # Build the workflow graph
|
|
||||||
# graph = self.build_graph()
|
|
||||||
#
|
|
||||||
# # Initialize the variable pool with input data
|
|
||||||
# await self.variable_initializer.initialize(
|
|
||||||
# variable_pool=self.variable_pool,
|
|
||||||
# input_data=input_data,
|
|
||||||
# execution_context=self.execution_context
|
|
||||||
# )
|
|
||||||
# initial_state = self.state_manager.create_initial_state(
|
|
||||||
# workflow_config=self.workflow_config,
|
|
||||||
# input_data=input_data,
|
|
||||||
# execution_context=self.execution_context,
|
|
||||||
# start_node_id=self.start_node_id
|
|
||||||
# )
|
|
||||||
#
|
|
||||||
# result = await graph.ainvoke(initial_state, config=self.execution_context.checkpoint_config)
|
|
||||||
#
|
|
||||||
# # Aggregate output from all End nodes
|
|
||||||
# full_content = ''
|
|
||||||
# for end_id in self.stream_coordinator.end_outputs.keys():
|
|
||||||
# full_content += self.variable_pool.get_value(f"{end_id}.output", default="", strict=False)
|
|
||||||
#
|
|
||||||
# # Append messages for user and assistant
|
|
||||||
# if input_data.get("files"):
|
|
||||||
# result["messages"].extend(
|
|
||||||
# [
|
|
||||||
# {
|
|
||||||
# "role": "user",
|
|
||||||
# "content": input_data.get("message", '')
|
|
||||||
# },
|
|
||||||
# {
|
|
||||||
# "role": "user",
|
|
||||||
# "content": input_data.get("files")
|
|
||||||
# },
|
|
||||||
# {
|
|
||||||
# "role": "assistant",
|
|
||||||
# "content": full_content
|
|
||||||
# }
|
|
||||||
# ]
|
|
||||||
# )
|
|
||||||
# else:
|
|
||||||
# result["messages"].extend(
|
|
||||||
# [
|
|
||||||
# {
|
|
||||||
# "role": "user",
|
|
||||||
# "content": input_data.get("message", '')
|
|
||||||
# },
|
|
||||||
# {
|
|
||||||
# "role": "assistant",
|
|
||||||
# "content": full_content
|
|
||||||
# }
|
|
||||||
# ]
|
|
||||||
# )
|
|
||||||
# # Calculate elapsed time
|
|
||||||
# end_time = datetime.datetime.now()
|
|
||||||
# elapsed_time = (end_time - start_time).total_seconds()
|
|
||||||
#
|
|
||||||
# logger.info(
|
|
||||||
# f"Workflow execution completed: execution_id={self.execution_context.execution_id}, elapsed_time={elapsed_time:.2f}ms")
|
|
||||||
#
|
|
||||||
# return self.result_builder.build_final_output(result, self.variable_pool, elapsed_time, full_content)
|
|
||||||
#
|
|
||||||
# except Exception as e:
|
|
||||||
# end_time = datetime.datetime.now()
|
|
||||||
# elapsed_time = (end_time - start_time).total_seconds()
|
|
||||||
#
|
|
||||||
# logger.error(f"Workflow execution failed: execution_id={self.execution_context.execution_id}, error={e}",
|
|
||||||
# exc_info=True)
|
|
||||||
# return {
|
|
||||||
# "status": "failed",
|
|
||||||
# "error": str(e),
|
|
||||||
# "output": None,
|
|
||||||
# "node_outputs": {},
|
|
||||||
# "elapsed_time": elapsed_time,
|
|
||||||
# "token_usage": None
|
|
||||||
# }
|
|
||||||
|
|
||||||
async def execute_stream(
|
async def execute_stream(
|
||||||
self,
|
self,
|
||||||
@@ -255,7 +177,7 @@ class WorkflowExecutor:
|
|||||||
"data": {
|
"data": {
|
||||||
"execution_id": self.execution_context.execution_id,
|
"execution_id": self.execution_context.execution_id,
|
||||||
"workspace_id": self.execution_context.workspace_id,
|
"workspace_id": self.execution_context.workspace_id,
|
||||||
"conversation_id": input_data.get("conversation_id"),
|
"conversation_id": self.execution_context.conversation_id,
|
||||||
"timestamp": int(start_time.timestamp() * 1000)
|
"timestamp": int(start_time.timestamp() * 1000)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -376,6 +298,7 @@ class WorkflowExecutor:
|
|||||||
"event": "workflow_end",
|
"event": "workflow_end",
|
||||||
"data": self.result_builder.build_final_output(
|
"data": self.result_builder.build_final_output(
|
||||||
result,
|
result,
|
||||||
|
self.execution_context,
|
||||||
self.variable_pool,
|
self.variable_pool,
|
||||||
elapsed_time,
|
elapsed_time,
|
||||||
full_content,
|
full_content,
|
||||||
@@ -396,6 +319,7 @@ class WorkflowExecutor:
|
|||||||
"event": "workflow_end",
|
"event": "workflow_end",
|
||||||
"data": self.result_builder.build_final_output(
|
"data": self.result_builder.build_final_output(
|
||||||
result,
|
result,
|
||||||
|
self.execution_context,
|
||||||
self.variable_pool,
|
self.variable_pool,
|
||||||
elapsed_time,
|
elapsed_time,
|
||||||
full_content,
|
full_content,
|
||||||
@@ -409,7 +333,9 @@ async def execute_workflow(
|
|||||||
input_data: dict[str, Any],
|
input_data: dict[str, Any],
|
||||||
execution_id: str,
|
execution_id: str,
|
||||||
workspace_id: str,
|
workspace_id: str,
|
||||||
user_id: str
|
user_id: str,
|
||||||
|
memory_storage_type: str,
|
||||||
|
user_rag_memory_id: str
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Execute a workflow (convenience function, non-streaming).
|
Execute a workflow (convenience function, non-streaming).
|
||||||
@@ -420,6 +346,8 @@ async def execute_workflow(
|
|||||||
execution_id (str): Execution ID.
|
execution_id (str): Execution ID.
|
||||||
workspace_id (str): Workspace ID.
|
workspace_id (str): Workspace ID.
|
||||||
user_id (str): User ID.
|
user_id (str): User ID.
|
||||||
|
user_rag_memory_id: rag knowledge db id
|
||||||
|
memory_storage_type: neo4j / rag
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
dict: Workflow execution result.
|
dict: Workflow execution result.
|
||||||
@@ -427,7 +355,10 @@ async def execute_workflow(
|
|||||||
execution_context = ExecutionContext.create(
|
execution_context = ExecutionContext.create(
|
||||||
execution_id=execution_id,
|
execution_id=execution_id,
|
||||||
workspace_id=workspace_id,
|
workspace_id=workspace_id,
|
||||||
user_id=user_id
|
user_id=user_id,
|
||||||
|
conversation_id=input_data.get("conversation_id"),
|
||||||
|
memory_storage_type=memory_storage_type,
|
||||||
|
user_rag_memory_id=user_rag_memory_id
|
||||||
)
|
)
|
||||||
executor = WorkflowExecutor(
|
executor = WorkflowExecutor(
|
||||||
workflow_config=workflow_config,
|
workflow_config=workflow_config,
|
||||||
@@ -441,7 +372,9 @@ async def execute_workflow_stream(
|
|||||||
input_data: dict[str, Any],
|
input_data: dict[str, Any],
|
||||||
execution_id: str,
|
execution_id: str,
|
||||||
workspace_id: str,
|
workspace_id: str,
|
||||||
user_id: str
|
user_id: str,
|
||||||
|
memory_storage_type: str,
|
||||||
|
user_rag_memory_id: str
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Execute a workflow in streaming mode (convenience function).
|
Execute a workflow in streaming mode (convenience function).
|
||||||
@@ -452,6 +385,8 @@ async def execute_workflow_stream(
|
|||||||
execution_id (str): Execution ID.
|
execution_id (str): Execution ID.
|
||||||
workspace_id (str): Workspace ID.
|
workspace_id (str): Workspace ID.
|
||||||
user_id (str): User ID.
|
user_id (str): User ID.
|
||||||
|
user_rag_memory_id: rag knowledge db id
|
||||||
|
memory_storage_type: neo4j / rag
|
||||||
|
|
||||||
Yields:
|
Yields:
|
||||||
dict: Streaming workflow events, e.g. node start, node end, chunk messages, workflow end.
|
dict: Streaming workflow events, e.g. node start, node end, chunk messages, workflow end.
|
||||||
@@ -459,7 +394,10 @@ async def execute_workflow_stream(
|
|||||||
execution_context = ExecutionContext.create(
|
execution_context = ExecutionContext.create(
|
||||||
execution_id=execution_id,
|
execution_id=execution_id,
|
||||||
workspace_id=workspace_id,
|
workspace_id=workspace_id,
|
||||||
user_id=user_id
|
user_id=user_id,
|
||||||
|
memory_storage_type=memory_storage_type,
|
||||||
|
conversation_id=input_data.get("conversation_id"),
|
||||||
|
user_rag_memory_id=user_rag_memory_id
|
||||||
)
|
)
|
||||||
executor = WorkflowExecutor(
|
executor = WorkflowExecutor(
|
||||||
workflow_config=workflow_config,
|
workflow_config=workflow_config,
|
||||||
|
|||||||
@@ -65,8 +65,6 @@ class AgentNode(BaseNode):
|
|||||||
if not release:
|
if not release:
|
||||||
raise ValueError(f"Agent 不存在: {agent_id}")
|
raise ValueError(f"Agent 不存在: {agent_id}")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
return release, message
|
return release, message
|
||||||
|
|
||||||
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]:
|
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]:
|
||||||
|
|||||||
@@ -14,8 +14,8 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
class AssignerNode(BaseNode):
|
class AssignerNode(BaseNode):
|
||||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
|
||||||
super().__init__(node_config, workflow_config)
|
super().__init__(node_config, workflow_config, down_stream_nodes)
|
||||||
self.variable_updater = True
|
self.variable_updater = True
|
||||||
self.typed_config: AssignerNodeConfig | None = None
|
self.typed_config: AssignerNodeConfig | None = None
|
||||||
|
|
||||||
|
|||||||
@@ -28,7 +28,7 @@ class BaseNode(ABC):
|
|||||||
All node types should inherit from this class and implement the `execute` method.
|
All node types should inherit from this class and implement the `execute` method.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
|
||||||
"""Initialize the node.
|
"""Initialize the node.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -41,6 +41,7 @@ class BaseNode(ABC):
|
|||||||
self.node_type = node_config["type"]
|
self.node_type = node_config["type"]
|
||||||
self.cycle = node_config.get("cycle")
|
self.cycle = node_config.get("cycle")
|
||||||
self.node_name = node_config.get("name", self.node_id)
|
self.node_name = node_config.get("name", self.node_id)
|
||||||
|
self.down_stream_nodes = down_stream_nodes
|
||||||
# 使用 or 运算符处理 None 值
|
# 使用 or 运算符处理 None 值
|
||||||
self.config = node_config.get("config") or {}
|
self.config = node_config.get("config") or {}
|
||||||
self.error_handling = node_config.get("error_handling") or {}
|
self.error_handling = node_config.get("error_handling") or {}
|
||||||
@@ -93,18 +94,16 @@ class BaseNode(ABC):
|
|||||||
dict: A dict with a single key 'activate', mapping node IDs to
|
dict: A dict with a single key 'activate', mapping node IDs to
|
||||||
their activation status (True/False).
|
their activation status (True/False).
|
||||||
"""
|
"""
|
||||||
edges = self.workflow_config.get("edges")
|
activate_flag = self.check_activate(state)
|
||||||
under_stream_nodes = [
|
|
||||||
edge.get("target")
|
if self.node_type not in BRANCH_NODES:
|
||||||
for edge in edges
|
activate = {node_id: activate_flag for node_id in self.down_stream_nodes}
|
||||||
if edge.get("source") == self.node_id and self.node_type not in BRANCH_NODES
|
else:
|
||||||
]
|
activate = {}
|
||||||
return {
|
|
||||||
"activate": {
|
activate[self.node_id] = activate_flag
|
||||||
node_id: self.check_activate(state)
|
|
||||||
for node_id in under_stream_nodes
|
return {"activate": activate}
|
||||||
} | {self.node_id: self.check_activate(state)}
|
|
||||||
}
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any:
|
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any:
|
||||||
@@ -315,8 +314,8 @@ class BaseNode(ABC):
|
|||||||
|
|
||||||
elapsed_time = (time.time() - start_time) * 1000
|
elapsed_time = (time.time() - start_time) * 1000
|
||||||
|
|
||||||
logger.info(f"Node {self.node_id} streaming execution finished, "
|
logger.debug(f"Node {self.node_id} streaming execution finished, "
|
||||||
f"time elapsed: {elapsed_time:.2f}ms, chunks: {chunk_count}")
|
f"time elapsed: {elapsed_time:.2f}ms, chunks: {chunk_count}")
|
||||||
|
|
||||||
# Extract processed output (call subclass's _extract_output)
|
# Extract processed output (call subclass's _extract_output)
|
||||||
extracted_output = self._extract_output(final_result)
|
extracted_output = self._extract_output(final_result)
|
||||||
@@ -428,8 +427,8 @@ class BaseNode(ABC):
|
|||||||
when an error edge exists. If no error edge exists, this method
|
when an error edge exists. If no error edge exists, this method
|
||||||
raises an exception to stop the workflow.
|
raises an exception to stop the workflow.
|
||||||
"""
|
"""
|
||||||
# Check if the node has an error edge defined
|
# # Check if the node has an error edge defined
|
||||||
error_edge = self._find_error_edge()
|
# error_edge = self._find_error_edge()
|
||||||
|
|
||||||
# Extract input data (for logging or audit purposes)
|
# Extract input data (for logging or audit purposes)
|
||||||
input_data = self._extract_input(state, variable_pool)
|
input_data = self._extract_input(state, variable_pool)
|
||||||
@@ -447,27 +446,26 @@ class BaseNode(ABC):
|
|||||||
"error": error_message
|
"error": error_message
|
||||||
}
|
}
|
||||||
|
|
||||||
if error_edge:
|
# if error_edge:
|
||||||
# If an error edge exists, log a warning and continue to error node
|
# # If an error edge exists, log a warning and continue to error node
|
||||||
logger.warning(
|
# logger.warning(
|
||||||
f"Node {self.node_id} execution failed, redirecting to error node: {error_edge['target']}"
|
# f"Node {self.node_id} execution failed, redirecting to error node: {error_edge['target']}"
|
||||||
)
|
# )
|
||||||
return {
|
# return {
|
||||||
"node_outputs": {
|
# "node_outputs": {
|
||||||
self.node_id: node_output
|
# self.node_id: node_output
|
||||||
},
|
# },
|
||||||
"error": error_message,
|
# "error": error_message,
|
||||||
"error_node": self.node_id
|
# "error_node": self.node_id
|
||||||
}
|
# }
|
||||||
else:
|
# else:
|
||||||
# If no error edge, send the error via stream writer and stop the workflow
|
writer = get_stream_writer()
|
||||||
writer = get_stream_writer()
|
writer({
|
||||||
writer({
|
"type": "node_error",
|
||||||
"type": "node_error",
|
**node_output
|
||||||
**node_output
|
})
|
||||||
})
|
logger.error(f"Node {self.node_id} execution failed, stopping workflow: {error_message}")
|
||||||
logger.error(f"Node {self.node_id} execution failed, stopping workflow: {error_message}")
|
raise Exception(f"Node {self.node_id} execution failed: {error_message}")
|
||||||
raise Exception(f"Node {self.node_id} execution failed: {error_message}")
|
|
||||||
|
|
||||||
def _extract_input(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]:
|
def _extract_input(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]:
|
||||||
"""Extracts the input data for this node (used for logging or audit).
|
"""Extracts the input data for this node (used for logging or audit).
|
||||||
@@ -623,7 +621,6 @@ class BaseNode(ABC):
|
|||||||
async def process_message(
|
async def process_message(
|
||||||
api_config: ModelInfo,
|
api_config: ModelInfo,
|
||||||
content: str | dict | FileObject,
|
content: str | dict | FileObject,
|
||||||
end_user_id: str,
|
|
||||||
enable_file=False
|
enable_file=False
|
||||||
) -> list | str | None:
|
) -> list | str | None:
|
||||||
provider = api_config.provider
|
provider = api_config.provider
|
||||||
@@ -642,10 +639,10 @@ class BaseNode(ABC):
|
|||||||
return content
|
return content
|
||||||
|
|
||||||
elif isinstance(content, FileObject):
|
elif isinstance(content, FileObject):
|
||||||
if content.content_cache.get(provider):
|
if content.content_cache.get(f"{provider}_{api_config.is_omni}"):
|
||||||
return content.content_cache[provider]
|
return content.content_cache[f"{provider}_{api_config.is_omni}"]
|
||||||
with get_db_read() as db:
|
with get_db_read() as db:
|
||||||
multimodel_service = MultimodalService(db, api_config=api_config)
|
multimodal_service = MultimodalService(db, api_config=api_config)
|
||||||
file_obj = FileInput(
|
file_obj = FileInput(
|
||||||
type=content.type,
|
type=content.type,
|
||||||
url=content.url,
|
url=content.url,
|
||||||
@@ -654,16 +651,15 @@ class BaseNode(ABC):
|
|||||||
upload_file_id=uuid.UUID(content.file_id) if content.file_id else None,
|
upload_file_id=uuid.UUID(content.file_id) if content.file_id else None,
|
||||||
)
|
)
|
||||||
file_obj.set_content(content.get_content())
|
file_obj.set_content(content.get_content())
|
||||||
message = await multimodel_service.process_files(
|
message = await multimodal_service.process_files(
|
||||||
end_user_id,
|
|
||||||
[file_obj],
|
[file_obj],
|
||||||
)
|
)
|
||||||
content.set_content(file_obj.get_content())
|
content.set_content(file_obj.get_content())
|
||||||
if message:
|
if message:
|
||||||
content.content_cache[provider] = message
|
content.content_cache[f"{provider}_{api_config.is_omni}"] = message
|
||||||
return message
|
return message
|
||||||
return None
|
return None
|
||||||
raise TypeError(f'Unexpect input value type - {type(content)}')
|
raise TypeError(f'Unexpected input value type - {type(content)}')
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def process_model_output(content) -> str:
|
def process_model_output(content) -> str:
|
||||||
|
|||||||
@@ -51,8 +51,8 @@ console.log(result)
|
|||||||
|
|
||||||
|
|
||||||
class CodeNode(BaseNode):
|
class CodeNode(BaseNode):
|
||||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
|
||||||
super().__init__(node_config, workflow_config)
|
super().__init__(node_config, workflow_config, down_stream_nodes)
|
||||||
self.typed_config: CodeNodeConfig | None = None
|
self.typed_config: CodeNodeConfig | None = None
|
||||||
|
|
||||||
def _output_types(self) -> dict[str, VariableType]:
|
def _output_types(self) -> dict[str, VariableType]:
|
||||||
|
|||||||
@@ -30,17 +30,13 @@ class CycleGraphNode(BaseNode):
|
|||||||
It acts as a container and execution controller for a subgraph.
|
It acts as a container and execution controller for a subgraph.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
|
||||||
super().__init__(node_config, workflow_config)
|
super().__init__(node_config, workflow_config, down_stream_nodes)
|
||||||
|
self.cycle_nodes, self.cycle_edges = self.pure_cycle_graph()
|
||||||
self.cycle_nodes = list() # Nodes belonging to this cycle
|
|
||||||
self.cycle_edges = list() # Edges connecting nodes within the cycle
|
|
||||||
self.start_node_id = None # ID of the start node within the cycle
|
self.start_node_id = None # ID of the start node within the cycle
|
||||||
|
|
||||||
self.graph: StateGraph | CompiledStateGraph | None = None
|
self.graph: StateGraph | CompiledStateGraph | None = None
|
||||||
self.child_variable_pool: VariablePool | None = None
|
self.child_variable_pool: VariablePool | None = None
|
||||||
self.build_graph()
|
|
||||||
self.iteration_flag = True
|
|
||||||
|
|
||||||
def _output_types(self) -> dict[str, VariableType]:
|
def _output_types(self) -> dict[str, VariableType]:
|
||||||
outputs = {"__child_state": VariableType.ARRAY_OBJECT}
|
outputs = {"__child_state": VariableType.ARRAY_OBJECT}
|
||||||
@@ -119,11 +115,11 @@ class CycleGraphNode(BaseNode):
|
|||||||
else:
|
else:
|
||||||
remain_edges.append(edge)
|
remain_edges.append(edge)
|
||||||
|
|
||||||
# Update workflow_config by removing cycle nodes and internal edges
|
# # Update workflow_config by removing cycle nodes and internal edges
|
||||||
self.workflow_config["nodes"] = [
|
# self.workflow_config["nodes"] = [
|
||||||
node for node in nodes if node.get("cycle") != self.node_id
|
# node for node in nodes if node.get("cycle") != self.node_id
|
||||||
]
|
# ]
|
||||||
self.workflow_config["edges"] = remain_edges
|
# self.workflow_config["edges"] = remain_edges
|
||||||
|
|
||||||
return cycle_nodes, cycle_edges
|
return cycle_nodes, cycle_edges
|
||||||
|
|
||||||
@@ -137,18 +133,18 @@ class CycleGraphNode(BaseNode):
|
|||||||
3. Compile the graph for runtime execution
|
3. Compile the graph for runtime execution
|
||||||
"""
|
"""
|
||||||
from app.core.workflow.engine.graph_builder import GraphBuilder
|
from app.core.workflow.engine.graph_builder import GraphBuilder
|
||||||
self.cycle_nodes, self.cycle_edges = self.pure_cycle_graph()
|
|
||||||
self.child_variable_pool = VariablePool()
|
self.child_variable_pool = VariablePool()
|
||||||
builder = GraphBuilder(
|
builder = GraphBuilder(
|
||||||
{
|
{
|
||||||
"nodes": self.cycle_nodes,
|
"nodes": self.cycle_nodes,
|
||||||
"edges": self.cycle_edges,
|
"edges": self.cycle_edges,
|
||||||
},
|
},
|
||||||
subgraph=True,
|
variable_pool=self.child_variable_pool,
|
||||||
variable_pool=self.child_variable_pool
|
cycle=self.node_id
|
||||||
)
|
)
|
||||||
self.start_node_id = builder.start_node_id
|
|
||||||
self.graph = builder.build()
|
self.graph = builder.build()
|
||||||
|
self.start_node_id = builder.start_node_id
|
||||||
self.child_variable_pool = builder.variable_pool
|
self.child_variable_pool = builder.variable_pool
|
||||||
|
|
||||||
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any:
|
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any:
|
||||||
@@ -169,6 +165,7 @@ class CycleGraphNode(BaseNode):
|
|||||||
Raises:
|
Raises:
|
||||||
RuntimeError: If the node type is unsupported.
|
RuntimeError: If the node type is unsupported.
|
||||||
"""
|
"""
|
||||||
|
self.build_graph()
|
||||||
if self.node_type == NodeType.LOOP:
|
if self.node_type == NodeType.LOOP:
|
||||||
return await LoopRuntime(
|
return await LoopRuntime(
|
||||||
start_id=self.start_node_id,
|
start_id=self.start_node_id,
|
||||||
@@ -194,6 +191,7 @@ class CycleGraphNode(BaseNode):
|
|||||||
raise RuntimeError("Unknown cycle node type")
|
raise RuntimeError("Unknown cycle node type")
|
||||||
|
|
||||||
async def execute_stream(self, state: WorkflowState, variable_pool: VariablePool):
|
async def execute_stream(self, state: WorkflowState, variable_pool: VariablePool):
|
||||||
|
self.build_graph()
|
||||||
if self.node_type == NodeType.LOOP:
|
if self.node_type == NodeType.LOOP:
|
||||||
yield {
|
yield {
|
||||||
"__final__": True,
|
"__final__": True,
|
||||||
|
|||||||
@@ -0,0 +1,4 @@
|
|||||||
|
from .config import DocExtractorNodeConfig
|
||||||
|
from .node import DocExtractorNode
|
||||||
|
|
||||||
|
__all__ = ["DocExtractorNode", "DocExtractorNodeConfig"]
|
||||||
18
api/app/core/workflow/nodes/document_extractor/config.py
Normal file
18
api/app/core/workflow/nodes/document_extractor/config.py
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
from pydantic import Field
|
||||||
|
from app.core.workflow.nodes.base_config import BaseNodeConfig
|
||||||
|
|
||||||
|
|
||||||
|
class DocExtractorNodeConfig(BaseNodeConfig):
|
||||||
|
file_selector: str = Field(
|
||||||
|
...,
|
||||||
|
description="File variable selector, e.g. {{ sys.files }} or {{ node_id.file }}"
|
||||||
|
)
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
json_schema_extra = {
|
||||||
|
"examples": [
|
||||||
|
{
|
||||||
|
"file_selector": "{{ sys.files }}"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
103
api/app/core/workflow/nodes/document_extractor/node.py
Normal file
103
api/app/core/workflow/nodes/document_extractor/node.py
Normal file
@@ -0,0 +1,103 @@
|
|||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from app.core.workflow.engine.state_manager import WorkflowState
|
||||||
|
from app.core.workflow.engine.variable_pool import VariablePool
|
||||||
|
from app.core.workflow.nodes.base_node import BaseNode
|
||||||
|
from app.core.workflow.nodes.document_extractor.config import DocExtractorNodeConfig
|
||||||
|
from app.core.workflow.variable.base_variable import VariableType, FileObject
|
||||||
|
from app.db import get_db_read
|
||||||
|
from app.schemas.app_schema import FileInput, FileType, TransferMethod
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _file_object_to_file_input(f: FileObject) -> FileInput:
|
||||||
|
"""Convert workflow FileObject to multimodal FileInput."""
|
||||||
|
return FileInput(
|
||||||
|
type=FileType.DOCUMENT,
|
||||||
|
transfer_method=TransferMethod(f.transfer_method),
|
||||||
|
url=f.url or None,
|
||||||
|
upload_file_id=f.file_id or None,
|
||||||
|
file_type=f.origin_file_type or "",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _normalise_files(val: Any) -> list[FileObject]:
|
||||||
|
if isinstance(val, FileObject):
|
||||||
|
return [val]
|
||||||
|
if isinstance(val, dict) and val.get("is_file"):
|
||||||
|
return [FileObject(**val)]
|
||||||
|
if isinstance(val, list):
|
||||||
|
result: list[FileObject] = []
|
||||||
|
for item in val:
|
||||||
|
if isinstance(item, FileObject):
|
||||||
|
result.append(item)
|
||||||
|
elif isinstance(item, dict) and item.get("is_file"):
|
||||||
|
result.append(FileObject(**item))
|
||||||
|
else:
|
||||||
|
logger.warning("Ignoring non-file entry in file list for document extractor: %r", item)
|
||||||
|
return result
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
class DocExtractorNode(BaseNode):
|
||||||
|
"""Document Extractor Node.
|
||||||
|
|
||||||
|
Reads one or more file variables and extracts their text content
|
||||||
|
by delegating to MultimodalService._extract_document_text.
|
||||||
|
|
||||||
|
Outputs:
|
||||||
|
text (string) – full concatenated text of all input files
|
||||||
|
chunks (array[string]) – per-file extracted text
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _output_types(self) -> dict[str, VariableType]:
|
||||||
|
return {
|
||||||
|
"text": VariableType.STRING,
|
||||||
|
"chunks": VariableType.ARRAY_STRING,
|
||||||
|
}
|
||||||
|
|
||||||
|
def _extract_output(self, business_result: Any) -> Any:
|
||||||
|
return business_result
|
||||||
|
|
||||||
|
def _extract_input(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]:
|
||||||
|
return {"file_selector": self.config.get("file_selector")}
|
||||||
|
|
||||||
|
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any:
|
||||||
|
config = DocExtractorNodeConfig(**self.config)
|
||||||
|
|
||||||
|
raw_val = self.get_variable(config.file_selector, variable_pool, strict=False)
|
||||||
|
if raw_val is None:
|
||||||
|
logger.warning(f"Node {self.node_id}: file variable '{config.file_selector}' is empty")
|
||||||
|
return {"text": "", "chunks": []}
|
||||||
|
|
||||||
|
files = _normalise_files(raw_val)
|
||||||
|
if not files:
|
||||||
|
return {"text": "", "chunks": []}
|
||||||
|
|
||||||
|
chunks: list[str] = []
|
||||||
|
with get_db_read() as db:
|
||||||
|
from app.services.multimodal_service import MultimodalService
|
||||||
|
svc = MultimodalService(db)
|
||||||
|
for f in files:
|
||||||
|
try:
|
||||||
|
file_input = _file_object_to_file_input(f)
|
||||||
|
# Ensure URL is populated for local files
|
||||||
|
if not file_input.url:
|
||||||
|
file_input.url = await svc.get_file_url(file_input)
|
||||||
|
# Reuse cached bytes if already fetched
|
||||||
|
if f.get_content():
|
||||||
|
file_input.set_content(f.get_content())
|
||||||
|
text = await svc._extract_document_text(file_input)
|
||||||
|
chunks.append(text)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Node {self.node_id}: failed to extract file url={f.url} file_id={f.file_id}: {e}",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
chunks.append("")
|
||||||
|
|
||||||
|
full_text = "\n\n".join(c for c in chunks if c)
|
||||||
|
logger.info(f"Node {self.node_id}: extracted {len(files)} file(s), total chars={len(full_text)}")
|
||||||
|
return {"text": full_text, "chunks": chunks}
|
||||||
@@ -1,9 +1,7 @@
|
|||||||
"""End 节点配置"""
|
"""End 节点配置"""
|
||||||
|
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
|
|
||||||
from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableDefinition
|
from app.core.workflow.nodes.base_config import BaseNodeConfig
|
||||||
from app.core.workflow.variable.base_variable import VariableType
|
|
||||||
|
|
||||||
|
|
||||||
class EndNodeConfig(BaseNodeConfig):
|
class EndNodeConfig(BaseNodeConfig):
|
||||||
|
|||||||
@@ -36,8 +36,6 @@ class EndNode(BaseNode):
|
|||||||
Returns:
|
Returns:
|
||||||
最终输出字符串
|
最终输出字符串
|
||||||
"""
|
"""
|
||||||
logger.info(f"节点 {self.node_id} (End) 开始执行")
|
|
||||||
|
|
||||||
# 获取配置的输出模板
|
# 获取配置的输出模板
|
||||||
output_template = self.config.get("output")
|
output_template = self.config.get("output")
|
||||||
|
|
||||||
@@ -46,11 +44,4 @@ class EndNode(BaseNode):
|
|||||||
output = self._render_template(output_template, variable_pool, strict=False)
|
output = self._render_template(output_template, variable_pool, strict=False)
|
||||||
else:
|
else:
|
||||||
output = ""
|
output = ""
|
||||||
|
|
||||||
# 统计信息(用于日志)
|
|
||||||
node_outputs = state.get("node_outputs", {})
|
|
||||||
total_nodes = len(node_outputs)
|
|
||||||
|
|
||||||
logger.info(f"节点 {self.node_id} (End) 执行完成,共执行 {total_nodes} 个节点")
|
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|||||||
@@ -23,12 +23,13 @@ class NodeType(StrEnum):
|
|||||||
BREAK = "break"
|
BREAK = "break"
|
||||||
MEMORY_READ = "memory-read"
|
MEMORY_READ = "memory-read"
|
||||||
MEMORY_WRITE = "memory-write"
|
MEMORY_WRITE = "memory-write"
|
||||||
|
DOCUMENT_EXTRACTOR = "document-extractor"
|
||||||
|
|
||||||
UNKNOWN = "unknown"
|
UNKNOWN = "unknown"
|
||||||
NOTES = "notes"
|
NOTES = "notes"
|
||||||
|
|
||||||
|
|
||||||
BRANCH_NODES = [NodeType.IF_ELSE, NodeType.HTTP_REQUEST, NodeType.QUESTION_CLASSIFIER]
|
BRANCH_NODES = frozenset({NodeType.IF_ELSE, NodeType.HTTP_REQUEST, NodeType.QUESTION_CLASSIFIER})
|
||||||
|
|
||||||
|
|
||||||
class ComparisonOperator(StrEnum):
|
class ComparisonOperator(StrEnum):
|
||||||
|
|||||||
@@ -115,7 +115,7 @@ class HttpRetryConfig(BaseModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class HttpErrorDefaultTamplete(BaseModel):
|
class HttpErrorDefaultTemplate(BaseModel):
|
||||||
body: str = Field(
|
body: str = Field(
|
||||||
default="",
|
default="",
|
||||||
description="Default body returned on HTTP error",
|
description="Default body returned on HTTP error",
|
||||||
@@ -143,7 +143,7 @@ class HttpErrorHandleConfig(BaseModel):
|
|||||||
description="Error handling strategy: 'none', 'default', or 'branch'",
|
description="Error handling strategy: 'none', 'default', or 'branch'",
|
||||||
)
|
)
|
||||||
|
|
||||||
default: HttpErrorDefaultTamplete | None = Field(
|
default: HttpErrorDefaultTemplate | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="Default response template for error handling",
|
description="Default response template for error handling",
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ from app.core.workflow.engine.variable_pool import VariablePool
|
|||||||
from app.core.workflow.nodes.base_node import BaseNode
|
from app.core.workflow.nodes.base_node import BaseNode
|
||||||
from app.core.workflow.nodes.enums import HttpRequestMethod, HttpErrorHandle, HttpAuthType, HttpContentType
|
from app.core.workflow.nodes.enums import HttpRequestMethod, HttpErrorHandle, HttpAuthType, HttpContentType
|
||||||
from app.core.workflow.nodes.http_request.config import HttpRequestNodeConfig, HttpRequestNodeOutput
|
from app.core.workflow.nodes.http_request.config import HttpRequestNodeConfig, HttpRequestNodeOutput
|
||||||
from app.core.workflow.utils.file_processer import mime_to_file_type
|
from app.core.workflow.utils.file_processor import mime_to_file_type
|
||||||
from app.core.workflow.variable.base_variable import VariableType, FileObject
|
from app.core.workflow.variable.base_variable import VariableType, FileObject
|
||||||
from app.core.workflow.variable.variable_objects import FileVariable, ArrayVariable
|
from app.core.workflow.variable.variable_objects import FileVariable, ArrayVariable
|
||||||
from app.schemas import FileType, TransferMethod
|
from app.schemas import FileType, TransferMethod
|
||||||
@@ -157,8 +157,8 @@ class HttpRequestNode(BaseNode):
|
|||||||
or a branch identifier string when error branching is enabled.
|
or a branch identifier string when error branching is enabled.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
|
||||||
super().__init__(node_config, workflow_config)
|
super().__init__(node_config, workflow_config, down_stream_nodes)
|
||||||
self.typed_config: HttpRequestNodeConfig | None = None
|
self.typed_config: HttpRequestNodeConfig | None = None
|
||||||
|
|
||||||
def _output_types(self) -> dict[str, VariableType]:
|
def _output_types(self) -> dict[str, VariableType]:
|
||||||
|
|||||||
@@ -14,8 +14,8 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
class IfElseNode(BaseNode):
|
class IfElseNode(BaseNode):
|
||||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
|
||||||
super().__init__(node_config, workflow_config)
|
super().__init__(node_config, workflow_config, down_stream_nodes)
|
||||||
self.typed_config: IfElseNodeConfig | None = None
|
self.typed_config: IfElseNodeConfig | None = None
|
||||||
|
|
||||||
def _output_types(self) -> dict[str, VariableType]:
|
def _output_types(self) -> dict[str, VariableType]:
|
||||||
|
|||||||
@@ -12,8 +12,8 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
class JinjaRenderNode(BaseNode):
|
class JinjaRenderNode(BaseNode):
|
||||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
|
||||||
super().__init__(node_config, workflow_config)
|
super().__init__(node_config, workflow_config, down_stream_nodes)
|
||||||
self.typed_config: JinjaRenderNodeConfig | None = None
|
self.typed_config: JinjaRenderNodeConfig | None = None
|
||||||
|
|
||||||
def _output_types(self) -> dict[str, VariableType]:
|
def _output_types(self) -> dict[str, VariableType]:
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ from typing import Any
|
|||||||
from app.core.error_codes import BizCode
|
from app.core.error_codes import BizCode
|
||||||
from app.core.exceptions import BusinessException
|
from app.core.exceptions import BusinessException
|
||||||
from app.core.models import RedBearRerank, RedBearModelConfig
|
from app.core.models import RedBearRerank, RedBearModelConfig
|
||||||
from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory
|
from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory, ElasticSearchVector
|
||||||
from app.core.workflow.engine.state_manager import WorkflowState
|
from app.core.workflow.engine.state_manager import WorkflowState
|
||||||
from app.core.workflow.engine.variable_pool import VariablePool
|
from app.core.workflow.engine.variable_pool import VariablePool
|
||||||
from app.core.workflow.nodes.base_node import BaseNode
|
from app.core.workflow.nodes.base_node import BaseNode
|
||||||
@@ -21,9 +21,10 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
class KnowledgeRetrievalNode(BaseNode):
|
class KnowledgeRetrievalNode(BaseNode):
|
||||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
|
||||||
super().__init__(node_config, workflow_config)
|
super().__init__(node_config, workflow_config, down_stream_nodes)
|
||||||
self.typed_config: KnowledgeRetrievalNodeConfig | None = None
|
self.typed_config: KnowledgeRetrievalNodeConfig | None = None
|
||||||
|
self.vector_service: ElasticSearchVector | None = None
|
||||||
|
|
||||||
def _output_types(self) -> dict[str, VariableType]:
|
def _output_types(self) -> dict[str, VariableType]:
|
||||||
return {
|
return {
|
||||||
@@ -163,6 +164,50 @@ class KnowledgeRetrievalNode(BaseNode):
|
|||||||
)
|
)
|
||||||
return reranker
|
return reranker
|
||||||
|
|
||||||
|
def knowledge_retrieval(self, db, query, rs, db_knowledge, kb_config):
|
||||||
|
if db_knowledge.type == knowledge_model.KnowledgeType.FOLDER:
|
||||||
|
children = knowledge_repository.get_knowledges_by_parent_id(db=db, parent_id=db_knowledge.id)
|
||||||
|
for child in children:
|
||||||
|
if not (child and child.chunk_num > 0 and child.status == 1):
|
||||||
|
continue
|
||||||
|
kb_config.kb_id = child.id
|
||||||
|
self.knowledge_retrieval(db, query, rs, child, kb_config)
|
||||||
|
return
|
||||||
|
self.vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
|
||||||
|
indices = f"Vector_index_{kb_config.kb_id}_Node".lower()
|
||||||
|
match kb_config.retrieve_type:
|
||||||
|
case RetrieveType.PARTICIPLE:
|
||||||
|
rs.extend(self.vector_service.search_by_full_text(query=query, top_k=kb_config.top_k,
|
||||||
|
indices=indices,
|
||||||
|
score_threshold=kb_config.similarity_threshold))
|
||||||
|
case RetrieveType.SEMANTIC:
|
||||||
|
rs.extend(self.vector_service.search_by_vector(query=query, top_k=kb_config.top_k,
|
||||||
|
indices=indices,
|
||||||
|
score_threshold=kb_config.vector_similarity_weight))
|
||||||
|
case RetrieveType.HYBRID:
|
||||||
|
rs1 = self.vector_service.search_by_vector(query=query, top_k=kb_config.top_k,
|
||||||
|
indices=indices,
|
||||||
|
score_threshold=kb_config.vector_similarity_weight)
|
||||||
|
rs2 = self.vector_service.search_by_full_text(query=query, top_k=kb_config.top_k,
|
||||||
|
indices=indices,
|
||||||
|
score_threshold=kb_config.similarity_threshold)
|
||||||
|
|
||||||
|
# Deduplicate hybrid retrieval results
|
||||||
|
unique_rs = self._deduplicate_docs(rs1, rs2)
|
||||||
|
if not unique_rs:
|
||||||
|
return
|
||||||
|
if self.typed_config.reranker_id:
|
||||||
|
self.vector_service.reranker = self.get_reranker_model()
|
||||||
|
rs.extend(self.vector_service.rerank(query=query, docs=unique_rs, top_k=kb_config.top_k))
|
||||||
|
else:
|
||||||
|
rs.extend(sorted(
|
||||||
|
unique_rs,
|
||||||
|
key=lambda d: d.metadata.get("score", 0),
|
||||||
|
reverse=True
|
||||||
|
)[:kb_config.top_k])
|
||||||
|
case _:
|
||||||
|
raise RuntimeError("Unknown retrieval type")
|
||||||
|
|
||||||
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any:
|
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any:
|
||||||
"""
|
"""
|
||||||
Execute the knowledge retrieval workflow node.
|
Execute the knowledge retrieval workflow node.
|
||||||
@@ -191,56 +236,19 @@ class KnowledgeRetrievalNode(BaseNode):
|
|||||||
query = self._render_template(self.typed_config.query, variable_pool)
|
query = self._render_template(self.typed_config.query, variable_pool)
|
||||||
with get_db_read() as db:
|
with get_db_read() as db:
|
||||||
knowledge_bases = self.typed_config.knowledge_bases
|
knowledge_bases = self.typed_config.knowledge_bases
|
||||||
existing_ids = self._get_existing_kb_ids(db, [kb.kb_id for kb in knowledge_bases])
|
|
||||||
|
|
||||||
if not existing_ids:
|
|
||||||
raise RuntimeError("Knowledge base retrieval failed: the knowledge base does not exist.")
|
|
||||||
|
|
||||||
rs = []
|
rs = []
|
||||||
for kb_config in knowledge_bases:
|
for kb_config in knowledge_bases:
|
||||||
db_knowledge = knowledge_repository.get_knowledge_by_id(db=db, knowledge_id=kb_config.kb_id)
|
db_knowledge = knowledge_repository.get_knowledge_by_id(db=db, knowledge_id=kb_config.kb_id)
|
||||||
if not db_knowledge:
|
if not db_knowledge:
|
||||||
raise RuntimeError("The knowledge base does not exist or access is denied.")
|
raise RuntimeError("The knowledge base does not exist or access is denied.")
|
||||||
|
self.knowledge_retrieval(db, query, rs, db_knowledge, kb_config)
|
||||||
|
|
||||||
vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
|
|
||||||
indices = f"Vector_index_{kb_config.kb_id}_Node".lower()
|
|
||||||
match kb_config.retrieve_type:
|
|
||||||
case RetrieveType.PARTICIPLE:
|
|
||||||
rs.extend(vector_service.search_by_full_text(query=query, top_k=kb_config.top_k,
|
|
||||||
indices=indices,
|
|
||||||
score_threshold=kb_config.similarity_threshold))
|
|
||||||
case RetrieveType.SEMANTIC:
|
|
||||||
rs.extend(vector_service.search_by_vector(query=query, top_k=kb_config.top_k,
|
|
||||||
indices=indices,
|
|
||||||
score_threshold=kb_config.vector_similarity_weight))
|
|
||||||
case RetrieveType.HYBRID:
|
|
||||||
rs1 = vector_service.search_by_vector(query=query, top_k=kb_config.top_k,
|
|
||||||
indices=indices,
|
|
||||||
score_threshold=kb_config.vector_similarity_weight)
|
|
||||||
rs2 = vector_service.search_by_full_text(query=query, top_k=kb_config.top_k,
|
|
||||||
indices=indices,
|
|
||||||
score_threshold=kb_config.similarity_threshold)
|
|
||||||
|
|
||||||
# Deduplicate hy brid retrieval results
|
|
||||||
unique_rs = self._deduplicate_docs(rs1, rs2)
|
|
||||||
if not unique_rs:
|
|
||||||
continue
|
|
||||||
if self.typed_config.reranker_id:
|
|
||||||
vector_service.reranker = self.get_reranker_model()
|
|
||||||
rs.extend(vector_service.rerank(query=query, docs=unique_rs, top_k=kb_config.top_k))
|
|
||||||
else:
|
|
||||||
rs.extend(sorted(
|
|
||||||
unique_rs,
|
|
||||||
key=lambda d: d.metadata.get("score", 0),
|
|
||||||
reverse=True
|
|
||||||
)[:kb_config.top_k])
|
|
||||||
case _:
|
|
||||||
raise RuntimeError("Unknown retrieval type")
|
|
||||||
if not rs:
|
if not rs:
|
||||||
return []
|
return []
|
||||||
if self.typed_config.reranker_id:
|
if self.typed_config.reranker_id:
|
||||||
vector_service.reranker = self.get_reranker_model()
|
self.vector_service.reranker = self.get_reranker_model()
|
||||||
final_rs = vector_service.rerank(query=query, docs=rs, top_k=self.typed_config.reranker_top_k)
|
final_rs = self.vector_service.rerank(query=query, docs=rs, top_k=self.typed_config.reranker_top_k)
|
||||||
else:
|
else:
|
||||||
final_rs = sorted(
|
final_rs = sorted(
|
||||||
rs,
|
rs,
|
||||||
|
|||||||
@@ -70,8 +70,8 @@ class LLMNode(BaseNode):
|
|||||||
- ai/assistant: AI 消息(AIMessage)
|
- ai/assistant: AI 消息(AIMessage)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
|
||||||
super().__init__(node_config, workflow_config)
|
super().__init__(node_config, workflow_config, down_stream_nodes)
|
||||||
self.typed_config: LLMNodeConfig | None = None
|
self.typed_config: LLMNodeConfig | None = None
|
||||||
self.messages = []
|
self.messages = []
|
||||||
|
|
||||||
@@ -144,7 +144,6 @@ class LLMNode(BaseNode):
|
|||||||
f"创建 LLM 实例: provider={model_info.provider}, model={model_info.model_name}, streaming={stream}")
|
f"创建 LLM 实例: provider={model_info.provider}, model={model_info.model_name}, streaming={stream}")
|
||||||
|
|
||||||
messages_config = self.typed_config.messages
|
messages_config = self.typed_config.messages
|
||||||
|
|
||||||
if messages_config:
|
if messages_config:
|
||||||
# 使用 LangChain 消息格式
|
# 使用 LangChain 消息格式
|
||||||
messages = []
|
messages = []
|
||||||
@@ -153,7 +152,6 @@ class LLMNode(BaseNode):
|
|||||||
content_template = msg_config.content
|
content_template = msg_config.content
|
||||||
content_template = self._render_context(content_template, variable_pool)
|
content_template = self._render_context(content_template, variable_pool)
|
||||||
content = self._render_template(content_template, variable_pool)
|
content = self._render_template(content_template, variable_pool)
|
||||||
user_id = self.get_variable("sys.user_id", variable_pool)
|
|
||||||
# 根据角色创建对应的消息对象
|
# 根据角色创建对应的消息对象
|
||||||
if role == "system":
|
if role == "system":
|
||||||
messages.append({
|
messages.append({
|
||||||
@@ -161,32 +159,31 @@ class LLMNode(BaseNode):
|
|||||||
"content": await self.process_message(
|
"content": await self.process_message(
|
||||||
model_info,
|
model_info,
|
||||||
content,
|
content,
|
||||||
user_id,
|
|
||||||
self.typed_config.vision,
|
self.typed_config.vision,
|
||||||
)
|
)
|
||||||
})
|
})
|
||||||
elif role in ["user", "human"]:
|
elif role in ["user", "human"]:
|
||||||
messages.append({
|
messages.append({
|
||||||
"role": "user",
|
"role": "user",
|
||||||
"content": await self.process_message(model_info, content, user_id, self.typed_config.vision)
|
"content": await self.process_message(model_info, content, self.typed_config.vision)
|
||||||
})
|
})
|
||||||
elif role in ["ai", "assistant"]:
|
elif role in ["ai", "assistant"]:
|
||||||
messages.append({
|
messages.append({
|
||||||
"role": "assistant",
|
"role": "assistant",
|
||||||
"content": await self.process_message(model_info, content, user_id, self.typed_config.vision)
|
"content": await self.process_message(model_info, content, self.typed_config.vision)
|
||||||
})
|
})
|
||||||
else:
|
else:
|
||||||
logger.warning(f"未知的消息角色: {role},默认使用 user")
|
logger.warning(f"未知的消息角色: {role},默认使用 user")
|
||||||
messages.append({
|
messages.append({
|
||||||
"role": "user",
|
"role": "user",
|
||||||
"content": await self.process_message(model_info, content, user_id, self.typed_config.vision)
|
"content": await self.process_message(model_info, content, self.typed_config.vision)
|
||||||
})
|
})
|
||||||
|
|
||||||
if self.typed_config.vision_input and self.typed_config.vision:
|
if self.typed_config.vision_input and self.typed_config.vision:
|
||||||
file_content = []
|
file_content = []
|
||||||
files = variable_pool.get_instance(self.typed_config.vision_input)
|
files = variable_pool.get_instance(self.typed_config.vision_input)
|
||||||
for file in files.value:
|
for file in files.value:
|
||||||
content = await self.process_message(model_info, file.value, user_id, self.typed_config.vision)
|
content = await self.process_message(model_info, file.value, self.typed_config.vision)
|
||||||
if content:
|
if content:
|
||||||
file_content.extend(content)
|
file_content.extend(content)
|
||||||
if messages and messages[-1]["role"] == 'user':
|
if messages and messages[-1]["role"] == 'user':
|
||||||
@@ -200,7 +197,7 @@ class LLMNode(BaseNode):
|
|||||||
if isinstance(message["content"], list):
|
if isinstance(message["content"], list):
|
||||||
file_content = []
|
file_content = []
|
||||||
for file in message["content"]:
|
for file in message["content"]:
|
||||||
content = await self.process_message(model_info, file, user_id, self.typed_config.vision)
|
content = await self.process_message(model_info, file, self.typed_config.vision)
|
||||||
if content:
|
if content:
|
||||||
file_content.extend(content)
|
file_content.extend(content)
|
||||||
history_message.append(
|
history_message.append(
|
||||||
@@ -210,7 +207,6 @@ class LLMNode(BaseNode):
|
|||||||
message["content"] = await self.process_message(
|
message["content"] = await self.process_message(
|
||||||
model_info,
|
model_info,
|
||||||
message["content"],
|
message["content"],
|
||||||
user_id,
|
|
||||||
self.typed_config.vision
|
self.typed_config.vision
|
||||||
)
|
)
|
||||||
history_message.append(message)
|
history_message.append(message)
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
import re
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from app.core.workflow.engine.state_manager import WorkflowState
|
from app.core.workflow.engine.state_manager import WorkflowState
|
||||||
@@ -5,14 +6,16 @@ from app.core.workflow.engine.variable_pool import VariablePool
|
|||||||
from app.core.workflow.nodes.base_node import BaseNode
|
from app.core.workflow.nodes.base_node import BaseNode
|
||||||
from app.core.workflow.nodes.memory.config import MemoryReadNodeConfig, MemoryWriteNodeConfig
|
from app.core.workflow.nodes.memory.config import MemoryReadNodeConfig, MemoryWriteNodeConfig
|
||||||
from app.core.workflow.variable.base_variable import VariableType
|
from app.core.workflow.variable.base_variable import VariableType
|
||||||
|
from app.core.workflow.variable.variable_objects import FileVariable, ArrayVariable
|
||||||
from app.db import get_db_read
|
from app.db import get_db_read
|
||||||
|
from app.schemas import FileInput
|
||||||
from app.services.memory_agent_service import MemoryAgentService
|
from app.services.memory_agent_service import MemoryAgentService
|
||||||
from app.tasks import write_message_task
|
from app.tasks import write_message_task
|
||||||
|
|
||||||
|
|
||||||
class MemoryReadNode(BaseNode):
|
class MemoryReadNode(BaseNode):
|
||||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
|
||||||
super().__init__(node_config, workflow_config)
|
super().__init__(node_config, workflow_config, down_stream_nodes)
|
||||||
self.typed_config: MemoryReadNodeConfig | None = None
|
self.typed_config: MemoryReadNodeConfig | None = None
|
||||||
|
|
||||||
def _output_types(self) -> dict[str, VariableType]:
|
def _output_types(self) -> dict[str, VariableType]:
|
||||||
@@ -36,19 +39,32 @@ class MemoryReadNode(BaseNode):
|
|||||||
search_switch=self.typed_config.search_switch,
|
search_switch=self.typed_config.search_switch,
|
||||||
history=[],
|
history=[],
|
||||||
db=db,
|
db=db,
|
||||||
storage_type="neo4j",
|
storage_type=state["memory_storage_type"],
|
||||||
user_rag_memory_id=""
|
user_rag_memory_id=state["user_rag_memory_id"]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class MemoryWriteNode(BaseNode):
|
class MemoryWriteNode(BaseNode):
|
||||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
|
||||||
super().__init__(node_config, workflow_config)
|
super().__init__(node_config, workflow_config, down_stream_nodes)
|
||||||
self.typed_config: MemoryWriteNodeConfig | None = None
|
self.typed_config: MemoryWriteNodeConfig | None = None
|
||||||
|
|
||||||
def _output_types(self) -> dict[str, VariableType]:
|
def _output_types(self) -> dict[str, VariableType]:
|
||||||
return {"output": VariableType.STRING}
|
return {"output": VariableType.STRING}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _extract_multimodal_memory_variables(content: str, variable_pool: VariablePool) -> tuple[list[str], str]:
|
||||||
|
variable_pattern_string = r'\{\{\s*[a-zA-Z0-9_]+\.[a-zA-Z0-9_]+\s*\}\}'
|
||||||
|
variable_pattern = re.compile(variable_pattern_string)
|
||||||
|
variables = variable_pattern.findall(content)
|
||||||
|
file_variables = []
|
||||||
|
for variable in variables:
|
||||||
|
if variable_pool.is_file_variable(variable):
|
||||||
|
file_variables.append(variable)
|
||||||
|
for var in file_variables:
|
||||||
|
content = content.replace(var, "")
|
||||||
|
return file_variables, content
|
||||||
|
|
||||||
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any:
|
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any:
|
||||||
self.typed_config = MemoryWriteNodeConfig(**self.config)
|
self.typed_config = MemoryWriteNodeConfig(**self.config)
|
||||||
end_user_id = self.get_variable("sys.user_id", variable_pool)
|
end_user_id = self.get_variable("sys.user_id", variable_pool)
|
||||||
@@ -63,17 +79,42 @@ class MemoryWriteNode(BaseNode):
|
|||||||
})
|
})
|
||||||
|
|
||||||
for message in self.typed_config.messages:
|
for message in self.typed_config.messages:
|
||||||
|
file_variables, content = self._extract_multimodal_memory_variables(
|
||||||
|
message.content,
|
||||||
|
variable_pool
|
||||||
|
)
|
||||||
|
file_info = []
|
||||||
|
for var in file_variables:
|
||||||
|
instence: FileVariable | ArrayVariable[FileVariable] = variable_pool.get_instance(var)
|
||||||
|
if isinstance(instence, FileVariable):
|
||||||
|
file_info.append(FileInput(
|
||||||
|
type=instence.value.type,
|
||||||
|
transfer_method=instence.value.transfer_method,
|
||||||
|
upload_file_id=instence.value.file_id,
|
||||||
|
url=instence.value.url,
|
||||||
|
file_type=instence.value.origin_file_type
|
||||||
|
).model_dump())
|
||||||
|
elif isinstance(instence, ArrayVariable) and instence.child_type == FileVariable:
|
||||||
|
for file_instence in instence.value:
|
||||||
|
file_info.append(FileInput(
|
||||||
|
type=file_instence.value.type,
|
||||||
|
transfer_method=file_instence.value.transfer_method,
|
||||||
|
upload_file_id=file_instence.value.file_id,
|
||||||
|
url=file_instence.value.url,
|
||||||
|
file_type=file_instence.value.origin_file_type
|
||||||
|
).model_dump())
|
||||||
messages.append({
|
messages.append({
|
||||||
"role": message.role,
|
"role": message.role,
|
||||||
"content": self._render_template(message.content, variable_pool)
|
"content": self._render_template(content, variable_pool),
|
||||||
|
"files": file_info
|
||||||
})
|
})
|
||||||
|
|
||||||
write_message_task.delay(
|
write_message_task.delay(
|
||||||
end_user_id,
|
end_user_id=end_user_id,
|
||||||
messages,
|
message=messages,
|
||||||
str(self.typed_config.config_id),
|
config_id=str(self.typed_config.config_id),
|
||||||
"neo4j",
|
storage_type=state["memory_storage_type"],
|
||||||
""
|
user_rag_memory_id=state["user_rag_memory_id"]
|
||||||
)
|
)
|
||||||
|
|
||||||
return "success"
|
return "success"
|
||||||
|
|||||||
@@ -26,6 +26,7 @@ from app.core.workflow.nodes.variable_aggregator import VariableAggregatorNode
|
|||||||
from app.core.workflow.nodes.question_classifier import QuestionClassifierNode
|
from app.core.workflow.nodes.question_classifier import QuestionClassifierNode
|
||||||
from app.core.workflow.nodes.breaker import BreakNode
|
from app.core.workflow.nodes.breaker import BreakNode
|
||||||
from app.core.workflow.nodes.tool import ToolNode
|
from app.core.workflow.nodes.tool import ToolNode
|
||||||
|
from app.core.workflow.nodes.document_extractor import DocExtractorNode
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -49,7 +50,8 @@ WorkflowNode = Union[
|
|||||||
ToolNode,
|
ToolNode,
|
||||||
MemoryReadNode,
|
MemoryReadNode,
|
||||||
MemoryWriteNode,
|
MemoryWriteNode,
|
||||||
CodeNode
|
CodeNode,
|
||||||
|
DocExtractorNode
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@@ -81,6 +83,7 @@ class NodeFactory:
|
|||||||
NodeType.MEMORY_READ: MemoryReadNode,
|
NodeType.MEMORY_READ: MemoryReadNode,
|
||||||
NodeType.MEMORY_WRITE: MemoryWriteNode,
|
NodeType.MEMORY_WRITE: MemoryWriteNode,
|
||||||
NodeType.CODE: CodeNode,
|
NodeType.CODE: CodeNode,
|
||||||
|
NodeType.DOCUMENT_EXTRACTOR: DocExtractorNode
|
||||||
}
|
}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -104,13 +107,15 @@ class NodeFactory:
|
|||||||
def create_node(
|
def create_node(
|
||||||
cls,
|
cls,
|
||||||
node_config: dict[str, Any],
|
node_config: dict[str, Any],
|
||||||
workflow_config: dict[str, Any]
|
workflow_config: dict[str, Any],
|
||||||
|
down_stream_nodes: list[str]
|
||||||
) -> WorkflowNode | None:
|
) -> WorkflowNode | None:
|
||||||
"""创建节点实例
|
"""创建节点实例
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
node_config: 节点配置
|
node_config: 节点配置
|
||||||
workflow_config: 工作流配置
|
workflow_config: 工作流配置
|
||||||
|
down_stream_nodes: 下游节点
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
节点实例或 None(对于不支持的节点类型)
|
节点实例或 None(对于不支持的节点类型)
|
||||||
@@ -127,7 +132,7 @@ class NodeFactory:
|
|||||||
|
|
||||||
# 创建节点实例
|
# 创建节点实例
|
||||||
logger.debug(f"create node instance: {node_config.get('id')} (type={node_type})")
|
logger.debug(f"create node instance: {node_config.get('id')} (type={node_type})")
|
||||||
return node_class(node_config, workflow_config)
|
return node_class(node_config, workflow_config, down_stream_nodes)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_supported_types(cls) -> list[str]:
|
def get_supported_types(cls) -> list[str]:
|
||||||
|
|||||||
@@ -21,8 +21,8 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
class ParameterExtractorNode(BaseNode):
|
class ParameterExtractorNode(BaseNode):
|
||||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
|
||||||
super().__init__(node_config, workflow_config)
|
super().__init__(node_config, workflow_config, down_stream_nodes)
|
||||||
self.typed_config: ParameterExtractorNodeConfig | None = None
|
self.typed_config: ParameterExtractorNodeConfig | None = None
|
||||||
self.response_metadata = {}
|
self.response_metadata = {}
|
||||||
|
|
||||||
|
|||||||
@@ -22,8 +22,8 @@ DEFAULT_EMPTY_QUESTION_CASE = f"{DEFAULT_CASE_PREFIX}1"
|
|||||||
class QuestionClassifierNode(BaseNode):
|
class QuestionClassifierNode(BaseNode):
|
||||||
"""问题分类器节点"""
|
"""问题分类器节点"""
|
||||||
|
|
||||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
|
||||||
super().__init__(node_config, workflow_config)
|
super().__init__(node_config, workflow_config, down_stream_nodes)
|
||||||
self.typed_config: QuestionClassifierNodeConfig | None = None
|
self.typed_config: QuestionClassifierNodeConfig | None = None
|
||||||
self.category_to_case_map = {}
|
self.category_to_case_map = {}
|
||||||
self.response_metadata = {}
|
self.response_metadata = {}
|
||||||
|
|||||||
@@ -27,14 +27,8 @@ class StartNode(BaseNode):
|
|||||||
注意:变量的验证和默认值处理由 Executor 在初始化时完成。
|
注意:变量的验证和默认值处理由 Executor 在初始化时完成。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
|
||||||
"""初始化 Start 节点
|
super().__init__(node_config, workflow_config, down_stream_nodes)
|
||||||
|
|
||||||
Args:
|
|
||||||
node_config: 节点配置
|
|
||||||
workflow_config: 工作流配置
|
|
||||||
"""
|
|
||||||
super().__init__(node_config, workflow_config)
|
|
||||||
|
|
||||||
# 解析并验证配置
|
# 解析并验证配置
|
||||||
self.typed_config: StartNodeConfig | None = None
|
self.typed_config: StartNodeConfig | None = None
|
||||||
@@ -62,7 +56,6 @@ class StartNode(BaseNode):
|
|||||||
包含系统参数、会话变量和自定义变量的字典
|
包含系统参数、会话变量和自定义变量的字典
|
||||||
"""
|
"""
|
||||||
self.typed_config = StartNodeConfig(**self.config)
|
self.typed_config = StartNodeConfig(**self.config)
|
||||||
logger.info(f"节点 {self.node_id} (Start) 开始执行")
|
|
||||||
|
|
||||||
# 处理自定义变量(传入 pool 避免重复创建)
|
# 处理自定义变量(传入 pool 避免重复创建)
|
||||||
custom_vars = self._process_custom_variables(variable_pool)
|
custom_vars = self._process_custom_variables(variable_pool)
|
||||||
@@ -77,9 +70,9 @@ class StartNode(BaseNode):
|
|||||||
**custom_vars # 自定义变量作为节点输出的一部分
|
**custom_vars # 自定义变量作为节点输出的一部分
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.info(
|
logger.debug(
|
||||||
f"节点 {self.node_id} (Start) 执行完成,"
|
f"Node {self.node_id} (Start) execution completed, "
|
||||||
f"输出了 {len(custom_vars)} 个自定义变量"
|
f"outputting {len(custom_vars)} custom variables"
|
||||||
)
|
)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|||||||
@@ -20,8 +20,8 @@ TEMPLATE_PATTERN = re.compile(r"\{\{.*?}}")
|
|||||||
class ToolNode(BaseNode):
|
class ToolNode(BaseNode):
|
||||||
"""工具节点"""
|
"""工具节点"""
|
||||||
|
|
||||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
|
||||||
super().__init__(node_config, workflow_config)
|
super().__init__(node_config, workflow_config, down_stream_nodes)
|
||||||
self.typed_config: ToolNodeConfig | None = None
|
self.typed_config: ToolNodeConfig | None = None
|
||||||
|
|
||||||
def _output_types(self) -> dict[str, VariableType]:
|
def _output_types(self) -> dict[str, VariableType]:
|
||||||
|
|||||||
@@ -12,8 +12,8 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
class VariableAggregatorNode(BaseNode):
|
class VariableAggregatorNode(BaseNode):
|
||||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
|
||||||
super().__init__(node_config, workflow_config)
|
super().__init__(node_config, workflow_config, down_stream_nodes)
|
||||||
self.typed_config: VariableAggregatorNodeConfig | None = None
|
self.typed_config: VariableAggregatorNodeConfig | None = None
|
||||||
|
|
||||||
def _output_types(self) -> dict[str, VariableType]:
|
def _output_types(self) -> dict[str, VariableType]:
|
||||||
|
|||||||
@@ -153,7 +153,8 @@ class TemplateRenderer:
|
|||||||
|
|
||||||
|
|
||||||
# 全局渲染器实例(严格模式)
|
# 全局渲染器实例(严格模式)
|
||||||
_default_renderer = TemplateRenderer(strict=True)
|
_strict_renderer = TemplateRenderer(strict=True)
|
||||||
|
_lenient_renderer = TemplateRenderer(strict=False)
|
||||||
|
|
||||||
|
|
||||||
def render_template(
|
def render_template(
|
||||||
@@ -184,7 +185,7 @@ def render_template(
|
|||||||
... )
|
... )
|
||||||
'请分析: 这是一段文本'
|
'请分析: 这是一段文本'
|
||||||
"""
|
"""
|
||||||
renderer = TemplateRenderer(strict=strict)
|
renderer = _strict_renderer if strict else _lenient_renderer
|
||||||
return renderer.render(template, conv_vars, node_outputs, system_vars)
|
return renderer.render(template, conv_vars, node_outputs, system_vars)
|
||||||
|
|
||||||
|
|
||||||
@@ -197,4 +198,4 @@ def validate_template(template: str) -> list[str]:
|
|||||||
Returns:
|
Returns:
|
||||||
错误列表
|
错误列表
|
||||||
"""
|
"""
|
||||||
return _default_renderer.validate(template)
|
return _strict_renderer.validate(template)
|
||||||
|
|||||||
@@ -6,6 +6,7 @@
|
|||||||
|
|
||||||
import copy
|
import copy
|
||||||
import logging
|
import logging
|
||||||
|
from collections import defaultdict, deque
|
||||||
from typing import Any, Union, TYPE_CHECKING
|
from typing import Any, Union, TYPE_CHECKING
|
||||||
|
|
||||||
from app.core.workflow.nodes.enums import NodeType
|
from app.core.workflow.nodes.enums import NodeType
|
||||||
@@ -119,7 +120,6 @@ class WorkflowValidator:
|
|||||||
errors = []
|
errors = []
|
||||||
|
|
||||||
graphs = cls.get_subgraph(workflow_config)
|
graphs = cls.get_subgraph(workflow_config)
|
||||||
logger.info(graphs)
|
|
||||||
for index, graph in enumerate(graphs):
|
for index, graph in enumerate(graphs):
|
||||||
nodes = graph.get("nodes", [])
|
nodes = graph.get("nodes", [])
|
||||||
edges = graph.get("edges", [])
|
edges = graph.get("edges", [])
|
||||||
@@ -183,7 +183,7 @@ class WorkflowValidator:
|
|||||||
has_cycle, cycle_path = WorkflowValidator._has_cycle(nodes, edges)
|
has_cycle, cycle_path = WorkflowValidator._has_cycle(nodes, edges)
|
||||||
if has_cycle:
|
if has_cycle:
|
||||||
errors.append(
|
errors.append(
|
||||||
f"工作流存在循环依赖(请使用 loop 节点实现循环): {' -> '.join(cycle_path)}"
|
f"工作流存在循环依赖(请使用 loop/iteration 节点实现循环): {' -> '.join(cycle_path)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# 8. 验证变量名
|
# 8. 验证变量名
|
||||||
@@ -204,18 +204,18 @@ class WorkflowValidator:
|
|||||||
Returns:
|
Returns:
|
||||||
可达节点 ID 集合
|
可达节点 ID 集合
|
||||||
"""
|
"""
|
||||||
|
adj = defaultdict(list)
|
||||||
|
for edge in edges:
|
||||||
|
adj[edge["source"]].append(edge["target"])
|
||||||
|
|
||||||
reachable = {start_id}
|
reachable = {start_id}
|
||||||
queue = [start_id]
|
queue = deque([start_id])
|
||||||
|
|
||||||
while queue:
|
while queue:
|
||||||
current = queue.pop(0)
|
current = queue.popleft()
|
||||||
for edge in edges:
|
for target in adj[current]:
|
||||||
if edge.get("source") == current:
|
if target not in reachable:
|
||||||
target = edge.get("target")
|
reachable.add(target)
|
||||||
if target and target not in reachable:
|
queue.append(target)
|
||||||
reachable.add(target)
|
|
||||||
queue.append(target)
|
|
||||||
|
|
||||||
return reachable
|
return reachable
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -229,10 +229,6 @@ class WorkflowValidator:
|
|||||||
Returns:
|
Returns:
|
||||||
(has_cycle, cycle_path): 是否有循环和循环路径
|
(has_cycle, cycle_path): 是否有循环和循环路径
|
||||||
"""
|
"""
|
||||||
# 排除 loop 类型的节点
|
|
||||||
loop_nodes = {n["id"] for n in nodes if n.get("type") == "loop"}
|
|
||||||
|
|
||||||
# 构建邻接表(排除 loop 节点的边和错误边)
|
|
||||||
graph: dict[str, list[str]] = {}
|
graph: dict[str, list[str]] = {}
|
||||||
for edge in edges:
|
for edge in edges:
|
||||||
source = edge.get("source")
|
source = edge.get("source")
|
||||||
@@ -243,10 +239,6 @@ class WorkflowValidator:
|
|||||||
if edge_type == "error":
|
if edge_type == "error":
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# 如果涉及 loop 节点,跳过
|
|
||||||
if source in loop_nodes or target in loop_nodes:
|
|
||||||
continue
|
|
||||||
|
|
||||||
if source and target:
|
if source and target:
|
||||||
if source not in graph:
|
if source not in graph:
|
||||||
graph[source] = []
|
graph[source] = []
|
||||||
|
|||||||
@@ -54,7 +54,7 @@ class DictVariable(BaseVariable):
|
|||||||
|
|
||||||
def valid_value(self, value) -> dict:
|
def valid_value(self, value) -> dict:
|
||||||
if not isinstance(value, dict):
|
if not isinstance(value, dict):
|
||||||
raise TypeError(f"Value must be a dict - {type(value)}:{value}")
|
raise TypeError(f"Value must be a dict - {type(value)}:{value}")
|
||||||
return value
|
return value
|
||||||
|
|
||||||
def to_literal(self) -> str:
|
def to_literal(self) -> str:
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
import os
|
import os
|
||||||
import subprocess
|
import subprocess
|
||||||
|
from app.repositories.neo4j.create_indexes import create_all_indexes
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
|
|
||||||
from fastapi import FastAPI, APIRouter
|
from fastapi import FastAPI, APIRouter
|
||||||
@@ -60,8 +61,10 @@ async def lifespan(app: FastAPI):
|
|||||||
logger.warning(f"加载预定义模型时出错: {str(e)}")
|
logger.warning(f"加载预定义模型时出错: {str(e)}")
|
||||||
else:
|
else:
|
||||||
logger.info("预定义模型加载已禁用 (LOAD_MODEL=false)")
|
logger.info("预定义模型加载已禁用 (LOAD_MODEL=false)")
|
||||||
|
await create_all_indexes()
|
||||||
logger.info("应用程序启动完成")
|
logger.info("应用程序启动完成")
|
||||||
|
|
||||||
|
|
||||||
yield
|
yield
|
||||||
# 应用关闭事件
|
# 应用关闭事件
|
||||||
logger.info("应用程序正在关闭")
|
logger.info("应用程序正在关闭")
|
||||||
@@ -506,10 +509,13 @@ async def http_exception_handler(request: Request, exc: HTTPException):
|
|||||||
404: "errors.common.not_found",
|
404: "errors.common.not_found",
|
||||||
405: "errors.common.method_not_allowed",
|
405: "errors.common.method_not_allowed",
|
||||||
409: "errors.common.conflict",
|
409: "errors.common.conflict",
|
||||||
|
413: "errors.common.payload_too_large",
|
||||||
422: "errors.common.validation_failed",
|
422: "errors.common.validation_failed",
|
||||||
429: "errors.common.too_many_requests",
|
429: "errors.common.too_many_requests",
|
||||||
500: "errors.common.internal_error",
|
500: "errors.common.internal_error",
|
||||||
|
502: "errors.common.bad_gateway",
|
||||||
503: "errors.common.service_unavailable",
|
503: "errors.common.service_unavailable",
|
||||||
|
504: "errors.common.gateway_timeout",
|
||||||
}
|
}
|
||||||
|
|
||||||
# 如果有对应的翻译键,使用翻译
|
# 如果有对应的翻译键,使用翻译
|
||||||
@@ -534,7 +540,7 @@ async def http_exception_handler(request: Request, exc: HTTPException):
|
|||||||
|
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
status_code=exc.status_code,
|
status_code=exc.status_code,
|
||||||
content=fail(code=exc.status_code, msg=translated_message, error=translated_message)
|
content=fail(code=exc.status_code, msg=translated_message, error=exc.detail)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ from .agent_app_config_model import AgentConfig
|
|||||||
from .app_release_model import AppRelease
|
from .app_release_model import AppRelease
|
||||||
from .memory_increment_model import MemoryIncrement
|
from .memory_increment_model import MemoryIncrement
|
||||||
from .end_user_model import EndUser
|
from .end_user_model import EndUser
|
||||||
|
from .end_user_info_model import EndUserInfo
|
||||||
from .appshare_model import AppShare
|
from .appshare_model import AppShare
|
||||||
from .release_share_model import ReleaseShare
|
from .release_share_model import ReleaseShare
|
||||||
from .conversation_model import Conversation, Message
|
from .conversation_model import Conversation, Message
|
||||||
@@ -60,6 +61,7 @@ __all__ = [
|
|||||||
"AppRelease",
|
"AppRelease",
|
||||||
"MemoryIncrement",
|
"MemoryIncrement",
|
||||||
"EndUser",
|
"EndUser",
|
||||||
|
"EndUserInfo",
|
||||||
"AppShare",
|
"AppShare",
|
||||||
"ReleaseShare",
|
"ReleaseShare",
|
||||||
"Conversation",
|
"Conversation",
|
||||||
|
|||||||
24
api/app/models/end_user_info_model.py
Normal file
24
api/app/models/end_user_info_model.py
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
import datetime
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
from sqlalchemy import Column, DateTime, ForeignKey, String, Text, ARRAY
|
||||||
|
from sqlalchemy.dialects.postgresql import UUID, JSONB
|
||||||
|
from sqlalchemy.orm import relationship
|
||||||
|
|
||||||
|
from app.db import Base
|
||||||
|
|
||||||
|
|
||||||
|
class EndUserInfo(Base):
|
||||||
|
"""终端用户信息表 - 存储用户的别名和扩展信息"""
|
||||||
|
__tablename__ = "end_user_info"
|
||||||
|
|
||||||
|
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, nullable=False, index=True)
|
||||||
|
end_user_id = Column(UUID(as_uuid=True), ForeignKey("end_users.id"), nullable=False, index=True, comment="关联的终端用户ID")
|
||||||
|
other_name = Column(String, nullable=False, comment="关联的用户名称")
|
||||||
|
aliases = Column(ARRAY(String), nullable=True, comment="用户别名列表(字符串数组)")
|
||||||
|
meta_data = Column(JSONB, nullable=True, comment="用户相关的扩展信息(JSON格式)")
|
||||||
|
created_at = Column(DateTime, default=datetime.datetime.now, comment="创建时间")
|
||||||
|
updated_at = Column(DateTime, default=datetime.datetime.now, onupdate=datetime.datetime.now, comment="更新时间")
|
||||||
|
|
||||||
|
# 与 EndUser 的关系
|
||||||
|
end_user = relationship("EndUser", back_populates="info")
|
||||||
@@ -22,6 +22,14 @@ class EndUser(Base):
|
|||||||
created_at = Column(DateTime, default=datetime.datetime.now)
|
created_at = Column(DateTime, default=datetime.datetime.now)
|
||||||
updated_at = Column(DateTime, default=datetime.datetime.now, onupdate=datetime.datetime.now)
|
updated_at = Column(DateTime, default=datetime.datetime.now, onupdate=datetime.datetime.now)
|
||||||
|
|
||||||
|
# 用户档案字段 - User Profile Fields
|
||||||
|
position = Column(String, nullable=True, comment="职位")
|
||||||
|
department = Column(String, nullable=True, comment="部门")
|
||||||
|
contact = Column(String, nullable=True, comment="联系方式")
|
||||||
|
phone = Column(String, nullable=True, comment="电话")
|
||||||
|
hire_date = Column(DateTime, nullable=True, comment="入职日期")
|
||||||
|
updatetime_profile = Column(DateTime, nullable=True, comment="核心档案信息最后更新时间")
|
||||||
|
|
||||||
memory_config_id = Column(
|
memory_config_id = Column(
|
||||||
UUID(as_uuid=True),
|
UUID(as_uuid=True),
|
||||||
ForeignKey("memory_config.config_id"),
|
ForeignKey("memory_config.config_id"),
|
||||||
@@ -30,14 +38,6 @@ class EndUser(Base):
|
|||||||
comment="关联的记忆配置ID"
|
comment="关联的记忆配置ID"
|
||||||
)
|
)
|
||||||
|
|
||||||
# 用户基本信息字段
|
|
||||||
position = Column(String, nullable=True, comment="职位")
|
|
||||||
department = Column(String, nullable=True, comment="部门")
|
|
||||||
contact = Column(String, nullable=True, comment="联系方式")
|
|
||||||
phone = Column(String, nullable=True, comment="电话")
|
|
||||||
hire_date = Column(DateTime, nullable=True, comment="入职日期")
|
|
||||||
updatetime_profile = Column(DateTime, nullable=True, comment="核心档案信息最后更新时间")
|
|
||||||
|
|
||||||
# 用户摘要四个维度 - User Summary Four Dimensions
|
# 用户摘要四个维度 - User Summary Four Dimensions
|
||||||
user_summary = Column(Text, nullable=True, comment="缓存的用户摘要(基本介绍)")
|
user_summary = Column(Text, nullable=True, comment="缓存的用户摘要(基本介绍)")
|
||||||
personality_traits = Column(Text, nullable=True, comment="性格特点")
|
personality_traits = Column(Text, nullable=True, comment="性格特点")
|
||||||
@@ -66,3 +66,6 @@ class EndUser(Base):
|
|||||||
|
|
||||||
# 与 WorkSpace 的反向关系
|
# 与 WorkSpace 的反向关系
|
||||||
workspace = relationship("Workspace", back_populates="end_users")
|
workspace = relationship("Workspace", back_populates="end_users")
|
||||||
|
|
||||||
|
# 与 EndUserInfo 的反向关系
|
||||||
|
info = relationship("EndUserInfo", back_populates="end_user", cascade="all, delete-orphan")
|
||||||
@@ -30,6 +30,9 @@ class MemoryConfig(Base):
|
|||||||
llm_id = Column(String, nullable=True, comment="LLM模型配置ID")
|
llm_id = Column(String, nullable=True, comment="LLM模型配置ID")
|
||||||
embedding_id = Column(String, nullable=True, comment="嵌入模型配置ID")
|
embedding_id = Column(String, nullable=True, comment="嵌入模型配置ID")
|
||||||
rerank_id = Column(String, nullable=True, comment="重排序模型配置ID")
|
rerank_id = Column(String, nullable=True, comment="重排序模型配置ID")
|
||||||
|
vision_id = Column(String, nullable=True, comment="视觉模型配置ID")
|
||||||
|
audio_id = Column(String, nullable=True, comment="语音模型配置ID")
|
||||||
|
video_id = Column(String, nullable=True, comment="视频模型配置ID")
|
||||||
|
|
||||||
# 记忆萃取引擎配置
|
# 记忆萃取引擎配置
|
||||||
enable_llm_dedup_blockwise = Column(Boolean, default=True, comment="启用LLM决策去重")
|
enable_llm_dedup_blockwise = Column(Boolean, default=True, comment="启用LLM决策去重")
|
||||||
|
|||||||
@@ -2,10 +2,11 @@ import datetime
|
|||||||
import uuid
|
import uuid
|
||||||
from enum import StrEnum
|
from enum import StrEnum
|
||||||
|
|
||||||
from sqlalchemy import Column, String, Boolean, DateTime, Text, ForeignKey, Enum as SQLEnum, UniqueConstraint, Integer, ARRAY, Table, text
|
from sqlalchemy import Column, String, Boolean, DateTime, Text, ForeignKey, UniqueConstraint, Integer, Table, text
|
||||||
from sqlalchemy.dialects.postgresql import UUID, JSON
|
from sqlalchemy.dialects.postgresql import UUID, JSON, ARRAY
|
||||||
from sqlalchemy.orm import relationship
|
from sqlalchemy.orm import relationship
|
||||||
from sqlalchemy.sql import func
|
from sqlalchemy.sql import func
|
||||||
|
|
||||||
from app.db import Base
|
from app.db import Base
|
||||||
|
|
||||||
|
|
||||||
@@ -26,9 +27,9 @@ class ModelType(StrEnum):
|
|||||||
RERANK = "rerank"
|
RERANK = "rerank"
|
||||||
# TTS = "tts"
|
# TTS = "tts"
|
||||||
# SPEECH2TEXT = "speech2text"
|
# SPEECH2TEXT = "speech2text"
|
||||||
# IMAGE = "image"
|
IMAGE = "image"
|
||||||
# AUDIO = "audio"
|
# AUDIO = "audio"
|
||||||
# VISION = "vision"
|
VIDEO = "video"
|
||||||
|
|
||||||
|
|
||||||
class ModelProvider(StrEnum):
|
class ModelProvider(StrEnum):
|
||||||
@@ -45,6 +46,7 @@ class ModelProvider(StrEnum):
|
|||||||
XINFERENCE = "xinference"
|
XINFERENCE = "xinference"
|
||||||
GPUSTACK = "gpustack"
|
GPUSTACK = "gpustack"
|
||||||
BEDROCK = "bedrock"
|
BEDROCK = "bedrock"
|
||||||
|
VOLCANO = "volcano"
|
||||||
COMPOSITE = "composite"
|
COMPOSITE = "composite"
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user