Compare commits
263 Commits
v0.2.5-hot
...
v0.2.6
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b9ebe22df1 | ||
|
|
6c8dca6379 | ||
|
|
819d205166 | ||
|
|
9e17f65eda | ||
|
|
7373f68172 | ||
|
|
0999bd30d7 | ||
|
|
f01185a7fc | ||
|
|
7cd7303754 | ||
|
|
d19fec2155 | ||
|
|
df18868888 | ||
|
|
4438b08560 | ||
|
|
1029f94669 | ||
|
|
0a3acf446d | ||
|
|
5a7723553c | ||
|
|
975844eccf | ||
|
|
865ad31f2f | ||
|
|
b756f0c86c | ||
|
|
3e5f6176af | ||
|
|
ab5b165dc2 | ||
|
|
f9393c2f63 | ||
|
|
aa6638424c | ||
|
|
834387e254 | ||
|
|
9caa986c80 | ||
|
|
72b84dfc8f | ||
|
|
af10195025 | ||
|
|
22382423ad | ||
|
|
0f80c67cbd | ||
|
|
aa6473c1c7 | ||
|
|
cde61cb6ac | ||
|
|
b1368997c2 | ||
|
|
ec7dc448c1 | ||
|
|
254147265e | ||
|
|
479bba9a4e | ||
|
|
cfb39a6baa | ||
|
|
05c9ed1450 | ||
|
|
f53633a8b8 | ||
|
|
63882e9391 | ||
|
|
3c4dfb868f | ||
|
|
cae9105b8d | ||
|
|
2c9401ccfb | ||
|
|
2b0dedc81c | ||
|
|
16b87de0df | ||
|
|
8c3af7f4ff | ||
|
|
5f56cc8056 | ||
|
|
827ab27bef | ||
|
|
ccc67df8df | ||
|
|
82538c469f | ||
|
|
076ceee29d | ||
|
|
822b73b015 | ||
|
|
862bff51cb | ||
|
|
bccbeaabe4 | ||
|
|
03676b7adc | ||
|
|
af6fde414f | ||
|
|
d069809001 | ||
|
|
fc240849cf | ||
|
|
61d2a328fe | ||
|
|
fed0ae8e9c | ||
|
|
eaf0de453b | ||
|
|
e833db954a | ||
|
|
0b2651f4ed | ||
|
|
10c677a6fd | ||
|
|
3398c4737a | ||
|
|
a008f5fbef | ||
|
|
6a42e73667 | ||
|
|
7611db19f3 | ||
|
|
d3399dfaf5 | ||
|
|
248f0d95ac | ||
|
|
5c39d841ee | ||
|
|
87be67cb9a | ||
|
|
1a08bea864 | ||
|
|
bc4406cec6 | ||
|
|
4206c849c3 | ||
|
|
3f052b7798 | ||
|
|
f1c5f24f6b | ||
|
|
e981c95225 | ||
|
|
4ce4f53835 | ||
|
|
f16e369540 | ||
|
|
47bf93d65e | ||
|
|
5c2e0af33e | ||
|
|
aaa0410781 | ||
|
|
366b148f3d | ||
|
|
6a265de31c | ||
|
|
c3707f543c | ||
|
|
8de368348b | ||
|
|
d052c31ac5 | ||
|
|
31320afed6 | ||
|
|
7afe507296 | ||
|
|
4188443101 | ||
|
|
a1fc0fd394 | ||
|
|
71fe35533d | ||
|
|
a2ed335e59 | ||
|
|
8422a05d74 | ||
|
|
139ae3bcb4 | ||
|
|
a0a57d5fbb | ||
|
|
80fa88ac37 | ||
|
|
0fda1c752d | ||
|
|
6c2fc75199 | ||
|
|
2cb6aeb022 | ||
|
|
e0174f75b3 | ||
|
|
51d04746a3 | ||
|
|
3b08d6c320 | ||
|
|
495c5802a0 | ||
|
|
621b074b3d | ||
|
|
6df32983b5 | ||
|
|
9c9fe9dde7 | ||
|
|
128c1a6178 | ||
|
|
f90e102854 | ||
|
|
2e1eb9a5a6 | ||
|
|
60a95f6556 | ||
|
|
218637e81d | ||
|
|
404f78af0f | ||
|
|
6301528301 | ||
|
|
6feea968e0 | ||
|
|
b5199b2eb9 | ||
|
|
78ce2a9a8b | ||
|
|
6ed542b007 | ||
|
|
5322b0c4a3 | ||
|
|
a72d5d2c77 | ||
|
|
16c1cbe24f | ||
|
|
0d8f4c76e7 | ||
|
|
e511b14933 | ||
|
|
b5ba53208e | ||
|
|
b8bfb4d0c5 | ||
|
|
1b666638bc | ||
|
|
2bd364eca3 | ||
|
|
f27fc51801 | ||
|
|
0f85eff76b | ||
|
|
0def474cc2 | ||
|
|
590ec3a446 | ||
|
|
23bfdcefef | ||
|
|
647a978865 | ||
|
|
86f72100f0 | ||
|
|
8b255259ba | ||
|
|
8aad8faae9 | ||
|
|
420f391f3c | ||
|
|
817221347f | ||
|
|
13dce5e265 | ||
|
|
850d9ee70b | ||
|
|
ba36ccb21f | ||
|
|
f712754927 | ||
|
|
efe3865aa4 | ||
|
|
53dbe2f436 | ||
|
|
720498084b | ||
|
|
f5eda38dc9 | ||
|
|
8ada221777 | ||
|
|
4ee198813a | ||
|
|
440e8acd99 | ||
|
|
37325e9802 | ||
|
|
778bc4bd70 | ||
|
|
f78f59ec42 | ||
|
|
d4c4160215 | ||
|
|
85aea97c21 | ||
|
|
b075cad4de | ||
|
|
f326febc8a | ||
|
|
1738e45090 | ||
|
|
6e758faa37 | ||
|
|
32e79c5df0 | ||
|
|
da4a1f536d | ||
|
|
b3af757167 | ||
|
|
82794f051a | ||
|
|
c041d24989 | ||
|
|
1d662fb63e | ||
|
|
d1933d2aef | ||
|
|
163872be6e | ||
|
|
14fcb66a9c | ||
|
|
c488eb0cd0 | ||
|
|
91d20f7272 | ||
|
|
c3d7963fe0 | ||
|
|
c31a92bf01 | ||
|
|
b5703c1b82 | ||
|
|
df34735a9b | ||
|
|
31bee889d7 | ||
|
|
b3ba0a6ed6 | ||
|
|
ce3b7897d7 | ||
|
|
9115ad6950 | ||
|
|
c6b76438f4 | ||
|
|
68c4c7429c | ||
|
|
8466c8e019 | ||
|
|
d899b27448 | ||
|
|
66c153f1ad | ||
|
|
c6c7a1827c | ||
|
|
8fdaebbe6e | ||
|
|
9a98ccff2c | ||
|
|
ee4027c561 | ||
|
|
7f36a06f26 | ||
|
|
0826a34d8b | ||
|
|
1792cb4d93 | ||
|
|
304ccef101 | ||
|
|
bdc22c892d | ||
|
|
a5034e84ba | ||
|
|
6e2de96fed | ||
|
|
2b6d86e591 | ||
|
|
8c6f4cb117 | ||
|
|
16d4b32eb7 | ||
|
|
45a64dbbac | ||
|
|
537668b463 | ||
|
|
07fea23dd0 | ||
|
|
cef14291f0 | ||
|
|
bbde0588af | ||
|
|
aa7d52568b | ||
|
|
f39c77ac70 | ||
|
|
aa733354e8 | ||
|
|
7cec966979 | ||
|
|
74865d2cf2 | ||
|
|
c9a8753473 | ||
|
|
ce8a2cbe34 | ||
|
|
c0fdd0c6d3 | ||
|
|
88bfcfe6cd | ||
|
|
c4dcf1fd65 | ||
|
|
6cebddf893 | ||
|
|
1738ed3664 | ||
|
|
37ddcb91ac | ||
|
|
574ab4506b | ||
|
|
81353538e5 | ||
|
|
5abfcdfbe8 | ||
|
|
9962a61c21 | ||
|
|
5cf2b08777 | ||
|
|
9be1c01b70 | ||
|
|
62b2ecdfc2 | ||
|
|
2ff9000d25 | ||
|
|
5829148ce4 | ||
|
|
8e15a340f6 | ||
|
|
1270b7cdd8 | ||
|
|
7c02fe8148 | ||
|
|
4ac63e1c23 | ||
|
|
4aeb653ed2 | ||
|
|
2d5c2de613 | ||
|
|
96590941cf | ||
|
|
0655ff4a91 | ||
|
|
0ba370052e | ||
|
|
4d59e04aba | ||
|
|
6db6c33564 | ||
|
|
ed0d963aec | ||
|
|
3a36d038ee | ||
|
|
3d068a9c96 | ||
|
|
87df352adc | ||
|
|
8b546b7366 | ||
|
|
77ea0680fb | ||
|
|
4c592bf7e3 | ||
|
|
6718553bf4 | ||
|
|
79dc6f3f69 | ||
|
|
8df72d2822 | ||
|
|
3ce5926689 | ||
|
|
035464c0ac | ||
|
|
f1fcffbfc0 | ||
|
|
b79fe07052 | ||
|
|
e6aa0e0e10 | ||
|
|
54700e6fbe | ||
|
|
3a0671c661 | ||
|
|
1037729fb3 | ||
|
|
5f211620c5 | ||
|
|
cb6a3aae9e | ||
|
|
5e512df3d4 | ||
|
|
9916cf3265 | ||
|
|
f7aed9dd98 | ||
|
|
5253cf3899 | ||
|
|
f7d92be5ea | ||
|
|
97d8168824 | ||
|
|
550bd4da23 | ||
|
|
4f0b653a82 | ||
|
|
616709acbb | ||
|
|
67053ab8ae | ||
|
|
33238d34c9 |
1
.gitignore
vendored
1
.gitignore
vendored
@@ -29,6 +29,7 @@ search_results.json
|
||||
api/migrations/versions
|
||||
tmp
|
||||
files
|
||||
powers/
|
||||
|
||||
# Exclude dep files
|
||||
huggingface.co/
|
||||
|
||||
@@ -226,8 +226,8 @@ REDIS_PORT=6379
|
||||
REDIS_DB=1
|
||||
|
||||
# Celery (Using Redis as broker)
|
||||
BROKER_URL=redis://127.0.0.1:6379/0
|
||||
RESULT_BACKEND=redis://127.0.0.1:6379/0
|
||||
REDIS_DB_CELERY_BROKER=1
|
||||
REDIS_DB_CELERY_BACKEND=2
|
||||
|
||||
# JWT Secret Key (Formation method: openssl rand -hex 32)
|
||||
SECRET_KEY=your-secret-key-here
|
||||
|
||||
@@ -201,8 +201,8 @@ REDIS_PORT=6379
|
||||
REDIS_DB=1
|
||||
|
||||
# Celery (使用Redis作为broker)
|
||||
BROKER_URL=redis://127.0.0.1:6379/0
|
||||
RESULT_BACKEND=redis://127.0.0.1:6379/0
|
||||
REDIS_DB_CELERY_BROKER=1
|
||||
REDIS_DB_CELERY_BACKEND=2
|
||||
|
||||
# JWT密钥 (生成方式: openssl rand -hex 32)
|
||||
SECRET_KEY=your-secret-key-here
|
||||
|
||||
@@ -10,7 +10,6 @@ from app.core.config import settings
|
||||
# 设置日志记录器
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# 创建连接池
|
||||
pool = ConnectionPool.from_url(
|
||||
f"redis://{settings.REDIS_HOST}:{settings.REDIS_PORT}",
|
||||
@@ -21,6 +20,7 @@ pool = ConnectionPool.from_url(
|
||||
)
|
||||
aio_redis = redis.StrictRedis(connection_pool=pool)
|
||||
|
||||
|
||||
async def get_redis_connection():
|
||||
"""获取Redis连接"""
|
||||
try:
|
||||
@@ -29,7 +29,8 @@ async def get_redis_connection():
|
||||
logger.error(f"Redis连接失败: {str(e)}")
|
||||
return None
|
||||
|
||||
async def aio_redis_set(key: str, val: str|dict, expire: int = None):
|
||||
|
||||
async def aio_redis_set(key: str, val: str | dict, expire: int = None):
|
||||
"""设置Redis键值
|
||||
|
||||
Args:
|
||||
@@ -40,7 +41,7 @@ async def aio_redis_set(key: str, val: str|dict, expire: int = None):
|
||||
try:
|
||||
if isinstance(val, dict):
|
||||
val = json.dumps(val, ensure_ascii=False)
|
||||
|
||||
|
||||
if expire is not None:
|
||||
# 设置带过期时间的键值
|
||||
await aio_redis.set(key, val, ex=expire)
|
||||
@@ -50,6 +51,7 @@ async def aio_redis_set(key: str, val: str|dict, expire: int = None):
|
||||
except Exception as e:
|
||||
logger.error(f"Redis set错误: {str(e)}")
|
||||
|
||||
|
||||
async def aio_redis_get(key: str):
|
||||
"""获取Redis键值"""
|
||||
try:
|
||||
@@ -58,6 +60,7 @@ async def aio_redis_get(key: str):
|
||||
logger.error(f"Redis get错误: {str(e)}")
|
||||
return None
|
||||
|
||||
|
||||
async def aio_redis_delete(key: str):
|
||||
"""删除Redis键"""
|
||||
try:
|
||||
@@ -66,6 +69,7 @@ async def aio_redis_delete(key: str):
|
||||
logger.error(f"Redis delete错误: {str(e)}")
|
||||
return None
|
||||
|
||||
|
||||
async def aio_redis_publish(channel: str, message: Dict[str, Any]) -> bool:
|
||||
"""发布消息到Redis频道"""
|
||||
try:
|
||||
@@ -78,9 +82,10 @@ async def aio_redis_publish(channel: str, message: Dict[str, Any]) -> bool:
|
||||
logger.error(f"Redis发布错误: {str(e)}")
|
||||
return False
|
||||
|
||||
|
||||
class RedisSubscriber:
|
||||
"""Redis订阅器"""
|
||||
|
||||
|
||||
def __init__(self, channel: str):
|
||||
self.channel = channel
|
||||
self.conn = None
|
||||
@@ -88,25 +93,25 @@ class RedisSubscriber:
|
||||
self.is_closed = False
|
||||
self._queue = asyncio.Queue()
|
||||
self._task = None
|
||||
|
||||
|
||||
async def start(self):
|
||||
"""开始订阅"""
|
||||
if self.is_closed or self._task:
|
||||
return
|
||||
|
||||
|
||||
self._task = asyncio.create_task(self._receive_messages())
|
||||
logger.info(f"开始订阅: {self.channel}")
|
||||
|
||||
|
||||
async def _receive_messages(self):
|
||||
"""接收消息"""
|
||||
try:
|
||||
self.conn = await get_redis_connection()
|
||||
if not self.conn:
|
||||
return
|
||||
|
||||
|
||||
self.pubsub = self.conn.pubsub()
|
||||
await self.pubsub.subscribe(self.channel)
|
||||
|
||||
|
||||
while not self.is_closed:
|
||||
try:
|
||||
message = await self.pubsub.get_message(ignore_subscribe_messages=True, timeout=0.01)
|
||||
@@ -127,7 +132,7 @@ class RedisSubscriber:
|
||||
finally:
|
||||
await self._queue.put(None)
|
||||
await self._cleanup()
|
||||
|
||||
|
||||
async def _cleanup(self):
|
||||
"""清理资源"""
|
||||
if self.pubsub:
|
||||
@@ -141,7 +146,7 @@ class RedisSubscriber:
|
||||
await self.conn.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
async def get_message(self) -> Optional[Dict[str, Any]]:
|
||||
"""获取消息"""
|
||||
if self.is_closed:
|
||||
@@ -153,7 +158,7 @@ class RedisSubscriber:
|
||||
except Exception as e:
|
||||
logger.error(f"获取消息错误: {str(e)}")
|
||||
return None
|
||||
|
||||
|
||||
async def close(self):
|
||||
"""关闭订阅器"""
|
||||
if self.is_closed:
|
||||
@@ -163,32 +168,33 @@ class RedisSubscriber:
|
||||
self._task.cancel()
|
||||
await self._cleanup()
|
||||
|
||||
|
||||
class RedisPubSubManager:
|
||||
"""Redis发布订阅管理器"""
|
||||
|
||||
|
||||
def __init__(self):
|
||||
self.subscribers = {}
|
||||
|
||||
|
||||
async def publish(self, channel: str, message: Dict[str, Any]) -> bool:
|
||||
return await aio_redis_publish(channel, message)
|
||||
|
||||
|
||||
def get_subscriber(self, channel: str) -> RedisSubscriber:
|
||||
if channel in self.subscribers:
|
||||
subscriber = self.subscribers[channel]
|
||||
if not subscriber.is_closed:
|
||||
return subscriber
|
||||
|
||||
|
||||
subscriber = RedisSubscriber(channel)
|
||||
self.subscribers[channel] = subscriber
|
||||
return subscriber
|
||||
|
||||
|
||||
def cancel_subscription(self, channel: str) -> bool:
|
||||
if channel in self.subscribers:
|
||||
asyncio.create_task(self.subscribers[channel].close())
|
||||
del self.subscribers[channel]
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def cancel_all_subscriptions(self) -> int:
|
||||
count = len(self.subscribers)
|
||||
for subscriber in self.subscribers.values():
|
||||
@@ -196,6 +202,6 @@ class RedisPubSubManager:
|
||||
self.subscribers.clear()
|
||||
return count
|
||||
|
||||
|
||||
# 全局实例
|
||||
pubsub_manager = RedisPubSubManager()
|
||||
|
||||
|
||||
6
api/app/cache/__init__.py
vendored
6
api/app/cache/__init__.py
vendored
@@ -2,7 +2,9 @@
|
||||
Cache 缓存模块
|
||||
|
||||
提供各种缓存功能的统一入口
|
||||
注意:隐性记忆和情绪建议已迁移到数据库存储,不再使用Redis缓存
|
||||
"""
|
||||
from .memory import InterestMemoryCache
|
||||
|
||||
__all__ = []
|
||||
__all__ = [
|
||||
"InterestMemoryCache",
|
||||
]
|
||||
|
||||
8
api/app/cache/memory/__init__.py
vendored
8
api/app/cache/memory/__init__.py
vendored
@@ -2,7 +2,11 @@
|
||||
Memory 缓存模块
|
||||
|
||||
提供记忆系统相关的缓存功能
|
||||
注意:隐性记忆和情绪建议已迁移到数据库存储,不再使用Redis缓存
|
||||
"""
|
||||
from .interest_memory import InterestMemoryCache
|
||||
from .activity_stats_cache import ActivityStatsCache
|
||||
|
||||
__all__ = []
|
||||
__all__ = [
|
||||
"InterestMemoryCache",
|
||||
"ActivityStatsCache",
|
||||
]
|
||||
|
||||
124
api/app/cache/memory/activity_stats_cache.py
vendored
Normal file
124
api/app/cache/memory/activity_stats_cache.py
vendored
Normal file
@@ -0,0 +1,124 @@
|
||||
"""
|
||||
Recent Activity Stats Cache
|
||||
|
||||
记忆提取活动统计缓存模块
|
||||
用于缓存每次记忆提取流程的统计数据,按 workspace_id 存储,24小时后释放
|
||||
查询命令:cache:memory:activity_stats:by_workspace:7de31a97-40a6-4fc0-b8d3-15c89f523843
|
||||
"""
|
||||
import json
|
||||
import logging
|
||||
from typing import Optional, Dict, Any
|
||||
from datetime import datetime
|
||||
|
||||
from app.aioRedis import aio_redis
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 缓存过期时间:24小时
|
||||
ACTIVITY_STATS_CACHE_EXPIRE = 86400
|
||||
|
||||
|
||||
class ActivityStatsCache:
|
||||
"""记忆提取活动统计缓存类"""
|
||||
|
||||
PREFIX = "cache:memory:activity_stats"
|
||||
|
||||
@classmethod
|
||||
def _get_key(cls, workspace_id: str) -> str:
|
||||
"""生成 Redis key
|
||||
|
||||
Args:
|
||||
workspace_id: 工作空间ID
|
||||
|
||||
Returns:
|
||||
完整的 Redis key
|
||||
"""
|
||||
return f"{cls.PREFIX}:by_workspace:{workspace_id}"
|
||||
|
||||
@classmethod
|
||||
async def set_activity_stats(
|
||||
cls,
|
||||
workspace_id: str,
|
||||
stats: Dict[str, Any],
|
||||
expire: int = ACTIVITY_STATS_CACHE_EXPIRE,
|
||||
) -> bool:
|
||||
"""设置记忆提取活动统计缓存
|
||||
|
||||
Args:
|
||||
workspace_id: 工作空间ID
|
||||
stats: 统计数据,格式:
|
||||
{
|
||||
"chunk_count": int,
|
||||
"statements_count": int,
|
||||
"triplet_entities_count": int,
|
||||
"triplet_relations_count": int,
|
||||
"temporal_count": int,
|
||||
}
|
||||
expire: 过期时间(秒),默认24小时
|
||||
|
||||
Returns:
|
||||
是否设置成功
|
||||
"""
|
||||
try:
|
||||
key = cls._get_key(workspace_id)
|
||||
payload = {
|
||||
"stats": stats,
|
||||
"generated_at": datetime.now().isoformat(),
|
||||
"workspace_id": workspace_id,
|
||||
"cached": True,
|
||||
}
|
||||
value = json.dumps(payload, ensure_ascii=False)
|
||||
await aio_redis.set(key, value, ex=expire)
|
||||
logger.info(f"设置活动统计缓存成功: {key}, 过期时间: {expire}秒")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"设置活动统计缓存失败: {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
async def get_activity_stats(
|
||||
cls,
|
||||
workspace_id: str,
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""获取记忆提取活动统计缓存
|
||||
|
||||
Args:
|
||||
workspace_id: 工作空间ID
|
||||
|
||||
Returns:
|
||||
统计数据字典,缓存不存在或已过期返回 None
|
||||
"""
|
||||
try:
|
||||
key = cls._get_key(workspace_id)
|
||||
value = await aio_redis.get(key)
|
||||
if value:
|
||||
payload = json.loads(value)
|
||||
logger.info(f"命中活动统计缓存: {key}")
|
||||
return payload
|
||||
logger.info(f"活动统计缓存不存在或已过期: {key}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"获取活动统计缓存失败: {e}", exc_info=True)
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
async def delete_activity_stats(
|
||||
cls,
|
||||
workspace_id: str,
|
||||
) -> bool:
|
||||
"""删除记忆提取活动统计缓存
|
||||
|
||||
Args:
|
||||
workspace_id: 工作空间ID
|
||||
|
||||
Returns:
|
||||
是否删除成功
|
||||
"""
|
||||
try:
|
||||
key = cls._get_key(workspace_id)
|
||||
result = await aio_redis.delete(key)
|
||||
logger.info(f"删除活动统计缓存: {key}, 结果: {result}")
|
||||
return result > 0
|
||||
except Exception as e:
|
||||
logger.error(f"删除活动统计缓存失败: {e}", exc_info=True)
|
||||
return False
|
||||
122
api/app/cache/memory/interest_memory.py
vendored
Normal file
122
api/app/cache/memory/interest_memory.py
vendored
Normal file
@@ -0,0 +1,122 @@
|
||||
"""
|
||||
Interest Distribution Cache
|
||||
|
||||
兴趣分布缓存模块
|
||||
用于缓存用户的兴趣分布标签数据,避免重复调用模型生成
|
||||
"""
|
||||
import json
|
||||
import logging
|
||||
from typing import Optional, List, Dict, Any
|
||||
from datetime import datetime
|
||||
|
||||
from app.aioRedis import aio_redis
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 缓存过期时间:24小时
|
||||
INTEREST_CACHE_EXPIRE = 86400
|
||||
|
||||
|
||||
class InterestMemoryCache:
|
||||
"""兴趣分布缓存类"""
|
||||
|
||||
PREFIX = "cache:memory:interest_distribution"
|
||||
|
||||
@classmethod
|
||||
def _get_key(cls, end_user_id: str, language: str) -> str:
|
||||
"""生成 Redis key
|
||||
|
||||
Args:
|
||||
end_user_id: 用户ID
|
||||
language: 语言类型
|
||||
|
||||
Returns:
|
||||
完整的 Redis key
|
||||
"""
|
||||
return f"{cls.PREFIX}:by_user:{end_user_id}:{language}"
|
||||
|
||||
@classmethod
|
||||
async def set_interest_distribution(
|
||||
cls,
|
||||
end_user_id: str,
|
||||
language: str,
|
||||
data: List[Dict[str, Any]],
|
||||
expire: int = INTEREST_CACHE_EXPIRE,
|
||||
) -> bool:
|
||||
"""设置用户兴趣分布缓存
|
||||
|
||||
Args:
|
||||
end_user_id: 用户ID
|
||||
language: 语言类型
|
||||
data: 兴趣分布列表,格式 [{"name": "...", "frequency": ...}, ...]
|
||||
expire: 过期时间(秒),默认24小时
|
||||
|
||||
Returns:
|
||||
是否设置成功
|
||||
"""
|
||||
try:
|
||||
key = cls._get_key(end_user_id, language)
|
||||
payload = {
|
||||
"data": data,
|
||||
"generated_at": datetime.now().isoformat(),
|
||||
"cached": True,
|
||||
}
|
||||
value = json.dumps(payload, ensure_ascii=False)
|
||||
await aio_redis.set(key, value, ex=expire)
|
||||
logger.info(f"设置兴趣分布缓存成功: {key}, 过期时间: {expire}秒")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"设置兴趣分布缓存失败: {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
async def get_interest_distribution(
|
||||
cls,
|
||||
end_user_id: str,
|
||||
language: str,
|
||||
) -> Optional[List[Dict[str, Any]]]:
|
||||
"""获取用户兴趣分布缓存
|
||||
|
||||
Args:
|
||||
end_user_id: 用户ID
|
||||
language: 语言类型
|
||||
|
||||
Returns:
|
||||
兴趣分布列表,缓存不存在或已过期返回 None
|
||||
"""
|
||||
try:
|
||||
key = cls._get_key(end_user_id, language)
|
||||
value = await aio_redis.get(key)
|
||||
if value:
|
||||
payload = json.loads(value)
|
||||
logger.info(f"命中兴趣分布缓存: {key}")
|
||||
return payload.get("data")
|
||||
logger.info(f"兴趣分布缓存不存在或已过期: {key}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"获取兴趣分布缓存失败: {e}", exc_info=True)
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
async def delete_interest_distribution(
|
||||
cls,
|
||||
end_user_id: str,
|
||||
language: str,
|
||||
) -> bool:
|
||||
"""删除用户兴趣分布缓存
|
||||
|
||||
Args:
|
||||
end_user_id: 用户ID
|
||||
language: 语言类型
|
||||
|
||||
Returns:
|
||||
是否删除成功
|
||||
"""
|
||||
try:
|
||||
key = cls._get_key(end_user_id, language)
|
||||
result = await aio_redis.delete(key)
|
||||
logger.info(f"删除兴趣分布缓存: {key}, 结果: {result}")
|
||||
return result > 0
|
||||
except Exception as e:
|
||||
logger.error(f"删除兴趣分布缓存失败: {e}", exc_info=True)
|
||||
return False
|
||||
@@ -7,20 +7,48 @@ from celery import Celery
|
||||
from celery.schedules import crontab
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.logging_config import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# macOS fork() safety - must be set before any Celery initialization
|
||||
if platform.system() == 'Darwin':
|
||||
os.environ.setdefault('OBJC_DISABLE_INITIALIZE_FORK_SAFETY', 'YES')
|
||||
|
||||
# 创建 Celery 应用实例
|
||||
# broker: 任务队列(使用 Redis DB 0)
|
||||
# backend: 结果存储(使用 Redis DB 10)
|
||||
# broker: 任务队列(使用 Redis DB,由 CELERY_BROKER_DB 指定)
|
||||
# backend: 结果存储(使用 Redis DB,由 CELERY_BACKEND_DB 指定)
|
||||
# NOTE: 不要在 .env 中设置 BROKER_URL / RESULT_BACKEND / CELERY_BROKER / CELERY_BACKEND,
|
||||
# 这些名称会被 Celery CLI 的 Click 框架劫持,详见 docs/celery-env-bug-report.md
|
||||
|
||||
# Build canonical broker/backend URLs and force them into os.environ so that
|
||||
# Celery's Settings.broker_url property (which checks CELERY_BROKER_URL first)
|
||||
# cannot be overridden by stray env vars.
|
||||
# See: https://github.com/celery/celery/issues/4284
|
||||
_broker_url = f"redis://:{quote(settings.REDIS_PASSWORD)}@{settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.REDIS_DB_CELERY_BROKER}"
|
||||
_backend_url = f"redis://:{quote(settings.REDIS_PASSWORD)}@{settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.REDIS_DB_CELERY_BACKEND}"
|
||||
os.environ["CELERY_BROKER_URL"] = _broker_url
|
||||
os.environ["CELERY_RESULT_BACKEND"] = _backend_url
|
||||
# Neutralize legacy Celery env vars that can be hijacked by Celery's CLI/Click
|
||||
# integration and accidentally override our canonical URLs.
|
||||
os.environ.pop("BROKER_URL", None)
|
||||
os.environ.pop("RESULT_BACKEND", None)
|
||||
os.environ.pop("CELERY_BROKER", None)
|
||||
os.environ.pop("CELERY_BACKEND", None)
|
||||
|
||||
celery_app = Celery(
|
||||
"redbear_tasks",
|
||||
broker=f"redis://:{quote(settings.REDIS_PASSWORD)}@{settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.CELERY_BROKER}",
|
||||
backend=f"redis://:{quote(settings.REDIS_PASSWORD)}@{settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.CELERY_BACKEND}",
|
||||
broker=_broker_url,
|
||||
backend=_backend_url,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Celery app initialized",
|
||||
extra={
|
||||
"broker": _broker_url.replace(quote(settings.REDIS_PASSWORD), "***"),
|
||||
"backend": _backend_url.replace(quote(settings.REDIS_PASSWORD), "***"),
|
||||
},
|
||||
)
|
||||
# Default queue for unrouted tasks
|
||||
celery_app.conf.task_default_queue = 'memory_tasks'
|
||||
|
||||
@@ -44,8 +72,8 @@ celery_app.conf.update(
|
||||
task_ignore_result=False,
|
||||
|
||||
# 超时设置
|
||||
task_time_limit=1800, # 30分钟硬超时
|
||||
task_soft_time_limit=1500, # 25分钟软超时
|
||||
task_time_limit=3600, # 60分钟硬超时
|
||||
task_soft_time_limit=3000, # 50分钟软超时
|
||||
|
||||
# Worker 设置 (per-worker settings are in docker-compose command line)
|
||||
worker_prefetch_multiplier=1, # Don't hoard tasks, fairer distribution
|
||||
@@ -92,7 +120,7 @@ celery_app.conf.update(
|
||||
celery_app.autodiscover_tasks(['app'])
|
||||
|
||||
# Celery Beat schedule for periodic tasks
|
||||
memory_increment_schedule = timedelta(hours=settings.MEMORY_INCREMENT_INTERVAL_HOURS)
|
||||
memory_increment_schedule = crontab(hour=settings.MEMORY_INCREMENT_HOUR, minute=settings.MEMORY_INCREMENT_MINUTE)
|
||||
memory_cache_regeneration_schedule = timedelta(hours=settings.MEMORY_CACHE_REGENERATION_HOURS)
|
||||
workspace_reflection_schedule = timedelta(seconds=settings.WORKSPACE_REFLECTION_INTERVAL_SECONDS)
|
||||
forgetting_cycle_schedule = timedelta(hours=settings.FORGETTING_CYCLE_INTERVAL_HOURS)
|
||||
|
||||
1
api/app/config/__init__.py
Normal file
1
api/app/config/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Configuration module for application settings."""
|
||||
239
api/app/config/default_ontology_config.py
Normal file
239
api/app/config/default_ontology_config.py
Normal file
@@ -0,0 +1,239 @@
|
||||
"""默认本体场景配置
|
||||
|
||||
本模块定义系统预设的本体场景和实体类型配置。
|
||||
这些配置用于在工作空间创建时自动初始化默认场景。
|
||||
支持中英文双语配置,根据用户语言偏好创建对应语言的场景。
|
||||
"""
|
||||
|
||||
# 在线教育场景配置
|
||||
ONLINE_EDUCATION_SCENE = {
|
||||
"name_chinese": "在线教育",
|
||||
"name_english": "Online Education",
|
||||
"description_chinese": "适用于在线教育平台的本体建模,包含学生、教师、课程等核心实体类型",
|
||||
"description_english": "Ontology modeling for online education platforms, including core entity types such as students, teachers, and courses",
|
||||
"types": [
|
||||
{
|
||||
"name_chinese": "学生",
|
||||
"name_english": "Student",
|
||||
"description_chinese": "在教育系统中接受教育的个体,包含姓名、学号、年级、班级等属性",
|
||||
"description_english": "Individuals receiving education in the education system, including attributes such as name, student ID, grade, and class"
|
||||
},
|
||||
{
|
||||
"name_chinese": "教师",
|
||||
"name_english": "Teacher",
|
||||
"description_chinese": "在教育系统中提供教学服务的个体,包含姓名、工号、任教学科、职称等属性",
|
||||
"description_english": "Individuals providing teaching services in the education system, including attributes such as name, employee ID, teaching subject, and title"
|
||||
},
|
||||
{
|
||||
"name_chinese": "课程",
|
||||
"name_english": "Course",
|
||||
"description_chinese": "教育系统中的教学内容单元,包含课程名称、课程代码、学分、学时等属性",
|
||||
"description_english": "Teaching content units in the education system, including attributes such as course name, course code, credits, and class hours"
|
||||
},
|
||||
{
|
||||
"name_chinese": "作业",
|
||||
"name_english": "Assignment",
|
||||
"description_chinese": "课程中布置的学习任务,包含作业标题、截止日期、所属课程、提交状态等属性",
|
||||
"description_english": "Learning tasks assigned in courses, including attributes such as assignment title, deadline, course, and submission status"
|
||||
},
|
||||
{
|
||||
"name_chinese": "成绩",
|
||||
"name_english": "Grade",
|
||||
"description_chinese": "学生学习成果的评价结果,包含分数、评级、考试类型、所属课程等属性",
|
||||
"description_english": "Evaluation results of student learning outcomes, including attributes such as score, rating, exam type, and course"
|
||||
},
|
||||
{
|
||||
"name_chinese": "考试",
|
||||
"name_english": "Exam",
|
||||
"description_chinese": "评估学生学习成果的测试活动,包含考试名称、时间、地点、科目等属性",
|
||||
"description_english": "Test activities to assess student learning outcomes, including attributes such as exam name, time, location, and subject"
|
||||
},
|
||||
{
|
||||
"name_chinese": "教室",
|
||||
"name_english": "Classroom",
|
||||
"description_chinese": "进行教学活动的物理或虚拟空间,包含教室编号、容量、设备等属性",
|
||||
"description_english": "Physical or virtual spaces for teaching activities, including attributes such as classroom number, capacity, and equipment"
|
||||
},
|
||||
{
|
||||
"name_chinese": "学科",
|
||||
"name_english": "Subject",
|
||||
"description_chinese": "知识的分类领域,包含学科名称、代码、所属院系等属性",
|
||||
"description_english": "Classification domains of knowledge, including attributes such as subject name, code, and department"
|
||||
},
|
||||
{
|
||||
"name_chinese": "教材",
|
||||
"name_english": "Textbook",
|
||||
"description_chinese": "教学使用的书籍或资料,包含书名、作者、出版社、ISBN等属性",
|
||||
"description_english": "Books or materials used for teaching, including attributes such as title, author, publisher, and ISBN"
|
||||
},
|
||||
{
|
||||
"name_chinese": "班级",
|
||||
"name_english": "Class",
|
||||
"description_chinese": "学生的组织单位,包含班级名称、年级、人数、班主任等属性",
|
||||
"description_english": "Organizational units of students, including attributes such as class name, grade, number of students, and class teacher"
|
||||
},
|
||||
{
|
||||
"name_chinese": "学期",
|
||||
"name_english": "Semester",
|
||||
"description_chinese": "教学时间的划分单位,包含学期名称、开始时间、结束时间等属性",
|
||||
"description_english": "Time division units for teaching, including attributes such as semester name, start time, and end time"
|
||||
},
|
||||
{
|
||||
"name_chinese": "课时",
|
||||
"name_english": "Class Hour",
|
||||
"description_chinese": "课程的时间单位,包含上课时间、地点、教师、课程等属性",
|
||||
"description_english": "Time units of courses, including attributes such as class time, location, teacher, and course"
|
||||
},
|
||||
{
|
||||
"name_chinese": "教学计划",
|
||||
"name_english": "Teaching Plan",
|
||||
"description_chinese": "课程的教学安排,包含教学目标、内容安排、进度计划等属性",
|
||||
"description_english": "Teaching arrangements for courses, including attributes such as teaching objectives, content arrangement, and progress plan"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
# 情感陪伴场景配置
|
||||
EMOTIONAL_COMPANION_SCENE = {
|
||||
"name_chinese": "情感陪伴",
|
||||
"name_english": "Emotional Companion",
|
||||
"description_chinese": "适用于情感陪伴应用的本体建模,包含用户、情绪、活动等核心实体类型",
|
||||
"description_english": "Ontology modeling for emotional companion applications, including core entity types such as users, emotions, and activities",
|
||||
"types": [
|
||||
{
|
||||
"name_chinese": "用户",
|
||||
"name_english": "User",
|
||||
"description_chinese": "使用情感陪伴服务的个体,包含姓名、昵称、性格特征、偏好等属性",
|
||||
"description_english": "Individuals using emotional companion services, including attributes such as name, nickname, personality traits, and preferences"
|
||||
},
|
||||
{
|
||||
"name_chinese": "情绪",
|
||||
"name_english": "Emotion",
|
||||
"description_chinese": "用户的情感状态,包含情绪类型、强度、触发原因、持续时间等属性",
|
||||
"description_english": "Emotional states of users, including attributes such as emotion type, intensity, trigger cause, and duration"
|
||||
},
|
||||
{
|
||||
"name_chinese": "活动",
|
||||
"name_english": "Activity",
|
||||
"description_chinese": "用户参与的各类活动,包含活动名称、类型、参与者、时间地点等属性",
|
||||
"description_english": "Various activities users participate in, including attributes such as activity name, type, participants, time, and location"
|
||||
},
|
||||
{
|
||||
"name_chinese": "对话",
|
||||
"name_english": "Conversation",
|
||||
"description_chinese": "用户之间的交流记录,包含对话主题、参与者、时间、关键内容等属性",
|
||||
"description_english": "Communication records between users, including attributes such as conversation topic, participants, time, and key content"
|
||||
},
|
||||
{
|
||||
"name_chinese": "兴趣爱好",
|
||||
"name_english": "Hobby",
|
||||
"description_chinese": "用户的兴趣和爱好,包含爱好名称、类别、熟练程度、相关活动等属性",
|
||||
"description_english": "User interests and hobbies, including attributes such as hobby name, category, proficiency level, and related activities"
|
||||
},
|
||||
{
|
||||
"name_chinese": "日常事件",
|
||||
"name_english": "Daily Event",
|
||||
"description_chinese": "用户日常生活中的事件,包含事件描述、时间、地点、相关人物等属性",
|
||||
"description_english": "Events in users' daily lives, including attributes such as event description, time, location, and related people"
|
||||
},
|
||||
{
|
||||
"name_chinese": "关系",
|
||||
"name_english": "Relationship",
|
||||
"description_chinese": "用户之间的社会关系,包含关系类型、亲密度、建立时间等属性",
|
||||
"description_english": "Social relationships between users, including attributes such as relationship type, intimacy, and establishment time"
|
||||
},
|
||||
{
|
||||
"name_chinese": "回忆",
|
||||
"name_english": "Memory",
|
||||
"description_chinese": "用户的重要记忆片段,包含回忆内容、时间、地点、相关人物等属性",
|
||||
"description_english": "Important memory fragments of users, including attributes such as memory content, time, location, and related people"
|
||||
},
|
||||
{
|
||||
"name_chinese": "地点",
|
||||
"name_english": "Location",
|
||||
"description_chinese": "用户活动的地理位置,包含地点名称、地址、类型、相关事件等属性",
|
||||
"description_english": "Geographic locations of user activities, including attributes such as location name, address, type, and related events"
|
||||
},
|
||||
{
|
||||
"name_chinese": "时间节点",
|
||||
"name_english": "Time Point",
|
||||
"description_chinese": "重要的时间标记,包含日期、事件、意义等属性",
|
||||
"description_english": "Important time markers, including attributes such as date, event, and significance"
|
||||
},
|
||||
{
|
||||
"name_chinese": "目标",
|
||||
"name_english": "Goal",
|
||||
"description_chinese": "用户设定的目标,包含目标描述、截止时间、完成状态、相关活动等属性",
|
||||
"description_english": "Goals set by users, including attributes such as goal description, deadline, completion status, and related activities"
|
||||
},
|
||||
{
|
||||
"name_chinese": "成就",
|
||||
"name_english": "Achievement",
|
||||
"description_chinese": "用户获得的成就,包含成就名称、获得时间、描述、相关目标等属性",
|
||||
"description_english": "Achievements obtained by users, including attributes such as achievement name, acquisition time, description, and related goals"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
# 导出默认场景列表
|
||||
DEFAULT_SCENES = [ONLINE_EDUCATION_SCENE, EMOTIONAL_COMPANION_SCENE]
|
||||
|
||||
|
||||
def get_scene_name(scene_config: dict, language: str = "zh") -> str:
|
||||
"""获取场景名称(根据语言)
|
||||
|
||||
Args:
|
||||
scene_config: 场景配置字典
|
||||
language: 语言类型 ("zh" 或 "en")
|
||||
|
||||
Returns:
|
||||
对应语言的场景名称
|
||||
"""
|
||||
if language == "en":
|
||||
return scene_config.get("name_english", scene_config.get("name_chinese"))
|
||||
return scene_config.get("name_chinese")
|
||||
|
||||
|
||||
def get_scene_description(scene_config: dict, language: str = "zh") -> str:
|
||||
"""获取场景描述(根据语言)
|
||||
|
||||
Args:
|
||||
scene_config: 场景配置字典
|
||||
language: 语言类型 ("zh" 或 "en")
|
||||
|
||||
Returns:
|
||||
对应语言的场景描述
|
||||
"""
|
||||
if language == "en":
|
||||
return scene_config.get("description_english", scene_config.get("description_chinese"))
|
||||
return scene_config.get("description_chinese")
|
||||
|
||||
|
||||
def get_type_name(type_config: dict, language: str = "zh") -> str:
|
||||
"""获取类型名称(根据语言)
|
||||
|
||||
Args:
|
||||
type_config: 类型配置字典
|
||||
language: 语言类型 ("zh" 或 "en")
|
||||
|
||||
Returns:
|
||||
对应语言的类型名称
|
||||
"""
|
||||
if language == "en":
|
||||
return type_config.get("name_english", type_config.get("name_chinese"))
|
||||
return type_config.get("name_chinese")
|
||||
|
||||
|
||||
def get_type_description(type_config: dict, language: str = "zh") -> str:
|
||||
"""获取类型描述(根据语言)
|
||||
|
||||
Args:
|
||||
type_config: 类型配置字典
|
||||
language: 语言类型 ("zh" 或 "en")
|
||||
|
||||
Returns:
|
||||
对应语言的类型描述
|
||||
"""
|
||||
if language == "en":
|
||||
return type_config.get("description_english", type_config.get("description_chinese"))
|
||||
return type_config.get("description_chinese")
|
||||
249
api/app/config/default_ontology_initializer.py
Normal file
249
api/app/config/default_ontology_initializer.py
Normal file
@@ -0,0 +1,249 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""默认本体场景初始化器
|
||||
|
||||
本模块提供默认本体场景和类型的自动初始化功能。
|
||||
在工作空间创建时,自动添加预设的本体场景和实体类型。
|
||||
|
||||
Classes:
|
||||
DefaultOntologyInitializer: 默认本体场景初始化器
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import List, Optional, Tuple
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.config.default_ontology_config import (
|
||||
DEFAULT_SCENES,
|
||||
get_scene_name,
|
||||
get_scene_description,
|
||||
get_type_name,
|
||||
get_type_description,
|
||||
)
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.repositories.ontology_scene_repository import OntologySceneRepository
|
||||
from app.repositories.ontology_class_repository import OntologyClassRepository
|
||||
|
||||
|
||||
class DefaultOntologyInitializer:
|
||||
"""默认本体场景初始化器
|
||||
|
||||
负责在工作空间创建时自动初始化默认的本体场景和类型。
|
||||
遵循最小侵入原则,确保初始化失败不阻止工作空间创建。
|
||||
|
||||
Attributes:
|
||||
db: 数据库会话
|
||||
scene_repo: 场景Repository
|
||||
class_repo: 类型Repository
|
||||
logger: 业务日志记录器
|
||||
"""
|
||||
|
||||
def __init__(self, db: Session):
|
||||
"""初始化
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
"""
|
||||
self.db = db
|
||||
self.scene_repo = OntologySceneRepository(db)
|
||||
self.class_repo = OntologyClassRepository(db)
|
||||
self.logger = get_business_logger()
|
||||
|
||||
def initialize_default_scenes(
|
||||
self,
|
||||
workspace_id: UUID,
|
||||
language: str = "zh"
|
||||
) -> Tuple[bool, str]:
|
||||
"""为工作空间初始化默认场景
|
||||
|
||||
创建两个默认场景(在线教育、情感陪伴)及其对应的实体类型。
|
||||
如果创建失败,记录错误日志但不抛出异常。
|
||||
|
||||
Args:
|
||||
workspace_id: 工作空间ID
|
||||
language: 语言类型 ("zh" 或 "en"),默认为 "zh"
|
||||
|
||||
Returns:
|
||||
Tuple[bool, str]: (是否成功, 错误信息)
|
||||
"""
|
||||
try:
|
||||
self.logger.info(
|
||||
f"开始初始化默认本体场景 - workspace_id={workspace_id}, language={language}"
|
||||
)
|
||||
|
||||
scenes_created = 0
|
||||
total_types_created = 0
|
||||
|
||||
# 遍历默认场景配置
|
||||
for scene_config in DEFAULT_SCENES:
|
||||
scene_name = get_scene_name(scene_config, language)
|
||||
|
||||
# 创建场景及其类型
|
||||
scene_id = self._create_scene_with_types(workspace_id, scene_config, language)
|
||||
|
||||
if scene_id:
|
||||
scenes_created += 1
|
||||
# 统计类型数量
|
||||
types_count = len(scene_config.get("types", []))
|
||||
total_types_created += types_count
|
||||
|
||||
self.logger.info(
|
||||
f"场景创建成功 - scene_name={scene_name}, "
|
||||
f"scene_id={scene_id}, types_count={types_count}, language={language}"
|
||||
)
|
||||
else:
|
||||
self.logger.warning(
|
||||
f"场景创建失败 - scene_name={scene_name}, "
|
||||
f"workspace_id={workspace_id}, language={language}"
|
||||
)
|
||||
|
||||
# 记录总体结果
|
||||
self.logger.info(
|
||||
f"默认场景初始化完成 - workspace_id={workspace_id}, "
|
||||
f"language={language}, scenes_created={scenes_created}, "
|
||||
f"total_types_created={total_types_created}"
|
||||
)
|
||||
|
||||
# 如果至少创建了一个场景,视为成功
|
||||
if scenes_created > 0:
|
||||
return True, ""
|
||||
else:
|
||||
error_msg = "所有默认场景创建失败"
|
||||
self.logger.error(
|
||||
f"默认场景初始化失败 - workspace_id={workspace_id}, "
|
||||
f"language={language}, error={error_msg}"
|
||||
)
|
||||
return False, error_msg
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"默认场景初始化异常: {str(e)}"
|
||||
self.logger.error(
|
||||
f"默认场景初始化异常 - workspace_id={workspace_id}, "
|
||||
f"language={language}, error={str(e)}",
|
||||
exc_info=True
|
||||
)
|
||||
return False, error_msg
|
||||
|
||||
def _create_scene_with_types(
|
||||
self,
|
||||
workspace_id: UUID,
|
||||
scene_config: dict,
|
||||
language: str = "zh"
|
||||
) -> Optional[UUID]:
|
||||
"""创建场景及其类型
|
||||
|
||||
Args:
|
||||
workspace_id: 工作空间ID
|
||||
scene_config: 场景配置字典
|
||||
language: 语言类型 ("zh" 或 "en")
|
||||
|
||||
Returns:
|
||||
Optional[UUID]: 创建的场景ID,失败返回None
|
||||
"""
|
||||
try:
|
||||
scene_name = get_scene_name(scene_config, language)
|
||||
scene_description = get_scene_description(scene_config, language)
|
||||
|
||||
# 检查是否已存在同名场景(支持向后兼容)
|
||||
existing_scene = self.scene_repo.get_by_name(scene_name, workspace_id)
|
||||
if existing_scene:
|
||||
self.logger.info(
|
||||
f"场景已存在,跳过创建 - scene_name={scene_name}, "
|
||||
f"workspace_id={workspace_id}, scene_id={existing_scene.scene_id}, "
|
||||
f"language={language}"
|
||||
)
|
||||
return None
|
||||
|
||||
# 创建场景记录,设置 is_system_default=true
|
||||
scene_data = {
|
||||
"scene_name": scene_name,
|
||||
"scene_description": scene_description
|
||||
}
|
||||
|
||||
scene = self.scene_repo.create(scene_data, workspace_id)
|
||||
|
||||
# 设置系统默认标识
|
||||
scene.is_system_default = True
|
||||
self.db.flush()
|
||||
|
||||
self.logger.info(
|
||||
f"场景创建成功 - scene_name={scene_name}, "
|
||||
f"scene_id={scene.scene_id}, is_system_default=True, language={language}"
|
||||
)
|
||||
|
||||
# 批量创建类型
|
||||
types_config = scene_config.get("types", [])
|
||||
types_created = self._batch_create_types(scene.scene_id, types_config, language)
|
||||
|
||||
self.logger.info(
|
||||
f"场景类型创建完成 - scene_id={scene.scene_id}, "
|
||||
f"types_created={types_created}/{len(types_config)}, language={language}"
|
||||
)
|
||||
|
||||
return scene.scene_id
|
||||
|
||||
except Exception as e:
|
||||
scene_name = get_scene_name(scene_config, language)
|
||||
self.logger.error(
|
||||
f"场景创建失败 - scene_name={scene_name}, "
|
||||
f"workspace_id={workspace_id}, language={language}, error={str(e)}",
|
||||
exc_info=True
|
||||
)
|
||||
return None
|
||||
|
||||
def _batch_create_types(
|
||||
self,
|
||||
scene_id: UUID,
|
||||
types_config: List[dict],
|
||||
language: str = "zh"
|
||||
) -> int:
|
||||
"""批量创建实体类型
|
||||
|
||||
Args:
|
||||
scene_id: 场景ID
|
||||
types_config: 类型配置列表
|
||||
language: 语言类型 ("zh" 或 "en")
|
||||
|
||||
Returns:
|
||||
int: 成功创建的类型数量
|
||||
"""
|
||||
created_count = 0
|
||||
|
||||
for type_config in types_config:
|
||||
try:
|
||||
type_name = get_type_name(type_config, language)
|
||||
type_description = get_type_description(type_config, language)
|
||||
|
||||
# 创建类型数据
|
||||
class_data = {
|
||||
"class_name": type_name,
|
||||
"class_description": type_description
|
||||
}
|
||||
|
||||
# 创建类型
|
||||
ontology_class = self.class_repo.create(class_data, scene_id)
|
||||
|
||||
# 设置系统默认标识
|
||||
ontology_class.is_system_default = True
|
||||
self.db.flush()
|
||||
|
||||
created_count += 1
|
||||
|
||||
self.logger.debug(
|
||||
f"类型创建成功 - class_name={type_name}, "
|
||||
f"class_id={ontology_class.class_id}, "
|
||||
f"scene_id={scene_id}, is_system_default=True, language={language}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
type_name = get_type_name(type_config, language)
|
||||
self.logger.warning(
|
||||
f"单个类型创建失败,继续创建其他类型 - "
|
||||
f"class_name={type_name}, scene_id={scene_id}, "
|
||||
f"language={language}, error={str(e)}"
|
||||
)
|
||||
# 继续创建其他类型
|
||||
continue
|
||||
|
||||
return created_count
|
||||
@@ -1,7 +1,8 @@
|
||||
import uuid
|
||||
from typing import Optional, Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends, Path
|
||||
import yaml
|
||||
from fastapi import APIRouter, Depends, Path, Form, UploadFile, File
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@@ -17,12 +18,13 @@ from app.repositories.end_user_repository import EndUserRepository
|
||||
from app.schemas import app_schema
|
||||
from app.schemas.response_schema import PageData, PageMeta
|
||||
from app.schemas.workflow_schema import WorkflowConfig as WorkflowConfigSchema
|
||||
from app.schemas.workflow_schema import WorkflowConfigUpdate
|
||||
from app.schemas.workflow_schema import WorkflowConfigUpdate, WorkflowImportSave
|
||||
from app.services import app_service, workspace_service
|
||||
from app.services.agent_config_helper import enrich_agent_config
|
||||
from app.services.app_service import AppService
|
||||
from app.services.workflow_service import WorkflowService, get_workflow_service
|
||||
from app.services.app_statistics_service import AppStatisticsService
|
||||
from app.services.workflow_import_service import WorkflowImportService
|
||||
from app.services.workflow_service import WorkflowService, get_workflow_service
|
||||
|
||||
router = APIRouter(prefix="/apps", tags=["Apps"])
|
||||
logger = get_business_logger()
|
||||
@@ -65,7 +67,7 @@ def list_apps(
|
||||
|
||||
# 当 ids 存在且不为 None 时,根据 ids 获取应用
|
||||
if ids is not None:
|
||||
app_ids = [id.strip() for id in ids.split(',') if id.strip()]
|
||||
app_ids = [app_id.strip() for app_id in ids.split(',') if app_id.strip()]
|
||||
items_orm = app_service.get_apps_by_ids(db, app_ids, workspace_id)
|
||||
items = [service._convert_to_schema(app, workspace_id) for app in items_orm]
|
||||
return success(data=items)
|
||||
@@ -394,10 +396,10 @@ async def draft_run(
|
||||
from app.models import AgentConfig, ModelConfig
|
||||
from sqlalchemy import select
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.services.draft_run_service import DraftRunService
|
||||
from app.services.draft_run_service import AgentRunService
|
||||
|
||||
service = AppService(db)
|
||||
draft_service = DraftRunService(db)
|
||||
draft_service = AgentRunService(db)
|
||||
|
||||
# 1. 验证应用
|
||||
app = service._get_app_or_404(app_id)
|
||||
@@ -482,8 +484,8 @@ async def draft_run(
|
||||
}
|
||||
)
|
||||
|
||||
from app.services.draft_run_service import DraftRunService
|
||||
draft_service = DraftRunService(db)
|
||||
from app.services.draft_run_service import AgentRunService
|
||||
draft_service = AgentRunService(db)
|
||||
result = await draft_service.run(
|
||||
agent_config=agent_cfg,
|
||||
model_config=model_config,
|
||||
@@ -787,8 +789,8 @@ async def draft_run_compare(
|
||||
# 流式返回
|
||||
if payload.stream:
|
||||
async def event_generator():
|
||||
from app.services.draft_run_service import DraftRunService
|
||||
draft_service = DraftRunService(db)
|
||||
from app.services.draft_run_service import AgentRunService
|
||||
draft_service = AgentRunService(db)
|
||||
async for event in draft_service.run_compare_stream(
|
||||
agent_config=agent_cfg,
|
||||
models=model_configs,
|
||||
@@ -818,8 +820,8 @@ async def draft_run_compare(
|
||||
)
|
||||
|
||||
# 非流式返回
|
||||
from app.services.draft_run_service import DraftRunService
|
||||
draft_service = DraftRunService(db)
|
||||
from app.services.draft_run_service import AgentRunService
|
||||
draft_service = AgentRunService(db)
|
||||
result = await draft_service.run_compare(
|
||||
agent_config=agent_cfg,
|
||||
models=model_configs,
|
||||
@@ -833,7 +835,8 @@ async def draft_run_compare(
|
||||
web_search=True,
|
||||
memory=True,
|
||||
parallel=payload.parallel,
|
||||
timeout=payload.timeout or 60
|
||||
timeout=payload.timeout or 60,
|
||||
files=payload.files
|
||||
)
|
||||
|
||||
logger.info(
|
||||
@@ -879,6 +882,60 @@ async def update_workflow_config(
|
||||
return success(data=WorkflowConfigSchema.model_validate(cfg))
|
||||
|
||||
|
||||
@router.get("/{app_id}/workflow/export")
|
||||
@cur_workspace_access_guard()
|
||||
async def export_workflow_config(
|
||||
app_id: uuid.UUID,
|
||||
db: Annotated[Session, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_user)]
|
||||
):
|
||||
"""导出工作流配置为YAML文件"""
|
||||
workflow_service = WorkflowService(db)
|
||||
|
||||
return success(data={
|
||||
"content": workflow_service.export_workflow_dsl(app_id=app_id),
|
||||
})
|
||||
|
||||
|
||||
@router.post("/workflow/import")
|
||||
@cur_workspace_access_guard()
|
||||
async def import_workflow_config(
|
||||
file: UploadFile = File(...),
|
||||
platform: str = Form(...),
|
||||
app_id: str = Form(None),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
|
||||
):
|
||||
"""从YAML内容导入工作流配置"""
|
||||
if not file.filename.lower().endswith((".yaml", ".yml")):
|
||||
return fail(msg="Only yaml file is allowed", code=BizCode.BAD_REQUEST)
|
||||
|
||||
raw_text = (await file.read()).decode("utf-8")
|
||||
import_service = WorkflowImportService(db)
|
||||
config = yaml.safe_load(raw_text)
|
||||
result = await import_service.upload_config(platform, config)
|
||||
return success(data=result)
|
||||
|
||||
|
||||
@router.post("/workflow/import/save")
|
||||
@cur_workspace_access_guard()
|
||||
async def save_workflow_import(
|
||||
data: WorkflowImportSave,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
import_service = WorkflowImportService(db)
|
||||
app = await import_service.save_workflow(
|
||||
user_id=current_user.id,
|
||||
workspace_id=current_user.current_workspace_id,
|
||||
temp_id=data.temp_id,
|
||||
name=data.name,
|
||||
description=data.description,
|
||||
)
|
||||
return success(data=app_schema.App.model_validate(app))
|
||||
|
||||
|
||||
@router.get("/{app_id}/statistics", summary="应用统计数据")
|
||||
@cur_workspace_access_guard()
|
||||
def get_app_statistics(
|
||||
@@ -889,12 +946,14 @@ def get_app_statistics(
|
||||
current_user=Depends(get_current_user),
|
||||
):
|
||||
"""获取应用统计数据
|
||||
|
||||
|
||||
Args:
|
||||
app_id: 应用ID
|
||||
start_date: 开始时间戳(毫秒)
|
||||
end_date: 结束时间戳(毫秒)
|
||||
|
||||
db: 数据库连接
|
||||
current_user: 当前用户
|
||||
|
||||
Returns:
|
||||
- daily_conversations: 每日会话数统计
|
||||
- total_conversations: 总会话数
|
||||
@@ -931,6 +990,8 @@ def get_workspace_api_statistics(
|
||||
Args:
|
||||
start_date: 开始时间戳(毫秒)
|
||||
end_date: 结束时间戳(毫秒)
|
||||
db: 数据库连接
|
||||
current_user: 当前用户
|
||||
|
||||
Returns:
|
||||
每日统计数据列表,每项包含:
|
||||
|
||||
@@ -441,14 +441,14 @@ async def retrieve_chunks(
|
||||
# 1 participle search, 2 semantic search, 3 hybrid search
|
||||
match retrieve_data.retrieve_type:
|
||||
case chunk_schema.RetrieveType.PARTICIPLE:
|
||||
rs = vector_service.search_by_full_text(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.similarity_threshold)
|
||||
rs = vector_service.search_by_full_text(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.similarity_threshold, file_names_filter=retrieve_data.file_names_filter)
|
||||
return success(data=rs, msg="retrieval successful")
|
||||
case chunk_schema.RetrieveType.SEMANTIC:
|
||||
rs = vector_service.search_by_vector(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.vector_similarity_weight)
|
||||
rs = vector_service.search_by_vector(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.vector_similarity_weight, file_names_filter=retrieve_data.file_names_filter)
|
||||
return success(data=rs, msg="retrieval successful")
|
||||
case _:
|
||||
rs1 = vector_service.search_by_vector(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.vector_similarity_weight)
|
||||
rs2 = vector_service.search_by_full_text(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.similarity_threshold)
|
||||
rs1 = vector_service.search_by_vector(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.vector_similarity_weight, file_names_filter=retrieve_data.file_names_filter)
|
||||
rs2 = vector_service.search_by_full_text(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.similarity_threshold, file_names_filter=retrieve_data.file_names_filter)
|
||||
# Efficient deduplication
|
||||
seen_ids = set()
|
||||
unique_rs = []
|
||||
|
||||
@@ -90,7 +90,7 @@ async def get_mcp_servers(
|
||||
cookies=cookies)
|
||||
raise_for_http_status(r)
|
||||
except requests.exceptions.RequestException as e:
|
||||
api_logger.error(f"mFailed to get MCP servers: {str(e)}")
|
||||
api_logger.error(f"Failed to get MCP servers: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to get MCP servers: {str(e)}"
|
||||
@@ -118,6 +118,65 @@ async def get_mcp_servers(
|
||||
return success(data=result, msg="Query of mcp servers list successful")
|
||||
|
||||
|
||||
@router.get("/operational_mcp_servers", response_model=ApiResponse)
|
||||
async def get_operational_mcp_servers(
|
||||
mcp_market_config_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Query the operational mcp servers list in pages
|
||||
- Support keyword search for name,author,owner
|
||||
- Return paging metadata + operational mcp server list
|
||||
"""
|
||||
api_logger.info(
|
||||
f"Query operational mcp server list: tenant_id={current_user.tenant_id}, username: {current_user.username}")
|
||||
|
||||
# 1. Query mcp market config information from the database
|
||||
api_logger.debug(f"Query mcp market config: {mcp_market_config_id}")
|
||||
db_mcp_market_config = mcp_market_config_service.get_mcp_market_config_by_id(db,
|
||||
mcp_market_config_id=mcp_market_config_id,
|
||||
current_user=current_user)
|
||||
if not db_mcp_market_config:
|
||||
api_logger.warning(
|
||||
f"The mcp market config does not exist or access is denied: mcp_market_config_id={mcp_market_config_id}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="The mcp market config does not exist or access is denied"
|
||||
)
|
||||
|
||||
# 2. Execute paged query
|
||||
api = MCPApi()
|
||||
token = db_mcp_market_config.token
|
||||
api.login(token)
|
||||
|
||||
url = f'{api.mcp_base_url}/operational'
|
||||
headers = api.builder_headers(api.headers)
|
||||
|
||||
try:
|
||||
cookies = api.get_cookies(access_token=token, cookies_required=True)
|
||||
r = api.session.get(url, headers=headers, cookies=cookies)
|
||||
raise_for_http_status(r)
|
||||
except requests.exceptions.RequestException as e:
|
||||
api_logger.error(f"Failed to get operational MCP servers: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to get operational MCP servers: {str(e)}"
|
||||
)
|
||||
|
||||
data = api._handle_response(r)
|
||||
total = data.get('total_count', 0)
|
||||
mcp_server_list = data.get('mcp_server_list', [])
|
||||
# items = [{
|
||||
# 'name': item.get('name', ''),
|
||||
# 'id': item.get('id', ''),
|
||||
# 'description': item.get('description', '')
|
||||
# } for item in mcp_server_list]
|
||||
|
||||
# 3. Return structured response
|
||||
return success(data=mcp_server_list, msg="Query of operational mcp servers list successful")
|
||||
|
||||
|
||||
@router.get("/mcp_server", response_model=ApiResponse)
|
||||
async def get_mcp_server(
|
||||
mcp_market_config_id: uuid.UUID,
|
||||
|
||||
@@ -1,27 +1,29 @@
|
||||
from typing import List, Optional
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from fastapi import APIRouter, Depends, File, Form, Query, UploadFile, Header
|
||||
from sqlalchemy.orm import Session
|
||||
from starlette.responses import StreamingResponse
|
||||
|
||||
from app.cache.memory.interest_memory import InterestMemoryCache
|
||||
from app.celery_app import celery_app
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.language_utils import get_language_from_header
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.core.memory.agent.utils.redis_tool import store
|
||||
from app.core.memory.agent.utils.session_tools import SessionService
|
||||
from app.core.rag.llm.cv_model import QWenCV
|
||||
from app.core.response_utils import fail, success
|
||||
from app.db import get_db
|
||||
from app.dependencies import cur_workspace_access_guard, get_current_user
|
||||
from app.models import ModelApiKey
|
||||
from app.models.user_model import User
|
||||
from app.core.memory.agent.utils.session_tools import SessionService
|
||||
from app.core.memory.agent.utils.redis_tool import store
|
||||
from app.repositories import knowledge_repository, WorkspaceRepository
|
||||
from app.repositories import knowledge_repository
|
||||
from app.schemas.memory_agent_schema import UserInput, Write_UserInput
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.services import task_service, workspace_service
|
||||
from app.services.memory_agent_service import MemoryAgentService
|
||||
from app.services.model_service import ModelConfigService
|
||||
from dotenv import load_dotenv
|
||||
from fastapi import APIRouter, Depends, File, Form, Query, UploadFile,Header
|
||||
from sqlalchemy.orm import Session
|
||||
from starlette.responses import StreamingResponse
|
||||
|
||||
load_dotenv()
|
||||
api_logger = get_api_logger()
|
||||
@@ -36,7 +38,7 @@ router = APIRouter(
|
||||
|
||||
@router.get("/health/status", response_model=ApiResponse)
|
||||
async def get_health_status(
|
||||
current_user: User = Depends(get_current_user)
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Get latest health status written by Celery periodic task
|
||||
@@ -54,8 +56,9 @@ async def get_health_status(
|
||||
|
||||
@router.get("/download_log")
|
||||
async def download_log(
|
||||
log_type: str = Query("file", regex="^(file|transmission)$", description="日志类型: file=完整文件, transmission=实时流式传输"),
|
||||
current_user: User = Depends(get_current_user)
|
||||
log_type: str = Query("file", regex="^(file|transmission)$",
|
||||
description="日志类型: file=完整文件, transmission=实时流式传输"),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Download or stream agent service log file
|
||||
@@ -74,16 +77,16 @@ async def download_log(
|
||||
- transmission mode: StreamingResponse with SSE
|
||||
"""
|
||||
api_logger.info(f"Log download requested with log_type={log_type}")
|
||||
|
||||
|
||||
# Validate log_type parameter (FastAPI Query regex already validates, but explicit check for clarity)
|
||||
if log_type not in ["file", "transmission"]:
|
||||
api_logger.warning(f"Invalid log_type parameter: {log_type}")
|
||||
return fail(
|
||||
BizCode.BAD_REQUEST,
|
||||
"无效的log_type参数",
|
||||
BizCode.BAD_REQUEST,
|
||||
"无效的log_type参数",
|
||||
"log_type必须是'file'或'transmission'"
|
||||
)
|
||||
|
||||
|
||||
# Route to appropriate mode
|
||||
if log_type == "file":
|
||||
# File mode: Return complete log file content
|
||||
@@ -118,10 +121,10 @@ async def download_log(
|
||||
@router.post("/writer_service", response_model=ApiResponse)
|
||||
@cur_workspace_access_guard()
|
||||
async def write_server(
|
||||
user_input: Write_UserInput,
|
||||
language_type: str = Header(default=None, alias="X-Language-Type"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
user_input: Write_UserInput,
|
||||
language_type: str = Header(default=None, alias="X-Language-Type"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Write service endpoint - processes write operations synchronously
|
||||
@@ -135,11 +138,11 @@ async def write_server(
|
||||
"""
|
||||
# 使用集中化的语言校验
|
||||
language = get_language_from_header(language_type)
|
||||
|
||||
|
||||
config_id = user_input.config_id
|
||||
workspace_id = current_user.current_workspace_id
|
||||
api_logger.info(f"Write service: workspace_id={workspace_id}, config_id={config_id}, language_type={language}")
|
||||
|
||||
|
||||
# 获取 storage_type,如果为 None 则使用默认值
|
||||
storage_type = workspace_service.get_workspace_storage_type(
|
||||
db=db,
|
||||
@@ -148,7 +151,7 @@ async def write_server(
|
||||
)
|
||||
if storage_type is None: storage_type = 'neo4j'
|
||||
user_rag_memory_id = ''
|
||||
|
||||
|
||||
# 如果 storage_type 是 rag,必须确保有有效的 user_rag_memory_id
|
||||
if storage_type == 'rag':
|
||||
if workspace_id:
|
||||
@@ -160,13 +163,15 @@ async def write_server(
|
||||
if knowledge:
|
||||
user_rag_memory_id = str(knowledge.id)
|
||||
else:
|
||||
api_logger.warning(f"未找到名为 'USER_RAG_MERORY' 的知识库,workspace_id: {workspace_id},将使用 neo4j 存储")
|
||||
api_logger.warning(
|
||||
f"未找到名为 'USER_RAG_MERORY' 的知识库,workspace_id: {workspace_id},将使用 neo4j 存储")
|
||||
storage_type = 'neo4j'
|
||||
else:
|
||||
api_logger.warning("workspace_id 为空,无法使用 rag 存储,将使用 neo4j 存储")
|
||||
storage_type = 'neo4j'
|
||||
|
||||
api_logger.info(f"Write service requested for group {user_input.end_user_id}, storage_type: {storage_type}, user_rag_memory_id: {user_rag_memory_id}")
|
||||
|
||||
api_logger.info(
|
||||
f"Write service requested for group {user_input.end_user_id}, storage_type: {storage_type}, user_rag_memory_id: {user_rag_memory_id}")
|
||||
try:
|
||||
messages_list = memory_agent_service.get_messages_list(user_input)
|
||||
result = await memory_agent_service.write_memory(
|
||||
@@ -174,7 +179,7 @@ async def write_server(
|
||||
messages_list,
|
||||
config_id,
|
||||
db,
|
||||
storage_type,
|
||||
storage_type,
|
||||
user_rag_memory_id,
|
||||
language
|
||||
)
|
||||
@@ -194,10 +199,10 @@ async def write_server(
|
||||
@router.post("/writer_service_async", response_model=ApiResponse)
|
||||
@cur_workspace_access_guard()
|
||||
async def write_server_async(
|
||||
user_input: Write_UserInput,
|
||||
language_type: str = Header(default=None, alias="X-Language-Type"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
user_input: Write_UserInput,
|
||||
language_type: str = Header(default=None, alias="X-Language-Type"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Async write service endpoint - enqueues write processing to Celery
|
||||
@@ -212,10 +217,11 @@ async def write_server_async(
|
||||
"""
|
||||
# 使用集中化的语言校验
|
||||
language = get_language_from_header(language_type)
|
||||
|
||||
|
||||
config_id = user_input.config_id
|
||||
workspace_id = current_user.current_workspace_id
|
||||
api_logger.info(f"Async write service: workspace_id={workspace_id}, config_id={config_id}, language_type={language}")
|
||||
api_logger.info(
|
||||
f"Async write service: workspace_id={workspace_id}, config_id={config_id}, language_type={language}")
|
||||
|
||||
# 获取 storage_type,如果为 None 则使用默认值
|
||||
storage_type = workspace_service.get_workspace_storage_type(
|
||||
@@ -243,7 +249,7 @@ async def write_server_async(
|
||||
args=[user_input.end_user_id, messages_list, config_id, storage_type, user_rag_memory_id, language]
|
||||
)
|
||||
api_logger.info(f"Write task queued: {task.id}")
|
||||
|
||||
|
||||
return success(data={"task_id": task.id}, msg="写入任务已提交")
|
||||
except Exception as e:
|
||||
api_logger.error(f"Async write operation failed: {str(e)}")
|
||||
@@ -253,9 +259,9 @@ async def write_server_async(
|
||||
@router.post("/read_service", response_model=ApiResponse)
|
||||
@cur_workspace_access_guard()
|
||||
async def read_server(
|
||||
user_input: UserInput,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
user_input: UserInput,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Read service endpoint - processes read operations synchronously
|
||||
@@ -290,8 +296,9 @@ async def read_server(
|
||||
)
|
||||
if knowledge:
|
||||
user_rag_memory_id = str(knowledge.id)
|
||||
|
||||
api_logger.info(f"Read service: group={user_input.end_user_id}, storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}, workspace_id={workspace_id}")
|
||||
|
||||
api_logger.info(
|
||||
f"Read service: group={user_input.end_user_id}, storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}, workspace_id={workspace_id}")
|
||||
try:
|
||||
result = await memory_agent_service.read_memory(
|
||||
user_input.end_user_id,
|
||||
@@ -305,7 +312,8 @@ async def read_server(
|
||||
)
|
||||
if str(user_input.search_switch) == "2":
|
||||
retrieve_info = result['answer']
|
||||
history = await SessionService(store).get_history(user_input.end_user_id, user_input.end_user_id, user_input.end_user_id)
|
||||
history = await SessionService(store).get_history(user_input.end_user_id, user_input.end_user_id,
|
||||
user_input.end_user_id)
|
||||
query = user_input.message
|
||||
|
||||
# 调用 memory_agent_service 的方法生成最终答案
|
||||
@@ -318,7 +326,7 @@ async def read_server(
|
||||
db=db
|
||||
)
|
||||
if "信息不足,无法回答" in result['answer']:
|
||||
result['answer']=retrieve_info
|
||||
result['answer'] = retrieve_info
|
||||
return success(data=result, msg="回复对话消息成功")
|
||||
except BaseException as e:
|
||||
# Handle ExceptionGroup from TaskGroup (Python 3.11+) or BaseExceptionGroup
|
||||
@@ -334,9 +342,10 @@ async def read_server(
|
||||
@router.post("/file", response_model=ApiResponse)
|
||||
async def file_update(
|
||||
files: List[UploadFile] = File(..., description="要上传的文件"),
|
||||
model_id:str = Form(..., description="模型ID"),
|
||||
model_id: str = Form(..., description="模型ID"),
|
||||
metadata: Optional[str] = Form(None, description="文件元数据 (JSON格式)"),
|
||||
current_user: User = Depends(get_current_user)
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
文件上传接口 - 支持图片识别
|
||||
@@ -349,9 +358,6 @@ async def file_update(
|
||||
Returns:
|
||||
文件处理结果
|
||||
"""
|
||||
|
||||
db_gen = get_db() # get_db 通常是一个生成器
|
||||
db = next(db_gen)
|
||||
api_logger.info(f"File upload requested, file count: {len(files)}")
|
||||
config = ModelConfigService.get_model_by_id(db=db, model_id=model_id)
|
||||
apiConfig: ModelApiKey = config.api_keys[0]
|
||||
@@ -360,7 +366,7 @@ async def file_update(
|
||||
for file in files:
|
||||
api_logger.debug(f"Processing file: {file.filename}, content_type: {file.content_type}")
|
||||
content = await file.read()
|
||||
|
||||
|
||||
if file.content_type and file.content_type.startswith("image/"):
|
||||
vision_model = QWenCV(
|
||||
key=apiConfig.api_key,
|
||||
@@ -374,12 +380,12 @@ async def file_update(
|
||||
else:
|
||||
api_logger.warning(f"Unsupported file type: {file.content_type}")
|
||||
file_content.append(f"[不支持的文件类型: {file.content_type}]")
|
||||
|
||||
|
||||
result_text = ';'.join(file_content)
|
||||
api_logger.info(f"File processing completed, result length: {len(result_text)}")
|
||||
|
||||
|
||||
return success(data=result_text, msg="转换文本成功")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"File processing failed: {str(e)}", exc_info=True)
|
||||
return fail(BizCode.INTERNAL_ERROR, "转换文本失败", str(e))
|
||||
@@ -429,8 +435,8 @@ async def read_server_async(
|
||||
|
||||
@router.get("/read_result/", response_model=ApiResponse)
|
||||
async def get_read_task_result(
|
||||
task_id: str,
|
||||
current_user: User = Depends(get_current_user)
|
||||
task_id: str,
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Get the status and result of an async read task
|
||||
@@ -451,7 +457,7 @@ async def get_read_task_result(
|
||||
try:
|
||||
result = task_service.get_task_memory_read_result(task_id)
|
||||
status = result.get("status")
|
||||
|
||||
|
||||
if status == "SUCCESS":
|
||||
# 任务成功完成
|
||||
task_result = result.get("result", {})
|
||||
@@ -469,7 +475,7 @@ async def get_read_task_result(
|
||||
else:
|
||||
# 旧格式:直接返回结果
|
||||
return success(data=task_result, msg="查询任务已完成")
|
||||
|
||||
|
||||
elif status == "FAILURE":
|
||||
# 任务失败
|
||||
error_info = result.get("result", "Unknown error")
|
||||
@@ -478,7 +484,7 @@ async def get_read_task_result(
|
||||
else:
|
||||
error_msg = str(error_info)
|
||||
return fail(BizCode.INTERNAL_ERROR, "查询任务失败", error_msg)
|
||||
|
||||
|
||||
elif status in ["PENDING", "STARTED"]:
|
||||
# 任务进行中
|
||||
return success(
|
||||
@@ -498,7 +504,7 @@ async def get_read_task_result(
|
||||
},
|
||||
msg=f"任务状态: {status}"
|
||||
)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"Read task status check failed: {str(e)}", exc_info=True)
|
||||
return fail(BizCode.INTERNAL_ERROR, "任务状态查询失败", str(e))
|
||||
@@ -506,8 +512,8 @@ async def get_read_task_result(
|
||||
|
||||
@router.get("/write_result/", response_model=ApiResponse)
|
||||
async def get_write_task_result(
|
||||
task_id: str,
|
||||
current_user: User = Depends(get_current_user)
|
||||
task_id: str,
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Get the status and result of an async write task
|
||||
@@ -528,7 +534,7 @@ async def get_write_task_result(
|
||||
try:
|
||||
result = task_service.get_task_memory_write_result(task_id)
|
||||
status = result.get("status")
|
||||
|
||||
|
||||
if status == "SUCCESS":
|
||||
# 任务成功完成
|
||||
task_result = result.get("result", {})
|
||||
@@ -546,7 +552,7 @@ async def get_write_task_result(
|
||||
else:
|
||||
# 旧格式:直接返回结果
|
||||
return success(data=task_result, msg="写入任务已完成")
|
||||
|
||||
|
||||
elif status == "FAILURE":
|
||||
# 任务失败
|
||||
error_info = result.get("result", "Unknown error")
|
||||
@@ -555,7 +561,7 @@ async def get_write_task_result(
|
||||
else:
|
||||
error_msg = str(error_info)
|
||||
return fail(BizCode.INTERNAL_ERROR, "写入任务失败", error_msg)
|
||||
|
||||
|
||||
elif status in ["PENDING", "STARTED"]:
|
||||
# 任务进行中
|
||||
return success(
|
||||
@@ -575,7 +581,7 @@ async def get_write_task_result(
|
||||
},
|
||||
msg=f"任务状态: {status}"
|
||||
)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"Write task status check failed: {str(e)}", exc_info=True)
|
||||
return fail(BizCode.INTERNAL_ERROR, "任务状态查询失败", str(e))
|
||||
@@ -583,9 +589,9 @@ async def get_write_task_result(
|
||||
|
||||
@router.post("/status_type", response_model=ApiResponse)
|
||||
async def status_type(
|
||||
user_input: Write_UserInput,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
user_input: Write_UserInput,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Determine the type of user message (read or write)
|
||||
@@ -628,9 +634,10 @@ async def status_type(
|
||||
|
||||
@router.get("/stats/types", response_model=ApiResponse)
|
||||
async def get_knowledge_type_stats_api(
|
||||
end_user_id: Optional[str] = Query(None, description="用户ID(可选)"),
|
||||
only_active: bool = Query(True, description="仅统计有效记录(status=1)"),
|
||||
current_user: User = Depends(get_current_user)
|
||||
end_user_id: Optional[str] = Query(None, description="用户ID(可选)"),
|
||||
only_active: bool = Query(True, description="仅统计有效记录(status=1)"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
统计当前空间下各知识库类型的数量,包含 General | Web | Third-party | Folder。
|
||||
@@ -639,14 +646,9 @@ async def get_knowledge_type_stats_api(
|
||||
- 知识库类型根据当前用户的 current_workspace_id 过滤
|
||||
- 如果用户没有当前工作空间,对应的统计返回 0
|
||||
"""
|
||||
api_logger.info(f"Knowledge type stats requested for workspace_id: {current_user.current_workspace_id}, end_user_id: {end_user_id}")
|
||||
api_logger.info(
|
||||
f"Knowledge type stats requested for workspace_id: {current_user.current_workspace_id}, end_user_id: {end_user_id}")
|
||||
try:
|
||||
from app.db import get_db
|
||||
|
||||
# 获取数据库会话
|
||||
db_gen = get_db()
|
||||
db = next(db_gen)
|
||||
|
||||
# 调用service层函数
|
||||
result = await memory_agent_service.get_knowledge_type_stats(
|
||||
end_user_id=end_user_id,
|
||||
@@ -654,48 +656,70 @@ async def get_knowledge_type_stats_api(
|
||||
current_workspace_id=current_user.current_workspace_id,
|
||||
db=db
|
||||
)
|
||||
|
||||
|
||||
return success(data=result, msg="获取知识库类型统计成功")
|
||||
except Exception as e:
|
||||
api_logger.error(f"Knowledge type stats failed: {str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "获取知识库类型统计失败", str(e))
|
||||
|
||||
|
||||
@router.get("/analytics/hot_memory_tags/by_user", response_model=ApiResponse)
|
||||
async def get_hot_memory_tags_by_user_api(
|
||||
end_user_id: Optional[str] = Query(None, description="用户ID(可选)"),
|
||||
limit: int = Query(20, description="返回标签数量限制"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session=Depends(get_db),
|
||||
@router.get("/analytics/interest_distribution/by_user", response_model=ApiResponse)
|
||||
async def get_interest_distribution_by_user_api(
|
||||
end_user_id: str = Query(..., description="用户ID(必填)"),
|
||||
limit: int = Query(5, le=5, description="返回兴趣标签数量限制,最多5个"),
|
||||
language_type: str = Header(default=None, alias="X-Language-Type"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
获取指定用户的热门记忆标签
|
||||
获取指定用户的兴趣分布标签
|
||||
|
||||
注意:标签语言由写入时的 X-Language-Type 决定,查询时不进行翻译
|
||||
与热门标签不同,此接口专注于识别用户的兴趣活动(运动、爱好、学习、创作等),
|
||||
过滤掉纯物品、工具、地点等不代表用户主动参与活动的名词。
|
||||
|
||||
返回格式:
|
||||
[
|
||||
{"name": "标签名", "frequency": 频次},
|
||||
{"name": "兴趣活动名", "frequency": 频次},
|
||||
...
|
||||
]
|
||||
"""
|
||||
api_logger.info(f"Hot memory tags by user requested: end_user_id={end_user_id}")
|
||||
language = get_language_from_header(language_type)
|
||||
api_logger.info(f"Interest distribution by user requested: end_user_id={end_user_id}, language={language}")
|
||||
try:
|
||||
result = await memory_agent_service.get_hot_memory_tags_by_user(
|
||||
# 优先读取缓存
|
||||
cached = await InterestMemoryCache.get_interest_distribution(
|
||||
end_user_id=end_user_id,
|
||||
limit=limit
|
||||
language=language,
|
||||
)
|
||||
return success(data=result, msg="获取热门记忆标签成功")
|
||||
if cached is not None:
|
||||
api_logger.info(f"Interest distribution cache hit: end_user_id={end_user_id}")
|
||||
return success(data=cached, msg="获取兴趣分布标签成功")
|
||||
|
||||
# 缓存未命中,调用模型生成
|
||||
result = await memory_agent_service.get_interest_distribution_by_user(
|
||||
end_user_id=end_user_id,
|
||||
limit=limit,
|
||||
language=language
|
||||
)
|
||||
|
||||
# 写入缓存,24小时过期
|
||||
await InterestMemoryCache.set_interest_distribution(
|
||||
end_user_id=end_user_id,
|
||||
language=language,
|
||||
data=result,
|
||||
)
|
||||
|
||||
return success(data=result, msg="获取兴趣分布标签成功")
|
||||
except Exception as e:
|
||||
api_logger.error(f"Hot memory tags by user failed: {str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "获取热门记忆标签失败", str(e))
|
||||
api_logger.error(f"Interest distribution by user failed: {str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "获取兴趣分布标签失败", str(e))
|
||||
|
||||
|
||||
@router.get("/analytics/user_profile", response_model=ApiResponse)
|
||||
async def get_user_profile_api(
|
||||
end_user_id: Optional[str] = Query(None, description="用户ID(可选)"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
end_user_id: Optional[str] = Query(None, description="用户ID(可选)"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
获取用户详情,包含:
|
||||
@@ -733,17 +757,17 @@ async def get_user_profile_api(
|
||||
# ):
|
||||
# """
|
||||
# Get parsed API documentation (Public endpoint - no authentication required)
|
||||
|
||||
|
||||
# Args:
|
||||
# file_path: Optional path to API docs file. If None, uses default path.
|
||||
|
||||
|
||||
# Returns:
|
||||
# Parsed API documentation including title, meta info, and sections
|
||||
# """
|
||||
# api_logger.info(f"API docs requested, file_path: {file_path or 'default'}")
|
||||
# try:
|
||||
# result = await memory_agent_service.get_api_docs(file_path)
|
||||
|
||||
|
||||
# if result.get("success"):
|
||||
# return success(msg=result["msg"], data=result["data"])
|
||||
# else:
|
||||
@@ -759,9 +783,9 @@ async def get_user_profile_api(
|
||||
|
||||
@router.get("/end_user/{end_user_id}/connected_config", response_model=ApiResponse)
|
||||
async def get_end_user_connected_config(
|
||||
end_user_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
end_user_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
获取终端用户关联的记忆配置
|
||||
@@ -780,9 +804,9 @@ async def get_end_user_connected_config(
|
||||
from app.services.memory_agent_service import (
|
||||
get_end_user_connected_config as get_config,
|
||||
)
|
||||
|
||||
|
||||
api_logger.info(f"Getting connected config for end_user: {end_user_id}")
|
||||
|
||||
|
||||
try:
|
||||
result = get_config(end_user_id, db)
|
||||
return success(data=result, msg="获取终端用户关联配置成功")
|
||||
@@ -791,4 +815,4 @@ async def get_end_user_connected_config(
|
||||
return fail(BizCode.NOT_FOUND, str(e))
|
||||
except Exception as e:
|
||||
api_logger.error(f"Failed to get end user connected config: {str(e)}", exc_info=True)
|
||||
return fail(BizCode.INTERNAL_ERROR, "获取终端用户关联配置失败", str(e))
|
||||
return fail(BizCode.INTERNAL_ERROR, "获取终端用户关联配置失败", str(e))
|
||||
|
||||
@@ -606,8 +606,8 @@ async def dashboard_data(
|
||||
|
||||
# 获取RAG相关数据
|
||||
try:
|
||||
# total_memory: 使用 total_chunk(总chunk数)
|
||||
total_chunk = memory_dashboard_service.get_rag_total_chunk(db, current_user)
|
||||
# total_memory: 只统计用户知识库(permission_id='Memory')的chunk数
|
||||
total_chunk = memory_dashboard_service.get_rag_user_kb_total_chunk(db, current_user)
|
||||
rag_data["total_memory"] = total_chunk
|
||||
|
||||
# total_app: 统计当前空间下的所有app数量
|
||||
|
||||
@@ -1,16 +1,18 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException, status,Header
|
||||
from typing import Optional
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from fastapi import APIRouter, Depends, Header, HTTPException, status
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.language_utils import get_language_from_header
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.core.response_utils import success
|
||||
from app.db import get_db
|
||||
from app.dependencies import get_current_user
|
||||
from app.models.user_model import User
|
||||
|
||||
from app.services.memory_short_service import LongService, ShortService
|
||||
from app.services.memory_storage_service import search_entity
|
||||
from app.services.memory_short_service import ShortService,LongService
|
||||
from dotenv import load_dotenv
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import Optional
|
||||
|
||||
load_dotenv()
|
||||
api_logger = get_api_logger()
|
||||
|
||||
@@ -29,11 +31,11 @@ async def short_term_configs(
|
||||
language = get_language_from_header(language_type)
|
||||
|
||||
# 获取短期记忆数据
|
||||
short_term=ShortService(end_user_id)
|
||||
short_term=ShortService(end_user_id, db)
|
||||
short_result=short_term.get_short_databasets()
|
||||
short_count=short_term.get_short_count()
|
||||
|
||||
long_term=LongService(end_user_id)
|
||||
long_term=LongService(end_user_id, db)
|
||||
long_result=long_term.get_long_databasets()
|
||||
|
||||
entity_result = await search_entity(end_user_id)
|
||||
|
||||
@@ -2,7 +2,7 @@ from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from fastapi.responses import StreamingResponse
|
||||
from fastapi.responses import StreamingResponse, JSONResponse
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.error_codes import BizCode
|
||||
@@ -85,6 +85,7 @@ def create_config(
|
||||
payload: ConfigParamsCreate,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
x_language_type: Optional[str] = Header(None, alias="X-Language-Type"),
|
||||
) -> dict:
|
||||
workspace_id = current_user.current_workspace_id
|
||||
# 检查用户是否已选择工作空间
|
||||
@@ -99,7 +100,29 @@ def create_config(
|
||||
svc = DataConfigService(db)
|
||||
result = svc.create(payload)
|
||||
return success(data=result, msg="创建成功")
|
||||
except ValueError as e:
|
||||
err_str = str(e)
|
||||
if err_str.startswith("DUPLICATE_CONFIG_NAME:"):
|
||||
config_name = err_str.split(":", 1)[1]
|
||||
api_logger.warning(f"重复的配置名称 '{config_name}' 在工作空间 {workspace_id}")
|
||||
lang = get_language_from_header(x_language_type)
|
||||
if lang == "en":
|
||||
msg = fail(BizCode.BAD_REQUEST, "Config name already exists", f"A config named \"{config_name}\" already exists in the current workspace. Please use a different name.")
|
||||
else:
|
||||
msg = fail(BizCode.BAD_REQUEST, "配置名称已存在", f"当前工作空间下已存在名为「{config_name}」的记忆配置,请使用其他名称")
|
||||
return JSONResponse(status_code=400, content=msg)
|
||||
api_logger.error(f"Create config failed: {err_str}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "创建配置失败", err_str)
|
||||
except Exception as e:
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
if isinstance(e, IntegrityError) and "uq_workspace_config_name" in str(getattr(e, 'orig', '')):
|
||||
api_logger.warning(f"重复的配置名称 '{payload.config_name}' 在工作空间 {workspace_id}")
|
||||
lang = get_language_from_header(x_language_type)
|
||||
if lang == "en":
|
||||
msg = fail(BizCode.BAD_REQUEST, "Config name already exists", f"A config named \"{payload.config_name}\" already exists in the current workspace. Please use a different name.")
|
||||
else:
|
||||
msg = fail(BizCode.BAD_REQUEST, "配置名称已存在", f"当前工作空间下已存在名为「{payload.config_name}」的记忆配置,请使用其他名称")
|
||||
return JSONResponse(status_code=400, content=msg)
|
||||
api_logger.error(f"Create config failed: {str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "创建配置失败", str(e))
|
||||
|
||||
@@ -521,10 +544,11 @@ async def clear_hot_memory_tags_cache(
|
||||
@router.get("/analytics/recent_activity_stats", response_model=ApiResponse)
|
||||
async def get_recent_activity_stats_api(
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> dict:
|
||||
api_logger.info("Recent activity stats requested")
|
||||
) -> dict:
|
||||
workspace_id = str(current_user.current_workspace_id) if current_user.current_workspace_id else None
|
||||
api_logger.info(f"Recent activity stats requested: workspace_id={workspace_id}")
|
||||
try:
|
||||
result = await analytics_recent_activity_stats()
|
||||
result = await analytics_recent_activity_stats(workspace_id=workspace_id)
|
||||
return success(data=result, msg="查询成功")
|
||||
except Exception as e:
|
||||
api_logger.error(f"Recent activity stats failed: {str(e)}")
|
||||
|
||||
@@ -371,6 +371,11 @@ def update_model(
|
||||
|
||||
if model_data.type is not None or model_data.provider is not None:
|
||||
raise BusinessException("不允许更改模型类型和供应商", BizCode.INVALID_PARAMETER)
|
||||
|
||||
if model_data.is_active:
|
||||
active_keys = ModelApiKeyService.get_api_keys_by_model(db=db, model_config_id=model_id, is_active=model_data.is_active)
|
||||
if not active_keys:
|
||||
raise BusinessException("请先为该模型配置可用的 API Key", BizCode.INVALID_PARAMETER)
|
||||
|
||||
try:
|
||||
api_logger.debug(f"开始更新模型配置: model_id={model_id}")
|
||||
@@ -469,7 +474,9 @@ async def create_model_api_key_by_provider(
|
||||
config=api_key_data.config,
|
||||
is_active=api_key_data.is_active,
|
||||
priority=api_key_data.priority,
|
||||
model_config_ids=model_config_ids
|
||||
model_config_ids=model_config_ids,
|
||||
capability=api_key_data.capability,
|
||||
is_omni=api_key_data.is_omni
|
||||
)
|
||||
created_keys, failed_models = await ModelApiKeyService.create_api_key_by_provider(db=db, data=create_data)
|
||||
|
||||
|
||||
@@ -25,13 +25,13 @@ from typing import Dict, Optional, List
|
||||
from urllib.parse import quote
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, File, UploadFile, Form, Header
|
||||
from fastapi.responses import StreamingResponse
|
||||
from fastapi.responses import StreamingResponse, JSONResponse
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.language_utils import get_language_from_header
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.core.logging_config import get_api_logger, get_business_logger
|
||||
from app.core.response_utils import fail, success
|
||||
from app.db import get_db
|
||||
from app.dependencies import get_current_user
|
||||
@@ -61,6 +61,7 @@ from app.repositories.ontology_scene_repository import OntologySceneRepository
|
||||
|
||||
|
||||
api_logger = get_api_logger()
|
||||
business_logger = get_business_logger()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(
|
||||
@@ -123,15 +124,23 @@ def _get_ontology_service(
|
||||
)
|
||||
|
||||
# 通过 Repository 获取可用的 API Key(负载均衡逻辑由 Repository 处理)
|
||||
from app.repositories.model_repository import ModelApiKeyRepository
|
||||
api_keys = ModelApiKeyRepository.get_by_model_config(db, model_config.id)
|
||||
if not api_keys:
|
||||
# from app.repositories.model_repository import ModelApiKeyRepository
|
||||
from app.services.model_service import ModelApiKeyService
|
||||
api_key_config = ModelApiKeyService.get_available_api_key(db, model_config.id)
|
||||
if not api_key_config:
|
||||
logger.error(f"Model {llm_id} has no active API key")
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="指定的LLM模型没有可用的API密钥"
|
||||
)
|
||||
api_key_config = api_keys[0]
|
||||
# api_keys = ModelApiKeyRepository.get_by_model_config(db, model_config.id)
|
||||
# if not api_keys:
|
||||
# logger.error(f"Model {llm_id} has no active API key")
|
||||
# raise HTTPException(
|
||||
# status_code=400,
|
||||
# detail="指定的LLM模型没有可用的API密钥"
|
||||
# )
|
||||
# api_key_config = api_keys[0]
|
||||
|
||||
is_composite = getattr(model_config, 'is_composite', False)
|
||||
logger.info(
|
||||
@@ -153,6 +162,7 @@ def _get_ontology_service(
|
||||
provider=actual_provider,
|
||||
api_key=api_key_config.api_key,
|
||||
base_url=api_key_config.api_base,
|
||||
is_omni=api_key_config.is_omni,
|
||||
max_retries=3,
|
||||
timeout=60.0
|
||||
)
|
||||
@@ -279,7 +289,8 @@ async def extract_ontology(
|
||||
async def create_scene(
|
||||
request: SceneCreateRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
current_user: User = Depends(get_current_user),
|
||||
x_language_type: Optional[str] = Header(None, alias="X-Language-Type")
|
||||
):
|
||||
"""创建本体场景
|
||||
|
||||
@@ -350,8 +361,18 @@ async def create_scene(
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", str(e))
|
||||
|
||||
except RuntimeError as e:
|
||||
api_logger.error(f"Runtime error in scene creation: {str(e)}", exc_info=True)
|
||||
return fail(BizCode.INTERNAL_ERROR, "场景创建失败", str(e))
|
||||
err_str = str(e)
|
||||
if "UniqueViolation" in err_str or "uq_workspace_scene_name" in err_str:
|
||||
api_logger.warning(f"Duplicate scene name '{request.scene_name}' in workspace {current_user.current_workspace_id}")
|
||||
from app.core.language_utils import get_language_from_header
|
||||
lang = get_language_from_header(x_language_type)
|
||||
if lang == "en":
|
||||
msg = fail(BizCode.BAD_REQUEST, "Scene name already exists", f"A scene named \"{request.scene_name}\" already exists in the current workspace. Please use a different name.")
|
||||
else:
|
||||
msg = fail(BizCode.BAD_REQUEST, "场景名称已存在", f"当前工作空间下已存在名为「{request.scene_name}」的场景,请使用其他名称")
|
||||
return JSONResponse(status_code=400, content=msg)
|
||||
api_logger.error(f"Runtime error in scene creation: {err_str}", exc_info=True)
|
||||
return fail(BizCode.INTERNAL_ERROR, "场景创建失败", err_str)
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"Unexpected error in scene creation: {str(e)}", exc_info=True)
|
||||
@@ -399,6 +420,20 @@ async def update_scene(
|
||||
api_logger.warning(f"User {current_user.id} has no current workspace")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "当前用户没有工作空间")
|
||||
|
||||
# 检查是否为系统默认场景
|
||||
scene_repo = OntologySceneRepository(db)
|
||||
scene = scene_repo.get_by_id(scene_uuid)
|
||||
if scene and scene.is_system_default:
|
||||
business_logger.warning(
|
||||
f"尝试修改系统默认场景: user_id={current_user.id}, "
|
||||
f"scene_id={scene_id}, scene_name={scene.scene_name}"
|
||||
)
|
||||
return fail(
|
||||
BizCode.BAD_REQUEST,
|
||||
"系统默认场景不可修改",
|
||||
"该场景为系统预设场景,不允许修改"
|
||||
)
|
||||
|
||||
# 创建OntologyService实例
|
||||
from app.core.memory.llm_tools.openai_client import OpenAIClient
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
@@ -491,6 +526,19 @@ async def delete_scene(
|
||||
api_logger.warning(f"User {current_user.id} has no current workspace")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "当前用户没有工作空间")
|
||||
|
||||
# 检查是否为系统默认场景
|
||||
scene_repo = OntologySceneRepository(db)
|
||||
scene = scene_repo.get_by_id(scene_uuid)
|
||||
if scene and scene.is_system_default:
|
||||
business_logger.warning(
|
||||
f"尝试删除系统默认场景: user_id={current_user.id}, "
|
||||
f"scene_id={scene_id}, scene_name={scene.scene_name}"
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="SYSTEM_DEFAULT_SCENE_CANNOT_DELETE"
|
||||
)
|
||||
|
||||
# 创建OntologyService实例
|
||||
from app.core.memory.llm_tools.openai_client import OpenAIClient
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
@@ -514,6 +562,9 @@ async def delete_scene(
|
||||
|
||||
return success(data={"deleted": success_flag}, msg="场景删除成功")
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
|
||||
except ValueError as e:
|
||||
api_logger.warning(f"Validation error in scene deletion: {str(e)}")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", str(e))
|
||||
@@ -621,7 +672,8 @@ async def get_scenes(
|
||||
async def create_class(
|
||||
request: ClassCreateRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
current_user: User = Depends(get_current_user),
|
||||
x_language_type: Optional[str] = Header(None, alias="X-Language-Type")
|
||||
):
|
||||
"""创建本体类型
|
||||
|
||||
@@ -636,7 +688,7 @@ async def create_class(
|
||||
ApiResponse: 包含创建的类型信息
|
||||
"""
|
||||
from app.controllers.ontology_secondary_routes import create_class_handler
|
||||
return await create_class_handler(request, db, current_user)
|
||||
return await create_class_handler(request, db, current_user, x_language_type)
|
||||
|
||||
|
||||
@router.put("/class/{class_id}", response_model=ApiResponse)
|
||||
|
||||
@@ -7,11 +7,11 @@
|
||||
from uuid import UUID
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import Depends
|
||||
from fastapi import Depends, Header
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.core.logging_config import get_api_logger, get_business_logger
|
||||
from app.core.response_utils import fail, success
|
||||
from app.db import get_db
|
||||
from app.dependencies import get_current_user
|
||||
@@ -30,9 +30,11 @@ from app.schemas.response_schema import ApiResponse
|
||||
from app.services.ontology_service import OntologyService
|
||||
from app.core.memory.llm_tools.openai_client import OpenAIClient
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
from app.repositories.ontology_class_repository import OntologyClassRepository
|
||||
|
||||
|
||||
api_logger = get_api_logger()
|
||||
business_logger = get_business_logger()
|
||||
|
||||
|
||||
def _get_dummy_ontology_service(db: Session) -> OntologyService:
|
||||
@@ -56,7 +58,7 @@ async def scenes_handler(
|
||||
workspace_id: Optional[str] = None,
|
||||
scene_name: Optional[str] = None,
|
||||
page: Optional[int] = None,
|
||||
page_size: Optional[int] = None,
|
||||
pagesize: Optional[int] = None,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
@@ -69,14 +71,14 @@ async def scenes_handler(
|
||||
workspace_id: 工作空间ID(可选,默认当前用户工作空间)
|
||||
scene_name: 场景名称关键词(可选,支持模糊匹配)
|
||||
page: 页码(可选,从1开始,仅在全量查询时有效)
|
||||
page_size: 每页数量(可选,仅在全量查询时有效)
|
||||
pagesize: 每页数量(可选,仅在全量查询时有效)
|
||||
db: 数据库会话
|
||||
current_user: 当前用户
|
||||
"""
|
||||
operation = "search" if scene_name else "list"
|
||||
api_logger.info(
|
||||
f"Scene {operation} requested by user {current_user.id}, "
|
||||
f"workspace_id={workspace_id}, keyword={scene_name}, page={page}, page_size={page_size}"
|
||||
f"workspace_id={workspace_id}, keyword={scene_name}, page={page}, pagesize={pagesize}"
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -103,13 +105,13 @@ async def scenes_handler(
|
||||
api_logger.warning(f"Invalid page number: {page}")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "页码必须大于0")
|
||||
|
||||
if page_size is not None and page_size < 1:
|
||||
api_logger.warning(f"Invalid page_size: {page_size}")
|
||||
if pagesize is not None and pagesize < 1:
|
||||
api_logger.warning(f"Invalid pagesize: {pagesize}")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "每页数量必须大于0")
|
||||
|
||||
# 如果只提供了page或page_size中的一个,返回错误
|
||||
if (page is not None and page_size is None) or (page is None and page_size is not None):
|
||||
api_logger.warning(f"Incomplete pagination params: page={page}, page_size={page_size}")
|
||||
# 如果只提供了page或pagesize中的一个,返回错误
|
||||
if (page is not None and pagesize is None) or (page is None and pagesize is not None):
|
||||
api_logger.warning(f"Incomplete pagination params: page={page}, pagesize={pagesize}")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "分页参数page和pagesize必须同时提供")
|
||||
|
||||
# 模糊搜索场景(支持分页)
|
||||
@@ -117,17 +119,15 @@ async def scenes_handler(
|
||||
total = len(scenes)
|
||||
|
||||
# 如果提供了分页参数,进行分页处理
|
||||
if page is not None and page_size is not None:
|
||||
start_idx = (page - 1) * page_size
|
||||
end_idx = start_idx + page_size
|
||||
if page is not None and pagesize is not None:
|
||||
start_idx = (page - 1) * pagesize
|
||||
end_idx = start_idx + pagesize
|
||||
scenes = scenes[start_idx:end_idx]
|
||||
|
||||
# 构建响应
|
||||
items = []
|
||||
for scene in scenes:
|
||||
# 获取前3个class_name作为entity_type
|
||||
entity_type = [cls.class_name for cls in scene.classes[:3]] if scene.classes else None
|
||||
# 动态计算 type_num
|
||||
type_num = len(scene.classes) if scene.classes else 0
|
||||
|
||||
items.append(SceneResponse(
|
||||
@@ -139,17 +139,16 @@ async def scenes_handler(
|
||||
workspace_id=scene.workspace_id,
|
||||
created_at=scene.created_at,
|
||||
updated_at=scene.updated_at,
|
||||
classes_count=type_num
|
||||
classes_count=type_num,
|
||||
is_system_default=scene.is_system_default
|
||||
))
|
||||
|
||||
# 构建响应(包含分页信息)
|
||||
if page is not None and page_size is not None:
|
||||
# 计算是否有下一页
|
||||
hasnext = (page * page_size) < total
|
||||
|
||||
if page is not None and pagesize is not None:
|
||||
hasnext = (page * pagesize) < total
|
||||
pagination_info = PaginationInfo(
|
||||
page=page,
|
||||
pagesize=page_size,
|
||||
pagesize=pagesize,
|
||||
total=total,
|
||||
hasnext=hasnext
|
||||
)
|
||||
@@ -163,28 +162,25 @@ async def scenes_handler(
|
||||
)
|
||||
else:
|
||||
# 获取所有场景(支持分页)
|
||||
# 验证分页参数
|
||||
if page is not None and page < 1:
|
||||
api_logger.warning(f"Invalid page number: {page}")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "页码必须大于0")
|
||||
|
||||
if page_size is not None and page_size < 1:
|
||||
api_logger.warning(f"Invalid page_size: {page_size}")
|
||||
if pagesize is not None and pagesize < 1:
|
||||
api_logger.warning(f"Invalid pagesize: {pagesize}")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "每页数量必须大于0")
|
||||
|
||||
# 如果只提供了page或page_size中的一个,返回错误
|
||||
if (page is not None and page_size is None) or (page is None and page_size is not None):
|
||||
api_logger.warning(f"Incomplete pagination params: page={page}, page_size={page_size}")
|
||||
# 如果只提供了page或pagesize中的一个,返回错误
|
||||
if (page is not None and pagesize is None) or (page is None and pagesize is not None):
|
||||
api_logger.warning(f"Incomplete pagination params: page={page}, pagesize={pagesize}")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "分页参数page和pagesize必须同时提供")
|
||||
|
||||
scenes, total = service.list_scenes(ws_uuid, page, page_size)
|
||||
scenes, total = service.list_scenes(ws_uuid, page, pagesize)
|
||||
|
||||
# 构建响应
|
||||
items = []
|
||||
for scene in scenes:
|
||||
# 获取前3个class_name作为entity_type
|
||||
entity_type = [cls.class_name for cls in scene.classes[:3]] if scene.classes else None
|
||||
# 动态计算 type_num
|
||||
type_num = len(scene.classes) if scene.classes else 0
|
||||
|
||||
items.append(SceneResponse(
|
||||
@@ -196,17 +192,16 @@ async def scenes_handler(
|
||||
workspace_id=scene.workspace_id,
|
||||
created_at=scene.created_at,
|
||||
updated_at=scene.updated_at,
|
||||
classes_count=type_num
|
||||
classes_count=type_num,
|
||||
is_system_default=scene.is_system_default
|
||||
))
|
||||
|
||||
# 构建响应(包含分页信息)
|
||||
if page is not None and page_size is not None:
|
||||
# 计算是否有下一页
|
||||
hasnext = (page * page_size) < total
|
||||
|
||||
if page is not None and pagesize is not None:
|
||||
hasnext = (page * pagesize) < total
|
||||
pagination_info = PaginationInfo(
|
||||
page=page,
|
||||
pagesize=page_size,
|
||||
pagesize=pagesize,
|
||||
total=total,
|
||||
hasnext=hasnext
|
||||
)
|
||||
@@ -236,7 +231,8 @@ async def scenes_handler(
|
||||
async def create_class_handler(
|
||||
request: ClassCreateRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
current_user: User = Depends(get_current_user),
|
||||
x_language_type: Optional[str] = None
|
||||
):
|
||||
"""创建本体类型(统一使用列表形式,支持单个或批量)"""
|
||||
|
||||
@@ -269,8 +265,11 @@ async def create_class_handler(
|
||||
]
|
||||
|
||||
if count == 1:
|
||||
# 单个创建
|
||||
# 单个创建 - 先检查重名
|
||||
class_data = classes_data[0]
|
||||
existing = OntologyClassRepository(db).get_by_name(class_data["class_name"], request.scene_id)
|
||||
if existing:
|
||||
raise ValueError(f"DUPLICATE_CLASS_NAME:{class_data['class_name']}")
|
||||
ontology_class = service.create_class(
|
||||
scene_id=request.scene_id,
|
||||
class_name=class_data["class_name"],
|
||||
@@ -328,12 +327,36 @@ async def create_class_handler(
|
||||
return success(data=response.model_dump(mode='json'), msg="批量创建完成")
|
||||
|
||||
except ValueError as e:
|
||||
api_logger.warning(f"Validation error in class creation: {str(e)}")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", str(e))
|
||||
|
||||
err_str = str(e)
|
||||
if err_str.startswith("DUPLICATE_CLASS_NAME:"):
|
||||
class_name = err_str.split(":", 1)[1]
|
||||
api_logger.warning(f"Duplicate class name '{class_name}' in scene {request.scene_id}")
|
||||
from app.core.language_utils import get_language_from_header
|
||||
from fastapi.responses import JSONResponse
|
||||
lang = get_language_from_header(x_language_type)
|
||||
if lang == "en":
|
||||
msg = fail(BizCode.BAD_REQUEST, "Class name already exists", f"A class named \"{class_name}\" already exists in this scene. Please use a different name.")
|
||||
else:
|
||||
msg = fail(BizCode.BAD_REQUEST, "类型名称已存在", f"当前场景下已存在名为「{class_name}」的类型,请使用其他名称")
|
||||
return JSONResponse(status_code=400, content=msg)
|
||||
api_logger.warning(f"Validation error in class creation: {err_str}")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", err_str)
|
||||
|
||||
except RuntimeError as e:
|
||||
api_logger.error(f"Runtime error in class creation: {str(e)}", exc_info=True)
|
||||
return fail(BizCode.INTERNAL_ERROR, "类型创建失败", str(e))
|
||||
err_str = str(e)
|
||||
if "UniqueViolation" in err_str or "uq_scene_class_name" in err_str:
|
||||
api_logger.warning(f"Duplicate class name in scene {request.scene_id}")
|
||||
from app.core.language_utils import get_language_from_header
|
||||
from fastapi.responses import JSONResponse
|
||||
lang = get_language_from_header(x_language_type)
|
||||
class_name = request.classes[0].class_name if request.classes else ""
|
||||
if lang == "en":
|
||||
msg = fail(BizCode.BAD_REQUEST, "Class name already exists", f"A class named \"{class_name}\" already exists in this scene. Please use a different name.")
|
||||
else:
|
||||
msg = fail(BizCode.BAD_REQUEST, "类型名称已存在", f"当前场景下已存在名为「{class_name}」的类型,请使用其他名称")
|
||||
return JSONResponse(status_code=400, content=msg)
|
||||
api_logger.error(f"Runtime error in class creation: {err_str}", exc_info=True)
|
||||
return fail(BizCode.INTERNAL_ERROR, "类型创建失败", err_str)
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"Unexpected error in class creation: {str(e)}", exc_info=True)
|
||||
@@ -366,6 +389,20 @@ async def update_class_handler(
|
||||
api_logger.warning(f"User {current_user.id} has no current workspace")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "当前用户没有工作空间")
|
||||
|
||||
# 检查是否为系统默认类型
|
||||
class_repo = OntologyClassRepository(db)
|
||||
ontology_class = class_repo.get_by_id(class_uuid)
|
||||
if ontology_class and ontology_class.is_system_default:
|
||||
business_logger.warning(
|
||||
f"尝试修改系统默认类型: user_id={current_user.id}, "
|
||||
f"class_id={class_id}, class_name={ontology_class.class_name}"
|
||||
)
|
||||
return fail(
|
||||
BizCode.BAD_REQUEST,
|
||||
"系统默认类型不可修改",
|
||||
"该类型为系统预设类型,不允许修改"
|
||||
)
|
||||
|
||||
# 创建Service
|
||||
service = _get_dummy_ontology_service(db)
|
||||
|
||||
@@ -429,6 +466,20 @@ async def delete_class_handler(
|
||||
api_logger.warning(f"User {current_user.id} has no current workspace")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "当前用户没有工作空间")
|
||||
|
||||
# 检查是否为系统默认类型
|
||||
class_repo = OntologyClassRepository(db)
|
||||
ontology_class = class_repo.get_by_id(class_uuid)
|
||||
if ontology_class and ontology_class.is_system_default:
|
||||
business_logger.warning(
|
||||
f"尝试删除系统默认类型: user_id={current_user.id}, "
|
||||
f"class_id={class_id}, class_name={ontology_class.class_name}"
|
||||
)
|
||||
return fail(
|
||||
BizCode.BAD_REQUEST,
|
||||
"系统默认类型不可删除",
|
||||
"该类型为系统预设类型,不允许删除"
|
||||
)
|
||||
|
||||
# 创建Service
|
||||
service = _get_dummy_ontology_service(db)
|
||||
|
||||
@@ -585,6 +636,7 @@ async def classes_handler(
|
||||
scene_id=scene_uuid,
|
||||
scene_name=scene.scene_name,
|
||||
scene_description=scene.scene_description,
|
||||
is_system_default=scene.is_system_default,
|
||||
items=items
|
||||
)
|
||||
|
||||
|
||||
@@ -2,25 +2,32 @@ import hashlib
|
||||
import json
|
||||
import uuid
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends, Query, Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.core.response_utils import success
|
||||
from app.core.response_utils import success, fail
|
||||
from app.db import get_db, get_db_read
|
||||
from app.dependencies import get_share_user_id, ShareTokenData
|
||||
from app.models.app_model import App
|
||||
from app.models.app_model import AppType
|
||||
from app.repositories import knowledge_repository
|
||||
from app.repositories.end_user_repository import EndUserRepository
|
||||
from app.repositories.workflow_repository import WorkflowConfigRepository
|
||||
from app.schemas import release_share_schema, conversation_schema
|
||||
from app.schemas.response_schema import PageData, PageMeta
|
||||
from app.services import workspace_service
|
||||
from app.services.app_chat_service import AppChatService, get_app_chat_service
|
||||
from app.services.auth_service import create_access_token
|
||||
from app.services.conversation_service import ConversationService
|
||||
from app.services.release_share_service import ReleaseShareService
|
||||
from app.services.shared_chat_service import SharedChatService
|
||||
from app.services.app_chat_service import AppChatService, get_app_chat_service
|
||||
from app.utils.app_config_utils import dict_to_multi_agent_config, workflow_config_4_app_release, \
|
||||
from app.services.workflow_service import WorkflowService
|
||||
from app.utils.app_config_utils import workflow_config_4_app_release, \
|
||||
agent_config_4_app_release, multi_agent_config_4_app_release
|
||||
|
||||
router = APIRouter(prefix="/public/share", tags=["Public Share"])
|
||||
@@ -206,15 +213,13 @@ def list_conversations(
|
||||
logger.debug(f"share_data:{share_data.user_id}")
|
||||
other_id = share_data.user_id
|
||||
service = SharedChatService(db)
|
||||
share, release = service._get_release_by_share_token(share_data.share_token, password)
|
||||
from app.repositories.end_user_repository import EndUserRepository
|
||||
share, release = service.get_release_by_share_token(share_data.share_token, password)
|
||||
end_user_repo = EndUserRepository(db)
|
||||
new_end_user = end_user_repo.get_or_create_end_user(
|
||||
app_id=share.app_id,
|
||||
other_id=other_id
|
||||
)
|
||||
logger.debug(new_end_user.id)
|
||||
service = SharedChatService(db)
|
||||
conversations, total = service.list_conversations(
|
||||
share_token=share_data.share_token,
|
||||
user_id=str(new_end_user.id),
|
||||
@@ -293,19 +298,15 @@ async def chat(
|
||||
|
||||
# 提前验证和准备(在流式响应开始前完成)
|
||||
# 这样可以确保错误能正确返回,而不是在流式响应中间出错
|
||||
from app.models.app_model import AppType
|
||||
|
||||
try:
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.error_codes import BizCode
|
||||
from app.services.app_service import AppService
|
||||
# 验证分享链接和密码
|
||||
share, release = service._get_release_by_share_token(share_token, password)
|
||||
share, release = service.get_release_by_share_token(share_token, password)
|
||||
|
||||
# # Create end_user_id by concatenating app_id with user_id
|
||||
# end_user_id = f"{share.app_id}_{user_id}"
|
||||
|
||||
# Store end_user_id in database with original user_id
|
||||
from app.repositories.end_user_repository import EndUserRepository
|
||||
end_user_repo = EndUserRepository(db)
|
||||
new_end_user = end_user_repo.get_or_create_end_user(
|
||||
app_id=share.app_id,
|
||||
@@ -318,7 +319,6 @@ async def chat(
|
||||
"""获取存储类型和工作空间的ID"""
|
||||
|
||||
# 直接通过 SQLAlchemy 查询 app(仅查询未删除的应用)
|
||||
from app.models.app_model import App
|
||||
app = db.query(App).filter(
|
||||
App.id == appid,
|
||||
App.is_active.is_(True)
|
||||
@@ -359,12 +359,12 @@ async def chat(
|
||||
app_type = release.app.type if release.app else None
|
||||
|
||||
# 根据应用类型验证配置
|
||||
if app_type == "agent":
|
||||
if app_type == AppType.AGENT:
|
||||
# Agent 类型:验证模型配置
|
||||
model_config_id = release.default_model_config_id
|
||||
if not model_config_id:
|
||||
raise BusinessException("Agent 应用未配置模型", BizCode.AGENT_CONFIG_MISSING)
|
||||
elif app_type == "multi_agent":
|
||||
elif app_type == AppType.MULTI_AGENT:
|
||||
# Multi-Agent 类型:验证多 Agent 配置
|
||||
config = release.config or {}
|
||||
if not config.get("sub_agents"):
|
||||
@@ -638,6 +638,34 @@ async def chat(
|
||||
# return success(data=conversation_schema.ChatResponse(**result).model_dump(mode="json"))
|
||||
|
||||
else:
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.error_codes import BizCode
|
||||
raise BusinessException(f"不支持的应用类型: {app_type}", BizCode.APP_TYPE_NOT_SUPPORTED)
|
||||
|
||||
|
||||
@router.get("/config", summary="获取应用启动配置")
|
||||
async def config_query(
|
||||
password: str = Query(None, description="访问密码"),
|
||||
share_data: ShareTokenData = Depends(get_share_user_id),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
share_service = SharedChatService(db)
|
||||
share_token = share_data.share_token
|
||||
share, release = share_service.get_release_by_share_token(share_token, password)
|
||||
if release.app.type == AppType.WORKFLOW:
|
||||
workflow_service = WorkflowService(db)
|
||||
content = {
|
||||
"app_type": release.app.type,
|
||||
"variables": workflow_service.get_start_node_variables(release.config)
|
||||
}
|
||||
elif release.app.type == AppType.AGENT:
|
||||
content = {
|
||||
"app_type": release.app.type,
|
||||
"variables": release.config.get("variables")
|
||||
}
|
||||
elif release.app.type == AppType.MULTI_AGENT:
|
||||
content = {
|
||||
"app_type": release.app.type,
|
||||
"variables": []
|
||||
}
|
||||
else:
|
||||
return fail(msg="Unsupported app type", code=BizCode.APP_TYPE_NOT_SUPPORTED)
|
||||
return success(data=content)
|
||||
|
||||
@@ -89,7 +89,6 @@ async def chat(
|
||||
body = await request.json()
|
||||
payload = AppChatRequest(**body)
|
||||
|
||||
other_id = payload.user_id
|
||||
app = app_service.get_app(api_key_auth.resource_id, api_key_auth.workspace_id)
|
||||
other_id = payload.user_id
|
||||
workspace_id = app.workspace_id
|
||||
@@ -135,7 +134,8 @@ async def chat(
|
||||
app_id=app.id,
|
||||
workspace_id=workspace_id,
|
||||
user_id=end_user_id,
|
||||
is_draft=False
|
||||
is_draft=False,
|
||||
conversation_id=payload.conversation_id
|
||||
)
|
||||
|
||||
if app_type == AppType.AGENT:
|
||||
@@ -249,6 +249,7 @@ async def chat(
|
||||
app_id=app.id,
|
||||
workspace_id=workspace_id,
|
||||
release_id=app.current_release.id,
|
||||
public=True
|
||||
):
|
||||
event_type = event.get("event", "message")
|
||||
event_data = event.get("data", {})
|
||||
|
||||
@@ -39,7 +39,7 @@ async def write_memory_api_service(
|
||||
|
||||
Stores memory content for the specified end user using the Memory API Service.
|
||||
"""
|
||||
logger.info(f"Memory write request - end_user_id: {payload.end_user_id}, tenant_id: {api_key_auth.tenant_id}")
|
||||
logger.info(f"Memory write request - end_user_id: {payload.end_user_id}, workspace_id: {api_key_auth.workspace_id}")
|
||||
|
||||
memory_api_service = MemoryAPIService(db)
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import uuid
|
||||
from typing import List, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||||
from fastapi import APIRouter, Depends, Header, HTTPException, Query, status
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.logging_config import get_api_logger
|
||||
@@ -95,16 +95,29 @@ def get_workspaces(
|
||||
@router.post("", response_model=ApiResponse)
|
||||
def create_workspace(
|
||||
workspace: WorkspaceCreate,
|
||||
language_type: str = Header(default="zh", alias="X-Language-Type"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_superuser),
|
||||
):
|
||||
"""创建新的工作空间"""
|
||||
api_logger.info(f"用户 {current_user.username} 请求创建工作空间: {workspace.name}")
|
||||
from app.core.language_utils import get_language_from_header
|
||||
|
||||
# 验证并获取语言参数
|
||||
language = get_language_from_header(language_type)
|
||||
|
||||
api_logger.info(
|
||||
f"用户 {current_user.username} 请求创建工作空间: {workspace.name}, "
|
||||
f"language={language}"
|
||||
)
|
||||
|
||||
result = workspace_service.create_workspace(
|
||||
db=db, workspace=workspace, user=current_user)
|
||||
db=db, workspace=workspace, user=current_user, language=language
|
||||
)
|
||||
|
||||
api_logger.info(f"工作空间创建成功 - 名称: {workspace.name}, ID: {result.id}, 创建者: {current_user.username}")
|
||||
api_logger.info(
|
||||
f"工作空间创建成功 - 名称: {workspace.name}, ID: {result.id}, "
|
||||
f"创建者: {current_user.username}, language={language}"
|
||||
)
|
||||
result_schema = WorkspaceResponse.model_validate(result)
|
||||
return success(data=result_schema, msg="工作空间创建成功")
|
||||
|
||||
|
||||
@@ -11,35 +11,37 @@ LangChain Agent 封装
|
||||
import time
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional, Sequence
|
||||
|
||||
from app.core.memory.agent.langgraph_graph.write_graph import write_long_term
|
||||
from app.core.memory.agent.langgraph_graph.write_graph import write_long_term
|
||||
from app.db import get_db
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.core.models import RedBearLLM, RedBearModelConfig
|
||||
from app.models.models_model import ModelType
|
||||
from app.models.models_model import ModelType, ModelProvider
|
||||
from app.services.memory_agent_service import (
|
||||
get_end_user_connected_config,
|
||||
)
|
||||
from langchain.agents import create_agent
|
||||
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
|
||||
from langchain_core.tools import BaseTool
|
||||
|
||||
logger = get_business_logger()
|
||||
|
||||
|
||||
class LangChainAgent:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str,
|
||||
api_key: str,
|
||||
provider: str = "openai",
|
||||
api_base: Optional[str] = None,
|
||||
temperature: float = 0.7,
|
||||
max_tokens: int = 2000,
|
||||
system_prompt: Optional[str] = None,
|
||||
tools: Optional[Sequence[BaseTool]] = None,
|
||||
streaming: bool = False,
|
||||
max_iterations: Optional[int] = None, # 最大迭代次数(None 表示自动计算)
|
||||
max_tool_consecutive_calls: int = 3 # 单个工具最大连续调用次数
|
||||
self,
|
||||
model_name: str,
|
||||
api_key: str,
|
||||
provider: str = "openai",
|
||||
api_base: Optional[str] = None,
|
||||
is_omni: bool = False,
|
||||
temperature: float = 0.7,
|
||||
max_tokens: int = 2000,
|
||||
system_prompt: Optional[str] = None,
|
||||
tools: Optional[Sequence[BaseTool]] = None,
|
||||
streaming: bool = False,
|
||||
max_iterations: Optional[int] = None, # 最大迭代次数(None 表示自动计算)
|
||||
max_tool_consecutive_calls: int = 3 # 单个工具最大连续调用次数
|
||||
):
|
||||
"""初始化 LangChain Agent
|
||||
|
||||
@@ -60,12 +62,13 @@ class LangChainAgent:
|
||||
self.provider = provider
|
||||
self.tools = tools or []
|
||||
self.streaming = streaming
|
||||
self.is_omni = is_omni
|
||||
self.max_tool_consecutive_calls = max_tool_consecutive_calls
|
||||
|
||||
|
||||
# 工具调用计数器:记录每个工具的连续调用次数
|
||||
self.tool_call_counter: Dict[str, int] = {}
|
||||
self.last_tool_called: Optional[str] = None
|
||||
|
||||
|
||||
# 根据工具数量动态调整最大迭代次数
|
||||
# 基础值 + 每个工具额外的调用机会
|
||||
if max_iterations is None:
|
||||
@@ -73,9 +76,9 @@ class LangChainAgent:
|
||||
self.max_iterations = 5 + len(self.tools) * 2
|
||||
else:
|
||||
self.max_iterations = max_iterations
|
||||
|
||||
|
||||
self.system_prompt = system_prompt or "你是一个专业的AI助手"
|
||||
|
||||
|
||||
logger.debug(
|
||||
f"Agent 迭代次数配置: max_iterations={self.max_iterations}, "
|
||||
f"tool_count={len(self.tools)}, "
|
||||
@@ -89,6 +92,7 @@ class LangChainAgent:
|
||||
provider=provider,
|
||||
api_key=api_key,
|
||||
base_url=api_base,
|
||||
is_omni=is_omni,
|
||||
extra_params={
|
||||
"temperature": temperature,
|
||||
"max_tokens": max_tokens,
|
||||
@@ -143,21 +147,22 @@ class LangChainAgent:
|
||||
"""
|
||||
from langchain_core.tools import StructuredTool
|
||||
from functools import wraps
|
||||
|
||||
|
||||
wrapped_tools = []
|
||||
|
||||
|
||||
for original_tool in tools:
|
||||
tool_name = original_tool.name
|
||||
original_func = original_tool.func if hasattr(original_tool, 'func') else None
|
||||
|
||||
|
||||
if not original_func:
|
||||
# 如果无法获取原始函数,直接使用原工具
|
||||
wrapped_tools.append(original_tool)
|
||||
continue
|
||||
|
||||
|
||||
# 创建包装函数
|
||||
def make_wrapped_func(tool_name, original_func):
|
||||
"""创建包装函数的工厂函数,避免闭包问题"""
|
||||
|
||||
@wraps(original_func)
|
||||
def wrapped_func(*args, **kwargs):
|
||||
"""包装后的工具函数,跟踪连续调用次数"""
|
||||
@@ -168,13 +173,13 @@ class LangChainAgent:
|
||||
# 切换到新工具,重置计数器
|
||||
self.tool_call_counter[tool_name] = 1
|
||||
self.last_tool_called = tool_name
|
||||
|
||||
|
||||
current_count = self.tool_call_counter[tool_name]
|
||||
|
||||
|
||||
logger.debug(
|
||||
f"工具调用: {tool_name}, 连续调用次数: {current_count}/{self.max_tool_consecutive_calls}"
|
||||
)
|
||||
|
||||
|
||||
# 检查是否超过最大连续调用次数
|
||||
if current_count > self.max_tool_consecutive_calls:
|
||||
logger.warning(
|
||||
@@ -185,12 +190,12 @@ class LangChainAgent:
|
||||
f"工具 '{tool_name}' 已连续调用 {self.max_tool_consecutive_calls} 次,"
|
||||
f"未找到有效结果。请尝试其他方法或直接回答用户的问题。"
|
||||
)
|
||||
|
||||
|
||||
# 调用原始工具函数
|
||||
return original_func(*args, **kwargs)
|
||||
|
||||
|
||||
return wrapped_func
|
||||
|
||||
|
||||
# 使用 StructuredTool 创建新工具
|
||||
wrapped_tool = StructuredTool(
|
||||
name=original_tool.name,
|
||||
@@ -198,17 +203,17 @@ class LangChainAgent:
|
||||
func=make_wrapped_func(tool_name, original_func),
|
||||
args_schema=original_tool.args_schema if hasattr(original_tool, 'args_schema') else None
|
||||
)
|
||||
|
||||
|
||||
wrapped_tools.append(wrapped_tool)
|
||||
|
||||
|
||||
return wrapped_tools
|
||||
|
||||
def _prepare_messages(
|
||||
self,
|
||||
message: str,
|
||||
history: Optional[List[Dict[str, str]]] = None,
|
||||
context: Optional[str] = None,
|
||||
files: Optional[List[Dict[str, Any]]] = None
|
||||
self,
|
||||
message: str,
|
||||
history: Optional[List[Dict[str, str]]] = None,
|
||||
context: Optional[str] = None,
|
||||
files: Optional[List[Dict[str, Any]]] = None
|
||||
) -> List[BaseMessage]:
|
||||
"""准备消息列表
|
||||
|
||||
@@ -248,7 +253,7 @@ class LangChainAgent:
|
||||
messages.append(HumanMessage(content=user_content))
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
def _build_multimodal_content(self, text: str, files: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
构建多模态消息内容
|
||||
@@ -261,23 +266,26 @@ class LangChainAgent:
|
||||
List[Dict]: 消息内容列表
|
||||
"""
|
||||
# 根据 provider 使用不同的文本格式
|
||||
if self.provider.lower() in ["bedrock", "anthropic"]:
|
||||
# Anthropic/Bedrock: {"type": "text", "text": "..."}
|
||||
content_parts = [{"type": "text", "text": text}]
|
||||
else:
|
||||
# 通义千问等: {"text": "..."}
|
||||
content_parts = [{"text": text}]
|
||||
|
||||
# if (self.provider.lower() in [ModelProvider.BEDROCK, ModelProvider.OPENAI, ModelProvider.XINFERENCE,
|
||||
# ModelProvider.GPUSTACK] or (
|
||||
# self.provider.lower() == ModelProvider.DASHSCOPE and self.is_omni)):
|
||||
# # Anthropic/Bedrock/Xinference/Gpustack/Openai: {"type": "text", "text": "..."}
|
||||
# content_parts = [{"type": "text", "text": text}]
|
||||
# else:
|
||||
# # 通义千问等: {"text": "..."}
|
||||
# content_parts = [{"type": "text", "text": text}]
|
||||
content_parts = [{"type": "text", "text": text}]
|
||||
|
||||
# 添加文件内容
|
||||
# MultimodalService 已经根据 provider 返回了正确格式,直接使用
|
||||
content_parts.extend(files)
|
||||
|
||||
|
||||
logger.debug(
|
||||
f"构建多模态消息: provider={self.provider}, "
|
||||
f"parts={len(content_parts)}, "
|
||||
f"files={len(files)}"
|
||||
)
|
||||
|
||||
|
||||
return content_parts
|
||||
|
||||
async def chat(
|
||||
@@ -302,7 +310,7 @@ class LangChainAgent:
|
||||
Returns:
|
||||
Dict: 包含 content 和元数据的字典
|
||||
"""
|
||||
message_chat= message
|
||||
message_chat = message
|
||||
start_time = time.time()
|
||||
actual_config_id = config_id
|
||||
# If config_id is None, try to get from end_user's connected config
|
||||
@@ -322,8 +330,8 @@ class LangChainAgent:
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get db session: {e}")
|
||||
actual_end_user_id = end_user_id if end_user_id is not None else "unknown"
|
||||
logger.info(f'写入类型{storage_type,str(end_user_id), message, str(user_rag_memory_id)}')
|
||||
print(f'写入类型{storage_type,str(end_user_id), message, str(user_rag_memory_id)}')
|
||||
logger.info(f'写入类型{storage_type, str(end_user_id), message, str(user_rag_memory_id)}')
|
||||
print(f'写入类型{storage_type, str(end_user_id), message, str(user_rag_memory_id)}')
|
||||
try:
|
||||
# 准备消息列表(支持多模态)
|
||||
messages = self._prepare_messages(message, history, context, files)
|
||||
@@ -367,14 +375,14 @@ class LangChainAgent:
|
||||
# 获取最后的 AI 消息
|
||||
output_messages = result.get("messages", [])
|
||||
content = ""
|
||||
|
||||
|
||||
logger.debug(f"输出消息数量: {len(output_messages)}")
|
||||
total_tokens = 0
|
||||
for msg in reversed(output_messages):
|
||||
if isinstance(msg, AIMessage):
|
||||
logger.debug(f"找到 AI 消息,content 类型: {type(msg.content)}")
|
||||
logger.debug(f"AI 消息内容: {msg.content}")
|
||||
|
||||
|
||||
# 处理多模态响应:content 可能是字符串或列表
|
||||
if isinstance(msg.content, str):
|
||||
content = msg.content
|
||||
@@ -407,12 +415,13 @@ class LangChainAgent:
|
||||
response_meta = msg.response_metadata if hasattr(msg, 'response_metadata') else None
|
||||
total_tokens = response_meta.get("token_usage", {}).get("total_tokens", 0) if response_meta else 0
|
||||
break
|
||||
|
||||
|
||||
logger.info(f"最终提取的内容长度: {len(content)}")
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
if memory_flag:
|
||||
await write_long_term(storage_type, end_user_id, message_chat, content, user_rag_memory_id, actual_config_id)
|
||||
await write_long_term(storage_type, end_user_id, message_chat, content, user_rag_memory_id,
|
||||
actual_config_id)
|
||||
response = {
|
||||
"content": content,
|
||||
"model": self.model_name,
|
||||
@@ -439,16 +448,16 @@ class LangChainAgent:
|
||||
raise
|
||||
|
||||
async def chat_stream(
|
||||
self,
|
||||
message: str,
|
||||
history: Optional[List[Dict[str, str]]] = None,
|
||||
context: Optional[str] = None,
|
||||
end_user_id:Optional[str] = None,
|
||||
config_id: Optional[str] = None,
|
||||
storage_type:Optional[str] = None,
|
||||
user_rag_memory_id:Optional[str] = None,
|
||||
memory_flag: Optional[bool] = True,
|
||||
files: Optional[List[Dict[str, Any]]] = None # 新增:多模态文件
|
||||
self,
|
||||
message: str,
|
||||
history: Optional[List[Dict[str, str]]] = None,
|
||||
context: Optional[str] = None,
|
||||
end_user_id: Optional[str] = None,
|
||||
config_id: Optional[str] = None,
|
||||
storage_type: Optional[str] = None,
|
||||
user_rag_memory_id: Optional[str] = None,
|
||||
memory_flag: Optional[bool] = True,
|
||||
files: Optional[List[Dict[str, Any]]] = None # 新增:多模态文件
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""执行流式对话
|
||||
|
||||
@@ -482,7 +491,6 @@ class LangChainAgent:
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get db session: {e}")
|
||||
|
||||
|
||||
# 注意:不在这里写入用户消息,等 AI 回复后一起写入
|
||||
try:
|
||||
# 准备消息列表(支持多模态)
|
||||
@@ -500,13 +508,13 @@ class LangChainAgent:
|
||||
full_content = ''
|
||||
try:
|
||||
async for event in self.agent.astream_events(
|
||||
{"messages": messages},
|
||||
version="v2",
|
||||
config={"recursion_limit": self.max_iterations}
|
||||
{"messages": messages},
|
||||
version="v2",
|
||||
config={"recursion_limit": self.max_iterations}
|
||||
):
|
||||
chunk_count += 1
|
||||
kind = event.get("event")
|
||||
|
||||
|
||||
# 处理所有可能的流式事件
|
||||
if kind == "on_chat_model_stream":
|
||||
# LLM 流式输出
|
||||
@@ -540,7 +548,7 @@ class LangChainAgent:
|
||||
full_content += item
|
||||
yield item
|
||||
yielded_content = True
|
||||
|
||||
|
||||
elif kind == "on_llm_stream":
|
||||
# 另一种 LLM 流式事件
|
||||
chunk = event.get("data", {}).get("chunk")
|
||||
@@ -577,13 +585,13 @@ class LangChainAgent:
|
||||
full_content += chunk
|
||||
yield chunk
|
||||
yielded_content = True
|
||||
|
||||
|
||||
# 记录工具调用(可选)
|
||||
elif kind == "on_tool_start":
|
||||
logger.debug(f"工具调用开始: {event.get('name')}")
|
||||
elif kind == "on_tool_end":
|
||||
logger.debug(f"工具调用结束: {event.get('name')}")
|
||||
|
||||
|
||||
logger.debug(f"Agent 流式完成,共 {chunk_count} 个事件")
|
||||
# 统计token消耗
|
||||
output_messages = event.get("data", {}).get("output", {}).get("messages", [])
|
||||
@@ -595,7 +603,8 @@ class LangChainAgent:
|
||||
yield total_tokens
|
||||
break
|
||||
if memory_flag:
|
||||
await write_long_term(storage_type, end_user_id, message_chat, full_content, user_rag_memory_id, actual_config_id)
|
||||
await write_long_term(storage_type, end_user_id, message_chat, full_content, user_rag_memory_id,
|
||||
actual_config_id)
|
||||
except Exception as e:
|
||||
logger.error(f"Agent astream_events 失败: {str(e)}", exc_info=True)
|
||||
raise
|
||||
@@ -609,5 +618,3 @@ class LangChainAgent:
|
||||
logger.info("=" * 80)
|
||||
logger.info("chat_stream 方法执行结束")
|
||||
logger.info("=" * 80)
|
||||
|
||||
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Annotated, Any, Dict, Optional
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from pydantic import Field, TypeAdapter
|
||||
|
||||
load_dotenv()
|
||||
|
||||
@@ -16,18 +17,18 @@ class Settings:
|
||||
# cloud: SaaS 云服务版(全功能,按量计费)
|
||||
# enterprise: 企业私有化版(License 控制)
|
||||
DEPLOYMENT_MODE: str = os.getenv("DEPLOYMENT_MODE", "community")
|
||||
|
||||
|
||||
# License 配置(企业版)
|
||||
LICENSE_FILE: str = os.getenv("LICENSE_FILE", "/etc/app/license.json")
|
||||
LICENSE_SERVER_URL: str = os.getenv("LICENSE_SERVER_URL", "https://license.yourcompany.com")
|
||||
|
||||
|
||||
# 计费服务配置(SaaS 版)
|
||||
BILLING_SERVICE_URL: str = os.getenv("BILLING_SERVICE_URL", "")
|
||||
|
||||
|
||||
# 基础 URL(用于 SSO 回调等)
|
||||
BASE_URL: str = os.getenv("BASE_URL", "http://localhost:8000")
|
||||
FRONTEND_URL: str = os.getenv("FRONTEND_URL", "http://localhost:3000")
|
||||
|
||||
|
||||
ENABLE_SINGLE_WORKSPACE: bool = os.getenv("ENABLE_SINGLE_WORKSPACE", "true").lower() == "true"
|
||||
# API Keys Configuration
|
||||
OPENAI_API_KEY: str = os.getenv("OPENAI_API_KEY", "")
|
||||
@@ -57,7 +58,6 @@ class Settings:
|
||||
REDIS_PORT: int = int(os.getenv("REDIS_PORT", "6379"))
|
||||
REDIS_DB: int = int(os.getenv("REDIS_DB", "1"))
|
||||
REDIS_PASSWORD: str = os.getenv("REDIS_PASSWORD", "")
|
||||
|
||||
|
||||
# ElasticSearch configuration
|
||||
ELASTICSEARCH_HOST: str = os.getenv("ELASTICSEARCH_HOST", "https://127.0.0.1")
|
||||
@@ -91,7 +91,7 @@ class Settings:
|
||||
|
||||
# Single Sign-On configuration
|
||||
ENABLE_SINGLE_SESSION: bool = os.getenv("ENABLE_SINGLE_SESSION", "false").lower() == "true"
|
||||
|
||||
|
||||
# SSO 免登配置
|
||||
SSO_TOKEN_EXPIRE_SECONDS: int = int(os.getenv("SSO_TOKEN_EXPIRE_SECONDS", "300"))
|
||||
SSO_TRUSTED_SOURCES_CONFIG: str = os.getenv("SSO_TRUSTED_SOURCES_CONFIG", "{}")
|
||||
@@ -130,7 +130,7 @@ class Settings:
|
||||
|
||||
# Server Configuration
|
||||
SERVER_IP: str = os.getenv("SERVER_IP", "127.0.0.1")
|
||||
FILE_LOCAL_SERVER_URL : str = os.getenv("FILE_LOCAL_SERVER_URL", "http://localhost:8000/api")
|
||||
FILE_LOCAL_SERVER_URL: str = os.getenv("FILE_LOCAL_SERVER_URL", "http://localhost:8000/api")
|
||||
|
||||
# ========================================================================
|
||||
# Internal Configuration (not in .env, used by application code)
|
||||
@@ -190,8 +190,12 @@ class Settings:
|
||||
LOG_FILE_MAX_SIZE_MB: int = int(os.getenv("LOG_FILE_MAX_SIZE_MB", "10")) # 10MB
|
||||
|
||||
# Celery configuration (internal)
|
||||
CELERY_BROKER: int = int(os.getenv("CELERY_BROKER", "1"))
|
||||
CELERY_BACKEND: int = int(os.getenv("CELERY_BACKEND", "2"))
|
||||
# NOTE: 变量名不以 CELERY_ 开头,避免被 Celery CLI 的前缀匹配机制劫持
|
||||
# 详见 docs/celery-env-bug-report.md
|
||||
# 默认使用 Redis DB 3 (broker) 和 DB 4 (backend),与业务缓存 (DB 1/2) 隔离
|
||||
# 多人共用同一 Redis 时,每位开发者应在 .env 中配置不同的 DB 编号避免任务互相干扰
|
||||
REDIS_DB_CELERY_BROKER: int = int(os.getenv("REDIS_DB_CELERY_BROKER", "3"))
|
||||
REDIS_DB_CELERY_BACKEND: int = int(os.getenv("REDIS_DB_CELERY_BACKEND", "4"))
|
||||
|
||||
# SMTP Email Configuration
|
||||
SMTP_SERVER: str = os.getenv("SMTP_SERVER", "smtp.gmail.com")
|
||||
@@ -201,21 +205,30 @@ class Settings:
|
||||
|
||||
REFLECTION_INTERVAL_SECONDS: float = float(os.getenv("REFLECTION_INTERVAL_SECONDS", "300"))
|
||||
HEALTH_CHECK_SECONDS: float = float(os.getenv("HEALTH_CHECK_SECONDS", "600"))
|
||||
MEMORY_INCREMENT_INTERVAL_HOURS: float = float(os.getenv("MEMORY_INCREMENT_INTERVAL_HOURS", "24"))
|
||||
REFLECTION_INTERVAL_TIME: Optional[str] = int(os.getenv("REFLECTION_INTERVAL_TIME", 30))
|
||||
|
||||
# Memory Cache Regeneration Configuration
|
||||
MEMORY_CACHE_REGENERATION_HOURS: int = int(os.getenv("MEMORY_CACHE_REGENERATION_HOURS", "24"))
|
||||
|
||||
# Periodic Task Schedule Configuration
|
||||
# workspace_reflection: 每隔多少秒执行一次
|
||||
WORKSPACE_REFLECTION_INTERVAL_SECONDS: int = int(os.getenv("WORKSPACE_REFLECTION_INTERVAL_SECONDS", "30"))
|
||||
# forgetting_cycle: 每隔多少小时执行一次
|
||||
FORGETTING_CYCLE_INTERVAL_HOURS: int = int(os.getenv("FORGETTING_CYCLE_INTERVAL_HOURS", "24"))
|
||||
# implicit_emotions_update: 每天几点执行(小时,0-23)
|
||||
# Celery Beat Schedule Configuration (定时任务执行频率)
|
||||
MEMORY_INCREMENT_HOUR: int = TypeAdapter(
|
||||
Annotated[int, Field(ge=0, le=23, description="cron hour [0, 23]")]
|
||||
).validate_python(int(os.getenv("MEMORY_INCREMENT_HOUR", "2")))
|
||||
MEMORY_INCREMENT_MINUTE: int = TypeAdapter(
|
||||
Annotated[int, Field(ge=0, le=59, description="cron minute [0, 59]")]
|
||||
).validate_python(int(os.getenv("MEMORY_INCREMENT_MINUTE", "0")))
|
||||
WORKSPACE_REFLECTION_INTERVAL_SECONDS: int = TypeAdapter(
|
||||
Annotated[int, Field(ge=1, description="reflection interval in seconds, must be >= 1")]
|
||||
).validate_python(int(os.getenv("WORKSPACE_REFLECTION_INTERVAL_SECONDS", "30")))
|
||||
FORGETTING_CYCLE_INTERVAL_HOURS: int = TypeAdapter(
|
||||
Annotated[int, Field(ge=1, description="forgetting cycle interval in hours, must be >= 1")]
|
||||
).validate_python(int(os.getenv("FORGETTING_CYCLE_INTERVAL_HOURS", "24")))
|
||||
|
||||
IMPLICIT_EMOTIONS_UPDATE_HOUR: int = int(os.getenv("IMPLICIT_EMOTIONS_UPDATE_HOUR", "2"))
|
||||
# implicit_emotions_update: 每天几分执行(分钟,0-59)
|
||||
IMPLICIT_EMOTIONS_UPDATE_MINUTE: int = int(os.getenv("IMPLICIT_EMOTIONS_UPDATE_MINUTE", "0")) # Memory Module Configuration (internal)
|
||||
IMPLICIT_EMOTIONS_UPDATE_MINUTE: int = int(os.getenv("IMPLICIT_EMOTIONS_UPDATE_MINUTE", "0"))
|
||||
# Memory Module Configuration (internal)
|
||||
|
||||
MEMORY_OUTPUT_DIR: str = os.getenv("MEMORY_OUTPUT_DIR", "logs/memory-output")
|
||||
MEMORY_CONFIG_DIR: str = os.getenv("MEMORY_CONFIG_DIR", "app/core/memory")
|
||||
|
||||
@@ -232,27 +245,28 @@ class Settings:
|
||||
LOAD_MODEL: bool = os.getenv("LOAD_MODEL", "false").lower() == "true"
|
||||
|
||||
# workflow config
|
||||
WORKFLOW_IMPORT_CACHE_TIMEOUT: int = int(os.getenv("WORKFLOW_IMPORT_CACHE_TIMEOUT", 1800))
|
||||
WORKFLOW_NODE_TIMEOUT: int = int(os.getenv("WORKFLOW_NODE_TIMEOUT", 600))
|
||||
|
||||
# ========================================================================
|
||||
# General Ontology Type Configuration
|
||||
# ========================================================================
|
||||
# 通用本体文件路径列表(逗号分隔)
|
||||
GENERAL_ONTOLOGY_FILES: str = os.getenv("GENERAL_ONTOLOGY_FILES", "General_purpose_entity.ttl")
|
||||
|
||||
GENERAL_ONTOLOGY_FILES: str = os.getenv("GENERAL_ONTOLOGY_FILES", "api/app/core/memory/ontology_services/General_purpose_entity.ttl")
|
||||
|
||||
# 是否启用通用本体类型功能
|
||||
ENABLE_GENERAL_ONTOLOGY_TYPES: bool = os.getenv("ENABLE_GENERAL_ONTOLOGY_TYPES", "true").lower() == "true"
|
||||
|
||||
|
||||
# Prompt 中最大类型数量
|
||||
MAX_ONTOLOGY_TYPES_IN_PROMPT: int = int(os.getenv("MAX_ONTOLOGY_TYPES_IN_PROMPT", "50"))
|
||||
|
||||
|
||||
# 核心通用类型列表(逗号分隔)
|
||||
CORE_GENERAL_TYPES: str = os.getenv(
|
||||
"CORE_GENERAL_TYPES",
|
||||
"Person,Organization,Company,GovernmentAgency,Place,Location,City,Country,Building,"
|
||||
"Event,SportsEvent,SocialEvent,Work,Book,Film,Software,Concept,TopicalConcept,AcademicSubject"
|
||||
)
|
||||
|
||||
|
||||
# 实验模式开关(允许通过 API 动态切换本体配置)
|
||||
ONTOLOGY_EXPERIMENT_MODE: bool = os.getenv("ONTOLOGY_EXPERIMENT_MODE", "true").lower() == "true"
|
||||
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
import os
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from app.core.logging_config import get_agent_logger
|
||||
from app.db import get_db
|
||||
|
||||
from app.core.logging_config import get_agent_logger
|
||||
from app.core.memory.agent.models.problem_models import ProblemExtensionResponse
|
||||
from app.core.memory.agent.services.optimized_llm_service import LLMServiceMixin
|
||||
from app.core.memory.agent.utils.llm_tools import (
|
||||
PROJECT_ROOT_,
|
||||
ReadState,
|
||||
@@ -12,10 +12,9 @@ from app.core.memory.agent.utils.llm_tools import (
|
||||
from app.core.memory.agent.utils.redis_tool import store
|
||||
from app.core.memory.agent.utils.session_tools import SessionService
|
||||
from app.core.memory.agent.utils.template_tools import TemplateService
|
||||
from app.core.memory.agent.services.optimized_llm_service import LLMServiceMixin
|
||||
from app.db import get_db_context
|
||||
|
||||
template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt')
|
||||
db_session = next(get_db())
|
||||
logger = get_agent_logger(__name__)
|
||||
|
||||
|
||||
@@ -53,13 +52,14 @@ async def Split_The_Problem(state: ReadState) -> ReadState:
|
||||
|
||||
try:
|
||||
# 使用优化的LLM服务
|
||||
structured = await problem_service.call_llm_structured(
|
||||
state=state,
|
||||
db_session=db_session,
|
||||
system_prompt=system_prompt,
|
||||
response_model=ProblemExtensionResponse,
|
||||
fallback_value=[]
|
||||
)
|
||||
with get_db_context() as db_session:
|
||||
structured = await problem_service.call_llm_structured(
|
||||
state=state,
|
||||
db_session=db_session,
|
||||
system_prompt=system_prompt,
|
||||
response_model=ProblemExtensionResponse,
|
||||
fallback_value=[]
|
||||
)
|
||||
|
||||
# 添加更详细的日志记录
|
||||
logger.info(f"Split_The_Problem: 开始处理问题分解,内容长度: {len(content)}")
|
||||
@@ -111,7 +111,7 @@ async def Split_The_Problem(state: ReadState) -> ReadState:
|
||||
"error_type": type(e).__name__,
|
||||
"error_message": str(e),
|
||||
"content_length": len(content),
|
||||
"llm_model_id": memory_config.llm_model_id if memory_config else None
|
||||
"llm_model_id": str(memory_config.llm_model_id) if memory_config else None
|
||||
}
|
||||
|
||||
logger.error(f"Split_The_Problem error details: {error_details}")
|
||||
@@ -171,13 +171,14 @@ async def Problem_Extension(state: ReadState) -> ReadState:
|
||||
|
||||
try:
|
||||
# 使用优化的LLM服务
|
||||
response_content = await problem_service.call_llm_structured(
|
||||
state=state,
|
||||
db_session=db_session,
|
||||
system_prompt=system_prompt,
|
||||
response_model=ProblemExtensionResponse,
|
||||
fallback_value=[]
|
||||
)
|
||||
with get_db_context() as db_session:
|
||||
response_content = await problem_service.call_llm_structured(
|
||||
state=state,
|
||||
db_session=db_session,
|
||||
system_prompt=system_prompt,
|
||||
response_model=ProblemExtensionResponse,
|
||||
fallback_value=[]
|
||||
)
|
||||
|
||||
logger.info(f"Problem_Extension: 开始处理问题扩展,问题数量: {len(databasets)}")
|
||||
|
||||
@@ -220,7 +221,7 @@ async def Problem_Extension(state: ReadState) -> ReadState:
|
||||
"error_type": type(e).__name__,
|
||||
"error_message": str(e),
|
||||
"questions_count": len(databasets),
|
||||
"llm_model_id": memory_config.llm_model_id if memory_config else None
|
||||
"llm_model_id": str(memory_config.llm_model_id) if memory_config else None
|
||||
}
|
||||
|
||||
logger.error(f"Problem_Extension error details: {error_details}")
|
||||
|
||||
@@ -6,31 +6,26 @@ import os
|
||||
# ===== 第三方库 =====
|
||||
from langchain.agents import create_agent
|
||||
from langchain_openai import ChatOpenAI
|
||||
|
||||
from app.core.logging_config import get_agent_logger
|
||||
from app.db import get_db, get_db_context
|
||||
|
||||
from app.schemas import model_schema
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
from app.services.model_service import ModelConfigService
|
||||
|
||||
from app.core.memory.agent.services.search_service import SearchService
|
||||
from app.core.memory.agent.utils.llm_tools import (
|
||||
COUNTState,
|
||||
ReadState,
|
||||
deduplicate_entries,
|
||||
merge_to_key_value_pairs,
|
||||
)
|
||||
from app.core.memory.agent.langgraph_graph.tools.tool import (
|
||||
create_hybrid_retrieval_tool_sync,
|
||||
create_time_retrieval_tool,
|
||||
extract_tool_message_content,
|
||||
)
|
||||
|
||||
from app.core.memory.agent.services.search_service import SearchService
|
||||
from app.core.memory.agent.utils.llm_tools import (
|
||||
ReadState,
|
||||
deduplicate_entries,
|
||||
merge_to_key_value_pairs,
|
||||
)
|
||||
from app.core.rag.nlp.search import knowledge_retrieval
|
||||
from app.db import get_db_context
|
||||
from app.schemas import model_schema
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
from app.services.model_service import ModelConfigService
|
||||
|
||||
logger = get_agent_logger(__name__)
|
||||
db = next(get_db())
|
||||
|
||||
|
||||
|
||||
async def rag_config(state):
|
||||
@@ -50,10 +45,12 @@ async def rag_config(state):
|
||||
"reranker_top_k": 10
|
||||
}
|
||||
return kb_config
|
||||
async def rag_knowledge(state,question):
|
||||
|
||||
|
||||
async def rag_knowledge(state, question):
|
||||
kb_config = await rag_config(state)
|
||||
end_user_id = state.get('end_user_id', '')
|
||||
user_rag_memory_id=state.get("user_rag_memory_id",'')
|
||||
user_rag_memory_id = state.get("user_rag_memory_id", '')
|
||||
retrieve_chunks_result = knowledge_retrieval(question, kb_config, [str(end_user_id)])
|
||||
try:
|
||||
retrieval_knowledge = [i.page_content for i in retrieve_chunks_result]
|
||||
@@ -61,13 +58,13 @@ async def rag_knowledge(state,question):
|
||||
cleaned_query = question
|
||||
raw_results = clean_content
|
||||
logger.info(f" Using RAG storage with memory_id={user_rag_memory_id}")
|
||||
except Exception :
|
||||
retrieval_knowledge=[]
|
||||
except Exception:
|
||||
retrieval_knowledge = []
|
||||
clean_content = ''
|
||||
raw_results = ''
|
||||
cleaned_query = question
|
||||
logger.info(f"No content retrieved from knowledge base: {user_rag_memory_id}")
|
||||
return retrieval_knowledge,clean_content,cleaned_query,raw_results
|
||||
return retrieval_knowledge, clean_content, cleaned_query, raw_results
|
||||
|
||||
|
||||
async def llm_infomation(state: ReadState) -> ReadState:
|
||||
@@ -113,7 +110,7 @@ async def clean_databases(data) -> str:
|
||||
|
||||
# 收集所有内容
|
||||
content_list = []
|
||||
|
||||
|
||||
# 处理重排序结果
|
||||
reranked = results.get('reranked_results', {})
|
||||
if reranked:
|
||||
@@ -141,7 +138,6 @@ async def clean_databases(data) -> str:
|
||||
elif isinstance(item, str):
|
||||
text_parts.append(item)
|
||||
|
||||
|
||||
return '\n'.join(text_parts).strip()
|
||||
|
||||
except Exception as e:
|
||||
@@ -150,23 +146,23 @@ async def clean_databases(data) -> str:
|
||||
|
||||
|
||||
async def retrieve_nodes(state: ReadState) -> ReadState:
|
||||
|
||||
'''
|
||||
|
||||
模型信息
|
||||
'''
|
||||
|
||||
problem_extension=state.get('problem_extension', '')['context']
|
||||
storage_type=state.get('storage_type', '')
|
||||
user_rag_memory_id=state.get('user_rag_memory_id', '')
|
||||
end_user_id=state.get('end_user_id', '')
|
||||
problem_extension = state.get('problem_extension', '')['context']
|
||||
storage_type = state.get('storage_type', '')
|
||||
user_rag_memory_id = state.get('user_rag_memory_id', '')
|
||||
end_user_id = state.get('end_user_id', '')
|
||||
memory_config = state.get('memory_config', None)
|
||||
original=state.get('data', '')
|
||||
problem_list=[]
|
||||
for key,values in problem_extension.items():
|
||||
original = state.get('data', '')
|
||||
problem_list = []
|
||||
for key, values in problem_extension.items():
|
||||
for data in values:
|
||||
problem_list.append(data)
|
||||
logger.info(f"Retrieve: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}")
|
||||
|
||||
# 创建异步任务处理单个问题
|
||||
async def process_question_nodes(idx, question):
|
||||
try:
|
||||
@@ -244,7 +240,7 @@ async def retrieve_nodes(state: ReadState) -> ReadState:
|
||||
|
||||
send_verify = []
|
||||
for i, j in zip(keys, val, strict=False):
|
||||
if j!=['']:
|
||||
if j != ['']:
|
||||
send_verify.append({
|
||||
"Query_small": i,
|
||||
"Answer_Small": j
|
||||
@@ -257,15 +253,13 @@ async def retrieve_nodes(state: ReadState) -> ReadState:
|
||||
}
|
||||
|
||||
logger.info(f"Collected {len(intermediate_outputs)} intermediate outputs from search results")
|
||||
return {'retrieve':dup_databases}
|
||||
|
||||
|
||||
return {'retrieve': dup_databases}
|
||||
|
||||
|
||||
async def retrieve(state: ReadState) -> ReadState:
|
||||
# 从state中获取end_user_id
|
||||
import time
|
||||
start=time.time()
|
||||
start = time.time()
|
||||
problem_extension = state.get('problem_extension', '')['context']
|
||||
storage_type = state.get('storage_type', '')
|
||||
user_rag_memory_id = state.get('user_rag_memory_id', '')
|
||||
@@ -283,6 +277,7 @@ async def retrieve(state: ReadState) -> ReadState:
|
||||
with get_db_context() as db: # 使用同步数据库上下文管理器
|
||||
config_service = MemoryConfigService(db)
|
||||
return await llm_infomation(state)
|
||||
|
||||
llm_config = await get_llm_info()
|
||||
api_key_obj = llm_config.api_keys[0]
|
||||
api_key = api_key_obj.api_key
|
||||
@@ -296,11 +291,11 @@ async def retrieve(state: ReadState) -> ReadState:
|
||||
)
|
||||
|
||||
time_retrieval_tool = create_time_retrieval_tool(end_user_id)
|
||||
search_params = { "end_user_id": end_user_id, "return_raw_results": True }
|
||||
hybrid_retrieval=create_hybrid_retrieval_tool_sync(memory_config, **search_params)
|
||||
search_params = {"end_user_id": end_user_id, "return_raw_results": True}
|
||||
hybrid_retrieval = create_hybrid_retrieval_tool_sync(memory_config, **search_params)
|
||||
agent = create_agent(
|
||||
llm,
|
||||
tools=[time_retrieval_tool,hybrid_retrieval],
|
||||
tools=[time_retrieval_tool, hybrid_retrieval],
|
||||
system_prompt=f"我是检索专家,可以根据适合的工具进行检索。当前使用的end_user_id是: {end_user_id}"
|
||||
)
|
||||
|
||||
@@ -314,7 +309,8 @@ async def retrieve(state: ReadState) -> ReadState:
|
||||
async with SEMAPHORE: # 限制并发
|
||||
try:
|
||||
if storage_type == "rag" and user_rag_memory_id:
|
||||
retrieval_knowledge, clean_content, cleaned_query, raw_results = await rag_knowledge(state, question)
|
||||
retrieval_knowledge, clean_content, cleaned_query, raw_results = await rag_knowledge(state,
|
||||
question)
|
||||
else:
|
||||
cleaned_query = question
|
||||
# 使用 asyncio 在线程池中运行同步的 agent.invoke
|
||||
@@ -413,5 +409,3 @@ async def retrieve(state: ReadState) -> ReadState:
|
||||
# json.dump(dup_databases, f, indent=4)
|
||||
logger.info(f"Collected {len(intermediate_outputs)} intermediate outputs from search results")
|
||||
return {'retrieve': dup_databases}
|
||||
|
||||
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
|
||||
|
||||
import os
|
||||
import time
|
||||
|
||||
@@ -17,33 +15,77 @@ from app.core.memory.agent.utils.llm_tools import (
|
||||
from app.core.memory.agent.utils.redis_tool import store
|
||||
from app.core.memory.agent.utils.session_tools import SessionService
|
||||
from app.core.memory.agent.utils.template_tools import TemplateService
|
||||
from app.db import get_db
|
||||
from app.core.rag.nlp.search import knowledge_retrieval
|
||||
from app.db import get_db_context
|
||||
|
||||
template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt')
|
||||
logger = get_agent_logger(__name__)
|
||||
db_session = next(get_db())
|
||||
|
||||
|
||||
class SummaryNodeService(LLMServiceMixin):
|
||||
"""总结节点服务类"""
|
||||
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.template_service = TemplateService(template_root)
|
||||
|
||||
|
||||
# 创建全局服务实例
|
||||
summary_service = SummaryNodeService()
|
||||
|
||||
|
||||
async def rag_config(state):
|
||||
user_rag_memory_id = state.get('user_rag_memory_id', '')
|
||||
kb_config = {
|
||||
"knowledge_bases": [
|
||||
{
|
||||
"kb_id": user_rag_memory_id,
|
||||
"similarity_threshold": 0.7,
|
||||
"vector_similarity_weight": 0.5,
|
||||
"top_k": 10,
|
||||
"retrieve_type": "participle"
|
||||
}
|
||||
],
|
||||
"merge_strategy": "weight",
|
||||
"reranker_id": os.getenv('reranker_id'),
|
||||
"reranker_top_k": 10
|
||||
}
|
||||
return kb_config
|
||||
|
||||
|
||||
async def rag_knowledge(state, question):
|
||||
kb_config = await rag_config(state)
|
||||
end_user_id = state.get('end_user_id', '')
|
||||
user_rag_memory_id = state.get("user_rag_memory_id", '')
|
||||
retrieve_chunks_result = knowledge_retrieval(question, kb_config, [str(end_user_id)])
|
||||
try:
|
||||
retrieval_knowledge = [i.page_content for i in retrieve_chunks_result]
|
||||
clean_content = '\n\n'.join(retrieval_knowledge)
|
||||
cleaned_query = question
|
||||
raw_results = clean_content
|
||||
logger.info(f" Using RAG storage with memory_id={user_rag_memory_id}")
|
||||
except Exception:
|
||||
retrieval_knowledge = []
|
||||
clean_content = ''
|
||||
raw_results = ''
|
||||
cleaned_query = question
|
||||
logger.info(f"No content retrieved from knowledge base: {user_rag_memory_id}")
|
||||
return retrieval_knowledge, clean_content, cleaned_query, raw_results
|
||||
|
||||
|
||||
async def summary_history(state: ReadState) -> ReadState:
|
||||
end_user_id = state.get("end_user_id", '')
|
||||
history = await SessionService(store).get_history(end_user_id, end_user_id, end_user_id)
|
||||
return history
|
||||
|
||||
async def summary_llm(state: ReadState, history, retrieve_info, template_name, operation_name, response_model,search_mode) -> str:
|
||||
|
||||
async def summary_llm(state: ReadState, history, retrieve_info, template_name, operation_name, response_model,
|
||||
search_mode) -> str:
|
||||
"""
|
||||
增强的summary_llm函数,包含更好的错误处理和数据验证
|
||||
"""
|
||||
data = state.get("data", '')
|
||||
|
||||
|
||||
# 构建系统提示词
|
||||
if str(search_mode) == "0":
|
||||
system_prompt = await summary_service.template_service.render_template(
|
||||
@@ -62,18 +104,19 @@ async def summary_llm(state: ReadState, history, retrieve_info, template_name, o
|
||||
)
|
||||
try:
|
||||
# 使用优化的LLM服务进行结构化输出
|
||||
structured = await summary_service.call_llm_structured(
|
||||
state=state,
|
||||
db_session=db_session,
|
||||
system_prompt=system_prompt,
|
||||
response_model=response_model,
|
||||
fallback_value=None
|
||||
)
|
||||
with get_db_context() as db_session:
|
||||
structured = await summary_service.call_llm_structured(
|
||||
state=state,
|
||||
db_session=db_session,
|
||||
system_prompt=system_prompt,
|
||||
response_model=response_model,
|
||||
fallback_value=None
|
||||
)
|
||||
# 验证结构化响应
|
||||
if structured is None:
|
||||
logger.warning(f"LLM返回None,使用默认回答")
|
||||
logger.warning("LLM返回None,使用默认回答")
|
||||
return "信息不足,无法回答"
|
||||
|
||||
|
||||
# 根据操作类型提取答案
|
||||
if operation_name == "summary":
|
||||
aimessages = getattr(structured, 'query_answer', None) or "信息不足,无法回答"
|
||||
@@ -82,18 +125,18 @@ async def summary_llm(state: ReadState, history, retrieve_info, template_name, o
|
||||
if hasattr(structured, 'data') and structured.data:
|
||||
aimessages = getattr(structured.data, 'query_answer', None) or "信息不足,无法回答"
|
||||
else:
|
||||
logger.warning(f"结构化响应缺少data字段")
|
||||
logger.warning("结构化响应缺少data字段")
|
||||
aimessages = "信息不足,无法回答"
|
||||
|
||||
|
||||
# 验证答案不为空
|
||||
if not aimessages or aimessages.strip() == "":
|
||||
aimessages = "信息不足,无法回答"
|
||||
|
||||
|
||||
return aimessages
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"结构化输出失败: {e}", exc_info=True)
|
||||
|
||||
|
||||
# 尝试非结构化输出作为fallback
|
||||
try:
|
||||
logger.info("尝试非结构化输出作为fallback")
|
||||
@@ -103,7 +146,7 @@ async def summary_llm(state: ReadState, history, retrieve_info, template_name, o
|
||||
system_prompt=system_prompt,
|
||||
fallback_message="信息不足,无法回答"
|
||||
)
|
||||
|
||||
|
||||
if response and response.strip():
|
||||
# 简单清理响应
|
||||
cleaned_response = response.strip()
|
||||
@@ -111,16 +154,17 @@ async def summary_llm(state: ReadState, history, retrieve_info, template_name, o
|
||||
if cleaned_response.startswith('```'):
|
||||
lines = cleaned_response.split('\n')
|
||||
cleaned_response = '\n'.join(lines[1:-1])
|
||||
|
||||
|
||||
return cleaned_response
|
||||
else:
|
||||
return "信息不足,无法回答"
|
||||
|
||||
|
||||
except Exception as fallback_error:
|
||||
logger.error(f"Fallback也失败: {fallback_error}")
|
||||
return "信息不足,无法回答"
|
||||
|
||||
async def summary_redis_save(state: ReadState,aimessages) -> ReadState:
|
||||
|
||||
async def summary_redis_save(state: ReadState, aimessages) -> ReadState:
|
||||
data = state.get("data", '')
|
||||
end_user_id = state.get("end_user_id", '')
|
||||
await SessionService(store).save_session(
|
||||
@@ -132,10 +176,12 @@ async def summary_redis_save(state: ReadState,aimessages) -> ReadState:
|
||||
)
|
||||
await SessionService(store).cleanup_duplicates()
|
||||
logger.info(f"sessionid: {aimessages} 写入成功")
|
||||
async def summary_prompt(state: ReadState,aimessages,raw_results) -> ReadState:
|
||||
storage_type=state.get("storage_type",'')
|
||||
user_rag_memory_id=state.get("user_rag_memory_id",'')
|
||||
data=state.get("data", '')
|
||||
|
||||
|
||||
async def summary_prompt(state: ReadState, aimessages, raw_results) -> ReadState:
|
||||
storage_type = state.get("storage_type", '')
|
||||
user_rag_memory_id = state.get("user_rag_memory_id", '')
|
||||
data = state.get("data", '')
|
||||
input_summary = {
|
||||
"status": "success",
|
||||
"summary_result": aimessages,
|
||||
@@ -152,14 +198,14 @@ async def summary_prompt(state: ReadState,aimessages,raw_results) -> ReadState:
|
||||
"user_rag_memory_id": user_rag_memory_id
|
||||
}
|
||||
}
|
||||
retrieve={
|
||||
retrieve = {
|
||||
"status": "success",
|
||||
"summary_result": aimessages,
|
||||
"storage_type": storage_type,
|
||||
"user_rag_memory_id": user_rag_memory_id,
|
||||
"_intermediate": {
|
||||
"type": "retrieval_summary",
|
||||
"title":"快速检索",
|
||||
"title": "快速检索",
|
||||
"summary": aimessages,
|
||||
"query": data,
|
||||
"storage_type": storage_type,
|
||||
@@ -167,17 +213,18 @@ async def summary_prompt(state: ReadState,aimessages,raw_results) -> ReadState:
|
||||
}
|
||||
}
|
||||
|
||||
return input_summary,retrieve
|
||||
return input_summary, retrieve
|
||||
|
||||
|
||||
async def Input_Summary(state: ReadState) -> ReadState:
|
||||
start=time.time()
|
||||
storage_type=state.get("storage_type",'')
|
||||
start = time.time()
|
||||
storage_type = state.get("storage_type", '')
|
||||
memory_config = state.get('memory_config', None)
|
||||
user_rag_memory_id=state.get("user_rag_memory_id",'')
|
||||
data=state.get("data", '')
|
||||
end_user_id=state.get("end_user_id", '')
|
||||
user_rag_memory_id = state.get("user_rag_memory_id", '')
|
||||
data = state.get("data", '')
|
||||
end_user_id = state.get("end_user_id", '')
|
||||
logger.info(f"Input_Summary: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}")
|
||||
history = await summary_history( state)
|
||||
history = await summary_history(state)
|
||||
search_params = {
|
||||
"end_user_id": end_user_id,
|
||||
"question": data,
|
||||
@@ -186,12 +233,14 @@ async def Input_Summary(state: ReadState) -> ReadState:
|
||||
}
|
||||
|
||||
try:
|
||||
retrieve_info, question, raw_results = await SearchService().execute_hybrid_search(**search_params, memory_config=memory_config)
|
||||
if storage_type != "rag":
|
||||
retrieve_info, question, raw_results = await SearchService().execute_hybrid_search(**search_params,
|
||||
memory_config=memory_config)
|
||||
else:
|
||||
retrieval_knowledge, retrieve_info, question, raw_results = await rag_knowledge(state, data)
|
||||
except Exception as e:
|
||||
logger.error( f"Input_Summary: hybrid_search failed, using empty results: {e}", exc_info=True )
|
||||
logger.error(f"Input_Summary: hybrid_search failed, using empty results: {e}", exc_info=True)
|
||||
retrieve_info, question, raw_results = "", data, []
|
||||
|
||||
|
||||
try:
|
||||
# aimessages=await summary_llm(state,history,retrieve_info,'Retrieve_Summary_prompt.jinja2',
|
||||
# 'input_summary',RetrieveSummaryResponse)
|
||||
@@ -199,8 +248,8 @@ async def Input_Summary(state: ReadState) -> ReadState:
|
||||
summary_result = await summary_prompt(state, retrieve_info, retrieve_info)
|
||||
summary = summary_result[0]
|
||||
except Exception as e:
|
||||
logger.error( f"Input_Summary failed: {e}", exc_info=True )
|
||||
summary= {
|
||||
logger.error(f"Input_Summary failed: {e}", exc_info=True)
|
||||
summary = {
|
||||
"status": "fail",
|
||||
"summary_result": "信息不足,无法回答",
|
||||
"storage_type": storage_type,
|
||||
@@ -213,30 +262,31 @@ async def Input_Summary(state: ReadState) -> ReadState:
|
||||
except Exception:
|
||||
duration = 0.0
|
||||
log_time('检索', duration)
|
||||
return {"summary":summary}
|
||||
return {"summary": summary}
|
||||
|
||||
async def Retrieve_Summary(state: ReadState)-> ReadState:
|
||||
retrieve=state.get("retrieve", '')
|
||||
history = await summary_history( state)
|
||||
|
||||
async def Retrieve_Summary(state: ReadState) -> ReadState:
|
||||
retrieve = state.get("retrieve", '')
|
||||
history = await summary_history(state)
|
||||
import json
|
||||
with open("检索.json","w",encoding='utf-8') as f:
|
||||
with open("检索.json", "w", encoding='utf-8') as f:
|
||||
f.write(json.dumps(retrieve, indent=4, ensure_ascii=False))
|
||||
retrieve=retrieve.get("Expansion_issue", [])
|
||||
start=time.time()
|
||||
retrieve_info_str=[]
|
||||
retrieve = retrieve.get("Expansion_issue", [])
|
||||
start = time.time()
|
||||
retrieve_info_str = []
|
||||
for data in retrieve:
|
||||
if data=='':
|
||||
retrieve_info_str=''
|
||||
if data == '':
|
||||
retrieve_info_str = ''
|
||||
else:
|
||||
for key, value in data.items():
|
||||
if key=='Answer_Small':
|
||||
if key == 'Answer_Small':
|
||||
for i in value:
|
||||
retrieve_info_str.append(i)
|
||||
retrieve_info_str=list(set(retrieve_info_str))
|
||||
retrieve_info_str='\n'.join(retrieve_info_str)
|
||||
retrieve_info_str = list(set(retrieve_info_str))
|
||||
retrieve_info_str = '\n'.join(retrieve_info_str)
|
||||
|
||||
aimessages=await summary_llm(state,history,retrieve_info_str,
|
||||
'direct_summary_prompt.jinja2','retrieve_summary',RetrieveSummaryResponse,"1")
|
||||
aimessages = await summary_llm(state, history, retrieve_info_str,
|
||||
'direct_summary_prompt.jinja2', 'retrieve_summary', RetrieveSummaryResponse, "1")
|
||||
if '信息不足,无法回答' not in str(aimessages) or str(aimessages) != "":
|
||||
await summary_redis_save(state, aimessages)
|
||||
if aimessages == '':
|
||||
@@ -248,33 +298,33 @@ async def Retrieve_Summary(state: ReadState)-> ReadState:
|
||||
except Exception:
|
||||
duration = 0.0
|
||||
log_time('Retrieval summary', duration)
|
||||
|
||||
|
||||
# 修复协程调用 - 先await,然后访问返回值
|
||||
summary_result = await summary_prompt(state, aimessages, retrieve_info_str)
|
||||
summary = summary_result[1]
|
||||
return {"summary":summary}
|
||||
return {"summary": summary}
|
||||
|
||||
|
||||
async def Summary(state: ReadState)-> ReadState:
|
||||
start=time.time()
|
||||
async def Summary(state: ReadState) -> ReadState:
|
||||
start = time.time()
|
||||
query = state.get("data", '')
|
||||
verify=state.get("verify", '')
|
||||
verify_expansion_issue=verify.get("verified_data", '')
|
||||
retrieve_info_str=''
|
||||
verify = state.get("verify", '')
|
||||
verify_expansion_issue = verify.get("verified_data", '')
|
||||
retrieve_info_str = ''
|
||||
for data in verify_expansion_issue:
|
||||
for key, value in data.items():
|
||||
if key=='answer_small':
|
||||
if key == 'answer_small':
|
||||
for i in value:
|
||||
retrieve_info_str+=i+'\n'
|
||||
history=await summary_history(state)
|
||||
retrieve_info_str += i + '\n'
|
||||
history = await summary_history(state)
|
||||
|
||||
data = {
|
||||
"query": query,
|
||||
"history": history,
|
||||
"retrieve_info": retrieve_info_str
|
||||
}
|
||||
aimessages=await summary_llm(state,history,data,
|
||||
'summary_prompt.jinja2','summary',SummaryResponse,0)
|
||||
aimessages = await summary_llm(state, history, data,
|
||||
'summary_prompt.jinja2', 'summary', SummaryResponse, 0)
|
||||
|
||||
if '信息不足,无法回答' not in str(aimessages) or str(aimessages) != "":
|
||||
await summary_redis_save(state, aimessages)
|
||||
@@ -289,11 +339,12 @@ async def Summary(state: ReadState)-> ReadState:
|
||||
# 修复协程调用 - 先await,然后访问返回值
|
||||
summary_result = await summary_prompt(state, aimessages, retrieve_info_str)
|
||||
summary = summary_result[1]
|
||||
return {"summary":summary}
|
||||
return {"summary": summary}
|
||||
|
||||
async def Summary_fails(state: ReadState)-> ReadState:
|
||||
storage_type=state.get("storage_type", '')
|
||||
user_rag_memory_id=state.get("user_rag_memory_id", '')
|
||||
|
||||
async def Summary_fails(state: ReadState) -> ReadState:
|
||||
storage_type = state.get("storage_type", '')
|
||||
user_rag_memory_id = state.get("user_rag_memory_id", '')
|
||||
history = await summary_history(state)
|
||||
query = state.get("data", '')
|
||||
verify = state.get("verify", '')
|
||||
@@ -309,12 +360,12 @@ async def Summary_fails(state: ReadState)-> ReadState:
|
||||
"history": history,
|
||||
"retrieve_info": retrieve_info_str
|
||||
}
|
||||
aimessages = await summary_llm(state, history, data,
|
||||
'fail_summary_prompt.jinja2', 'summary', SummaryResponse, 0)
|
||||
result= {
|
||||
aimessages = await summary_llm(state, history, data,
|
||||
'fail_summary_prompt.jinja2', 'summary', SummaryResponse, 0)
|
||||
result = {
|
||||
"status": "success",
|
||||
"summary_result": aimessages,
|
||||
"storage_type": storage_type,
|
||||
"user_rag_memory_id": user_rag_memory_id
|
||||
}
|
||||
return {"summary":result}
|
||||
return {"summary": result}
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
import asyncio
|
||||
import os
|
||||
from app.core.logging_config import get_agent_logger
|
||||
from app.db import get_db
|
||||
|
||||
from app.core.logging_config import get_agent_logger
|
||||
from app.core.memory.agent.models.verification_models import VerificationResult
|
||||
from app.core.memory.agent.services.optimized_llm_service import LLMServiceMixin
|
||||
from app.core.memory.agent.utils.llm_tools import (
|
||||
PROJECT_ROOT_,
|
||||
ReadState,
|
||||
@@ -10,28 +11,30 @@ from app.core.memory.agent.utils.llm_tools import (
|
||||
from app.core.memory.agent.utils.redis_tool import store
|
||||
from app.core.memory.agent.utils.session_tools import SessionService
|
||||
from app.core.memory.agent.utils.template_tools import TemplateService
|
||||
from app.core.memory.agent.services.optimized_llm_service import LLMServiceMixin
|
||||
from app.db import get_db_context
|
||||
|
||||
template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt')
|
||||
db_session = next(get_db())
|
||||
logger = get_agent_logger(__name__)
|
||||
|
||||
|
||||
class VerificationNodeService(LLMServiceMixin):
|
||||
"""验证节点服务类"""
|
||||
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.template_service = TemplateService(template_root)
|
||||
|
||||
|
||||
# 创建全局服务实例
|
||||
verification_service = VerificationNodeService()
|
||||
|
||||
|
||||
async def Verify_prompt(state: ReadState, messages_deal: VerificationResult):
|
||||
"""处理验证结果并生成输出格式"""
|
||||
storage_type = state.get('storage_type', '')
|
||||
user_rag_memory_id = state.get('user_rag_memory_id', '')
|
||||
data = state.get('data', '')
|
||||
|
||||
|
||||
# 将 VerificationItem 对象转换为字典列表
|
||||
verified_data = []
|
||||
if messages_deal.expansion_issue:
|
||||
@@ -40,7 +43,7 @@ async def Verify_prompt(state: ReadState, messages_deal: VerificationResult):
|
||||
verified_data.append(item.model_dump())
|
||||
elif isinstance(item, dict):
|
||||
verified_data.append(item)
|
||||
|
||||
|
||||
Verify_result = {
|
||||
"status": messages_deal.split_result,
|
||||
"verified_data": verified_data,
|
||||
@@ -58,34 +61,37 @@ async def Verify_prompt(state: ReadState, messages_deal: VerificationResult):
|
||||
}
|
||||
}
|
||||
return Verify_result
|
||||
|
||||
|
||||
async def Verify(state: ReadState):
|
||||
logger.info("=== Verify 节点开始执行 ===")
|
||||
try:
|
||||
content = state.get('data', '')
|
||||
end_user_id = state.get('end_user_id', '')
|
||||
memory_config = state.get('memory_config', None)
|
||||
|
||||
|
||||
logger.info(f"Verify: content={content[:50] if content else 'empty'}..., end_user_id={end_user_id}")
|
||||
|
||||
history = await SessionService(store).get_history(end_user_id, end_user_id, end_user_id)
|
||||
logger.info(f"Verify: 获取历史记录完成,history length={len(history)}")
|
||||
|
||||
retrieve = state.get("retrieve", {})
|
||||
logger.info(f"Verify: retrieve data type={type(retrieve)}, keys={retrieve.keys() if isinstance(retrieve, dict) else 'N/A'}")
|
||||
|
||||
logger.info(
|
||||
f"Verify: retrieve data type={type(retrieve)}, keys={retrieve.keys() if isinstance(retrieve, dict) else 'N/A'}")
|
||||
|
||||
retrieve_expansion = retrieve.get("Expansion_issue", []) if isinstance(retrieve, dict) else []
|
||||
logger.info(f"Verify: Expansion_issue length={len(retrieve_expansion)}")
|
||||
|
||||
|
||||
messages = {
|
||||
"Query": content,
|
||||
"Expansion_issue": retrieve_expansion
|
||||
}
|
||||
|
||||
logger.info("Verify: 开始渲染模板")
|
||||
|
||||
|
||||
# 生成 JSON schema 以指导 LLM 输出正确格式
|
||||
json_schema = VerificationResult.model_json_schema()
|
||||
|
||||
|
||||
system_prompt = await verification_service.template_service.render_template(
|
||||
template_name='split_verify_prompt.jinja2',
|
||||
operation_name='split_verify_prompt',
|
||||
@@ -94,29 +100,30 @@ async def Verify(state: ReadState):
|
||||
json_schema=json_schema
|
||||
)
|
||||
logger.info(f"Verify: 模板渲染完成,prompt length={len(system_prompt)}")
|
||||
|
||||
|
||||
# 使用优化的LLM服务,添加超时保护
|
||||
logger.info("Verify: 开始调用 LLM")
|
||||
try:
|
||||
# 添加 asyncio.wait_for 超时包裹,防止无限等待
|
||||
# 超时时间设置为 150 秒(比 LLM 配置的 120 秒稍长)
|
||||
import asyncio
|
||||
structured = await asyncio.wait_for(
|
||||
verification_service.call_llm_structured(
|
||||
state=state,
|
||||
db_session=db_session,
|
||||
system_prompt=system_prompt,
|
||||
response_model=VerificationResult,
|
||||
fallback_value={
|
||||
"query": content,
|
||||
"history": history if isinstance(history, list) else [],
|
||||
"expansion_issue": [],
|
||||
"split_result": "failed",
|
||||
"reason": "验证失败或超时"
|
||||
}
|
||||
),
|
||||
timeout=150.0 # 150秒超时
|
||||
)
|
||||
|
||||
with get_db_context() as db_session:
|
||||
structured = await asyncio.wait_for(
|
||||
verification_service.call_llm_structured(
|
||||
state=state,
|
||||
db_session=db_session,
|
||||
system_prompt=system_prompt,
|
||||
response_model=VerificationResult,
|
||||
fallback_value={
|
||||
"query": content,
|
||||
"history": history if isinstance(history, list) else [],
|
||||
"expansion_issue": [],
|
||||
"split_result": "failed",
|
||||
"reason": "验证失败或超时"
|
||||
}
|
||||
),
|
||||
timeout=150.0 # 150秒超时
|
||||
)
|
||||
logger.info(f"Verify: LLM 调用完成,result={structured}")
|
||||
except asyncio.TimeoutError:
|
||||
logger.error("Verify: LLM 调用超时(150秒),使用 fallback 值")
|
||||
@@ -127,11 +134,11 @@ async def Verify(state: ReadState):
|
||||
split_result="failed",
|
||||
reason="LLM调用超时"
|
||||
)
|
||||
|
||||
|
||||
result = await Verify_prompt(state, structured)
|
||||
logger.info("=== Verify 节点执行完成 ===")
|
||||
return {"verify": result}
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Verify 节点执行失败: {e}", exc_info=True)
|
||||
# 返回失败的验证结果
|
||||
@@ -152,4 +159,4 @@ async def Verify(state: ReadState):
|
||||
"user_rag_memory_id": state.get('user_rag_memory_id', '')
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from app.cache.memory.interest_memory import InterestMemoryCache
|
||||
from app.core.memory.agent.utils.llm_tools import WriteState
|
||||
from app.core.memory.agent.utils.write_tools import write
|
||||
from app.core.logging_config import get_agent_logger
|
||||
@@ -40,6 +41,15 @@ async def write_node(state: WriteState) -> WriteState:
|
||||
)
|
||||
logger.info(f"Write completed successfully! Config: {memory_config.config_name}")
|
||||
|
||||
# 写入 neo4j 成功后,删除该用户的兴趣分布缓存,确保下次请求重新生成
|
||||
for lang in ["zh", "en"]:
|
||||
deleted = await InterestMemoryCache.delete_interest_distribution(
|
||||
end_user_id=end_user_id,
|
||||
language=lang,
|
||||
)
|
||||
if deleted:
|
||||
logger.info(f"Invalidated interest distribution cache: end_user_id={end_user_id}, language={lang}")
|
||||
|
||||
write_result = {
|
||||
"status": "success",
|
||||
"data": structured_messages,
|
||||
|
||||
@@ -5,7 +5,6 @@ from langchain_core.messages import HumanMessage
|
||||
from langgraph.constants import START, END
|
||||
from langgraph.graph import StateGraph
|
||||
|
||||
|
||||
from app.db import get_db
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
|
||||
@@ -32,7 +31,6 @@ from app.core.memory.agent.langgraph_graph.routing.routers import (
|
||||
)
|
||||
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def make_read_graph():
|
||||
"""创建并返回 LangGraph 工作流"""
|
||||
@@ -49,7 +47,7 @@ async def make_read_graph():
|
||||
workflow.add_node("Retrieve_Summary", Retrieve_Summary)
|
||||
workflow.add_node("Summary", Summary)
|
||||
workflow.add_node("Summary_fails", Summary_fails)
|
||||
|
||||
|
||||
# 添加边
|
||||
workflow.add_edge(START, "content_input")
|
||||
workflow.add_conditional_edges("content_input", Split_continue)
|
||||
@@ -62,20 +60,20 @@ async def make_read_graph():
|
||||
workflow.add_edge("Summary_fails", END)
|
||||
workflow.add_edge("Summary", END)
|
||||
|
||||
|
||||
'''-----'''
|
||||
# workflow.add_edge("Retrieve", END)
|
||||
|
||||
|
||||
# 编译工作流
|
||||
graph = workflow.compile()
|
||||
yield graph
|
||||
|
||||
|
||||
except Exception as e:
|
||||
print(f"创建工作流失败: {e}")
|
||||
raise
|
||||
finally:
|
||||
print("工作流创建完成")
|
||||
|
||||
|
||||
async def main():
|
||||
"""主函数 - 运行工作流"""
|
||||
message = "昨天有什么好看的电影"
|
||||
@@ -92,17 +90,19 @@ async def main():
|
||||
service_name="MemoryAgentService"
|
||||
)
|
||||
import time
|
||||
start=time.time()
|
||||
start = time.time()
|
||||
try:
|
||||
async with make_read_graph() as graph:
|
||||
config = {"configurable": {"thread_id": end_user_id}}
|
||||
# 初始状态 - 包含所有必要字段
|
||||
initial_state = {"messages": [HumanMessage(content=message)] ,"search_switch":search_switch,"end_user_id":end_user_id
|
||||
,"storage_type":storage_type,"user_rag_memory_id":user_rag_memory_id,"memory_config":memory_config}
|
||||
initial_state = {"messages": [HumanMessage(content=message)], "search_switch": search_switch,
|
||||
"end_user_id": end_user_id
|
||||
, "storage_type": storage_type, "user_rag_memory_id": user_rag_memory_id,
|
||||
"memory_config": memory_config}
|
||||
# 获取节点更新信息
|
||||
_intermediate_outputs = []
|
||||
summary = ''
|
||||
|
||||
|
||||
async for update_event in graph.astream(
|
||||
initial_state,
|
||||
stream_mode="updates",
|
||||
@@ -110,7 +110,7 @@ async def main():
|
||||
):
|
||||
for node_name, node_data in update_event.items():
|
||||
print(f"处理节点: {node_name}")
|
||||
|
||||
|
||||
# 处理不同Summary节点的返回结构
|
||||
if 'Summary' in node_name:
|
||||
if 'InputSummary' in node_data and 'summary_result' in node_data['InputSummary']:
|
||||
@@ -125,23 +125,22 @@ async def main():
|
||||
spit_data = node_data.get('spit_data', {}).get('_intermediate', None)
|
||||
if spit_data and spit_data != [] and spit_data != {}:
|
||||
_intermediate_outputs.append(spit_data)
|
||||
|
||||
|
||||
# Problem_Extension 节点
|
||||
problem_extension = node_data.get('problem_extension', {}).get('_intermediate', None)
|
||||
if problem_extension and problem_extension != [] and problem_extension != {}:
|
||||
_intermediate_outputs.append(problem_extension)
|
||||
|
||||
|
||||
# Retrieve 节点
|
||||
retrieve_node = node_data.get('retrieve', {}).get('_intermediate_outputs', None)
|
||||
if retrieve_node and retrieve_node != [] and retrieve_node != {}:
|
||||
_intermediate_outputs.extend(retrieve_node)
|
||||
|
||||
|
||||
# Verify 节点
|
||||
verify_n = node_data.get('verify', {}).get('_intermediate', None)
|
||||
if verify_n and verify_n != [] and verify_n != {}:
|
||||
_intermediate_outputs.append(verify_n)
|
||||
|
||||
|
||||
# Summary 节点
|
||||
summary_n = node_data.get('summary', {}).get('_intermediate', None)
|
||||
if summary_n and summary_n != [] and summary_n != {}:
|
||||
@@ -161,17 +160,20 @@ async def main():
|
||||
#
|
||||
print(f"=== 最终摘要 ===")
|
||||
print(summary)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
finally:
|
||||
db_session.close()
|
||||
|
||||
end=time.time()
|
||||
print(100*'y')
|
||||
print(f"总耗时: {end-start}s")
|
||||
print(100*'y')
|
||||
end = time.time()
|
||||
print(100 * 'y')
|
||||
print(f"总耗时: {end - start}s")
|
||||
print(100 * 'y')
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
|
||||
asyncio.run(main())
|
||||
|
||||
@@ -21,7 +21,7 @@ async def get_chunked_dialogs(
|
||||
end_user_id: Group identifier
|
||||
messages: Structured message list [{"role": "user", "content": "..."}, ...]
|
||||
ref_id: Reference identifier
|
||||
config_id: Configuration ID for processing
|
||||
config_id: Configuration ID for processing (used to load pruning config)
|
||||
|
||||
Returns:
|
||||
List of DialogData objects with generated chunks
|
||||
@@ -57,6 +57,63 @@ async def get_chunked_dialogs(
|
||||
end_user_id=end_user_id,
|
||||
config_id=config_id
|
||||
)
|
||||
|
||||
# 语义剪枝步骤(在分块之前)
|
||||
try:
|
||||
from app.core.memory.storage_services.extraction_engine.data_preprocessing.data_pruning import SemanticPruner
|
||||
from app.core.memory.models.config_models import PruningConfig
|
||||
from app.db import get_db_context
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
|
||||
# 加载剪枝配置
|
||||
pruning_config = None
|
||||
if config_id:
|
||||
try:
|
||||
with get_db_context() as db:
|
||||
# 使用 MemoryConfigService 加载完整的 MemoryConfig 对象
|
||||
config_service = MemoryConfigService(db)
|
||||
memory_config = config_service.load_memory_config(
|
||||
config_id=config_id,
|
||||
service_name="semantic_pruning"
|
||||
)
|
||||
|
||||
if memory_config:
|
||||
pruning_config = PruningConfig(
|
||||
pruning_switch=memory_config.pruning_enabled,
|
||||
pruning_scene=memory_config.pruning_scene or "education",
|
||||
pruning_threshold=memory_config.pruning_threshold,
|
||||
scene_id=str(memory_config.scene_id) if memory_config.scene_id else None,
|
||||
ontology_classes=memory_config.ontology_classes,
|
||||
)
|
||||
logger.info(f"[剪枝] 加载配置: switch={pruning_config.pruning_switch}, scene={pruning_config.pruning_scene}, threshold={pruning_config.pruning_threshold}")
|
||||
|
||||
# 获取LLM客户端用于剪枝
|
||||
if pruning_config.pruning_switch:
|
||||
factory = MemoryClientFactory(db)
|
||||
llm_client = factory.get_llm_client_from_config(memory_config)
|
||||
|
||||
# 执行剪枝 - 使用 prune_dataset 支持消息级剪枝
|
||||
pruner = SemanticPruner(config=pruning_config, llm_client=llm_client)
|
||||
original_msg_count = len(dialog_data.context.msgs)
|
||||
|
||||
# 使用 prune_dataset 而不是 prune_dialog
|
||||
# prune_dataset 会进行消息级剪枝,即使对话整体相关也会删除不重要消息
|
||||
pruned_dialogs = await pruner.prune_dataset([dialog_data])
|
||||
|
||||
if pruned_dialogs:
|
||||
dialog_data = pruned_dialogs[0]
|
||||
remaining_msg_count = len(dialog_data.context.msgs)
|
||||
deleted_count = original_msg_count - remaining_msg_count
|
||||
logger.info(f"[剪枝] 完成: 原始{original_msg_count}条 -> 保留{remaining_msg_count}条 (删除{deleted_count}条)")
|
||||
else:
|
||||
logger.warning("[剪枝] prune_dataset 返回空列表")
|
||||
else:
|
||||
logger.info("[剪枝] 配置中剪枝开关关闭,跳过剪枝")
|
||||
except Exception as e:
|
||||
logger.warning(f"[剪枝] 加载配置失败,跳过剪枝: {e}", exc_info=True)
|
||||
except Exception as e:
|
||||
logger.warning(f"[剪枝] 执行失败,跳过剪枝: {e}", exc_info=True)
|
||||
|
||||
chunker = DialogueChunker(chunker_strategy)
|
||||
extracted_chunks = await chunker.process_dialogue(dialog_data)
|
||||
|
||||
@@ -1,56 +0,0 @@
|
||||
|
||||
import asyncio
|
||||
from typing import Dict, Optional
|
||||
from app.core.memory.utils.llm.llm_utils import get_llm_client_fast
|
||||
from app.db import get_db
|
||||
from app.core.logging_config import get_agent_logger
|
||||
|
||||
logger = get_agent_logger(__name__)
|
||||
|
||||
class LLMClientPool:
|
||||
"""LLM客户端连接池"""
|
||||
|
||||
def __init__(self, max_size: int = 5):
|
||||
self.max_size = max_size
|
||||
self.pools: Dict[str, asyncio.Queue] = {}
|
||||
self.active_clients: Dict[str, int] = {}
|
||||
|
||||
async def get_client(self, llm_model_id: str):
|
||||
"""获取LLM客户端"""
|
||||
if llm_model_id not in self.pools:
|
||||
self.pools[llm_model_id] = asyncio.Queue(maxsize=self.max_size)
|
||||
self.active_clients[llm_model_id] = 0
|
||||
|
||||
pool = self.pools[llm_model_id]
|
||||
|
||||
try:
|
||||
# 尝试从池中获取客户端
|
||||
client = pool.get_nowait()
|
||||
logger.debug(f"从池中获取LLM客户端: {llm_model_id}")
|
||||
return client
|
||||
except asyncio.QueueEmpty:
|
||||
# 池为空,创建新客户端
|
||||
if self.active_clients[llm_model_id] < self.max_size:
|
||||
db_session = next(get_db())
|
||||
client = get_llm_client_fast(llm_model_id, db_session)
|
||||
self.active_clients[llm_model_id] += 1
|
||||
logger.debug(f"创建新LLM客户端: {llm_model_id}")
|
||||
return client
|
||||
else:
|
||||
# 等待可用客户端
|
||||
logger.debug(f"等待LLM客户端可用: {llm_model_id}")
|
||||
return await pool.get()
|
||||
|
||||
async def return_client(self, llm_model_id: str, client):
|
||||
"""归还LLM客户端到池中"""
|
||||
if llm_model_id in self.pools:
|
||||
try:
|
||||
self.pools[llm_model_id].put_nowait(client)
|
||||
logger.debug(f"归还LLM客户端到池: {llm_model_id}")
|
||||
except asyncio.QueueFull:
|
||||
# 池已满,丢弃客户端
|
||||
self.active_clients[llm_model_id] -= 1
|
||||
logger.debug(f"池已满,丢弃LLM客户端: {llm_model_id}")
|
||||
|
||||
# 全局客户端池
|
||||
llm_client_pool = LLMClientPool()
|
||||
@@ -225,5 +225,24 @@ async def write(
|
||||
with open(log_file, "a", encoding="utf-8") as f:
|
||||
f.write(f"=== Pipeline Run Completed: {timestamp} ===\n\n")
|
||||
|
||||
# 将提取统计写入 Redis,按 workspace_id 存储
|
||||
try:
|
||||
from app.cache.memory.activity_stats_cache import ActivityStatsCache
|
||||
|
||||
stats_to_cache = {
|
||||
"chunk_count": len(all_chunk_nodes) if all_chunk_nodes else 0,
|
||||
"statements_count": len(all_statement_nodes) if all_statement_nodes else 0,
|
||||
"triplet_entities_count": len(all_entity_nodes) if all_entity_nodes else 0,
|
||||
"triplet_relations_count": len(all_entity_entity_edges) if all_entity_entity_edges else 0,
|
||||
"temporal_count": 0,
|
||||
}
|
||||
await ActivityStatsCache.set_activity_stats(
|
||||
workspace_id=str(memory_config.workspace_id),
|
||||
stats=stats_to_cache,
|
||||
)
|
||||
logger.info(f"[WRITE] 活动统计已写入 Redis: workspace_id={memory_config.workspace_id}")
|
||||
except Exception as cache_err:
|
||||
logger.warning(f"[WRITE] 写入活动统计缓存失败(不影响主流程): {cache_err}", exc_info=True)
|
||||
|
||||
logger.info("=== Pipeline Complete ===")
|
||||
logger.info(f"Total execution time: {total_time:.2f} seconds")
|
||||
@@ -1,9 +1,12 @@
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from typing import List, Tuple
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.db import get_db_context
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
@@ -16,6 +19,10 @@ class FilteredTags(BaseModel):
|
||||
"""用于接收LLM筛选后的核心标签列表的模型。"""
|
||||
meaningful_tags: List[str] = Field(..., description="从原始列表中筛选出的具有核心代表意义的名词列表。")
|
||||
|
||||
class InterestTags(BaseModel):
|
||||
"""用于接收LLM筛选后的兴趣活动标签列表的模型。"""
|
||||
interest_tags: List[str] = Field(..., description="从原始列表中筛选出的代表用户兴趣活动的标签列表。")
|
||||
|
||||
async def filter_tags_with_llm(tags: List[str], end_user_id: str) -> List[str]:
|
||||
"""
|
||||
使用LLM筛选标签列表,仅保留具有代表性的核心名词。
|
||||
@@ -85,10 +92,74 @@ async def filter_tags_with_llm(tags: List[str], end_user_id: str) -> List[str]:
|
||||
return structured_response.meaningful_tags
|
||||
|
||||
except Exception as e:
|
||||
print(f"LLM筛选过程中发生错误: {e}")
|
||||
logger.error(f"LLM筛选过程中发生错误: {e}", exc_info=True)
|
||||
# 在LLM失败时返回原始标签,确保流程继续
|
||||
return tags
|
||||
|
||||
async def filter_interests_with_llm(tags: List[str], end_user_id: str, language: str = "zh") -> List[str]:
|
||||
"""
|
||||
使用LLM从标签列表中筛选出代表用户兴趣活动的标签。
|
||||
|
||||
与 filter_tags_with_llm 不同,此函数专注于识别"活动/行为"类兴趣,
|
||||
过滤掉纯物品、工具、地点等不代表用户主动参与活动的名词。
|
||||
|
||||
Args:
|
||||
tags: 原始标签列表
|
||||
end_user_id: 用户ID,用于获取LLM配置
|
||||
|
||||
Returns:
|
||||
筛选后的兴趣活动标签列表
|
||||
"""
|
||||
try:
|
||||
with get_db_context() as db:
|
||||
from app.services.memory_agent_service import (
|
||||
get_end_user_connected_config,
|
||||
)
|
||||
connected_config = get_end_user_connected_config(end_user_id, db)
|
||||
config_id = connected_config.get("memory_config_id")
|
||||
workspace_id = connected_config.get("workspace_id")
|
||||
|
||||
if not config_id and not workspace_id:
|
||||
raise ValueError(
|
||||
f"No memory_config_id found for end_user_id: {end_user_id}."
|
||||
)
|
||||
|
||||
config_service = MemoryConfigService(db)
|
||||
memory_config = config_service.load_memory_config(
|
||||
config_id=config_id,
|
||||
workspace_id=workspace_id
|
||||
)
|
||||
|
||||
if not memory_config.llm_model_id:
|
||||
raise ValueError(
|
||||
f"No llm_model_id found in memory config {config_id}."
|
||||
)
|
||||
|
||||
factory = MemoryClientFactory(db)
|
||||
llm_client = factory.get_llm_client(memory_config.llm_model_id)
|
||||
|
||||
tag_list_str = ", ".join(tags)
|
||||
from app.core.memory.utils.prompt.prompt_utils import render_interest_filter_prompt
|
||||
rendered_prompt = render_interest_filter_prompt(tag_list_str, language=language)
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": rendered_prompt
|
||||
}
|
||||
]
|
||||
|
||||
structured_response = await llm_client.response_structured(
|
||||
messages=messages,
|
||||
response_model=InterestTags
|
||||
)
|
||||
|
||||
return structured_response.interest_tags
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"兴趣标签LLM筛选过程中发生错误: {e}", exc_info=True)
|
||||
return tags
|
||||
|
||||
|
||||
async def get_raw_tags_from_db(
|
||||
connector: Neo4jConnector,
|
||||
end_user_id: str,
|
||||
@@ -139,14 +210,14 @@ async def get_raw_tags_from_db(
|
||||
|
||||
return [(record["name"], record["frequency"]) for record in results]
|
||||
|
||||
async def get_hot_memory_tags(end_user_id: str, limit: int = 40, by_user: bool = False) -> List[Tuple[str, int]]:
|
||||
async def get_hot_memory_tags(end_user_id: str, limit: int = 10, by_user: bool = False) -> List[Tuple[str, int]]:
|
||||
"""
|
||||
获取原始标签,然后使用LLM进行筛选,返回最终的热门标签列表。
|
||||
查询更多的标签(limit=40)给LLM提供更丰富的上下文进行筛选。
|
||||
查询更多的标签(40条)给LLM提供更丰富的上下文进行筛选,但最终返回数量由limit参数控制。
|
||||
|
||||
Args:
|
||||
end_user_id: 必需参数。如果by_user=False,则为end_user_id;如果by_user=True,则为user_id
|
||||
limit: 返回的标签数量限制
|
||||
limit: 最终返回的标签数量限制(默认10)
|
||||
by_user: 是否按user_id查询(默认False,按end_user_id查询)
|
||||
|
||||
Raises:
|
||||
@@ -161,8 +232,9 @@ async def get_hot_memory_tags(end_user_id: str, limit: int = 40, by_user: bool =
|
||||
# 使用项目的Neo4jConnector
|
||||
connector = Neo4jConnector()
|
||||
try:
|
||||
# 1. 从数据库获取原始排名靠前的标签
|
||||
raw_tags_with_freq = await get_raw_tags_from_db(connector, end_user_id, limit, by_user=by_user)
|
||||
# 1. 从数据库获取原始排名靠前的标签(查询40条给LLM提供更丰富的上下文)
|
||||
query_limit = 40
|
||||
raw_tags_with_freq = await get_raw_tags_from_db(connector, end_user_id, query_limit, by_user=by_user)
|
||||
if not raw_tags_with_freq:
|
||||
return []
|
||||
|
||||
@@ -177,7 +249,61 @@ async def get_hot_memory_tags(end_user_id: str, limit: int = 40, by_user: bool =
|
||||
if tag in meaningful_tag_names:
|
||||
final_tags.append((tag, freq))
|
||||
|
||||
return final_tags
|
||||
# 4. 限制返回的标签数量
|
||||
return final_tags[:limit]
|
||||
finally:
|
||||
# 确保关闭连接
|
||||
await connector.close()
|
||||
|
||||
async def get_interest_distribution(end_user_id: str, limit: int = 10, by_user: bool = False, language: str = "zh") -> List[Tuple[str, int]]:
|
||||
"""
|
||||
获取用户的兴趣分布标签。
|
||||
|
||||
与 get_hot_memory_tags 不同,此函数使用专门针对"活动/行为"的LLM prompt,
|
||||
过滤掉纯物品、工具、地点等,只保留能代表用户兴趣爱好的活动类标签。
|
||||
|
||||
Args:
|
||||
end_user_id: 必需参数。如果by_user=False,则为end_user_id;如果by_user=True,则为user_id
|
||||
limit: 最终返回的标签数量限制(默认10)
|
||||
by_user: 是否按user_id查询(默认False,按end_user_id查询)
|
||||
|
||||
Raises:
|
||||
ValueError: 如果end_user_id未提供或为空
|
||||
"""
|
||||
if not end_user_id or not end_user_id.strip():
|
||||
raise ValueError(
|
||||
"end_user_id is required. Please provide a valid end_user_id or user_id."
|
||||
)
|
||||
|
||||
connector = Neo4jConnector()
|
||||
try:
|
||||
# 查询更多原始标签,给LLM提供充足上下文
|
||||
query_limit = 40
|
||||
raw_tags_with_freq = await get_raw_tags_from_db(connector, end_user_id, query_limit, by_user=by_user)
|
||||
if not raw_tags_with_freq:
|
||||
return []
|
||||
|
||||
raw_tag_names = [tag for tag, freq in raw_tags_with_freq]
|
||||
raw_freq_map = {tag: freq for tag, freq in raw_tags_with_freq}
|
||||
|
||||
# 使用兴趣活动专用prompt进行筛选(支持语义推断出新标签)
|
||||
interest_tag_names = await filter_interests_with_llm(raw_tag_names, end_user_id, language=language)
|
||||
|
||||
# 构建最终标签列表:
|
||||
# - 原始标签中存在的,保留原始频率
|
||||
# - LLM推断出的新标签(不在原始列表中),赋予默认频率1
|
||||
final_tags = []
|
||||
seen = set()
|
||||
for tag in interest_tag_names:
|
||||
if tag in seen:
|
||||
continue
|
||||
seen.add(tag)
|
||||
freq = raw_freq_map.get(tag, 1)
|
||||
final_tags.append((tag, freq))
|
||||
|
||||
# 按频率降序排列
|
||||
final_tags.sort(key=lambda x: x[1], reverse=True)
|
||||
|
||||
return final_tags[:limit]
|
||||
finally:
|
||||
await connector.close()
|
||||
|
||||
@@ -10,7 +10,7 @@ Classes:
|
||||
TemporalSearchParams: Parameters for temporal search queries
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
from typing import Optional, List
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
@@ -55,17 +55,26 @@ class PruningConfig(BaseModel):
|
||||
|
||||
Attributes:
|
||||
pruning_switch: Enable or disable semantic pruning
|
||||
pruning_scene: Scene type for pruning ('education', 'online_service', 'outbound')
|
||||
pruning_scene: Scene name for pruning, either a built-in key
|
||||
('education', 'online_service', 'outbound') or a custom scene_name
|
||||
from ontology_scene table
|
||||
pruning_threshold: Pruning ratio (0-0.9, max 0.9 to avoid complete removal)
|
||||
scene_id: Optional ontology scene UUID, used to load custom ontology classes
|
||||
ontology_classes: List of class_name strings from ontology_class table,
|
||||
injected into the prompt when pruning_scene is not a built-in scene
|
||||
"""
|
||||
pruning_switch: bool = Field(False, description="Enable semantic pruning when True.")
|
||||
pruning_scene: str = Field(
|
||||
"education",
|
||||
description="Scene for pruning: one of 'education', 'online_service', 'outbound'.",
|
||||
description="Scene for pruning: built-in key or custom scene_name from ontology_scene.",
|
||||
)
|
||||
pruning_threshold: float = Field(
|
||||
0.5, ge=0.0, le=0.9,
|
||||
description="Pruning ratio within 0-0.9 (max 0.9 to avoid termination).")
|
||||
scene_id: Optional[str] = Field(None, description="Ontology scene UUID (optional).")
|
||||
ontology_classes: Optional[List[str]] = Field(
|
||||
None, description="Class names from ontology_class table for custom scenes."
|
||||
)
|
||||
|
||||
|
||||
class TemporalSearchParams(BaseModel):
|
||||
|
||||
@@ -5,20 +5,27 @@
|
||||
- 对话级一次性抽取判定相关性
|
||||
- 仅对"不相关对话"的消息按比例删除
|
||||
- 重要信息(时间、编号、金额、联系方式、地址等)优先保留
|
||||
- 改进版:增强重要性判断、智能填充消息识别、问答对保护、并发优化
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import hashlib
|
||||
import json
|
||||
import re
|
||||
from collections import OrderedDict
|
||||
from datetime import datetime
|
||||
from typing import List, Optional
|
||||
from typing import List, Optional, Dict, Tuple, Set
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.core.memory.models.message_models import DialogData, ConversationMessage, ConversationContext
|
||||
from app.core.memory.models.config_models import PruningConfig
|
||||
from app.core.memory.utils.config.config_utils import get_pruning_config
|
||||
from app.core.memory.utils.prompt.prompt_utils import prompt_env, log_prompt_rendering, log_template_rendering
|
||||
from app.core.memory.storage_services.extraction_engine.data_preprocessing.scene_config import (
|
||||
SceneConfigRegistry,
|
||||
ScenePatterns
|
||||
)
|
||||
|
||||
|
||||
class DialogExtractionResponse(BaseModel):
|
||||
@@ -36,6 +43,23 @@ class DialogExtractionResponse(BaseModel):
|
||||
keywords: List[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
class MessageImportanceResponse(BaseModel):
|
||||
"""消息重要性批量判断的结构化返回(用于LLM语义判断)。
|
||||
|
||||
- importance_scores: 消息索引到重要性分数的映射 (0-10分)
|
||||
- reasons: 可选的判断理由
|
||||
"""
|
||||
importance_scores: Dict[int, int] = Field(default_factory=dict, description="消息索引到重要性分数(0-10)的映射")
|
||||
reasons: Optional[Dict[int, str]] = Field(default_factory=dict, description="可选的判断理由")
|
||||
|
||||
|
||||
class QAPair(BaseModel):
|
||||
"""问答对模型,用于识别和保护对话中的问答结构。"""
|
||||
question_idx: int = Field(..., description="问题消息的索引")
|
||||
answer_idx: int = Field(..., description="答案消息的索引")
|
||||
confidence: float = Field(default=1.0, description="问答对的置信度(0-1)")
|
||||
|
||||
|
||||
class SemanticPruner:
|
||||
"""语义剪枝:在预处理与分块之间过滤与场景不相关内容。
|
||||
|
||||
@@ -43,109 +67,385 @@ class SemanticPruner:
|
||||
重要信息(时间、编号、金额、联系方式、地址等)优先保留。
|
||||
"""
|
||||
|
||||
def __init__(self, config: Optional[PruningConfig] = None, llm_client=None):
|
||||
cfg_dict = get_pruning_config() if config is None else config.model_dump()
|
||||
self.config = PruningConfig.model_validate(cfg_dict)
|
||||
def __init__(self, config: Optional[PruningConfig] = None, llm_client=None, language: str = "zh", max_concurrent: int = 5):
|
||||
# 如果没有提供config,使用默认配置
|
||||
if config is None:
|
||||
# 使用默认的剪枝配置
|
||||
config = PruningConfig(
|
||||
pruning_switch=False, # 默认关闭剪枝,保持向后兼容
|
||||
pruning_scene="education",
|
||||
pruning_threshold=0.5
|
||||
)
|
||||
|
||||
self.config = config
|
||||
self.llm_client = llm_client
|
||||
self.language = language # 保存语言配置
|
||||
self.max_concurrent = max_concurrent # 新增:最大并发数
|
||||
|
||||
# 详细日志配置:限制逐条消息日志的数量
|
||||
self._detailed_prune_logging = True # 是否启用详细日志
|
||||
self._max_debug_msgs_per_dialog = 20 # 每个对话最多记录前N条消息的详细日志
|
||||
|
||||
# 加载场景特定配置(内置场景走专门规则,自定义场景 fallback 到通用规则)
|
||||
self.scene_config: ScenePatterns = SceneConfigRegistry.get_config(
|
||||
self.config.pruning_scene,
|
||||
fallback_to_generic=True
|
||||
)
|
||||
|
||||
# 判断是否为内置专门场景
|
||||
self._is_builtin_scene = SceneConfigRegistry.is_scene_supported(self.config.pruning_scene)
|
||||
|
||||
# 自定义场景的本体类型列表(用于注入提示词)
|
||||
self._ontology_classes = getattr(self.config, "ontology_classes", None) or []
|
||||
|
||||
if self._is_builtin_scene:
|
||||
self._log(f"[剪枝-初始化] 场景={self.config.pruning_scene} 使用内置专门配置")
|
||||
else:
|
||||
self._log(f"[剪枝-初始化] 场景={self.config.pruning_scene} 为自定义场景,使用通用规则 + 本体类型提示词注入")
|
||||
if self._ontology_classes:
|
||||
self._log(f"[剪枝-初始化] 注入本体类型: {self._ontology_classes}")
|
||||
else:
|
||||
self._log(f"[剪枝-初始化] 未找到本体类型,将使用通用提示词")
|
||||
|
||||
# Load Jinja2 template
|
||||
self.template = prompt_env.get_template("extracat_Pruning.jinja2")
|
||||
# 对话抽取缓存:避免同一对话重复调用 LLM / 重复渲染
|
||||
self._dialog_extract_cache: dict[str, DialogExtractionResponse] = {}
|
||||
|
||||
# 对话抽取缓存:使用 OrderedDict 实现 LRU 缓存
|
||||
self._dialog_extract_cache: OrderedDict[str, DialogExtractionResponse] = OrderedDict()
|
||||
self._cache_max_size = 1000 # 缓存大小限制
|
||||
|
||||
# 运行日志:收集关键终端输出,便于写入 JSON
|
||||
self.run_logs: List[str] = []
|
||||
# 采用顺序处理,移除并发配置以简化与稳定执行
|
||||
|
||||
def _is_important_message(self, message: ConversationMessage) -> bool:
|
||||
"""基于启发式规则识别重要信息消息,优先保留。
|
||||
|
||||
- 含日期/时间(如YYYY-MM-DD、HH:MM、2024年11月10日、上午/下午)。
|
||||
- 含编号/ID/订单号/申请号/账号/电话/金额等关键字段。
|
||||
- 关键词:"时间"、"日期"、"编号"、"订单"、"流水"、"金额"、"¥"、"元"、"电话"、"手机号"、"邮箱"、"地址"。
|
||||
改进版:使用场景特定的模式进行识别
|
||||
- 根据 pruning_scene 动态加载对应的识别规则
|
||||
- 支持教育、在线服务、外呼三个场景的特定模式
|
||||
"""
|
||||
import re
|
||||
text = message.msg.strip()
|
||||
if not text:
|
||||
return False
|
||||
patterns = [
|
||||
r"\b\d{4}-\d{1,2}-\d{1,2}\b",
|
||||
r"\b\d{1,2}:\d{2}\b",
|
||||
r"\d{4}年\d{1,2}月\d{1,2}日",
|
||||
r"上午|下午|AM|PM",
|
||||
r"订单号|工单|申请号|编号|ID|账号|账户",
|
||||
r"电话|手机号|微信|QQ|邮箱",
|
||||
r"地址|地点",
|
||||
r"金额|费用|价格|¥|¥|\d+元",
|
||||
r"时间|日期|有效期|截止",
|
||||
]
|
||||
for p in patterns:
|
||||
if re.search(p, text, flags=re.IGNORECASE):
|
||||
|
||||
# 使用场景特定的模式
|
||||
all_patterns = (
|
||||
self.scene_config.high_priority_patterns +
|
||||
self.scene_config.medium_priority_patterns +
|
||||
self.scene_config.low_priority_patterns
|
||||
)
|
||||
|
||||
for pattern, _ in all_patterns:
|
||||
if re.search(pattern, text, flags=re.IGNORECASE):
|
||||
return True
|
||||
|
||||
# 检查是否为问句(以问号结尾或包含疑问词)
|
||||
if text.endswith("?") or text.endswith("?"):
|
||||
return True
|
||||
|
||||
# 检查是否包含问句关键词
|
||||
if any(keyword in text for keyword in self.scene_config.question_keywords):
|
||||
return True
|
||||
|
||||
# 检查是否包含决策性关键词
|
||||
if any(keyword in text for keyword in self.scene_config.decision_keywords):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _importance_score(self, message: ConversationMessage) -> int:
|
||||
"""为重要消息打分,用于在保留比例内优先保留更关键的内容。
|
||||
|
||||
简单启发:匹配到的类别越多、越关键分值越高。
|
||||
改进版:使用场景特定的权重体系(0-10分)
|
||||
- 根据场景动态调整不同信息类型的权重
|
||||
- 高优先级模式:4-6分
|
||||
- 中优先级模式:2-3分
|
||||
- 低优先级模式:1分
|
||||
"""
|
||||
import re
|
||||
text = message.msg.strip()
|
||||
score = 0
|
||||
weights = [
|
||||
(r"\b\d{4}-\d{1,2}-\d{1,2}\b", 3),
|
||||
(r"\b\d{1,2}:\d{2}\b", 2),
|
||||
(r"\d{4}年\d{1,2}月\d{1,2}日", 3),
|
||||
(r"订单号|工单|申请号|编号|ID|账号|账户", 4),
|
||||
(r"电话|手机号|微信|QQ|邮箱", 3),
|
||||
(r"地址|地点", 2),
|
||||
(r"金额|费用|价格|¥|¥|\d+元", 4),
|
||||
(r"时间|日期|有效期|截止", 2),
|
||||
]
|
||||
for p, w in weights:
|
||||
if re.search(p, text, flags=re.IGNORECASE):
|
||||
score += w
|
||||
return score
|
||||
|
||||
# 使用场景特定的权重
|
||||
for pattern, weight in self.scene_config.high_priority_patterns:
|
||||
if re.search(pattern, text, flags=re.IGNORECASE):
|
||||
score += weight
|
||||
|
||||
for pattern, weight in self.scene_config.medium_priority_patterns:
|
||||
if re.search(pattern, text, flags=re.IGNORECASE):
|
||||
score += weight
|
||||
|
||||
for pattern, weight in self.scene_config.low_priority_patterns:
|
||||
if re.search(pattern, text, flags=re.IGNORECASE):
|
||||
score += weight
|
||||
|
||||
# 问句加分
|
||||
if text.endswith("?") or text.endswith("?"):
|
||||
score += 2
|
||||
|
||||
# 包含问句关键词加分
|
||||
if any(keyword in text for keyword in self.scene_config.question_keywords):
|
||||
score += 1
|
||||
|
||||
# 包含决策性关键词加分
|
||||
if any(keyword in text for keyword in self.scene_config.decision_keywords):
|
||||
score += 2
|
||||
|
||||
# 长度加分(较长的消息通常包含更多信息)
|
||||
if len(text) > 50:
|
||||
score += 1
|
||||
if len(text) > 100:
|
||||
score += 1
|
||||
|
||||
return min(score, 10) # 最高10分
|
||||
|
||||
def _is_filler_message(self, message: ConversationMessage) -> bool:
|
||||
"""检测典型寒暄/口头禅/确认类短消息,用于跳过LLM分类以加速。
|
||||
"""检测典型寒暄/口头禅/确认类短消息。
|
||||
|
||||
改进版:更严格的填充消息判断,避免误删场景相关内容
|
||||
满足以下之一视为填充消息:
|
||||
- 纯标点或长度很短(<= 4 个汉字或 <= 8 个字符)且不包含数字或关键实体;
|
||||
- 常见词:你好/您好/在吗/嗯/嗯嗯/哦/好的/好/行/可以/不可以/谢谢/拜拜/再见/哈哈/呵呵/哈哈哈/。。。/??。
|
||||
- 纯标点或空白
|
||||
- 在场景特定填充词库中(精确匹配)
|
||||
- 纯表情符号
|
||||
- 常见寒暄(精确匹配短语)
|
||||
|
||||
注意:不再使用长度判断,避免误删短但重要的消息
|
||||
"""
|
||||
import re
|
||||
t = message.msg.strip()
|
||||
if not t:
|
||||
return True
|
||||
# 常见填充语
|
||||
fillers = [
|
||||
"你好", "您好", "在吗", "嗯", "嗯嗯", "哦", "好的", "好", "行", "可以", "不可以", "谢谢",
|
||||
"拜拜", "再见", "哈哈", "呵呵", "哈哈哈", "。。。", "??", "??"
|
||||
]
|
||||
if t in fillers:
|
||||
|
||||
# 检查是否在场景特定填充词库中(精确匹配)
|
||||
if t in self.scene_config.filler_phrases:
|
||||
return True
|
||||
# 长度与字符类型判断
|
||||
if len(t) <= 8:
|
||||
# 非数字、无关键实体的短文本
|
||||
if not re.search(r"[0-9]", t) and not self._is_important_message(message):
|
||||
# 主要是标点或简单确认词
|
||||
if re.fullmatch(r"[。!?,.!?…·\s]+", t) or t in fillers:
|
||||
return True
|
||||
|
||||
# 常见寒暄和问候(精确匹配,避免误删)
|
||||
common_greetings = {
|
||||
"在吗", "在不在", "在呢", "在的",
|
||||
"你好", "您好", "hello", "hi",
|
||||
"拜拜", "再见", "拜", "88", "bye",
|
||||
"好的", "好", "行", "可以", "嗯", "哦", "啊",
|
||||
"是的", "对", "对的", "没错", "是啊",
|
||||
"哈哈", "呵呵", "嘿嘿", "嗯嗯"
|
||||
}
|
||||
if t in common_greetings:
|
||||
return True
|
||||
|
||||
# 检查是否为纯表情符号(方括号包裹)
|
||||
if re.fullmatch(r"(\[[^\]]+\])+", t):
|
||||
return True
|
||||
|
||||
# 检查是否为纯emoji(Unicode表情)
|
||||
emoji_pattern = re.compile(
|
||||
"["
|
||||
"\U0001F600-\U0001F64F" # 表情符号
|
||||
"\U0001F300-\U0001F5FF" # 符号和象形文字
|
||||
"\U0001F680-\U0001F6FF" # 交通和地图符号
|
||||
"\U0001F1E0-\U0001F1FF" # 旗帜
|
||||
"\U00002702-\U000027B0"
|
||||
"\U000024C2-\U0001F251"
|
||||
"]+", flags=re.UNICODE
|
||||
)
|
||||
if emoji_pattern.fullmatch(t):
|
||||
return True
|
||||
|
||||
# 纯标点符号
|
||||
if re.fullmatch(r"[。!?,.!?…·\s]+", t):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
async def _batch_evaluate_importance_with_llm(
|
||||
self,
|
||||
messages: List[ConversationMessage],
|
||||
context: str = ""
|
||||
) -> Dict[int, int]:
|
||||
"""使用LLM批量评估消息的重要性(语义层面)。
|
||||
|
||||
Args:
|
||||
messages: 消息列表
|
||||
context: 对话上下文(可选)
|
||||
|
||||
Returns:
|
||||
消息索引到重要性分数(0-10)的映射
|
||||
"""
|
||||
if not self.llm_client or not messages:
|
||||
return {}
|
||||
|
||||
# 构建批量评估的提示词
|
||||
msg_list = []
|
||||
for idx, msg in enumerate(messages):
|
||||
msg_list.append(f"{idx}. {msg.msg}")
|
||||
|
||||
msg_text = "\n".join(msg_list)
|
||||
|
||||
prompt = f"""请评估以下消息的重要性,给每条消息打分(0-10分):
|
||||
- 0-2分:无意义的寒暄、口头禅、纯表情
|
||||
- 3-5分:一般性对话,有一定信息量但不关键
|
||||
- 6-8分:包含重要信息(时间、地点、人物、事件等)
|
||||
- 9-10分:关键决策、承诺、重要数据
|
||||
|
||||
对话上下文:
|
||||
{context if context else "无"}
|
||||
|
||||
待评估的消息:
|
||||
{msg_text}
|
||||
|
||||
请以JSON格式返回,格式为:
|
||||
{{
|
||||
"importance_scores": {{
|
||||
"0": 分数,
|
||||
"1": 分数,
|
||||
...
|
||||
}}
|
||||
}}
|
||||
"""
|
||||
|
||||
try:
|
||||
messages_for_llm = [
|
||||
{"role": "system", "content": "你是一个专业的对话分析助手,擅长评估消息的重要性。"},
|
||||
{"role": "user", "content": prompt}
|
||||
]
|
||||
|
||||
response = await self.llm_client.response_structured(
|
||||
messages_for_llm,
|
||||
MessageImportanceResponse
|
||||
)
|
||||
|
||||
# 转换字符串键为整数键
|
||||
return {int(k): v for k, v in response.importance_scores.items()}
|
||||
except Exception as e:
|
||||
self._log(f"[剪枝-LLM] 批量重要性评估失败: {str(e)[:100]}")
|
||||
return {}
|
||||
|
||||
def _identify_qa_pairs(self, messages: List[ConversationMessage]) -> List[QAPair]:
|
||||
"""识别对话中的问答对,用于保护问答结构的完整性。
|
||||
|
||||
改进版:使用场景特定的问句关键词,并排除寒暄类问句
|
||||
|
||||
Args:
|
||||
messages: 消息列表
|
||||
|
||||
Returns:
|
||||
问答对列表
|
||||
"""
|
||||
qa_pairs = []
|
||||
|
||||
# 寒暄类问句,不应该被保护(这些不是真正的问答)
|
||||
greeting_questions = {
|
||||
"在吗", "在不在", "你好吗", "怎么样", "好吗",
|
||||
"有空吗", "忙吗", "睡了吗", "起床了吗"
|
||||
}
|
||||
|
||||
for i in range(len(messages) - 1):
|
||||
current_msg = messages[i].msg.strip()
|
||||
next_msg = messages[i + 1].msg.strip()
|
||||
|
||||
# 排除寒暄类问句
|
||||
if current_msg in greeting_questions:
|
||||
continue
|
||||
|
||||
# 使用场景特定的问句关键词,但要求更严格
|
||||
is_question = False
|
||||
|
||||
# 1. 以问号结尾
|
||||
if current_msg.endswith("?") or current_msg.endswith("?"):
|
||||
is_question = True
|
||||
# 2. 包含实质性问句关键词(排除"吗"这种太宽泛的)
|
||||
elif any(word in current_msg for word in ["什么", "为什么", "怎么", "如何", "哪里", "哪个", "谁", "多少", "几点", "何时"]):
|
||||
is_question = True
|
||||
|
||||
if is_question and next_msg:
|
||||
# 检查下一条消息是否像答案(不是另一个问句,也不是寒暄)
|
||||
is_answer = not (next_msg.endswith("?") or next_msg.endswith("?"))
|
||||
|
||||
# 排除寒暄类回复
|
||||
greeting_answers = {"你好", "您好", "在呢", "在的", "嗯", "哦", "好的"}
|
||||
if next_msg in greeting_answers:
|
||||
is_answer = False
|
||||
|
||||
if is_answer:
|
||||
qa_pairs.append(QAPair(
|
||||
question_idx=i,
|
||||
answer_idx=i + 1,
|
||||
confidence=0.8 # 基于规则的置信度
|
||||
))
|
||||
|
||||
return qa_pairs
|
||||
|
||||
def _get_protected_indices(
|
||||
self,
|
||||
messages: List[ConversationMessage],
|
||||
qa_pairs: List[QAPair],
|
||||
window_size: int = 2
|
||||
) -> Set[int]:
|
||||
"""获取需要保护的消息索引集合(问答对+上下文窗口)。
|
||||
|
||||
Args:
|
||||
messages: 消息列表
|
||||
qa_pairs: 问答对列表
|
||||
window_size: 上下文窗口大小(前后各保留几条消息)
|
||||
|
||||
Returns:
|
||||
需要保护的消息索引集合
|
||||
"""
|
||||
protected = set()
|
||||
|
||||
for qa_pair in qa_pairs:
|
||||
# 保护问答对本身
|
||||
protected.add(qa_pair.question_idx)
|
||||
protected.add(qa_pair.answer_idx)
|
||||
|
||||
# 保护上下文窗口
|
||||
for offset in range(-window_size, window_size + 1):
|
||||
q_idx = qa_pair.question_idx + offset
|
||||
a_idx = qa_pair.answer_idx + offset
|
||||
|
||||
if 0 <= q_idx < len(messages):
|
||||
protected.add(q_idx)
|
||||
if 0 <= a_idx < len(messages):
|
||||
protected.add(a_idx)
|
||||
|
||||
return protected
|
||||
|
||||
async def _extract_dialog_important(self, dialog_text: str) -> DialogExtractionResponse:
|
||||
"""对话级一次性抽取:从整段对话中提取重要信息并判定相关性。
|
||||
|
||||
- 仅使用 LLM 结构化输出;
|
||||
改进版:
|
||||
- LRU缓存管理
|
||||
- 重试机制
|
||||
- 降级策略
|
||||
"""
|
||||
# 缓存命中则直接返回(场景+内容作为键)
|
||||
cache_key = f"{self.config.pruning_scene}:" + hashlib.sha1(dialog_text.encode("utf-8")).hexdigest()
|
||||
|
||||
# LRU缓存:如果命中,移到末尾(最近使用)
|
||||
if cache_key in self._dialog_extract_cache:
|
||||
self._dialog_extract_cache.move_to_end(cache_key)
|
||||
return self._dialog_extract_cache[cache_key]
|
||||
|
||||
rendered = self.template.render(pruning_scene=self.config.pruning_scene, dialog_text=dialog_text)
|
||||
log_template_rendering("extracat_Pruning.jinja2", {"pruning_scene": self.config.pruning_scene})
|
||||
# LRU缓存大小限制:超过限制时删除最旧的条目
|
||||
if len(self._dialog_extract_cache) >= self._cache_max_size:
|
||||
# 删除最旧的条目(OrderedDict的第一个)
|
||||
oldest_key = next(iter(self._dialog_extract_cache))
|
||||
del self._dialog_extract_cache[oldest_key]
|
||||
self._log(f"[剪枝-缓存] LRU缓存已满,删除最旧条目")
|
||||
|
||||
rendered = self.template.render(
|
||||
pruning_scene=self.config.pruning_scene,
|
||||
is_builtin_scene=self._is_builtin_scene,
|
||||
ontology_classes=self._ontology_classes,
|
||||
dialog_text=dialog_text,
|
||||
language=self.language
|
||||
)
|
||||
log_template_rendering("extracat_Pruning.jinja2", {
|
||||
"pruning_scene": self.config.pruning_scene,
|
||||
"is_builtin_scene": self._is_builtin_scene,
|
||||
"ontology_classes_count": len(self._ontology_classes),
|
||||
"language": self.language
|
||||
})
|
||||
log_prompt_rendering("pruning-extract", rendered)
|
||||
|
||||
# 强制使用 LLM;移除正则回退
|
||||
# 强制使用 LLM
|
||||
if not self.llm_client:
|
||||
raise RuntimeError("llm_client 未配置;请配置 LLM 以进行结构化抽取。")
|
||||
|
||||
@@ -153,12 +453,32 @@ class SemanticPruner:
|
||||
{"role": "system", "content": "你是一个严谨的场景抽取助手,只输出严格 JSON。"},
|
||||
{"role": "user", "content": rendered},
|
||||
]
|
||||
try:
|
||||
ex = await self.llm_client.response_structured(messages, DialogExtractionResponse)
|
||||
self._dialog_extract_cache[cache_key] = ex
|
||||
return ex
|
||||
except Exception as e:
|
||||
raise RuntimeError("LLM 结构化抽取失败;请检查 LLM 配置或重试。") from e
|
||||
|
||||
# 重试机制
|
||||
max_retries = 3
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
ex = await self.llm_client.response_structured(messages, DialogExtractionResponse)
|
||||
self._dialog_extract_cache[cache_key] = ex
|
||||
return ex
|
||||
except Exception as e:
|
||||
if attempt < max_retries - 1:
|
||||
self._log(f"[剪枝-LLM] 第 {attempt + 1} 次尝试失败,重试中... 错误: {str(e)[:100]}")
|
||||
await asyncio.sleep(0.5 * (attempt + 1)) # 指数退避
|
||||
continue
|
||||
else:
|
||||
# 降级策略:标记为相关,避免误删
|
||||
self._log(f"[剪枝-LLM] LLM 调用失败 {max_retries} 次,使用降级策略(标记为相关)")
|
||||
fallback_response = DialogExtractionResponse(
|
||||
is_related=True,
|
||||
times=[],
|
||||
ids=[],
|
||||
amounts=[],
|
||||
contacts=[],
|
||||
addresses=[],
|
||||
keywords=[]
|
||||
)
|
||||
return fallback_response
|
||||
|
||||
def _msg_matches_tokens(self, message: ConversationMessage, tokens: List[str]) -> bool:
|
||||
"""判断消息是否包含任意抽取到的重要片段。"""
|
||||
@@ -248,12 +568,14 @@ class SemanticPruner:
|
||||
async def prune_dataset(self, dialogs: List[DialogData]) -> List[DialogData]:
|
||||
"""数据集层面:全局消息级剪枝,保留所有对话。
|
||||
|
||||
- 仅在"不相关对话"的范围内执行消息剪枝;相关对话不动。
|
||||
- 只删除"不重要的不相关消息",重要信息(时间、编号等)强制保留。
|
||||
- 删除总量 = 阈值 * 全部不相关可删消息数,按可删容量比例分配;顺序删除。
|
||||
- 保证每段对话至少保留1条消息,不会删除整段对话。
|
||||
改进版:
|
||||
- 消息级独立判断,每条消息根据场景规则独立评估
|
||||
- 问答对保护已注释(暂不启用,留作观察)
|
||||
- 优化删除策略:填充消息 → 不重要消息 → 低分重要消息
|
||||
- 只删除"不重要的不相关消息",重要信息(时间、编号等)强制保留
|
||||
- 保证每段对话至少保留1条消息,不会删除整段对话
|
||||
"""
|
||||
# 如果剪枝功能关闭,直接返回原始数据集。
|
||||
# 如果剪枝功能关闭,直接返回原始数据集
|
||||
if not self.config.pruning_switch:
|
||||
return dialogs
|
||||
|
||||
@@ -264,179 +586,140 @@ class SemanticPruner:
|
||||
proportion = 0.9
|
||||
if proportion < 0.0:
|
||||
proportion = 0.0
|
||||
evaluated_dialogs = [] # list of dicts: {dialog, is_related}
|
||||
|
||||
self._log(
|
||||
f"[剪枝-数据集] 对话总数={len(dialogs)} 场景={self.config.pruning_scene} 删除比例={proportion} 开关={self.config.pruning_switch}"
|
||||
f"[剪枝-数据集] 对话总数={len(dialogs)} 场景={self.config.pruning_scene} 删除比例={proportion} 开关={self.config.pruning_switch} 模式=消息级独立判断"
|
||||
)
|
||||
# 对话级相关性分类(一次性对整段对话文本进行判断,顺序执行并复用缓存)
|
||||
evaluated_dialogs = []
|
||||
for idx, dd in enumerate(dialogs):
|
||||
try:
|
||||
ex = await self._extract_dialog_important(dd.content)
|
||||
evaluated_dialogs.append({
|
||||
"dialog": dd,
|
||||
"is_related": bool(ex.is_related),
|
||||
"index": idx,
|
||||
"extraction": ex
|
||||
})
|
||||
except Exception:
|
||||
evaluated_dialogs.append({
|
||||
"dialog": dd,
|
||||
"is_related": True,
|
||||
"index": idx,
|
||||
"extraction": None
|
||||
})
|
||||
|
||||
# 统计相关 / 不相关对话
|
||||
not_related_dialogs = [d for d in evaluated_dialogs if not d["is_related"]]
|
||||
related_dialogs = [d for d in evaluated_dialogs if d["is_related"]]
|
||||
self._log(
|
||||
f"[剪枝-数据集] 相关对话数={len(related_dialogs)} 不相关对话数={len(not_related_dialogs)}"
|
||||
)
|
||||
|
||||
# 简洁打印第几段对话相关/不相关(索引基于1)
|
||||
def _fmt_indices(items, cap: int = 10):
|
||||
inds = [i["index"] + 1 for i in items]
|
||||
if len(inds) <= cap:
|
||||
return inds
|
||||
# 超过上限时只打印前cap个,并标注总数
|
||||
return inds[:cap] + ["...", f"共{len(inds)}个"]
|
||||
|
||||
rel_inds = _fmt_indices(related_dialogs)
|
||||
nrel_inds = _fmt_indices(not_related_dialogs)
|
||||
self._log(f"[剪枝-数据集] 相关对话:第{rel_inds}段;不相关对话:第{nrel_inds}段")
|
||||
|
||||
|
||||
result: List[DialogData] = []
|
||||
if not_related_dialogs:
|
||||
# 为每个不相关对话进行一次性抽取,识别重要/不重要(避免逐条 LLM)
|
||||
per_dialog_info = {}
|
||||
total_unrelated = 0
|
||||
total_capacity = 0
|
||||
for d in not_related_dialogs:
|
||||
dd = d["dialog"]
|
||||
extraction = d.get("extraction")
|
||||
if extraction is None:
|
||||
extraction = await self._extract_dialog_important(dd.content)
|
||||
# 合并所有重要标记
|
||||
tokens = extraction.times + extraction.ids + extraction.amounts + extraction.contacts + extraction.addresses + extraction.keywords
|
||||
msgs = dd.context.msgs
|
||||
# 分类消息
|
||||
imp_unrel_msgs = [m for m in msgs if self._msg_matches_tokens(m, tokens) or self._is_important_message(m)]
|
||||
unimp_unrel_msgs = [m for m in msgs if m not in imp_unrel_msgs]
|
||||
# 重要消息按重要性排序
|
||||
imp_sorted_ids = [id(m) for m in sorted(imp_unrel_msgs, key=lambda m: self._importance_score(m))]
|
||||
info = {
|
||||
"dialog": dd,
|
||||
"total_msgs": len(msgs),
|
||||
"unrelated_count": len(msgs),
|
||||
"imp_ids_sorted": imp_sorted_ids,
|
||||
"unimp_ids": [id(m) for m in unimp_unrel_msgs],
|
||||
}
|
||||
per_dialog_info[d["index"]] = info
|
||||
total_unrelated += info["unrelated_count"]
|
||||
# 全局删除配额:比例作用于全部不相关消息(重要+不重要)
|
||||
global_delete = int(total_unrelated * proportion)
|
||||
if proportion > 0 and total_unrelated > 0 and global_delete == 0:
|
||||
global_delete = 1
|
||||
# 每段的最大可删容量:不重要全部 + 重要最多删除 floor(len(重要)*比例),且至少保留1条消息
|
||||
capacities = []
|
||||
for d in not_related_dialogs:
|
||||
idx = d["index"]
|
||||
info = per_dialog_info[idx]
|
||||
# 统计重要数量
|
||||
imp_count = len(info["imp_ids_sorted"])
|
||||
unimp_count = len(info["unimp_ids"])
|
||||
imp_cap = int(imp_count * proportion)
|
||||
cap = min(unimp_count + imp_cap, max(0, info["total_msgs"] - 1))
|
||||
capacities.append(cap)
|
||||
total_capacity = sum(capacities)
|
||||
if global_delete > total_capacity:
|
||||
print(f"[剪枝-数据集] 不相关消息总数={total_unrelated},目标删除={global_delete},最大可删={total_capacity}(重要消息按比例保留)。将按最大可删执行。")
|
||||
global_delete = total_capacity
|
||||
|
||||
# 配额分配:按不相关消息占比分配到各对话,但不超过各自容量
|
||||
alloc = []
|
||||
for i, d in enumerate(not_related_dialogs):
|
||||
idx = d["index"]
|
||||
info = per_dialog_info[idx]
|
||||
share = int(global_delete * (info["unrelated_count"] / total_unrelated)) if total_unrelated > 0 else 0
|
||||
alloc.append(min(share, capacities[i]))
|
||||
allocated = sum(alloc)
|
||||
rem = global_delete - allocated
|
||||
turn = 0
|
||||
while rem > 0 and turn < 100000:
|
||||
progressed = False
|
||||
for i in range(len(not_related_dialogs)):
|
||||
if rem <= 0:
|
||||
break
|
||||
if alloc[i] < capacities[i]:
|
||||
alloc[i] += 1
|
||||
rem -= 1
|
||||
progressed = True
|
||||
if not progressed:
|
||||
break
|
||||
turn += 1
|
||||
|
||||
# 应用删除:相关对话不动;不相关按分配先删不重要,再删重要(低分优先)
|
||||
total_deleted_confirm = 0
|
||||
for d in evaluated_dialogs:
|
||||
dd = d["dialog"]
|
||||
msgs = dd.context.msgs
|
||||
original = len(msgs)
|
||||
if d["is_related"]:
|
||||
result.append(dd)
|
||||
continue
|
||||
idx_in_unrel = next((k for k, x in enumerate(not_related_dialogs) if x["index"] == d["index"]), None)
|
||||
if idx_in_unrel is None:
|
||||
result.append(dd)
|
||||
continue
|
||||
quota = alloc[idx_in_unrel]
|
||||
info = per_dialog_info[d["index"]]
|
||||
# 计算本对话重要最多可删数量
|
||||
imp_count = len(info["imp_ids_sorted"])
|
||||
imp_del_cap = int(imp_count * proportion)
|
||||
# 先构造顺序删除的"不重要ID集合"(按出现顺序前 quota 条)
|
||||
unimp_delete_ids = set(info["unimp_ids"][:min(quota, len(info["unimp_ids"]))])
|
||||
del_unimp = min(quota, len(unimp_delete_ids))
|
||||
rem_quota = quota - del_unimp
|
||||
# 再从重要里选低分优先的删除ID(不超过 imp_del_cap)
|
||||
imp_delete_ids = set(info["imp_ids_sorted"][:min(rem_quota, imp_del_cap)])
|
||||
deleted_here = 0
|
||||
actual_unimp_deleted = 0
|
||||
actual_imp_deleted = 0
|
||||
kept = []
|
||||
for m in msgs:
|
||||
mid = id(m)
|
||||
if mid in unimp_delete_ids and actual_unimp_deleted < del_unimp:
|
||||
actual_unimp_deleted += 1
|
||||
deleted_here += 1
|
||||
continue
|
||||
if mid in imp_delete_ids and actual_imp_deleted < len(imp_delete_ids):
|
||||
actual_imp_deleted += 1
|
||||
deleted_here += 1
|
||||
continue
|
||||
kept.append(m)
|
||||
if not kept and msgs:
|
||||
kept = [msgs[0]]
|
||||
dd.context.msgs = kept
|
||||
total_deleted_confirm += deleted_here
|
||||
self._log(
|
||||
f"[剪枝-对话] 对话 {d['index']+1} 总消息={original} 分配删除={quota} 实删={deleted_here} 保留={len(kept)}"
|
||||
)
|
||||
result.append(dd)
|
||||
self._log(f"[剪枝-数据集] 全局消息级顺序剪枝完成,总删除 {total_deleted_confirm} 条(不相关消息,重要按比例保留)。")
|
||||
else:
|
||||
# 全部相关:不执行剪枝
|
||||
result = [d["dialog"] for d in evaluated_dialogs]
|
||||
total_original_msgs = 0
|
||||
total_deleted_msgs = 0
|
||||
|
||||
for d_idx, dd in enumerate(dialogs):
|
||||
msgs = dd.context.msgs
|
||||
original_count = len(msgs)
|
||||
total_original_msgs += original_count
|
||||
|
||||
# ========== 问答对保护(已注释,暂不启用,留作观察) ==========
|
||||
# qa_pairs = self._identify_qa_pairs(msgs)
|
||||
# protected_indices = self._get_protected_indices(msgs, qa_pairs, window_size=0)
|
||||
# ========================================================
|
||||
|
||||
# 消息级分类:每条消息独立判断
|
||||
important_msgs = [] # 重要消息(保留)
|
||||
unimportant_msgs = [] # 不重要消息(可删除)
|
||||
filler_msgs = [] # 填充消息(优先删除)
|
||||
|
||||
# 判断是否需要详细日志(仅对前N条消息记录)
|
||||
should_log_details = self._detailed_prune_logging and original_count <= self._max_debug_msgs_per_dialog
|
||||
if self._detailed_prune_logging and original_count > self._max_debug_msgs_per_dialog:
|
||||
self._log(f" 对话[{d_idx}]消息数={original_count},仅采样前{self._max_debug_msgs_per_dialog}条进行详细日志")
|
||||
|
||||
for idx, m in enumerate(msgs):
|
||||
msg_text = m.msg.strip()
|
||||
|
||||
# ========== 问答对保护判断(已注释) ==========
|
||||
# if idx in protected_indices:
|
||||
# important_msgs.append((idx, m))
|
||||
# self._log(f" [{idx}] '{msg_text[:30]}...' → 重要(问答对保护)")
|
||||
# ==========================================
|
||||
|
||||
# 填充消息(寒暄、表情等)
|
||||
if self._is_filler_message(m):
|
||||
filler_msgs.append((idx, m))
|
||||
if should_log_details or idx < self._max_debug_msgs_per_dialog:
|
||||
self._log(f" [{idx}] '{msg_text[:30]}...' → 填充")
|
||||
# 重要信息(学号、成绩、时间、金额等)
|
||||
elif self._is_important_message(m):
|
||||
important_msgs.append((idx, m))
|
||||
if should_log_details or idx < self._max_debug_msgs_per_dialog:
|
||||
self._log(f" [{idx}] '{msg_text[:30]}...' → 重要(场景规则)")
|
||||
# 其他消息
|
||||
else:
|
||||
unimportant_msgs.append((idx, m))
|
||||
if should_log_details or idx < self._max_debug_msgs_per_dialog:
|
||||
self._log(f" [{idx}] '{msg_text[:30]}...' → 不重要")
|
||||
|
||||
# 计算删除配额
|
||||
delete_target = int(original_count * proportion)
|
||||
if proportion > 0 and original_count > 0 and delete_target == 0:
|
||||
delete_target = 1
|
||||
|
||||
# 确保至少保留1条消息
|
||||
max_deletable = max(0, original_count - 1)
|
||||
delete_target = min(delete_target, max_deletable)
|
||||
|
||||
# 删除策略:优先删除填充消息,再删除不重要消息
|
||||
to_delete_indices = set()
|
||||
deleted_details = [] # 记录删除的消息详情
|
||||
|
||||
# 第一步:删除填充消息
|
||||
filler_to_delete = min(len(filler_msgs), delete_target)
|
||||
for i in range(filler_to_delete):
|
||||
idx, msg = filler_msgs[i]
|
||||
to_delete_indices.add(idx)
|
||||
deleted_details.append(f"[{idx}] 填充: '{msg.msg[:50]}'")
|
||||
|
||||
# 第二步:如果还需要删除,删除不重要消息
|
||||
remaining_quota = delete_target - len(to_delete_indices)
|
||||
if remaining_quota > 0:
|
||||
unimp_to_delete = min(len(unimportant_msgs), remaining_quota)
|
||||
for i in range(unimp_to_delete):
|
||||
idx, msg = unimportant_msgs[i]
|
||||
to_delete_indices.add(idx)
|
||||
deleted_details.append(f"[{idx}] 不重要: '{msg.msg[:50]}'")
|
||||
|
||||
# 第三步:如果还需要删除,按重要性分数删除重要消息
|
||||
remaining_quota = delete_target - len(to_delete_indices)
|
||||
if remaining_quota > 0 and important_msgs:
|
||||
# 按重要性分数排序(分数低的优先删除)
|
||||
imp_sorted = sorted(important_msgs, key=lambda x: self._importance_score(x[1]))
|
||||
imp_to_delete = min(len(imp_sorted), remaining_quota)
|
||||
for i in range(imp_to_delete):
|
||||
idx, msg = imp_sorted[i]
|
||||
to_delete_indices.add(idx)
|
||||
score = self._importance_score(msg)
|
||||
deleted_details.append(f"[{idx}] 重要(分数{score}): '{msg.msg[:50]}'")
|
||||
|
||||
# 执行删除
|
||||
kept_msgs = []
|
||||
for idx, m in enumerate(msgs):
|
||||
if idx not in to_delete_indices:
|
||||
kept_msgs.append(m)
|
||||
|
||||
# 确保至少保留1条
|
||||
if not kept_msgs and msgs:
|
||||
kept_msgs = [msgs[0]]
|
||||
|
||||
dd.context.msgs = kept_msgs
|
||||
deleted_count = original_count - len(kept_msgs)
|
||||
total_deleted_msgs += deleted_count
|
||||
|
||||
# 输出删除详情
|
||||
if deleted_details:
|
||||
self._log(f"[剪枝-删除详情] 对话 {d_idx+1} 删除了以下消息:")
|
||||
for detail in deleted_details:
|
||||
self._log(f" {detail}")
|
||||
|
||||
# ========== 问答对统计(已注释) ==========
|
||||
# qa_info = f",问答对={len(qa_pairs)}" if qa_pairs else ""
|
||||
# ========================================
|
||||
|
||||
self._log(
|
||||
f"[剪枝-对话] 对话 {d_idx+1} 总消息={original_count} "
|
||||
f"(重要={len(important_msgs)} 不重要={len(unimportant_msgs)} 填充={len(filler_msgs)}) "
|
||||
f"删除={deleted_count} 保留={len(kept_msgs)}"
|
||||
)
|
||||
|
||||
result.append(dd)
|
||||
|
||||
self._log(f"[剪枝-数据集] 剩余对话数={len(result)}")
|
||||
|
||||
# 将本次剪枝阶段的终端输出保存为 JSON 文件(仅在剪枝器内部完成)
|
||||
# 保存日志
|
||||
try:
|
||||
from app.core.config import settings
|
||||
settings.ensure_memory_output_dir()
|
||||
log_output_path = settings.get_memory_output_path("pruned_terminal.json")
|
||||
# 去除日志前缀标签(如 [剪枝-数据集]、[剪枝-对话])后再解析为结构化字段保存
|
||||
sanitized_logs = [self._sanitize_log_line(l) for l in self.run_logs]
|
||||
payload = self._parse_logs_to_structured(sanitized_logs)
|
||||
with open(log_output_path, "w", encoding="utf-8") as f:
|
||||
@@ -448,6 +731,7 @@ class SemanticPruner:
|
||||
if not result:
|
||||
print("警告: 语义剪枝后数据集为空,已回退为未剪枝数据以避免流程中断")
|
||||
return dialogs
|
||||
|
||||
return result
|
||||
|
||||
def _log(self, msg: str) -> None:
|
||||
|
||||
@@ -0,0 +1,326 @@
|
||||
"""
|
||||
场景特定配置 - 为不同场景提供定制化的剪枝规则
|
||||
|
||||
功能:
|
||||
- 场景特定的重要信息识别模式
|
||||
- 场景特定的重要性评分权重
|
||||
- 场景特定的填充词库
|
||||
- 场景特定的问答对识别规则
|
||||
"""
|
||||
|
||||
from typing import Dict, List, Set, Tuple
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScenePatterns:
|
||||
"""场景特定的识别模式"""
|
||||
|
||||
# 重要信息的正则模式(优先级从高到低)
|
||||
high_priority_patterns: List[Tuple[str, int]] = field(default_factory=list) # (pattern, weight)
|
||||
medium_priority_patterns: List[Tuple[str, int]] = field(default_factory=list)
|
||||
low_priority_patterns: List[Tuple[str, int]] = field(default_factory=list)
|
||||
|
||||
# 填充词库(无意义对话)
|
||||
filler_phrases: Set[str] = field(default_factory=set)
|
||||
|
||||
# 问句关键词(用于识别问答对)
|
||||
question_keywords: Set[str] = field(default_factory=set)
|
||||
|
||||
# 决策性/承诺性关键词
|
||||
decision_keywords: Set[str] = field(default_factory=set)
|
||||
|
||||
|
||||
class SceneConfigRegistry:
|
||||
"""场景配置注册表 - 管理所有场景的特定配置"""
|
||||
|
||||
# 基础通用模式(所有场景共享)
|
||||
BASE_HIGH_PRIORITY = [
|
||||
(r"订单号|工单|申请号|编号|ID|账号|账户", 5),
|
||||
(r"金额|费用|价格|¥|¥|\d+元", 5),
|
||||
(r"\d{11}", 4), # 手机号
|
||||
(r"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}", 4), # 邮箱
|
||||
]
|
||||
|
||||
BASE_MEDIUM_PRIORITY = [
|
||||
(r"\d{4}-\d{1,2}-\d{1,2}", 3), # 日期
|
||||
(r"\d{4}年\d{1,2}月\d{1,2}日", 3),
|
||||
(r"电话|手机号|微信|QQ|联系方式", 3),
|
||||
(r"地址|地点|位置", 2),
|
||||
(r"时间|日期|有效期|截止", 2),
|
||||
(r"今天|明天|后天|昨天|前天", 3), # 相对时间(提高权重)
|
||||
(r"下周|下月|下年|上周|上月|上年|本周|本月|本年", 3),
|
||||
(r"今年|去年|明年", 3),
|
||||
]
|
||||
|
||||
BASE_LOW_PRIORITY = [
|
||||
(r"\d{1,2}:\d{2}", 2), # 时间点 HH:MM
|
||||
(r"\d{1,2}点\d{0,2}分?", 2), # 时间点 X点Y分 或 X点
|
||||
(r"上午|下午|中午|晚上|早上|傍晚|凌晨", 2), # 时段(提高权重并扩充)
|
||||
(r"AM|PM|am|pm", 1),
|
||||
]
|
||||
|
||||
BASE_FILLERS = {
|
||||
# 基础寒暄
|
||||
"你好", "您好", "在吗", "在的", "在呢", "嗯", "嗯嗯", "哦", "哦哦",
|
||||
"好的", "好", "行", "可以", "不可以", "谢谢", "多谢", "感谢",
|
||||
"拜拜", "再见", "88", "拜", "回见",
|
||||
# 口头禅
|
||||
"哈哈", "呵呵", "哈哈哈", "嘿嘿", "嘻嘻", "hiahia",
|
||||
"额", "呃", "啊", "诶", "唉", "哎", "嗯哼",
|
||||
# 确认词
|
||||
"是的", "对", "对的", "没错", "嗯嗯", "好嘞", "收到", "明白", "了解", "知道了",
|
||||
# 标点和符号
|
||||
"。。。", "...", "???", "???", "!!!", "!!!",
|
||||
# 表情符号
|
||||
"[微笑]", "[呲牙]", "[发呆]", "[得意]", "[流泪]", "[害羞]", "[闭嘴]",
|
||||
"[睡]", "[大哭]", "[尴尬]", "[发怒]", "[调皮]", "[龇牙]", "[惊讶]",
|
||||
"[难过]", "[酷]", "[冷汗]", "[抓狂]", "[吐]", "[偷笑]", "[可爱]",
|
||||
"[白眼]", "[傲慢]", "[饥饿]", "[困]", "[惊恐]", "[流汗]", "[憨笑]",
|
||||
# 网络用语
|
||||
"hhh", "hhhh", "2333", "666", "gg", "ok", "OK", "okok",
|
||||
"emmm", "emm", "em", "mmp", "wtf", "omg",
|
||||
}
|
||||
|
||||
BASE_QUESTION_KEYWORDS = {
|
||||
"什么", "为什么", "怎么", "如何", "哪里", "哪个", "谁", "多少", "几点", "何时", "吗"
|
||||
}
|
||||
|
||||
BASE_DECISION_KEYWORDS = {
|
||||
"必须", "一定", "务必", "需要", "要求", "规定", "应该",
|
||||
"承诺", "保证", "确保", "负责", "同意", "答应"
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_education_config(cls) -> ScenePatterns:
|
||||
"""教育场景配置"""
|
||||
return ScenePatterns(
|
||||
high_priority_patterns=cls.BASE_HIGH_PRIORITY + [
|
||||
# 成绩相关(最高优先级)
|
||||
(r"成绩|分数|得分|满分|及格|不及格", 6),
|
||||
(r"GPA|绩点|学分|平均分", 6),
|
||||
(r"\d+分|\d+\.?\d*分", 5), # 具体分数
|
||||
(r"排名|名次|第.{1,3}名", 5), # 支持"第三名"、"第1名"等
|
||||
|
||||
# 学籍信息
|
||||
(r"学号|学生证|教师工号|工号", 5),
|
||||
(r"班级|年级|专业|院系", 4),
|
||||
|
||||
# 课程相关
|
||||
(r"课程|科目|学科|必修|选修", 4),
|
||||
(r"教材|课本|教科书|参考书", 4),
|
||||
(r"章节|第.{1,3}章|第.{1,3}节", 3), # 支持"第三章"、"第1章"等
|
||||
|
||||
# 学科内容(新增)
|
||||
(r"微积分|导数|积分|函数|极限|微分", 4),
|
||||
(r"代数|几何|三角|概率|统计", 4),
|
||||
(r"物理|化学|生物|历史|地理", 4),
|
||||
(r"英语|语文|数学|政治|哲学", 4),
|
||||
(r"定义|定理|公式|概念|原理|法则", 3),
|
||||
(r"例题|解题|证明|推导|计算", 3),
|
||||
],
|
||||
medium_priority_patterns=cls.BASE_MEDIUM_PRIORITY + [
|
||||
# 教学活动
|
||||
(r"作业|练习|习题|题目", 3),
|
||||
(r"考试|测验|测试|考核|期中|期末", 3),
|
||||
(r"上课|下课|课堂|讲课", 2),
|
||||
(r"提问|回答|发言|讨论", 2),
|
||||
(r"问一下|请教|咨询|询问", 2), # 新增:问询相关
|
||||
(r"理解|明白|懂|掌握|学会", 2), # 新增:学习状态
|
||||
|
||||
# 时间安排
|
||||
(r"课表|课程表|时间表", 3),
|
||||
(r"第.{1,3}节课|第.{1,3}周", 2), # 支持"第三节课"、"第1周"等
|
||||
],
|
||||
low_priority_patterns=cls.BASE_LOW_PRIORITY + [
|
||||
(r"老师|教师|同学|学生", 1),
|
||||
(r"教室|实验室|图书馆", 1),
|
||||
],
|
||||
filler_phrases=cls.BASE_FILLERS | {
|
||||
# 教育场景特有填充词(移除了"明白了"、"懂了"、"不懂"等,这些在教育场景中有意义)
|
||||
"老师好", "同学们好", "上课", "下课", "起立", "坐下",
|
||||
"举手", "请坐", "很好", "不错", "继续",
|
||||
"下一个", "下一题", "下一位", "还有吗", "还有问题吗",
|
||||
},
|
||||
question_keywords=cls.BASE_QUESTION_KEYWORDS | {
|
||||
"为啥", "咋", "咋办", "怎样", "如何做",
|
||||
"能不能", "可不可以", "行不行", "对不对", "是不是",
|
||||
},
|
||||
decision_keywords=cls.BASE_DECISION_KEYWORDS | {
|
||||
"必考", "重点", "考点", "难点", "关键",
|
||||
"记住", "背诵", "掌握", "理解", "复习",
|
||||
}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_online_service_config(cls) -> ScenePatterns:
|
||||
"""在线服务场景配置"""
|
||||
return ScenePatterns(
|
||||
high_priority_patterns=cls.BASE_HIGH_PRIORITY + [
|
||||
# 工单相关(最高优先级)
|
||||
(r"工单号|工单编号|ticket|TK\d+", 6),
|
||||
(r"工单状态|处理中|已解决|已关闭|待处理", 5),
|
||||
(r"优先级|紧急|高优先级|P0|P1|P2", 5),
|
||||
|
||||
# 产品信息
|
||||
(r"产品型号|型号|SKU|产品编号", 5),
|
||||
(r"序列号|SN|设备号", 5),
|
||||
(r"版本号|软件版本|固件版本", 4),
|
||||
|
||||
# 问题描述
|
||||
(r"故障|错误|异常|bug|问题", 4),
|
||||
(r"错误代码|故障代码|error code", 5),
|
||||
(r"无法|不能|失败|报错", 3),
|
||||
],
|
||||
medium_priority_patterns=cls.BASE_MEDIUM_PRIORITY + [
|
||||
# 服务相关
|
||||
(r"退款|退货|换货|补发", 4),
|
||||
(r"发票|收据|凭证", 3),
|
||||
(r"物流|快递|运单号", 3),
|
||||
(r"保修|质保|售后", 3),
|
||||
|
||||
# 时效相关
|
||||
(r"SLA|响应时间|处理时长", 4),
|
||||
(r"超时|延迟|等待", 2),
|
||||
],
|
||||
low_priority_patterns=cls.BASE_LOW_PRIORITY + [
|
||||
(r"客服|工程师|技术支持", 1),
|
||||
(r"用户|客户|会员", 1),
|
||||
],
|
||||
filler_phrases=cls.BASE_FILLERS | {
|
||||
# 在线服务特有填充词
|
||||
"您好", "请问", "请稍等", "稍等", "马上", "立即",
|
||||
"正在查询", "正在处理", "正在为您", "帮您查一下",
|
||||
"还有其他问题吗", "还需要什么帮助", "很高兴为您服务",
|
||||
"感谢您的耐心等待", "抱歉让您久等了",
|
||||
"已记录", "已反馈", "已转接", "已升级",
|
||||
"祝您生活愉快", "再见", "欢迎下次咨询",
|
||||
},
|
||||
question_keywords=cls.BASE_QUESTION_KEYWORDS | {
|
||||
"能否", "可否", "是否", "有没有", "能不能",
|
||||
"怎么办", "如何处理", "怎么解决",
|
||||
},
|
||||
decision_keywords=cls.BASE_DECISION_KEYWORDS | {
|
||||
"立即处理", "马上解决", "尽快", "优先",
|
||||
"升级", "转接", "派单", "跟进",
|
||||
"补偿", "赔偿", "退款", "换货",
|
||||
}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_outbound_config(cls) -> ScenePatterns:
|
||||
"""外呼场景配置"""
|
||||
return ScenePatterns(
|
||||
high_priority_patterns=cls.BASE_HIGH_PRIORITY + [
|
||||
# 意向相关(最高优先级)
|
||||
(r"意向|意愿|兴趣|感兴趣", 6),
|
||||
(r"A类|B类|C类|D类|高意向|低意向", 6),
|
||||
(r"成交|签约|下单|购买|确认", 6),
|
||||
|
||||
# 联系信息(外呼场景中更重要)
|
||||
(r"预约|约定|安排|确定时间", 5),
|
||||
(r"下次联系|回访|跟进", 5),
|
||||
(r"方便|有空|可以|时间", 4),
|
||||
|
||||
# 通话状态
|
||||
(r"接通|未接通|占线|关机|停机", 4),
|
||||
(r"通话时长|通话时间", 3),
|
||||
],
|
||||
medium_priority_patterns=cls.BASE_MEDIUM_PRIORITY + [
|
||||
# 客户信息
|
||||
(r"姓名|称呼|先生|女士", 3),
|
||||
(r"公司|单位|职位|职务", 3),
|
||||
(r"需求|要求|期望", 3),
|
||||
|
||||
# 跟进状态
|
||||
(r"跟进状态|进展|进度", 3),
|
||||
(r"已联系|待联系|联系中", 2),
|
||||
(r"拒绝|不感兴趣|考虑|再说", 3),
|
||||
],
|
||||
low_priority_patterns=cls.BASE_LOW_PRIORITY + [
|
||||
(r"销售|客户经理|业务员", 1),
|
||||
(r"产品|服务|方案", 1),
|
||||
],
|
||||
filler_phrases=cls.BASE_FILLERS | {
|
||||
# 外呼场景特有填充词
|
||||
"您好", "喂", "hello", "打扰了", "不好意思",
|
||||
"方便接电话吗", "现在方便吗", "占用您一点时间",
|
||||
"我是", "我们是", "我们公司", "我们这边",
|
||||
"了解一下", "介绍一下", "简单说一下",
|
||||
"考虑考虑", "想一想", "再说", "再看看",
|
||||
"不需要", "不感兴趣", "没兴趣", "不用了",
|
||||
"好的", "行", "可以", "没问题", "那就这样",
|
||||
"再联系", "回头聊", "有需要再说",
|
||||
},
|
||||
question_keywords=cls.BASE_QUESTION_KEYWORDS | {
|
||||
"有没有", "需不需要", "要不要", "考虑不考虑",
|
||||
"了解吗", "知道吗", "听说过吗",
|
||||
"方便吗", "有空吗", "在吗",
|
||||
},
|
||||
decision_keywords=cls.BASE_DECISION_KEYWORDS | {
|
||||
"确定", "决定", "选择", "购买", "下单",
|
||||
"预约", "安排", "约定", "确认",
|
||||
"跟进", "回访", "联系", "沟通",
|
||||
}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_config(cls, scene: str, fallback_to_generic: bool = True) -> ScenePatterns:
|
||||
"""根据场景名称获取配置
|
||||
|
||||
Args:
|
||||
scene: 场景名称 ('education', 'online_service', 'outbound' 或其他)
|
||||
fallback_to_generic: 如果场景不存在,是否降级到通用配置
|
||||
|
||||
Returns:
|
||||
对应场景的配置,如果场景不存在:
|
||||
- fallback_to_generic=True: 返回通用配置(仅基础规则)
|
||||
- fallback_to_generic=False: 抛出异常
|
||||
"""
|
||||
scene_map = {
|
||||
'education': cls.get_education_config,
|
||||
'online_service': cls.get_online_service_config,
|
||||
'outbound': cls.get_outbound_config,
|
||||
}
|
||||
|
||||
if scene in scene_map:
|
||||
return scene_map[scene]()
|
||||
|
||||
if fallback_to_generic:
|
||||
# 返回通用配置(仅包含基础规则,不包含场景特定规则)
|
||||
return cls.get_generic_config()
|
||||
else:
|
||||
raise ValueError(f"不支持的场景: {scene},支持的场景: {list(scene_map.keys())}")
|
||||
|
||||
@classmethod
|
||||
def get_generic_config(cls) -> ScenePatterns:
|
||||
"""通用场景配置 - 仅包含基础规则,适用于未定义的场景
|
||||
|
||||
这是一个保守的配置,只使用最通用的规则,避免误删重要信息
|
||||
"""
|
||||
return ScenePatterns(
|
||||
high_priority_patterns=cls.BASE_HIGH_PRIORITY,
|
||||
medium_priority_patterns=cls.BASE_MEDIUM_PRIORITY,
|
||||
low_priority_patterns=cls.BASE_LOW_PRIORITY,
|
||||
filler_phrases=cls.BASE_FILLERS,
|
||||
question_keywords=cls.BASE_QUESTION_KEYWORDS,
|
||||
decision_keywords=cls.BASE_DECISION_KEYWORDS
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_all_scenes(cls) -> List[str]:
|
||||
"""获取所有预定义场景的列表"""
|
||||
return ['education', 'online_service', 'outbound']
|
||||
|
||||
@classmethod
|
||||
def is_scene_supported(cls, scene: str) -> bool:
|
||||
"""检查场景是否有专门的配置支持
|
||||
|
||||
Args:
|
||||
scene: 场景名称
|
||||
|
||||
Returns:
|
||||
True: 有专门配置
|
||||
False: 将使用通用配置
|
||||
"""
|
||||
return scene in cls.get_all_scenes()
|
||||
@@ -1932,17 +1932,17 @@ def preprocess_data(
|
||||
Returns:
|
||||
经过清洗转换后的 DialogData 列表
|
||||
"""
|
||||
print("\n=== 数据预处理 ===")
|
||||
logger.debug("=== 数据预处理 ===")
|
||||
from app.core.memory.storage_services.extraction_engine.data_preprocessing.data_preprocessor import (
|
||||
DataPreprocessor,
|
||||
)
|
||||
preprocessor = DataPreprocessor()
|
||||
try:
|
||||
cleaned_data = preprocessor.preprocess(input_path=input_path, output_path=output_path, skip_cleaning=skip_cleaning, indices=indices)
|
||||
print(f"数据预处理完成!共处理了 {len(cleaned_data)} 条对话数据")
|
||||
logger.debug(f"数据预处理完成!共处理了 {len(cleaned_data)} 条对话数据")
|
||||
return cleaned_data
|
||||
except Exception as e:
|
||||
print(f"数据预处理过程中出现错误: {e}")
|
||||
logger.error(f"数据预处理过程中出现错误: {e}")
|
||||
raise
|
||||
|
||||
|
||||
@@ -1961,7 +1961,7 @@ async def get_chunked_dialogs_from_preprocessed(
|
||||
Returns:
|
||||
带 chunks 的 DialogData 列表
|
||||
"""
|
||||
print(f"\n=== 批量对话分块处理 (使用 {chunker_strategy}) ===")
|
||||
logger.debug(f"=== 批量对话分块处理 (使用 {chunker_strategy}) ===")
|
||||
if not data:
|
||||
raise ValueError("预处理数据为空,无法进行分块")
|
||||
|
||||
@@ -1988,6 +1988,7 @@ async def get_chunked_dialogs_with_preprocessing(
|
||||
input_data_path: Optional[str] = None,
|
||||
llm_client: Optional[Any] = None,
|
||||
skip_cleaning: bool = True,
|
||||
pruning_config: Optional[Dict] = None,
|
||||
) -> List[DialogData]:
|
||||
"""包含数据预处理步骤的完整分块流程
|
||||
|
||||
@@ -2000,11 +2001,12 @@ async def get_chunked_dialogs_with_preprocessing(
|
||||
input_data_path: 输入数据路径
|
||||
llm_client: LLM 客户端
|
||||
skip_cleaning: 是否跳过数据清洗步骤(默认False)
|
||||
pruning_config: 剪枝配置字典,包含 pruning_switch, pruning_scene, pruning_threshold
|
||||
|
||||
Returns:
|
||||
带 chunks 的 DialogData 列表
|
||||
"""
|
||||
print("\n=== 完整数据处理流程(包含预处理)===")
|
||||
logger.debug("=== 完整数据处理流程(包含预处理)===")
|
||||
|
||||
if input_data_path is None:
|
||||
input_data_path = os.path.join(
|
||||
@@ -2030,7 +2032,19 @@ async def get_chunked_dialogs_with_preprocessing(
|
||||
from app.core.memory.storage_services.extraction_engine.data_preprocessing.data_pruning import (
|
||||
SemanticPruner,
|
||||
)
|
||||
pruner = SemanticPruner(llm_client=llm_client)
|
||||
from app.core.memory.models.config_models import PruningConfig
|
||||
|
||||
# 构建剪枝配置
|
||||
if pruning_config:
|
||||
# 使用传入的配置
|
||||
config = PruningConfig(**pruning_config)
|
||||
logger.debug(f"[剪枝] 使用传入配置: switch={config.pruning_switch}, scene={config.pruning_scene}, threshold={config.pruning_threshold}")
|
||||
else:
|
||||
# 使用默认配置(关闭剪枝)
|
||||
config = None
|
||||
logger.debug("[剪枝] 未提供配置,使用默认配置(剪枝关闭)")
|
||||
|
||||
pruner = SemanticPruner(config=config, llm_client=llm_client)
|
||||
|
||||
# 记录单对话场景下剪枝前的消息数量
|
||||
single_dialog_original_msgs = None
|
||||
@@ -2043,12 +2057,12 @@ async def get_chunked_dialogs_with_preprocessing(
|
||||
if len(preprocessed_data) == 1 and single_dialog_original_msgs is not None:
|
||||
remaining_msgs = len(preprocessed_data[0].context.msgs) if preprocessed_data[0].context else 0
|
||||
deleted_msgs = max(0, single_dialog_original_msgs - remaining_msgs)
|
||||
print(
|
||||
logger.debug(
|
||||
f"语义剪枝完成!剩余 1 条对话!原始消息数:{single_dialog_original_msgs},"
|
||||
f"保留消息数:{remaining_msgs},删除 {deleted_msgs} 条。"
|
||||
)
|
||||
else:
|
||||
print(f"语义剪枝完成!剩余 {len(preprocessed_data)} 条对话")
|
||||
logger.debug(f"语义剪枝完成!剩余 {len(preprocessed_data)} 条对话")
|
||||
|
||||
# 保存剪枝后的数据
|
||||
try:
|
||||
@@ -2059,9 +2073,9 @@ async def get_chunked_dialogs_with_preprocessing(
|
||||
dp = DataPreprocessor(output_file_path=pruned_output_path)
|
||||
dp.save_data(preprocessed_data, output_path=pruned_output_path)
|
||||
except Exception as se:
|
||||
print(f"保存剪枝结果失败:{se}")
|
||||
logger.error(f"保存剪枝结果失败:{se}")
|
||||
except Exception as e:
|
||||
print(f"语义剪枝过程中出现错误,跳过剪枝: {e}")
|
||||
logger.error(f"语义剪枝过程中出现错误,跳过剪枝: {e}")
|
||||
|
||||
# 步骤3: 对话分块
|
||||
return await get_chunked_dialogs_from_preprocessed(
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
import os
|
||||
from typing import Optional
|
||||
from typing import Optional, List, Any
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
|
||||
from app.core.logging_config import get_memory_logger
|
||||
from app.core.memory.models.message_models import DialogData, Chunk
|
||||
@@ -10,6 +12,20 @@ from app.core.memory.utils.config.config_utils import get_chunker_config
|
||||
logger = get_memory_logger(__name__)
|
||||
|
||||
|
||||
class ChunkerStrategy(Enum):
|
||||
"""Supported chunking strategies."""
|
||||
RECURSIVE = "RecursiveChunker"
|
||||
SEMANTIC = "SemanticChunker"
|
||||
LATE = "LateChunker"
|
||||
NEURAL = "NeuralChunker"
|
||||
LLM = "LLMChunker"
|
||||
|
||||
@classmethod
|
||||
def get_valid_strategies(cls) -> List[str]:
|
||||
"""Get list of valid strategy names."""
|
||||
return [strategy.value for strategy in cls]
|
||||
|
||||
|
||||
class DialogueChunker:
|
||||
"""A class that processes dialogues and fills them with chunks based on a specified strategy.
|
||||
|
||||
@@ -17,23 +33,51 @@ class DialogueChunker:
|
||||
of different chunking strategies to dialogue data.
|
||||
"""
|
||||
|
||||
def __init__(self, chunker_strategy: str = "RecursiveChunker", llm_client=None):
|
||||
def __init__(self, chunker_strategy: str = "RecursiveChunker", llm_client: Optional[Any] = None):
|
||||
"""Initialize the DialogueChunker with a specific chunking strategy.
|
||||
|
||||
Args:
|
||||
chunker_strategy: The chunking strategy to use (default: RecursiveChunker)
|
||||
Options: SemanticChunker, RecursiveChunker, LateChunker, NeuralChunker
|
||||
Options: SemanticChunker, RecursiveChunker, LateChunker, NeuralChunker, LLMChunker
|
||||
llm_client: LLM client instance (required for LLMChunker strategy)
|
||||
|
||||
Raises:
|
||||
ValueError: If chunker_strategy is invalid or required parameters are missing
|
||||
"""
|
||||
self.chunker_strategy = chunker_strategy
|
||||
chunker_config_dict = get_chunker_config(chunker_strategy)
|
||||
self.chunker_config = ChunkerConfig.model_validate(chunker_config_dict)
|
||||
# Validate strategy
|
||||
valid_strategies = ChunkerStrategy.get_valid_strategies()
|
||||
if chunker_strategy not in valid_strategies:
|
||||
raise ValueError(
|
||||
f"Invalid chunker_strategy: '{chunker_strategy}'. "
|
||||
f"Must be one of {valid_strategies}"
|
||||
)
|
||||
|
||||
if self.chunker_config.chunker_strategy == "LLMChunker":
|
||||
self.chunker_client = ChunkerClient(self.chunker_config, llm_client)
|
||||
else:
|
||||
self.chunker_client = ChunkerClient(self.chunker_config)
|
||||
self.chunker_strategy = chunker_strategy
|
||||
logger.info(f"Initializing DialogueChunker with strategy: {chunker_strategy}")
|
||||
|
||||
try:
|
||||
# Load and validate configuration
|
||||
chunker_config_dict = get_chunker_config(chunker_strategy)
|
||||
if not chunker_config_dict:
|
||||
raise ValueError(f"Failed to load configuration for strategy: {chunker_strategy}")
|
||||
|
||||
self.chunker_config = ChunkerConfig.model_validate(chunker_config_dict)
|
||||
|
||||
# Initialize chunker client
|
||||
if self.chunker_config.chunker_strategy == "LLMChunker":
|
||||
if not llm_client:
|
||||
raise ValueError("llm_client is required for LLMChunker strategy")
|
||||
self.chunker_client = ChunkerClient(self.chunker_config, llm_client)
|
||||
else:
|
||||
self.chunker_client = ChunkerClient(self.chunker_config)
|
||||
|
||||
logger.info(f"DialogueChunker initialized successfully with strategy: {chunker_strategy}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize DialogueChunker: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def process_dialogue(self, dialogue: DialogData) -> list[Chunk]:
|
||||
async def process_dialogue(self, dialogue: DialogData) -> List[Chunk]:
|
||||
"""Process a dialogue by generating chunks and adding them to the DialogData object.
|
||||
|
||||
Args:
|
||||
@@ -43,54 +87,125 @@ class DialogueChunker:
|
||||
A list of Chunk objects
|
||||
|
||||
Raises:
|
||||
ValueError: If chunking fails or returns empty chunks
|
||||
ValueError: If dialogue is invalid or chunking fails
|
||||
Exception: If chunking process encounters an error
|
||||
"""
|
||||
result_dialogue = await self.chunker_client.generate_chunks(dialogue)
|
||||
chunks = result_dialogue.chunks
|
||||
|
||||
if not chunks or len(chunks) == 0:
|
||||
# Validate input
|
||||
if not dialogue:
|
||||
raise ValueError("dialogue cannot be None")
|
||||
|
||||
if not dialogue.context or not dialogue.context.msgs:
|
||||
raise ValueError(
|
||||
f"Chunking failed: No chunks generated for dialogue {dialogue.ref_id}. "
|
||||
f"Messages: {len(dialogue.context.msgs) if dialogue.context else 0}, "
|
||||
f"Strategy: {self.chunker_config.chunker_strategy}"
|
||||
f"Dialogue {dialogue.ref_id} has no messages to chunk. "
|
||||
f"Context: {dialogue.context is not None}, "
|
||||
f"Messages: {len(dialogue.context.msgs) if dialogue.context else 0}"
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Processing dialogue {dialogue.ref_id} with {len(dialogue.context.msgs)} messages "
|
||||
f"using strategy: {self.chunker_strategy}"
|
||||
)
|
||||
|
||||
try:
|
||||
# Generate chunks
|
||||
result_dialogue = await self.chunker_client.generate_chunks(dialogue)
|
||||
chunks = result_dialogue.chunks
|
||||
|
||||
return chunks
|
||||
# Validate results
|
||||
if not chunks or len(chunks) == 0:
|
||||
raise ValueError(
|
||||
f"Chunking failed: No chunks generated for dialogue {dialogue.ref_id}. "
|
||||
f"Messages: {len(dialogue.context.msgs)}, "
|
||||
f"Content length: {len(dialogue.content) if dialogue.content else 0}, "
|
||||
f"Strategy: {self.chunker_config.chunker_strategy}"
|
||||
)
|
||||
|
||||
def save_chunking_results(self, dialogue: DialogData, output_path: Optional[str] = None) -> str:
|
||||
logger.info(
|
||||
f"Successfully generated {len(chunks)} chunks for dialogue {dialogue.ref_id}. "
|
||||
f"Total characters processed: {len(dialogue.content) if dialogue.content else 0}"
|
||||
)
|
||||
|
||||
return chunks
|
||||
|
||||
except ValueError:
|
||||
# Re-raise validation errors
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error processing dialogue {dialogue.ref_id} with strategy {self.chunker_strategy}: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
raise
|
||||
|
||||
def save_chunking_results(
|
||||
self,
|
||||
chunks: List[Chunk],
|
||||
dialogue: DialogData,
|
||||
output_path: Optional[str] = None,
|
||||
preview_length: int = 100
|
||||
) -> str:
|
||||
"""Save the chunking results to a file and return the output path.
|
||||
|
||||
Args:
|
||||
dialogue: The processed DialogData object with chunks
|
||||
output_path: Optional path to save the output
|
||||
chunks: List of Chunk objects to save
|
||||
dialogue: The DialogData object that was processed
|
||||
output_path: Optional path to save the output (defaults to current directory)
|
||||
preview_length: Maximum length of content preview (default: 100)
|
||||
|
||||
Returns:
|
||||
The path where the output was saved
|
||||
|
||||
Raises:
|
||||
ValueError: If chunks or dialogue is invalid
|
||||
IOError: If file writing fails
|
||||
"""
|
||||
if not output_path:
|
||||
output_path = os.path.join(
|
||||
os.path.dirname(__file__), "..", "..",
|
||||
f"chunker_output_{self.chunker_strategy.lower()}.txt"
|
||||
)
|
||||
|
||||
output_lines = [
|
||||
f"=== Chunking Results ({self.chunker_strategy}) ===",
|
||||
f"Dialogue ID: {dialogue.ref_id}",
|
||||
f"Original conversation has {len(dialogue.context.msgs)} messages",
|
||||
f"Total characters: {len(dialogue.content)}",
|
||||
f"Generated {len(dialogue.chunks)} chunks:"
|
||||
]
|
||||
# Validate input
|
||||
if not chunks:
|
||||
raise ValueError("chunks list cannot be empty")
|
||||
if not dialogue:
|
||||
raise ValueError("dialogue cannot be None")
|
||||
|
||||
for i, chunk in enumerate(dialogue.chunks):
|
||||
output_lines.append(f" Chunk {i+1}: {len(chunk.content)} characters")
|
||||
output_lines.append(f" Content preview: {chunk.content}...")
|
||||
if chunk.metadata:
|
||||
output_lines.append(f" Metadata: {chunk.metadata}")
|
||||
# Generate default output path if not provided
|
||||
if not output_path:
|
||||
output_dir = Path(__file__).parent.parent.parent
|
||||
output_path = str(output_dir / f"chunker_output_{self.chunker_strategy.lower()}.txt")
|
||||
|
||||
logger.info(f"Saving chunking results to: {output_path}")
|
||||
|
||||
try:
|
||||
# Prepare output content
|
||||
output_lines = [
|
||||
f"=== Chunking Results ({self.chunker_strategy}) ===",
|
||||
f"Dialogue ID: {dialogue.ref_id}",
|
||||
f"Original conversation has {len(dialogue.context.msgs) if dialogue.context else 0} messages",
|
||||
f"Total characters: {len(dialogue.content) if dialogue.content else 0}",
|
||||
f"Generated {len(chunks)} chunks:",
|
||||
""
|
||||
]
|
||||
|
||||
for i, chunk in enumerate(chunks, 1):
|
||||
content_preview = chunk.content[:preview_length] if chunk.content else ""
|
||||
if len(chunk.content) > preview_length:
|
||||
content_preview += "..."
|
||||
|
||||
output_lines.append(f" Chunk {i}: {len(chunk.content)} characters")
|
||||
output_lines.append(f" Content preview: {content_preview}")
|
||||
if chunk.metadata:
|
||||
output_lines.append(f" Metadata: {chunk.metadata}")
|
||||
output_lines.append("")
|
||||
|
||||
with open(output_path, "w", encoding="utf-8") as f:
|
||||
f.write("\n".join(output_lines))
|
||||
# Write to file
|
||||
with open(output_path, "w", encoding="utf-8") as f:
|
||||
f.write("\n".join(output_lines))
|
||||
|
||||
logger.info(f"Chunking results saved to: {output_path}")
|
||||
return output_path
|
||||
logger.info(f"Successfully saved chunking results to: {output_path}")
|
||||
return output_path
|
||||
|
||||
except IOError as e:
|
||||
logger.error(f"Failed to write chunking results to {output_path}: {e}", exc_info=True)
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error saving chunking results: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
|
||||
@@ -327,7 +327,7 @@ class MultiOntologyParser:
|
||||
|
||||
Example:
|
||||
>>> parser = MultiOntologyParser([
|
||||
... "General_purpose_entity.ttl",
|
||||
... "app/core/memory/ontology_services/General_purpose_entity.ttl",
|
||||
... "domain_specific.owl"
|
||||
... ])
|
||||
>>> registry = parser.parse_all()
|
||||
|
||||
@@ -400,7 +400,8 @@ async def render_user_summary_prompt(
|
||||
user_id: str,
|
||||
entities: str,
|
||||
statements: str,
|
||||
language: str = "zh"
|
||||
language: str = "zh",
|
||||
user_display_name: str = None
|
||||
) -> str:
|
||||
"""
|
||||
Renders the user summary prompt using the user_summary.jinja2 template.
|
||||
@@ -410,16 +411,22 @@ async def render_user_summary_prompt(
|
||||
entities: Core entities with frequency information
|
||||
statements: Representative statement samples
|
||||
language: The language to use for summary generation ("zh" for Chinese, "en" for English)
|
||||
user_display_name: Display name for the user (e.g., other_name or "该用户"/"the user")
|
||||
|
||||
Returns:
|
||||
Rendered prompt content as string
|
||||
"""
|
||||
# 如果没有提供 user_display_name,使用默认值
|
||||
if user_display_name is None:
|
||||
user_display_name = "该用户" if language == "zh" else "the user"
|
||||
|
||||
template = prompt_env.get_template("user_summary.jinja2")
|
||||
rendered_prompt = template.render(
|
||||
user_id=user_id,
|
||||
entities=entities,
|
||||
statements=statements,
|
||||
language=language
|
||||
language=language,
|
||||
user_display_name=user_display_name
|
||||
)
|
||||
|
||||
# 记录渲染结果到提示日志
|
||||
@@ -429,7 +436,8 @@ async def render_user_summary_prompt(
|
||||
'user_id': user_id,
|
||||
'entities_len': len(entities),
|
||||
'statements_len': len(statements),
|
||||
'language': language
|
||||
'language': language,
|
||||
'user_display_name': user_display_name
|
||||
})
|
||||
|
||||
return rendered_prompt
|
||||
@@ -540,3 +548,20 @@ async def render_ontology_extraction_prompt(
|
||||
})
|
||||
|
||||
return rendered_prompt
|
||||
|
||||
|
||||
def render_interest_filter_prompt(tag_list: str, language: str = "zh") -> str:
|
||||
"""
|
||||
Renders the interest filter prompt using the interest_filter.jinja2 template.
|
||||
|
||||
Args:
|
||||
tag_list: Comma-separated string of raw tags to filter
|
||||
language: Output language ("zh" for Chinese, "en" for English)
|
||||
|
||||
Returns:
|
||||
Rendered prompt content as string
|
||||
"""
|
||||
template = prompt_env.get_template("interest_filter.jinja2")
|
||||
rendered_prompt = template.render(tag_list=tag_list, language=language)
|
||||
log_prompt_rendering('interest filter', rendered_prompt)
|
||||
return rendered_prompt
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{#
|
||||
对话级抽取与相关性判定模板(用于剪枝加速)
|
||||
输入:pruning_scene, dialog_text
|
||||
输入:pruning_scene, is_builtin_scene, ontology_classes, dialog_text, language
|
||||
输出:严格 JSON(不要包含任何多余文本),字段:
|
||||
- is_related: bool,是否与所选场景相关
|
||||
- times: [string],从对话中抽取的时间相关文本(日期、时间、时间段、有效期等)
|
||||
@@ -16,7 +16,8 @@
|
||||
- 仅输出上述键;避免多余解释或字段。
|
||||
#}
|
||||
|
||||
{% set scene_instructions = {
|
||||
{# ── 内置场景的固定说明 ── #}
|
||||
{% set builtin_scene_instructions = {
|
||||
'education': {
|
||||
'zh': '教育场景:教学、课程、考试、作业、老师/学生互动、学习资源、学校管理等。',
|
||||
'en': 'Education Scenario: Teaching, courses, exams, homework, teacher/student interaction, learning resources, school management, etc.'
|
||||
@@ -31,16 +32,40 @@
|
||||
}
|
||||
} %}
|
||||
|
||||
{% set scene_key = pruning_scene %}
|
||||
{% if scene_key not in scene_instructions %}
|
||||
{% set scene_key = 'education' %}
|
||||
{# ── 确定最终使用的场景说明 ── #}
|
||||
{% if is_builtin_scene %}
|
||||
{# 内置专门场景:使用固定说明 #}
|
||||
{% set scene_key = pruning_scene %}
|
||||
{% if scene_key not in builtin_scene_instructions %}{% set scene_key = 'education' %}{% endif %}
|
||||
{% set instruction = builtin_scene_instructions[scene_key][language] if language in ['zh', 'en'] else builtin_scene_instructions[scene_key]['zh'] %}
|
||||
{% set custom_types_str = '' %}
|
||||
{% else %}
|
||||
{# 自定义场景:使用场景名称 + 本体类型列表构建说明 #}
|
||||
{% if ontology_classes and ontology_classes | length > 0 %}
|
||||
{% if language == 'en' %}
|
||||
{% set custom_types_str = ontology_classes | join(', ') %}
|
||||
{% set instruction = 'Custom scene "' ~ pruning_scene ~ '": The dialogue is related to this scene if it involves any of the following entity types: ' ~ custom_types_str ~ '.' %}
|
||||
{% else %}
|
||||
{% set custom_types_str = ontology_classes | join('、') %}
|
||||
{% set instruction = '自定义场景「' ~ pruning_scene ~ '」:对话涉及以下任意实体类型时视为相关:' ~ custom_types_str ~ '。' %}
|
||||
{% endif %}
|
||||
{% else %}
|
||||
{# 无本体类型时退化为通用说明 #}
|
||||
{% if language == 'en' %}
|
||||
{% set instruction = 'Custom scene "' ~ pruning_scene ~ '": Determine whether the dialogue content is relevant to this scene based on overall context.' %}
|
||||
{% else %}
|
||||
{% set instruction = '自定义场景「' ~ pruning_scene ~ '」:根据对话整体内容判断是否与该场景相关。' %}
|
||||
{% endif %}
|
||||
{% set custom_types_str = '' %}
|
||||
{% endif %}
|
||||
{% endif %}
|
||||
|
||||
{% set instruction = scene_instructions[scene_key][language] if language in ['zh', 'en'] else scene_instructions[scene_key]['zh'] %}
|
||||
|
||||
{% if language == "zh" %}
|
||||
请在下方对话全文基础上,按该场景进行一次性抽取并判定相关性:
|
||||
场景说明:{{ instruction }}
|
||||
{% if not is_builtin_scene and custom_types_str %}
|
||||
重要提示:只要对话中出现与上述实体类型({{ custom_types_str }})相关的内容,即判定为相关(is_related=true)。
|
||||
{% endif %}
|
||||
|
||||
对话全文:
|
||||
"""
|
||||
@@ -60,6 +85,9 @@
|
||||
{% else %}
|
||||
Based on the full dialogue below, perform one-time extraction and relevance determination according to this scenario:
|
||||
Scenario Description: {{ instruction }}
|
||||
{% if not is_builtin_scene and custom_types_str %}
|
||||
Important: If the dialogue contains content related to any of the entity types above ({{ custom_types_str }}), mark it as relevant (is_related=true).
|
||||
{% endif %}
|
||||
|
||||
Full Dialogue:
|
||||
"""
|
||||
|
||||
@@ -0,0 +1,67 @@
|
||||
{% if language == "zh" %}
|
||||
You are a user interest analysis expert. Your task is to infer and extract the user's core hobby/interest activities from a tag list. The tags may be specific project names, tool names, or compound nouns — your job is to identify the underlying interest they represent.
|
||||
|
||||
**Step 1 - Infer the underlying interest from each tag**:
|
||||
Look at each tag and ask: "What hobby or interest does this tag suggest the user has?"
|
||||
|
||||
Examples of inference:
|
||||
- '攀岩', '室内攀岩馆', '攀岩者数据仪表盘', '路线解锁地图', '指力', '路线等级', '当日攀岩流畅度' → '攀岩'
|
||||
- '风光摄影元数据增强器', 'EXIF数据', '.CR2文件', '.NEF文件', '日出拍摄点', '曝光补偿', '光圈', '太阳高度角', '云量预测图层' → '摄影'
|
||||
- '晨间冥想坚持天数', '身心协同峰值' → '冥想'
|
||||
- '川味可视化', '川菜' → '烹饪'
|
||||
- '开源项目命名建议', 'climbviz', '可视化', '力量增长雷达图' → '编程' 或 '数据可视化'
|
||||
- '吉他', '指弹', '琴谱' → '吉他'
|
||||
- '跑步', '5公里', '跑鞋' → '跑步'
|
||||
- '瑜伽垫', '瑜伽课' → '瑜伽'
|
||||
|
||||
**Step 2 - Consolidate and deduplicate**:
|
||||
- Merge tags that point to the same interest into one representative label
|
||||
- Use concise, standard hobby names (e.g., '攀岩', '摄影', '编程', '烹饪', '冥想', '吉他', '跑步')
|
||||
- If multiple tags all point to '攀岩', output '攀岩' only once
|
||||
|
||||
**Step 3 - Filter out non-interest tags**:
|
||||
Remove tags that do NOT suggest any hobby or interest:
|
||||
- Generic system/assistant terms (e.g., '助手', '用户', 'AI')
|
||||
- Pure abstract metrics with no clear hobby link (e.g., '完成时间', '日期', '自我评分')
|
||||
- Location names with no clear hobby link (e.g., '青城山后山' alone — but if combined with photography context, infer '摄影')
|
||||
|
||||
**Output format**: Return a list of concise interest activity names in Chinese.
|
||||
|
||||
**Example**:
|
||||
Input: ['攀岩', '攀岩者数据仪表盘', '路线解锁地图', '指力', '风光摄影元数据增强器', 'EXIF数据', '晨间冥想坚持天数', '川味可视化', '可视化', '助手', '完成时间']
|
||||
Output: ['攀岩', '摄影', '冥想', '烹饪', '编程']
|
||||
|
||||
Now process the following tag list and return the inferred interest activities in Chinese: {{ tag_list }}
|
||||
{% else %}
|
||||
You are a user interest analysis expert. Your task is to infer and extract the user's core hobby/interest activities from a tag list. The tags may be specific project names, tool names, or compound nouns — your job is to identify the underlying interest they represent.
|
||||
|
||||
**Step 1 - Infer the underlying interest from each tag**:
|
||||
Look at each tag and ask: "What hobby or interest does this tag suggest the user has?"
|
||||
|
||||
Examples of inference:
|
||||
- 'rock climbing', 'indoor climbing gym', 'climber dashboard', 'route map', 'finger strength' → 'rock climbing'
|
||||
- 'landscape photography metadata enhancer', 'EXIF data', 'sunrise shooting spot', 'exposure compensation' → 'photography'
|
||||
- 'morning meditation streak', 'mind-body peak' → 'meditation'
|
||||
- 'Sichuan cuisine visualization', 'Sichuan food' → 'cooking'
|
||||
- 'open source project', 'data visualization tool', 'Python' → 'programming'
|
||||
- 'guitar', 'fingerpicking', 'sheet music' → 'guitar'
|
||||
- 'running', '5km', 'running shoes' → 'running'
|
||||
|
||||
**Step 2 - Consolidate and deduplicate**:
|
||||
- Merge tags that point to the same interest into one representative label
|
||||
- Use concise, standard hobby names (e.g., 'rock climbing', 'photography', 'programming', 'cooking', 'meditation')
|
||||
- If multiple tags all point to 'rock climbing', output 'rock climbing' only once
|
||||
|
||||
**Step 3 - Filter out non-interest tags**:
|
||||
Remove tags that do NOT suggest any hobby or interest:
|
||||
- Generic system/assistant terms (e.g., 'assistant', 'user', 'AI')
|
||||
- Pure abstract metrics with no clear hobby link (e.g., 'completion time', 'date', 'self-rating')
|
||||
|
||||
**Output format**: Return a list of concise interest activity names in English.
|
||||
|
||||
**Example**:
|
||||
Input: ['rock climbing', 'climber dashboard', 'route map', 'finger strength', 'landscape photography metadata enhancer', 'EXIF data', 'morning meditation streak', 'Sichuan cuisine visualization', 'visualization', 'assistant', 'completion time']
|
||||
Output: ['rock climbing', 'photography', 'meditation', 'cooking', 'programming']
|
||||
|
||||
Now process the following tag list and return the inferred interest activities in English: {{ tag_list }}
|
||||
{% endif %}
|
||||
@@ -14,8 +14,8 @@ Your task is to generate a comprehensive user profile based on the provided enti
|
||||
{% endif %}
|
||||
|
||||
===Inputs===
|
||||
{% if user_id %}
|
||||
- User ID: {{ user_id }}
|
||||
{% if user_display_name %}
|
||||
- User Display Name: {{ user_display_name }}
|
||||
{% endif %}
|
||||
{% if entities %}
|
||||
- Core Entities & Frequency: {{ entities }}
|
||||
@@ -33,6 +33,20 @@ Your task is to generate a comprehensive user profile based on the provided enti
|
||||
3. Avoid excessive adjectives and empty phrases
|
||||
4. Strictly follow the output format specified below
|
||||
|
||||
{% if language == "zh" %}
|
||||
**【严格人称规定】**
|
||||
- 在描述用户时,必须使用"{{ user_display_name }}"作为人称
|
||||
- 绝对禁止使用用户ID(如 {{ user_id }})来称呼用户
|
||||
- 绝对禁止在摘要中出现任何形式的UUID或ID字符串
|
||||
- 如果需要指代用户,只能使用"{{ user_display_name }}"或相应的代词(他/她/TA)
|
||||
{% else %}
|
||||
**【STRICT PRONOUN RULES】**
|
||||
- When describing the user, you MUST use "{{ user_display_name }}" as the reference
|
||||
- It is ABSOLUTELY FORBIDDEN to use the user ID (such as {{ user_id }}) to refer to the user
|
||||
- It is ABSOLUTELY FORBIDDEN to include any form of UUID or ID string in the summary
|
||||
- If you need to refer to the user, you can ONLY use "{{ user_display_name }}" or appropriate pronouns (he/she/they)
|
||||
{% endif %}
|
||||
|
||||
**Section-Specific Requirements:**
|
||||
|
||||
{% if language == "zh" %}
|
||||
@@ -103,13 +117,13 @@ Your task is to generate a comprehensive user profile based on the provided enti
|
||||
|
||||
{% if language == "zh" %}
|
||||
Example Input:
|
||||
- User ID: user_12345
|
||||
- User Display Name: 张三
|
||||
- Core Entities & Frequency: 产品经理 (15), AI (12), 深圳 (10), 数据分析 (8), 团队协作 (7)
|
||||
- Representative Statement Samples: 我在深圳从事产品经理工作已经5年了 | 我相信好的产品源于对用户需求的深刻理解 | 我喜欢在团队中起到协调作用 | 数据驱动决策是我的工作原则
|
||||
|
||||
Example Output:
|
||||
【基本介绍】
|
||||
我是张三,一名充满热情的高级产品经理。在过去的5年里,我专注于AI和数据驱动的产品设计,致力于创造能够真正改善用户生活的产品。我相信好的产品源于对用户需求的深刻理解和对技术可能性的不断探索。
|
||||
张三是一名充满热情的高级产品经理,在深圳工作。在过去的5年里,张三专注于AI和数据驱动的产品设计,致力于创造能够真正改善用户生活的产品。张三相信好的产品源于对用户需求的深刻理解和对技术可能性的不断探索。
|
||||
|
||||
【性格特点】
|
||||
性格开朗,善于沟通,注重细节。喜欢在团队中起到协调作用,帮助大家达成共识。面对挑战时保持乐观,相信每个问题都有解决方案。
|
||||
@@ -121,13 +135,13 @@ Example Output:
|
||||
"让每一个产品决策都充满温度。"
|
||||
{% else %}
|
||||
Example Input:
|
||||
- User ID: user_12345
|
||||
- User Display Name: John
|
||||
- Core Entities & Frequency: Product Manager (15), AI (12), San Francisco (10), Data Analysis (8), Team Collaboration (7)
|
||||
- Representative Statement Samples: I have been working as a product manager in San Francisco for 5 years | I believe good products come from deep understanding of user needs | I enjoy playing a coordinating role in teams | Data-driven decision making is my work principle
|
||||
|
||||
Example Output:
|
||||
【Basic Introduction】
|
||||
This is a passionate senior product manager based in San Francisco. Over the past 5 years, they have focused on AI and data-driven product design, dedicated to creating products that truly improve users' lives. They believe good products stem from deep understanding of user needs and continuous exploration of technological possibilities.
|
||||
John is a passionate senior product manager based in San Francisco. Over the past 5 years, John has focused on AI and data-driven product design, dedicated to creating products that truly improve users' lives. John believes good products stem from deep understanding of user needs and continuous exploration of technological possibilities.
|
||||
|
||||
【Personality Traits】
|
||||
Outgoing personality with excellent communication skills and attention to detail. Enjoys playing a coordinating role in teams, helping everyone reach consensus. Maintains optimism when facing challenges, believing every problem has a solution.
|
||||
|
||||
@@ -21,31 +21,55 @@ from pydantic import BaseModel, Field
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class RedBearModelConfig(BaseModel):
|
||||
"""模型配置基类"""
|
||||
model_name: str
|
||||
provider: str
|
||||
api_key: str
|
||||
base_url: Optional[str] = None
|
||||
is_omni: bool = False # 是否为 Omni 模型
|
||||
# 请求超时时间(秒)- 默认120秒以支持复杂的LLM调用,可通过环境变量 LLM_TIMEOUT 配置
|
||||
timeout: float = Field(default_factory=lambda: float(os.getenv("LLM_TIMEOUT", "120.0")))
|
||||
# 最大重试次数 - 默认2次以避免过长等待,可通过环境变量 LLM_MAX_RETRIES 配置
|
||||
max_retries: int = Field(default_factory=lambda: int(os.getenv("LLM_MAX_RETRIES", "2")))
|
||||
concurrency: int = 5 # 并发限流
|
||||
concurrency: int = 5 # 并发限流
|
||||
extra_params: Dict[str, Any] = {}
|
||||
|
||||
|
||||
class RedBearModelFactory:
|
||||
"""模型工厂类"""
|
||||
|
||||
|
||||
@classmethod
|
||||
def get_model_params(cls, config: RedBearModelConfig) -> Dict[str, Any]:
|
||||
"""根据提供商获取模型参数"""
|
||||
provider = config.provider.lower()
|
||||
|
||||
|
||||
# 打印供应商信息用于调试
|
||||
from app.core.logging_config import get_business_logger
|
||||
logger = get_business_logger()
|
||||
logger.debug(f"获取模型参数 - Provider: {provider}, Model: {config.model_name}")
|
||||
logger.debug(f"获取模型参数 - Provider: {provider}, Model: {config.model_name}, is_omni: {config.is_omni}")
|
||||
|
||||
# dashscope 的 omni 模型使用 OpenAI 兼容模式
|
||||
if provider == ModelProvider.DASHSCOPE and config.is_omni:
|
||||
import httpx
|
||||
if not config.base_url:
|
||||
config.base_url = "https://dashscope.aliyuncs.com/compatible-mode/v1"
|
||||
timeout_config = httpx.Timeout(
|
||||
timeout=config.timeout,
|
||||
connect=60.0,
|
||||
read=config.timeout,
|
||||
write=60.0,
|
||||
pool=10.0,
|
||||
)
|
||||
return {
|
||||
"model": config.model_name,
|
||||
"base_url": config.base_url,
|
||||
"api_key": config.api_key,
|
||||
"timeout": timeout_config,
|
||||
"max_retries": config.max_retries,
|
||||
**config.extra_params
|
||||
}
|
||||
|
||||
if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK, ModelProvider.OLLAMA]:
|
||||
# 使用 httpx.Timeout 对象来设置详细的超时配置
|
||||
@@ -65,7 +89,7 @@ class RedBearModelFactory:
|
||||
"timeout": timeout_config,
|
||||
"max_retries": config.max_retries,
|
||||
**config.extra_params
|
||||
}
|
||||
}
|
||||
elif provider == ModelProvider.DASHSCOPE:
|
||||
# DashScope (通义千问) 使用自己的参数格式
|
||||
# 注意: DashScopeEmbeddings 不支持 timeout 和 base_url 参数
|
||||
@@ -82,7 +106,7 @@ class RedBearModelFactory:
|
||||
# region 从 base_url 或 extra_params 获取
|
||||
from botocore.config import Config as BotoConfig
|
||||
from app.core.models.bedrock_model_mapper import normalize_bedrock_model_id
|
||||
|
||||
|
||||
max_pool_connections = int(os.getenv("BEDROCK_MAX_POOL_CONNECTIONS", "50"))
|
||||
max_retries = int(os.getenv("BEDROCK_MAX_RETRIES", "2"))
|
||||
# Configure with increased connection pool
|
||||
@@ -90,16 +114,16 @@ class RedBearModelFactory:
|
||||
max_pool_connections=max_pool_connections,
|
||||
retries={'max_attempts': max_retries, 'mode': 'adaptive'}
|
||||
)
|
||||
|
||||
|
||||
# 标准化模型 ID(自动转换简化名称为完整 Bedrock Model ID)
|
||||
model_id = normalize_bedrock_model_id(config.model_name)
|
||||
|
||||
|
||||
params = {
|
||||
"model_id": model_id,
|
||||
"config": boto_config,
|
||||
**config.extra_params
|
||||
}
|
||||
|
||||
|
||||
# 解析 API key (格式: access_key_id:secret_access_key)
|
||||
if config.api_key and ":" in config.api_key:
|
||||
access_key_id, secret_access_key = config.api_key.split(":", 1)
|
||||
@@ -107,45 +131,52 @@ class RedBearModelFactory:
|
||||
params["aws_secret_access_key"] = secret_access_key
|
||||
elif config.api_key:
|
||||
params["aws_access_key_id"] = config.api_key
|
||||
|
||||
|
||||
# 设置 region
|
||||
if config.base_url:
|
||||
params["region_name"] = config.base_url
|
||||
elif "region_name" not in params:
|
||||
params["region_name"] = "us-east-1" # 默认区域
|
||||
|
||||
|
||||
return params
|
||||
else:
|
||||
raise BusinessException(f"不支持的提供商: {provider}", code=BizCode.PROVIDER_NOT_SUPPORTED)
|
||||
|
||||
|
||||
@classmethod
|
||||
def get_rerank_model_params(cls, config: RedBearModelConfig) -> Dict[str, Any]:
|
||||
"""根据提供商获取模型参数"""
|
||||
provider = config.provider.lower()
|
||||
if provider in [ModelProvider.XINFERENCE, ModelProvider.GPUSTACK]:
|
||||
return {
|
||||
return {
|
||||
"model": config.model_name,
|
||||
# "base_url": config.base_url,
|
||||
"jina_api_key": config.api_key,
|
||||
**config.extra_params
|
||||
}
|
||||
}
|
||||
else:
|
||||
raise BusinessException(f"不支持的提供商: {provider}", code=BizCode.PROVIDER_NOT_SUPPORTED)
|
||||
|
||||
def get_provider_llm_class(config:RedBearModelConfig, type: ModelType=ModelType.LLM) -> type[BaseLLM]:
|
||||
|
||||
def get_provider_llm_class(config: RedBearModelConfig, type: ModelType = ModelType.LLM) -> type[BaseLLM]:
|
||||
"""根据模型提供商获取对应的模型类"""
|
||||
provider = config.provider.lower()
|
||||
if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK] :
|
||||
|
||||
# dashscope 的 omni 模型使用 OpenAI 兼容模式
|
||||
if provider == ModelProvider.DASHSCOPE and config.is_omni:
|
||||
from langchain_openai import ChatOpenAI
|
||||
return ChatOpenAI
|
||||
|
||||
if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK] :
|
||||
if type == ModelType.LLM:
|
||||
from langchain_openai import OpenAI
|
||||
return OpenAI
|
||||
return OpenAI
|
||||
elif type == ModelType.CHAT:
|
||||
from langchain_openai import ChatOpenAI
|
||||
return ChatOpenAI
|
||||
elif provider == ModelProvider.DASHSCOPE:
|
||||
from langchain_community.chat_models import ChatTongyi
|
||||
return ChatTongyi
|
||||
elif provider == ModelProvider.OLLAMA:
|
||||
elif provider == ModelProvider.OLLAMA:
|
||||
from langchain_ollama import OllamaLLM
|
||||
return OllamaLLM
|
||||
elif provider == ModelProvider.BEDROCK:
|
||||
@@ -155,15 +186,16 @@ def get_provider_llm_class(config:RedBearModelConfig, type: ModelType=ModelType.
|
||||
else:
|
||||
raise BusinessException(f"不支持的模型提供商: {provider}", code=BizCode.PROVIDER_NOT_SUPPORTED)
|
||||
|
||||
|
||||
def get_provider_embedding_class(provider: str) -> type[Embeddings]:
|
||||
"""根据模型提供商获取对应的模型类"""
|
||||
provider = provider.lower()
|
||||
if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK] :
|
||||
if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK]:
|
||||
from langchain_openai import OpenAIEmbeddings
|
||||
return OpenAIEmbeddings
|
||||
return OpenAIEmbeddings
|
||||
elif provider == ModelProvider.DASHSCOPE:
|
||||
from langchain_community.embeddings import DashScopeEmbeddings
|
||||
return DashScopeEmbeddings
|
||||
return DashScopeEmbeddings
|
||||
elif provider == ModelProvider.OLLAMA:
|
||||
from langchain_ollama import OllamaEmbeddings
|
||||
return OllamaEmbeddings
|
||||
@@ -173,14 +205,15 @@ def get_provider_embedding_class(provider: str) -> type[Embeddings]:
|
||||
else:
|
||||
raise BusinessException(f"不支持的模型提供商: {provider}", code=BizCode.PROVIDER_NOT_SUPPORTED)
|
||||
|
||||
|
||||
def get_provider_rerank_class(provider: str):
|
||||
"""根据模型提供商获取对应的模型类"""
|
||||
provider = provider.lower()
|
||||
if provider in [ModelProvider.XINFERENCE, ModelProvider.GPUSTACK] :
|
||||
provider = provider.lower()
|
||||
if provider in [ModelProvider.XINFERENCE, ModelProvider.GPUSTACK]:
|
||||
from langchain_community.document_compressors import JinaRerank
|
||||
return JinaRerank
|
||||
# elif provider == ModelProvider.OLLAMA:
|
||||
return JinaRerank
|
||||
# elif provider == ModelProvider.OLLAMA:
|
||||
# from langchain_ollama import OllamaEmbeddings
|
||||
# return OllamaEmbeddings
|
||||
else:
|
||||
raise BusinessException(f"不支持的模型提供商: {provider}", code=BizCode.PROVIDER_NOT_SUPPORTED)
|
||||
raise BusinessException(f"不支持的模型提供商: {provider}", code=BizCode.PROVIDER_NOT_SUPPORTED)
|
||||
|
||||
@@ -6,6 +6,8 @@ models:
|
||||
description: AI21 Labs大语言模型,completion生成模式,256000上下文窗口
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
logo: bedrock
|
||||
@@ -15,6 +17,9 @@ models:
|
||||
description: Amazon Nova大语言模型,支持智能体思考、工具调用、流式工具调用、视觉能力,300000上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- agent-thought
|
||||
@@ -28,6 +33,9 @@ models:
|
||||
description: Anthropic Claude大语言模型,支持智能体思考、视觉能力、工具调用、流式工具调用、文档处理,200000上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- agent-thought
|
||||
@@ -42,6 +50,8 @@ models:
|
||||
description: Cohere大语言模型,支持智能体思考、工具调用、流式工具调用,128000上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- agent-thought
|
||||
@@ -54,6 +64,9 @@ models:
|
||||
description: DeepSeek大语言模型,支持智能体思考、视觉能力、工具调用、流式工具调用,32768上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- agent-thought
|
||||
@@ -67,6 +80,8 @@ models:
|
||||
description: Meta Llama大语言模型,支持智能体思考、工具调用,128000上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- agent-thought
|
||||
@@ -78,6 +93,8 @@ models:
|
||||
description: Mistral AI大语言模型,支持智能体思考、工具调用,32000上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- agent-thought
|
||||
@@ -89,6 +106,8 @@ models:
|
||||
description: OpenAI大语言模型,支持智能体思考、工具调用、流式工具调用,32768上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- agent-thought
|
||||
@@ -101,6 +120,8 @@ models:
|
||||
description: Qwen大语言模型,支持智能体思考、工具调用、流式工具调用,32768上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- agent-thought
|
||||
@@ -113,6 +134,8 @@ models:
|
||||
description: amazon.rerank-v1:0重排序模型,5120上下文窗口
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 重排序模型
|
||||
logo: bedrock
|
||||
@@ -122,6 +145,8 @@ models:
|
||||
description: cohere.rerank-v3-5:0重排序模型,5120上下文窗口
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 重排序模型
|
||||
logo: bedrock
|
||||
@@ -131,6 +156,9 @@ models:
|
||||
description: amazon.nova-2-multimodal-embeddings-v1:0文本嵌入模型,支持视觉能力,8192上下文窗口
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
is_omni: false
|
||||
tags:
|
||||
- 文本嵌入模型
|
||||
- vision
|
||||
@@ -141,6 +169,8 @@ models:
|
||||
description: amazon.titan-embed-text-v1文本嵌入模型,8192上下文窗口
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 文本嵌入模型
|
||||
logo: bedrock
|
||||
@@ -150,6 +180,8 @@ models:
|
||||
description: amazon.titan-embed-text-v2:0文本嵌入模型,8192上下文窗口
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 文本嵌入模型
|
||||
logo: bedrock
|
||||
@@ -159,6 +191,8 @@ models:
|
||||
description: Cohere Embed 3 English文本嵌入模型,512上下文窗口
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 文本嵌入模型
|
||||
logo: bedrock
|
||||
@@ -168,6 +202,8 @@ models:
|
||||
description: Cohere Embed 3 Multilingual文本嵌入模型,512上下文窗口
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 文本嵌入模型
|
||||
logo: bedrock
|
||||
logo: bedrock
|
||||
@@ -6,6 +6,8 @@ models:
|
||||
description: DeepSeek-R1-Distill-Qwen-14B大语言模型,支持智能体思考,32000上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- agent-thought
|
||||
@@ -16,6 +18,8 @@ models:
|
||||
description: DeepSeek-R1-Distill-Qwen-32B大语言模型,支持智能体思考,32000上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- agent-thought
|
||||
@@ -26,6 +30,8 @@ models:
|
||||
description: DeepSeek-R1大语言模型,支持智能体思考,131072超大上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- agent-thought
|
||||
@@ -36,6 +42,8 @@ models:
|
||||
description: DeepSeek-V3.1大语言模型,支持智能体思考,131072超大上下文窗口,对话模式,支持丰富生成参数调节
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- agent-thought
|
||||
@@ -46,6 +54,8 @@ models:
|
||||
description: DeepSeek-V3.2-exp实验版大语言模型,支持智能体思考,131072超大上下文窗口,对话模式,支持丰富生成参数调节
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- agent-thought
|
||||
@@ -56,6 +66,8 @@ models:
|
||||
description: DeepSeek-V3.2大语言模型,支持智能体思考,131072超大上下文窗口,对话模式,支持丰富生成参数调节
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- agent-thought
|
||||
@@ -66,6 +78,8 @@ models:
|
||||
description: DeepSeek-V3大语言模型,支持智能体思考,64000上下文窗口,对话模式,支持文本与JSON格式输出
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- agent-thought
|
||||
@@ -76,6 +90,8 @@ models:
|
||||
description: farui-plus大语言模型,支持多工具调用、智能体思考、流式工具调用,12288上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- multi-tool-call
|
||||
@@ -88,6 +104,8 @@ models:
|
||||
description: GLM-4.7大语言模型,支持多工具调用、智能体思考、流式工具调用,202752超大上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- multi-tool-call
|
||||
@@ -100,6 +118,9 @@ models:
|
||||
description: qvq-max-latest大语言模型,支持视觉、智能体思考、流式工具调用,131072上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- vision
|
||||
@@ -112,6 +133,9 @@ models:
|
||||
description: qvq-max大语言模型,支持视觉、智能体思考、流式工具调用,131072上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- vision
|
||||
@@ -124,6 +148,8 @@ models:
|
||||
description: qwen-coder-turbo-0919代码专用大语言模型,支持智能体思考,131072上下文窗口,对话模式,已废弃
|
||||
is_deprecated: true
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- 代码模型
|
||||
@@ -135,6 +161,8 @@ models:
|
||||
description: qwen-max-latest大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式,支持联网搜索
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- multi-tool-call
|
||||
@@ -147,6 +175,8 @@ models:
|
||||
description: qwen-max-longcontext长上下文大语言模型,支持多工具调用、智能体思考、流式工具调用,32000上下文窗口,对话模式,已废弃
|
||||
is_deprecated: true
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- multi-tool-call
|
||||
@@ -159,6 +189,8 @@ models:
|
||||
description: qwen-max大语言模型,支持多工具调用、智能体思考、流式工具调用,32768上下文窗口,对话模式,支持联网搜索
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- multi-tool-call
|
||||
@@ -171,6 +203,8 @@ models:
|
||||
description: qwen-mt-plus多语言翻译大语言模型,支持智能体思考,16384上下文窗口,对话模式,支持多语种互译与领域翻译适配
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- 翻译模型
|
||||
@@ -182,6 +216,8 @@ models:
|
||||
description: qwen-mt-turbo轻量化多语言翻译大语言模型,支持智能体思考,16384上下文窗口,对话模式,支持多语种互译与领域翻译适配
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- 翻译模型
|
||||
@@ -193,6 +229,8 @@ models:
|
||||
description: qwen-plus-0112大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式,支持联网搜索,已废弃
|
||||
is_deprecated: true
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- multi-tool-call
|
||||
@@ -205,6 +243,8 @@ models:
|
||||
description: qwen-plus-0125大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式,支持联网搜索,已废弃
|
||||
is_deprecated: true
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- multi-tool-call
|
||||
@@ -217,6 +257,8 @@ models:
|
||||
description: qwen-plus-0723大语言模型,支持多工具调用、智能体思考、流式工具调用,32000上下文窗口,对话模式,支持联网搜索,已废弃
|
||||
is_deprecated: true
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- multi-tool-call
|
||||
@@ -229,6 +271,8 @@ models:
|
||||
description: qwen-plus-0806大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式,支持联网搜索,已废弃
|
||||
is_deprecated: true
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- multi-tool-call
|
||||
@@ -241,6 +285,8 @@ models:
|
||||
description: qwen-plus-0919大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式,支持联网搜索,已废弃
|
||||
is_deprecated: true
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- multi-tool-call
|
||||
@@ -253,6 +299,8 @@ models:
|
||||
description: qwen-plus-1125大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式,支持联网搜索,已废弃
|
||||
is_deprecated: true
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- multi-tool-call
|
||||
@@ -265,6 +313,8 @@ models:
|
||||
description: qwen-plus-1127大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式,支持联网搜索,已废弃
|
||||
is_deprecated: true
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- multi-tool-call
|
||||
@@ -277,6 +327,8 @@ models:
|
||||
description: qwen-plus-1220大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式,已废弃
|
||||
is_deprecated: true
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- multi-tool-call
|
||||
@@ -289,6 +341,10 @@ models:
|
||||
description: qwen-vl-max多模态大模型,支持视觉理解、智能体思考、视频理解,131072上下文窗口,对话模式,未废弃
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
- video
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- 多模态模型
|
||||
@@ -302,6 +358,10 @@ models:
|
||||
description: qwen-vl-plus-0809多模态大模型,支持视觉理解、智能体思考、视频理解,32768上下文窗口,对话模式,已废弃
|
||||
is_deprecated: true
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
- video
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- 多模态模型
|
||||
@@ -315,6 +375,10 @@ models:
|
||||
description: qwen-vl-plus-2025-01-02多模态大模型,支持视觉理解、智能体思考、视频理解,32768上下文窗口,对话模式,未废弃
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
- video
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- 多模态模型
|
||||
@@ -328,6 +392,10 @@ models:
|
||||
description: qwen-vl-plus-2025-01-25多模态大模型,支持视觉理解、智能体思考、视频理解,131072上下文窗口,对话模式,未废弃
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
- video
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- 多模态模型
|
||||
@@ -341,6 +409,10 @@ models:
|
||||
description: qwen-vl-plus-latest多模态大模型,支持视觉理解、智能体思考、视频理解,131072上下文窗口,对话模式,未废弃
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
- video
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- 多模态模型
|
||||
@@ -354,6 +426,10 @@ models:
|
||||
description: qwen-vl-plus多模态大模型,支持视觉理解、智能体思考、视频理解,131072上下文窗口,对话模式,未废弃
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
- video
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- 多模态模型
|
||||
@@ -367,6 +443,8 @@ models:
|
||||
description: qwen2.5-0.5b-instruct大语言模型,支持多工具调用、智能体思考、流式工具调用,32768上下文窗口,对话模式,未废弃
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- multi-tool-call
|
||||
@@ -379,6 +457,8 @@ models:
|
||||
description: qwen3-14b大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- multi-tool-call
|
||||
@@ -391,6 +471,8 @@ models:
|
||||
description: qwen3-235b-a22b-instruct-2507大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- multi-tool-call
|
||||
@@ -403,6 +485,8 @@ models:
|
||||
description: qwen3-235b-a22b-thinking-2507大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- multi-tool-call
|
||||
@@ -415,6 +499,8 @@ models:
|
||||
description: qwen3-235b-a22b大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- multi-tool-call
|
||||
@@ -427,6 +513,8 @@ models:
|
||||
description: qwen3-30b-a3b-instruct-2507大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- multi-tool-call
|
||||
@@ -439,6 +527,8 @@ models:
|
||||
description: qwen3-30b-a3b大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- multi-tool-call
|
||||
@@ -451,6 +541,8 @@ models:
|
||||
description: qwen3-32b大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- multi-tool-call
|
||||
@@ -463,6 +555,8 @@ models:
|
||||
description: qwen3-4b大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- multi-tool-call
|
||||
@@ -475,6 +569,8 @@ models:
|
||||
description: qwen3-8b大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- multi-tool-call
|
||||
@@ -487,6 +583,8 @@ models:
|
||||
description: qwen3-coder-30b-a3b-instruct大语言模型,支持智能体思考,262144上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- 代码模型
|
||||
@@ -498,6 +596,8 @@ models:
|
||||
description: qwen3-coder-480b-a35b-instruct大语言模型,支持智能体思考,262144上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- 代码模型
|
||||
@@ -509,6 +609,8 @@ models:
|
||||
description: qwen3-coder-plus-2025-09-23大语言模型,支持智能体思考,1000000上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- 代码模型
|
||||
@@ -520,6 +622,8 @@ models:
|
||||
description: qwen3-coder-plus大语言模型,支持智能体思考,1000000上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- 代码模型
|
||||
@@ -531,6 +635,8 @@ models:
|
||||
description: qwen3-max-2025-09-23大语言模型,支持多工具调用、智能体思考、流式工具调用,262144上下文窗口,对话模式,支持联网搜索
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- multi-tool-call
|
||||
@@ -544,6 +650,8 @@ models:
|
||||
description: qwen3-max-2026-01-23大语言模型,支持多工具调用、智能体思考、流式工具调用,262144上下文窗口,对话模式,支持联网搜索
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- multi-tool-call
|
||||
@@ -557,6 +665,8 @@ models:
|
||||
description: qwen3-max-preview大语言模型,支持多工具调用、智能体思考、流式工具调用,262144上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- multi-tool-call
|
||||
@@ -569,6 +679,8 @@ models:
|
||||
description: qwen3-max大语言模型,支持多工具调用、智能体思考、流式工具调用,262144上下文窗口,对话模式,支持联网搜索
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- multi-tool-call
|
||||
@@ -582,6 +694,8 @@ models:
|
||||
description: qwen3-next-80b-a3b-instruct大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- multi-tool-call
|
||||
@@ -594,6 +708,8 @@ models:
|
||||
description: qwen3-next-80b-a3b-thinking大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- multi-tool-call
|
||||
@@ -606,6 +722,11 @@ models:
|
||||
description: qwen3-omni-flash-2025-12-01多模态大语言模型,支持视觉、智能体思考、视频、音频能力,65536上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
- video
|
||||
- audio
|
||||
is_omni: true
|
||||
tags:
|
||||
- 大语言模型
|
||||
- 多模态模型
|
||||
@@ -620,6 +741,10 @@ models:
|
||||
description: qwen3-vl-235b-a22b-instruct多模态大语言模型,支持多工具调用、智能体思考、流式工具调用、视觉、视频能力,131072上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
- video
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- 多模态模型
|
||||
@@ -635,6 +760,10 @@ models:
|
||||
description: qwen3-vl-235b-a22b-thinking多模态大语言模型,支持多工具调用、智能体思考、流式工具调用、视觉、视频能力,131072上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
- video
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- 多模态模型
|
||||
@@ -650,6 +779,10 @@ models:
|
||||
description: qwen3-vl-30b-a3b-instruct多模态大语言模型,支持多工具调用、智能体思考、流式工具调用、视觉、视频能力,131072上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
- video
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- 多模态模型
|
||||
@@ -665,6 +798,10 @@ models:
|
||||
description: qwen3-vl-30b-a3b-thinking多模态大语言模型,支持多工具调用、智能体思考、流式工具调用、视觉、视频能力,131072上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
- video
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- 多模态模型
|
||||
@@ -680,6 +817,10 @@ models:
|
||||
description: qwen3-vl-flash多模态大语言模型,支持多工具调用、智能体思考、流式工具调用、视觉、视频能力,131072上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
- video
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- 多模态模型
|
||||
@@ -695,6 +836,10 @@ models:
|
||||
description: qwen3-vl-plus-2025-09-23多模态大语言模型,支持视觉、智能体思考、视频能力,262144上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
- video
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- 多模态模型
|
||||
@@ -708,6 +853,10 @@ models:
|
||||
description: qwen3-vl-plus多模态大语言模型,支持视觉、智能体思考、视频能力,262144上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
- video
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- 多模态模型
|
||||
@@ -721,6 +870,8 @@ models:
|
||||
description: qwq-32b大语言模型,支持智能体思考、流式工具调用,131072上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- agent-thought
|
||||
@@ -732,6 +883,8 @@ models:
|
||||
description: qwq-plus-0305大语言模型,支持智能体思考、流式工具调用,131072上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- agent-thought
|
||||
@@ -743,6 +896,8 @@ models:
|
||||
description: qwq-plus大语言模型,支持智能体思考、流式工具调用,131072上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- agent-thought
|
||||
@@ -754,6 +909,8 @@ models:
|
||||
description: gte-rerank-v2重排序模型,4000上下文窗口
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 重排序模型
|
||||
logo: dashscope
|
||||
@@ -763,6 +920,8 @@ models:
|
||||
description: gte-rerank重排序模型,4000上下文窗口
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 重排序模型
|
||||
logo: dashscope
|
||||
@@ -772,6 +931,9 @@ models:
|
||||
description: multimodal-embedding-v1多模态嵌入模型,支持视觉能力,8192上下文窗口,最大分块数10
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
is_omni: false
|
||||
tags:
|
||||
- 嵌入模型
|
||||
- 多模态模型
|
||||
@@ -783,6 +945,8 @@ models:
|
||||
description: text-embedding-v1文本嵌入模型,2048上下文窗口,最大分块数25
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 嵌入模型
|
||||
- 文本嵌入
|
||||
@@ -793,6 +957,8 @@ models:
|
||||
description: text-embedding-v2文本嵌入模型,2048上下文窗口,最大分块数25
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 嵌入模型
|
||||
- 文本嵌入
|
||||
@@ -803,6 +969,8 @@ models:
|
||||
description: text-embedding-v3文本嵌入模型,8192上下文窗口,最大分块数10
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 嵌入模型
|
||||
- 文本嵌入
|
||||
@@ -813,7 +981,9 @@ models:
|
||||
description: text-embedding-v4文本嵌入模型,8192上下文窗口,最大分块数10
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 嵌入模型
|
||||
- 文本嵌入
|
||||
logo: dashscope
|
||||
logo: dashscope
|
||||
@@ -6,7 +6,7 @@ from typing import Callable
|
||||
import yaml
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.models.models_model import ModelBase, ModelProvider
|
||||
from app.models.models_model import ModelBase, ModelProvider, ModelConfig
|
||||
|
||||
|
||||
def _load_yaml_config(provider: ModelProvider) -> list[dict]:
|
||||
@@ -55,6 +55,15 @@ def load_models(db: Session, providers: list[str] = None, silent: bool = False)
|
||||
print(f"\n正在加载 {provider.value} 的 {len(models)} 个模型...")
|
||||
|
||||
for model_data in models:
|
||||
config_sync_fields = {
|
||||
"logo": None,
|
||||
"capability": None,
|
||||
"is_omni": None,
|
||||
"name": None,
|
||||
"provider": None,
|
||||
"type": None,
|
||||
"description": None
|
||||
}
|
||||
try:
|
||||
# 检查模型是否已存在
|
||||
existing = db.query(ModelBase).filter(
|
||||
@@ -66,6 +75,40 @@ def load_models(db: Session, providers: list[str] = None, silent: bool = False)
|
||||
# 更新现有模型配置
|
||||
for key, value in model_data.items():
|
||||
setattr(existing, key, value)
|
||||
|
||||
# 更新绑定了该 model_id 的 ModelConfig 和 ModelApiKey
|
||||
sync_fields = [k for k in config_sync_fields.keys() if k in model_data]
|
||||
if sync_fields:
|
||||
# 批量更新 ModelConfig
|
||||
update_kwargs = {k: model_data[k] for k in sync_fields}
|
||||
db.query(ModelConfig).filter(ModelConfig.model_id == existing.id).update(
|
||||
update_kwargs,
|
||||
synchronize_session=False
|
||||
)
|
||||
|
||||
# 更新 ModelApiKey 的 capability 和 is_omni
|
||||
if 'capability' in model_data or 'is_omni' in model_data:
|
||||
from app.models.models_model import ModelApiKey, model_config_api_key_association
|
||||
api_key_update = {}
|
||||
if 'capability' in model_data:
|
||||
api_key_update['capability'] = model_data['capability']
|
||||
if 'is_omni' in model_data:
|
||||
api_key_update['is_omni'] = model_data['is_omni']
|
||||
|
||||
if api_key_update:
|
||||
# 查找所有关联的 API Key
|
||||
api_key_ids = db.query(model_config_api_key_association.c.api_key_id).join(
|
||||
ModelConfig,
|
||||
ModelConfig.id == model_config_api_key_association.c.model_config_id
|
||||
).filter(ModelConfig.model_id == existing.id).distinct().all()
|
||||
|
||||
if api_key_ids:
|
||||
api_key_ids = [aid[0] for aid in api_key_ids]
|
||||
db.query(ModelApiKey).filter(ModelApiKey.id.in_(api_key_ids)).update(
|
||||
api_key_update,
|
||||
synchronize_session=False
|
||||
)
|
||||
|
||||
db.commit()
|
||||
if not silent:
|
||||
print(f"更新成功: {model_data['name']}")
|
||||
|
||||
@@ -6,12 +6,19 @@ models:
|
||||
description: chatgpt-4o-latest大语言模型,支持多工具调用、智能体思考、流式工具调用、视觉能力,128000上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
- audio
|
||||
- video
|
||||
is_omni: true
|
||||
tags:
|
||||
- 大语言模型
|
||||
- multi-tool-call
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
- vision
|
||||
- audio
|
||||
- video
|
||||
logo: openai
|
||||
- name: gpt-3.5-turbo-0125
|
||||
type: llm
|
||||
@@ -19,6 +26,8 @@ models:
|
||||
description: gpt-3.5-turbo-0125大语言模型,支持多工具调用、智能体思考、流式工具调用,16385上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- multi-tool-call
|
||||
@@ -31,6 +40,8 @@ models:
|
||||
description: gpt-3.5-turbo-1106大语言模型,支持多工具调用、智能体思考、流式工具调用,16385上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- multi-tool-call
|
||||
@@ -43,6 +54,8 @@ models:
|
||||
description: gpt-3.5-turbo-16k大语言模型,支持多工具调用、智能体思考、流式工具调用,16385上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- multi-tool-call
|
||||
@@ -55,6 +68,8 @@ models:
|
||||
description: gpt-3.5-turbo-instruct大语言模型,4096上下文窗口,文本补全模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
logo: openai
|
||||
@@ -64,6 +79,8 @@ models:
|
||||
description: gpt-3.5-turbo大语言模型,支持多工具调用、智能体思考、流式工具调用,16385上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- multi-tool-call
|
||||
@@ -76,6 +93,8 @@ models:
|
||||
description: gpt-4-0125-preview大语言模型,支持多工具调用、智能体思考、流式工具调用,128000上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- multi-tool-call
|
||||
@@ -88,6 +107,8 @@ models:
|
||||
description: gpt-4-1106-preview大语言模型,支持多工具调用、智能体思考、流式工具调用,128000上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- multi-tool-call
|
||||
@@ -100,6 +121,9 @@ models:
|
||||
description: gpt-4-turbo-2024-04-09大语言模型,支持多工具调用、智能体思考、流式工具调用、视觉能力,128000上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- multi-tool-call
|
||||
@@ -113,6 +137,8 @@ models:
|
||||
description: gpt-4-turbo-preview大语言模型,支持多工具调用、智能体思考、流式工具调用,128000上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- multi-tool-call
|
||||
@@ -125,6 +151,9 @@ models:
|
||||
description: gpt-4-turbo大语言模型,支持多工具调用、智能体思考、流式工具调用、视觉能力,128000上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- multi-tool-call
|
||||
@@ -138,6 +167,8 @@ models:
|
||||
description: o1-preview大语言模型,支持智能体思考,128000上下文窗口,对话模式,已废弃
|
||||
is_deprecated: true
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- agent-thought
|
||||
@@ -148,6 +179,9 @@ models:
|
||||
description: o1大语言模型,支持多工具调用、智能体思考、流式工具调用、视觉能力、结构化输出,200000上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- multi-tool-call
|
||||
@@ -162,6 +196,9 @@ models:
|
||||
description: o3-2025-04-16大语言模型,支持智能体思考、工具调用、视觉能力、流式工具调用、结构化输出,200000上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- agent-thought
|
||||
@@ -176,6 +213,8 @@ models:
|
||||
description: o3-mini-2025-01-31大语言模型,支持智能体思考、工具调用、流式工具调用、结构化输出,200000上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- agent-thought
|
||||
@@ -189,6 +228,8 @@ models:
|
||||
description: o3-mini大语言模型,支持智能体思考、工具调用、流式工具调用、结构化输出,200000上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- agent-thought
|
||||
@@ -202,6 +243,9 @@ models:
|
||||
description: o3-pro-2025-06-10大语言模型,支持智能体思考、工具调用、视觉能力、结构化输出,200000上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- agent-thought
|
||||
@@ -215,6 +259,9 @@ models:
|
||||
description: o3-pro大语言模型,支持智能体思考、工具调用、视觉能力、结构化输出,200000上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- agent-thought
|
||||
@@ -228,6 +275,9 @@ models:
|
||||
description: o3大语言模型,支持智能体思考、视觉能力、工具调用、流式工具调用、结构化输出,200000上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- agent-thought
|
||||
@@ -242,6 +292,9 @@ models:
|
||||
description: o4-mini-2025-04-16大语言模型,支持智能体思考、工具调用、视觉能力、流式工具调用、结构化输出,200000上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- agent-thought
|
||||
@@ -256,6 +309,9 @@ models:
|
||||
description: o4-mini大语言模型,支持智能体思考、工具调用、视觉能力、流式工具调用、结构化输出,200000上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- agent-thought
|
||||
@@ -270,6 +326,8 @@ models:
|
||||
description: text-embedding-3-large文本向量模型,8191上下文窗口,最大分块数32
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 文本向量模型
|
||||
logo: openai
|
||||
@@ -279,6 +337,8 @@ models:
|
||||
description: text-embedding-3-small文本向量模型,8191上下文窗口,最大分块数32
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 文本向量模型
|
||||
logo: openai
|
||||
@@ -288,6 +348,8 @@ models:
|
||||
description: text-embedding-ada-002文本向量模型,8097上下文窗口,最大分块数32
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 文本向量模型
|
||||
logo: openai
|
||||
logo: openai
|
||||
8
api/app/core/workflow/adapters/__init__.py
Normal file
8
api/app/core/workflow/adapters/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
||||
# -*- coding: UTF-8 -*-
|
||||
# Author: Eternity
|
||||
# @Email: 1533512157@qq.com
|
||||
# @Time : 2026/2/24 15:54
|
||||
from app.core.workflow.adapters.dify.dify_adapter import DifyAdapter
|
||||
from app.core.workflow.adapters.memory_bear.memory_bear_adapter import MemoryBearAdapter
|
||||
|
||||
__all__ = ["DifyAdapter", "MemoryBearAdapter"]
|
||||
90
api/app/core/workflow/adapters/base_adapter.py
Normal file
90
api/app/core/workflow/adapters/base_adapter.py
Normal file
@@ -0,0 +1,90 @@
|
||||
# -*- coding: UTF-8 -*-
|
||||
# Author: Eternity
|
||||
# @Email: 1533512157@qq.com
|
||||
# @Time : 2026/2/24 15:58
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import defaultdict
|
||||
from enum import StrEnum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.core.workflow.adapters.errors import ExceptionDefineition
|
||||
from app.schemas.workflow_schema import (
|
||||
EdgeDefinition,
|
||||
NodeDefinition,
|
||||
VariableDefinition,
|
||||
ExecutionConfig,
|
||||
TriggerConfig
|
||||
)
|
||||
|
||||
|
||||
class PlatformType(StrEnum):
|
||||
MEMORY_BEAR = "memory_bear"
|
||||
DIFY = "dify"
|
||||
COZE = "coze"
|
||||
|
||||
|
||||
class PlatformMetadata(BaseModel):
|
||||
platform_name: str
|
||||
version: str
|
||||
support_node_types: list[str]
|
||||
|
||||
|
||||
class WorkflowParserResult(BaseModel):
|
||||
success: bool
|
||||
platform: PlatformMetadata
|
||||
execution_config: ExecutionConfig
|
||||
origin_config: dict[str, Any]
|
||||
trigger: TriggerConfig | None
|
||||
edges: list[EdgeDefinition] = Field(default_factory=list)
|
||||
nodes: list[NodeDefinition] = Field(default_factory=list)
|
||||
variables: list[VariableDefinition] = Field(default_factory=list)
|
||||
warnings: list[ExceptionDefineition] = Field(default_factory=list)
|
||||
errors: list[ExceptionDefineition] = Field(default_factory=list)
|
||||
|
||||
|
||||
class WorkflowImportResult(BaseModel):
|
||||
success: bool
|
||||
temp_id: str | None = Field(..., description="cache id")
|
||||
workflow_id: str | None = Field(..., description="workflow id")
|
||||
edges: list[EdgeDefinition] = Field(default_factory=list)
|
||||
nodes: list[NodeDefinition] = Field(default_factory=list)
|
||||
variables: list[VariableDefinition] = Field(default_factory=list)
|
||||
warnings: list[ExceptionDefineition] = Field(default_factory=list)
|
||||
errors: list[ExceptionDefineition] = Field(default_factory=list)
|
||||
|
||||
|
||||
class BasePlatformAdapter(ABC):
|
||||
def __init__(self, config: dict[str, Any]):
|
||||
self.config = config
|
||||
self.nodes: list[NodeDefinition] = []
|
||||
self.edges: list[EdgeDefinition] = []
|
||||
self.conv_variables: list[VariableDefinition] = []
|
||||
|
||||
self.errors = []
|
||||
self.warnings = []
|
||||
|
||||
self.branch_node_cache = defaultdict(list)
|
||||
self.error_branch_node_cache = []
|
||||
|
||||
self.node_output_map = {}
|
||||
|
||||
@abstractmethod
|
||||
def get_metadata(self) -> PlatformMetadata:
|
||||
"""get platform metadata"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def validate_config(self) -> bool:
|
||||
"""platform configuration validate"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def parse_workflow(self) -> WorkflowParserResult:
|
||||
"""parse platform configuration to local config"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def map_node_type(self, platform_node_type: str) -> str:
|
||||
pass
|
||||
75
api/app/core/workflow/adapters/base_converter.py
Normal file
75
api/app/core/workflow/adapters/base_converter.py
Normal file
@@ -0,0 +1,75 @@
|
||||
# -*- coding: UTF-8 -*-
|
||||
# Author: Eternity
|
||||
# @Email: 1533512157@qq.com
|
||||
# @Time : 2026/2/26 14:32
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from app.core.workflow.variable.base_variable import DEFAULT_VALUE, VariableType
|
||||
|
||||
|
||||
class BaseConverter(ABC):
|
||||
@staticmethod
|
||||
def _convert_string(var):
|
||||
try:
|
||||
return str(var)
|
||||
except:
|
||||
return DEFAULT_VALUE(VariableType.STRING)
|
||||
|
||||
@staticmethod
|
||||
def _convert_boolean(var):
|
||||
try:
|
||||
return bool(var)
|
||||
except:
|
||||
return DEFAULT_VALUE(VariableType.BOOLEAN)
|
||||
|
||||
@staticmethod
|
||||
def _convert_number(var):
|
||||
try:
|
||||
return float(var)
|
||||
except:
|
||||
return DEFAULT_VALUE(VariableType.NUMBER)
|
||||
|
||||
@staticmethod
|
||||
def _convert_object(var):
|
||||
try:
|
||||
return dict(var)
|
||||
except:
|
||||
return DEFAULT_VALUE(VariableType.OBJECT)
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def _convert_file(var):
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def _convert_array_string(var):
|
||||
try:
|
||||
return list(var)
|
||||
except:
|
||||
return DEFAULT_VALUE(VariableType.ARRAY_STRING)
|
||||
|
||||
@staticmethod
|
||||
def _convert_array_number(var):
|
||||
try:
|
||||
return list(var)
|
||||
except:
|
||||
return DEFAULT_VALUE(VariableType.ARRAY_NUMBER)
|
||||
|
||||
@staticmethod
|
||||
def _convert_array_boolean(var):
|
||||
try:
|
||||
return list(var)
|
||||
except:
|
||||
return DEFAULT_VALUE(VariableType.ARRAY_BOOLEAN)
|
||||
|
||||
@staticmethod
|
||||
def _convert_array_object(var):
|
||||
try:
|
||||
return list(var)
|
||||
except:
|
||||
return DEFAULT_VALUE(VariableType.ARRAY_OBJECT)
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def _convert_array_file(var):
|
||||
pass
|
||||
4
api/app/core/workflow/adapters/dify/__init__.py
Normal file
4
api/app/core/workflow/adapters/dify/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
# -*- coding: UTF-8 -*-
|
||||
# Author: Eternity
|
||||
# @Email: 1533512157@qq.com
|
||||
# @Time : 2026/2/25 18:20
|
||||
734
api/app/core/workflow/adapters/dify/converter.py
Normal file
734
api/app/core/workflow/adapters/dify/converter.py
Normal file
@@ -0,0 +1,734 @@
|
||||
# -*- coding: UTF-8 -*-
|
||||
# Author: Eternity
|
||||
# @Email: 1533512157@qq.com
|
||||
# @Time : 2026/2/25 18:21
|
||||
import base64
|
||||
import re
|
||||
from typing import Any
|
||||
from urllib.parse import quote
|
||||
|
||||
from app.core.workflow.adapters.base_converter import BaseConverter
|
||||
from app.core.workflow.adapters.errors import UnsupportVariableType, UnknowModelWarning, ExceptionDefineition, \
|
||||
ExceptionType
|
||||
from app.core.workflow.nodes.assigner import AssignerNodeConfig
|
||||
from app.core.workflow.nodes.assigner.config import AssignmentItem
|
||||
from app.core.workflow.nodes.base_config import VariableDefinition, BaseNodeConfig
|
||||
from app.core.workflow.nodes.code import CodeNodeConfig
|
||||
from app.core.workflow.nodes.code.config import InputVariable, OutputVariable
|
||||
from app.core.workflow.nodes.configs import StartNodeConfig, LLMNodeConfig
|
||||
from app.core.workflow.nodes.cycle_graph import LoopNodeConfig, IterationNodeConfig
|
||||
from app.core.workflow.nodes.cycle_graph.config import ConditionDetail as LoopConditionDetail, ConditionsConfig, \
|
||||
CycleVariable
|
||||
from app.core.workflow.nodes.end import EndNodeConfig
|
||||
from app.core.workflow.nodes.enums import ValueInputType, ComparisonOperator, AssignmentOperator, HttpAuthType, \
|
||||
HttpContentType, HttpErrorHandle
|
||||
from app.core.workflow.nodes.http_request import HttpRequestNodeConfig
|
||||
from app.core.workflow.nodes.http_request.config import HttpAuthConfig, HttpContentTypeConfig, HttpFormData, \
|
||||
HttpTimeOutConfig, HttpRetryConfig, HttpErrorDefaultTamplete, HttpErrorHandleConfig
|
||||
from app.core.workflow.nodes.if_else import IfElseNodeConfig
|
||||
from app.core.workflow.nodes.if_else.config import ConditionDetail, ConditionBranchConfig
|
||||
from app.core.workflow.nodes.jinja_render import JinjaRenderNodeConfig
|
||||
from app.core.workflow.nodes.jinja_render.config import VariablesMappingConfig
|
||||
from app.core.workflow.nodes.knowledge import KnowledgeRetrievalNodeConfig
|
||||
from app.core.workflow.nodes.llm.config import MemoryWindowSetting, MessageConfig
|
||||
from app.core.workflow.nodes.parameter_extractor import ParameterExtractorNodeConfig
|
||||
from app.core.workflow.nodes.parameter_extractor.config import ParamsConfig
|
||||
from app.core.workflow.nodes.question_classifier import QuestionClassifierNodeConfig
|
||||
from app.core.workflow.nodes.question_classifier.config import ClassifierConfig
|
||||
from app.core.workflow.nodes.variable_aggregator import VariableAggregatorNodeConfig
|
||||
from app.core.workflow.variable.base_variable import VariableType, DEFAULT_VALUE
|
||||
|
||||
|
||||
class DifyConverter(BaseConverter):
|
||||
errors: list
|
||||
warnings: list
|
||||
branch_node_cache: dict
|
||||
error_branch_node_cache: list
|
||||
node_output_map: dict
|
||||
|
||||
def __init__(self):
|
||||
self.CONFIG_CONVERT_MAP = {
|
||||
"start": self.convert_start_node_config,
|
||||
"llm": self.convert_llm_node_config,
|
||||
"answer": self.convert_end_node_config,
|
||||
"if-else": self.convert_if_else_node_config,
|
||||
"loop": self.convert_loop_node_config,
|
||||
"iteration": self.convert_iteration_node_config,
|
||||
"assigner": self.convert_assigner_node_config,
|
||||
"code": self.convert_code_node_config,
|
||||
"http-request": self.convert_http_node_config,
|
||||
"template-transform": self.convert_jinja_render_node_config,
|
||||
"knowledge-retrieval": self.convert_knowledge_node_config,
|
||||
"parameter-extractor": self.convert_parameter_extractor_node_config,
|
||||
"question-classifier": self.convert_question_classifier_node_config,
|
||||
"variable-aggregator": self.convert_variable_aggregator_node_config,
|
||||
"tool": self.convert_tool_node_config,
|
||||
"loop-start": lambda x: {},
|
||||
"iteration-start": lambda x: {},
|
||||
"loop-end": lambda x: {},
|
||||
}
|
||||
|
||||
def get_node_convert(self, node_type):
|
||||
func = self.CONFIG_CONVERT_MAP.get(node_type, lambda x: {})
|
||||
return func
|
||||
|
||||
def config_validate(
|
||||
self,
|
||||
node_id: str,
|
||||
node_name: str,
|
||||
config: type[BaseNodeConfig],
|
||||
value: dict
|
||||
):
|
||||
try:
|
||||
return config.model_validate(value)
|
||||
except Exception as e:
|
||||
self.errors.append(ExceptionDefineition(
|
||||
type=ExceptionType.CONFIG,
|
||||
node_id=node_id,
|
||||
node_name=node_name,
|
||||
detail=str(e)
|
||||
))
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def is_variable(expression) -> bool:
|
||||
return bool(re.match(r"\{\{#(.*?)#}}", expression))
|
||||
|
||||
def process_var_selector(self, var_selector):
|
||||
if not var_selector:
|
||||
return ""
|
||||
selector = var_selector.split('.')
|
||||
if len(selector) not in [2, 3] and var_selector != "context":
|
||||
raise Exception(f"invalid variable selector: {var_selector}")
|
||||
if len(selector) == 3:
|
||||
selector = selector[1:]
|
||||
if selector[0] == "conversation":
|
||||
selector[0] = "conv"
|
||||
var_selector = ".".join(selector)
|
||||
mapping = {
|
||||
"sys.query": "sys.message"
|
||||
} | self.node_output_map
|
||||
|
||||
var_selector = mapping.get(var_selector, var_selector)
|
||||
return var_selector
|
||||
|
||||
def _process_list_variable_litearl(self, variable_selector: list) -> str | None:
|
||||
if not self.process_var_selector(".".join(variable_selector)):
|
||||
return None
|
||||
return "{{" + self.process_var_selector(".".join(variable_selector)) + "}}"
|
||||
|
||||
def trans_variable_format(self, content):
|
||||
pattern = re.compile(r"\{\{#(.*?)#}}")
|
||||
|
||||
def replacer(match: re.Match) -> str:
|
||||
raw_name = match.group(1)
|
||||
new_name = self.process_var_selector(raw_name)
|
||||
return f"{{{{{new_name}}}}}"
|
||||
|
||||
return pattern.sub(replacer, content)
|
||||
|
||||
@staticmethod
|
||||
def _convert_file(var):
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _convert_array_file(var):
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
def variable_type_map(source_type) -> VariableType | None:
|
||||
type_map = {
|
||||
"file": VariableType.FILE,
|
||||
"paragraph": VariableType.STRING,
|
||||
"text-input": VariableType.STRING,
|
||||
"number": VariableType.NUMBER,
|
||||
"checkbox": VariableType.BOOLEAN,
|
||||
"file-list": VariableType.ARRAY_FILE,
|
||||
"select": VariableType.STRING,
|
||||
"integer": VariableType.NUMBER,
|
||||
"float": VariableType.NUMBER,
|
||||
}
|
||||
var_type = type_map.get(source_type, source_type)
|
||||
return var_type
|
||||
|
||||
def convert_variable_type(self, target_type: VariableType, origin_value: Any):
|
||||
if not origin_value:
|
||||
return DEFAULT_VALUE(target_type)
|
||||
try:
|
||||
match target_type:
|
||||
case VariableType.STRING:
|
||||
return self._convert_string(origin_value)
|
||||
case VariableType.NUMBER:
|
||||
return self._convert_number(origin_value)
|
||||
case VariableType.BOOLEAN:
|
||||
return self._convert_boolean(origin_value)
|
||||
case VariableType.FILE:
|
||||
return self._convert_file(origin_value)
|
||||
case VariableType.ARRAY_FILE:
|
||||
return self._convert_array_file(origin_value)
|
||||
case _:
|
||||
return origin_value
|
||||
except:
|
||||
raise Exception(f"convert variable failed: {target_type}")
|
||||
|
||||
@staticmethod
|
||||
def convert_compare_operator(operator):
|
||||
operator_map = {
|
||||
"is": ComparisonOperator.EQ,
|
||||
"is not": ComparisonOperator.NE,
|
||||
"=": ComparisonOperator.EQ,
|
||||
"≠": ComparisonOperator.NE,
|
||||
">": ComparisonOperator.GT,
|
||||
"<": ComparisonOperator.LT,
|
||||
"≥": ComparisonOperator.GE,
|
||||
"≤": ComparisonOperator.LE,
|
||||
"not empty": ComparisonOperator.NOT_EMPTY,
|
||||
"start with": ComparisonOperator.START_WITH,
|
||||
"end with": ComparisonOperator.END_WITH,
|
||||
"not contains": ComparisonOperator.NOT_CONTAINS,
|
||||
"exists": ComparisonOperator.NOT_EMPTY,
|
||||
"not exists": ComparisonOperator.EMPTY
|
||||
}
|
||||
return operator_map.get(operator, operator)
|
||||
|
||||
@staticmethod
|
||||
def convert_assignment_operator(operator):
|
||||
operator_map = {
|
||||
"+=": AssignmentOperator.ADD,
|
||||
"-=": AssignmentOperator.SUBTRACT,
|
||||
"*=": AssignmentOperator.MULTIPLY,
|
||||
"/=": AssignmentOperator.DIVIDE,
|
||||
"over-write": AssignmentOperator.COVER,
|
||||
"remove-last": AssignmentOperator.REMOVE_LAST,
|
||||
"remove-first": AssignmentOperator.REMOVE_FIRST,
|
||||
"set": AssignmentOperator.ASSIGN,
|
||||
}
|
||||
return operator_map.get(operator, operator)
|
||||
|
||||
@staticmethod
|
||||
def convert_http_auth_type(auth_type):
|
||||
auth_type_map = {
|
||||
"no-auth": HttpAuthType.NONE,
|
||||
"bearer": HttpAuthType.BEARER,
|
||||
"basic": HttpAuthType.BASIC,
|
||||
"custom": HttpAuthType.CUSTOM,
|
||||
}
|
||||
return auth_type_map.get(auth_type, auth_type)
|
||||
|
||||
@staticmethod
|
||||
def convert_http_content_type(content_type):
|
||||
content_type_map = {
|
||||
"none": HttpContentType.NONE,
|
||||
"form-data": HttpContentType.FROM_DATA,
|
||||
"x-www-form-urlencoded": HttpContentType.WWW_FORM,
|
||||
"json": HttpContentType.JSON,
|
||||
"raw-text": HttpContentType.RAW,
|
||||
"binary": HttpContentType.BINARY,
|
||||
}
|
||||
return content_type_map.get(content_type, content_type)
|
||||
|
||||
@staticmethod
|
||||
def convert_http_error_handle_type(handle_type):
|
||||
handle_type_map = {
|
||||
"none": HttpErrorHandle.NONE,
|
||||
"fail-branch": HttpErrorHandle.BRANCH,
|
||||
"default-value": HttpErrorHandle.DEFAULT,
|
||||
}
|
||||
return handle_type_map.get(handle_type, handle_type)
|
||||
|
||||
def convert_start_node_config(self, node: dict) -> dict:
|
||||
node_data = node["data"]
|
||||
start_vars = []
|
||||
for var in node_data["variables"]:
|
||||
var_type = self.variable_type_map(var["type"])
|
||||
if not var_type:
|
||||
self.errors.append(
|
||||
UnsupportVariableType(
|
||||
scope=node["id"],
|
||||
name=var["variable"],
|
||||
var_type=var["type"],
|
||||
node_id=node["id"],
|
||||
node_name=node_data["title"]
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
if var_type in ["file", "array[file]"]:
|
||||
self.errors.append(
|
||||
ExceptionDefineition(
|
||||
type=ExceptionType.VARIABLE,
|
||||
node_id=node["id"],
|
||||
node_name=node_data["title"],
|
||||
name=var["variable"],
|
||||
detail=f"Unsupported Variable type for start node: {var_type}"
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
var_def = VariableDefinition(
|
||||
name=var["variable"],
|
||||
type=var_type,
|
||||
required=var["required"],
|
||||
default=self.convert_variable_type(
|
||||
var_type, var.get("default")
|
||||
),
|
||||
description=var["label"],
|
||||
max_length=var.get("max_length", 50),
|
||||
)
|
||||
start_vars.append(var_def)
|
||||
result = StartNodeConfig.model_construct(
|
||||
variables=start_vars
|
||||
).model_dump()
|
||||
self.config_validate(node["id"], node["data"]["title"], StartNodeConfig, result)
|
||||
return result
|
||||
|
||||
def convert_question_classifier_node_config(self, node: dict) -> dict:
|
||||
node_data = node["data"]
|
||||
self.warnings.append(
|
||||
UnknowModelWarning(
|
||||
node_id=node["id"],
|
||||
node_name=node_data["title"],
|
||||
model_name=node_data["model"].get("name")
|
||||
)
|
||||
)
|
||||
categories = []
|
||||
for category in node_data["classes"]:
|
||||
self.branch_node_cache[node["id"]].append(category["id"])
|
||||
categories.append(
|
||||
ClassifierConfig.model_construct(
|
||||
class_name=category["name"],
|
||||
)
|
||||
)
|
||||
|
||||
result = QuestionClassifierNodeConfig.model_construct(
|
||||
input_variable=self._process_list_variable_litearl(node_data.get("query_variable_selector")),
|
||||
user_supplement_prompt=self.trans_variable_format(node_data.get("instructions", "")),
|
||||
categories=categories,
|
||||
).model_dump()
|
||||
self.config_validate(node["id"], node["data"]["title"], QuestionClassifierNodeConfig, result)
|
||||
return result
|
||||
|
||||
def convert_llm_node_config(self, node: dict) -> dict:
|
||||
node_data = node["data"]
|
||||
self.warnings.append(
|
||||
UnknowModelWarning(
|
||||
node_id=node["id"],
|
||||
node_name=node_data["title"],
|
||||
model_name=node_data["model"].get("name")
|
||||
)
|
||||
)
|
||||
context = self._process_list_variable_litearl(node_data["context"]["variable_selector"])
|
||||
memory = MemoryWindowSetting(
|
||||
enable=bool(node_data.get("memory")),
|
||||
enable_window=bool(node_data.get("memory", {}).get("window", {}).get("enabled", False)),
|
||||
window_size=int(node_data.get("memory", {}).get("window", {}).get("size", 20))
|
||||
)
|
||||
messages = []
|
||||
for message in node_data["prompt_template"]:
|
||||
messages.append(
|
||||
MessageConfig(
|
||||
role=message["role"],
|
||||
content=self.trans_variable_format(message["text"])
|
||||
)
|
||||
)
|
||||
if memory.enable:
|
||||
messages.append(
|
||||
MessageConfig(
|
||||
role="user",
|
||||
content=self.trans_variable_format(
|
||||
node_data["memory"].get("query_prompt_template") or "{{#sys.query#}}"
|
||||
)
|
||||
)
|
||||
)
|
||||
vision = node_data["vision"]["enabled"]
|
||||
vision_input = self._process_list_variable_litearl(
|
||||
node_data["vision"]["configs"]["variable_selector"]
|
||||
) if vision else None
|
||||
result = LLMNodeConfig.model_construct(
|
||||
model_id=None,
|
||||
context=context,
|
||||
memory=memory,
|
||||
vision=vision,
|
||||
vision_input=vision_input,
|
||||
messages=messages
|
||||
).model_dump()
|
||||
self.config_validate(node["id"], node["data"]["title"], LLMNodeConfig, result)
|
||||
return result
|
||||
|
||||
def convert_end_node_config(self, node: dict) -> dict:
|
||||
node_data = node["data"]
|
||||
result = EndNodeConfig.model_construct(
|
||||
output=self.trans_variable_format(node_data.get("answer", "")),
|
||||
).model_dump()
|
||||
self.config_validate(node["id"], node["data"]["title"], EndNodeConfig, result)
|
||||
return result
|
||||
|
||||
def convert_if_else_node_config(self, node: dict) -> dict:
|
||||
node_data = node["data"]
|
||||
cases = []
|
||||
for case in node_data["cases"]:
|
||||
case_id = case.get("id") or case.get("case_id")
|
||||
logical_operator = case["logical_operator"]
|
||||
conditions = []
|
||||
for condition in case["conditions"]:
|
||||
right_value = condition["value"]
|
||||
condition_detail = ConditionDetail(
|
||||
operator=self.convert_compare_operator(condition["comparison_operator"]),
|
||||
left="{{" + self.process_var_selector(".".join(condition["variable_selector"])) + "}}",
|
||||
right=self.trans_variable_format(
|
||||
right_value
|
||||
) if isinstance(right_value, str) and self.is_variable(right_value) else self.convert_variable_type(
|
||||
self.variable_type_map(condition["varType"]),
|
||||
condition["value"]
|
||||
),
|
||||
input_type=ValueInputType.VARIABLE
|
||||
if isinstance(right_value, str) and self.is_variable(right_value) else ValueInputType.CONSTANT,
|
||||
)
|
||||
conditions.append(condition_detail)
|
||||
cases.append(
|
||||
ConditionBranchConfig(
|
||||
logical_operator=logical_operator,
|
||||
expressions=conditions
|
||||
)
|
||||
)
|
||||
self.branch_node_cache[node["id"]].append(case_id)
|
||||
result = IfElseNodeConfig.model_construct(
|
||||
cases=cases
|
||||
).model_dump()
|
||||
self.config_validate(node["id"], node["data"]["title"], IfElseNodeConfig, result)
|
||||
return result
|
||||
|
||||
def convert_loop_node_config(self, node: dict) -> dict:
|
||||
node_data = node["data"]
|
||||
logical_operator = node_data["logical_operator"]
|
||||
conditions = []
|
||||
for condition in node_data["break_conditions"]:
|
||||
right_value = condition["value"]
|
||||
conditions.append(
|
||||
LoopConditionDetail.model_construct(
|
||||
operator=self.convert_compare_operator(condition["comparison_operator"]),
|
||||
left=self._process_list_variable_litearl(condition["variable_selector"]),
|
||||
right=self.trans_variable_format(
|
||||
right_value
|
||||
) if isinstance(right_value, str) and self.is_variable(right_value) else self.convert_variable_type(
|
||||
self.variable_type_map(condition["varType"]),
|
||||
condition["value"]
|
||||
),
|
||||
input_type=ValueInputType.VARIABLE
|
||||
if isinstance(right_value, str) and self.is_variable(right_value) else ValueInputType.CONSTANT,
|
||||
)
|
||||
)
|
||||
condition_config = ConditionsConfig.model_construct(
|
||||
logical_operator=logical_operator,
|
||||
expressions=conditions
|
||||
)
|
||||
loop_variables = []
|
||||
for variable in node_data["loop_variables"]:
|
||||
right_input_type = variable["value_type"]
|
||||
right_value_type = self.variable_type_map(variable["var_type"])
|
||||
if right_input_type == ValueInputType.VARIABLE:
|
||||
right_value = self._process_list_variable_litearl(variable.get("value", ""))
|
||||
else:
|
||||
right_value = self.convert_variable_type(right_value_type, variable.get("value", ""))
|
||||
loop_variables.append(
|
||||
CycleVariable(
|
||||
name=variable["label"],
|
||||
type=right_value_type,
|
||||
value=right_value,
|
||||
input_type=right_input_type
|
||||
)
|
||||
)
|
||||
result = LoopNodeConfig.model_construct(
|
||||
condition=condition_config,
|
||||
cycle_vars=loop_variables,
|
||||
max_loop=node_data.get("loop_count", 10)
|
||||
).model_dump()
|
||||
self.config_validate(node["id"], node["data"]["title"], LoopNodeConfig, result)
|
||||
return result
|
||||
|
||||
def convert_iteration_node_config(self, node: dict) -> dict:
|
||||
node_data = node["data"]
|
||||
result = IterationNodeConfig.model_construct(
|
||||
input=self._process_list_variable_litearl(node_data["iterator_selector"]),
|
||||
parallel=node_data["is_parallel"],
|
||||
parallel_count=node_data["parallel_nums"],
|
||||
output=self._process_list_variable_litearl(node_data["output_selector"]),
|
||||
output_type=self.variable_type_map(node_data.get("output_type")),
|
||||
flatten=node_data["flatten_output"],
|
||||
).model_dump()
|
||||
|
||||
self.config_validate(node["id"], node["data"]["title"], IterationNodeConfig, result)
|
||||
return result
|
||||
|
||||
def convert_assigner_node_config(self, node: dict) -> dict:
|
||||
node_data = node["data"]
|
||||
assignments = []
|
||||
for assignment in node_data["items"]:
|
||||
if assignment.get("operation") is None or assignment.get("value") is None:
|
||||
continue
|
||||
assignments.append(
|
||||
AssignmentItem(
|
||||
variable_selector=self._process_list_variable_litearl(assignment["variable_selector"]),
|
||||
value=self._process_list_variable_litearl(
|
||||
assignment["value"]
|
||||
) if assignment["input_type"] == ValueInputType.VARIABLE else assignment["value"],
|
||||
operation=self.convert_assignment_operator(assignment["operation"])
|
||||
)
|
||||
)
|
||||
result = AssignerNodeConfig.model_construct(
|
||||
assignments=assignments
|
||||
).model_dump()
|
||||
self.config_validate(node["id"], node["data"]["title"], AssignerNodeConfig, result)
|
||||
return result
|
||||
|
||||
def convert_code_node_config(self, node: dict) -> dict:
|
||||
node_data = node["data"]
|
||||
input_variables = []
|
||||
for input_variable in node_data["variables"]:
|
||||
input_variables.append(
|
||||
InputVariable.model_construct(
|
||||
name=input_variable["variable"],
|
||||
variable=self._process_list_variable_litearl(input_variable["value_selector"]),
|
||||
)
|
||||
)
|
||||
|
||||
output_variables = []
|
||||
for output_variable in node_data["outputs"]:
|
||||
output_variables.append(
|
||||
OutputVariable.model_construct(
|
||||
name=output_variable,
|
||||
type=node_data["outputs"][output_variable]["type"],
|
||||
)
|
||||
)
|
||||
|
||||
code = base64.b64encode(quote(node_data["code"]).encode("utf-8")).decode("utf-8")
|
||||
|
||||
result = CodeNodeConfig.model_construct(
|
||||
input_variables=input_variables,
|
||||
language=node_data["code_language"],
|
||||
output_variables=output_variables,
|
||||
code=code
|
||||
).model_dump()
|
||||
self.config_validate(node["id"], node["data"]["title"], CodeNodeConfig, result)
|
||||
return result
|
||||
|
||||
def convert_http_node_config(self, node: dict) -> dict:
|
||||
node_data = node["data"]
|
||||
if node_data["authorization"]["type"] != 'no-auth':
|
||||
auth_type = self.convert_http_auth_type(node_data["authorization"]["config"]["type"])
|
||||
auth_config = HttpAuthConfig.model_construct(
|
||||
auth_type=auth_type,
|
||||
header=node_data["authorization"]["config"].get("header"),
|
||||
api_key=node_data["authorization"]["config"].get("api_key"),
|
||||
)
|
||||
else:
|
||||
auth_config = HttpAuthConfig()
|
||||
|
||||
content_type = self.convert_http_content_type(node_data["body"]["type"])
|
||||
if content_type == HttpContentType.FROM_DATA:
|
||||
body_content = []
|
||||
for content in node_data["body"]["data"]:
|
||||
body_content.append(
|
||||
HttpFormData(
|
||||
key=self.trans_variable_format(content["key"]),
|
||||
type=content["type"],
|
||||
value=self.trans_variable_format(content["value"]),
|
||||
)
|
||||
)
|
||||
elif content_type == HttpContentType.WWW_FORM:
|
||||
body_content = {}
|
||||
for content in node_data["body"]["data"]:
|
||||
body_content[
|
||||
self.trans_variable_format(content["key"])
|
||||
] = self.trans_variable_format(content["value"])
|
||||
else:
|
||||
if node_data["body"]["data"]:
|
||||
body_content = (node_data["body"]["data"][0].get("value") or
|
||||
self._process_list_variable_litearl(node_data["body"]["data"][0].get("file")))
|
||||
else:
|
||||
body_content = ""
|
||||
|
||||
headers = {}
|
||||
for header in node_data.get("headers", "").split("\n"):
|
||||
if not header:
|
||||
continue
|
||||
|
||||
key_value = header.split(":")
|
||||
if len(key_value) == 2:
|
||||
headers[
|
||||
self.trans_variable_format(key_value[0])
|
||||
] = self.trans_variable_format(key_value[1])
|
||||
else:
|
||||
self.warnings.append(ExceptionDefineition(
|
||||
type=ExceptionType.CONFIG,
|
||||
node_id=node["id"],
|
||||
node_name=node_data["title"],
|
||||
detail=f"Invalid header/param - {header}",
|
||||
))
|
||||
|
||||
params = {}
|
||||
for param in node_data.get("params", "").split("\n"):
|
||||
if not param:
|
||||
continue
|
||||
|
||||
key_value = param.split(":")
|
||||
if len(key_value) == 2:
|
||||
params[
|
||||
self.trans_variable_format(key_value[0])
|
||||
] = self.trans_variable_format(key_value[1])
|
||||
else:
|
||||
self.warnings.append(ExceptionDefineition(
|
||||
type=ExceptionType.CONFIG,
|
||||
node_id=node["id"],
|
||||
node_name=node_data["title"],
|
||||
detail=f"Invalid header/param - {param}",
|
||||
))
|
||||
|
||||
error_handle_type = self.convert_http_error_handle_type(
|
||||
node_data.get("error_strategy", "none")
|
||||
)
|
||||
default_value = None
|
||||
if error_handle_type == HttpErrorHandle.DEFAULT:
|
||||
default_body = ""
|
||||
default_header = {}
|
||||
default_status_code = 0
|
||||
for var in node_data.get("default_value") or []:
|
||||
if var["key"] == "body":
|
||||
default_body = var["value"]
|
||||
elif var["key"] == "header":
|
||||
default_header = var["value"]
|
||||
elif var["key"] == "status_code":
|
||||
default_status_code = var["value"]
|
||||
default_value = HttpErrorDefaultTamplete(
|
||||
body=default_body,
|
||||
headers=default_header,
|
||||
status_code=default_status_code,
|
||||
)
|
||||
|
||||
self.error_branch_node_cache.append(node['id'])
|
||||
result = HttpRequestNodeConfig.model_construct(
|
||||
method=node_data["method"].upper(),
|
||||
url=node_data["url"],
|
||||
auth=auth_config,
|
||||
body=HttpContentTypeConfig.model_construct(
|
||||
content_type=self.convert_http_content_type(node_data["body"]["type"]),
|
||||
data=body_content,
|
||||
),
|
||||
headers=headers,
|
||||
params=params,
|
||||
verify_ssl=node_data.get("ssl_verify", False),
|
||||
timeouts=HttpTimeOutConfig.model_construct(
|
||||
connect_timeout=node_data["timeout"]["max_connect_timeout"] or 5,
|
||||
read_timeout=node_data["timeout"]["max_read_timeout"] or 5,
|
||||
write_timeout=node_data["timeout"]["max_write_timeout"] or 5,
|
||||
),
|
||||
retry=HttpRetryConfig.model_construct(
|
||||
enable=node_data["retry_config"]["retry_enabled"],
|
||||
max_attempts=node_data["retry_config"]["max_retries"],
|
||||
retry_interval=node_data["retry_config"]["retry_interval"],
|
||||
),
|
||||
error_handle=HttpErrorHandleConfig.model_construct(
|
||||
method=error_handle_type,
|
||||
default=default_value,
|
||||
)
|
||||
).model_dump()
|
||||
|
||||
self.config_validate(node["id"], node["data"]["title"], HttpRequestNodeConfig, result)
|
||||
return result
|
||||
|
||||
def convert_jinja_render_node_config(self, node: dict) -> dict:
|
||||
node_data = node["data"]
|
||||
mapping = []
|
||||
for variable in node_data["variables"]:
|
||||
mapping.append(VariablesMappingConfig.model_construct(
|
||||
name=variable["variable"],
|
||||
value=self._process_list_variable_litearl(variable["value_selector"])
|
||||
))
|
||||
result = JinjaRenderNodeConfig.model_construct(
|
||||
template=node_data["template"],
|
||||
mapping=mapping,
|
||||
).model_dump()
|
||||
self.config_validate(node["id"], node["data"]["title"], JinjaRenderNodeConfig, result)
|
||||
return result
|
||||
|
||||
def convert_knowledge_node_config(self, node: dict) -> dict:
|
||||
node_data = node["data"]
|
||||
self.warnings.append(ExceptionDefineition(
|
||||
node_id=node["id"],
|
||||
node_name=node_data["title"],
|
||||
type=ExceptionType.CONFIG,
|
||||
detail=f"Please reconfigure the Knowledge Retrieval node.",
|
||||
))
|
||||
result = KnowledgeRetrievalNodeConfig.model_construct(
|
||||
query=self._process_list_variable_litearl(node_data["query_variable_selector"]),
|
||||
).model_dump()
|
||||
|
||||
self.config_validate(node["id"], node["data"]["title"], KnowledgeRetrievalNodeConfig, result)
|
||||
return result
|
||||
|
||||
def convert_parameter_extractor_node_config(self, node: dict) -> dict:
|
||||
node_data = node["data"]
|
||||
self.warnings.append(
|
||||
UnknowModelWarning(
|
||||
node_id=node["id"],
|
||||
node_name=node_data["title"],
|
||||
model_name=node_data["model"].get("name")
|
||||
)
|
||||
)
|
||||
params = []
|
||||
for param in node_data.get("parameters", []):
|
||||
params.append(
|
||||
ParamsConfig.model_construct(
|
||||
name=param["name"],
|
||||
desc=param["description"],
|
||||
required=param["required"],
|
||||
type=param["type"],
|
||||
)
|
||||
)
|
||||
result = ParameterExtractorNodeConfig.model_construct(
|
||||
text=self._process_list_variable_litearl(node_data["query"]),
|
||||
params=params,
|
||||
prompt=node_data.get("instruction")
|
||||
).model_dump()
|
||||
|
||||
self.config_validate(node["id"], node["data"]["title"], ParameterExtractorNodeConfig, result)
|
||||
return result
|
||||
|
||||
def convert_variable_aggregator_node_config(self, node: dict) -> dict:
|
||||
node_data = node["data"]
|
||||
advanced_settings = node_data.get("advanced_settings", {})
|
||||
group_variables = {}
|
||||
group_type = {}
|
||||
if not advanced_settings or not advanced_settings["group_enabled"]:
|
||||
group_variables = [
|
||||
self._process_list_variable_litearl(variable)
|
||||
for variable in node_data["variables"]
|
||||
]
|
||||
group_type["output"] = node_data["output_type"]
|
||||
else:
|
||||
for group in advanced_settings["groups"]:
|
||||
group_variables[group["group_name"]] = [
|
||||
self._process_list_variable_litearl(variable)
|
||||
for variable in group["variables"]
|
||||
]
|
||||
group_type[group["group_name"]] = group["output_type"]
|
||||
|
||||
result = VariableAggregatorNodeConfig.model_construct(
|
||||
group=advanced_settings.get("group_enabled", False),
|
||||
group_variables=group_variables,
|
||||
group_type=group_type,
|
||||
).model_dump()
|
||||
|
||||
self.config_validate(node["id"], node["data"]["title"], VariableAggregatorNodeConfig, result)
|
||||
|
||||
return result
|
||||
|
||||
def convert_tool_node_config(self, node: dict) -> dict:
|
||||
node_data = node["data"]
|
||||
self.warnings.append(ExceptionDefineition(
|
||||
node_id=node["id"],
|
||||
node_name=node_data["title"],
|
||||
type=ExceptionType.CONFIG,
|
||||
detail=f"Please reconfigure the tool node.",
|
||||
))
|
||||
return {}
|
||||
261
api/app/core/workflow/adapters/dify/dify_adapter.py
Normal file
261
api/app/core/workflow/adapters/dify/dify_adapter.py
Normal file
@@ -0,0 +1,261 @@
|
||||
# -*- coding: UTF-8 -*-
|
||||
# Author: Eternity
|
||||
# @Email: 1533512157@qq.com
|
||||
# @Time : 2026/2/24 16:05
|
||||
from typing import Any
|
||||
|
||||
from app.core.logging_config import get_logger
|
||||
from app.core.workflow.adapters.base_adapter import (
|
||||
BasePlatformAdapter,
|
||||
PlatformMetadata,
|
||||
PlatformType,
|
||||
WorkflowParserResult
|
||||
)
|
||||
from app.core.workflow.adapters.dify.converter import DifyConverter
|
||||
from app.core.workflow.adapters.errors import ExceptionDefineition, ExceptionType
|
||||
from app.core.workflow.nodes.enums import NodeType
|
||||
from app.schemas.workflow_schema import (
|
||||
NodeDefinition,
|
||||
EdgeDefinition,
|
||||
VariableDefinition,
|
||||
TriggerConfig,
|
||||
ExecutionConfig
|
||||
)
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
class DifyAdapter(BasePlatformAdapter, DifyConverter):
|
||||
NODE_TYPE_MAPPING = {
|
||||
"start": NodeType.START,
|
||||
"llm": NodeType.LLM,
|
||||
"answer": NodeType.END,
|
||||
"if-else": NodeType.IF_ELSE,
|
||||
"loop-start": NodeType.CYCLE_START,
|
||||
"iteration-start": NodeType.CYCLE_START,
|
||||
"assigner": NodeType.ASSIGNER,
|
||||
"loop": NodeType.LOOP,
|
||||
"iteration": NodeType.ITERATION,
|
||||
"loop-end": NodeType.BREAK,
|
||||
"code": NodeType.CODE,
|
||||
"http-request": NodeType.HTTP_REQUEST,
|
||||
"template-transform": NodeType.JINJARENDER,
|
||||
"knowledge-retrieval": NodeType.KNOWLEDGE_RETRIEVAL,
|
||||
"parameter-extractor": NodeType.PARAMETER_EXTRACTOR,
|
||||
"question-classifier": NodeType.QUESTION_CLASSIFIER,
|
||||
"variable-aggregator": NodeType.VAR_AGGREGATOR,
|
||||
"tool": NodeType.TOOL,
|
||||
"": NodeType.NOTES
|
||||
}
|
||||
|
||||
def __init__(self, config: dict[str, Any]):
|
||||
DifyConverter.__init__(self)
|
||||
BasePlatformAdapter.__init__(self, config)
|
||||
|
||||
def get_metadata(self) -> PlatformMetadata:
|
||||
return PlatformMetadata(
|
||||
platform_name=PlatformType.DIFY,
|
||||
version="0.5.0",
|
||||
support_node_types=list(self.NODE_TYPE_MAPPING.keys())
|
||||
)
|
||||
|
||||
def map_node_type(self, platform_node_type) -> str:
|
||||
return self.NODE_TYPE_MAPPING.get(platform_node_type, NodeType.UNKNOWN)
|
||||
|
||||
@property
|
||||
def origin_nodes(self):
|
||||
return self.config.get("workflow").get("graph").get("nodes")
|
||||
|
||||
@property
|
||||
def origin_edges(self):
|
||||
return self.config.get("workflow").get("graph").get("edges")
|
||||
|
||||
@staticmethod
|
||||
def _valid_nodes(node: dict[str, Any]):
|
||||
if "data" not in node:
|
||||
return False
|
||||
if "type" not in node["data"]:
|
||||
return False
|
||||
if "id" not in node or "type" not in node:
|
||||
return False
|
||||
return True
|
||||
|
||||
def validate_config(self) -> bool:
|
||||
require_fields = frozenset({'app', 'kind', 'version', 'workflow'})
|
||||
if not all(field in self.config for field in require_fields):
|
||||
return False
|
||||
if self.config.get("app",{}).get("mode") == "workflow":
|
||||
self.errors.append(ExceptionDefineition(
|
||||
type=ExceptionType.PLATFORM,
|
||||
detail="workflow mode is not supported"
|
||||
))
|
||||
return False
|
||||
|
||||
for node in self.origin_nodes:
|
||||
if not self._valid_nodes(node):
|
||||
return False
|
||||
return True
|
||||
|
||||
def parse_workflow(self) -> WorkflowParserResult:
|
||||
self._init_node_output_map()
|
||||
for node in self.origin_nodes:
|
||||
node = self._convert_node(node)
|
||||
if node:
|
||||
self.nodes.append(node)
|
||||
nodes_id = [node.id for node in self.nodes]
|
||||
for edge in self.origin_edges:
|
||||
source = edge["source"]
|
||||
target = edge["target"]
|
||||
if source not in nodes_id or target not in nodes_id:
|
||||
continue
|
||||
edge = self._convert_edge(edge)
|
||||
if edge:
|
||||
self.edges.append(edge)
|
||||
#
|
||||
for variable in self.config.get("workflow").get("conversation_variables"):
|
||||
con_var = self._convert_variable(variable)
|
||||
if variable:
|
||||
self.conv_variables.append(con_var)
|
||||
#
|
||||
# for variables in config.get("workflow").get("environment_variables"):
|
||||
# variable = self._convert_variable(variables)
|
||||
# conv_variables.append(variable)
|
||||
|
||||
trigger = self._convert_trigger({})
|
||||
execution_config = self._convert_execution({})
|
||||
|
||||
return WorkflowParserResult(
|
||||
success=not self.errors and not self.warnings,
|
||||
platform=self.get_metadata(),
|
||||
execution_config=execution_config,
|
||||
origin_config=self.config,
|
||||
trigger=trigger,
|
||||
edges=self.edges,
|
||||
nodes=self.nodes,
|
||||
variables=self.conv_variables,
|
||||
warnings=self.warnings,
|
||||
errors=self.errors
|
||||
)
|
||||
|
||||
def _init_node_output_map(self):
|
||||
for node in self.origin_nodes:
|
||||
if self.map_node_type(node["data"]["type"]) == NodeType.LLM:
|
||||
self.node_output_map[f"{node['id']}.text"] = f"{node['id']}.output"
|
||||
elif self.map_node_type(node["data"]["type"]) == NodeType.KNOWLEDGE_RETRIEVAL:
|
||||
self.node_output_map[f"{node['id']}.result"] = f"{node['id']}.output"
|
||||
|
||||
def _convert_cycle_node_position(self, node_id: str, position: dict):
|
||||
for node in self.origin_nodes:
|
||||
if node["id"] == node_id:
|
||||
return {
|
||||
"x": node["position"]["x"] + position["x"],
|
||||
"y": node["position"]["y"] + position["y"]
|
||||
}
|
||||
self.errors.append(
|
||||
ExceptionDefineition(
|
||||
type=ExceptionType.NODE,
|
||||
node_id=node_id,
|
||||
detail="parent cycle node not found"
|
||||
)
|
||||
)
|
||||
raise Exception("parent cycle node not found")
|
||||
|
||||
def _convert_node(self, node: dict[str, Any]) -> NodeDefinition | None:
|
||||
node_data = node["data"]
|
||||
try:
|
||||
return NodeDefinition(
|
||||
id=node["id"],
|
||||
type=self.map_node_type(node_data["type"]),
|
||||
name=node_data.get("title") or "notes",
|
||||
cycle=node.get("parentId"),
|
||||
description=None,
|
||||
config=self._convert_node_config(node),
|
||||
position={
|
||||
"x": node["position"]["x"],
|
||||
"y": node["position"]["y"]
|
||||
} if node.get("parentId") is None else self._convert_cycle_node_position(
|
||||
node["parentId"],
|
||||
node["position"]
|
||||
),
|
||||
error_handling=None,
|
||||
cache=None
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug(f"convert node error - {e}", exc_info=True)
|
||||
|
||||
def _convert_node_config(self, node: dict):
|
||||
node_data = node["data"]
|
||||
node_type = node_data["type"]
|
||||
try:
|
||||
converter = self.get_node_convert(node_type)
|
||||
if node_type not in self.CONFIG_CONVERT_MAP:
|
||||
self.errors.append(ExceptionDefineition(
|
||||
type=ExceptionType.NODE,
|
||||
node_id=node["id"],
|
||||
node_name=node["data"]["title"],
|
||||
detail=f"node type {node_type if node_type else 'notes'} is unsupported",
|
||||
))
|
||||
return converter(node)
|
||||
except Exception as e:
|
||||
self.errors.append(ExceptionDefineition(
|
||||
type=ExceptionType.NODE,
|
||||
node_id=node["id"],
|
||||
node_name=node["data"]["title"],
|
||||
detail=f"convert node error - {e}",
|
||||
))
|
||||
raise e
|
||||
|
||||
def _convert_edge(self, edge: dict[str, Any]) -> EdgeDefinition | None:
|
||||
try:
|
||||
|
||||
source = edge["source"]
|
||||
target = edge["target"]
|
||||
label = None
|
||||
if source in self.branch_node_cache:
|
||||
case_id = edge["sourceHandle"]
|
||||
if case_id == "false":
|
||||
label = f'CASE{len(self.branch_node_cache[source])+1}'
|
||||
else:
|
||||
label = f'CASE{self.branch_node_cache[source].index(case_id) + 1}'
|
||||
if source in self.error_branch_node_cache:
|
||||
case_id = edge["sourceHandle"]
|
||||
if case_id == "source":
|
||||
label = "SUCCESS"
|
||||
else:
|
||||
label = "ERROR"
|
||||
return EdgeDefinition(
|
||||
id=edge["id"],
|
||||
source=source,
|
||||
target=target,
|
||||
label=label,
|
||||
)
|
||||
except Exception as e:
|
||||
self.errors.append(ExceptionDefineition(
|
||||
type=ExceptionType.EDGE,
|
||||
detail=f"convert edge error - {e}",
|
||||
))
|
||||
logger.debug(f"convert edge error - {e}", exc_info=True)
|
||||
return None
|
||||
|
||||
def _convert_variable(self, variable) -> VariableDefinition | None:
|
||||
try:
|
||||
return VariableDefinition(
|
||||
name=variable["name"],
|
||||
default=variable["value"],
|
||||
type=self.variable_type_map(variable["value_type"]),
|
||||
description=variable.get("description")
|
||||
)
|
||||
except Exception as e:
|
||||
self.errors.append(ExceptionDefineition(
|
||||
type=ExceptionType.VARIABLE,
|
||||
name=variable.get("name"),
|
||||
detail=f"convert variable error - {e}",
|
||||
))
|
||||
|
||||
def _convert_trigger(self, trigger: dict[str, Any]) -> TriggerConfig | None:
|
||||
pass
|
||||
|
||||
def _convert_execution(self, execution: dict[str, Any]) -> ExecutionConfig:
|
||||
return ExecutionConfig()
|
||||
|
||||
|
||||
75
api/app/core/workflow/adapters/errors.py
Normal file
75
api/app/core/workflow/adapters/errors.py
Normal file
@@ -0,0 +1,75 @@
|
||||
# -*- coding: UTF-8 -*-
|
||||
# Author: Eternity
|
||||
# @Email: 1533512157@qq.com
|
||||
# @Time : 2026/2/26 11:29
|
||||
from enum import StrEnum
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class ExceptionType(StrEnum):
|
||||
NODE = "node"
|
||||
EDGE = "edge"
|
||||
VARIABLE = "variable"
|
||||
TRIGGER = "trigger"
|
||||
EXECUTION = "execution"
|
||||
CONFIG = "config"
|
||||
PLATFORM = "platform"
|
||||
UNKNOWN = "unknown"
|
||||
|
||||
|
||||
class ExceptionDefineition(BaseModel):
|
||||
type: ExceptionType
|
||||
detail: str
|
||||
|
||||
node_id: str | None = None
|
||||
node_name: str | None = None
|
||||
|
||||
scope: str | None = None
|
||||
name: str | None = None
|
||||
|
||||
|
||||
class UnknowModelWarning(ExceptionDefineition):
|
||||
type: ExceptionType = ExceptionType.NODE
|
||||
|
||||
def __init__(self, node_id, node_name, model_name):
|
||||
super().__init__(
|
||||
detail=f"Please specify the model mapping manually for model: {model_name}",
|
||||
node_id=node_id,
|
||||
node_name=node_name
|
||||
)
|
||||
|
||||
|
||||
class UnknowError(ExceptionDefineition):
|
||||
type: ExceptionType = ExceptionType.UNKNOWN
|
||||
|
||||
def __init__(self, detail: str, **kwargs):
|
||||
super().__init__(detail=detail, **kwargs)
|
||||
|
||||
|
||||
class UnsupportPlatform(ExceptionDefineition):
|
||||
type: ExceptionType = ExceptionType.PLATFORM
|
||||
|
||||
def __init__(self, platform: str):
|
||||
super().__init__(detail=f"Unsupport platform {platform}")
|
||||
|
||||
|
||||
class UnsupportVariableType(ExceptionDefineition):
|
||||
type: ExceptionType = ExceptionType.VARIABLE
|
||||
|
||||
def __init__(self, scope, name, var_type: str, **kwargs):
|
||||
super().__init__(scope=scope, name=name, detail=f"Unsupport variable type:[{var_type}]", **kwargs)
|
||||
|
||||
|
||||
class InvalidConfiguration(ExceptionDefineition):
|
||||
type: ExceptionType = ExceptionType.CONFIG
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(detail="Invalid workflow configuration format")
|
||||
|
||||
|
||||
class UnsupportNodeType(ExceptionDefineition):
|
||||
type: ExceptionType = ExceptionType.NODE
|
||||
|
||||
def __init__(self, node_id: str, node_type: str):
|
||||
super().__init__(node_id=node_id, detail=f"Unsupport node Type {node_type}")
|
||||
4
api/app/core/workflow/adapters/memory_bear/__init__.py
Normal file
4
api/app/core/workflow/adapters/memory_bear/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
# -*- coding: UTF-8 -*-
|
||||
# Author: Eternity
|
||||
# @Email: 1533512157@qq.com
|
||||
# @Time : 2026/2/26 11:30
|
||||
@@ -0,0 +1,76 @@
|
||||
# -*- coding: UTF-8 -*-
|
||||
# Author: Eternity
|
||||
# @Email: 1533512157@qq.com
|
||||
# @Time : 2026/2/25 14:11
|
||||
from typing import Any
|
||||
|
||||
from app.core.workflow.adapters.base_adapter import (
|
||||
PlatformMetadata,
|
||||
PlatformType,
|
||||
BasePlatformAdapter,
|
||||
WorkflowParserResult
|
||||
)
|
||||
from app.schemas.workflow_schema import ExecutionConfig
|
||||
|
||||
|
||||
class MemoryBearAdapter(BasePlatformAdapter):
|
||||
NODE_TYPE_MAPPING = {}
|
||||
|
||||
@property
|
||||
def origin_nodes(self):
|
||||
return self.config.get("workflow").get("nodes")
|
||||
|
||||
@property
|
||||
def origin_edges(self):
|
||||
return self.config.get("workflow").get("edges")
|
||||
|
||||
@property
|
||||
def origin_variables(self):
|
||||
return self.config.get("workflow").get("variables")
|
||||
|
||||
def get_metadata(self) -> PlatformMetadata:
|
||||
return PlatformMetadata(
|
||||
platform_name=PlatformType.MEMORY_BEAR,
|
||||
version="0.2.5",
|
||||
support_node_types=list(self.NODE_TYPE_MAPPING.keys())
|
||||
)
|
||||
|
||||
def map_node_type(self, platform_node_type) -> str:
|
||||
return platform_node_type
|
||||
|
||||
@staticmethod
|
||||
def _valid_nodes(node: dict[str, Any]):
|
||||
if "type" not in node["data"]:
|
||||
return False
|
||||
if "id" not in node or "type" not in node:
|
||||
return False
|
||||
return True
|
||||
|
||||
def validate_config(self) -> bool:
|
||||
require_fields = frozenset({'app', 'workflow'})
|
||||
if not all(field in self.config for field in require_fields):
|
||||
return False
|
||||
|
||||
for node in self.origin_nodes:
|
||||
if not self._valid_nodes(node):
|
||||
return False
|
||||
return True
|
||||
|
||||
def parse_workflow(self) -> WorkflowParserResult:
|
||||
self.nodes = self.origin_nodes
|
||||
self.edges = self.origin_edges
|
||||
self.conv_variables = self.origin_variables
|
||||
|
||||
return WorkflowParserResult(
|
||||
success=True,
|
||||
platform=self.get_metadata(),
|
||||
execution_config=ExecutionConfig(),
|
||||
origin_config=self.config,
|
||||
trigger=None,
|
||||
edges=self.edges,
|
||||
nodes=self.nodes,
|
||||
variables=self.conv_variables,
|
||||
warnings=self.warnings,
|
||||
errors=self.errors,
|
||||
|
||||
)
|
||||
34
api/app/core/workflow/adapters/registry.py
Normal file
34
api/app/core/workflow/adapters/registry.py
Normal file
@@ -0,0 +1,34 @@
|
||||
# -*- coding: UTF-8 -*-
|
||||
# Author: Eternity
|
||||
# @Email: 1533512157@qq.com
|
||||
# @Time : 2026/2/25 14:19
|
||||
from typing import Any
|
||||
|
||||
from app.core.workflow.adapters import DifyAdapter, MemoryBearAdapter
|
||||
from app.core.workflow.adapters.base_adapter import BasePlatformAdapter, PlatformType
|
||||
|
||||
|
||||
class PlatformAdapterRegistry:
|
||||
_adapters: dict[str, type[BasePlatformAdapter]] = {}
|
||||
|
||||
@classmethod
|
||||
def register(cls, platform: str, adapter: type[BasePlatformAdapter]):
|
||||
cls._adapters[platform] = adapter
|
||||
|
||||
@classmethod
|
||||
def get_adapter(cls, platform: str, config: dict[str, Any]) -> BasePlatformAdapter:
|
||||
if platform not in cls._adapters:
|
||||
raise ValueError(f"Unsupported platform: {platform}")
|
||||
return cls._adapters.get(platform)(config)
|
||||
|
||||
@classmethod
|
||||
def list_platforms(cls) -> list[str]:
|
||||
return list(cls._adapters.keys())
|
||||
|
||||
@classmethod
|
||||
def is_supported(cls, platform: str) -> bool:
|
||||
return platform in cls._adapters
|
||||
|
||||
|
||||
PlatformAdapterRegistry.register(PlatformType.MEMORY_BEAR, MemoryBearAdapter)
|
||||
PlatformAdapterRegistry.register(PlatformType.DIFY, DifyAdapter)
|
||||
@@ -127,7 +127,7 @@ class EventStreamHandler:
|
||||
yield {
|
||||
"event": "message",
|
||||
"data": {
|
||||
"chunk": data.get("chunk")
|
||||
"content": data.get("chunk")
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -292,6 +292,8 @@ class GraphBuilder:
|
||||
"""
|
||||
for node in self.nodes:
|
||||
node_type = node.get("type")
|
||||
if node_type == NodeType.NOTES:
|
||||
continue
|
||||
node_id = node.get("id")
|
||||
cycle_node = node.get("cycle")
|
||||
if cycle_node:
|
||||
@@ -320,7 +322,7 @@ class GraphBuilder:
|
||||
# Used later to determine which branch to take based on the node's output
|
||||
# Assumes node output `node.<node_id>.output` matches the edge's label
|
||||
# For example, if node.123.output == 'CASE1', take the branch labeled 'CASE1'
|
||||
related_edge[idx]['condition'] = f"node.{node_id}.output == '{related_edge[idx]['label']}'"
|
||||
related_edge[idx]['condition'] = f"node['{node_id}']['output'] == '{related_edge[idx]['label']}'"
|
||||
|
||||
if node_instance:
|
||||
# Wrap node's run method to avoid closure issues
|
||||
|
||||
@@ -13,7 +13,7 @@ from app.core.workflow.engine.variable_pool import VariablePool
|
||||
logger = get_logger(__name__)
|
||||
|
||||
SCOPE_PATTERN = re.compile(
|
||||
r"\{\{\s*([a-zA-Z_][a-zA-Z0-9_]*)\.[a-zA-Z0-9_]+\s*}}"
|
||||
r"\{\{\s*([a-zA-Z0-9_]+)\.[a-zA-Z0-9_]+\s*}}"
|
||||
)
|
||||
|
||||
|
||||
@@ -274,7 +274,7 @@ class StreamOutputCoordinator:
|
||||
yield {
|
||||
"event": "message",
|
||||
"data": {
|
||||
"chunk": final_chunk
|
||||
"content": final_chunk
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -73,7 +73,7 @@ class VariableStruct(BaseModel, Generic[T]):
|
||||
instance:
|
||||
The concrete variable object. The actual Python type is
|
||||
represented by the generic parameter ``T`` (e.g. StringVariable,
|
||||
NumberVariable, ArrayObject[StringVariable]).
|
||||
NumberVariable, ArrayVariable[StringVariable]).
|
||||
mut:
|
||||
Whether the variable is mutable.
|
||||
"""
|
||||
@@ -152,6 +152,36 @@ class VariablePool:
|
||||
return None
|
||||
return var_instance
|
||||
|
||||
def get_instance(
|
||||
self,
|
||||
selector: str,
|
||||
default: Any = None,
|
||||
strict: bool = True
|
||||
):
|
||||
"""Retrieve a variable instance from the variable pool.
|
||||
|
||||
Args:
|
||||
selector:
|
||||
Variable selector as a string variable literal (e.g. "{{ sys.message }}").
|
||||
default:
|
||||
The value to return if the variable does not exist.
|
||||
strict:
|
||||
If True, raises KeyError when the variable does not exist.
|
||||
|
||||
Returns:
|
||||
The variable instance object if it exists; otherwise returns `default`.
|
||||
|
||||
Raises:
|
||||
KeyError: If strict is True and the variable does not exist.
|
||||
"""
|
||||
variable_struct = self._get_variable_struct(selector)
|
||||
if variable_struct is None:
|
||||
if strict:
|
||||
raise KeyError(f"{selector} not exist")
|
||||
return default
|
||||
|
||||
return variable_struct.instance
|
||||
|
||||
def get_value(
|
||||
self,
|
||||
selector: str,
|
||||
@@ -273,38 +303,52 @@ class VariablePool:
|
||||
"""
|
||||
return self._get_variable_struct(selector) is not None
|
||||
|
||||
def get_all_system_vars(self) -> dict[str, Any]:
|
||||
def get_all_system_vars(self, literal=False) -> dict[str, Any]:
|
||||
"""获取所有系统变量
|
||||
|
||||
Returns:
|
||||
系统变量字典
|
||||
"""
|
||||
sys_namespace = self.variables.get("sys", {})
|
||||
if literal:
|
||||
return {k: v.instance.to_literal() for k, v in sys_namespace.items()}
|
||||
return {k: v.instance.get_value() for k, v in sys_namespace.items()}
|
||||
|
||||
def get_all_conversation_vars(self) -> dict[str, Any]:
|
||||
def get_all_conversation_vars(self, literal=False) -> dict[str, Any]:
|
||||
"""获取所有会话变量
|
||||
|
||||
Returns:
|
||||
会话变量字典
|
||||
"""
|
||||
conv_namespace = self.variables.get("conv", {})
|
||||
if literal:
|
||||
return {k: v.instance.to_literal() for k, v in conv_namespace.items()}
|
||||
return {k: v.instance.get_value() for k, v in conv_namespace.items()}
|
||||
|
||||
def get_all_node_outputs(self) -> dict[str, Any]:
|
||||
def get_all_node_outputs(self, literal=False) -> dict[str, Any]:
|
||||
"""获取所有节点输出(运行时变量)
|
||||
|
||||
Returns:
|
||||
节点输出字典,键为节点 ID
|
||||
"""
|
||||
runtime_vars = {
|
||||
namespace: {
|
||||
k: v.instance.get_value()
|
||||
for k, v in vars_dict.items()
|
||||
if literal:
|
||||
runtime_vars = {
|
||||
namespace: {
|
||||
k: v.instance.to_literal()
|
||||
for k, v in vars_dict.items()
|
||||
}
|
||||
for namespace, vars_dict in self.variables.items()
|
||||
if namespace not in ("sys", "conv")
|
||||
}
|
||||
else:
|
||||
runtime_vars = {
|
||||
namespace: {
|
||||
k: v.instance.get_value()
|
||||
for k, v in vars_dict.items()
|
||||
}
|
||||
for namespace, vars_dict in self.variables.items()
|
||||
if namespace not in ("sys", "conv")
|
||||
}
|
||||
for namespace, vars_dict in self.variables.items()
|
||||
if namespace not in ("sys", "conv")
|
||||
}
|
||||
return runtime_vars
|
||||
|
||||
def get_node_output(self, node_id: str, defalut: Any = None, strict: bool = True) -> dict[str, Any] | None:
|
||||
|
||||
@@ -132,24 +132,24 @@ class WorkflowExecutor:
|
||||
|
||||
start_time = datetime.datetime.now()
|
||||
|
||||
# Build the workflow graph
|
||||
graph = self.build_graph()
|
||||
|
||||
# Initialize the variable pool with input data
|
||||
await self.variable_initializer.initialize(
|
||||
variable_pool=self.variable_pool,
|
||||
input_data=input_data,
|
||||
execution_context=self.execution_context
|
||||
)
|
||||
initial_state = self.state_manager.create_initial_state(
|
||||
workflow_config=self.workflow_config,
|
||||
input_data=input_data,
|
||||
execution_context=self.execution_context,
|
||||
start_node_id=self.start_node_id
|
||||
)
|
||||
|
||||
# Execute the workflow
|
||||
try:
|
||||
# Build the workflow graph
|
||||
graph = self.build_graph()
|
||||
|
||||
# Initialize the variable pool with input data
|
||||
await self.variable_initializer.initialize(
|
||||
variable_pool=self.variable_pool,
|
||||
input_data=input_data,
|
||||
execution_context=self.execution_context
|
||||
)
|
||||
initial_state = self.state_manager.create_initial_state(
|
||||
workflow_config=self.workflow_config,
|
||||
input_data=input_data,
|
||||
execution_context=self.execution_context,
|
||||
start_node_id=self.start_node_id
|
||||
)
|
||||
|
||||
result = await graph.ainvoke(initial_state, config=self.execution_context.checkpoint_config)
|
||||
|
||||
# Aggregate output from all End nodes
|
||||
@@ -158,24 +158,42 @@ class WorkflowExecutor:
|
||||
full_content += self.variable_pool.get_value(f"{end_id}.output", default="", strict=False)
|
||||
|
||||
# Append messages for user and assistant
|
||||
result["messages"].extend(
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": input_data.get("message", '')
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": full_content
|
||||
}
|
||||
]
|
||||
)
|
||||
if input_data.get("files"):
|
||||
result["messages"].extend(
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": input_data.get("message", '')
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": input_data.get("files")
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": full_content
|
||||
}
|
||||
]
|
||||
)
|
||||
else:
|
||||
result["messages"].extend(
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": input_data.get("message", '')
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": full_content
|
||||
}
|
||||
]
|
||||
)
|
||||
# Calculate elapsed time
|
||||
end_time = datetime.datetime.now()
|
||||
elapsed_time = (end_time - start_time).total_seconds()
|
||||
|
||||
logger.info(
|
||||
f"Workflow execution completed: execution_id={self.execution_context.execution_id}, elapsed_time={elapsed_time:.2f}s")
|
||||
f"Workflow execution completed: execution_id={self.execution_context.execution_id}, elapsed_time={elapsed_time:.2f}ms")
|
||||
|
||||
return self.result_builder.build_final_output(result, self.variable_pool, elapsed_time, full_content)
|
||||
|
||||
@@ -231,23 +249,23 @@ class WorkflowExecutor:
|
||||
}
|
||||
}
|
||||
|
||||
# Build the workflow graph in streaming mode
|
||||
graph = self.build_graph(stream=True)
|
||||
|
||||
# Initialize the variable pool and system variables
|
||||
await self.variable_initializer.initialize(
|
||||
variable_pool=self.variable_pool,
|
||||
input_data=input_data,
|
||||
execution_context=self.execution_context
|
||||
)
|
||||
initial_state = self.state_manager.create_initial_state(
|
||||
workflow_config=self.workflow_config,
|
||||
input_data=input_data,
|
||||
execution_context=self.execution_context,
|
||||
start_node_id=self.start_node_id
|
||||
)
|
||||
|
||||
try:
|
||||
# Build the workflow graph in streaming mode
|
||||
graph = self.build_graph(stream=True)
|
||||
|
||||
# Initialize the variable pool and system variables
|
||||
await self.variable_initializer.initialize(
|
||||
variable_pool=self.variable_pool,
|
||||
input_data=input_data,
|
||||
execution_context=self.execution_context
|
||||
)
|
||||
initial_state = self.state_manager.create_initial_state(
|
||||
workflow_config=self.workflow_config,
|
||||
input_data=input_data,
|
||||
execution_context=self.execution_context,
|
||||
start_node_id=self.start_node_id
|
||||
)
|
||||
|
||||
full_content = ''
|
||||
self.stream_coordinator.update_scope_activation("sys")
|
||||
|
||||
@@ -272,7 +290,7 @@ class WorkflowExecutor:
|
||||
event_type = data.get("type", "node_chunk") # "message" or "node_chunk"
|
||||
if event_type == "node_chunk":
|
||||
async for msg_event in self.event_handler.handle_node_chunk_event(data):
|
||||
full_content += msg_event["data"]["chunk"]
|
||||
full_content += msg_event["data"]["content"]
|
||||
yield msg_event
|
||||
|
||||
elif event_type == "node_error":
|
||||
@@ -295,12 +313,12 @@ class WorkflowExecutor:
|
||||
self.graph,
|
||||
self.execution_context.checkpoint_config
|
||||
):
|
||||
full_content += msg_event["data"]['chunk']
|
||||
full_content += msg_event["data"]['content']
|
||||
yield msg_event
|
||||
|
||||
# Flush any remaining chunks
|
||||
async for msg_event in self.stream_coordinator.flush_remaining_chunk(self.variable_pool):
|
||||
full_content += msg_event["data"]['chunk']
|
||||
full_content += msg_event["data"]['content']
|
||||
yield msg_event
|
||||
|
||||
result = graph.get_state(self.execution_context.checkpoint_config).values
|
||||
@@ -308,21 +326,39 @@ class WorkflowExecutor:
|
||||
elapsed_time = (end_time - start_time).total_seconds()
|
||||
|
||||
# Append messages for user and assistant
|
||||
result["messages"].extend(
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": input_data.get("message", '')
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": full_content
|
||||
}
|
||||
]
|
||||
)
|
||||
if input_data.get("files"):
|
||||
result["messages"].extend(
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": input_data.get("message", '')
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": input_data.get("files")
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": full_content
|
||||
}
|
||||
]
|
||||
)
|
||||
else:
|
||||
result["messages"].extend(
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": input_data.get("message", '')
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": full_content
|
||||
}
|
||||
]
|
||||
)
|
||||
logger.info(
|
||||
f"Workflow execution completed (streaming), "
|
||||
f"elapsed: {elapsed_time:.2f}s, execution_id: {self.execution_context.execution_id}"
|
||||
f"elapsed: {elapsed_time:.2f}ms, execution_id: {self.execution_context.execution_id}"
|
||||
)
|
||||
|
||||
yield {
|
||||
|
||||
@@ -14,9 +14,9 @@ from app.core.workflow.engine.state_manager import WorkflowState
|
||||
from app.core.workflow.engine.variable_pool import VariablePool
|
||||
from app.core.workflow.nodes.base_node import BaseNode
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
from app.db import get_db
|
||||
from app.db import get_db_context
|
||||
from app.models import AppRelease
|
||||
from app.services.draft_run_service import DraftRunService
|
||||
from app.services.draft_run_service import AgentRunService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -39,7 +39,7 @@ class AgentNode(BaseNode):
|
||||
def _output_types(self) -> dict[str, VariableType]:
|
||||
return {"output": VariableType.STRING}
|
||||
|
||||
def _prepare_agent(self, variable_pool: VariablePool) -> tuple[DraftRunService, AppRelease, str]:
|
||||
def _prepare_agent(self, variable_pool: VariablePool) -> tuple[AppRelease, str]:
|
||||
"""准备 Agent(公共逻辑)
|
||||
|
||||
Args:
|
||||
@@ -57,17 +57,17 @@ class AgentNode(BaseNode):
|
||||
if not agent_id:
|
||||
raise ValueError(f"节点 {self.node_id} 缺少 agent_id 配置")
|
||||
|
||||
db = next(get_db())
|
||||
release = db.query(AppRelease).filter(
|
||||
AppRelease.id == agent_id
|
||||
).first()
|
||||
with get_db_context() as db:
|
||||
release = db.query(AppRelease).filter(
|
||||
AppRelease.id == agent_id
|
||||
).first()
|
||||
|
||||
if not release:
|
||||
raise ValueError(f"Agent 不存在: {agent_id}")
|
||||
|
||||
draft_service = DraftRunService(db)
|
||||
|
||||
|
||||
return draft_service, release, message
|
||||
return release, message
|
||||
|
||||
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]:
|
||||
"""非流式执行
|
||||
@@ -79,19 +79,21 @@ class AgentNode(BaseNode):
|
||||
Returns:
|
||||
状态更新字典
|
||||
"""
|
||||
draft_service, release, message = self._prepare_agent(variable_pool)
|
||||
release, message = self._prepare_agent(variable_pool)
|
||||
|
||||
logger.info(f"节点 {self.node_id} 开始执行 Agent 调用(非流式)")
|
||||
|
||||
# 执行 Agent(非流式)
|
||||
result = await draft_service.run(
|
||||
agent_config=release.config,
|
||||
model_config=None,
|
||||
message=message,
|
||||
workspace_id=variable_pool.get_value("sys.workspace_id"),
|
||||
user_id=state.get("user_id"),
|
||||
variables=variable_pool.get_all_conversation_vars()
|
||||
)
|
||||
with get_db_context() as db:
|
||||
draft_service = AgentRunService(db)
|
||||
|
||||
# 执行 Agent(非流式)
|
||||
result = await draft_service.run(
|
||||
agent_config=release.config,
|
||||
model_config=None,
|
||||
message=message,
|
||||
workspace_id=variable_pool.get_value("sys.workspace_id"),
|
||||
user_id=state.get("user_id"),
|
||||
variables=variable_pool.get_all_conversation_vars()
|
||||
)
|
||||
|
||||
response = result.get("response", "")
|
||||
|
||||
@@ -118,34 +120,35 @@ class AgentNode(BaseNode):
|
||||
Yields:
|
||||
流式事件字典
|
||||
"""
|
||||
draft_service, release, message = self._prepare_agent(variable_pool)
|
||||
release, message = self._prepare_agent(variable_pool)
|
||||
|
||||
logger.info(f"节点 {self.node_id} 开始执行 Agent 调用(流式)")
|
||||
|
||||
# 累积完整响应
|
||||
full_response = ""
|
||||
|
||||
with get_db_context() as db:
|
||||
draft_service = AgentRunService(db)
|
||||
# 执行 Agent(流式)
|
||||
async for chunk in draft_service.run_stream(
|
||||
agent_config=release.config,
|
||||
model_config=None,
|
||||
message=message,
|
||||
workspace_id=variable_pool.get_value("sys.workspace_id"),
|
||||
user_id=state.get("user_id"),
|
||||
variables=variable_pool.get_all_conversation_vars()
|
||||
):
|
||||
# 提取内容
|
||||
content = chunk.get("content", "")
|
||||
full_response += content
|
||||
|
||||
# 流式返回每个 chunk
|
||||
yield {
|
||||
"type": "chunk",
|
||||
"node_id": self.node_id,
|
||||
"content": content,
|
||||
"full_content": full_response,
|
||||
"meta_data": chunk.get("meta_data", {})
|
||||
}
|
||||
async for chunk in draft_service.run_stream(
|
||||
agent_config=release.config,
|
||||
model_config=None,
|
||||
message=message,
|
||||
workspace_id=variable_pool.get_value("sys.workspace_id"),
|
||||
user_id=state.get("user_id"),
|
||||
variables=variable_pool.get_all_conversation_vars()
|
||||
):
|
||||
# 提取内容
|
||||
content = chunk.get("content", "")
|
||||
full_response += content
|
||||
|
||||
# 流式返回每个 chunk
|
||||
yield {
|
||||
"type": "chunk",
|
||||
"node_id": self.node_id,
|
||||
"content": content,
|
||||
"full_content": full_response,
|
||||
"meta_data": chunk.get("meta_data", {})
|
||||
}
|
||||
|
||||
logger.info(f"节点 {self.node_id} Agent 调用完成,输出长度: {len(full_response)}")
|
||||
|
||||
|
||||
@@ -88,6 +88,8 @@ class AssignerNode(BaseNode):
|
||||
await operator.remove_first()
|
||||
case AssignmentOperator.REMOVE_LAST:
|
||||
await operator.remove_last()
|
||||
case AssignmentOperator.EXTEND:
|
||||
await operator.extend()
|
||||
case _:
|
||||
raise ValueError(f"Invalid Operator: {assignment.operation}")
|
||||
logger.info(f"Node {self.node_id}: execution completed")
|
||||
|
||||
@@ -85,20 +85,20 @@ class BaseNodeConfig(BaseModel):
|
||||
- tags: 节点标签(用于分类和搜索)
|
||||
"""
|
||||
|
||||
name: str | None = Field(
|
||||
default=None,
|
||||
description="节点名称(显示名称),如果不设置则使用节点 ID"
|
||||
)
|
||||
|
||||
description: str | None = Field(
|
||||
default=None,
|
||||
description="节点描述,说明节点的作用"
|
||||
)
|
||||
|
||||
tags: list[str] = Field(
|
||||
default_factory=list,
|
||||
description="节点标签,用于分类和搜索"
|
||||
)
|
||||
# name: str | None = Field(
|
||||
# default=None,
|
||||
# description="节点名称(显示名称),如果不设置则使用节点 ID"
|
||||
# )
|
||||
#
|
||||
# description: str | None = Field(
|
||||
# default=None,
|
||||
# description="节点描述,说明节点的作用"
|
||||
# )
|
||||
#
|
||||
# tags: list[str] = Field(
|
||||
# default_factory=list,
|
||||
# description="节点标签,用于分类和搜索"
|
||||
# )
|
||||
|
||||
class Config:
|
||||
"""Pydantic 配置"""
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import datetime
|
||||
from functools import cached_property
|
||||
from typing import Any, AsyncGenerator
|
||||
|
||||
@@ -10,8 +11,11 @@ from app.core.config import settings
|
||||
from app.core.workflow.engine.state_manager import WorkflowState
|
||||
from app.core.workflow.engine.variable_pool import VariablePool
|
||||
from app.core.workflow.nodes.enums import BRANCH_NODES
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
from app.services.multimodal_service import PROVIDER_STRATEGIES
|
||||
from app.core.workflow.variable.base_variable import VariableType, FileObject
|
||||
from app.db import get_db_read
|
||||
from app.models import ModelConfig, ModelApiKey, LoadBalanceStrategy
|
||||
from app.schemas import FileInput
|
||||
from app.services.multimodal_service import MultimodalService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -196,7 +200,7 @@ class BaseNode(ABC):
|
||||
timeout=timeout
|
||||
)
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
elapsed_time = (time.time() - start_time) * 1000
|
||||
|
||||
# Extract processed outputs using subclass-defined logic.
|
||||
extracted_output = self._extract_output(business_result)
|
||||
@@ -219,7 +223,7 @@ class BaseNode(ABC):
|
||||
} | self.trans_activate(state)
|
||||
|
||||
except TimeoutError:
|
||||
elapsed_time = time.time() - start_time
|
||||
elapsed_time = (time.time() - start_time) * 1000
|
||||
logger.error(
|
||||
f"Node {self.node_id} execution timed out ({timeout} seconds)."
|
||||
)
|
||||
@@ -230,7 +234,7 @@ class BaseNode(ABC):
|
||||
variable_pool,
|
||||
)
|
||||
except Exception as e:
|
||||
elapsed_time = time.time() - start_time
|
||||
elapsed_time = (time.time() - start_time) * 1000
|
||||
logger.error(
|
||||
f"Node {self.node_id} execution failed: {e}",
|
||||
exc_info=True,
|
||||
@@ -307,10 +311,10 @@ class BaseNode(ABC):
|
||||
"done": done
|
||||
})
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
elapsed_time = (time.time() - start_time) * 1000
|
||||
|
||||
logger.info(f"Node {self.node_id} streaming execution finished, "
|
||||
f"time elapsed: {elapsed_time:.2f}s, chunks: {chunk_count}")
|
||||
f"time elapsed: {elapsed_time:.2f}ms, chunks: {chunk_count}")
|
||||
|
||||
# Extract processed output (call subclass's _extract_output)
|
||||
extracted_output = self._extract_output(final_result)
|
||||
@@ -337,7 +341,7 @@ class BaseNode(ABC):
|
||||
yield state_update | self.trans_activate(state)
|
||||
|
||||
except TimeoutError:
|
||||
elapsed_time = time.time() - start_time
|
||||
elapsed_time = (time.time() - start_time) * 1000
|
||||
logger.error(f"Node {self.node_id} execution timed out ({timeout}s)")
|
||||
error_output = self._wrap_error(
|
||||
f"Node execution timed out ({timeout}s)",
|
||||
@@ -347,7 +351,7 @@ class BaseNode(ABC):
|
||||
)
|
||||
yield error_output
|
||||
except Exception as e:
|
||||
elapsed_time = time.time() - start_time
|
||||
elapsed_time = (time.time() - start_time) * 1000
|
||||
logger.error(f"Node {self.node_id} execution failed: {e}", exc_info=True)
|
||||
error_output = self._wrap_error(str(e), elapsed_time, state, variable_pool)
|
||||
yield error_output
|
||||
@@ -548,9 +552,9 @@ class BaseNode(ABC):
|
||||
|
||||
return render_template(
|
||||
template=template,
|
||||
conv_vars=variable_pool.get_all_conversation_vars(),
|
||||
node_outputs=variable_pool.get_all_node_outputs(),
|
||||
system_vars=variable_pool.get_all_system_vars(),
|
||||
conv_vars=variable_pool.get_all_conversation_vars(literal=True),
|
||||
node_outputs=variable_pool.get_all_node_outputs(literal=True),
|
||||
system_vars=variable_pool.get_all_system_vars(literal=True),
|
||||
strict=strict
|
||||
)
|
||||
|
||||
@@ -614,16 +618,45 @@ class BaseNode(ABC):
|
||||
return variable_pool.has(selector)
|
||||
|
||||
@staticmethod
|
||||
async def process_message(provider, content, enable_file=False) -> dict | str | None:
|
||||
async def process_message(
|
||||
provider: str,
|
||||
is_omni: bool,
|
||||
content: str | dict | FileObject,
|
||||
enable_file=False
|
||||
) -> list | str | None:
|
||||
if isinstance(content, dict):
|
||||
content = FileObject(
|
||||
type=content.get("type"),
|
||||
url=content.get("url"),
|
||||
transfer_method=content.get("transfer_method"),
|
||||
origin_file_type=content.get("origin_file_type"),
|
||||
file_id=content.get("file_id"),
|
||||
is_file=True
|
||||
)
|
||||
if isinstance(content, str):
|
||||
if enable_file:
|
||||
return {"text": content}
|
||||
return [{"type": "text", "text": content}]
|
||||
return content
|
||||
elif isinstance(content, dict):
|
||||
trans_tool = PROVIDER_STRATEGIES[provider]()
|
||||
result = await trans_tool.format_image(content["url"])
|
||||
return result
|
||||
raise TypeError('Unexpect input value type')
|
||||
|
||||
elif isinstance(content, FileObject):
|
||||
if content.content_cache.get(provider):
|
||||
return content.content_cache[provider]
|
||||
with get_db_read() as db:
|
||||
multimodel_service = MultimodalService(db, provider, is_omni=is_omni)
|
||||
message = await multimodel_service.process_files(
|
||||
[FileInput.model_construct(
|
||||
type=content.type,
|
||||
url=content.url,
|
||||
transfer_method=content.transfer_method,
|
||||
file_type=content.origin_file_type,
|
||||
upload_file_id=content.file_id
|
||||
)]
|
||||
)
|
||||
if message:
|
||||
content.content_cache[provider] = message
|
||||
return message
|
||||
return None
|
||||
raise TypeError(f'Unexpect input value type - {type(content)}')
|
||||
|
||||
@staticmethod
|
||||
def process_model_output(content) -> str:
|
||||
@@ -639,3 +672,12 @@ class BaseNode(ABC):
|
||||
elif isinstance(content, str):
|
||||
return content
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def model_balance(model_config: ModelConfig) -> ModelApiKey:
|
||||
api_keys = [key for key in model_config.api_keys if key.is_active]
|
||||
if not api_keys:
|
||||
raise ValueError("No active API keys available for model")
|
||||
if model_config.load_balance_strategy == LoadBalanceStrategy.ROUND_ROBIN:
|
||||
return min(api_keys, key=lambda x: (int(x.usage_count or "0"), x.last_used_at or datetime.min))
|
||||
return api_keys[0]
|
||||
|
||||
@@ -91,8 +91,8 @@ class IterationRuntime:
|
||||
return loopstate
|
||||
|
||||
def merge_conv_vars(self):
|
||||
self.variable_pool.get_all_conversation_vars().update(
|
||||
self.child_variable_pool.get_all_conversation_vars()
|
||||
self.variable_pool.variables["conv"].update(
|
||||
self.child_variable_pool.variables["conv"]
|
||||
)
|
||||
|
||||
async def run_task(self, item, idx):
|
||||
|
||||
@@ -156,7 +156,7 @@ class LoopRuntime:
|
||||
|
||||
def merge_conv_vars(self, loopstate):
|
||||
self.variable_pool.variables["conv"].update(
|
||||
self.child_variable_pool.variables.get("conv", {})
|
||||
self.child_variable_pool.variables["conv"]
|
||||
)
|
||||
loop_vars = self.child_variable_pool.get_node_output(self.node_id, defalut={}, strict=False)
|
||||
loopstate["node_outputs"][self.node_id] = loop_vars
|
||||
|
||||
@@ -66,7 +66,7 @@ class CycleGraphNode(BaseNode):
|
||||
if config.flatten:
|
||||
outputs['output'] = config.output_type
|
||||
else:
|
||||
outputs['output'] = VariableType.ARRAY_STRING
|
||||
outputs['output'] = VariableType.NESTED_ARRAY
|
||||
else:
|
||||
outputs['output'] = VariableType(f"array[{config.output_type}]")
|
||||
return outputs
|
||||
|
||||
@@ -17,17 +17,17 @@ class EndNodeConfig(BaseNodeConfig):
|
||||
description="输出模板,支持引用前置节点的输出,如:{{ llm_qa.output }}"
|
||||
)
|
||||
|
||||
# 输出变量定义
|
||||
output_variables: list[VariableDefinition] = Field(
|
||||
default_factory=lambda: [
|
||||
VariableDefinition(
|
||||
name="output",
|
||||
type=VariableType.STRING,
|
||||
description="工作流的最终输出"
|
||||
)
|
||||
],
|
||||
description="输出变量定义(自动生成,通常不需要修改)"
|
||||
)
|
||||
# # 输出变量定义
|
||||
# output_variables: list[VariableDefinition] = Field(
|
||||
# default_factory=lambda: [
|
||||
# VariableDefinition(
|
||||
# name="output",
|
||||
# type=VariableType.STRING,
|
||||
# description="工作流的最终输出"
|
||||
# )
|
||||
# ],
|
||||
# description="输出变量定义(自动生成,通常不需要修改)"
|
||||
# )
|
||||
|
||||
class Config:
|
||||
json_schema_extra = {
|
||||
|
||||
@@ -24,6 +24,9 @@ class NodeType(StrEnum):
|
||||
MEMORY_READ = "memory-read"
|
||||
MEMORY_WRITE = "memory-write"
|
||||
|
||||
UNKNOWN = "unknown"
|
||||
NOTES = "notes"
|
||||
|
||||
|
||||
BRANCH_NODES = [NodeType.IF_ELSE, NodeType.HTTP_REQUEST, NodeType.QUESTION_CLASSIFIER]
|
||||
|
||||
@@ -61,6 +64,7 @@ class AssignmentOperator(StrEnum):
|
||||
APPEND = "append"
|
||||
REMOVE_LAST = "remove_last"
|
||||
REMOVE_FIRST = "remove_first"
|
||||
EXTEND = "extend"
|
||||
|
||||
|
||||
class HttpRequestMethod(StrEnum):
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Any, Callable, Coroutine
|
||||
|
||||
import httpx
|
||||
@@ -13,6 +14,7 @@ from app.core.workflow.nodes.base_node import BaseNode
|
||||
from app.core.workflow.nodes.enums import HttpRequestMethod, HttpErrorHandle, HttpAuthType, HttpContentType
|
||||
from app.core.workflow.nodes.http_request.config import HttpRequestNodeConfig, HttpRequestNodeOutput
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
from app.core.workflow.variable.variable_objects import FileVariable, ArrayVariable
|
||||
|
||||
logger = logging.getLogger(__file__)
|
||||
|
||||
@@ -115,7 +117,7 @@ class HttpRequestNode(BaseNode):
|
||||
params[self._render_template(key, variable_pool)] = self._render_template(value, variable_pool)
|
||||
return params
|
||||
|
||||
def _build_content(self, variable_pool: VariablePool) -> dict[str, Any]:
|
||||
async def _build_content(self, variable_pool: VariablePool) -> dict[str, Any]:
|
||||
"""
|
||||
Build HTTP request body arguments for httpx request methods.
|
||||
|
||||
@@ -135,16 +137,35 @@ class HttpRequestNode(BaseNode):
|
||||
))
|
||||
case HttpContentType.FROM_DATA:
|
||||
data = {}
|
||||
content["files"] = {}
|
||||
for item in self.typed_config.body.data:
|
||||
if item.type == "text":
|
||||
data[self._render_template(item.key, variable_pool)] = self._render_template(item.value, variable_pool)
|
||||
data[self._render_template(item.key, variable_pool)] = self._render_template(item.value,
|
||||
variable_pool)
|
||||
elif item.type == "file":
|
||||
# TODO: File support (Feature)
|
||||
pass
|
||||
content["files"][self._render_template(item.key, variable_pool)] = (
|
||||
uuid.uuid4().hex,
|
||||
await variable_pool.get_instance(item.value).get_content()
|
||||
)
|
||||
content["data"] = data
|
||||
case HttpContentType.BINARY:
|
||||
# TODO: File support (Feature)
|
||||
pass
|
||||
content["files"] = []
|
||||
file_instence = variable_pool.get_instance(self.typed_config.body.data)
|
||||
if isinstance(file_instence, ArrayVariable):
|
||||
for v in file_instence.value:
|
||||
if isinstance(v, FileVariable):
|
||||
content["files"].append(
|
||||
(
|
||||
"files", (uuid.uuid4().hex, await v.get_content())
|
||||
)
|
||||
)
|
||||
elif isinstance(file_instence, FileVariable):
|
||||
content["files"].append(
|
||||
(
|
||||
"file", (uuid.uuid4().hex, await file_instence.get_content())
|
||||
)
|
||||
)
|
||||
|
||||
case HttpContentType.WWW_FORM:
|
||||
content["data"] = json.loads(self._render_template(
|
||||
json.dumps(self.typed_config.body.data), variable_pool
|
||||
@@ -207,7 +228,7 @@ class HttpRequestNode(BaseNode):
|
||||
request_func = self._get_client_method(client)
|
||||
resp = await request_func(
|
||||
url=self._render_template(self.typed_config.url, variable_pool),
|
||||
**self._build_content(variable_pool)
|
||||
**(await self._build_content(variable_pool))
|
||||
)
|
||||
resp.raise_for_status()
|
||||
logger.info(f"Node {self.node_id}: HTTP request succeeded")
|
||||
@@ -236,5 +257,5 @@ class HttpRequestNode(BaseNode):
|
||||
logger.warning(
|
||||
f"Node {self.node_id}: HTTP request failed, switching to error handling branch"
|
||||
)
|
||||
return "ERROR"
|
||||
return {"output": "ERROR"}
|
||||
raise RuntimeError("http request failed")
|
||||
|
||||
@@ -40,7 +40,7 @@ class KnowledgeRetrievalNodeConfig(BaseNodeConfig):
|
||||
)
|
||||
|
||||
knowledge_bases: list[KnowledgeBaseConfig] = Field(
|
||||
...,
|
||||
default_factory=list,
|
||||
description="Knowledge base config"
|
||||
)
|
||||
|
||||
|
||||
@@ -180,6 +180,8 @@ class KnowledgeRetrievalNode(BaseNode):
|
||||
RuntimeError: If no valid knowledge base is found or access is denied.
|
||||
"""
|
||||
self.typed_config = KnowledgeRetrievalNodeConfig(**self.config)
|
||||
if not self.typed_config.knowledge_bases:
|
||||
return []
|
||||
query = self._render_template(self.typed_config.query, variable_pool)
|
||||
with get_db_read() as db:
|
||||
knowledge_bases = self.typed_config.knowledge_bases
|
||||
|
||||
@@ -112,11 +112,12 @@ class LLMNode(BaseNode):
|
||||
raise BusinessException("模型配置缺少 API Key", BizCode.INVALID_PARAMETER)
|
||||
|
||||
# 在 Session 关闭前提取所有需要的数据
|
||||
api_config = config.api_keys[0]
|
||||
api_config = self.model_balance(config)
|
||||
model_name = api_config.model_name
|
||||
provider = api_config.provider
|
||||
api_key = api_config.api_key
|
||||
api_base = api_config.api_base
|
||||
is_omni = api_config.is_omni
|
||||
model_type = config.type
|
||||
|
||||
# 4. 创建 LLM 实例(使用已提取的数据)
|
||||
@@ -129,7 +130,8 @@ class LLMNode(BaseNode):
|
||||
provider=provider,
|
||||
api_key=api_key,
|
||||
base_url=api_base,
|
||||
extra_params=extra_params
|
||||
extra_params=extra_params,
|
||||
is_omni=is_omni
|
||||
),
|
||||
type=ModelType(model_type)
|
||||
)
|
||||
@@ -151,39 +153,53 @@ class LLMNode(BaseNode):
|
||||
if role == "system":
|
||||
messages.append({
|
||||
"role": "system",
|
||||
"content": content
|
||||
"content": await self.process_message(provider, is_omni, content, self.typed_config.vision)
|
||||
})
|
||||
elif role in ["user", "human"]:
|
||||
messages.append({
|
||||
"role": "user",
|
||||
"content": content
|
||||
"content": await self.process_message(provider, is_omni, content, self.typed_config.vision)
|
||||
})
|
||||
elif role in ["ai", "assistant"]:
|
||||
messages.append({
|
||||
"role": "assistant",
|
||||
"content": content
|
||||
"content": await self.process_message(provider, is_omni, content, self.typed_config.vision)
|
||||
})
|
||||
else:
|
||||
logger.warning(f"未知的消息角色: {role},默认使用 user")
|
||||
messages.append({
|
||||
"role": "user",
|
||||
"content": content
|
||||
"content": await self.process_message(provider, is_omni, content, self.typed_config.vision)
|
||||
})
|
||||
|
||||
if self.typed_config.vision_input and self.typed_config.vision:
|
||||
file_content = []
|
||||
files = variable_pool.get_value(self.typed_config.vision_input)
|
||||
for file in files:
|
||||
content = await self.process_message(provider, file, self.typed_config.vision)
|
||||
files = variable_pool.get_instance(self.typed_config.vision_input)
|
||||
for file in files.value:
|
||||
content = await self.process_message(provider, is_omni, file.value, self.typed_config.vision)
|
||||
if content:
|
||||
file_content.append(content)
|
||||
file_content.extend(content)
|
||||
if messages and messages[-1]["role"] == 'user':
|
||||
messages[-1]['content'] = [messages[-1]["content"]] + file_content
|
||||
messages[-1]['content'] = messages[-1]["content"] + file_content
|
||||
else:
|
||||
messages.append({"role": "user", "content": file_content})
|
||||
|
||||
if self.typed_config.memory.enable:
|
||||
messages = messages[:-1] + state["messages"][-self.typed_config.memory.window_size:] + messages[-1:]
|
||||
history_message = []
|
||||
for message in state["messages"][-self.typed_config.memory.window_size:]:
|
||||
if isinstance(message["content"], list):
|
||||
file_content = []
|
||||
for file in message["content"]:
|
||||
content = await self.process_message(provider, is_omni, file, self.typed_config.vision)
|
||||
if content:
|
||||
file_content.extend(content)
|
||||
history_message.append(
|
||||
{"role": message["role"], "content": file_content}
|
||||
)
|
||||
else:
|
||||
message["content"] = await self.process_message(provider, is_omni, message["content"], self.typed_config.vision)
|
||||
history_message.append(message)
|
||||
messages = messages[:-1] + history_message + messages[-1:]
|
||||
self.messages = messages
|
||||
else:
|
||||
# 使用简单的 prompt 格式(向后兼容)
|
||||
|
||||
@@ -123,10 +123,10 @@ class NodeFactory:
|
||||
# 获取节点类
|
||||
node_class = cls._node_types.get(node_type)
|
||||
if not node_class:
|
||||
raise ValueError(f"不支持的节点类型: {node_type}")
|
||||
raise ValueError(f"Unsupported node type: {node_type}")
|
||||
|
||||
# 创建节点实例
|
||||
logger.debug(f"创建节点: {node_config.get('id')} (type={node_type})")
|
||||
logger.debug(f"create node instance: {node_config.get('id')} (type={node_type})")
|
||||
return node_class(node_config, workflow_config)
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -95,11 +95,12 @@ class ParameterExtractorNode(BaseNode):
|
||||
if not config.api_keys or len(config.api_keys) == 0:
|
||||
raise BusinessException("Model configuration is missing API Key", BizCode.INVALID_PARAMETER)
|
||||
|
||||
api_config = config.api_keys[0]
|
||||
api_config = self.model_balance(config)
|
||||
model_name = api_config.model_name
|
||||
provider = api_config.provider
|
||||
api_key = api_config.api_key
|
||||
api_base = api_config.api_base
|
||||
is_omni = api_config.is_omni
|
||||
model_type = config.type
|
||||
|
||||
llm = RedBearLLM(
|
||||
@@ -108,6 +109,7 @@ class ParameterExtractorNode(BaseNode):
|
||||
provider=provider,
|
||||
api_key=api_key,
|
||||
base_url=api_base,
|
||||
is_omni=is_omni
|
||||
),
|
||||
type=ModelType(model_type)
|
||||
)
|
||||
|
||||
@@ -56,11 +56,12 @@ class QuestionClassifierNode(BaseNode):
|
||||
if not config.api_keys or len(config.api_keys) == 0:
|
||||
raise BusinessException("模型配置缺少 API Key", BizCode.INVALID_PARAMETER)
|
||||
|
||||
api_config = config.api_keys[0]
|
||||
api_config = self.model_balance(config)
|
||||
model_name = api_config.model_name
|
||||
provider = api_config.provider
|
||||
api_key = api_config.api_key
|
||||
base_url = api_config.api_base
|
||||
is_omni = api_config.is_omni
|
||||
model_type = config.type
|
||||
|
||||
return RedBearLLM(
|
||||
@@ -69,6 +70,7 @@ class QuestionClassifierNode(BaseNode):
|
||||
provider=provider,
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
is_omni=is_omni
|
||||
),
|
||||
type=ModelType(model_type)
|
||||
)
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
from pydantic import Field
|
||||
|
||||
from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableDefinition
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
|
||||
|
||||
class StartNodeConfig(BaseNodeConfig):
|
||||
@@ -21,42 +20,42 @@ class StartNodeConfig(BaseNodeConfig):
|
||||
description="自定义输入变量列表,这些变量会作为 Start 节点的输出"
|
||||
)
|
||||
|
||||
# 输出变量定义
|
||||
output_variables: list[VariableDefinition] = Field(
|
||||
default_factory=lambda: [
|
||||
VariableDefinition(
|
||||
name="message",
|
||||
type=VariableType.STRING,
|
||||
description="用户输入的消息"
|
||||
),
|
||||
VariableDefinition(
|
||||
name="conversation_vars",
|
||||
type=VariableType.OBJECT,
|
||||
description="会话级变量"
|
||||
),
|
||||
VariableDefinition(
|
||||
name="execution_id",
|
||||
type=VariableType.STRING,
|
||||
description="执行 ID"
|
||||
),
|
||||
VariableDefinition(
|
||||
name="conversation_id",
|
||||
type=VariableType.STRING,
|
||||
description="会话 ID"
|
||||
),
|
||||
VariableDefinition(
|
||||
name="workspace_id",
|
||||
type=VariableType.STRING,
|
||||
description="工作空间 ID"
|
||||
),
|
||||
VariableDefinition(
|
||||
name="user_id",
|
||||
type=VariableType.STRING,
|
||||
description="用户 ID"
|
||||
)
|
||||
],
|
||||
description="输出变量定义(自动生成,通常不需要修改)"
|
||||
)
|
||||
# # 输出变量定义
|
||||
# output_variables: list[VariableDefinition] = Field(
|
||||
# default_factory=lambda: [
|
||||
# VariableDefinition(
|
||||
# name="message",
|
||||
# type=VariableType.STRING,
|
||||
# description="用户输入的消息"
|
||||
# ),
|
||||
# VariableDefinition(
|
||||
# name="conversation_vars",
|
||||
# type=VariableType.OBJECT,
|
||||
# description="会话级变量"
|
||||
# ),
|
||||
# VariableDefinition(
|
||||
# name="execution_id",
|
||||
# type=VariableType.STRING,
|
||||
# description="执行 ID"
|
||||
# ),
|
||||
# VariableDefinition(
|
||||
# name="conversation_id",
|
||||
# type=VariableType.STRING,
|
||||
# description="会话 ID"
|
||||
# ),
|
||||
# VariableDefinition(
|
||||
# name="workspace_id",
|
||||
# type=VariableType.STRING,
|
||||
# description="工作空间 ID"
|
||||
# ),
|
||||
# VariableDefinition(
|
||||
# name="user_id",
|
||||
# type=VariableType.STRING,
|
||||
# description="用户 ID"
|
||||
# )
|
||||
# ],
|
||||
# description="输出变量定义(自动生成,通常不需要修改)"
|
||||
# )
|
||||
|
||||
class Config:
|
||||
json_schema_extra = {
|
||||
|
||||
@@ -12,9 +12,20 @@ class ExpressionEvaluator:
|
||||
|
||||
# Reserved namespaces
|
||||
RESERVED_NAMESPACES = {"var", "node", "sys", "nodes"}
|
||||
|
||||
@staticmethod
|
||||
|
||||
@classmethod
|
||||
def normalize_template(cls, template: str) -> str:
|
||||
pattern = re.compile(
|
||||
r"\{\{\s*(\d+)\.(\w+)\s*}}"
|
||||
)
|
||||
return pattern.sub(
|
||||
r'{{ node["\1"].\2 }}',
|
||||
template
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def evaluate(
|
||||
cls,
|
||||
expression: str,
|
||||
conv_vars: dict[str, Any],
|
||||
node_outputs: dict[str, Any],
|
||||
@@ -37,6 +48,7 @@ class ExpressionEvaluator:
|
||||
"""
|
||||
# Remove Jinja2-style brackets if present
|
||||
expression = expression.strip()
|
||||
expression = cls.normalize_template(expression)
|
||||
pattern = r"\{\{\s*(.*?)\s*\}\}"
|
||||
expression = re.sub(pattern, r"\1", expression).strip()
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
"""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from jinja2 import TemplateSyntaxError, UndefinedError, Environment, StrictUndefined, Undefined
|
||||
@@ -39,6 +40,16 @@ class TemplateRenderer:
|
||||
autoescape=False # 不自动转义,因为我们处理的是文本而非 HTML
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def normalize_template(template: str) -> str:
|
||||
pattern = re.compile(
|
||||
r"\{\{\s*(\d+)\.(\w+)\s*}}"
|
||||
)
|
||||
return pattern.sub(
|
||||
r'{{ node["\1"].\2 }}',
|
||||
template
|
||||
)
|
||||
|
||||
def render(
|
||||
self,
|
||||
template: str,
|
||||
@@ -95,7 +106,7 @@ class TemplateRenderer:
|
||||
context.update(conv_vars)
|
||||
|
||||
context["nodes"] = node_outputs or {} # 旧语法兼容
|
||||
|
||||
template = self.normalize_template(template)
|
||||
try:
|
||||
tmpl = self.env.from_string(template)
|
||||
return tmpl.render(**context)
|
||||
|
||||
@@ -138,7 +138,7 @@ class WorkflowValidator:
|
||||
errors.append("工作流必须至少有一个 end 节点")
|
||||
|
||||
# 3. 验证节点 ID 唯一性
|
||||
node_ids = [n.get("id") for n in nodes]
|
||||
node_ids = [n.get("id") for n in nodes if n.get("type") != NodeType.NOTES]
|
||||
if len(node_ids) != len(set(node_ids)):
|
||||
duplicates = [nid for nid in node_ids if node_ids.count(nid) > 1]
|
||||
errors.append(f"节点 ID 必须唯一,重复的 ID: {set(duplicates)}")
|
||||
|
||||
@@ -2,7 +2,7 @@ from enum import StrEnum
|
||||
from abc import abstractmethod, ABC
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.schemas import FileType
|
||||
|
||||
@@ -45,7 +45,7 @@ class VariableType(StrEnum):
|
||||
return cls.NUMBER
|
||||
elif isinstance(var, bool):
|
||||
return cls.BOOLEAN
|
||||
elif isinstance(var, FileObject) or (isinstance(var, dict) and var.get('__file')):
|
||||
elif isinstance(var, FileObject) or (isinstance(var, dict) and var.get('is_file')):
|
||||
return cls.FILE
|
||||
elif isinstance(var, dict):
|
||||
return cls.OBJECT
|
||||
@@ -109,7 +109,13 @@ def DEFAULT_VALUE(var_type: VariableType) -> Any:
|
||||
class FileObject(BaseModel):
|
||||
type: FileType
|
||||
url: str
|
||||
__file: bool
|
||||
transfer_method: str
|
||||
origin_file_type: str
|
||||
file_id: str | None
|
||||
|
||||
content_cache: dict = Field(default_factory=dict)
|
||||
|
||||
is_file: bool
|
||||
|
||||
|
||||
class BaseVariable(ABC):
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
from typing import Any, TypeVar, Type, Generic
|
||||
|
||||
import httpx
|
||||
from deprecated import deprecated
|
||||
|
||||
from app.core.workflow.variable.base_variable import BaseVariable, VariableType, FileObject, FileType
|
||||
from app.core.config import settings
|
||||
|
||||
T = TypeVar("T", bound=BaseVariable)
|
||||
|
||||
@@ -61,13 +63,16 @@ class FileVariable(BaseVariable):
|
||||
def valid_value(self, value) -> FileObject:
|
||||
|
||||
if isinstance(value, dict):
|
||||
if not value.get("__file"):
|
||||
if not value.get("is_file"):
|
||||
raise TypeError(f"Value must be a FileObject - {type(value)}:{value}")
|
||||
return FileObject(
|
||||
**{
|
||||
"type": str(value.get('type')),
|
||||
"transfer_method": value.get("transfer_method"),
|
||||
"url": value.get('url'),
|
||||
"__file": True
|
||||
"file_id": value.get("file_id"),
|
||||
"origin_file_type": value.get("origin_file_type"),
|
||||
"is_file": True
|
||||
}
|
||||
)
|
||||
if isinstance(value, FileObject):
|
||||
@@ -80,8 +85,23 @@ class FileVariable(BaseVariable):
|
||||
def get_value(self) -> Any:
|
||||
return self.value.model_dump()
|
||||
|
||||
async def get_content(self):
|
||||
total_bytes = 0
|
||||
chunks = []
|
||||
|
||||
class ArrayObject(BaseVariable, Generic[T]):
|
||||
async with httpx.AsyncClient() as client:
|
||||
async with client.stream("GET", self.value.url) as resp:
|
||||
resp.raise_for_status()
|
||||
async for chunk in resp.aiter_bytes(8192):
|
||||
total_bytes += len(chunk)
|
||||
if total_bytes > settings.MAX_FILE_SIZE:
|
||||
raise ValueError(f"File too large: {total_bytes} bytes")
|
||||
chunks.append(chunk)
|
||||
|
||||
return b"".join(chunks)
|
||||
|
||||
|
||||
class ArrayVariable(BaseVariable, Generic[T]):
|
||||
type = 'array'
|
||||
|
||||
def __init__(self, child_type: Type[T], value: list[Any]):
|
||||
@@ -108,7 +128,7 @@ class ArrayObject(BaseVariable, Generic[T]):
|
||||
return [v.get_value() for v in self.value]
|
||||
|
||||
|
||||
class NestedArrayObject(BaseVariable):
|
||||
class NestedArrayVariable(BaseVariable):
|
||||
type = 'array_nest'
|
||||
|
||||
def valid_value(self, value: list[T]) -> list[T]:
|
||||
@@ -116,23 +136,23 @@ class NestedArrayObject(BaseVariable):
|
||||
raise TypeError(f"Value must be a list - {type(value)}:{value}")
|
||||
final_value = []
|
||||
for v in value:
|
||||
if not isinstance(v, ArrayObject):
|
||||
if not isinstance(v, list):
|
||||
raise TypeError("All elements must be of type list")
|
||||
final_value.append(v)
|
||||
final_value.append(make_array(AnyVariable, v))
|
||||
return final_value
|
||||
|
||||
def to_literal(self) -> str:
|
||||
return "\n".join(["\n".join([item.to_literal() for item in row]) for row in self.value])
|
||||
return "\n".join(["\n".join([str(item) for item in row.get_value()]) for row in self.value])
|
||||
|
||||
def get_value(self) -> Any:
|
||||
return [[item.get_value() for item in row] for row in self.value]
|
||||
return [[item for item in row.get_value()] for row in self.value]
|
||||
|
||||
|
||||
@deprecated(
|
||||
reason="Using arbitrary-type values may cause unexpected errors; please switch to strongly-typed values.",
|
||||
category=RuntimeWarning
|
||||
)
|
||||
class AnyObject(BaseVariable):
|
||||
class AnyVariable(BaseVariable):
|
||||
type = 'any'
|
||||
|
||||
def valid_value(self, value: Any) -> Any:
|
||||
@@ -142,10 +162,10 @@ class AnyObject(BaseVariable):
|
||||
return str(self.value)
|
||||
|
||||
|
||||
def make_array(child_type: Type[T], value: list[Any]) -> ArrayObject[T]:
|
||||
"""简化 ArrayObject 创建,不需要重复写类型"""
|
||||
def make_array(child_type: Type[T], value: list[Any]) -> ArrayVariable[T]:
|
||||
"""简化 ArrayVariable 创建,不需要重复写类型"""
|
||||
|
||||
return ArrayObject(child_type, value)
|
||||
return ArrayVariable(child_type, value)
|
||||
|
||||
|
||||
def create_variable_instance(var_type: VariableType, value: Any) -> T:
|
||||
@@ -168,7 +188,9 @@ def create_variable_instance(var_type: VariableType, value: Any) -> T:
|
||||
return make_array(DictVariable, value)
|
||||
case VariableType.ARRAY_FILE:
|
||||
return make_array(FileVariable, value)
|
||||
case VariableType.NESTED_ARRAY:
|
||||
return NestedArrayVariable(value)
|
||||
case VariableType.ANY:
|
||||
return AnyObject(value)
|
||||
return AnyVariable(value)
|
||||
case _:
|
||||
raise TypeError(f"Invalid type - {var_type}")
|
||||
|
||||
@@ -2,7 +2,7 @@ import datetime
|
||||
import uuid
|
||||
from enum import StrEnum
|
||||
|
||||
from sqlalchemy import Column, String, Boolean, DateTime, Text, ForeignKey, Enum as SQLEnum, UniqueConstraint, Integer, ARRAY, Table
|
||||
from sqlalchemy import Column, String, Boolean, DateTime, Text, ForeignKey, Enum as SQLEnum, UniqueConstraint, Integer, ARRAY, Table, text
|
||||
from sqlalchemy.dialects.postgresql import UUID, JSON
|
||||
from sqlalchemy.orm import relationship
|
||||
from sqlalchemy.sql import func
|
||||
@@ -78,6 +78,9 @@ class ModelConfig(BaseModel):
|
||||
description = Column(String, comment="模型描述")
|
||||
|
||||
# 模型配置参数
|
||||
capability = Column(ARRAY(String), default=list, nullable=False, server_default=text("'{}'::varchar[]"),
|
||||
comment="模型能力列表(如['vision', 'audio', 'video'])")
|
||||
is_omni = Column(Boolean, default=False, nullable=False, server_default="false", comment="是否为Omni模型(使用特殊API调用)")
|
||||
config = Column(JSON, comment="模型配置参数")
|
||||
# - temperature : 控制生成文本的随机性。值越高,输出越随机、越有创造性;值越低,输出越确定、越保守。
|
||||
# - top_p : 一种替代 temperature 的采样方法,控制模型从概率最高的词中选择的范围。
|
||||
@@ -118,6 +121,11 @@ class ModelApiKey(BaseModel):
|
||||
api_key = Column(String, nullable=False, comment="API密钥")
|
||||
api_base = Column(String, comment="API基础URL")
|
||||
|
||||
# 模型能力参数
|
||||
capability = Column(ARRAY(String), default=list, nullable=False, server_default=text("'{}'::varchar[]"),
|
||||
comment="模型能力列表(如['vision', 'audio', 'video'])")
|
||||
is_omni = Column(Boolean, default=False, nullable=False, server_default="false", comment="是否为Omni模型(使用特殊API调用)")
|
||||
|
||||
# 配置参数
|
||||
config = Column(JSON, comment="API Key特定配置")
|
||||
|
||||
@@ -155,6 +163,9 @@ class ModelBase(Base):
|
||||
tags = Column(ARRAY(String), default=list, nullable=False, comment="模型标签(如['聊天', '创作'])")
|
||||
add_count = Column(Integer, default=0, nullable=False, comment="模型被用户添加的次数")
|
||||
created_at = Column(DateTime, default=datetime.datetime.now, comment="创建时间", server_default=func.now())
|
||||
capability = Column(ARRAY(String), default=list, nullable=False, server_default=text("'{}'::varchar[]"),
|
||||
comment="模型能力列表(如['vision', 'audio', 'video'])")
|
||||
is_omni = Column(Boolean, default=False, nullable=False, server_default="false", comment="是否为Omni模型(使用特殊API调用)")
|
||||
|
||||
# 关联关系
|
||||
configs = relationship("ModelConfig", back_populates="model_base", cascade="all, delete-orphan")
|
||||
|
||||
@@ -9,7 +9,7 @@ Classes:
|
||||
|
||||
import datetime
|
||||
import uuid
|
||||
from sqlalchemy import Column, String, DateTime, Text, ForeignKey
|
||||
from sqlalchemy import Column, String, DateTime, Text, ForeignKey, Boolean
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.orm import relationship
|
||||
from app.db import Base
|
||||
@@ -25,6 +25,9 @@ class OntologyClass(Base):
|
||||
# 类型信息
|
||||
class_name = Column(String(200), nullable=False, comment="类型名称")
|
||||
class_description = Column(Text, nullable=True, comment="类型描述")
|
||||
|
||||
# 系统默认标识
|
||||
is_system_default = Column(Boolean, default=False, nullable=False, comment="是否为系统默认类型")
|
||||
|
||||
# 外键:关联到本体场景
|
||||
scene_id = Column(UUID(as_uuid=True), ForeignKey("ontology_scene.scene_id", ondelete="CASCADE"), nullable=False, index=True, comment="所属场景ID")
|
||||
|
||||
@@ -9,7 +9,7 @@ Classes:
|
||||
|
||||
import datetime
|
||||
import uuid
|
||||
from sqlalchemy import Column, String, DateTime, Integer, Text, ForeignKey, UniqueConstraint
|
||||
from sqlalchemy import Column, String, DateTime, Integer, Text, ForeignKey, UniqueConstraint, Boolean
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.orm import relationship
|
||||
from app.db import Base
|
||||
@@ -28,6 +28,9 @@ class OntologyScene(Base):
|
||||
# 场景信息
|
||||
scene_name = Column(String(200), nullable=False, comment="场景名称")
|
||||
scene_description = Column(Text, nullable=True, comment="场景描述")
|
||||
|
||||
# 系统默认标识
|
||||
is_system_default = Column(Boolean, default=False, nullable=False, index=True, comment="是否为系统默认场景")
|
||||
|
||||
# 外键:关联到工作空间
|
||||
workspace_id = Column(UUID(as_uuid=True), ForeignKey("workspaces.id", ondelete="CASCADE"), nullable=False, index=True, comment="所属工作空间ID")
|
||||
|
||||
@@ -211,3 +211,46 @@ def get_total_kb_count_by_workspace(db: Session, workspace_id: uuid.UUID) -> int
|
||||
except Exception as e:
|
||||
db_logger.error(f"Failed to query total knowledge base count: workspace_id={workspace_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
def get_user_kb_chunk_num_by_workspace(db: Session, workspace_id: uuid.UUID) -> int:
|
||||
"""
|
||||
根据workspace_id查询knowledges表中permission_id='Memory'(用户知识库)的chunk_num总和
|
||||
"""
|
||||
db_logger.debug(f"Query user KB chunk_num by workspace_id: workspace_id={workspace_id}")
|
||||
|
||||
try:
|
||||
from sqlalchemy import func
|
||||
result = db.query(func.sum(Knowledge.chunk_num)).filter(
|
||||
Knowledge.workspace_id == workspace_id,
|
||||
Knowledge.status == 1,
|
||||
Knowledge.permission_id == "Memory"
|
||||
).scalar()
|
||||
|
||||
total = result if result is not None else 0
|
||||
db_logger.info(f"User KB chunk_num query successful: workspace_id={workspace_id}, total={total}")
|
||||
return total
|
||||
except Exception as e:
|
||||
db_logger.error(f"Failed to query user KB chunk_num: workspace_id={workspace_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
def get_non_user_kb_count_by_workspace(db: Session, workspace_id: uuid.UUID) -> int:
|
||||
"""
|
||||
根据workspace_id查询knowledges表中排除用户知识库(permission_id!='Memory')的数量
|
||||
"""
|
||||
db_logger.debug(f"Query non-user KB count by workspace_id: workspace_id={workspace_id}")
|
||||
|
||||
try:
|
||||
count = db.query(Knowledge).filter(
|
||||
Knowledge.workspace_id == workspace_id,
|
||||
Knowledge.status == 1,
|
||||
Knowledge.permission_id != "Memory"
|
||||
).count()
|
||||
|
||||
db_logger.info(f"Non-user KB count query successful: workspace_id={workspace_id}, count={count}")
|
||||
return count
|
||||
except Exception as e:
|
||||
db_logger.error(f"Failed to query non-user KB count: workspace_id={workspace_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
@@ -233,6 +233,7 @@ class MemoryConfigRepository:
|
||||
config_desc=params.config_desc,
|
||||
workspace_id=params.workspace_id,
|
||||
scene_id=params.scene_id,
|
||||
pruning_scene=params.pruning_scene,
|
||||
llm_id=params.llm_id,
|
||||
embedding_id=params.embedding_id,
|
||||
rerank_id=params.rerank_id,
|
||||
|
||||
@@ -428,19 +428,17 @@ class ModelConfigRepository:
|
||||
|
||||
try:
|
||||
# 查询ModelConfig关联的ModelApiKey,筛选出匹配的model_config_id
|
||||
model_config_ids = db.query(ModelConfig.id).join(
|
||||
ModelBase, ModelConfig.model_id == ModelBase.id
|
||||
).filter(
|
||||
model_config_ids = db.query(ModelConfig.id).filter(
|
||||
and_(
|
||||
or_(
|
||||
ModelConfig.tenant_id == tenant_id,
|
||||
ModelConfig.is_public
|
||||
),
|
||||
ModelBase.provider == provider,
|
||||
ModelConfig.provider == provider,
|
||||
ModelConfig.is_active,
|
||||
~ModelConfig.is_composite
|
||||
)
|
||||
).distinct().all()
|
||||
).all()
|
||||
|
||||
db_logger.debug(f"查询成功: 数量={len(model_config_ids)}")
|
||||
return [row[0] for row in model_config_ids]
|
||||
|
||||
@@ -374,7 +374,7 @@ class OntologySceneRepository:
|
||||
|
||||
count = self.db.query(OntologyScene).filter(
|
||||
OntologyScene.scene_id == scene_id,
|
||||
OntologyScene.workspace_id == workspace_id
|
||||
(OntologyScene.workspace_id == workspace_id) | (OntologyScene.is_system_default == True)
|
||||
).count()
|
||||
|
||||
is_owner = count > 0
|
||||
|
||||
@@ -15,7 +15,7 @@ class ApiKeyCreate(BaseModel):
|
||||
type: ApiKeyType = Field(..., description="API Key 类型")
|
||||
scopes: List[str] = Field(default_factory=list, description="权限范围列表")
|
||||
resource_id: Optional[uuid.UUID] = Field(None, description="关联资源ID")
|
||||
rate_limit: Optional[int] = Field(10, ge=1, le=1000, description="QPS限制(请求/秒)")
|
||||
rate_limit: Optional[int] = Field(100, ge=1, le=1000, description="QPS限制(请求/秒)")
|
||||
daily_request_limit: Optional[int] = Field(10000, description="日请求限制", ge=1)
|
||||
quota_limit: Optional[int] = Field(None, description="配额限制(总请求数)", ge=1)
|
||||
expires_at: Optional[datetime.datetime] = Field(None, description="过期时间")
|
||||
@@ -155,8 +155,7 @@ class ApiKey(BaseModel):
|
||||
return datetime.datetime.now() > self.expires_at
|
||||
|
||||
@field_serializer('expires_at', 'last_used_at', 'created_at', 'updated_at')
|
||||
@classmethod
|
||||
def serialize_datetime(cls, v: Optional[datetime.datetime]) -> Optional[int]:
|
||||
def serialize_datetime(self, v: Optional[datetime.datetime]) -> Optional[int]:
|
||||
"""将datetime转换为时间戳"""
|
||||
return datetime_to_timestamp(v)
|
||||
|
||||
@@ -171,8 +170,7 @@ class ApiKeyStats(BaseModel):
|
||||
avg_response_time: Optional[float] = Field(None, description="平均响应时间(毫秒)")
|
||||
|
||||
@field_serializer('last_used_at')
|
||||
@classmethod
|
||||
def serialize_datetime(cls, v: Optional[datetime.datetime]) -> Optional[int]:
|
||||
def serialize_datetime(self, v: Optional[datetime.datetime]) -> Optional[int]:
|
||||
"""将datetime转换为时间戳"""
|
||||
return datetime_to_timestamp(v)
|
||||
|
||||
@@ -219,7 +217,6 @@ class ApiKeyLog(BaseModel):
|
||||
created_at: datetime.datetime
|
||||
|
||||
@field_serializer('created_at')
|
||||
@classmethod
|
||||
def serialize_datetime(cls, v: datetime.datetime) -> int:
|
||||
def serialize_datetime(self, v: datetime.datetime) -> int:
|
||||
"""将datetime转换为时间戳"""
|
||||
return datetime_to_timestamp(v)
|
||||
|
||||
@@ -5,6 +5,8 @@ from enum import Enum, StrEnum
|
||||
|
||||
from pydantic import BaseModel, Field, ConfigDict, field_serializer, field_validator
|
||||
|
||||
from app.schemas.workflow_schema import WorkflowConfigCreate
|
||||
|
||||
|
||||
# ---------- Multimodal File Support ----------
|
||||
|
||||
@@ -19,8 +21,14 @@ class FileType(StrEnum):
|
||||
def trans(cls, value: str) -> 'FileType':
|
||||
if value.startswith("image"):
|
||||
return cls.IMAGE
|
||||
# TODO: other file type support
|
||||
raise RuntimeError("Unsupport file type")
|
||||
elif value.startswith("document"):
|
||||
return cls.DOCUMENT
|
||||
elif value.startswith("audio"):
|
||||
return cls.AUDIO
|
||||
elif value.startswith("video"):
|
||||
return cls.VIDEO
|
||||
else:
|
||||
raise RuntimeError("Unsupport file type")
|
||||
|
||||
|
||||
class TransferMethod(str, Enum):
|
||||
@@ -35,6 +43,12 @@ class FileInput(BaseModel):
|
||||
transfer_method: TransferMethod = Field(..., description="传输方式: local_file/remote_url")
|
||||
upload_file_id: Optional[uuid.UUID] = Field(None, description="已上传文件ID(local_file时必填)")
|
||||
url: Optional[str] = Field(None, description="远程URL(remote_url时必填)")
|
||||
file_type: Optional[str] = Field(None, description="具体文件格式(如image/jpg、audio/wav、document/docx、video/mp4)")
|
||||
|
||||
def __init__(self, **data):
|
||||
if "type" in data:
|
||||
data['file_type'] = data['type']
|
||||
super().__init__(**data)
|
||||
|
||||
@field_validator("type", mode="before")
|
||||
@classmethod
|
||||
@@ -196,6 +210,8 @@ class AppCreate(BaseModel):
|
||||
# only for type=multi_agent
|
||||
multi_agent_config: Optional[Dict[str, Any]] = None
|
||||
|
||||
workflow_config: Optional[WorkflowConfigCreate] = None
|
||||
|
||||
|
||||
class AppUpdate(BaseModel):
|
||||
name: Optional[str] = None
|
||||
@@ -429,7 +445,7 @@ class AppChatRequest(BaseModel):
|
||||
user_id: Optional[str] = Field(default=None, description="用户ID(用于会话管理)")
|
||||
variables: Optional[Dict[str, Any]] = Field(default=None, description="自定义变量参数值")
|
||||
stream: bool = Field(default=False, description="是否流式返回")
|
||||
files: Optional[List[FileInput]] = Field(default=None, description="附件列表(支持多文件)")
|
||||
files: List[FileInput] = Field(default_factory=list, description="附件列表(支持多文件)")
|
||||
|
||||
|
||||
class DraftRunRequest(BaseModel):
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user