Compare commits
105 Commits
release/v0
...
v0.2.5-hot
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
130f15665c | ||
|
|
026e4376d4 | ||
|
|
cf571cf02b | ||
|
|
218671ef06 | ||
|
|
34de0bb9c5 | ||
|
|
8e6cf09056 | ||
|
|
5929072b76 | ||
|
|
aa69cd3a0c | ||
|
|
a726a81224 | ||
|
|
9aae6163f0 | ||
|
|
941527e7ee | ||
|
|
a3f05220d3 | ||
|
|
7446241735 | ||
|
|
6033d37537 | ||
|
|
1524d7b5ce | ||
|
|
e00341a4cc | ||
|
|
f5185d2e95 | ||
|
|
dc9003f9db | ||
|
|
07e0c70629 | ||
|
|
37f77e0990 | ||
|
|
aef1a57ea8 | ||
|
|
69af479224 | ||
|
|
f38223c97f | ||
|
|
1ac6702eb0 | ||
|
|
2510f60dce | ||
|
|
b9d7fb2598 | ||
|
|
a39ba564fa | ||
|
|
34310bfabe | ||
|
|
78fd189510 | ||
|
|
94836ed9af | ||
|
|
229eb5cc86 | ||
|
|
bbb2c6c903 | ||
|
|
5edf3f2b8a | ||
|
|
006c6cd159 | ||
|
|
9675982555 | ||
|
|
3ac8a9431b | ||
|
|
5c42a84c3e | ||
|
|
b9578bd08a | ||
|
|
035e56e42f | ||
|
|
5a90d4776d | ||
|
|
f81fdca62a | ||
|
|
729c283c63 | ||
|
|
c99f04314c | ||
|
|
dd9be2ed90 | ||
|
|
2327be7557 | ||
|
|
a7ffc19ba1 | ||
|
|
bbaa39c569 | ||
|
|
d1de0250e7 | ||
|
|
2d731c6412 | ||
|
|
6a6e64f487 | ||
|
|
b9201c918a | ||
|
|
7dedad898a | ||
|
|
d497189352 | ||
|
|
fa4da8f467 | ||
|
|
e9ff742162 | ||
|
|
3849cfb835 | ||
|
|
c453af23c6 | ||
|
|
bcf2376f5a | ||
|
|
be2f56ae6a | ||
|
|
cbc9602495 | ||
|
|
c72ce381c0 | ||
|
|
2ef54168fc | ||
|
|
b33ccf00f9 | ||
|
|
829eb4b3be | ||
|
|
6c49456c13 | ||
|
|
fc8f06ee14 | ||
|
|
120a524b7e | ||
|
|
bd037ac3a3 | ||
|
|
b8ea427029 | ||
|
|
275be47224 | ||
|
|
4ea9c7e660 | ||
|
|
92d78d9a52 | ||
|
|
a820001eea | ||
|
|
8273f6d217 | ||
|
|
bd63e0fce8 | ||
|
|
12ba3d473e | ||
|
|
0b9cc0f068 | ||
|
|
5ca397befa | ||
|
|
da735fe776 | ||
|
|
b4f69f2cff | ||
|
|
1885c00cbc | ||
|
|
1e4fdeb1a6 | ||
|
|
cb7dbb0ed4 | ||
|
|
44083aec79 | ||
|
|
4a9b743153 | ||
|
|
b462e17a5b | ||
|
|
b272a52b57 | ||
|
|
79b19b744e | ||
|
|
3a09b26b6d | ||
|
|
dc2ea5c007 | ||
|
|
4fb673077a | ||
|
|
b3a136ac03 | ||
|
|
22f1bfa3fa | ||
|
|
b31e526e4d | ||
|
|
d477e24e34 | ||
|
|
6c7a68802b | ||
|
|
c5674246b0 | ||
|
|
f076199e3f | ||
|
|
e19d27f640 | ||
|
|
de545a69ca | ||
|
|
dc48ba540d | ||
|
|
81e92b4fa6 | ||
|
|
ebad5e00a3 | ||
|
|
4d98bace87 | ||
|
|
d0c0168c20 |
3
.gitignore
vendored
3
.gitignore
vendored
@@ -37,5 +37,4 @@ tika-server*.jar*
|
||||
cl100k_base.tiktoken
|
||||
libssl*.deb
|
||||
|
||||
sandbox/lib/seccomp_python/target
|
||||
sandbox/lib/seccomp_nodejs/target
|
||||
sandbox/lib/seccomp_redbear/target
|
||||
|
||||
7
api/app/cache/__init__.py
vendored
7
api/app/cache/__init__.py
vendored
@@ -2,10 +2,7 @@
|
||||
Cache 缓存模块
|
||||
|
||||
提供各种缓存功能的统一入口
|
||||
注意:隐性记忆和情绪建议已迁移到数据库存储,不再使用Redis缓存
|
||||
"""
|
||||
from .memory import EmotionMemoryCache, ImplicitMemoryCache
|
||||
|
||||
__all__ = [
|
||||
"EmotionMemoryCache",
|
||||
"ImplicitMemoryCache",
|
||||
]
|
||||
__all__ = []
|
||||
|
||||
8
api/app/cache/memory/__init__.py
vendored
8
api/app/cache/memory/__init__.py
vendored
@@ -2,11 +2,7 @@
|
||||
Memory 缓存模块
|
||||
|
||||
提供记忆系统相关的缓存功能
|
||||
注意:隐性记忆和情绪建议已迁移到数据库存储,不再使用Redis缓存
|
||||
"""
|
||||
from .emotion_memory import EmotionMemoryCache
|
||||
from .implicit_memory import ImplicitMemoryCache
|
||||
|
||||
__all__ = [
|
||||
"EmotionMemoryCache",
|
||||
"ImplicitMemoryCache",
|
||||
]
|
||||
__all__ = []
|
||||
|
||||
134
api/app/cache/memory/emotion_memory.py
vendored
134
api/app/cache/memory/emotion_memory.py
vendored
@@ -1,134 +0,0 @@
|
||||
"""
|
||||
Emotion Suggestions Cache
|
||||
|
||||
情绪个性化建议缓存模块
|
||||
用于缓存用户的情绪个性化建议数据
|
||||
"""
|
||||
import json
|
||||
import logging
|
||||
from typing import Optional, Dict, Any
|
||||
from datetime import datetime
|
||||
|
||||
from app.aioRedis import aio_redis
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EmotionMemoryCache:
|
||||
"""情绪建议缓存类"""
|
||||
|
||||
# Key 前缀
|
||||
PREFIX = "cache:memory:emotion_memory"
|
||||
|
||||
@classmethod
|
||||
def _get_key(cls, *parts: str) -> str:
|
||||
"""生成 Redis key
|
||||
|
||||
Args:
|
||||
*parts: key 的各个部分
|
||||
|
||||
Returns:
|
||||
完整的 Redis key
|
||||
"""
|
||||
return ":".join([cls.PREFIX] + list(parts))
|
||||
|
||||
@classmethod
|
||||
async def set_emotion_suggestions(
|
||||
cls,
|
||||
user_id: str,
|
||||
suggestions_data: Dict[str, Any],
|
||||
expire: int = 86400
|
||||
) -> bool:
|
||||
"""设置用户情绪建议缓存
|
||||
|
||||
Args:
|
||||
user_id: 用户ID(end_user_id)
|
||||
suggestions_data: 建议数据字典,包含:
|
||||
- health_summary: 健康状态摘要
|
||||
- suggestions: 建议列表
|
||||
- generated_at: 生成时间(可选)
|
||||
expire: 过期时间(秒),默认24小时(86400秒)
|
||||
|
||||
Returns:
|
||||
是否设置成功
|
||||
"""
|
||||
try:
|
||||
key = cls._get_key("suggestions", user_id)
|
||||
|
||||
# 添加生成时间戳
|
||||
if "generated_at" not in suggestions_data:
|
||||
suggestions_data["generated_at"] = datetime.now().isoformat()
|
||||
|
||||
# 添加缓存标记
|
||||
suggestions_data["cached"] = True
|
||||
|
||||
value = json.dumps(suggestions_data, ensure_ascii=False)
|
||||
await aio_redis.set(key, value, ex=expire)
|
||||
logger.info(f"设置情绪建议缓存成功: {key}, 过期时间: {expire}秒")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"设置情绪建议缓存失败: {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
async def get_emotion_suggestions(cls, user_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""获取用户情绪建议缓存
|
||||
|
||||
Args:
|
||||
user_id: 用户ID(end_user_id)
|
||||
|
||||
Returns:
|
||||
建议数据字典,如果不存在或已过期返回 None
|
||||
"""
|
||||
try:
|
||||
key = cls._get_key("suggestions", user_id)
|
||||
value = await aio_redis.get(key)
|
||||
|
||||
if value:
|
||||
data = json.loads(value)
|
||||
logger.info(f"成功获取情绪建议缓存: {key}")
|
||||
return data
|
||||
|
||||
logger.info(f"情绪建议缓存不存在或已过期: {key}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"获取情绪建议缓存失败: {e}", exc_info=True)
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
async def delete_emotion_suggestions(cls, user_id: str) -> bool:
|
||||
"""删除用户情绪建议缓存
|
||||
|
||||
Args:
|
||||
user_id: 用户ID(end_user_id)
|
||||
|
||||
Returns:
|
||||
是否删除成功
|
||||
"""
|
||||
try:
|
||||
key = cls._get_key("suggestions", user_id)
|
||||
result = await aio_redis.delete(key)
|
||||
logger.info(f"删除情绪建议缓存: {key}, 结果: {result}")
|
||||
return result > 0
|
||||
except Exception as e:
|
||||
logger.error(f"删除情绪建议缓存失败: {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
async def get_suggestions_ttl(cls, user_id: str) -> int:
|
||||
"""获取情绪建议缓存的剩余过期时间
|
||||
|
||||
Args:
|
||||
user_id: 用户ID(end_user_id)
|
||||
|
||||
Returns:
|
||||
剩余秒数,-1表示永不过期,-2表示key不存在
|
||||
"""
|
||||
try:
|
||||
key = cls._get_key("suggestions", user_id)
|
||||
ttl = await aio_redis.ttl(key)
|
||||
logger.debug(f"情绪建议缓存TTL: {key} = {ttl}秒")
|
||||
return ttl
|
||||
except Exception as e:
|
||||
logger.error(f"获取情绪建议缓存TTL失败: {e}")
|
||||
return -2
|
||||
136
api/app/cache/memory/implicit_memory.py
vendored
136
api/app/cache/memory/implicit_memory.py
vendored
@@ -1,136 +0,0 @@
|
||||
"""
|
||||
Implicit Memory Profile Cache
|
||||
|
||||
隐式记忆用户画像缓存模块
|
||||
用于缓存用户的完整画像数据(偏好标签、四维画像、兴趣领域、行为习惯)
|
||||
"""
|
||||
import json
|
||||
import logging
|
||||
from typing import Optional, Dict, Any
|
||||
from datetime import datetime
|
||||
|
||||
from app.aioRedis import aio_redis
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ImplicitMemoryCache:
|
||||
"""隐式记忆用户画像缓存类"""
|
||||
|
||||
# Key 前缀
|
||||
PREFIX = "cache:memory:implicit_memory"
|
||||
|
||||
@classmethod
|
||||
def _get_key(cls, *parts: str) -> str:
|
||||
"""生成 Redis key
|
||||
|
||||
Args:
|
||||
*parts: key 的各个部分
|
||||
|
||||
Returns:
|
||||
完整的 Redis key
|
||||
"""
|
||||
return ":".join([cls.PREFIX] + list(parts))
|
||||
|
||||
@classmethod
|
||||
async def set_user_profile(
|
||||
cls,
|
||||
user_id: str,
|
||||
profile_data: Dict[str, Any],
|
||||
expire: int = 86400
|
||||
) -> bool:
|
||||
"""设置用户完整画像缓存
|
||||
|
||||
Args:
|
||||
user_id: 用户ID(end_user_id)
|
||||
profile_data: 画像数据字典,包含:
|
||||
- preferences: 偏好标签列表
|
||||
- portrait: 四维画像对象
|
||||
- interest_areas: 兴趣领域分布对象
|
||||
- habits: 行为习惯列表
|
||||
- generated_at: 生成时间(可选)
|
||||
expire: 过期时间(秒),默认24小时(86400秒)
|
||||
|
||||
Returns:
|
||||
是否设置成功
|
||||
"""
|
||||
try:
|
||||
key = cls._get_key("profile", user_id)
|
||||
|
||||
# 添加生成时间戳
|
||||
if "generated_at" not in profile_data:
|
||||
profile_data["generated_at"] = datetime.now().isoformat()
|
||||
|
||||
# 添加缓存标记
|
||||
profile_data["cached"] = True
|
||||
|
||||
value = json.dumps(profile_data, ensure_ascii=False)
|
||||
await aio_redis.set(key, value, ex=expire)
|
||||
logger.info(f"设置用户画像缓存成功: {key}, 过期时间: {expire}秒")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"设置用户画像缓存失败: {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
async def get_user_profile(cls, user_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""获取用户完整画像缓存
|
||||
|
||||
Args:
|
||||
user_id: 用户ID(end_user_id)
|
||||
|
||||
Returns:
|
||||
画像数据字典,如果不存在或已过期返回 None
|
||||
"""
|
||||
try:
|
||||
key = cls._get_key("profile", user_id)
|
||||
value = await aio_redis.get(key)
|
||||
|
||||
if value:
|
||||
data = json.loads(value)
|
||||
logger.info(f"成功获取用户画像缓存: {key}")
|
||||
return data
|
||||
|
||||
logger.info(f"用户画像缓存不存在或已过期: {key}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"获取用户画像缓存失败: {e}", exc_info=True)
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
async def delete_user_profile(cls, user_id: str) -> bool:
|
||||
"""删除用户完整画像缓存
|
||||
|
||||
Args:
|
||||
user_id: 用户ID(end_user_id)
|
||||
|
||||
Returns:
|
||||
是否删除成功
|
||||
"""
|
||||
try:
|
||||
key = cls._get_key("profile", user_id)
|
||||
result = await aio_redis.delete(key)
|
||||
logger.info(f"删除用户画像缓存: {key}, 结果: {result}")
|
||||
return result > 0
|
||||
except Exception as e:
|
||||
logger.error(f"删除用户画像缓存失败: {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
async def get_profile_ttl(cls, user_id: str) -> int:
|
||||
"""获取用户画像缓存的剩余过期时间
|
||||
|
||||
Args:
|
||||
user_id: 用户ID(end_user_id)
|
||||
|
||||
Returns:
|
||||
剩余秒数,-1表示永不过期,-2表示key不存在
|
||||
"""
|
||||
try:
|
||||
key = cls._get_key("profile", user_id)
|
||||
ttl = await aio_redis.ttl(key)
|
||||
logger.debug(f"用户画像缓存TTL: {key} = {ttl}秒")
|
||||
return ttl
|
||||
except Exception as e:
|
||||
logger.error(f"获取用户画像缓存TTL失败: {e}")
|
||||
return -2
|
||||
@@ -4,6 +4,7 @@ from datetime import timedelta
|
||||
from urllib.parse import quote
|
||||
|
||||
from celery import Celery
|
||||
from celery.schedules import crontab
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
@@ -82,7 +83,8 @@ celery_app.conf.update(
|
||||
'app.tasks.workspace_reflection_task': {'queue': 'periodic_tasks'},
|
||||
'app.tasks.regenerate_memory_cache': {'queue': 'periodic_tasks'},
|
||||
'app.tasks.run_forgetting_cycle_task': {'queue': 'periodic_tasks'},
|
||||
'app.controllers.memory_storage_controller.search_all': {'queue': 'periodic_tasks'},
|
||||
'app.tasks.write_all_workspaces_memory_task': {'queue': 'periodic_tasks'},
|
||||
'app.tasks.update_implicit_emotions_storage': {'queue': 'periodic_tasks'},
|
||||
},
|
||||
)
|
||||
|
||||
@@ -92,9 +94,12 @@ celery_app.autodiscover_tasks(['app'])
|
||||
# Celery Beat schedule for periodic tasks
|
||||
memory_increment_schedule = timedelta(hours=settings.MEMORY_INCREMENT_INTERVAL_HOURS)
|
||||
memory_cache_regeneration_schedule = timedelta(hours=settings.MEMORY_CACHE_REGENERATION_HOURS)
|
||||
# 这个30秒的设计不合理
|
||||
workspace_reflection_schedule = timedelta(seconds=30) # 每30秒运行一次settings.REFLECTION_INTERVAL_TIME
|
||||
forgetting_cycle_schedule = timedelta(hours=24) # 每24小时运行一次遗忘周期
|
||||
workspace_reflection_schedule = timedelta(seconds=settings.WORKSPACE_REFLECTION_INTERVAL_SECONDS)
|
||||
forgetting_cycle_schedule = timedelta(hours=settings.FORGETTING_CYCLE_INTERVAL_HOURS)
|
||||
implicit_emotions_update_schedule = crontab(
|
||||
hour=settings.IMPLICIT_EMOTIONS_UPDATE_HOUR,
|
||||
minute=settings.IMPLICIT_EMOTIONS_UPDATE_MINUTE,
|
||||
)
|
||||
|
||||
#构建定时任务配置
|
||||
beat_schedule_config = {
|
||||
@@ -115,16 +120,16 @@ beat_schedule_config = {
|
||||
"config_id": None, # 使用默认配置,可以通过环境变量配置
|
||||
},
|
||||
},
|
||||
"write-all-workspaces-memory": {
|
||||
"task": "app.tasks.write_all_workspaces_memory_task",
|
||||
"schedule": memory_increment_schedule,
|
||||
"args": (),
|
||||
},
|
||||
"update-implicit-emotions-storage": {
|
||||
"task": "app.tasks.update_implicit_emotions_storage",
|
||||
"schedule": implicit_emotions_update_schedule,
|
||||
"args": (),
|
||||
},
|
||||
}
|
||||
|
||||
#如果配置了默认工作空间ID,则添加记忆总量统计任务
|
||||
if settings.DEFAULT_WORKSPACE_ID:
|
||||
beat_schedule_config["write-total-memory"] = {
|
||||
"task": "app.controllers.memory_storage_controller.search_all",
|
||||
"schedule": memory_increment_schedule,
|
||||
"kwargs": {
|
||||
"workspace_id": settings.DEFAULT_WORKSPACE_ID,
|
||||
},
|
||||
}
|
||||
|
||||
celery_app.conf.beat_schedule = beat_schedule_config
|
||||
|
||||
@@ -19,6 +19,8 @@ from . import (
|
||||
implicit_memory_controller,
|
||||
knowledge_controller,
|
||||
knowledgeshare_controller,
|
||||
mcp_market_controller,
|
||||
mcp_market_config_controller,
|
||||
memory_agent_controller,
|
||||
memory_dashboard_controller,
|
||||
memory_episodic_controller,
|
||||
@@ -60,6 +62,8 @@ manager_router.include_router(model_controller.router)
|
||||
manager_router.include_router(file_controller.router)
|
||||
manager_router.include_router(document_controller.router)
|
||||
manager_router.include_router(knowledge_controller.router)
|
||||
manager_router.include_router(mcp_market_controller.router)
|
||||
manager_router.include_router(mcp_market_config_controller.router)
|
||||
manager_router.include_router(chunk_controller.router)
|
||||
manager_router.include_router(test_controller.router)
|
||||
manager_router.include_router(knowledgeshare_controller.router)
|
||||
|
||||
@@ -61,6 +61,7 @@ async def login_for_access_token(
|
||||
user = auth_service.register_user_with_invite(
|
||||
db=db,
|
||||
email=form_data.email,
|
||||
username=form_data.username,
|
||||
password=form_data.password,
|
||||
invite_token=form_data.invite,
|
||||
workspace_id=invite_info.workspace_id
|
||||
|
||||
@@ -208,14 +208,64 @@ async def get_emotion_health(
|
||||
|
||||
|
||||
|
||||
# @router.post("/check-data", response_model=ApiResponse)
|
||||
# async def check_emotion_data_exists(
|
||||
# request: EmotionSuggestionsRequest,
|
||||
# db: Session = Depends(get_db),
|
||||
# current_user: User = Depends(get_current_user),
|
||||
# ):
|
||||
# """检查用户情绪建议数据是否存在
|
||||
|
||||
# Args:
|
||||
# request: 包含 end_user_id
|
||||
# db: 数据库会话
|
||||
# current_user: 当前用户
|
||||
|
||||
# Returns:
|
||||
# 数据存在状态
|
||||
# """
|
||||
# try:
|
||||
# api_logger.info(
|
||||
# f"检查用户情绪建议数据是否存在: {request.end_user_id}",
|
||||
# extra={"end_user_id": request.end_user_id}
|
||||
# )
|
||||
|
||||
# # 从数据库获取建议
|
||||
# data = await emotion_service.get_cached_suggestions(
|
||||
# end_user_id=request.end_user_id,
|
||||
# db=db
|
||||
# )
|
||||
|
||||
# if data is None:
|
||||
# api_logger.info(f"用户 {request.end_user_id} 的情绪建议数据不存在")
|
||||
# return fail(
|
||||
# BizCode.NOT_FOUND,
|
||||
# "情绪建议数据不存在,请点击右上角刷新进行初始化",
|
||||
# {"exists": False}
|
||||
# )
|
||||
|
||||
# api_logger.info(f"用户 {request.end_user_id} 的情绪建议数据存在")
|
||||
# return success(data={"exists": True}, msg="情绪建议数据已存在")
|
||||
|
||||
# except Exception as e:
|
||||
# api_logger.error(
|
||||
# f"检查情绪建议数据失败: {str(e)}",
|
||||
# extra={"end_user_id": request.end_user_id},
|
||||
# exc_info=True
|
||||
# )
|
||||
# raise HTTPException(
|
||||
# status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
# detail=f"检查情绪建议数据失败: {str(e)}"
|
||||
# )
|
||||
|
||||
|
||||
@router.post("/suggestions", response_model=ApiResponse)
|
||||
async def get_emotion_suggestions(
|
||||
request: EmotionSuggestionsRequest,
|
||||
language_type: str = Header(default=None, alias="X-Language-Type"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""获取个性化情绪建议(从缓存读取)
|
||||
"""获取个性化情绪建议(从数据库读取)
|
||||
|
||||
Args:
|
||||
request: 包含 end_user_id 和可选的 config_id
|
||||
@@ -223,77 +273,42 @@ async def get_emotion_suggestions(
|
||||
current_user: 当前用户
|
||||
|
||||
Returns:
|
||||
缓存的个性化情绪建议响应
|
||||
存储的个性化情绪建议响应
|
||||
"""
|
||||
try:
|
||||
# 使用集中化的语言校验
|
||||
language = get_language_from_header(language_type)
|
||||
|
||||
api_logger.info(
|
||||
f"用户 {current_user.username} 请求获取个性化情绪建议(缓存)",
|
||||
f"用户 {current_user.username} 请求获取个性化情绪建议",
|
||||
extra={
|
||||
"end_user_id": request.end_user_id,
|
||||
"config_id": request.config_id
|
||||
}
|
||||
)
|
||||
|
||||
# 从缓存获取建议
|
||||
# 从数据库获取建议
|
||||
data = await emotion_service.get_cached_suggestions(
|
||||
end_user_id=request.end_user_id,
|
||||
db=db
|
||||
)
|
||||
|
||||
if data is None:
|
||||
# 缓存不存在或已过期,自动触发生成
|
||||
api_logger.info(
|
||||
f"用户 {request.end_user_id} 的建议缓存不存在或已过期,自动生成新建议",
|
||||
f"用户 {request.end_user_id} 的建议数据不存在",
|
||||
extra={"end_user_id": request.end_user_id}
|
||||
)
|
||||
try:
|
||||
data = await emotion_service.generate_emotion_suggestions(
|
||||
end_user_id=request.end_user_id,
|
||||
db=db,
|
||||
language=language
|
||||
)
|
||||
# 保存到缓存
|
||||
await emotion_service.save_suggestions_cache(
|
||||
end_user_id=request.end_user_id,
|
||||
suggestions_data=data,
|
||||
db=db,
|
||||
expires_hours=24
|
||||
)
|
||||
except (ValueError, KeyError) as gen_e:
|
||||
# 预期内的业务异常:配置缺失、数据格式问题等
|
||||
api_logger.warning(
|
||||
f"自动生成建议失败(业务异常): {str(gen_e)}",
|
||||
extra={"end_user_id": request.end_user_id}
|
||||
)
|
||||
return fail(
|
||||
BizCode.NOT_FOUND,
|
||||
f"自动生成建议失败: {str(gen_e)}",
|
||||
""
|
||||
)
|
||||
except Exception as gen_e:
|
||||
# 非预期异常:记录完整 traceback 便于排查
|
||||
api_logger.error(
|
||||
f"自动生成建议时发生未预期异常: {str(gen_e)}",
|
||||
extra={"end_user_id": request.end_user_id},
|
||||
exc_info=True
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"生成建议时发生内部错误: {str(gen_e)}"
|
||||
)
|
||||
return success(
|
||||
data={"exists": False},
|
||||
msg="情绪建议数据不存在,请点击右上角刷新进行初始化"
|
||||
)
|
||||
|
||||
api_logger.info(
|
||||
"个性化建议获取成功(缓存)",
|
||||
"个性化建议获取成功",
|
||||
extra={
|
||||
"end_user_id": request.end_user_id,
|
||||
"suggestions_count": len(data.get("suggestions", []))
|
||||
}
|
||||
)
|
||||
|
||||
return success(data=data, msg="个性化建议获取成功(缓存)")
|
||||
return success(data=data, msg="个性化建议获取成功")
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(
|
||||
@@ -314,7 +329,7 @@ async def generate_emotion_suggestions(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""生成个性化情绪建议(调用LLM并缓存)
|
||||
"""生成个性化情绪建议(调用LLM并保存到数据库)
|
||||
|
||||
Args:
|
||||
request: 包含 end_user_id
|
||||
@@ -342,12 +357,11 @@ async def generate_emotion_suggestions(
|
||||
language=language
|
||||
)
|
||||
|
||||
# 保存到缓存
|
||||
# 保存到数据库
|
||||
await emotion_service.save_suggestions_cache(
|
||||
end_user_id=request.end_user_id,
|
||||
suggestions_data=data,
|
||||
db=db,
|
||||
expires_hours=24
|
||||
db=db
|
||||
)
|
||||
|
||||
api_logger.info(
|
||||
@@ -369,4 +383,4 @@ async def generate_emotion_suggestions(
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"生成个性化建议失败: {str(e)}"
|
||||
)
|
||||
)
|
||||
@@ -122,6 +122,48 @@ def validate_confidence_threshold(threshold: float) -> None:
|
||||
raise ValueError("confidence_threshold must be between 0.0 and 1.0")
|
||||
|
||||
|
||||
@router.get("/check-data/{end_user_id}", response_model=ApiResponse)
|
||||
@cur_workspace_access_guard()
|
||||
async def check_user_data_exists(
|
||||
end_user_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
) -> ApiResponse:
|
||||
"""
|
||||
检查用户画像数据是否存在
|
||||
|
||||
Args:
|
||||
end_user_id: 目标用户ID
|
||||
|
||||
Returns:
|
||||
数据存在状态
|
||||
"""
|
||||
api_logger.info(f"检查用户画像数据是否存在: {end_user_id}")
|
||||
|
||||
try:
|
||||
# Validate inputs
|
||||
validate_user_id(end_user_id)
|
||||
|
||||
# Create service with user-specific config
|
||||
service = ImplicitMemoryService(db=db, end_user_id=end_user_id)
|
||||
|
||||
# Get cached profile
|
||||
cached_profile = await service.get_cached_profile(end_user_id=end_user_id, db=db)
|
||||
|
||||
if cached_profile is None:
|
||||
api_logger.info(f"用户 {end_user_id} 的画像数据不存在")
|
||||
return success(
|
||||
data={"exists": False},
|
||||
msg="画像数据不存在,请点击右上角刷新进行初始化"
|
||||
)
|
||||
|
||||
api_logger.info(f"用户 {end_user_id} 的画像数据存在")
|
||||
return success(data={"exists": True}, msg="画像数据已存在")
|
||||
|
||||
except Exception as e:
|
||||
return handle_implicit_memory_error(e, "检查画像数据", end_user_id)
|
||||
|
||||
|
||||
@router.get("/preferences/{end_user_id}", response_model=ApiResponse)
|
||||
@cur_workspace_access_guard()
|
||||
async def get_preference_tags(
|
||||
@@ -159,12 +201,8 @@ async def get_preference_tags(
|
||||
cached_profile = await service.get_cached_profile(end_user_id=end_user_id, db=db)
|
||||
|
||||
if cached_profile is None:
|
||||
api_logger.info(f"用户 {end_user_id} 的画像缓存不存在或已过期")
|
||||
return fail(
|
||||
BizCode.NOT_FOUND,
|
||||
"画像缓存不存在或已过期,请右上角刷新生成新画像",
|
||||
""
|
||||
)
|
||||
api_logger.info(f"用户 {end_user_id} 的画像数据不存在")
|
||||
return fail(BizCode.NOT_FOUND, "", "")
|
||||
|
||||
# Extract preferences from cache
|
||||
preferences = cached_profile.get("preferences", [])
|
||||
@@ -230,12 +268,8 @@ async def get_dimension_portrait(
|
||||
cached_profile = await service.get_cached_profile(end_user_id=end_user_id, db=db)
|
||||
|
||||
if cached_profile is None:
|
||||
api_logger.info(f"用户 {end_user_id} 的画像缓存不存在或已过期")
|
||||
return fail(
|
||||
BizCode.NOT_FOUND,
|
||||
"画像缓存不存在或已过期,请右上角刷新生成新画像",
|
||||
""
|
||||
)
|
||||
api_logger.info(f"用户 {end_user_id} 的画像数据不存在")
|
||||
return fail(BizCode.NOT_FOUND, "", "")
|
||||
|
||||
# Extract portrait from cache
|
||||
portrait = cached_profile.get("portrait", {})
|
||||
@@ -278,12 +312,8 @@ async def get_interest_area_distribution(
|
||||
cached_profile = await service.get_cached_profile(end_user_id=end_user_id, db=db)
|
||||
|
||||
if cached_profile is None:
|
||||
api_logger.info(f"用户 {end_user_id} 的画像缓存不存在或已过期")
|
||||
return fail(
|
||||
BizCode.NOT_FOUND,
|
||||
"画像缓存不存在或已过期,请右上角刷新生成新画像",
|
||||
""
|
||||
)
|
||||
api_logger.info(f"用户 {end_user_id} 的画像数据不存在")
|
||||
return fail(BizCode.NOT_FOUND, "", "")
|
||||
|
||||
# Extract interest areas from cache
|
||||
interest_areas = cached_profile.get("interest_areas", {})
|
||||
@@ -330,12 +360,8 @@ async def get_behavior_habits(
|
||||
cached_profile = await service.get_cached_profile(end_user_id=end_user_id, db=db)
|
||||
|
||||
if cached_profile is None:
|
||||
api_logger.info(f"用户 {end_user_id} 的画像缓存不存在或已过期")
|
||||
return fail(
|
||||
BizCode.NOT_FOUND,
|
||||
"画像缓存不存在或已过期,请右上角刷新生成新画像",
|
||||
""
|
||||
)
|
||||
api_logger.info(f"用户 {end_user_id} 的画像数据不存在")
|
||||
return fail(BizCode.NOT_FOUND, "", "")
|
||||
|
||||
# Extract habits from cache
|
||||
habits = cached_profile.get("habits", [])
|
||||
|
||||
336
api/app/controllers/mcp_market_config_controller.py
Normal file
336
api/app/controllers/mcp_market_config_controller.py
Normal file
@@ -0,0 +1,336 @@
|
||||
import datetime
|
||||
import json
|
||||
from typing import Optional
|
||||
import uuid
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Query
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
import requests
|
||||
from sqlalchemy import or_
|
||||
from sqlalchemy.orm import Session
|
||||
from modelscope.hub.errors import raise_for_http_status
|
||||
from modelscope.hub.mcp_api import MCPApi
|
||||
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.core.response_utils import success, fail
|
||||
from app.db import get_db
|
||||
from app.dependencies import get_current_user
|
||||
from app.models import mcp_market_config_model
|
||||
from app.models.user_model import User
|
||||
from app.schemas import mcp_market_config_schema
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.services import mcp_market_config_service
|
||||
|
||||
# Obtain a dedicated API logger
|
||||
api_logger = get_api_logger()
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/mcp_market_configs",
|
||||
tags=["mcp_market_configs"],
|
||||
dependencies=[Depends(get_current_user)] # Apply auth to all routes in this controller
|
||||
)
|
||||
|
||||
|
||||
@router.get("/mcp_servers", response_model=ApiResponse)
|
||||
async def get_mcp_servers(
|
||||
mcp_market_config_id: uuid.UUID,
|
||||
page: int = Query(1, gt=0), # Default: 1, which must be greater than 0
|
||||
pagesize: int = Query(20, gt=0, le=100), # Default: 20 items per page, maximum: 100 items
|
||||
keywords: Optional[str] = Query(None, description="Search keywords (Optional search query string,e.g. Chinese service name, English service name, author/owner username)"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Query the mcp servers list in pages
|
||||
- Support keyword search for name,author,owner
|
||||
- Return paging metadata + mcp server list
|
||||
"""
|
||||
api_logger.info(
|
||||
f"Query mcp server list: tenant_id={current_user.tenant_id}, page={page}, pagesize={pagesize}, keywords={keywords}, username: {current_user.username}")
|
||||
|
||||
# 1. parameter validation
|
||||
if page < 1 or pagesize < 1:
|
||||
api_logger.warning(f"Error in paging parameters: page={page}, pagesize={pagesize}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="The paging parameter must be greater than 0"
|
||||
)
|
||||
|
||||
# 2. Query mcp market config information from the database
|
||||
api_logger.debug(f"Query mcp market config: {mcp_market_config_id}")
|
||||
db_mcp_market_config = mcp_market_config_service.get_mcp_market_config_by_id(db,
|
||||
mcp_market_config_id=mcp_market_config_id,
|
||||
current_user=current_user)
|
||||
if not db_mcp_market_config:
|
||||
api_logger.warning(
|
||||
f"The mcp market config does not exist or access is denied: mcp_market_config_id={mcp_market_config_id}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="The mcp market config does not exist or access is denied"
|
||||
)
|
||||
|
||||
# 3. Execute paged query
|
||||
api = MCPApi()
|
||||
token = db_mcp_market_config.token
|
||||
api.login(token)
|
||||
|
||||
body = {
|
||||
'filter': {},
|
||||
'page_number': page,
|
||||
'page_size': pagesize,
|
||||
'search': keywords
|
||||
}
|
||||
|
||||
try:
|
||||
cookies = api.get_cookies(token)
|
||||
r = api.session.put(
|
||||
url=api.mcp_base_url,
|
||||
headers=api.builder_headers(api.headers),
|
||||
json=body,
|
||||
cookies=cookies)
|
||||
raise_for_http_status(r)
|
||||
except requests.exceptions.RequestException as e:
|
||||
api_logger.error(f"mFailed to get MCP servers: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to get MCP servers: {str(e)}"
|
||||
)
|
||||
|
||||
data = api._handle_response(r)
|
||||
total = data.get('total_count', 0)
|
||||
mcp_server_list = data.get('mcp_server_list', [])
|
||||
# items = [{
|
||||
# 'name': item.get('name', ''),
|
||||
# 'id': item.get('id', ''),
|
||||
# 'description': item.get('description', '')
|
||||
# } for item in mcp_server_list]
|
||||
|
||||
# 4. Return structured response
|
||||
result = {
|
||||
"items": mcp_server_list,
|
||||
"page": {
|
||||
"page": page,
|
||||
"pagesize": pagesize,
|
||||
"total": total,
|
||||
"has_next": True if page * pagesize < total else False
|
||||
}
|
||||
}
|
||||
return success(data=result, msg="Query of mcp servers list successful")
|
||||
|
||||
|
||||
@router.get("/mcp_server", response_model=ApiResponse)
|
||||
async def get_mcp_server(
|
||||
mcp_market_config_id: uuid.UUID,
|
||||
server_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Get detailed information for a specific MCP Server
|
||||
"""
|
||||
api_logger.info(
|
||||
f"Query mcp server: tenant_id={current_user.tenant_id}, mcp_market_config_id={mcp_market_config_id}, server_id={server_id}, username: {current_user.username}")
|
||||
|
||||
# 1. Query mcp market config information from the database
|
||||
api_logger.debug(f"Query mcp market config: {mcp_market_config_id}")
|
||||
db_mcp_market_config = mcp_market_config_service.get_mcp_market_config_by_id(db,
|
||||
mcp_market_config_id=mcp_market_config_id,
|
||||
current_user=current_user)
|
||||
if not db_mcp_market_config:
|
||||
api_logger.warning(
|
||||
f"The mcp market config does not exist or access is denied: mcp_market_config_id={mcp_market_config_id}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="The mcp market config does not exist or access is denied"
|
||||
)
|
||||
|
||||
# 2. Get detailed information for a specific MCP Server
|
||||
api = MCPApi()
|
||||
token = db_mcp_market_config.token
|
||||
api.login(token)
|
||||
|
||||
result = api.get_mcp_server(server_id=server_id)
|
||||
return success(data=result, msg="Query of mcp servers list successful")
|
||||
|
||||
|
||||
@router.post("/mcp_market_config", response_model=ApiResponse)
|
||||
async def create_mcp_market_config(
|
||||
create_data: mcp_market_config_schema.McpMarketConfigCreate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
create mcp market config
|
||||
"""
|
||||
api_logger.info(
|
||||
f"Request to create a mcp market config: mcp_market_id={create_data.mcp_market_id}, tenant_id={current_user.tenant_id}, username: {current_user.username}")
|
||||
|
||||
try:
|
||||
api_logger.debug(f"Start creating the mcp market config: {create_data.mcp_market_id}")
|
||||
# 1. Check if the mcp market name already exists
|
||||
db_mcp_market_config_exist = mcp_market_config_service.get_mcp_market_config_by_mcp_market_id(db, mcp_market_id=create_data.mcp_market_id, current_user=current_user)
|
||||
if db_mcp_market_config_exist:
|
||||
api_logger.warning(f"The mcp market id already exists: {create_data.mcp_market_id}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"The mcp market id already exists: {create_data.mcp_market_id}"
|
||||
)
|
||||
db_mcp_market_config = mcp_market_config_service.create_mcp_market_config(db=db, mcp_market_config=create_data, current_user=current_user)
|
||||
api_logger.info(
|
||||
f"The mcp market config has been successfully created: (ID: {db_mcp_market_config.id})")
|
||||
return success(data=jsonable_encoder(mcp_market_config_schema.McpMarketConfig.model_validate(db_mcp_market_config)),
|
||||
msg="The mcp market config has been successfully created")
|
||||
except Exception as e:
|
||||
api_logger.error(f"The creation of the mcp market config failed: {create_data.mcp_market_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
@router.get("/{mcp_market_config_id}", response_model=ApiResponse)
|
||||
async def get_mcp_market_config(
|
||||
mcp_market_config_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Retrieve mcp market config information based on mcp_market_config_id
|
||||
"""
|
||||
api_logger.info(
|
||||
f"Obtain details of the mcp market config: mcp_market_config_id={mcp_market_config_id}, username: {current_user.username}")
|
||||
|
||||
try:
|
||||
# 1. Query mcp market config information from the database
|
||||
api_logger.debug(f"Query mcp market config: {mcp_market_config_id}")
|
||||
db_mcp_market_config = mcp_market_config_service.get_mcp_market_config_by_id(db, mcp_market_config_id=mcp_market_config_id, current_user=current_user)
|
||||
if not db_mcp_market_config:
|
||||
api_logger.warning(f"The mcp market config does not exist or access is denied: mcp_market_config_id={mcp_market_config_id}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="The mcp market config does not exist or access is denied"
|
||||
)
|
||||
|
||||
api_logger.info(f"mcp market config query successful: (ID: {db_mcp_market_config.id})")
|
||||
return success(data=jsonable_encoder(mcp_market_config_schema.McpMarketConfig.model_validate(db_mcp_market_config)),
|
||||
msg="Successfully obtained mcp market config information")
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
api_logger.error(f"mcp market config query failed: mcp_market_config_id={mcp_market_config_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
@router.get("/mcp_market_id/{mcp_market_id}", response_model=ApiResponse)
|
||||
async def get_mcp_market_config_by_mcp_market_id(
|
||||
mcp_market_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Retrieve mcp market config information based on mcp_market_id
|
||||
"""
|
||||
api_logger.info(
|
||||
f"Request to create a mcp market config: mcp_market_id={mcp_market_id}, tenant_id={current_user.tenant_id}, username: {current_user.username}")
|
||||
|
||||
try:
|
||||
# 1. Query mcp market config information from the database
|
||||
api_logger.debug(f"Query mcp market config: mcp_market_id={mcp_market_id}")
|
||||
db_mcp_market_config = mcp_market_config_service.get_mcp_market_config_by_mcp_market_id(db, mcp_market_id=mcp_market_id, current_user=current_user)
|
||||
if not db_mcp_market_config:
|
||||
api_logger.warning(f"The mcp market config does not exist or access is denied: mcp_market_id={mcp_market_id}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="The mcp market config does not exist or access is denied"
|
||||
)
|
||||
|
||||
api_logger.info(f"mcp market config query successful: (ID: {db_mcp_market_config.id})")
|
||||
return success(data=jsonable_encoder(mcp_market_config_schema.McpMarketConfig.model_validate(db_mcp_market_config)),
|
||||
msg="Successfully obtained mcp market config information")
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
api_logger.error(f"mcp market config query failed: mcp_market_id={mcp_market_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
@router.put("/{mcp_market_config_id}", response_model=ApiResponse)
|
||||
async def update_mcp_market_config(
|
||||
mcp_market_config_id: uuid.UUID,
|
||||
update_data: mcp_market_config_schema.McpMarketConfigUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
# 1. Check if the mcp market config exists
|
||||
api_logger.debug(f"Query the mcp market config to be updated: {mcp_market_config_id}")
|
||||
db_mcp_market_config = mcp_market_config_service.get_mcp_market_config_by_id(db, mcp_market_config_id=mcp_market_config_id, current_user=current_user)
|
||||
|
||||
if not db_mcp_market_config:
|
||||
api_logger.warning(
|
||||
f"The mcp market config does not exist or you do not have permission to access it: mcp_market_config_id={mcp_market_config_id}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="The mcp market config does not exist or you do not have permission to access it"
|
||||
)
|
||||
|
||||
# 2. Update fields (only update non-null fields)
|
||||
api_logger.debug(f"Start updating the mcp market config fields: {mcp_market_config_id}")
|
||||
update_dict = update_data.dict(exclude_unset=True)
|
||||
updated_fields = []
|
||||
for field, value in update_dict.items():
|
||||
if hasattr(db_mcp_market_config, field):
|
||||
old_value = getattr(db_mcp_market_config, field)
|
||||
if old_value != value:
|
||||
# update value
|
||||
setattr(db_mcp_market_config, field, value)
|
||||
updated_fields.append(f"{field}: {old_value} -> {value}")
|
||||
|
||||
if updated_fields:
|
||||
api_logger.debug(f"updated fields: {', '.join(updated_fields)}")
|
||||
|
||||
# 3. Save to database
|
||||
try:
|
||||
db.commit()
|
||||
db.refresh(db_mcp_market_config)
|
||||
api_logger.info(f"The mcp market config has been successfully updated: (ID: {db_mcp_market_config.id})")
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
api_logger.error(f"The mcp market config update failed: mcp_market_config_id={mcp_market_config_id} - {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"The mcp market config update failed: {str(e)}"
|
||||
)
|
||||
|
||||
# 4. Return the updated mcp market config
|
||||
return success(data=jsonable_encoder(mcp_market_config_schema.McpMarketConfig.model_validate(db_mcp_market_config)),
|
||||
msg="The mcp market config information updated successfully")
|
||||
|
||||
|
||||
@router.delete("/{mcp_market_config_id}", response_model=ApiResponse)
|
||||
async def delete_mcp_market_config(
|
||||
mcp_market_config_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
delete mcp market config
|
||||
"""
|
||||
api_logger.info(f"Request to delete mcp market config: mcp_market_config_id={mcp_market_config_id}, username: {current_user.username}")
|
||||
|
||||
try:
|
||||
# 1. Check whether the mcp market config exists
|
||||
api_logger.debug(f"Check whether the mcp market config exists: {mcp_market_config_id}")
|
||||
db_mcp_market_config = mcp_market_config_service.get_mcp_market_config_by_id(db, mcp_market_config_id=mcp_market_config_id, current_user=current_user)
|
||||
|
||||
if not db_mcp_market_config:
|
||||
api_logger.warning(
|
||||
f"The mcp market config does not exist or you do not have permission to access it: mcp_market_config_id={mcp_market_config_id}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="The mcp market config does not exist or you do not have permission to access it"
|
||||
)
|
||||
|
||||
# 2. Deleting mcp market config
|
||||
mcp_market_config_service.delete_mcp_market_config_by_id(db, mcp_market_config_id=mcp_market_config_id, current_user=current_user)
|
||||
api_logger.info(f"The mcp market config has been successfully deleted: (ID: {mcp_market_config_id})")
|
||||
return success(msg="The mcp market config has been successfully deleted")
|
||||
except Exception as e:
|
||||
api_logger.error(f"Failed to delete from the mcp market config: mcp_market_config_id={mcp_market_config_id} - {str(e)}")
|
||||
raise
|
||||
262
api/app/controllers/mcp_market_controller.py
Normal file
262
api/app/controllers/mcp_market_controller.py
Normal file
@@ -0,0 +1,262 @@
|
||||
import datetime
|
||||
import json
|
||||
from typing import Optional
|
||||
import uuid
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Query
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from sqlalchemy import or_
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.core.response_utils import success, fail
|
||||
from app.db import get_db
|
||||
from app.dependencies import get_current_user
|
||||
from app.models import mcp_market_model
|
||||
from app.models.user_model import User
|
||||
from app.schemas import mcp_market_schema
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.services import mcp_market_service
|
||||
|
||||
# Obtain a dedicated API logger
|
||||
api_logger = get_api_logger()
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/mcp_markets",
|
||||
tags=["mcp_markets"],
|
||||
dependencies=[Depends(get_current_user)] # Apply auth to all routes in this controller
|
||||
)
|
||||
|
||||
|
||||
@router.get("/mcp_markets", response_model=ApiResponse)
|
||||
async def get_mcp_markets(
|
||||
page: int = Query(1, gt=0), # Default: 1, which must be greater than 0
|
||||
pagesize: int = Query(20, gt=0, le=100), # Default: 20 items per page, maximum: 100 items
|
||||
orderby: Optional[str] = Query(None, description="Sort fields, such as: category, created_at"),
|
||||
desc: Optional[bool] = Query(False, description="Is it descending order"),
|
||||
keywords: Optional[str] = Query(None, description="Search keywords (mcp_market base name)"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Query the mcp markets list in pages
|
||||
- Support keyword search for name,description
|
||||
- Support dynamic sorting
|
||||
- Return paging metadata + mcp_market list
|
||||
"""
|
||||
api_logger.info(
|
||||
f"Query mcp market list: tenant_id={current_user.tenant_id}, page={page}, pagesize={pagesize}, keywords={keywords}, username: {current_user.username}")
|
||||
|
||||
# 1. parameter validation
|
||||
if page < 1 or pagesize < 1:
|
||||
api_logger.warning(f"Error in paging parameters: page={page}, pagesize={pagesize}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="The paging parameter must be greater than 0"
|
||||
)
|
||||
|
||||
# 2. Construct query conditions
|
||||
filters = []
|
||||
|
||||
# Keyword search (fuzzy matching of mcp market name,description)
|
||||
if keywords:
|
||||
api_logger.debug(f"Add keyword search criteria: {keywords}")
|
||||
filters.append(
|
||||
or_(
|
||||
mcp_market_model.McpMarket.name.ilike(f"%{keywords}%"),
|
||||
mcp_market_model.McpMarket.description.ilike(f"%{keywords}%")
|
||||
)
|
||||
)
|
||||
# 3. Execute paged query
|
||||
try:
|
||||
api_logger.debug("Start executing mcp market paging query")
|
||||
total, items = mcp_market_service.get_mcp_markets_paginated(
|
||||
db=db,
|
||||
filters=filters,
|
||||
page=page,
|
||||
pagesize=pagesize,
|
||||
orderby=orderby,
|
||||
desc=desc,
|
||||
current_user=current_user
|
||||
)
|
||||
api_logger.info(f"mcp market query successful: total={total}, returned={len(items)} records")
|
||||
except Exception as e:
|
||||
api_logger.error(f"mcp market query failed: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Query failed: {str(e)}"
|
||||
)
|
||||
|
||||
# 4. Return structured response
|
||||
result = {
|
||||
"items": items,
|
||||
"page": {
|
||||
"page": page,
|
||||
"pagesize": pagesize,
|
||||
"total": total,
|
||||
"has_next": True if page * pagesize < total else False
|
||||
}
|
||||
}
|
||||
return success(data=jsonable_encoder(result), msg="Query of mcp market list successful")
|
||||
|
||||
|
||||
@router.post("/mcp_market", response_model=ApiResponse)
|
||||
async def create_mcp_market(
|
||||
create_data: mcp_market_schema.McpMarketCreate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
create mcp market
|
||||
"""
|
||||
api_logger.info(
|
||||
f"Request to create a mcp market: name={create_data.name}, tenant_id={current_user.tenant_id}, username: {current_user.username}")
|
||||
|
||||
try:
|
||||
api_logger.debug(f"Start creating the mcp market: {create_data.name}")
|
||||
# 1. Check if the mcp market name already exists
|
||||
db_mcp_market_exist = mcp_market_service.get_mcp_market_by_name(db, name=create_data.name, current_user=current_user)
|
||||
if db_mcp_market_exist:
|
||||
api_logger.warning(f"The mcp market name already exists: {create_data.name}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"The mcp market name already exists: {create_data.name}"
|
||||
)
|
||||
db_mcp_market = mcp_market_service.create_mcp_market(db=db, mcp_market=create_data, current_user=current_user)
|
||||
api_logger.info(
|
||||
f"The mcp market has been successfully created: {db_mcp_market.name} (ID: {db_mcp_market.id})")
|
||||
return success(data=jsonable_encoder(mcp_market_schema.McpMarket.model_validate(db_mcp_market)),
|
||||
msg="The mcp market has been successfully created")
|
||||
except Exception as e:
|
||||
api_logger.error(f"The creation of the mcp market failed: {create_data.name} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
@router.get("/{mcp_market_id}", response_model=ApiResponse)
|
||||
async def get_mcp_market(
|
||||
mcp_market_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Retrieve mcp market information based on mcp_market_id
|
||||
"""
|
||||
api_logger.info(
|
||||
f"Obtain details of the mcp market: mcp_market_id={mcp_market_id}, username: {current_user.username}")
|
||||
|
||||
try:
|
||||
# 1. Query mcp market information from the database
|
||||
api_logger.debug(f"Query mcp market: {mcp_market_id}")
|
||||
db_mcp_market = mcp_market_service.get_mcp_market_by_id(db, mcp_market_id=mcp_market_id, current_user=current_user)
|
||||
if not db_mcp_market:
|
||||
api_logger.warning(f"The mcp market does not exist or access is denied: mcp_market_id={mcp_market_id}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="The mcp market does not exist or access is denied"
|
||||
)
|
||||
|
||||
api_logger.info(f"mcp market query successful: {db_mcp_market.name} (ID: {db_mcp_market.id})")
|
||||
return success(data=jsonable_encoder(mcp_market_schema.McpMarket.model_validate(db_mcp_market)),
|
||||
msg="Successfully obtained mcp market information")
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
api_logger.error(f"mcp market query failed: mcp_market_id={mcp_market_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
@router.put("/{mcp_market_id}", response_model=ApiResponse)
|
||||
async def update_mcp_market(
|
||||
mcp_market_id: uuid.UUID,
|
||||
update_data: mcp_market_schema.McpMarketUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
# 1. Check if the mcp market exists
|
||||
api_logger.debug(f"Query the mcp market to be updated: {mcp_market_id}")
|
||||
db_mcp_market = mcp_market_service.get_mcp_market_by_id(db, mcp_market_id=mcp_market_id, current_user=current_user)
|
||||
|
||||
if not db_mcp_market:
|
||||
api_logger.warning(
|
||||
f"The mcp market does not exist or you do not have permission to access it: mcp_market_id={mcp_market_id}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="The mcp market does not exist or you do not have permission to access it"
|
||||
)
|
||||
|
||||
# 2. not updating the name (name already exists)
|
||||
update_dict = update_data.dict(exclude_unset=True)
|
||||
if "name" in update_dict:
|
||||
name = update_dict["name"]
|
||||
if name != db_mcp_market.name:
|
||||
# Check if the mcp market name already exists
|
||||
db_mcp_market_exist = mcp_market_service.get_mcp_market_by_name(db, name=name, current_user=current_user)
|
||||
if db_mcp_market_exist:
|
||||
api_logger.warning(f"The mcp market name already exists: {name}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"The mcp market name already exists: {name}"
|
||||
)
|
||||
# 3. Update fields (only update non-null fields)
|
||||
api_logger.debug(f"Start updating the mcp market fields: {mcp_market_id}")
|
||||
updated_fields = []
|
||||
for field, value in update_dict.items():
|
||||
if hasattr(db_mcp_market, field):
|
||||
old_value = getattr(db_mcp_market, field)
|
||||
if old_value != value:
|
||||
# update value
|
||||
setattr(db_mcp_market, field, value)
|
||||
updated_fields.append(f"{field}: {old_value} -> {value}")
|
||||
|
||||
if updated_fields:
|
||||
api_logger.debug(f"updated fields: {', '.join(updated_fields)}")
|
||||
|
||||
# 4. Save to database
|
||||
try:
|
||||
db.commit()
|
||||
db.refresh(db_mcp_market)
|
||||
api_logger.info(f"The mcp market has been successfully updated: {db_mcp_market.name} (ID: {db_mcp_market.id})")
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
api_logger.error(f"The mcp market update failed: mcp_market_id={mcp_market_id} - {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"The mcp market update failed: {str(e)}"
|
||||
)
|
||||
|
||||
# 5. Return the updated mcp market
|
||||
return success(data=jsonable_encoder(mcp_market_schema.McpMarket.model_validate(db_mcp_market)),
|
||||
msg="The mcp market information updated successfully")
|
||||
|
||||
|
||||
@router.delete("/{mcp_market_id}", response_model=ApiResponse)
|
||||
async def delete_mcp_market(
|
||||
mcp_market_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
delete mcp market
|
||||
"""
|
||||
api_logger.info(f"Request to delete mcp market: mcp_market_id={mcp_market_id}, username: {current_user.username}")
|
||||
|
||||
try:
|
||||
# 1. Check whether the mcp market exists
|
||||
api_logger.debug(f"Check whether the mcp market exists: {mcp_market_id}")
|
||||
db_mcp_market = mcp_market_service.get_mcp_market_by_id(db, mcp_market_id=mcp_market_id, current_user=current_user)
|
||||
|
||||
if not db_mcp_market:
|
||||
api_logger.warning(
|
||||
f"The mcp market does not exist or you do not have permission to access it: mcp_market_id={mcp_market_id}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="The mcp market does not exist or you do not have permission to access it"
|
||||
)
|
||||
|
||||
# 2. Deleting mcp market
|
||||
mcp_market_service.delete_mcp_market_by_id(db, mcp_market_id=mcp_market_id, current_user=current_user)
|
||||
api_logger.info(f"The mcp market has been successfully deleted: (ID: {mcp_market_id})")
|
||||
return success(msg="The mcp market has been successfully deleted")
|
||||
except Exception as e:
|
||||
api_logger.error(f"Failed to delete from the mcp market: mcp_market_id={mcp_market_id} - {str(e)}")
|
||||
raise
|
||||
@@ -633,12 +633,11 @@ async def get_knowledge_type_stats_api(
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
统计当前空间下各知识库类型的数量,包含 General | Web | Third-party | Folder | memory。
|
||||
统计当前空间下各知识库类型的数量,包含 General | Web | Third-party | Folder。
|
||||
会对缺失类型补 0,返回字典形式。
|
||||
可选按状态过滤。
|
||||
- 知识库类型根据当前用户的 current_workspace_id 过滤
|
||||
- memory 是 Neo4j 中 Chunk 的数量,根据 end_user_id (end_user_id) 过滤
|
||||
- 如果用户没有当前工作空间或未提供 end_user_id,对应的统计返回 0
|
||||
- 如果用户没有当前工作空间,对应的统计返回 0
|
||||
"""
|
||||
api_logger.info(f"Knowledge type stats requested for workspace_id: {current_user.current_workspace_id}, end_user_id: {end_user_id}")
|
||||
try:
|
||||
|
||||
@@ -9,6 +9,7 @@ from app.schemas.response_schema import ApiResponse
|
||||
|
||||
from app.services import memory_dashboard_service, memory_storage_service, workspace_service
|
||||
from app.services.memory_agent_service import get_end_users_connected_configs_batch
|
||||
from app.services.app_statistics_service import AppStatisticsService
|
||||
from app.core.logging_config import get_api_logger
|
||||
|
||||
# 获取API专用日志器
|
||||
@@ -469,6 +470,8 @@ async def get_chunk_insight(
|
||||
@router.get("/dashboard_data", response_model=ApiResponse)
|
||||
async def dashboard_data(
|
||||
end_user_id: Optional[str] = Query(None, description="可选的用户ID"),
|
||||
start_date: Optional[int] = Query(None, description="开始时间戳(毫秒)"),
|
||||
end_date: Optional[int] = Query(None, description="结束时间戳(毫秒)"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
@@ -503,6 +506,15 @@ async def dashboard_data(
|
||||
workspace_id = current_user.current_workspace_id
|
||||
api_logger.info(f"用户 {current_user.username} 请求获取工作空间 {workspace_id} 的dashboard整合数据")
|
||||
|
||||
# 如果没有提供时间范围,默认使用最近30天
|
||||
if start_date is None or end_date is None:
|
||||
from datetime import datetime, timedelta
|
||||
end_dt = datetime.now()
|
||||
start_dt = end_dt - timedelta(days=30)
|
||||
end_date = int(end_dt.timestamp() * 1000)
|
||||
start_date = int(start_dt.timestamp() * 1000)
|
||||
api_logger.info(f"使用默认时间范围: {start_dt} 到 {end_dt}")
|
||||
|
||||
# 获取 storage_type,如果为 None 则使用默认值
|
||||
storage_type = workspace_service.get_workspace_storage_type(
|
||||
db=db,
|
||||
@@ -563,17 +575,22 @@ async def dashboard_data(
|
||||
except Exception as e:
|
||||
api_logger.warning(f"获取知识库类型统计失败: {str(e)}")
|
||||
|
||||
# 3. 获取API调用增量(total_api_call,转换为整数)
|
||||
# 3. 获取API调用统计(total_api_call)
|
||||
try:
|
||||
api_increment = memory_dashboard_service.get_workspace_api_increment(
|
||||
db=db,
|
||||
# 使用 AppStatisticsService 获取真实的API调用统计
|
||||
app_stats_service = AppStatisticsService(db)
|
||||
api_stats = app_stats_service.get_workspace_api_statistics(
|
||||
workspace_id=workspace_id,
|
||||
current_user=current_user
|
||||
start_date=start_date,
|
||||
end_date=end_date
|
||||
)
|
||||
neo4j_data["total_api_call"] = api_increment
|
||||
api_logger.info(f"成功获取API调用增量: {neo4j_data['total_api_call']}")
|
||||
# 计算总调用次数
|
||||
total_api_calls = sum(item.get("total_calls", 0) for item in api_stats)
|
||||
neo4j_data["total_api_call"] = total_api_calls
|
||||
api_logger.info(f"成功获取API调用统计: {neo4j_data['total_api_call']}")
|
||||
except Exception as e:
|
||||
api_logger.warning(f"获取API调用增量失败: {str(e)}")
|
||||
api_logger.error(f"获取API调用统计失败: {str(e)}")
|
||||
neo4j_data["total_api_call"] = 0
|
||||
|
||||
result["neo4j_data"] = neo4j_data
|
||||
api_logger.info("成功获取neo4j_data")
|
||||
@@ -602,10 +619,23 @@ async def dashboard_data(
|
||||
total_kb = memory_dashboard_service.get_rag_total_kb(db, current_user)
|
||||
rag_data["total_knowledge"] = total_kb
|
||||
|
||||
# total_api_call: 固定值
|
||||
rag_data["total_api_call"] = 1024
|
||||
# total_api_call: 使用 AppStatisticsService 获取真实的API调用统计
|
||||
try:
|
||||
app_stats_service = AppStatisticsService(db)
|
||||
api_stats = app_stats_service.get_workspace_api_statistics(
|
||||
workspace_id=workspace_id,
|
||||
start_date=start_date,
|
||||
end_date=end_date
|
||||
)
|
||||
# 计算总调用次数
|
||||
total_api_calls = sum(item.get("total_calls", 0) for item in api_stats)
|
||||
rag_data["total_api_call"] = total_api_calls
|
||||
api_logger.info(f"成功获取RAG模式API调用统计: {rag_data['total_api_call']}")
|
||||
except Exception as e:
|
||||
api_logger.warning(f"获取RAG模式API调用统计失败,使用默认值: {str(e)}")
|
||||
rag_data["total_api_call"] = 0
|
||||
|
||||
api_logger.info(f"成功获取RAG相关数据: memory={total_chunk}, app={len(apps_orm)}, knowledge={total_kb}")
|
||||
api_logger.info(f"成功获取RAG相关数据: memory={total_chunk}, app={len(apps_orm)}, knowledge={total_kb}, api_calls={rag_data['total_api_call']}")
|
||||
except Exception as e:
|
||||
api_logger.warning(f"获取RAG相关数据失败: {str(e)}")
|
||||
|
||||
|
||||
@@ -328,7 +328,7 @@ async def update_composite_model(
|
||||
|
||||
try:
|
||||
if model_data.type is not None:
|
||||
raise BusinessException("不允许更改模型类型和供应商", BizCode.INVALID_PARAMETER)
|
||||
raise BusinessException("不允许更改模型类型", BizCode.INVALID_PARAMETER)
|
||||
result_orm = await ModelConfigService.update_composite_model(db=db, model_id=model_id, model_data=model_data, tenant_id=current_user.tenant_id)
|
||||
api_logger.info(f"组合模型更新成功: {result_orm.name} (ID: {model_id})")
|
||||
|
||||
@@ -368,6 +368,9 @@ def update_model(
|
||||
更新模型配置
|
||||
"""
|
||||
api_logger.info(f"更新模型配置请求: model_id={model_id}, 用户: {current_user.username}, tenant_id={current_user.tenant_id}")
|
||||
|
||||
if model_data.type is not None or model_data.provider is not None:
|
||||
raise BusinessException("不允许更改模型类型和供应商", BizCode.INVALID_PARAMETER)
|
||||
|
||||
try:
|
||||
api_logger.debug(f"开始更新模型配置: model_id={model_id}")
|
||||
|
||||
@@ -2,15 +2,23 @@ from fastapi import APIRouter, Depends
|
||||
from sqlalchemy.orm import Session
|
||||
import uuid
|
||||
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.db import get_db
|
||||
from app.dependencies import get_current_user, get_current_superuser
|
||||
from app.models.user_model import User
|
||||
from app.schemas import user_schema
|
||||
from app.schemas.user_schema import ChangePasswordRequest, AdminChangePasswordRequest
|
||||
from app.schemas.user_schema import (
|
||||
ChangePasswordRequest,
|
||||
AdminChangePasswordRequest,
|
||||
SendEmailCodeRequest,
|
||||
VerifyEmailCodeRequest,
|
||||
VerifyPasswordRequest)
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.services import user_service
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.core.response_utils import success
|
||||
from app.core.security import verify_password
|
||||
|
||||
# 获取API专用日志器
|
||||
api_logger = get_api_logger()
|
||||
@@ -92,7 +100,7 @@ def get_current_user_info(
|
||||
result_schema.current_workspace_name = current_workspace.name
|
||||
|
||||
for ws in result.workspaces:
|
||||
if ws.workspace_id == current_user.current_workspace_id:
|
||||
if ws.workspace_id == current_user.current_workspace_id and ws.is_active:
|
||||
result_schema.role = ws.role
|
||||
break
|
||||
|
||||
@@ -120,6 +128,7 @@ def get_tenant_superusers(
|
||||
return success(data=superusers_schema, msg="租户超管列表获取成功")
|
||||
|
||||
|
||||
|
||||
@router.get("/{user_id}", response_model=ApiResponse)
|
||||
def get_user_info_by_id(
|
||||
user_id: uuid.UUID,
|
||||
@@ -180,4 +189,54 @@ async def admin_change_password(
|
||||
return success(msg="密码修改成功")
|
||||
else:
|
||||
api_logger.info(f"管理员密码重置成功: 用户 {request.user_id}, 随机密码已生成")
|
||||
return success(data=generated_password, msg="密码重置成功")
|
||||
return success(data=generated_password, msg="密码重置成功")
|
||||
|
||||
|
||||
@router.post("/verify_pwd", response_model=ApiResponse)
|
||||
def verify_pwd(
|
||||
request: VerifyPasswordRequest,
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""验证当前用户密码"""
|
||||
api_logger.info(f"用户验证密码请求: {current_user.username}")
|
||||
|
||||
is_valid = verify_password(request.password, current_user.hashed_password)
|
||||
api_logger.info(f"用户密码验证结果: {current_user.username}, valid={is_valid}")
|
||||
if not is_valid:
|
||||
raise BusinessException("密码验证失败", code=BizCode.VALIDATION_FAILED)
|
||||
return success(data={"valid": is_valid}, msg="验证完成")
|
||||
|
||||
|
||||
@router.post("/send-email-code", response_model=ApiResponse)
|
||||
async def send_email_code(
|
||||
request: SendEmailCodeRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""发送邮箱验证码"""
|
||||
api_logger.info(f"用户请求发送邮箱验证码: {current_user.username}, email={request.email}")
|
||||
|
||||
await user_service.send_email_code_method(db=db, email=request.email, user_id=current_user.id)
|
||||
|
||||
api_logger.info(f"邮箱验证码已发送: {current_user.username}")
|
||||
return success(msg="验证码已发送到您的邮箱,请查收")
|
||||
|
||||
|
||||
@router.put("/change-email", response_model=ApiResponse)
|
||||
async def change_email(
|
||||
request: VerifyEmailCodeRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""验证验证码并修改邮箱"""
|
||||
api_logger.info(f"用户修改邮箱: {current_user.username}, new_email={request.new_email}")
|
||||
|
||||
await user_service.verify_and_change_email(
|
||||
db=db,
|
||||
user_id=current_user.id,
|
||||
new_email=request.new_email,
|
||||
code=request.code
|
||||
)
|
||||
|
||||
api_logger.info(f"用户邮箱修改成功: {current_user.username}")
|
||||
return success(msg="邮箱修改成功")
|
||||
|
||||
4
api/app/core/__init__.py
Normal file
4
api/app/core/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
# -*- coding: UTF-8 -*-
|
||||
# Author: Eternity
|
||||
# @Email: 1533512157@qq.com
|
||||
# @Time : 2026/2/9 16:24
|
||||
@@ -193,16 +193,29 @@ class Settings:
|
||||
CELERY_BROKER: int = int(os.getenv("CELERY_BROKER", "1"))
|
||||
CELERY_BACKEND: int = int(os.getenv("CELERY_BACKEND", "2"))
|
||||
|
||||
# SMTP Email Configuration
|
||||
SMTP_SERVER: str = os.getenv("SMTP_SERVER", "smtp.gmail.com")
|
||||
SMTP_PORT: int = int(os.getenv("SMTP_PORT", "587"))
|
||||
SMTP_USER: str = os.getenv("SMTP_USER", "")
|
||||
SMTP_PASSWORD: str = os.getenv("SMTP_PASSWORD", "")
|
||||
|
||||
REFLECTION_INTERVAL_SECONDS: float = float(os.getenv("REFLECTION_INTERVAL_SECONDS", "300"))
|
||||
HEALTH_CHECK_SECONDS: float = float(os.getenv("HEALTH_CHECK_SECONDS", "600"))
|
||||
MEMORY_INCREMENT_INTERVAL_HOURS: float = float(os.getenv("MEMORY_INCREMENT_INTERVAL_HOURS", "24"))
|
||||
DEFAULT_WORKSPACE_ID: Optional[str] = os.getenv("DEFAULT_WORKSPACE_ID", None)
|
||||
REFLECTION_INTERVAL_TIME: Optional[str] = int(os.getenv("REFLECTION_INTERVAL_TIME", 30))
|
||||
|
||||
# Memory Cache Regeneration Configuration
|
||||
MEMORY_CACHE_REGENERATION_HOURS: int = int(os.getenv("MEMORY_CACHE_REGENERATION_HOURS", "24"))
|
||||
|
||||
# Memory Module Configuration (internal)
|
||||
# Periodic Task Schedule Configuration
|
||||
# workspace_reflection: 每隔多少秒执行一次
|
||||
WORKSPACE_REFLECTION_INTERVAL_SECONDS: int = int(os.getenv("WORKSPACE_REFLECTION_INTERVAL_SECONDS", "30"))
|
||||
# forgetting_cycle: 每隔多少小时执行一次
|
||||
FORGETTING_CYCLE_INTERVAL_HOURS: int = int(os.getenv("FORGETTING_CYCLE_INTERVAL_HOURS", "24"))
|
||||
# implicit_emotions_update: 每天几点执行(小时,0-23)
|
||||
IMPLICIT_EMOTIONS_UPDATE_HOUR: int = int(os.getenv("IMPLICIT_EMOTIONS_UPDATE_HOUR", "2"))
|
||||
# implicit_emotions_update: 每天几分执行(分钟,0-59)
|
||||
IMPLICIT_EMOTIONS_UPDATE_MINUTE: int = int(os.getenv("IMPLICIT_EMOTIONS_UPDATE_MINUTE", "0")) # Memory Module Configuration (internal)
|
||||
MEMORY_OUTPUT_DIR: str = os.getenv("MEMORY_OUTPUT_DIR", "logs/memory-output")
|
||||
MEMORY_CONFIG_DIR: str = os.getenv("MEMORY_CONFIG_DIR", "app/core/memory")
|
||||
|
||||
|
||||
4
api/app/core/workflow/engine/__init__.py
Normal file
4
api/app/core/workflow/engine/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
# -*- coding: UTF-8 -*-
|
||||
# Author: Eternity
|
||||
# @Email: 1533512157@qq.com
|
||||
# @Time : 2026/2/9 16:28
|
||||
281
api/app/core/workflow/engine/event_stream_handler.py
Normal file
281
api/app/core/workflow/engine/event_stream_handler.py
Normal file
@@ -0,0 +1,281 @@
|
||||
# -*- coding: UTF-8 -*-
|
||||
# Author: Eternity
|
||||
# @Email: 1533512157@qq.com
|
||||
# @Time : 2026/2/10 13:33
|
||||
import datetime
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.graph.state import CompiledStateGraph
|
||||
|
||||
from app.core.logging_config import get_logger
|
||||
from app.core.workflow.engine.stream_output_coordinator import StreamOutputCoordinator
|
||||
from app.core.workflow.engine.variable_pool import VariablePool
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class EventStreamHandler:
|
||||
def __init__(
|
||||
self,
|
||||
output_coordinator: StreamOutputCoordinator,
|
||||
variable_pool: VariablePool,
|
||||
execution_id: str,
|
||||
):
|
||||
self.coordinator = output_coordinator
|
||||
self.variable_pool = variable_pool
|
||||
self.execution_id = execution_id
|
||||
|
||||
def update_stream_output_status(self, activate: dict, data: dict):
|
||||
"""
|
||||
Update the stream output state of End nodes based on workflow state updates.
|
||||
|
||||
This method checks which nodes/scopes are activated and propagates
|
||||
activation to End nodes accordingly.
|
||||
|
||||
Args:
|
||||
activate (dict): Mapping of node_id -> bool indicating which nodes/scopes are activated.
|
||||
data (dict): Mapping of node_id -> node runtime data, including outputs.
|
||||
|
||||
Behavior:
|
||||
For each node in `data`:
|
||||
1. If the node is activated (`activate[node_id]` is True),
|
||||
retrieve its output status from `runtime_vars`.
|
||||
2. Call `_update_scope_activate` to propagate the activation
|
||||
to all relevant End nodes and update `self.activate_end`.
|
||||
"""
|
||||
for node_id in data.keys():
|
||||
if activate.get(node_id):
|
||||
node_output_status = self.variable_pool.get_value(f"{node_id}.output", default=None, strict=False)
|
||||
self.coordinator.update_scope_activation(node_id, status=node_output_status)
|
||||
|
||||
async def handle_updates_event(
|
||||
self,
|
||||
data: dict,
|
||||
graph: CompiledStateGraph,
|
||||
checkpoint_config: RunnableConfig
|
||||
):
|
||||
"""
|
||||
Handle workflow state update events ("updates") and stream active End node outputs.
|
||||
|
||||
Steps:
|
||||
1. Retrieve the current graph state.
|
||||
2. Extract node activation information from the state.
|
||||
3. Update the activation status of all End nodes.
|
||||
4. While there is an active End node:
|
||||
- Call _emit_active_chunks() to yield all currently active output segments.
|
||||
- After all segments are processed, update activate_end if there are remaining End nodes.
|
||||
5. Log a debug message indicating state update received.
|
||||
|
||||
Args:
|
||||
data (dict): The latest node state updates.
|
||||
graph (CompiledStateGraph): The compiled LangGraph state machine.
|
||||
checkpoint_config (RunnableConfig): Configuration for the current execution context.)
|
||||
|
||||
Yields:
|
||||
dict: Streamed output event, each chunk in the format:
|
||||
{"event": "message", "data": {"chunk": ...}}
|
||||
"""
|
||||
state = graph.get_state(config=checkpoint_config).values
|
||||
activate = state.get("activate", {})
|
||||
|
||||
self.update_stream_output_status(activate, data)
|
||||
wait = False
|
||||
while self.coordinator.activate_end and not wait:
|
||||
async for msg_event in self.coordinator.emit_activate_chunk(self.variable_pool):
|
||||
yield msg_event
|
||||
|
||||
if self.coordinator.activate_end:
|
||||
wait = True
|
||||
else:
|
||||
self.update_stream_output_status(activate, data)
|
||||
|
||||
logger.debug(f"[UPDATES] Received state update from nodes: {list(data.keys())} "
|
||||
f"- execution_id: {self.execution_id}")
|
||||
|
||||
async def handle_node_chunk_event(self, data: dict):
|
||||
"""
|
||||
Handle streaming chunk events from individual nodes ("node_chunk").
|
||||
|
||||
This method processes output segments for the currently active End node.
|
||||
If the segment depends on the provided node_id:
|
||||
- If the node has finished execution (`done=True`), advance the cursor.
|
||||
- If all segments are processed, deactivate the End node.
|
||||
- Otherwise, yield the current chunk as a streaming message.
|
||||
|
||||
Args:
|
||||
data (dict): Node chunk event data, expected keys:
|
||||
- "node_id": ID of the node producing this chunk
|
||||
- "chunk": Chunk of output text
|
||||
- "done": Boolean indicating whether the node finished producing output
|
||||
|
||||
Yields:
|
||||
dict: Streaming message event in the format:
|
||||
{"event": "message", "data": {"chunk": ...}}
|
||||
"""
|
||||
node_id = data.get("node_id")
|
||||
if self.coordinator.activate_end:
|
||||
end_info = self.coordinator.current_activate_end_info
|
||||
if not end_info or end_info.cursor >= len(end_info.outputs):
|
||||
return
|
||||
current_output = end_info.outputs[end_info.cursor]
|
||||
if current_output.is_variable and current_output.depends_on_scope(node_id):
|
||||
if data.get("done"):
|
||||
end_info.cursor += 1
|
||||
if end_info.cursor >= len(end_info.outputs):
|
||||
self.coordinator.pop_current_activate_end()
|
||||
else:
|
||||
yield {
|
||||
"event": "message",
|
||||
"data": {
|
||||
"chunk": data.get("chunk")
|
||||
}
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
async def handle_node_error_event(data: dict):
|
||||
"""
|
||||
Handle node error events ("node_error") during workflow execution.
|
||||
|
||||
This method streams an error event for a node that has failed. The event
|
||||
contains the node ID, status, input data, elapsed time, and error message.
|
||||
|
||||
Args:
|
||||
data (dict): Node error event data, expected keys:
|
||||
- "node_id": ID of the node that failed
|
||||
- "input_data": The input data that caused the error
|
||||
- "elapsed_time": Execution time before the error occurred
|
||||
- "error": Error message or exception string
|
||||
|
||||
Yields:
|
||||
dict: Node error event in the format:
|
||||
{
|
||||
"event": "node_error",
|
||||
"data": {
|
||||
"node_id": str,
|
||||
"status": "failed",
|
||||
"input": ...,
|
||||
"elapsed_time": float,
|
||||
"output": None,
|
||||
"error": str
|
||||
}
|
||||
}
|
||||
"""
|
||||
node_id = data.get("node_id")
|
||||
yield {
|
||||
"event": "node_error",
|
||||
"data": {
|
||||
"node_id": node_id,
|
||||
"status": "failed",
|
||||
"input": data.get("input_data"),
|
||||
"elapsed_time": data.get("elapsed_time"),
|
||||
"output": None,
|
||||
"error": data.get("error")
|
||||
}
|
||||
}
|
||||
|
||||
async def handle_debug_event(self, data: dict, input_data: dict):
|
||||
"""
|
||||
Handle debug events ("debug") related to node execution status.
|
||||
|
||||
This method streams debug events for nodes, including when a node starts
|
||||
execution ("node_start") and when it completes execution ("node_end").
|
||||
It filters out nodes with names starting with "nop" as no-operation nodes.
|
||||
|
||||
Args:
|
||||
data (dict): Debug event data, expected keys:
|
||||
- "type": Event type ("task" for start, "task_result" for completion)
|
||||
- "payload": Node-related information, including:
|
||||
- "name": Node name / ID
|
||||
- "input": Node input data (for "task" type)
|
||||
- "result": Node execution result (for "task_result" type)
|
||||
- "timestamp": ISO timestamp string of the event
|
||||
input_data (dict): Original workflow input data (used to get conversation_id)
|
||||
|
||||
Yields:
|
||||
dict: Node debug event in one of the following formats:
|
||||
1. Node start:
|
||||
{
|
||||
"event": "node_start",
|
||||
"data": {
|
||||
"node_id": str,
|
||||
"conversation_id": str,
|
||||
"execution_id": str,
|
||||
"timestamp": int (ms)
|
||||
}
|
||||
}
|
||||
2. Node end:
|
||||
{
|
||||
"event": "node_end",
|
||||
"data": {
|
||||
"node_id": str,
|
||||
"conversation_id": str,
|
||||
"execution_id": str,
|
||||
"timestamp": int (ms),
|
||||
"input": dict,
|
||||
"output": Any,
|
||||
"elapsed_time": float
|
||||
}
|
||||
}
|
||||
"""
|
||||
event_type = data.get("type")
|
||||
payload = data.get("payload", {})
|
||||
node_name = payload.get("name")
|
||||
conversation_id = input_data.get("conversation_id")
|
||||
|
||||
# Skip no-operation nodes
|
||||
if node_name and node_name.startswith("nop"):
|
||||
return
|
||||
|
||||
if event_type == "task":
|
||||
# Node starts execution
|
||||
inputv = payload.get("input", {})
|
||||
if not inputv.get("activate", {}).get(node_name):
|
||||
return
|
||||
|
||||
logger.info(
|
||||
f"[NODE-START] Node '{node_name}' execution started - execution_id: {self.execution_id}")
|
||||
|
||||
yield {
|
||||
"event": "node_start",
|
||||
"data": {
|
||||
"node_id": node_name,
|
||||
"conversation_id": conversation_id,
|
||||
"execution_id": self.execution_id,
|
||||
"timestamp": int(datetime.datetime.fromisoformat(
|
||||
data.get("timestamp")
|
||||
).timestamp() * 1000),
|
||||
}
|
||||
}
|
||||
elif event_type == "task_result":
|
||||
# Node execution completed
|
||||
result = payload.get("result", {})
|
||||
if not result.get("activate", {}).get(node_name):
|
||||
return
|
||||
|
||||
logger.info(
|
||||
f"[NODE-END] Node '{node_name}' execution completed - execution_id: {self.execution_id}")
|
||||
|
||||
yield {
|
||||
"event": "node_end",
|
||||
"data": {
|
||||
"node_id": node_name,
|
||||
"conversation_id": conversation_id,
|
||||
"execution_id": self.execution_id,
|
||||
"timestamp": int(datetime.datetime.fromisoformat(
|
||||
data.get("timestamp")
|
||||
).timestamp() * 1000),
|
||||
"input": result.get("node_outputs", {}).get(node_name, {}).get("input"),
|
||||
"output": result.get("node_outputs", {}).get(node_name, {}).get("output"),
|
||||
"elapsed_time": result.get("node_outputs", {}).get(node_name, {}).get("elapsed_time"),
|
||||
"token_usage": result.get("node_outputs", {}).get(node_name, {}).get("token_usage")
|
||||
}
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
async def handle_cycle_item_event(data: dict):
|
||||
yield {
|
||||
"event": "cycle_item",
|
||||
"data": data.get("data")
|
||||
}
|
||||
|
||||
|
||||
@@ -1,177 +1,28 @@
|
||||
# -*- coding: UTF-8 -*-
|
||||
# Author: Eternity
|
||||
# @Email: 1533512157@qq.com
|
||||
# @Time : 2026/2/10 13:33
|
||||
import logging
|
||||
import re
|
||||
import uuid
|
||||
from collections import defaultdict
|
||||
from functools import lru_cache
|
||||
from typing import Any
|
||||
from typing import Any, Iterable
|
||||
|
||||
from langgraph.checkpoint.memory import InMemorySaver
|
||||
from langgraph.graph import START, END
|
||||
from langgraph.graph.state import CompiledStateGraph, StateGraph
|
||||
from langgraph.types import Send
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.core.workflow.expression_evaluator import evaluate_condition
|
||||
from app.core.workflow.nodes import WorkflowState, NodeFactory
|
||||
from app.core.workflow.engine.state_manager import WorkflowState
|
||||
from app.core.workflow.engine.stream_output_coordinator import OutputContent, StreamOutputConfig
|
||||
from app.core.workflow.engine.variable_pool import VariablePool
|
||||
from app.core.workflow.nodes import NodeFactory
|
||||
from app.core.workflow.nodes.enums import NodeType, BRANCH_NODES
|
||||
from app.core.workflow.variable_pool import VariablePool
|
||||
from app.core.workflow.utils.expression_evaluator import evaluate_condition
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
SCOPE_PATTERN = re.compile(
|
||||
r"\{\{\s*([a-zA-Z_][a-zA-Z0-9_]*)\.[a-zA-Z0-9_]+\s*}}"
|
||||
)
|
||||
|
||||
|
||||
class OutputContent(BaseModel):
|
||||
"""
|
||||
Represents a single output segment of an End node.
|
||||
|
||||
An output segment can be either:
|
||||
- literal text (static string)
|
||||
- a variable placeholder (e.g. {{ node.field }})
|
||||
|
||||
Each segment has its own activation state, which is especially
|
||||
important in stream mode.
|
||||
"""
|
||||
|
||||
literal: str = Field(
|
||||
...,
|
||||
description="Raw output content. Can be literal text or a variable placeholder."
|
||||
)
|
||||
|
||||
activate: bool = Field(
|
||||
...,
|
||||
description=(
|
||||
"Whether this output segment is currently active.\n"
|
||||
"- True: allowed to be emitted/output\n"
|
||||
"- False: blocked until activated by branch control"
|
||||
)
|
||||
)
|
||||
|
||||
is_variable: bool = Field(
|
||||
...,
|
||||
description=(
|
||||
"Whether this segment represents a variable placeholder.\n"
|
||||
"True -> variable (e.g. {{ node.field }})\n"
|
||||
"False -> literal text"
|
||||
)
|
||||
)
|
||||
|
||||
_SCOPE: str | None = None
|
||||
|
||||
def get_scope(self) -> str:
|
||||
self._SCOPE = SCOPE_PATTERN.findall(self.literal)[0]
|
||||
return self._SCOPE
|
||||
|
||||
def depends_on_scope(self, scope: str) -> bool:
|
||||
"""
|
||||
Check if this segment depends on a given scope.
|
||||
|
||||
Args:
|
||||
scope (str): Node ID or special variable prefix (e.g., "sys").
|
||||
|
||||
Returns:
|
||||
bool: True if this segment references the given scope.
|
||||
"""
|
||||
if self._SCOPE:
|
||||
return self._SCOPE == scope
|
||||
return self.get_scope() == scope
|
||||
|
||||
|
||||
class StreamOutputConfig(BaseModel):
|
||||
"""
|
||||
Streaming output configuration for an End node.
|
||||
|
||||
This configuration describes how the End node output behaves in streaming mode,
|
||||
including:
|
||||
- whether output emission is globally activated
|
||||
- which upstream branch/control nodes gate the activation
|
||||
- how each parsed output segment is streamed and activated
|
||||
"""
|
||||
|
||||
activate: bool = Field(
|
||||
...,
|
||||
description=(
|
||||
"Global activation flag for the End node output.\n"
|
||||
"When False, output segments should not be emitted even if available.\n"
|
||||
"This flag typically becomes True once required control branch conditions "
|
||||
"are satisfied."
|
||||
)
|
||||
)
|
||||
|
||||
control_nodes: dict[str, list[str]] = Field(
|
||||
...,
|
||||
description=(
|
||||
"Control branch conditions for this End node output.\n"
|
||||
"Mapping of `branch_node_id -> expected_branch_label`.\n"
|
||||
"The End node output becomes globally active when a controlling branch node "
|
||||
"reports a matching completion status."
|
||||
)
|
||||
)
|
||||
|
||||
outputs: list[OutputContent] = Field(
|
||||
...,
|
||||
description=(
|
||||
"Ordered list of output segments parsed from the output template.\n"
|
||||
"Each segment represents either a literal text block or a variable placeholder "
|
||||
"that may be activated independently."
|
||||
)
|
||||
)
|
||||
|
||||
cursor: int = Field(
|
||||
...,
|
||||
description=(
|
||||
"Streaming cursor index.\n"
|
||||
"Indicates the next output segment index to be emitted.\n"
|
||||
"Segments with index < cursor are considered already streamed."
|
||||
)
|
||||
)
|
||||
|
||||
def update_activate(self, scope: str, status=None):
|
||||
"""
|
||||
Update streaming activation state based on an upstream node or special variable.
|
||||
|
||||
Args:
|
||||
scope (str):
|
||||
Identifier of the completed upstream entity.
|
||||
- If a control branch node, it should match a key in `control_nodes`.
|
||||
- If a variable placeholder (e.g., "sys.xxx"), it may appear in output segments.
|
||||
status (optional):
|
||||
Completion status of the control branch node.
|
||||
Required when `scope` refers to a control node.
|
||||
|
||||
Behavior:
|
||||
1. Control branch nodes:
|
||||
- If `scope` matches a key in `control_nodes` and `status` matches the expected
|
||||
branch label, the End node output becomes globally active (`activate = True`).
|
||||
|
||||
2. Variable output segments:
|
||||
- For each segment that is a variable (`is_variable=True`):
|
||||
- If the segment literal references `scope`, mark the segment as active.
|
||||
- This applies both to regular node variables (e.g., "node_id.field")
|
||||
and special system variables (e.g., "sys.xxx").
|
||||
|
||||
Notes:
|
||||
- This method does not emit output or advance the streaming cursor.
|
||||
- It only updates activation flags based on upstream events or special variables.
|
||||
"""
|
||||
|
||||
# Case 1: resolve control branch dependency
|
||||
if scope in self.control_nodes.keys():
|
||||
if status is None:
|
||||
raise RuntimeError("[Stream Output] Control node activation status not provided")
|
||||
if status in self.control_nodes[scope]:
|
||||
self.activate = True
|
||||
|
||||
# Case 2: activate variable segments related to this node
|
||||
for i in range(len(self.outputs)):
|
||||
if (
|
||||
self.outputs[i].is_variable
|
||||
and self.outputs[i].depends_on_scope(scope)
|
||||
):
|
||||
self.outputs[i].activate = True
|
||||
|
||||
|
||||
class GraphBuilder:
|
||||
def __init__(
|
||||
@@ -230,7 +81,7 @@ class GraphBuilder:
|
||||
raise RuntimeError(f"Node not found: Id={node_id}")
|
||||
|
||||
@staticmethod
|
||||
def _merge_control_nodes(control_nodes: list[tuple[str, str]]) -> dict[str, list]:
|
||||
def _merge_control_nodes(control_nodes: Iterable[tuple[str, str]]) -> dict[str, list]:
|
||||
result = defaultdict(list)
|
||||
for node in control_nodes:
|
||||
result[node[0]].append(node[1])
|
||||
104
api/app/core/workflow/engine/result_builder.py
Normal file
104
api/app/core/workflow/engine/result_builder.py
Normal file
@@ -0,0 +1,104 @@
|
||||
# -*- coding: UTF-8 -*-
|
||||
# Author: Eternity
|
||||
# @Email: 1533512157@qq.com
|
||||
# @Time : 2026/2/10 13:33
|
||||
from app.core.workflow.engine.variable_pool import VariablePool
|
||||
|
||||
|
||||
class WorkflowResultBuilder:
|
||||
def build_final_output(
|
||||
self,
|
||||
result: dict,
|
||||
variable_pool: VariablePool,
|
||||
elapsed_time: float,
|
||||
final_output: str,
|
||||
):
|
||||
"""Construct the final standardized output of the workflow execution.
|
||||
|
||||
This method aggregates node outputs, token usage, conversation and system
|
||||
variables, messages, and other metadata into a consistent dictionary
|
||||
structure suitable for returning from workflow execution.
|
||||
|
||||
Args:
|
||||
result (dict): The runtime state returned by the workflow graph execution.
|
||||
Expected keys include:
|
||||
- "node_outputs" (dict): Outputs of executed nodes.
|
||||
- "messages" (list): Conversation messages exchanged during execution.
|
||||
- "error" (str, optional): Error message if any node failed.
|
||||
variable_pool (VariablePool): Variable Pool
|
||||
elapsed_time (float): Total execution time in seconds.
|
||||
final_output (Any): The aggregated or final output content of the workflow
|
||||
(e.g., combined messages from all End nodes).
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing the final workflow execution result with keys:
|
||||
- "status": Execution status ("completed")
|
||||
- "output": Aggregated final output content
|
||||
- "variables": Namespace dictionary with:
|
||||
- "conv": Conversation variables
|
||||
- "sys": System variables
|
||||
- "node_outputs": Outputs from all executed nodes
|
||||
- "messages": Conversation messages exchanged
|
||||
- "conversation_id": ID of the current conversation
|
||||
- "elapsed_time": Total execution time in seconds
|
||||
- "token_usage": Aggregated token usage across nodes (if available)
|
||||
- "error": Error message if any occurred during execution
|
||||
"""
|
||||
node_outputs = result.get("node_outputs", {})
|
||||
token_usage = self.aggregate_token_usage(node_outputs)
|
||||
conversation_id = variable_pool.get_value("sys.conversation_id")
|
||||
|
||||
return {
|
||||
"status": "completed",
|
||||
"output": final_output,
|
||||
"variables": {
|
||||
"conv": variable_pool.get_all_conversation_vars(),
|
||||
"sys": variable_pool.get_all_system_vars()
|
||||
},
|
||||
"node_outputs": node_outputs,
|
||||
"messages": result.get("messages", []),
|
||||
"conversation_id": conversation_id,
|
||||
"elapsed_time": elapsed_time,
|
||||
"token_usage": token_usage,
|
||||
"error": result.get("error"),
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def aggregate_token_usage(node_outputs: dict) -> dict[str, int] | None:
|
||||
"""
|
||||
Aggregate token usage statistics across all nodes.
|
||||
|
||||
Args:
|
||||
node_outputs (dict): A dictionary of all node outputs.
|
||||
|
||||
Returns:
|
||||
dict | None: Aggregated token usage in the format:
|
||||
{
|
||||
"prompt_tokens": int,
|
||||
"completion_tokens": int,
|
||||
"total_tokens": int
|
||||
}
|
||||
Returns None if no token usage information is available.
|
||||
"""
|
||||
total_prompt_tokens = 0
|
||||
total_completion_tokens = 0
|
||||
total_tokens = 0
|
||||
has_token_info = False
|
||||
|
||||
for node_output in node_outputs.values():
|
||||
if isinstance(node_output, dict):
|
||||
token_usage = node_output.get("token_usage")
|
||||
if token_usage and isinstance(token_usage, dict):
|
||||
has_token_info = True
|
||||
total_prompt_tokens += token_usage.get("prompt_tokens", 0)
|
||||
total_completion_tokens += token_usage.get("completion_tokens", 0)
|
||||
total_tokens += token_usage.get("total_tokens", 0)
|
||||
|
||||
if not has_token_info:
|
||||
return None
|
||||
|
||||
return {
|
||||
"prompt_tokens": total_prompt_tokens,
|
||||
"completion_tokens": total_completion_tokens,
|
||||
"total_tokens": total_tokens
|
||||
}
|
||||
29
api/app/core/workflow/engine/runtime_schema.py
Normal file
29
api/app/core/workflow/engine/runtime_schema.py
Normal file
@@ -0,0 +1,29 @@
|
||||
# -*- coding: UTF-8 -*-
|
||||
# Author: Eternity
|
||||
# @Email: 1533512157@qq.com
|
||||
# @Time : 2026/2/10 13:33
|
||||
import uuid
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class ExecutionContext(BaseModel):
|
||||
execution_id: str
|
||||
workspace_id: str
|
||||
user_id: str
|
||||
checkpoint_config: RunnableConfig
|
||||
|
||||
@classmethod
|
||||
def create(cls, execution_id: str, workspace_id: str, user_id: str):
|
||||
return cls(
|
||||
execution_id=execution_id,
|
||||
workspace_id=workspace_id,
|
||||
user_id=user_id,
|
||||
checkpoint_config=RunnableConfig(
|
||||
configurable={
|
||||
"thread_id": uuid.uuid4(),
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
99
api/app/core/workflow/engine/state_manager.py
Normal file
99
api/app/core/workflow/engine/state_manager.py
Normal file
@@ -0,0 +1,99 @@
|
||||
# -*- coding: UTF-8 -*-
|
||||
# Author: Eternity
|
||||
# @Email: 1533512157@qq.com
|
||||
# @Time : 2026/2/10 13:33
|
||||
from typing import Annotated, Any
|
||||
|
||||
from app.core.workflow.engine.runtime_schema import ExecutionContext
|
||||
from app.core.workflow.nodes.enums import NodeType
|
||||
|
||||
|
||||
def merge_activate_state(x, y):
|
||||
return {
|
||||
k: x.get(k, False) or y.get(k, False)
|
||||
for k in set(x) | set(y)
|
||||
}
|
||||
|
||||
|
||||
def merge_looping_state(x, y):
|
||||
return y if y > x else x
|
||||
|
||||
|
||||
class WorkflowState(dict):
|
||||
"""Workflow state
|
||||
|
||||
The state object passed between nodes in a workflow, containing messages, variables, node outputs, etc.
|
||||
"""
|
||||
__required_keys__ = frozenset({
|
||||
"messages",
|
||||
"cycle_nodes",
|
||||
"looping",
|
||||
"node_outputs",
|
||||
"execution_id",
|
||||
"workspace_id",
|
||||
"user_id",
|
||||
"activate",
|
||||
})
|
||||
__optional_keys__ = frozenset({
|
||||
"error",
|
||||
"error_node",
|
||||
})
|
||||
|
||||
# List of messages (append mode)
|
||||
messages: Annotated[list[dict[str, str]], lambda x, y: y]
|
||||
|
||||
# Set of loop node IDs, used for assigning values in loop nodes
|
||||
cycle_nodes: list
|
||||
looping: Annotated[int, merge_looping_state]
|
||||
|
||||
# Node outputs (stores execution results of each node for variable references)
|
||||
# Uses a custom merge function to combine new node outputs into the existing dictionary
|
||||
node_outputs: Annotated[dict[str, Any], lambda x, y: {**x, **y}]
|
||||
|
||||
# Execution context
|
||||
execution_id: str
|
||||
workspace_id: str
|
||||
user_id: str
|
||||
|
||||
# Error information (for error edges)
|
||||
error: str | None
|
||||
error_node: str | None
|
||||
|
||||
# node activate status
|
||||
activate: Annotated[dict[str, bool], merge_activate_state]
|
||||
|
||||
|
||||
class WorkflowStateManager:
|
||||
def create_initial_state(
|
||||
self,
|
||||
workflow_config: dict,
|
||||
input_data: dict,
|
||||
execution_context: ExecutionContext,
|
||||
start_node_id: str
|
||||
) -> WorkflowState:
|
||||
conversation_messages = input_data.get("conv_messages", [])
|
||||
|
||||
return WorkflowState(
|
||||
messages=conversation_messages,
|
||||
node_outputs={},
|
||||
execution_id=execution_context.execution_id,
|
||||
workspace_id=execution_context.workspace_id,
|
||||
user_id=execution_context.user_id,
|
||||
error=None,
|
||||
error_node=None,
|
||||
cycle_nodes=self._identify_cycle_nodes(workflow_config),
|
||||
looping=0,
|
||||
activate={
|
||||
start_node_id: True
|
||||
}
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _identify_cycle_nodes(
|
||||
workflow_config: dict
|
||||
):
|
||||
return [
|
||||
node.get("id")
|
||||
for node in workflow_config.get("nodes")
|
||||
if node.get("type") in [NodeType.LOOP, NodeType.ITERATION]
|
||||
]
|
||||
327
api/app/core/workflow/engine/stream_output_coordinator.py
Normal file
327
api/app/core/workflow/engine/stream_output_coordinator.py
Normal file
@@ -0,0 +1,327 @@
|
||||
# -*- coding: UTF-8 -*-
|
||||
# Author: Eternity
|
||||
# @Email: 1533512157@qq.com
|
||||
# @Time : 2026/2/9 15:11
|
||||
import re
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.core.logging_config import get_logger
|
||||
from app.core.workflow.engine.variable_pool import VariablePool
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
SCOPE_PATTERN = re.compile(
|
||||
r"\{\{\s*([a-zA-Z_][a-zA-Z0-9_]*)\.[a-zA-Z0-9_]+\s*}}"
|
||||
)
|
||||
|
||||
|
||||
class OutputContent(BaseModel):
|
||||
"""
|
||||
Represents a single output segment of an End node.
|
||||
|
||||
An output segment can be either:
|
||||
- literal text (static string)
|
||||
- a variable placeholder (e.g. {{ node.field }})
|
||||
|
||||
Each segment has its own activation state, which is especially
|
||||
important in stream mode.
|
||||
"""
|
||||
|
||||
literal: str = Field(
|
||||
...,
|
||||
description="Raw output content. Can be literal text or a variable placeholder."
|
||||
)
|
||||
|
||||
activate: bool = Field(
|
||||
...,
|
||||
description=(
|
||||
"Whether this output segment is currently active.\n"
|
||||
"- True: allowed to be emitted/output\n"
|
||||
"- False: blocked until activated by branch control"
|
||||
)
|
||||
)
|
||||
|
||||
is_variable: bool = Field(
|
||||
...,
|
||||
description=(
|
||||
"Whether this segment represents a variable placeholder.\n"
|
||||
"True -> variable (e.g. {{ node.field }})\n"
|
||||
"False -> literal text"
|
||||
)
|
||||
)
|
||||
|
||||
_SCOPE: str | None = None
|
||||
|
||||
def get_scope(self) -> str:
|
||||
self._SCOPE = SCOPE_PATTERN.findall(self.literal)[0]
|
||||
return self._SCOPE
|
||||
|
||||
def depends_on_scope(self, scope: str) -> bool:
|
||||
"""
|
||||
Check if this segment depends on a given scope.
|
||||
|
||||
Args:
|
||||
scope (str): Node ID or special variable prefix (e.g., "sys").
|
||||
|
||||
Returns:
|
||||
bool: True if this segment references the given scope.
|
||||
"""
|
||||
if self._SCOPE:
|
||||
return self._SCOPE == scope
|
||||
return self.get_scope() == scope
|
||||
|
||||
|
||||
class StreamOutputConfig(BaseModel):
|
||||
"""
|
||||
Streaming output configuration for an End node.
|
||||
|
||||
This configuration describes how the End node output behaves in streaming mode,
|
||||
including:
|
||||
- whether output emission is globally activated
|
||||
- which upstream branch/control nodes gate the activation
|
||||
- how each parsed output segment is streamed and activated
|
||||
"""
|
||||
|
||||
activate: bool = Field(
|
||||
...,
|
||||
description=(
|
||||
"Global activation flag for the End node output.\n"
|
||||
"When False, output segments should not be emitted even if available.\n"
|
||||
"This flag typically becomes True once required control branch conditions "
|
||||
"are satisfied."
|
||||
)
|
||||
)
|
||||
|
||||
control_nodes: dict[str, list[str]] = Field(
|
||||
...,
|
||||
description=(
|
||||
"Control branch conditions for this End node output.\n"
|
||||
"Mapping of `branch_node_id -> expected_branch_label`.\n"
|
||||
"The End node output becomes globally active when a controlling branch node "
|
||||
"reports a matching completion status."
|
||||
)
|
||||
)
|
||||
|
||||
outputs: list[OutputContent] = Field(
|
||||
...,
|
||||
description=(
|
||||
"Ordered list of output segments parsed from the output template.\n"
|
||||
"Each segment represents either a literal text block or a variable placeholder "
|
||||
"that may be activated independently."
|
||||
)
|
||||
)
|
||||
|
||||
cursor: int = Field(
|
||||
...,
|
||||
description=(
|
||||
"Streaming cursor index.\n"
|
||||
"Indicates the next output segment index to be emitted.\n"
|
||||
"Segments with index < cursor are considered already streamed."
|
||||
)
|
||||
)
|
||||
|
||||
def update_activate(self, scope: str, status=None):
|
||||
"""
|
||||
Update streaming activation state based on an upstream node or special variable.
|
||||
|
||||
Args:
|
||||
scope (str):
|
||||
Identifier of the completed upstream entity.
|
||||
- If a control branch node, it should match a key in `control_nodes`.
|
||||
- If a variable placeholder (e.g., "sys.xxx"), it may appear in output segments.
|
||||
status (optional):
|
||||
Completion status of the control branch node.
|
||||
Required when `scope` refers to a control node.
|
||||
|
||||
Behavior:
|
||||
1. Control branch nodes:
|
||||
- If `scope` matches a key in `control_nodes` and `status` matches the expected
|
||||
branch label, the End node output becomes globally active (`activate = True`).
|
||||
|
||||
2. Variable output segments:
|
||||
- For each segment that is a variable (`is_variable=True`):
|
||||
- If the segment literal references `scope`, mark the segment as active.
|
||||
- This applies both to regular node variables (e.g., "node_id.field")
|
||||
and special system variables (e.g., "sys.xxx").
|
||||
|
||||
Notes:
|
||||
- This method does not emit output or advance the streaming cursor.
|
||||
- It only updates activation flags based on upstream events or special variables.
|
||||
"""
|
||||
|
||||
# Case 1: resolve control branch dependency
|
||||
if scope in self.control_nodes.keys():
|
||||
if status is None:
|
||||
raise RuntimeError("[Stream Output] Control node activation status not provided")
|
||||
if status in self.control_nodes[scope]:
|
||||
self.activate = True
|
||||
|
||||
# Case 2: activate variable segments related to this node
|
||||
for i in range(len(self.outputs)):
|
||||
if (
|
||||
self.outputs[i].is_variable
|
||||
and self.outputs[i].depends_on_scope(scope)
|
||||
):
|
||||
self.outputs[i].activate = True
|
||||
|
||||
|
||||
class StreamOutputCoordinator:
|
||||
def __init__(self):
|
||||
self.end_outputs: dict[str, StreamOutputConfig] = {}
|
||||
self.activate_end: str | None = None
|
||||
|
||||
def initialize_end_outputs(
|
||||
self,
|
||||
end_node_map: dict[str, StreamOutputConfig]
|
||||
):
|
||||
self.end_outputs = end_node_map
|
||||
|
||||
@property
|
||||
def current_activate_end_info(self):
|
||||
return self.end_outputs.get(self.activate_end)
|
||||
|
||||
def pop_current_activate_end(self):
|
||||
self.end_outputs.pop(self.activate_end)
|
||||
self.activate_end = None
|
||||
|
||||
def update_scope_activation(
|
||||
self,
|
||||
scope: str,
|
||||
status: str | None = None
|
||||
):
|
||||
"""
|
||||
Update the activation state of all End nodes based on a completed scope (node or variable).
|
||||
|
||||
Iterates over all End nodes in `self.end_outputs` and calls
|
||||
`update_activate` on each, which may:
|
||||
- Activate variable segments that depend on the completed node/scope.
|
||||
- Activate the entire End node output if any control conditions are met.
|
||||
|
||||
If any End node becomes active and `self.activate_end` is not yet set,
|
||||
this node will be marked as the currently active End node.
|
||||
|
||||
Args:
|
||||
scope (str): The node ID or scope that has completed execution.
|
||||
status (str | None): Optional status of the node (used for branch/control nodes).
|
||||
"""
|
||||
for node in self.end_outputs.keys():
|
||||
self.end_outputs[node].update_activate(scope, status)
|
||||
if self.end_outputs[node].activate and self.activate_end is None:
|
||||
self.activate_end = node
|
||||
|
||||
async def emit_activate_chunk(
|
||||
self,
|
||||
variable_pool: VariablePool,
|
||||
force: bool = False
|
||||
) -> AsyncGenerator[dict[str, str | dict], None]:
|
||||
"""
|
||||
Process and yield all currently active output segments for the currently active End node.
|
||||
|
||||
This method handles stream-mode output for an End node by iterating through its output segments
|
||||
(`OutputContent`). Only segments marked as active (`activate=True`) are processed, unless
|
||||
`force=True`, which allows all segments to be processed regardless of their activation state.
|
||||
|
||||
Behavior:
|
||||
1. Iterates from the current `cursor` position to the end of the outputs list.
|
||||
2. For each segment:
|
||||
- If the segment is literal text (`is_variable=False`), append it directly.
|
||||
- If the segment is a variable (`is_variable=True`), evaluate it using
|
||||
`evaluate_expression` with the given `node_outputs` and `variables`,
|
||||
then transform the result with `_trans_output_string`.
|
||||
3. Yield a stream event of type "message" containing the processed chunk.
|
||||
4. Move the `cursor` forward after processing each segment.
|
||||
5. When all segments have been processed, remove this End node from `end_outputs`
|
||||
and reset `activate_end` to None.
|
||||
|
||||
Args:
|
||||
variable_pool (VariablePool): Pool of variables for evaluating segment values.
|
||||
force (bool, default=False): If True, process segments even if `activate=False`.
|
||||
|
||||
Yields:
|
||||
dict: A stream event of type "message" containing the processed chunk.
|
||||
|
||||
Notes:
|
||||
- Segments that fail evaluation (ValueError) are skipped with a warning logged.
|
||||
- This method only processes the currently active End node (`self.activate_end`).
|
||||
- Use `force=True` for final emission regardless of activation state.
|
||||
"""
|
||||
end_info = self.end_outputs[self.activate_end]
|
||||
|
||||
while end_info.cursor < len(end_info.outputs):
|
||||
final_chunk = ''
|
||||
current_segment = end_info.outputs[end_info.cursor]
|
||||
|
||||
if not current_segment.activate and not force:
|
||||
# Stop processing until this segment becomes active
|
||||
break
|
||||
|
||||
# Literal segment
|
||||
if not current_segment.is_variable:
|
||||
final_chunk += current_segment.literal
|
||||
else:
|
||||
# Variable segment: evaluate and transform
|
||||
try:
|
||||
chunk = variable_pool.get_literal(current_segment.literal)
|
||||
final_chunk += chunk
|
||||
except Exception as e:
|
||||
# Log failed evaluation but continue streaming
|
||||
logger.warning(f"[STREAM] Failed to evaluate segment: {current_segment.literal}, error: {e}")
|
||||
|
||||
if final_chunk:
|
||||
logger.info(f"[STREAM] StreamOutput Node:{self.activate_end}, chunk:{final_chunk}")
|
||||
yield {
|
||||
"event": "message",
|
||||
"data": {
|
||||
"chunk": final_chunk
|
||||
}
|
||||
}
|
||||
|
||||
# Advance cursor after processing
|
||||
end_info.cursor += 1
|
||||
|
||||
if end_info.cursor >= len(end_info.outputs):
|
||||
self.end_outputs.pop(self.activate_end)
|
||||
self.activate_end = None
|
||||
|
||||
async def flush_remaining_chunk(
|
||||
self,
|
||||
variable_pool: VariablePool
|
||||
) -> AsyncGenerator[dict[str, str | dict], None]:
|
||||
"""
|
||||
Flush and yield all remaining output segments from active End nodes.
|
||||
|
||||
This method ensures that any remaining chunks of output, which may not have
|
||||
been emitted during normal streaming due to activation conditions, are fully
|
||||
processed. It is typically called at the end of a workflow to guarantee
|
||||
that all output is delivered.
|
||||
|
||||
Behavior:
|
||||
1. Filter `end_outputs` to only keep End nodes that are still active.
|
||||
2. While there is an active End node (`self.activate_end`):
|
||||
- Call `_emit_active_chunks(force=True)` to emit all segments regardless
|
||||
of their activation state.
|
||||
- If the current End node finishes, move to the next active End node
|
||||
if any remain.
|
||||
|
||||
Yields:
|
||||
dict: Streamed output events in the format:
|
||||
{"event": "message", "data": {"chunk": ...}}
|
||||
"""
|
||||
# Keep only active End nodes
|
||||
self.end_outputs = {
|
||||
node_id: node_info
|
||||
for node_id, node_info in self.end_outputs.items()
|
||||
if node_info.activate
|
||||
}
|
||||
|
||||
if self.end_outputs or self.activate_end:
|
||||
while self.activate_end:
|
||||
# Force emit all remaining chunks of the active End node
|
||||
async for msg_event in self.emit_activate_chunk(variable_pool, force=True):
|
||||
yield msg_event
|
||||
|
||||
# Move to next active End node if current one is done
|
||||
if not self.activate_end and self.end_outputs:
|
||||
self.activate_end = list(self.end_outputs.keys())[0]
|
||||
@@ -1,14 +1,7 @@
|
||||
"""
|
||||
变量池 (Variable Pool)
|
||||
|
||||
工作流执行的数据中心,管理所有变量的存储和访问。
|
||||
|
||||
变量类型:
|
||||
1. 系统变量 (sys.*) - 系统内置变量(execution_id, workspace_id, user_id, message 等)
|
||||
2. 节点输出 (node_id.*) - 节点执行结果
|
||||
3. 会话变量 (conv.*) - 会话级变量(跨多轮对话保持)
|
||||
"""
|
||||
|
||||
# -*- coding: UTF-8 -*-
|
||||
# Author: Eternity
|
||||
# @Email: 1533512157@qq.com
|
||||
# @Time : 2025/12/15 19:50
|
||||
import logging
|
||||
import re
|
||||
from asyncio import Lock
|
||||
@@ -18,7 +11,8 @@ from typing import Any, Generic
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
from app.core.workflow.engine.runtime_schema import ExecutionContext
|
||||
from app.core.workflow.variable.base_variable import VariableType, DEFAULT_VALUE
|
||||
from app.core.workflow.variable.variable_objects import T, create_variable_instance
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -359,3 +353,77 @@ class VariablePool:
|
||||
f" runtime_vars={len(runtime_vars)}\n"
|
||||
f")"
|
||||
)
|
||||
|
||||
|
||||
class VariablePoolInitializer:
|
||||
def __init__(self, workflow_config: dict):
|
||||
self.workflow_config = workflow_config
|
||||
|
||||
async def initialize(
|
||||
self,
|
||||
variable_pool: VariablePool,
|
||||
input_data: dict,
|
||||
execution_context: ExecutionContext
|
||||
) -> None:
|
||||
await self._init_conversation_vars(variable_pool, input_data)
|
||||
await self._init_system_vars(variable_pool, input_data, execution_context)
|
||||
|
||||
async def _init_conversation_vars(
|
||||
self,
|
||||
variable_pool: VariablePool,
|
||||
input_data: dict
|
||||
):
|
||||
init_conv_vars: list[dict] = self.workflow_config.get("variables") or []
|
||||
runtime_conv_vars: dict[str, Any] = input_data.get("conv", {})
|
||||
|
||||
for var_def in init_conv_vars:
|
||||
var_name = var_def.get("name")
|
||||
var_default = runtime_conv_vars.get(var_name, var_def.get("default"))
|
||||
var_type = var_def.get("type")
|
||||
if var_name:
|
||||
if var_default:
|
||||
var_value = var_default
|
||||
else:
|
||||
var_value = DEFAULT_VALUE(var_type)
|
||||
await variable_pool.new(
|
||||
namespace="conv",
|
||||
key=var_name,
|
||||
value=var_value,
|
||||
var_type=var_type,
|
||||
mut=True
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def _init_system_vars(
|
||||
variable_pool: VariablePool,
|
||||
input_data: dict,
|
||||
context: ExecutionContext
|
||||
):
|
||||
user_message = input_data.get("message") or ""
|
||||
user_files = input_data.get("files") or []
|
||||
conversations = input_data.get("conv_messages", [])
|
||||
conversation_index = len(conversations) // 2
|
||||
|
||||
input_variables = input_data.get("variables") or {}
|
||||
sys_vars = {
|
||||
"message": (user_message, VariableType.STRING),
|
||||
"conversation_index": (conversation_index, VariableType.NUMBER),
|
||||
"conversation_id": (input_data.get("conversation_id"), VariableType.STRING),
|
||||
"execution_id": (context.execution_id, VariableType.STRING),
|
||||
"workspace_id": (context.workspace_id, VariableType.STRING),
|
||||
"user_id": (context.user_id, VariableType.STRING),
|
||||
"input_variables": (input_variables, VariableType.OBJECT),
|
||||
"files": (user_files, VariableType.ARRAY_FILE)
|
||||
}
|
||||
for key, var_def in sys_vars.items():
|
||||
value = var_def[0]
|
||||
var_type = var_def[1]
|
||||
await variable_pool.new(
|
||||
namespace='sys',
|
||||
key=key,
|
||||
value=value,
|
||||
var_type=var_type,
|
||||
mut=False
|
||||
)
|
||||
|
||||
|
||||
@@ -1,21 +1,20 @@
|
||||
"""
|
||||
工作流执行器
|
||||
|
||||
基于 LangGraph 的工作流执行引擎。
|
||||
"""
|
||||
# -*- coding: UTF-8 -*-
|
||||
# Author: Eternity
|
||||
# @Email: 1533512157@qq.com
|
||||
# @Time : 2026/2/9 13:51
|
||||
import datetime
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.graph.state import CompiledStateGraph
|
||||
|
||||
from app.core.workflow.graph_builder import GraphBuilder, StreamOutputConfig
|
||||
from app.core.workflow.nodes import WorkflowState
|
||||
from app.core.workflow.nodes.enums import NodeType
|
||||
from app.core.workflow.variable.base_variable import VariableType, DEFAULT_VALUE
|
||||
from app.core.workflow.variable_pool import VariablePool
|
||||
from app.core.workflow.engine.event_stream_handler import EventStreamHandler
|
||||
from app.core.workflow.engine.graph_builder import GraphBuilder
|
||||
from app.core.workflow.engine.result_builder import WorkflowResultBuilder
|
||||
from app.core.workflow.engine.runtime_schema import ExecutionContext
|
||||
from app.core.workflow.engine.state_manager import WorkflowStateManager
|
||||
from app.core.workflow.engine.stream_output_coordinator import StreamOutputCoordinator
|
||||
from app.core.workflow.engine.variable_pool import VariablePool, VariablePoolInitializer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -30,9 +29,7 @@ class WorkflowExecutor:
|
||||
def __init__(
|
||||
self,
|
||||
workflow_config: dict[str, Any],
|
||||
execution_id: str,
|
||||
workspace_id: str,
|
||||
user_id: str,
|
||||
execution_context: ExecutionContext,
|
||||
):
|
||||
"""Initialize Workflow Executor.
|
||||
|
||||
@@ -41,13 +38,10 @@ class WorkflowExecutor:
|
||||
|
||||
Args:
|
||||
workflow_config (dict): The workflow configuration dictionary.
|
||||
execution_id (str): Unique identifier for this workflow execution.
|
||||
workspace_id (str): Workspace or project ID.
|
||||
user_id (str): User ID executing the workflow.
|
||||
execution_context (ExecutionContext): The workflow execution context
|
||||
include execution_id, workspace_id, user_id, checkpoint_config
|
||||
|
||||
Attributes:
|
||||
self.nodes (list): List of node definitions from workflow_config.
|
||||
self.edges (list): List of edge definitions from workflow_config.
|
||||
self.execution_config (dict): Optional execution parameters from workflow_config.
|
||||
self.start_node_id (str | None): ID of the Start node, set after graph build.
|
||||
self.end_outputs (dict[str, StreamOutputConfig]): End node output configs.
|
||||
@@ -57,555 +51,18 @@ class WorkflowExecutor:
|
||||
self.checkpoint_config (RunnableConfig): Config for LangGraph checkpointing.
|
||||
"""
|
||||
self.workflow_config = workflow_config
|
||||
self.execution_id = execution_id
|
||||
self.workspace_id = workspace_id
|
||||
self.user_id = user_id
|
||||
self.nodes = workflow_config.get("nodes", [])
|
||||
self.edges = workflow_config.get("edges", [])
|
||||
self.execution_context = execution_context
|
||||
self.execution_config = workflow_config.get("execution_config", {})
|
||||
|
||||
self.start_node_id = None
|
||||
self.end_outputs: dict[str, StreamOutputConfig] = {}
|
||||
self.activate_end: str | None = None
|
||||
self.start_node_id: str | None = None
|
||||
self.variable_pool: VariablePool | None = None
|
||||
|
||||
self.graph: CompiledStateGraph | None = None
|
||||
self.checkpoint_config = RunnableConfig(
|
||||
configurable={
|
||||
"thread_id": uuid.uuid4(),
|
||||
}
|
||||
)
|
||||
|
||||
async def __init_variable_pool(self, input_data: dict[str, Any]):
|
||||
"""Initialize the variable pool with system, conversation, and input variables.
|
||||
|
||||
This method populates the VariablePool instance with:
|
||||
- Conversation-level variables (`conv` namespace) from workflow config or provided values.
|
||||
- System variables (`sys` namespace) such as message, files, conversation_id, execution_id, workspace_id, user_id, and input_variables.
|
||||
|
||||
Args:
|
||||
input_data (dict): Input data for workflow execution, may contain:
|
||||
- "message": user message (str)
|
||||
- "file": list of user-uploaded files
|
||||
- "conv": existing conversation variables (dict)
|
||||
- "variables": custom variables for the Start node (dict)
|
||||
- "conversation_id": conversation identifier
|
||||
"""
|
||||
user_message = input_data.get("message") or ""
|
||||
user_files = input_data.get("files") or []
|
||||
|
||||
config_variables_list = self.workflow_config.get("variables") or []
|
||||
conv_vars = input_data.get("conv", {})
|
||||
|
||||
# Initialize conversation variables (conv namespace)
|
||||
for var_def in config_variables_list:
|
||||
var_name = var_def.get("name")
|
||||
var_default = conv_vars.get(var_name, var_def.get("default"))
|
||||
var_type = var_def.get("type")
|
||||
if var_name:
|
||||
if var_default:
|
||||
var_value = var_default
|
||||
else:
|
||||
var_value = DEFAULT_VALUE(var_type)
|
||||
await self.variable_pool.new(
|
||||
namespace="conv",
|
||||
key=var_name,
|
||||
value=var_value,
|
||||
var_type=var_type,
|
||||
mut=True
|
||||
)
|
||||
|
||||
# Initialize system variables (sys namespace)
|
||||
input_variables = input_data.get("variables") or {}
|
||||
sys_vars = {
|
||||
"message": (user_message, VariableType.STRING),
|
||||
"conversation_id": (input_data.get("conversation_id"), VariableType.STRING),
|
||||
"execution_id": (self.execution_id, VariableType.STRING),
|
||||
"workspace_id": (self.workspace_id, VariableType.STRING),
|
||||
"user_id": (self.user_id, VariableType.STRING),
|
||||
"input_variables": (input_variables, VariableType.OBJECT),
|
||||
"files": (user_files, VariableType.ARRAY_FILE)
|
||||
}
|
||||
for key, var_def in sys_vars.items():
|
||||
value = var_def[0]
|
||||
var_type = var_def[1]
|
||||
await self.variable_pool.new(
|
||||
namespace='sys',
|
||||
key=key,
|
||||
value=value,
|
||||
var_type=var_type,
|
||||
mut=False
|
||||
)
|
||||
|
||||
def _prepare_initial_state(self, input_data: dict[str, Any]) -> WorkflowState:
|
||||
"""Generate the initial workflow state for execution.
|
||||
|
||||
This method prepares the runtime state dictionary with system variables,
|
||||
conversation variables, node outputs, loop tracking, and activation flags.
|
||||
|
||||
Args:
|
||||
input_data (dict): The input payload for workflow execution.
|
||||
Expected keys:
|
||||
- "conv_messages" (list, optional): Historical conversation messages
|
||||
to include in the workflow state.
|
||||
|
||||
Returns:
|
||||
WorkflowState: A dictionary representing the initialized workflow state
|
||||
with the following keys:
|
||||
- "messages": List of conversation messages
|
||||
- "node_outputs": Empty dict to store outputs of executed nodes
|
||||
- "execution_id": Current workflow execution ID
|
||||
- "workspace_id": Current workspace ID
|
||||
- "user_id": ID of the user triggering execution
|
||||
- "error": None initially, will store error message if a node fails
|
||||
- "error_node": None initially, will store ID of node that caused error
|
||||
- "cycle_nodes": List of node IDs that are of type LOOP or ITERATION
|
||||
- "looping": Integer flag indicating loop execution state (0 = not looping)
|
||||
- "activate": Dict mapping node IDs to activation status; initially
|
||||
only the start node is active
|
||||
"""
|
||||
conversation_messages = input_data.get("conv_messages") or []
|
||||
|
||||
return {
|
||||
"messages": conversation_messages,
|
||||
"node_outputs": {},
|
||||
"execution_id": self.execution_id,
|
||||
"workspace_id": self.workspace_id,
|
||||
"user_id": self.user_id,
|
||||
"error": None,
|
||||
"error_node": None,
|
||||
"cycle_nodes": [
|
||||
node.get("id")
|
||||
for node in self.workflow_config.get("nodes")
|
||||
if node.get("type") in [NodeType.LOOP, NodeType.ITERATION]
|
||||
], # loop, iteration node id
|
||||
"looping": 0, # loop runing flag, only use in loop node,not use in main loop
|
||||
"activate": {
|
||||
self.start_node_id: True
|
||||
}
|
||||
}
|
||||
|
||||
def _build_final_output(self, result, elapsed_time, final_output):
|
||||
"""Construct the final standardized output of the workflow execution.
|
||||
|
||||
This method aggregates node outputs, token usage, conversation and system
|
||||
variables, messages, and other metadata into a consistent dictionary
|
||||
structure suitable for returning from workflow execution.
|
||||
|
||||
Args:
|
||||
result (dict): The runtime state returned by the workflow graph execution.
|
||||
Expected keys include:
|
||||
- "node_outputs" (dict): Outputs of executed nodes.
|
||||
- "messages" (list): Conversation messages exchanged during execution.
|
||||
- "error" (str, optional): Error message if any node failed.
|
||||
elapsed_time (float): Total execution time in seconds.
|
||||
final_output (Any): The aggregated or final output content of the workflow
|
||||
(e.g., combined messages from all End nodes).
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing the final workflow execution result with keys:
|
||||
- "status": Execution status ("completed")
|
||||
- "output": Aggregated final output content
|
||||
- "variables": Namespace dictionary with:
|
||||
- "conv": Conversation variables
|
||||
- "sys": System variables
|
||||
- "node_outputs": Outputs from all executed nodes
|
||||
- "messages": Conversation messages exchanged
|
||||
- "conversation_id": ID of the current conversation
|
||||
- "elapsed_time": Total execution time in seconds
|
||||
- "token_usage": Aggregated token usage across nodes (if available)
|
||||
- "error": Error message if any occurred during execution
|
||||
"""
|
||||
node_outputs = result.get("node_outputs", {})
|
||||
token_usage = self._aggregate_token_usage(node_outputs)
|
||||
conversation_id = self.variable_pool.get_value("sys.conversation_id")
|
||||
|
||||
return {
|
||||
"status": "completed",
|
||||
"output": final_output,
|
||||
"variables": {
|
||||
"conv": self.variable_pool.get_all_conversation_vars(),
|
||||
"sys": self.variable_pool.get_all_system_vars()
|
||||
},
|
||||
"node_outputs": node_outputs,
|
||||
"messages": result.get("messages", []),
|
||||
"conversation_id": conversation_id,
|
||||
"elapsed_time": elapsed_time,
|
||||
"token_usage": token_usage,
|
||||
"error": result.get("error"),
|
||||
}
|
||||
|
||||
def _update_scope_activate(self, scope, status=None):
|
||||
"""
|
||||
Update the activation state of all End nodes based on a completed scope (node or variable).
|
||||
|
||||
Iterates over all End nodes in `self.end_outputs` and calls
|
||||
`update_activate` on each, which may:
|
||||
- Activate variable segments that depend on the completed node/scope.
|
||||
- Activate the entire End node output if any control conditions are met.
|
||||
|
||||
If any End node becomes active and `self.activate_end` is not yet set,
|
||||
this node will be marked as the currently active End node.
|
||||
|
||||
Args:
|
||||
scope (str): The node ID or scope that has completed execution.
|
||||
status (str | None): Optional status of the node (used for branch/control nodes).
|
||||
"""
|
||||
for node in self.end_outputs.keys():
|
||||
self.end_outputs[node].update_activate(scope, status)
|
||||
if self.end_outputs[node].activate and self.activate_end is None:
|
||||
self.activate_end = node
|
||||
|
||||
def _update_stream_output_status(self, activate, data):
|
||||
"""
|
||||
Update the stream output state of End nodes based on workflow state updates.
|
||||
|
||||
This method checks which nodes/scopes are activated and propagates
|
||||
activation to End nodes accordingly.
|
||||
|
||||
Args:
|
||||
activate (dict): Mapping of node_id -> bool indicating which nodes/scopes are activated.
|
||||
data (dict): Mapping of node_id -> node runtime data, including outputs.
|
||||
|
||||
Behavior:
|
||||
For each node in `data`:
|
||||
1. If the node is activated (`activate[node_id]` is True),
|
||||
retrieve its output status from `runtime_vars`.
|
||||
2. Call `_update_scope_activate` to propagate the activation
|
||||
to all relevant End nodes and update `self.activate_end`.
|
||||
"""
|
||||
for node_id in data.keys():
|
||||
if activate.get(node_id):
|
||||
node_output_status = self.variable_pool.get_value(f"{node_id}.output", default=None, strict=False)
|
||||
self._update_scope_activate(node_id, status=node_output_status)
|
||||
|
||||
async def _emit_active_chunks(
|
||||
self,
|
||||
force=False
|
||||
):
|
||||
"""
|
||||
Process and yield all currently active output segments for the currently active End node.
|
||||
|
||||
This method handles stream-mode output for an End node by iterating through its output segments
|
||||
(`OutputContent`). Only segments marked as active (`activate=True`) are processed, unless
|
||||
`force=True`, which allows all segments to be processed regardless of their activation state.
|
||||
|
||||
Behavior:
|
||||
1. Iterates from the current `cursor` position to the end of the outputs list.
|
||||
2. For each segment:
|
||||
- If the segment is literal text (`is_variable=False`), append it directly.
|
||||
- If the segment is a variable (`is_variable=True`), evaluate it using
|
||||
`evaluate_expression` with the given `node_outputs` and `variables`,
|
||||
then transform the result with `_trans_output_string`.
|
||||
3. Yield a stream event of type "message" containing the processed chunk.
|
||||
4. Move the `cursor` forward after processing each segment.
|
||||
5. When all segments have been processed, remove this End node from `end_outputs`
|
||||
and reset `activate_end` to None.
|
||||
|
||||
Args:
|
||||
force (bool, default=False): If True, process segments even if `activate=False`.
|
||||
|
||||
Yields:
|
||||
dict: A stream event of type "message" containing the processed chunk.
|
||||
|
||||
Notes:
|
||||
- Segments that fail evaluation (ValueError) are skipped with a warning logged.
|
||||
- This method only processes the currently active End node (`self.activate_end`).
|
||||
- Use `force=True` for final emission regardless of activation state.
|
||||
"""
|
||||
|
||||
end_info = self.end_outputs[self.activate_end]
|
||||
|
||||
while end_info.cursor < len(end_info.outputs):
|
||||
final_chunk = ''
|
||||
current_segment = end_info.outputs[end_info.cursor]
|
||||
|
||||
if not current_segment.activate and not force:
|
||||
# Stop processing until this segment becomes active
|
||||
break
|
||||
|
||||
# Literal segment
|
||||
if not current_segment.is_variable:
|
||||
final_chunk += current_segment.literal
|
||||
else:
|
||||
# Variable segment: evaluate and transform
|
||||
try:
|
||||
chunk = self.variable_pool.get_literal(current_segment.literal)
|
||||
final_chunk += chunk
|
||||
except KeyError:
|
||||
# Log failed evaluation but continue streaming
|
||||
logger.warning(f"[STREAM] Failed to evaluate segment: {current_segment.literal}")
|
||||
|
||||
if final_chunk:
|
||||
logger.info(f"[STREAM] StreamOutput Node:{self.activate_end}, chunk:{final_chunk}")
|
||||
yield {
|
||||
"event": "message",
|
||||
"data": {
|
||||
"chunk": final_chunk
|
||||
}
|
||||
}
|
||||
|
||||
# Advance cursor after processing
|
||||
end_info.cursor += 1
|
||||
|
||||
# Remove End node from active tracking if all segments have been processed
|
||||
if end_info.cursor >= len(end_info.outputs):
|
||||
self.end_outputs.pop(self.activate_end)
|
||||
self.activate_end = None
|
||||
|
||||
async def _handle_updates_event(self, data):
|
||||
"""
|
||||
Handle workflow state update events ("updates") and stream active End node outputs.
|
||||
|
||||
Steps:
|
||||
1. Retrieve the current graph state.
|
||||
2. Extract node activation information from the state.
|
||||
3. Update the activation status of all End nodes.
|
||||
4. While there is an active End node:
|
||||
- Call _emit_active_chunks() to yield all currently active output segments.
|
||||
- After all segments are processed, update activate_end if there are remaining End nodes.
|
||||
5. Log a debug message indicating state update received.
|
||||
|
||||
Args:
|
||||
data (dict): The latest node state updates.
|
||||
|
||||
Yields:
|
||||
dict: Streamed output event, each chunk in the format:
|
||||
{"event": "message", "data": {"chunk": ...}}
|
||||
"""
|
||||
# Get the latest workflow state
|
||||
state = self.graph.get_state(config=self.checkpoint_config).values
|
||||
activate = state.get("activate", {})
|
||||
|
||||
# Update End node activation based on the new state
|
||||
self._update_stream_output_status(activate, data)
|
||||
wait = False
|
||||
while self.activate_end and not wait:
|
||||
async for msg_event in self._emit_active_chunks():
|
||||
yield msg_event
|
||||
|
||||
if self.activate_end:
|
||||
wait = True
|
||||
else:
|
||||
self._update_stream_output_status(activate, data)
|
||||
|
||||
logger.debug(f"[UPDATES] Received state update from nodes: {list(data.keys())} "
|
||||
f"- execution_id: {self.execution_id}")
|
||||
|
||||
async def _handle_node_chunk_event(self, data):
|
||||
"""
|
||||
Handle streaming chunk events from individual nodes ("node_chunk").
|
||||
|
||||
This method processes output segments for the currently active End node.
|
||||
If the segment depends on the provided node_id:
|
||||
- If the node has finished execution (`done=True`), advance the cursor.
|
||||
- If all segments are processed, deactivate the End node.
|
||||
- Otherwise, yield the current chunk as a streaming message.
|
||||
|
||||
Args:
|
||||
data (dict): Node chunk event data, expected keys:
|
||||
- "node_id": ID of the node producing this chunk
|
||||
- "chunk": Chunk of output text
|
||||
- "done": Boolean indicating whether the node finished producing output
|
||||
|
||||
Yields:
|
||||
dict: Streaming message event in the format:
|
||||
{"event": "message", "data": {"chunk": ...}}
|
||||
"""
|
||||
node_id = data.get("node_id")
|
||||
if self.activate_end:
|
||||
end_info = self.end_outputs.get(self.activate_end)
|
||||
if not end_info or end_info.cursor >= len(end_info.outputs):
|
||||
return
|
||||
current_output = end_info.outputs[end_info.cursor]
|
||||
if current_output.is_variable and current_output.depends_on_scope(node_id):
|
||||
if data.get("done"):
|
||||
end_info.cursor += 1
|
||||
if end_info.cursor >= len(end_info.outputs):
|
||||
self.end_outputs.pop(self.activate_end)
|
||||
self.activate_end = None
|
||||
else:
|
||||
yield {
|
||||
"event": "message",
|
||||
"data": {
|
||||
"chunk": data.get("chunk")
|
||||
}
|
||||
}
|
||||
|
||||
async def _handle_node_error_event(self, data):
|
||||
"""
|
||||
Handle node error events ("node_error") during workflow execution.
|
||||
|
||||
This method streams an error event for a node that has failed. The event
|
||||
contains the node ID, status, input data, elapsed time, and error message.
|
||||
|
||||
Args:
|
||||
data (dict): Node error event data, expected keys:
|
||||
- "node_id": ID of the node that failed
|
||||
- "input_data": The input data that caused the error
|
||||
- "elapsed_time": Execution time before the error occurred
|
||||
- "error": Error message or exception string
|
||||
|
||||
Yields:
|
||||
dict: Node error event in the format:
|
||||
{
|
||||
"event": "node_error",
|
||||
"data": {
|
||||
"node_id": str,
|
||||
"status": "failed",
|
||||
"input": ...,
|
||||
"elapsed_time": float,
|
||||
"output": None,
|
||||
"error": str
|
||||
}
|
||||
}
|
||||
"""
|
||||
node_id = data.get("node_id")
|
||||
yield {
|
||||
"event": "node_error",
|
||||
"data": {
|
||||
"node_id": node_id,
|
||||
"status": "failed",
|
||||
"input": data.get("input_data"),
|
||||
"elapsed_time": data.get("elapsed_time"),
|
||||
"output": None,
|
||||
"error": data.get("error")
|
||||
}
|
||||
}
|
||||
|
||||
async def _handle_debug_event(self, data, input_data):
|
||||
"""
|
||||
Handle debug events ("debug") related to node execution status.
|
||||
|
||||
This method streams debug events for nodes, including when a node starts
|
||||
execution ("node_start") and when it completes execution ("node_end").
|
||||
It filters out nodes with names starting with "nop" as no-operation nodes.
|
||||
|
||||
Args:
|
||||
data (dict): Debug event data, expected keys:
|
||||
- "type": Event type ("task" for start, "task_result" for completion)
|
||||
- "payload": Node-related information, including:
|
||||
- "name": Node name / ID
|
||||
- "input": Node input data (for "task" type)
|
||||
- "result": Node execution result (for "task_result" type)
|
||||
- "timestamp": ISO timestamp string of the event
|
||||
input_data (dict): Original workflow input data (used to get conversation_id)
|
||||
|
||||
Yields:
|
||||
dict: Node debug event in one of the following formats:
|
||||
1. Node start:
|
||||
{
|
||||
"event": "node_start",
|
||||
"data": {
|
||||
"node_id": str,
|
||||
"conversation_id": str,
|
||||
"execution_id": str,
|
||||
"timestamp": int (ms)
|
||||
}
|
||||
}
|
||||
2. Node end:
|
||||
{
|
||||
"event": "node_end",
|
||||
"data": {
|
||||
"node_id": str,
|
||||
"conversation_id": str,
|
||||
"execution_id": str,
|
||||
"timestamp": int (ms),
|
||||
"input": dict,
|
||||
"output": Any,
|
||||
"elapsed_time": float
|
||||
}
|
||||
}
|
||||
"""
|
||||
event_type = data.get("type")
|
||||
payload = data.get("payload", {})
|
||||
node_name = payload.get("name")
|
||||
|
||||
# Skip no-operation nodes
|
||||
if node_name and node_name.startswith("nop"):
|
||||
return
|
||||
|
||||
if event_type == "task":
|
||||
# Node starts execution
|
||||
inputv = payload.get("input", {})
|
||||
if not inputv.get("activate", {}).get(node_name):
|
||||
return
|
||||
conversation_id = input_data.get("conversation_id")
|
||||
logger.info(f"[NODE-START] Node '{node_name}' execution started - execution_id: {self.execution_id}")
|
||||
|
||||
yield {
|
||||
"event": "node_start",
|
||||
"data": {
|
||||
"node_id": node_name,
|
||||
"conversation_id": conversation_id,
|
||||
"execution_id": self.execution_id,
|
||||
"timestamp": int(datetime.datetime.fromisoformat(
|
||||
data.get("timestamp")
|
||||
).timestamp() * 1000),
|
||||
}
|
||||
}
|
||||
elif event_type == "task_result":
|
||||
# Node execution completed
|
||||
result = payload.get("result", {})
|
||||
if not result.get("activate", {}).get(node_name):
|
||||
return
|
||||
|
||||
conversation_id = input_data.get("conversation_id")
|
||||
logger.info(f"[NODE-END] Node '{node_name}' execution completed - execution_id: {self.execution_id}")
|
||||
|
||||
yield {
|
||||
"event": "node_end",
|
||||
"data": {
|
||||
"node_id": node_name,
|
||||
"conversation_id": conversation_id,
|
||||
"execution_id": self.execution_id,
|
||||
"timestamp": int(datetime.datetime.fromisoformat(
|
||||
data.get("timestamp")
|
||||
).timestamp() * 1000),
|
||||
"input": result.get("node_outputs", {}).get(node_name, {}).get("input"),
|
||||
"output": result.get("node_outputs", {}).get(node_name, {}).get("output"),
|
||||
"elapsed_time": result.get("node_outputs", {}).get(node_name, {}).get("elapsed_time"),
|
||||
"token_usage": result.get("node_outputs", {}).get(node_name, {}).get("token_usage")
|
||||
}
|
||||
}
|
||||
|
||||
async def _flush_remaining_chunk(self):
|
||||
"""
|
||||
Flush and yield all remaining output segments from active End nodes.
|
||||
|
||||
This method ensures that any remaining chunks of output, which may not have
|
||||
been emitted during normal streaming due to activation conditions, are fully
|
||||
processed. It is typically called at the end of a workflow to guarantee
|
||||
that all output is delivered.
|
||||
|
||||
Behavior:
|
||||
1. Filter `end_outputs` to only keep End nodes that are still active.
|
||||
2. While there is an active End node (`self.activate_end`):
|
||||
- Call `_emit_active_chunks(force=True)` to emit all segments regardless
|
||||
of their activation state.
|
||||
- If the current End node finishes, move to the next active End node
|
||||
if any remain.
|
||||
|
||||
Yields:
|
||||
dict: Streamed output events in the format:
|
||||
{"event": "message", "data": {"chunk": ...}}
|
||||
"""
|
||||
# Keep only active End nodes
|
||||
self.end_outputs = {
|
||||
node_id: node_info
|
||||
for node_id, node_info in self.end_outputs.items()
|
||||
if node_info.activate
|
||||
}
|
||||
|
||||
if self.end_outputs or self.activate_end:
|
||||
while self.activate_end:
|
||||
# Force emit all remaining chunks of the active End node
|
||||
async for msg_event in self._emit_active_chunks(force=True):
|
||||
yield msg_event
|
||||
|
||||
# Move to next active End node if current one is done
|
||||
if not self.activate_end and self.end_outputs:
|
||||
self.activate_end = list(self.end_outputs.keys())[0]
|
||||
self.variable_initializer = VariablePoolInitializer(workflow_config)
|
||||
self.state_manager = WorkflowStateManager()
|
||||
self.result_builder = WorkflowResultBuilder()
|
||||
self.stream_coordinator = StreamOutputCoordinator()
|
||||
self.event_handler: EventStreamHandler | None = None
|
||||
|
||||
def build_graph(self, stream=False) -> CompiledStateGraph:
|
||||
"""
|
||||
@@ -624,16 +81,22 @@ class WorkflowExecutor:
|
||||
Returns:
|
||||
CompiledStateGraph: The compiled and ready-to-run state graph.
|
||||
"""
|
||||
logger.info(f"Starting workflow graph build: execution_id={self.execution_id}")
|
||||
logger.info(f"Starting workflow graph build: execution_id={self.execution_context.execution_id}")
|
||||
builder = GraphBuilder(
|
||||
self.workflow_config,
|
||||
stream=stream,
|
||||
)
|
||||
self.start_node_id = builder.start_node_id
|
||||
self.end_outputs = builder.end_node_map
|
||||
self.variable_pool = builder.variable_pool
|
||||
self.graph = builder.build()
|
||||
logger.info(f"Workflow graph build completed: execution_id={self.execution_id}")
|
||||
|
||||
self.stream_coordinator.initialize_end_outputs(builder.end_node_map)
|
||||
self.event_handler = EventStreamHandler(
|
||||
output_coordinator=self.stream_coordinator,
|
||||
variable_pool=self.variable_pool,
|
||||
execution_id=self.execution_context.execution_id
|
||||
)
|
||||
logger.info(f"Workflow graph build completed: execution_id={self.execution_context.execution_id}")
|
||||
|
||||
return self.graph
|
||||
|
||||
@@ -665,7 +128,7 @@ class WorkflowExecutor:
|
||||
- token_usage: aggregated token usage if available
|
||||
- error: error message if any
|
||||
"""
|
||||
logger.info(f"Starting workflow execution: execution_id={self.execution_id}")
|
||||
logger.info(f"Starting workflow execution: execution_id={self.execution_context.execution_id}")
|
||||
|
||||
start_time = datetime.datetime.now()
|
||||
|
||||
@@ -673,16 +136,25 @@ class WorkflowExecutor:
|
||||
graph = self.build_graph()
|
||||
|
||||
# Initialize the variable pool with input data
|
||||
await self.__init_variable_pool(input_data)
|
||||
initial_state = self._prepare_initial_state(input_data)
|
||||
await self.variable_initializer.initialize(
|
||||
variable_pool=self.variable_pool,
|
||||
input_data=input_data,
|
||||
execution_context=self.execution_context
|
||||
)
|
||||
initial_state = self.state_manager.create_initial_state(
|
||||
workflow_config=self.workflow_config,
|
||||
input_data=input_data,
|
||||
execution_context=self.execution_context,
|
||||
start_node_id=self.start_node_id
|
||||
)
|
||||
|
||||
# Execute the workflow
|
||||
try:
|
||||
result = await graph.ainvoke(initial_state, config=self.checkpoint_config)
|
||||
result = await graph.ainvoke(initial_state, config=self.execution_context.checkpoint_config)
|
||||
|
||||
# Aggregate output from all End nodes
|
||||
full_content = ''
|
||||
for end_id in self.end_outputs.keys():
|
||||
for end_id in self.stream_coordinator.end_outputs.keys():
|
||||
full_content += self.variable_pool.get_value(f"{end_id}.output", default="", strict=False)
|
||||
|
||||
# Append messages for user and assistant
|
||||
@@ -703,15 +175,16 @@ class WorkflowExecutor:
|
||||
elapsed_time = (end_time - start_time).total_seconds()
|
||||
|
||||
logger.info(
|
||||
f"Workflow execution completed: execution_id={self.execution_id}, elapsed_time={elapsed_time:.2f}s")
|
||||
f"Workflow execution completed: execution_id={self.execution_context.execution_id}, elapsed_time={elapsed_time:.2f}s")
|
||||
|
||||
return self._build_final_output(result, elapsed_time, full_content)
|
||||
return self.result_builder.build_final_output(result, self.variable_pool, elapsed_time, full_content)
|
||||
|
||||
except Exception as e:
|
||||
end_time = datetime.datetime.now()
|
||||
elapsed_time = (end_time - start_time).total_seconds()
|
||||
|
||||
logger.error(f"Workflow execution failed: execution_id={self.execution_id}, error={e}", exc_info=True)
|
||||
logger.error(f"Workflow execution failed: execution_id={self.execution_context.execution_id}, error={e}",
|
||||
exc_info=True)
|
||||
return {
|
||||
"status": "failed",
|
||||
"error": str(e),
|
||||
@@ -744,15 +217,15 @@ class WorkflowExecutor:
|
||||
"data": {...}
|
||||
}
|
||||
"""
|
||||
logger.info(f"Starting workflow execution (streaming): execution_id={self.execution_id}")
|
||||
logger.info(f"Starting workflow execution (streaming): execution_id={self.execution_context.execution_id}")
|
||||
|
||||
start_time = datetime.datetime.now()
|
||||
|
||||
yield {
|
||||
"event": "workflow_start",
|
||||
"data": {
|
||||
"execution_id": self.execution_id,
|
||||
"workspace_id": self.workspace_id,
|
||||
"execution_id": self.execution_context.execution_id,
|
||||
"workspace_id": self.execution_context.workspace_id,
|
||||
"conversation_id": input_data.get("conversation_id"),
|
||||
"timestamp": int(start_time.timestamp() * 1000)
|
||||
}
|
||||
@@ -762,18 +235,27 @@ class WorkflowExecutor:
|
||||
graph = self.build_graph(stream=True)
|
||||
|
||||
# Initialize the variable pool and system variables
|
||||
await self.__init_variable_pool(input_data)
|
||||
initial_state = self._prepare_initial_state(input_data)
|
||||
await self.variable_initializer.initialize(
|
||||
variable_pool=self.variable_pool,
|
||||
input_data=input_data,
|
||||
execution_context=self.execution_context
|
||||
)
|
||||
initial_state = self.state_manager.create_initial_state(
|
||||
workflow_config=self.workflow_config,
|
||||
input_data=input_data,
|
||||
execution_context=self.execution_context,
|
||||
start_node_id=self.start_node_id
|
||||
)
|
||||
|
||||
try:
|
||||
full_content = ''
|
||||
self._update_scope_activate("sys")
|
||||
self.stream_coordinator.update_scope_activation("sys")
|
||||
|
||||
# Execute the workflow with streaming
|
||||
async for event in graph.astream(
|
||||
initial_state,
|
||||
stream_mode=["updates", "debug", "custom"], # Use updates + debug + custom mode
|
||||
config=self.checkpoint_config
|
||||
config=self.execution_context.checkpoint_config
|
||||
):
|
||||
# event should be a tuple: (mode, data)
|
||||
# But let's handle both cases
|
||||
@@ -782,38 +264,46 @@ class WorkflowExecutor:
|
||||
else:
|
||||
# Unexpected format, log and skip
|
||||
logger.warning(f"[STREAM] Unexpected event format: {type(event)}, value: {event}"
|
||||
f"- execution_id: {self.execution_id}")
|
||||
f"- execution_id: {self.execution_context.execution_id}")
|
||||
continue
|
||||
|
||||
if mode == "custom":
|
||||
# Handle custom streaming events (chunks from nodes via stream writer)
|
||||
event_type = data.get("type", "node_chunk") # "message" or "node_chunk"
|
||||
if event_type == "node_chunk":
|
||||
async for msg_event in self._handle_node_chunk_event(data):
|
||||
async for msg_event in self.event_handler.handle_node_chunk_event(data):
|
||||
full_content += msg_event["data"]["chunk"]
|
||||
yield msg_event
|
||||
|
||||
elif event_type == "node_error":
|
||||
async for error_event in self._handle_node_error_event(data):
|
||||
async for error_event in self.event_handler.handle_node_error_event(data):
|
||||
yield error_event
|
||||
|
||||
elif event_type == "cycle_item":
|
||||
async for cycle_event in self.event_handler.handle_cycle_item_event(data):
|
||||
yield cycle_event
|
||||
|
||||
elif mode == "debug":
|
||||
async for debug_event in self._handle_debug_event(data, input_data):
|
||||
async for debug_event in self.event_handler.handle_debug_event(data, input_data):
|
||||
yield debug_event
|
||||
|
||||
elif mode == "updates":
|
||||
logger.debug(f"[UPDATES] 收到 state 更新 from {list(data.keys())} "
|
||||
f"- execution_id: {self.execution_id}")
|
||||
async for msg_event in self._handle_updates_event(data):
|
||||
f"- execution_id: {self.execution_context.execution_id}")
|
||||
async for msg_event in self.event_handler.handle_updates_event(
|
||||
data,
|
||||
self.graph,
|
||||
self.execution_context.checkpoint_config
|
||||
):
|
||||
full_content += msg_event["data"]['chunk']
|
||||
yield msg_event
|
||||
|
||||
# Flush any remaining chunks
|
||||
async for msg_event in self._flush_remaining_chunk():
|
||||
async for msg_event in self.stream_coordinator.flush_remaining_chunk(self.variable_pool):
|
||||
full_content += msg_event["data"]['chunk']
|
||||
yield msg_event
|
||||
|
||||
result = graph.get_state(self.checkpoint_config).values
|
||||
result = graph.get_state(self.execution_context.checkpoint_config).values
|
||||
end_time = datetime.datetime.now()
|
||||
elapsed_time = (end_time - start_time).total_seconds()
|
||||
|
||||
@@ -832,24 +322,25 @@ class WorkflowExecutor:
|
||||
)
|
||||
logger.info(
|
||||
f"Workflow execution completed (streaming), "
|
||||
f"elapsed: {elapsed_time:.2f}s, execution_id: {self.execution_id}"
|
||||
f"elapsed: {elapsed_time:.2f}s, execution_id: {self.execution_context.execution_id}"
|
||||
)
|
||||
|
||||
yield {
|
||||
"event": "workflow_end",
|
||||
"data": self._build_final_output(result, elapsed_time, full_content)
|
||||
"data": self.result_builder.build_final_output(result, self.variable_pool, elapsed_time, full_content)
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
end_time = datetime.datetime.now()
|
||||
elapsed_time = (end_time - start_time).total_seconds()
|
||||
|
||||
logger.error(f"Workflow execution failed: execution_id={self.execution_id}, error={e}", exc_info=True)
|
||||
logger.error(f"Workflow execution failed: execution_id={self.execution_context.execution_id}, error={e}",
|
||||
exc_info=True)
|
||||
|
||||
yield {
|
||||
"event": "workflow_end",
|
||||
"data": {
|
||||
"execution_id": self.execution_id,
|
||||
"execution_id": self.execution_context.execution_id,
|
||||
"status": "failed",
|
||||
"error": str(e),
|
||||
"elapsed_time": elapsed_time,
|
||||
@@ -857,46 +348,6 @@ class WorkflowExecutor:
|
||||
}
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _aggregate_token_usage(node_outputs: dict[str, Any]) -> dict[str, int] | None:
|
||||
"""
|
||||
Aggregate token usage statistics across all nodes.
|
||||
|
||||
Args:
|
||||
node_outputs (dict): A dictionary of all node outputs.
|
||||
|
||||
Returns:
|
||||
dict | None: Aggregated token usage in the format:
|
||||
{
|
||||
"prompt_tokens": int,
|
||||
"completion_tokens": int,
|
||||
"total_tokens": int
|
||||
}
|
||||
Returns None if no token usage information is available.
|
||||
"""
|
||||
total_prompt_tokens = 0
|
||||
total_completion_tokens = 0
|
||||
total_tokens = 0
|
||||
has_token_info = False
|
||||
|
||||
for node_output in node_outputs.values():
|
||||
if isinstance(node_output, dict):
|
||||
token_usage = node_output.get("token_usage")
|
||||
if token_usage and isinstance(token_usage, dict):
|
||||
has_token_info = True
|
||||
total_prompt_tokens += token_usage.get("prompt_tokens", 0)
|
||||
total_completion_tokens += token_usage.get("completion_tokens", 0)
|
||||
total_tokens += token_usage.get("total_tokens", 0)
|
||||
|
||||
if not has_token_info:
|
||||
return None
|
||||
|
||||
return {
|
||||
"prompt_tokens": total_prompt_tokens,
|
||||
"completion_tokens": total_completion_tokens,
|
||||
"total_tokens": total_tokens
|
||||
}
|
||||
|
||||
|
||||
async def execute_workflow(
|
||||
workflow_config: dict[str, Any],
|
||||
@@ -918,12 +369,15 @@ async def execute_workflow(
|
||||
Returns:
|
||||
dict: Workflow execution result.
|
||||
"""
|
||||
executor = WorkflowExecutor(
|
||||
workflow_config=workflow_config,
|
||||
execution_context = ExecutionContext.create(
|
||||
execution_id=execution_id,
|
||||
workspace_id=workspace_id,
|
||||
user_id=user_id
|
||||
)
|
||||
executor = WorkflowExecutor(
|
||||
workflow_config=workflow_config,
|
||||
execution_context=execution_context
|
||||
)
|
||||
return await executor.execute(input_data)
|
||||
|
||||
|
||||
@@ -947,11 +401,14 @@ async def execute_workflow_stream(
|
||||
Yields:
|
||||
dict: Streaming workflow events, e.g. node start, node end, chunk messages, workflow end.
|
||||
"""
|
||||
executor = WorkflowExecutor(
|
||||
workflow_config=workflow_config,
|
||||
execution_context = ExecutionContext.create(
|
||||
execution_id=execution_id,
|
||||
workspace_id=workspace_id,
|
||||
user_id=user_id
|
||||
)
|
||||
executor = WorkflowExecutor(
|
||||
workflow_config=workflow_config,
|
||||
execution_context=execution_context
|
||||
)
|
||||
async for event in executor.execute_stream(input_data):
|
||||
yield event
|
||||
|
||||
@@ -6,7 +6,8 @@
|
||||
|
||||
from app.core.workflow.nodes.agent import AgentNode
|
||||
from app.core.workflow.nodes.assigner import AssignerNode
|
||||
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
|
||||
from app.core.workflow.nodes.base_node import BaseNode
|
||||
from app.core.workflow.nodes.code import CodeNode
|
||||
from app.core.workflow.nodes.end import EndNode
|
||||
from app.core.workflow.nodes.http_request import HttpRequestNode
|
||||
from app.core.workflow.nodes.if_else import IfElseNode
|
||||
@@ -14,16 +15,14 @@ from app.core.workflow.nodes.jinja_render import JinjaRenderNode
|
||||
from app.core.workflow.nodes.knowledge import KnowledgeRetrievalNode
|
||||
from app.core.workflow.nodes.llm import LLMNode
|
||||
from app.core.workflow.nodes.node_factory import NodeFactory, WorkflowNode
|
||||
from app.core.workflow.nodes.start import StartNode
|
||||
from app.core.workflow.nodes.parameter_extractor import ParameterExtractorNode
|
||||
from app.core.workflow.nodes.question_classifier import QuestionClassifierNode
|
||||
from app.core.workflow.nodes.start import StartNode
|
||||
from app.core.workflow.nodes.tool import ToolNode
|
||||
from app.core.workflow.nodes.variable_aggregator import VariableAggregatorNode
|
||||
from app.core.workflow.nodes.code import CodeNode
|
||||
|
||||
__all__ = [
|
||||
"BaseNode",
|
||||
"WorkflowState",
|
||||
"LLMNode",
|
||||
"AgentNode",
|
||||
"IfElseNode",
|
||||
|
||||
@@ -7,14 +7,16 @@ Agent 节点实现
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.messages import AIMessage
|
||||
|
||||
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
|
||||
from app.core.workflow.engine.state_manager import WorkflowState
|
||||
from app.core.workflow.engine.variable_pool import VariablePool
|
||||
from app.core.workflow.nodes.base_node import BaseNode
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
from app.core.workflow.variable_pool import VariablePool
|
||||
from app.services.draft_run_service import DraftRunService
|
||||
from app.models import AppRelease
|
||||
from app.db import get_db
|
||||
from app.models import AppRelease
|
||||
from app.services.draft_run_service import DraftRunService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -2,12 +2,13 @@ import logging
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from app.core.workflow.engine.state_manager import WorkflowState
|
||||
from app.core.workflow.engine.variable_pool import VariablePool
|
||||
from app.core.workflow.nodes.assigner.config import AssignerNodeConfig
|
||||
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
|
||||
from app.core.workflow.nodes.base_node import BaseNode
|
||||
from app.core.workflow.nodes.enums import AssignmentOperator
|
||||
from app.core.workflow.nodes.operators import AssignmentOperatorInstance, AssignmentOperatorResolver
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
from app.core.workflow.variable_pool import VariablePool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -5,57 +5,17 @@ from functools import cached_property
|
||||
from typing import Any, AsyncGenerator
|
||||
|
||||
from langgraph.config import get_stream_writer
|
||||
from typing_extensions import TypedDict, Annotated
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.workflow.engine.state_manager import WorkflowState
|
||||
from app.core.workflow.engine.variable_pool import VariablePool
|
||||
from app.core.workflow.nodes.enums import BRANCH_NODES
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
from app.core.workflow.variable_pool import VariablePool
|
||||
from app.services.multimodal_service import PROVIDER_STRATEGIES
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def merge_activate_state(x, y):
|
||||
return {
|
||||
k: x.get(k, False) or y.get(k, False)
|
||||
for k in set(x) | set(y)
|
||||
}
|
||||
|
||||
|
||||
def merge_looping_state(x, y):
|
||||
return y if y > x else x
|
||||
|
||||
|
||||
class WorkflowState(TypedDict):
|
||||
"""Workflow state
|
||||
|
||||
The state object passed between nodes in a workflow, containing messages, variables, node outputs, etc.
|
||||
"""
|
||||
# List of messages (append mode)
|
||||
messages: Annotated[list[dict[str, str]], lambda x, y: y]
|
||||
|
||||
# Set of loop node IDs, used for assigning values in loop nodes
|
||||
cycle_nodes: list
|
||||
looping: Annotated[int, merge_looping_state]
|
||||
|
||||
# Node outputs (stores execution results of each node for variable references)
|
||||
# Uses a custom merge function to combine new node outputs into the existing dictionary
|
||||
node_outputs: Annotated[dict[str, Any], lambda x, y: {**x, **y}]
|
||||
|
||||
# Execution context
|
||||
execution_id: str
|
||||
workspace_id: str
|
||||
user_id: str
|
||||
|
||||
# Error information (for error edges)
|
||||
error: str | None
|
||||
error_node: str | None
|
||||
|
||||
# node activate status
|
||||
activate: Annotated[dict[str, bool], merge_activate_state]
|
||||
|
||||
|
||||
class BaseNode(ABC):
|
||||
"""Base class for workflow nodes.
|
||||
|
||||
@@ -584,7 +544,7 @@ class BaseNode(ABC):
|
||||
Returns:
|
||||
The rendered string with all variables substituted.
|
||||
"""
|
||||
from app.core.workflow.template_renderer import render_template
|
||||
from app.core.workflow.utils.template_renderer import render_template
|
||||
|
||||
return render_template(
|
||||
template=template,
|
||||
@@ -611,7 +571,7 @@ class BaseNode(ABC):
|
||||
Returns:
|
||||
The boolean result of evaluating the expression.
|
||||
"""
|
||||
from app.core.workflow.expression_evaluator import evaluate_condition
|
||||
from app.core.workflow.utils.expression_evaluator import evaluate_condition
|
||||
|
||||
return evaluate_condition(
|
||||
expression=expression,
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from app.core.workflow.nodes import BaseNode, WorkflowState
|
||||
from app.core.workflow.engine.state_manager import WorkflowState
|
||||
from app.core.workflow.engine.variable_pool import VariablePool
|
||||
from app.core.workflow.nodes import BaseNode
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
from app.core.workflow.variable_pool import VariablePool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -6,13 +6,14 @@ import urllib.parse
|
||||
from string import Template
|
||||
from textwrap import dedent
|
||||
from typing import Any
|
||||
import urllib.parse
|
||||
|
||||
import httpx
|
||||
|
||||
from app.core.workflow.nodes import BaseNode, WorkflowState
|
||||
from app.core.workflow.engine.state_manager import WorkflowState
|
||||
from app.core.workflow.engine.variable_pool import VariablePool
|
||||
from app.core.workflow.nodes import BaseNode
|
||||
from app.core.workflow.nodes.code.config import CodeNodeConfig
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
from app.core.workflow.variable_pool import VariablePool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -1,14 +1,18 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import re
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.graph.state import CompiledStateGraph
|
||||
from langgraph.config import get_stream_writer
|
||||
|
||||
from app.core.workflow.nodes import WorkflowState
|
||||
from app.core.workflow.engine.state_manager import WorkflowState
|
||||
from app.core.workflow.engine.variable_pool import VariablePool
|
||||
from app.core.workflow.nodes.cycle_graph import IterationNodeConfig
|
||||
from app.core.workflow.nodes.enums import NodeType
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
from app.core.workflow.variable_pool import VariablePool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -25,6 +29,7 @@ class IterationRuntime:
|
||||
def __init__(
|
||||
self,
|
||||
start_id: str,
|
||||
stream: bool,
|
||||
graph: CompiledStateGraph,
|
||||
node_id: str,
|
||||
config: dict[str, Any],
|
||||
@@ -42,6 +47,7 @@ class IterationRuntime:
|
||||
state: Current workflow state at the point of iteration.
|
||||
"""
|
||||
self.start_id = start_id
|
||||
self.stream = stream
|
||||
self.graph = graph
|
||||
self.state = state
|
||||
self.node_id = node_id
|
||||
@@ -49,6 +55,12 @@ class IterationRuntime:
|
||||
self.looping = True
|
||||
self.variable_pool = variable_pool
|
||||
self.child_variable_pool = child_variable_pool
|
||||
self.event_write = get_stream_writer()
|
||||
self.checkpoint = RunnableConfig(
|
||||
configurable={
|
||||
"thread_id": uuid.uuid4()
|
||||
}
|
||||
)
|
||||
|
||||
self.output_value = None
|
||||
self.result: list = []
|
||||
@@ -91,7 +103,46 @@ class IterationRuntime:
|
||||
item: The input element for this iteration.
|
||||
idx: The index of this iteration.
|
||||
"""
|
||||
result = await self.graph.ainvoke(await self._init_iteration_state(item, idx))
|
||||
if self.stream:
|
||||
async for event in self.graph.astream(
|
||||
await self._init_iteration_state(item, idx),
|
||||
stream_mode=["debug"],
|
||||
config=self.checkpoint
|
||||
):
|
||||
if isinstance(event, tuple) and len(event) == 2:
|
||||
mode, data = event
|
||||
else:
|
||||
continue
|
||||
if mode == "debug":
|
||||
event_type = data.get("type")
|
||||
payload = data.get("payload", {})
|
||||
node_name = payload.get("name")
|
||||
|
||||
if node_name and node_name.startswith("nop"):
|
||||
continue
|
||||
if event_type == "task_result":
|
||||
result = payload.get("result", {})
|
||||
if not result.get("activate", {}).get(node_name):
|
||||
continue
|
||||
node_type = result.get("node_outputs", {}).get(node_name, {}).get("node_type")
|
||||
cycle_variable = {"item": item} if node_type == NodeType.CYCLE_START else None
|
||||
self.event_write({
|
||||
"type": "cycle_item",
|
||||
"data": {
|
||||
"cycle_id": self.node_id,
|
||||
"cycle_idx": idx,
|
||||
"node_id": node_name,
|
||||
"input": result.get("node_outputs", {}).get(node_name, {}).get("input")
|
||||
if not cycle_variable else cycle_variable,
|
||||
"output": result.get("node_outputs", {}).get(node_name, {}).get("output")
|
||||
if not cycle_variable else cycle_variable,
|
||||
"elapsed_time": result.get("node_outputs", {}).get(node_name, {}).get("elapsed_time"),
|
||||
"token_usage": result.get("node_outputs", {}).get(node_name, {}).get("token_usage")
|
||||
}
|
||||
})
|
||||
result = self.graph.get_state(config=self.checkpoint).values
|
||||
else:
|
||||
result = await self.graph.ainvoke(await self._init_iteration_state(item, idx))
|
||||
output = self.child_variable_pool.get_value(self.output_value)
|
||||
if isinstance(output, list) and self.typed_config.flatten:
|
||||
self.result.extend(output)
|
||||
@@ -152,16 +203,9 @@ class IterationRuntime:
|
||||
while idx < len(array_obj) and self.looping:
|
||||
logger.info(f"Iteration node {self.node_id}: running")
|
||||
item = array_obj[idx]
|
||||
result = await self.graph.ainvoke(await self._init_iteration_state(item, idx))
|
||||
child_state.append(result)
|
||||
output = self.child_variable_pool.get_value(self.output_value)
|
||||
result = await self.run_task(item, idx)
|
||||
self.merge_conv_vars()
|
||||
if isinstance(output, list) and self.typed_config.flatten:
|
||||
self.result.extend(output)
|
||||
else:
|
||||
self.result.append(output)
|
||||
if result["looping"] == 2:
|
||||
self.looping = False
|
||||
child_state.append(result)
|
||||
idx += 1
|
||||
logger.info(f"Iteration node {self.node_id}: execution completed")
|
||||
return {
|
||||
|
||||
@@ -1,14 +1,17 @@
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.config import get_stream_writer
|
||||
from langgraph.graph.state import CompiledStateGraph
|
||||
|
||||
from app.core.workflow.expression_evaluator import evaluate_expression
|
||||
from app.core.workflow.nodes import WorkflowState
|
||||
from app.core.workflow.engine.state_manager import WorkflowState
|
||||
from app.core.workflow.engine.variable_pool import VariablePool
|
||||
from app.core.workflow.nodes.cycle_graph import LoopNodeConfig
|
||||
from app.core.workflow.nodes.enums import ValueInputType, ComparisonOperator, LogicOperator
|
||||
from app.core.workflow.nodes.enums import ValueInputType, ComparisonOperator, LogicOperator, NodeType
|
||||
from app.core.workflow.nodes.operators import TypeTransformer, ConditionExpressionResolver, CompareOperatorInstance
|
||||
from app.core.workflow.variable_pool import VariablePool
|
||||
from app.core.workflow.utils.expression_evaluator import evaluate_expression
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -27,6 +30,7 @@ class LoopRuntime:
|
||||
def __init__(
|
||||
self,
|
||||
start_id: str,
|
||||
stream: bool,
|
||||
graph: CompiledStateGraph,
|
||||
node_id: str,
|
||||
config: dict[str, Any],
|
||||
@@ -46,6 +50,7 @@ class LoopRuntime:
|
||||
child_variable_pool: A VariablePool instance for managing child node outputs.
|
||||
"""
|
||||
self.start_id = start_id
|
||||
self.stream = stream
|
||||
self.graph = graph
|
||||
self.state = state
|
||||
self.node_id = node_id
|
||||
@@ -53,6 +58,13 @@ class LoopRuntime:
|
||||
self.looping = True
|
||||
self.variable_pool = variable_pool
|
||||
self.child_variable_pool = child_variable_pool
|
||||
self.event_write = get_stream_writer()
|
||||
|
||||
self.checkpoint = RunnableConfig(
|
||||
configurable={
|
||||
"thread_id": uuid.uuid4()
|
||||
}
|
||||
)
|
||||
|
||||
async def _init_loop_state(self):
|
||||
"""
|
||||
@@ -142,10 +154,12 @@ class LoopRuntime:
|
||||
case _:
|
||||
raise ValueError(f"Invalid condition: {operator}")
|
||||
|
||||
def merge_conv_vars(self):
|
||||
def merge_conv_vars(self, loopstate):
|
||||
self.variable_pool.variables["conv"].update(
|
||||
self.child_variable_pool.variables.get("conv", {})
|
||||
)
|
||||
loop_vars = self.child_variable_pool.get_node_output(self.node_id, defalut={}, strict=False)
|
||||
loopstate["node_outputs"][self.node_id] = loop_vars
|
||||
|
||||
def evaluate_conditional(self) -> bool:
|
||||
"""
|
||||
@@ -175,6 +189,50 @@ class LoopRuntime:
|
||||
else:
|
||||
return any(conditions)
|
||||
|
||||
async def _run(self, loopstate, idx):
|
||||
if self.stream:
|
||||
async for event in self.graph.astream(
|
||||
loopstate,
|
||||
stream_mode=["debug"],
|
||||
config=self.checkpoint
|
||||
):
|
||||
if isinstance(event, tuple) and len(event) == 2:
|
||||
mode, data = event
|
||||
else:
|
||||
continue
|
||||
if mode == "debug":
|
||||
event_type = data.get("type")
|
||||
payload = data.get("payload", {})
|
||||
node_name = payload.get("name")
|
||||
|
||||
if node_name and node_name.startswith("nop"):
|
||||
continue
|
||||
if event_type == "task_result":
|
||||
result = payload.get("result", {})
|
||||
node_type = result.get("node_outputs", {}).get(node_name, {}).get("node_type")
|
||||
if not result.get("activate", {}).get(node_name):
|
||||
continue
|
||||
cycle_variable = None
|
||||
if node_type == NodeType.CYCLE_START:
|
||||
cycle_variable = loopstate.get("node_outputs", {}).get(self.node_id, {})
|
||||
self.event_write({
|
||||
"type": "cycle_item",
|
||||
"data": {
|
||||
"cycle_id": self.node_id,
|
||||
"cycle_idx": idx,
|
||||
"node_id": node_name,
|
||||
"input": result.get("node_outputs", {}).get(node_name, {}).get("input")
|
||||
if not cycle_variable else cycle_variable,
|
||||
"output": result.get("node_outputs", {}).get(node_name, {}).get("output")
|
||||
if not cycle_variable else cycle_variable,
|
||||
"elapsed_time": result.get("node_outputs", {}).get(node_name, {}).get("elapsed_time"),
|
||||
"token_usage": result.get("node_outputs", {}).get(node_name, {}).get("token_usage")
|
||||
}
|
||||
})
|
||||
return self.graph.get_state(config=self.checkpoint).values
|
||||
else:
|
||||
return await self.graph.ainvoke(loopstate)
|
||||
|
||||
async def run(self):
|
||||
"""
|
||||
Execute the loop node until termination conditions are met.
|
||||
@@ -190,15 +248,17 @@ class LoopRuntime:
|
||||
loopstate = await self._init_loop_state()
|
||||
loop_time = self.typed_config.max_loop
|
||||
child_state = []
|
||||
idx = 0
|
||||
while not self.evaluate_conditional() and self.looping and loop_time > 0:
|
||||
logger.info(f"loop node {self.node_id}: running")
|
||||
result = await self.graph.ainvoke(loopstate)
|
||||
result = await self._run(loopstate, idx)
|
||||
child_state.append(result)
|
||||
|
||||
self.merge_conv_vars()
|
||||
self.merge_conv_vars(loopstate)
|
||||
if result["looping"] == 2:
|
||||
self.looping = False
|
||||
loop_time -= 1
|
||||
idx += 1
|
||||
|
||||
logger.info(f"loop node {self.node_id}: execution completed")
|
||||
return self.child_variable_pool.get_node_output(self.node_id) | {"__child_state": child_state}
|
||||
|
||||
@@ -4,14 +4,14 @@ from typing import Any
|
||||
from langgraph.graph import StateGraph
|
||||
from langgraph.graph.state import CompiledStateGraph
|
||||
|
||||
from app.core.workflow.nodes import WorkflowState
|
||||
from app.core.workflow.engine.state_manager import WorkflowState
|
||||
from app.core.workflow.engine.variable_pool import VariablePool
|
||||
from app.core.workflow.nodes.base_node import BaseNode
|
||||
from app.core.workflow.nodes.cycle_graph import LoopNodeConfig, IterationNodeConfig
|
||||
from app.core.workflow.nodes.cycle_graph.iteration import IterationRuntime
|
||||
from app.core.workflow.nodes.cycle_graph.loop import LoopRuntime
|
||||
from app.core.workflow.nodes.enums import NodeType
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
from app.core.workflow.variable_pool import VariablePool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -136,7 +136,7 @@ class CycleGraphNode(BaseNode):
|
||||
2. Construct a StateGraph using GraphBuilder in subgraph mode
|
||||
3. Compile the graph for runtime execution
|
||||
"""
|
||||
from app.core.workflow.graph_builder import GraphBuilder
|
||||
from app.core.workflow.engine.graph_builder import GraphBuilder
|
||||
self.cycle_nodes, self.cycle_edges = self.pure_cycle_graph()
|
||||
self.child_variable_pool = VariablePool()
|
||||
builder = GraphBuilder(
|
||||
@@ -172,6 +172,7 @@ class CycleGraphNode(BaseNode):
|
||||
if self.node_type == NodeType.LOOP:
|
||||
return await LoopRuntime(
|
||||
start_id=self.start_node_id,
|
||||
stream=False,
|
||||
graph=self.graph,
|
||||
node_id=self.node_id,
|
||||
config=self.config,
|
||||
@@ -182,6 +183,7 @@ class CycleGraphNode(BaseNode):
|
||||
if self.node_type == NodeType.ITERATION:
|
||||
return await IterationRuntime(
|
||||
start_id=self.start_node_id,
|
||||
stream=False,
|
||||
graph=self.graph,
|
||||
node_id=self.node_id,
|
||||
config=self.config,
|
||||
@@ -190,3 +192,36 @@ class CycleGraphNode(BaseNode):
|
||||
child_variable_pool=self.child_variable_pool
|
||||
).run()
|
||||
raise RuntimeError("Unknown cycle node type")
|
||||
|
||||
async def execute_stream(self, state: WorkflowState, variable_pool: VariablePool):
|
||||
if self.node_type == NodeType.LOOP:
|
||||
yield {
|
||||
"__final__": True,
|
||||
"result": await LoopRuntime(
|
||||
start_id=self.start_node_id,
|
||||
stream=True,
|
||||
graph=self.graph,
|
||||
node_id=self.node_id,
|
||||
config=self.config,
|
||||
state=state,
|
||||
variable_pool=variable_pool,
|
||||
child_variable_pool=self.child_variable_pool,
|
||||
).run()
|
||||
}
|
||||
return
|
||||
if self.node_type == NodeType.ITERATION:
|
||||
yield {
|
||||
"__final__": True,
|
||||
"result": await IterationRuntime(
|
||||
start_id=self.start_node_id,
|
||||
stream=True,
|
||||
graph=self.graph,
|
||||
node_id=self.node_id,
|
||||
config=self.config,
|
||||
state=state,
|
||||
variable_pool=variable_pool,
|
||||
child_variable_pool=self.child_variable_pool
|
||||
).run()
|
||||
}
|
||||
return
|
||||
raise RuntimeError("Unknown cycle node type")
|
||||
|
||||
@@ -6,9 +6,10 @@ End 节点实现
|
||||
|
||||
import logging
|
||||
|
||||
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
|
||||
from app.core.workflow.engine.state_manager import WorkflowState
|
||||
from app.core.workflow.engine.variable_pool import VariablePool
|
||||
from app.core.workflow.nodes.base_node import BaseNode
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
from app.core.workflow.variable_pool import VariablePool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -7,11 +7,12 @@ import httpx
|
||||
# import filetypes # TODO: File support (Feature)
|
||||
from httpx import AsyncClient, Response, Timeout
|
||||
|
||||
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
|
||||
from app.core.workflow.engine.state_manager import WorkflowState
|
||||
from app.core.workflow.engine.variable_pool import VariablePool
|
||||
from app.core.workflow.nodes.base_node import BaseNode
|
||||
from app.core.workflow.nodes.enums import HttpRequestMethod, HttpErrorHandle, HttpAuthType, HttpContentType
|
||||
from app.core.workflow.nodes.http_request.config import HttpRequestNodeConfig, HttpRequestNodeOutput
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
from app.core.workflow.variable_pool import VariablePool
|
||||
|
||||
logger = logging.getLogger(__file__)
|
||||
|
||||
|
||||
@@ -60,7 +60,7 @@ class IfElseNodeConfig(BaseNodeConfig):
|
||||
|
||||
@field_validator("cases")
|
||||
@classmethod
|
||||
def validate_case_number(cls, v, info):
|
||||
def validate_case_number(cls, v):
|
||||
if len(v) < 1:
|
||||
raise ValueError("At least one cases are required")
|
||||
return v
|
||||
|
||||
@@ -2,12 +2,13 @@ import logging
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
|
||||
from app.core.workflow.engine.state_manager import WorkflowState
|
||||
from app.core.workflow.engine.variable_pool import VariablePool
|
||||
from app.core.workflow.nodes.base_node import BaseNode
|
||||
from app.core.workflow.nodes.enums import ComparisonOperator, LogicOperator
|
||||
from app.core.workflow.nodes.if_else import IfElseNodeConfig
|
||||
from app.core.workflow.nodes.operators import ConditionExpressionResolver, CompareOperatorInstance
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
from app.core.workflow.variable_pool import VariablePool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from app.core.workflow.nodes import WorkflowState
|
||||
from app.core.workflow.engine.state_manager import WorkflowState
|
||||
from app.core.workflow.engine.variable_pool import VariablePool
|
||||
from app.core.workflow.nodes.base_node import BaseNode
|
||||
from app.core.workflow.nodes.jinja_render.config import JinjaRenderNodeConfig
|
||||
from app.core.workflow.template_renderer import TemplateRenderer
|
||||
from app.core.workflow.utils.template_renderer import TemplateRenderer
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
from app.core.workflow.variable_pool import VariablePool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -6,10 +6,11 @@ from app.core.error_codes import BizCode
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.models import RedBearRerank, RedBearModelConfig
|
||||
from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory
|
||||
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
|
||||
from app.core.workflow.engine.state_manager import WorkflowState
|
||||
from app.core.workflow.engine.variable_pool import VariablePool
|
||||
from app.core.workflow.nodes.base_node import BaseNode
|
||||
from app.core.workflow.nodes.knowledge import KnowledgeRetrievalNodeConfig
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
from app.core.workflow.variable_pool import VariablePool
|
||||
from app.db import get_db_read
|
||||
from app.models import knowledge_model, knowledgeshare_model, ModelType
|
||||
from app.repositories import knowledge_repository, knowledgeshare_repository
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""LLM 节点配置"""
|
||||
|
||||
from typing import Any
|
||||
import uuid
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
@@ -56,7 +57,7 @@ class LLMNodeConfig(BaseNodeConfig):
|
||||
2. 消息模式:使用 messages 字段(推荐)
|
||||
"""
|
||||
|
||||
model_id: str = Field(
|
||||
model_id: uuid.UUID = Field(
|
||||
...,
|
||||
description="模型配置 ID"
|
||||
)
|
||||
@@ -148,7 +149,7 @@ class LLMNodeConfig(BaseNodeConfig):
|
||||
|
||||
@field_validator("messages", "prompt")
|
||||
@classmethod
|
||||
def validate_input_mode(cls, v, info):
|
||||
def validate_input_mode(cls, v):
|
||||
"""验证输入模式:prompt 和 messages 至少有一个"""
|
||||
# 这个验证在 model_validator 中更合适
|
||||
return v
|
||||
|
||||
@@ -13,10 +13,11 @@ from langchain_core.messages import AIMessage
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.models import RedBearLLM, RedBearModelConfig
|
||||
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
|
||||
from app.core.workflow.engine.state_manager import WorkflowState
|
||||
from app.core.workflow.engine.variable_pool import VariablePool
|
||||
from app.core.workflow.nodes.base_node import BaseNode
|
||||
from app.core.workflow.nodes.llm.config import LLMNodeConfig
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
from app.core.workflow.variable_pool import VariablePool
|
||||
from app.db import get_db_context
|
||||
from app.models import ModelType
|
||||
from app.services.model_service import ModelConfigService
|
||||
@@ -268,7 +269,7 @@ class LLMNode(BaseNode):
|
||||
llm = await self._prepare_llm(state, variable_pool, True)
|
||||
|
||||
logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(流式)")
|
||||
logger.debug(f"LLM 配置: streaming={getattr(llm._model, 'streaming', 'unknown')}")
|
||||
# logger.debug(f"LLM 配置: streaming={getattr(llm._model, 'streaming', 'unknown')}")
|
||||
|
||||
# 累积完整响应
|
||||
full_response = ""
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
from typing import Any
|
||||
|
||||
from app.core.workflow.nodes import WorkflowState
|
||||
from app.core.workflow.engine.state_manager import WorkflowState
|
||||
from app.core.workflow.engine.variable_pool import VariablePool
|
||||
from app.core.workflow.nodes.base_node import BaseNode
|
||||
from app.core.workflow.nodes.memory.config import MemoryReadNodeConfig, MemoryWriteNodeConfig
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
from app.core.workflow.variable_pool import VariablePool
|
||||
from app.db import get_db_read
|
||||
from app.services.memory_agent_service import MemoryAgentService
|
||||
from app.tasks import write_message_task
|
||||
|
||||
@@ -3,9 +3,9 @@ import re
|
||||
from abc import ABC
|
||||
from typing import Union, Type, NoReturn, Any
|
||||
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
from app.core.workflow.engine.variable_pool import VariablePool
|
||||
from app.core.workflow.nodes.enums import ValueInputType
|
||||
from app.core.workflow.variable_pool import VariablePool
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
|
||||
|
||||
class TypeTransformer:
|
||||
|
||||
@@ -1,19 +1,18 @@
|
||||
import os
|
||||
import logging
|
||||
|
||||
import json_repair
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
import json_repair
|
||||
from jinja2 import Template
|
||||
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.models import RedBearLLM, RedBearModelConfig
|
||||
from app.core.workflow.nodes import WorkflowState
|
||||
from app.core.workflow.engine.state_manager import WorkflowState
|
||||
from app.core.workflow.engine.variable_pool import VariablePool
|
||||
from app.core.workflow.nodes.base_node import BaseNode
|
||||
from app.core.workflow.nodes.parameter_extractor.config import ParameterExtractorNodeConfig
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
from app.core.workflow.variable_pool import VariablePool
|
||||
from app.db import get_db_read
|
||||
from app.models import ModelType
|
||||
from app.services.model_service import ModelConfigService
|
||||
|
||||
@@ -1,13 +1,14 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
|
||||
from app.core.workflow.nodes.question_classifier.config import QuestionClassifierNodeConfig
|
||||
from app.core.models import RedBearLLM, RedBearModelConfig
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.models import RedBearLLM, RedBearModelConfig
|
||||
from app.core.workflow.engine.state_manager import WorkflowState
|
||||
from app.core.workflow.engine.variable_pool import VariablePool
|
||||
from app.core.workflow.nodes.base_node import BaseNode
|
||||
from app.core.workflow.nodes.question_classifier.config import QuestionClassifierNodeConfig
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
from app.core.workflow.variable_pool import VariablePool
|
||||
from app.db import get_db_read
|
||||
from app.models import ModelType
|
||||
from app.services.model_service import ModelConfigService
|
||||
|
||||
@@ -7,10 +7,11 @@ Start 节点实现
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from app.core.workflow.variable.base_variable import VariableType, DEFAULT_VALUE
|
||||
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
|
||||
from app.core.workflow.engine.state_manager import WorkflowState
|
||||
from app.core.workflow.engine.variable_pool import VariablePool
|
||||
from app.core.workflow.nodes.base_node import BaseNode
|
||||
from app.core.workflow.nodes.start.config import StartNodeConfig
|
||||
from app.core.workflow.variable_pool import VariablePool
|
||||
from app.core.workflow.variable.base_variable import VariableType, DEFAULT_VALUE
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -4,16 +4,17 @@ import re
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
|
||||
from app.core.workflow.engine.state_manager import WorkflowState
|
||||
from app.core.workflow.engine.variable_pool import VariablePool
|
||||
from app.core.workflow.nodes.base_node import BaseNode
|
||||
from app.core.workflow.nodes.tool.config import ToolNodeConfig
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
from app.core.workflow.variable_pool import VariablePool
|
||||
from app.services.tool_service import ToolService
|
||||
from app.db import get_db_read
|
||||
from app.services.tool_service import ToolService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
TEMPLATE_PATTERN = re.compile(r"\{\{.*?\}\}")
|
||||
TEMPLATE_PATTERN = re.compile(r"\{\{.*?}}")
|
||||
|
||||
|
||||
class ToolNode(BaseNode):
|
||||
|
||||
@@ -2,11 +2,11 @@ import logging
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from app.core.workflow.nodes import WorkflowState
|
||||
from app.core.workflow.engine.state_manager import WorkflowState
|
||||
from app.core.workflow.engine.variable_pool import VariablePool
|
||||
from app.core.workflow.nodes.base_node import BaseNode
|
||||
from app.core.workflow.nodes.variable_aggregator.config import VariableAggregatorNodeConfig
|
||||
from app.core.workflow.variable.base_variable import VariableType, DEFAULT_VALUE
|
||||
from app.core.workflow.variable_pool import VariablePool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
4
api/app/core/workflow/utils/__init__.py
Normal file
4
api/app/core/workflow/utils/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
# -*- coding: UTF-8 -*-
|
||||
# Author: Eternity
|
||||
# @Email: 1533512157@qq.com
|
||||
# @Time : 2026/2/9 16:24
|
||||
@@ -5,7 +5,6 @@
|
||||
"""
|
||||
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from typing import Any
|
||||
|
||||
from jinja2 import TemplateSyntaxError, UndefinedError, Environment, StrictUndefined, Undefined
|
||||
@@ -187,7 +187,7 @@ class WorkflowValidator:
|
||||
)
|
||||
|
||||
# 8. 验证变量名
|
||||
from app.core.workflow.expression_evaluator import ExpressionEvaluator
|
||||
from app.core.workflow.utils.expression_evaluator import ExpressionEvaluator
|
||||
var_errors = ExpressionEvaluator.validate_variable_names(variables)
|
||||
errors.extend(var_errors)
|
||||
|
||||
|
||||
@@ -9,6 +9,8 @@ from .generic_file_model import GenericFile
|
||||
from .models_model import ModelConfig, ModelProvider, ModelType, ModelApiKey, ModelBase, LoadBalanceStrategy
|
||||
from .memory_short_model import ShortTermMemory, LongTermMemory
|
||||
from .knowledgeshare_model import KnowledgeShare
|
||||
from .mcp_market_model import McpMarket
|
||||
from .mcp_market_config_model import McpMarketConfig
|
||||
from .app_model import App
|
||||
from .agent_app_config_model import AgentConfig
|
||||
from .app_release_model import AppRelease
|
||||
@@ -33,6 +35,7 @@ from .ontology_scene import OntologyScene
|
||||
from .ontology_class import OntologyClass
|
||||
from .ontology_scene import OntologyScene
|
||||
from .ontology_class import OntologyClass
|
||||
from .implicit_emotions_storage_model import ImplicitEmotionsStorage
|
||||
|
||||
__all__ = [
|
||||
"Tenants",
|
||||
@@ -50,6 +53,8 @@ __all__ = [
|
||||
"ModelType",
|
||||
"ModelApiKey",
|
||||
"KnowledgeShare",
|
||||
"McpMarket",
|
||||
"McpMarketConfig",
|
||||
"App",
|
||||
"AgentConfig",
|
||||
"AppRelease",
|
||||
@@ -86,5 +91,6 @@ __all__ = [
|
||||
"MemoryPerceptualModel",
|
||||
"ModelBase",
|
||||
"LoadBalanceStrategy",
|
||||
"Skill"
|
||||
"Skill",
|
||||
"ImplicitEmotionsStorage"
|
||||
]
|
||||
|
||||
@@ -35,7 +35,7 @@ class FileMetadata(Base):
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, index=True)
|
||||
tenant_id = Column(UUID(as_uuid=True), nullable=False, index=True, comment="Tenant ID")
|
||||
workspace_id = Column(UUID(as_uuid=True), nullable=False, index=True, comment="Workspace ID")
|
||||
workspace_id = Column(UUID(as_uuid=True), nullable=True, index=True, comment="Workspace ID")
|
||||
file_key = Column(String(512), nullable=False, unique=True, index=True, comment="Storage file key")
|
||||
file_name = Column(String(255), nullable=False, comment="Original file name")
|
||||
file_ext = Column(String(32), nullable=False, comment="File extension")
|
||||
|
||||
45
api/app/models/implicit_emotions_storage_model.py
Normal file
45
api/app/models/implicit_emotions_storage_model.py
Normal file
@@ -0,0 +1,45 @@
|
||||
"""
|
||||
Implicit Emotions Storage Model
|
||||
|
||||
数据库模型:存储用户的隐性记忆画像和情绪建议数据
|
||||
替代原有的Redis缓存方式
|
||||
"""
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from sqlalchemy import Column, String, Text, DateTime, Index
|
||||
from sqlalchemy.dialects.postgresql import UUID, JSONB
|
||||
from app.db import Base
|
||||
|
||||
|
||||
class ImplicitEmotionsStorage(Base):
|
||||
"""隐性记忆和情绪存储表"""
|
||||
|
||||
__tablename__ = "implicit_emotions_storage"
|
||||
|
||||
# 主键
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, comment="主键ID")
|
||||
|
||||
# 用户标识(unique=True会自动创建唯一索引)
|
||||
end_user_id = Column(String(255), nullable=False, unique=True, comment="终端用户ID")
|
||||
|
||||
# 隐性记忆画像数据(JSON格式)
|
||||
implicit_profile = Column(JSONB, nullable=True, comment="隐性记忆用户画像数据")
|
||||
|
||||
# 情绪建议数据(JSON格式)
|
||||
emotion_suggestions = Column(JSONB, nullable=True, comment="情绪个性化建议数据")
|
||||
|
||||
# 时间戳
|
||||
created_at = Column(DateTime, nullable=False, default=datetime.utcnow, comment="创建时间")
|
||||
updated_at = Column(DateTime, nullable=False, default=datetime.utcnow, onupdate=datetime.utcnow, comment="更新时间")
|
||||
|
||||
# 数据生成时间(用于业务逻辑)
|
||||
implicit_generated_at = Column(DateTime, nullable=True, comment="隐性记忆画像生成时间")
|
||||
emotion_generated_at = Column(DateTime, nullable=True, comment="情绪建议生成时间")
|
||||
|
||||
# 索引(只为updated_at创建索引,end_user_id的unique约束已自动创建索引)
|
||||
__table_args__ = (
|
||||
Index('idx_updated_at', 'updated_at'),
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
return f"<ImplicitEmotionsStorage(id={self.id}, end_user_id={self.end_user_id})>"
|
||||
16
api/app/models/mcp_market_config_model.py
Normal file
16
api/app/models/mcp_market_config_model.py
Normal file
@@ -0,0 +1,16 @@
|
||||
import datetime
|
||||
import uuid
|
||||
from sqlalchemy import Column, Integer, String, DateTime, ForeignKey
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from app.db import Base
|
||||
|
||||
class McpMarketConfig(Base):
|
||||
__tablename__ = "mcp_market_configs"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, index=True)
|
||||
mcp_market_id = Column(UUID(as_uuid=True), nullable=False, comment="mcp_markets.id")
|
||||
token = Column(String, nullable=True, comment="mcp market token")
|
||||
status = Column(Integer, default=0, comment="connect status(0: Not connected, 1: connected)")
|
||||
tenant_id = Column(UUID(as_uuid=True), nullable=False, comment="tenant.id")
|
||||
created_by = Column(UUID(as_uuid=True), nullable=False, comment="users.id")
|
||||
created_at = Column(DateTime, default=datetime.datetime.now)
|
||||
18
api/app/models/mcp_market_model.py
Normal file
18
api/app/models/mcp_market_model.py
Normal file
@@ -0,0 +1,18 @@
|
||||
import datetime
|
||||
import uuid
|
||||
from sqlalchemy import Column, Integer, String, DateTime, ForeignKey
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from app.db import Base
|
||||
|
||||
class McpMarket(Base):
|
||||
__tablename__ = "mcp_markets"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, index=True)
|
||||
name = Column(String, index=True, nullable=False, comment="mcp market name")
|
||||
description = Column(String, index=True, nullable=True, comment="mcp market description")
|
||||
logo_url = Column(String, index=True, nullable=True, comment="logo url")
|
||||
mcp_count = Column(Integer, default=1, comment="mcp count")
|
||||
url = Column(String, index=True, nullable=False, comment="mcp market url")
|
||||
category = Column(String, index=True, nullable=False, comment="category")
|
||||
created_by = Column(UUID(as_uuid=True), nullable=False, comment="users.id")
|
||||
created_at = Column(DateTime, default=datetime.datetime.now)
|
||||
169
api/app/repositories/implicit_emotions_storage_repository.py
Normal file
169
api/app/repositories/implicit_emotions_storage_repository.py
Normal file
@@ -0,0 +1,169 @@
|
||||
"""
|
||||
Implicit Emotions Storage Repository
|
||||
|
||||
数据访问层:处理隐性记忆和情绪数据的数据库操作
|
||||
事务由调用方控制,仓储层只使用 flush/refresh
|
||||
"""
|
||||
import logging
|
||||
from datetime import datetime, date, timezone, timedelta
|
||||
from typing import Optional, Generator
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import select, not_, exists
|
||||
|
||||
from app.models.implicit_emotions_storage_model import ImplicitEmotionsStorage
|
||||
from app.models.end_user_model import EndUser
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ImplicitEmotionsStorageRepository:
|
||||
"""隐性记忆和情绪存储仓储类"""
|
||||
|
||||
def __init__(self, db: Session):
|
||||
self.db = db
|
||||
|
||||
def get_by_end_user_id(self, end_user_id: str) -> Optional[ImplicitEmotionsStorage]:
|
||||
"""根据终端用户ID获取存储记录"""
|
||||
try:
|
||||
stmt = select(ImplicitEmotionsStorage).where(
|
||||
ImplicitEmotionsStorage.end_user_id == end_user_id
|
||||
)
|
||||
return self.db.execute(stmt).scalar_one_or_none()
|
||||
except Exception as e:
|
||||
logger.error(f"获取用户存储记录失败: end_user_id={end_user_id}, error={e}")
|
||||
return None
|
||||
|
||||
def create(self, end_user_id: str) -> ImplicitEmotionsStorage:
|
||||
"""创建新的存储记录(事务由调用方提交)"""
|
||||
storage = ImplicitEmotionsStorage(
|
||||
end_user_id=end_user_id,
|
||||
created_at=datetime.utcnow(),
|
||||
updated_at=datetime.utcnow()
|
||||
)
|
||||
self.db.add(storage)
|
||||
self.db.flush()
|
||||
self.db.refresh(storage)
|
||||
logger.info(f"创建用户存储记录成功: end_user_id={end_user_id}")
|
||||
return storage
|
||||
|
||||
def update_implicit_profile(
|
||||
self,
|
||||
end_user_id: str,
|
||||
profile_data: dict
|
||||
) -> ImplicitEmotionsStorage:
|
||||
"""更新隐性记忆画像数据(事务由调用方提交)"""
|
||||
storage = self.get_by_end_user_id(end_user_id)
|
||||
if storage is None:
|
||||
storage = self.create(end_user_id)
|
||||
|
||||
storage.implicit_profile = profile_data
|
||||
storage.implicit_generated_at = datetime.utcnow()
|
||||
storage.updated_at = datetime.utcnow()
|
||||
|
||||
self.db.flush()
|
||||
self.db.refresh(storage)
|
||||
logger.info(f"更新隐性记忆画像成功: end_user_id={end_user_id}")
|
||||
return storage
|
||||
|
||||
def update_emotion_suggestions(
|
||||
self,
|
||||
end_user_id: str,
|
||||
suggestions_data: dict
|
||||
) -> ImplicitEmotionsStorage:
|
||||
"""更新情绪建议数据(事务由调用方提交)"""
|
||||
storage = self.get_by_end_user_id(end_user_id)
|
||||
if storage is None:
|
||||
storage = self.create(end_user_id)
|
||||
|
||||
storage.emotion_suggestions = suggestions_data
|
||||
storage.emotion_generated_at = datetime.utcnow()
|
||||
storage.updated_at = datetime.utcnow()
|
||||
|
||||
self.db.flush()
|
||||
self.db.refresh(storage)
|
||||
logger.info(f"更新情绪建议成功: end_user_id={end_user_id}")
|
||||
return storage
|
||||
|
||||
def get_all_user_ids(self, batch_size: int = 100) -> Generator[str, None, None]:
|
||||
"""分批次获取所有已存储数据的用户ID(避免大数据量内存溢出)
|
||||
|
||||
Args:
|
||||
batch_size: 每批次加载的数量,默认100
|
||||
|
||||
Yields:
|
||||
用户ID字符串
|
||||
"""
|
||||
offset = 0
|
||||
while True:
|
||||
try:
|
||||
stmt = (
|
||||
select(ImplicitEmotionsStorage.end_user_id)
|
||||
.order_by(ImplicitEmotionsStorage.end_user_id)
|
||||
.limit(batch_size)
|
||||
.offset(offset)
|
||||
)
|
||||
batch = self.db.execute(stmt).scalars().all()
|
||||
if not batch:
|
||||
break
|
||||
yield from batch
|
||||
offset += batch_size
|
||||
except Exception as e:
|
||||
logger.error(f"分批获取用户ID失败: offset={offset}, error={e}")
|
||||
break
|
||||
|
||||
def get_new_user_ids_today(self, batch_size: int = 100) -> Generator[str, None, None]:
|
||||
"""分批次获取当天新增的、尚未初始化隐性记忆和情绪建议数据的用户ID
|
||||
|
||||
查询逻辑:end_users 表中 created_at 为今天,且在 implicit_emotions_storage 中没有对应记录。
|
||||
没有对应记录意味着隐性记忆画像和情绪建议均未初始化,需要对这批用户执行首次初始化。
|
||||
end_users.id(UUID)转为字符串后与 implicit_emotions_storage.end_user_id(String)对比。
|
||||
|
||||
Args:
|
||||
batch_size: 每批次加载的数量,默认100
|
||||
|
||||
Yields:
|
||||
用户ID字符串
|
||||
"""
|
||||
from sqlalchemy import cast, String as SAString
|
||||
CST = timezone(timedelta(hours=8))
|
||||
now_cst = datetime.now(CST)
|
||||
today_start = now_cst.replace(hour=0, minute=0, second=0, microsecond=0).astimezone(timezone.utc).replace(tzinfo=None)
|
||||
tomorrow_start = today_start + timedelta(days=1)
|
||||
offset = 0
|
||||
while True:
|
||||
try:
|
||||
stmt = (
|
||||
select(EndUser.id)
|
||||
.where(
|
||||
EndUser.created_at >= today_start,
|
||||
EndUser.created_at < tomorrow_start,
|
||||
not_(
|
||||
exists(
|
||||
select(ImplicitEmotionsStorage.end_user_id).where(
|
||||
ImplicitEmotionsStorage.end_user_id == cast(EndUser.id, SAString)
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
.order_by(EndUser.id)
|
||||
.limit(batch_size)
|
||||
.offset(offset)
|
||||
)
|
||||
batch = self.db.execute(stmt).scalars().all()
|
||||
if not batch:
|
||||
break
|
||||
yield from (str(uid) for uid in batch)
|
||||
offset += batch_size
|
||||
except Exception as e:
|
||||
logger.error(f"分批获取当天新增用户ID失败: offset={offset}, error={e}")
|
||||
break
|
||||
|
||||
def delete_by_end_user_id(self, end_user_id: str) -> bool:
|
||||
"""删除用户的存储记录(事务由调用方提交)"""
|
||||
storage = self.get_by_end_user_id(end_user_id)
|
||||
if storage:
|
||||
self.db.delete(storage)
|
||||
self.db.flush()
|
||||
logger.info(f"删除用户存储记录成功: end_user_id={end_user_id}")
|
||||
return True
|
||||
return False
|
||||
72
api/app/repositories/mcp_market_config_repository.py
Normal file
72
api/app/repositories/mcp_market_config_repository.py
Normal file
@@ -0,0 +1,72 @@
|
||||
import uuid
|
||||
import datetime
|
||||
from sqlalchemy.orm import Session
|
||||
from app.models.mcp_market_config_model import McpMarketConfig
|
||||
from app.schemas import mcp_market_config_schema
|
||||
from app.core.logging_config import get_db_logger
|
||||
|
||||
# Obtain a dedicated logger for the database
|
||||
db_logger = get_db_logger()
|
||||
|
||||
|
||||
def create_mcp_market_config(db: Session, mcp_market_config: mcp_market_config_schema.McpMarketConfigCreate) -> McpMarketConfig:
|
||||
db_logger.debug(f"Create a mcp market config record: mcp_market_id={mcp_market_config.mcp_market_id}")
|
||||
|
||||
try:
|
||||
db_mcp_market_config = McpMarketConfig(**mcp_market_config.model_dump())
|
||||
db.add(db_mcp_market_config)
|
||||
db.commit()
|
||||
db_logger.info(f"McpMarketConfig record created successfully: {mcp_market_config.mcp_market_id} (ID: {db_mcp_market_config.id})")
|
||||
return db_mcp_market_config
|
||||
except Exception as e:
|
||||
db_logger.error(f"Failed to create a mcp market config record: mcp_market_id={mcp_market_config.mcp_market_id} - {str(e)}")
|
||||
db.rollback()
|
||||
raise
|
||||
|
||||
|
||||
def get_mcp_market_config_by_id(db: Session, mcp_market_config_id: uuid.UUID) -> McpMarketConfig | None:
|
||||
db_logger.debug(f"Query mcp market config based on ID: mcp_market_config_id={mcp_market_config_id}")
|
||||
|
||||
try:
|
||||
db_mcp_market_config = db.query(McpMarketConfig).filter(McpMarketConfig.id == mcp_market_config_id).first()
|
||||
if db_mcp_market_config:
|
||||
db_logger.debug(f"McpMarketConfig query successful: (ID: {mcp_market_config_id})")
|
||||
else:
|
||||
db_logger.debug(f"McpMarketConfig does not exist: mcp_market_config_id={mcp_market_config_id}")
|
||||
return db_mcp_market_config
|
||||
except Exception as e:
|
||||
db_logger.error(f"Failed to query the mcp market config based on the ID: {mcp_market_config_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
def get_mcp_market_config_by_mcp_market_id(db: Session, mcp_market_id: uuid.UUID, tenant_id: uuid.UUID) -> McpMarketConfig | None:
|
||||
db_logger.debug(f"Query mcp market config based on mcp_market_id: {mcp_market_id}")
|
||||
|
||||
try:
|
||||
db_mcp_market_config = db.query(McpMarketConfig).filter(McpMarketConfig.mcp_market_id == mcp_market_id, McpMarketConfig.tenant_id == tenant_id).first()
|
||||
if db_mcp_market_config:
|
||||
db_logger.debug(f"McpMarketConfig query successful: (mcp_market_id: {mcp_market_id})")
|
||||
else:
|
||||
db_logger.debug(f"McpMarketConfig does not exist: mcp_market_id={mcp_market_id}")
|
||||
return db_mcp_market_config
|
||||
except Exception as e:
|
||||
db_logger.error(f"Failed to query the mcp market config based on the mcp_market_id: {mcp_market_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
def delete_mcp_market_config_by_id(db: Session, mcp_market_config_id: uuid.UUID):
|
||||
db_logger.debug(f"Delete McpMarketConfig record: mcp_market_config_id={mcp_market_config_id}")
|
||||
|
||||
try:
|
||||
# First, query the mcp market config information for logging purposes
|
||||
result = db.query(McpMarketConfig).filter(McpMarketConfig.id == mcp_market_config_id).delete()
|
||||
db.commit()
|
||||
|
||||
if result > 0:
|
||||
db_logger.info(f"McpMarketConfig record deleted successfully: (ID: {mcp_market_config_id})")
|
||||
else:
|
||||
db_logger.warning(f"The mcp market config record does not exist, and cannot be deleted: id={mcp_market_config_id}")
|
||||
except Exception as e:
|
||||
db_logger.error(f"Failed to delete mcp market config record: id={mcp_market_config_id} - {str(e)}")
|
||||
db.rollback()
|
||||
raise
|
||||
124
api/app/repositories/mcp_market_repository.py
Normal file
124
api/app/repositories/mcp_market_repository.py
Normal file
@@ -0,0 +1,124 @@
|
||||
import uuid
|
||||
import datetime
|
||||
from sqlalchemy.orm import Session
|
||||
from app.models.mcp_market_model import McpMarket
|
||||
from app.schemas import mcp_market_schema
|
||||
from app.core.logging_config import get_db_logger
|
||||
|
||||
# Obtain a dedicated logger for the database
|
||||
db_logger = get_db_logger()
|
||||
|
||||
|
||||
def get_mcp_markets_paginated(
|
||||
db: Session,
|
||||
filters: list,
|
||||
page: int,
|
||||
pagesize: int,
|
||||
orderby: str = None,
|
||||
desc: bool = False
|
||||
) -> tuple[int, list]:
|
||||
"""
|
||||
Paged query mcp market (with filtering and sorting)
|
||||
"""
|
||||
db_logger.debug(
|
||||
f"Query mcp market in pages: page={page}, pagesize={pagesize}, orderby={orderby}, desc={desc}, filters_count={len(filters)}")
|
||||
|
||||
try:
|
||||
query = db.query(McpMarket)
|
||||
|
||||
# Apply filter conditions
|
||||
for filter_cond in filters:
|
||||
query = query.filter(filter_cond)
|
||||
|
||||
# Calculate the total count (for pagination)
|
||||
total = query.count()
|
||||
db_logger.debug(f"Total number of mcp_market queries: {total}")
|
||||
|
||||
# sort
|
||||
if orderby:
|
||||
order_attr = getattr(McpMarket, orderby, None)
|
||||
if order_attr is not None:
|
||||
if desc:
|
||||
query = query.order_by(order_attr.desc())
|
||||
else:
|
||||
query = query.order_by(order_attr.asc())
|
||||
db_logger.debug(f"sort: {orderby}, desc={desc}")
|
||||
|
||||
# pagination
|
||||
items = query.offset((page - 1) * pagesize).limit(pagesize).all()
|
||||
db_logger.info(
|
||||
f"The mcp market paging query has been successful: total={total}, Number of current page={len(items)}")
|
||||
|
||||
return total, [mcp_market_schema.McpMarket.model_validate(item) for item in items]
|
||||
except Exception as e:
|
||||
db_logger.error(f"Querying mcp_market pagination failed: page={page}, pagesize={pagesize} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
def create_mcp_market(db: Session, mcp_market: mcp_market_schema.McpMarketCreate) -> McpMarket:
|
||||
db_logger.debug(f"Create a mcp market record: name={mcp_market.name}")
|
||||
|
||||
try:
|
||||
db_mcp_market = McpMarket(**mcp_market.model_dump())
|
||||
db.add(db_mcp_market)
|
||||
db.commit()
|
||||
db_logger.info(f"McpMarket record created successfully: {mcp_market.name} (ID: {db_mcp_market.id})")
|
||||
return db_mcp_market
|
||||
except Exception as e:
|
||||
db_logger.error(f"Failed to create a mcp market record: title={mcp_market.name} - {str(e)}")
|
||||
db.rollback()
|
||||
raise
|
||||
|
||||
|
||||
def get_mcp_market_by_id(db: Session, mcp_market_id: uuid.UUID) -> McpMarket | None:
|
||||
db_logger.debug(f"Query mcp market based on ID: mcp_market_id={mcp_market_id}")
|
||||
|
||||
try:
|
||||
db_mcp_market = db.query(McpMarket).filter(McpMarket.id == mcp_market_id).first()
|
||||
if db_mcp_market:
|
||||
db_logger.debug(f"McpMarket query successful: {db_mcp_market.name} (ID: {mcp_market_id})")
|
||||
else:
|
||||
db_logger.debug(f"McpMarket does not exist: mcp_market_id={mcp_market_id}")
|
||||
return db_mcp_market
|
||||
except Exception as e:
|
||||
db_logger.error(f"Failed to query the mcp market based on the ID: mcp_market_id={mcp_market_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
def get_mcp_market_by_name(db: Session, name: str) -> McpMarket | None:
|
||||
db_logger.debug(f"Query mcp market based on name: name={name}")
|
||||
|
||||
try:
|
||||
db_mcp_market = db.query(McpMarket).filter(McpMarket.name == name).first()
|
||||
if db_mcp_market:
|
||||
db_logger.debug(f"mcp market query successful: {name} (ID: {db_mcp_market.id})")
|
||||
else:
|
||||
db_logger.debug(f"mcp market does not exist: name={name}")
|
||||
return db_mcp_market
|
||||
except Exception as e:
|
||||
db_logger.error(f"Failed to query the mcp market based on the name: {name} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
def delete_mcp_market_by_id(db: Session, mcp_market_id: uuid.UUID):
|
||||
db_logger.debug(f"Delete McpMarket record: mcp_market_id={mcp_market_id}")
|
||||
|
||||
try:
|
||||
# First, query the mcp market information for logging purposes
|
||||
db_mcp_market = db.query(McpMarket).filter(McpMarket.id == mcp_market_id).first()
|
||||
if db_mcp_market:
|
||||
name = db_mcp_market.name
|
||||
else:
|
||||
name = "unknown"
|
||||
|
||||
result = db.query(McpMarket).filter(McpMarket.id == mcp_market_id).delete()
|
||||
db.commit()
|
||||
|
||||
if result > 0:
|
||||
db_logger.info(f"McpMarket record deleted successfully: {name} (ID: {mcp_market_id})")
|
||||
else:
|
||||
db_logger.warning(f"The mcp market record does not exist, and cannot be deleted: mcp_market_id={mcp_market_id}")
|
||||
except Exception as e:
|
||||
db_logger.error(f"Failed to delete mcp market record: mcp_market_id={mcp_market_id} - {str(e)}")
|
||||
db.rollback()
|
||||
raise
|
||||
@@ -48,13 +48,17 @@ class ModelConfigRepository:
|
||||
raise
|
||||
|
||||
@staticmethod
|
||||
def get_by_name(db: Session, name: str, tenant_id: uuid.UUID | None = None) -> Optional[ModelConfig]:
|
||||
"""根据名称获取模型配置"""
|
||||
db_logger.debug(f"根据名称查询模型配置: name={name}, tenant_id={tenant_id}")
|
||||
def get_by_name(db: Session, name: str, provider: str | None = None, tenant_id: uuid.UUID | None = None) -> Optional[ModelConfig]:
|
||||
"""根据名称和供应商获取模型配置"""
|
||||
db_logger.debug(f"根据名称查询模型配置: name={name}, provider={provider}, tenant_id={tenant_id}")
|
||||
|
||||
try:
|
||||
query = db.query(ModelConfig).filter(ModelConfig.name == name)
|
||||
|
||||
# 添加供应商过滤
|
||||
if provider:
|
||||
query = query.filter(ModelConfig.provider == provider)
|
||||
|
||||
# 添加租户过滤
|
||||
if tenant_id:
|
||||
query = query.filter(
|
||||
@@ -69,7 +73,7 @@ class ModelConfigRepository:
|
||||
db_logger.debug(f"模型配置查询成功: {model.name}")
|
||||
return model
|
||||
except Exception as e:
|
||||
db_logger.error(f"根据名称查询模型配置失败: name={name} - {str(e)}")
|
||||
db_logger.error(f"根据名称查询模型配置失败: name={name}, provider={provider} - {str(e)}")
|
||||
raise
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -115,6 +115,7 @@ class WorkspaceRepository:
|
||||
self.db.query(Workspace)
|
||||
.join(WorkspaceMember, Workspace.id == WorkspaceMember.workspace_id)
|
||||
.filter(WorkspaceMember.user_id == user_id)
|
||||
.filter(WorkspaceMember.is_active.is_(True))
|
||||
.filter(Workspace.is_active.is_(True))
|
||||
.order_by(Workspace.updated_at.desc())
|
||||
.all()
|
||||
|
||||
@@ -8,6 +8,8 @@ from .file_schema import File, FileCreate, FileUpdate
|
||||
from .tenant_schema import Tenant, TenantCreate, TenantUpdate
|
||||
from .chunk_schema import ChunkCreate, ChunkUpdate, ChunkRetrieve
|
||||
from .knowledgeshare_schema import KnowledgeShare, KnowledgeShareCreate
|
||||
from .mcp_market_schema import McpMarketCreate, McpMarketUpdate, McpMarket
|
||||
from .mcp_market_config_schema import McpMarketConfigCreate, McpMarketConfigUpdate, McpMarketConfig
|
||||
from .order_schema import CreateOrderRequest, OrderResponse, ExternalOrderResponse
|
||||
from .app_schema import (
|
||||
AppChatRequest,
|
||||
@@ -78,6 +80,12 @@ __all__ = [
|
||||
"ChunkRetrieve",
|
||||
"KnowledgeShare",
|
||||
"KnowledgeShareCreate",
|
||||
"McpMarketCreate",
|
||||
"McpMarketUpdate",
|
||||
"McpMarket",
|
||||
"McpMarketConfigCreate",
|
||||
"McpMarketConfigUpdate",
|
||||
"McpMarketConfig",
|
||||
"CreateOrderRequest",
|
||||
"OrderResponse",
|
||||
"ExternalOrderResponse",
|
||||
|
||||
@@ -439,7 +439,7 @@ class DraftRunRequest(BaseModel):
|
||||
user_id: Optional[str] = Field(default=None, description="用户ID(用于会话管理)")
|
||||
variables: Optional[Dict[str, Any]] = Field(default=None, description="自定义变量参数值")
|
||||
stream: bool = Field(default=False, description="是否流式返回")
|
||||
files: Optional[List[FileInput]] = Field(default=None, description="附件列表(支持多文件)")
|
||||
files: Optional[List[FileInput]] = Field(default_factory=list, description="附件列表(支持多文件)")
|
||||
|
||||
|
||||
class DraftRunResponse(BaseModel):
|
||||
|
||||
31
api/app/schemas/mcp_market_config_schema.py
Normal file
31
api/app/schemas/mcp_market_config_schema.py
Normal file
@@ -0,0 +1,31 @@
|
||||
from pydantic import BaseModel, Field, field_serializer, ConfigDict
|
||||
import datetime
|
||||
import uuid
|
||||
|
||||
|
||||
class McpMarketConfigBase(BaseModel):
|
||||
mcp_market_id: uuid.UUID
|
||||
token: str | None = None
|
||||
status: int | None = None
|
||||
tenant_id: uuid.UUID | None = None
|
||||
created_by: uuid.UUID | None = None
|
||||
|
||||
|
||||
class McpMarketConfigCreate(McpMarketConfigBase):
|
||||
pass
|
||||
|
||||
|
||||
class McpMarketConfigUpdate(BaseModel):
|
||||
token: str | None = None
|
||||
status: int | None = None
|
||||
|
||||
|
||||
class McpMarketConfig(McpMarketConfigBase):
|
||||
id: uuid.UUID
|
||||
created_at: datetime.datetime
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
@field_serializer("created_at", when_used="json")
|
||||
def _serialize_created_at(self, dt: datetime.datetime):
|
||||
return int(dt.timestamp() * 1000) if dt else None
|
||||
37
api/app/schemas/mcp_market_schema.py
Normal file
37
api/app/schemas/mcp_market_schema.py
Normal file
@@ -0,0 +1,37 @@
|
||||
from pydantic import BaseModel, Field, field_serializer, ConfigDict
|
||||
import datetime
|
||||
import uuid
|
||||
|
||||
|
||||
class McpMarketBase(BaseModel):
|
||||
name: str
|
||||
description: str | None = None
|
||||
logo_url: str | None = None
|
||||
mcp_count: int
|
||||
url: str
|
||||
category: str
|
||||
created_by: uuid.UUID | None = None
|
||||
|
||||
|
||||
class McpMarketCreate(McpMarketBase):
|
||||
pass
|
||||
|
||||
|
||||
class McpMarketUpdate(BaseModel):
|
||||
name: str | None = Field(None)
|
||||
description: str | None = Field(None)
|
||||
logo_url: str | None = Field(None)
|
||||
mcp_count: int | None = Field(None)
|
||||
url: str | None = Field(None)
|
||||
category: str | None = Field(None)
|
||||
|
||||
|
||||
class McpMarket(McpMarketBase):
|
||||
id: uuid.UUID
|
||||
created_at: datetime.datetime
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
@field_serializer("created_at", when_used="json")
|
||||
def _serialize_created_at(self, dt: datetime.datetime):
|
||||
return int(dt.timestamp() * 1000) if dt else None
|
||||
@@ -25,9 +25,9 @@ class ModelConfigBase(BaseModel):
|
||||
|
||||
class ApiKeyCreateNested(BaseModel):
|
||||
"""用于在创建模型时内嵌创建API Key的Schema"""
|
||||
model_name: str = Field(..., description="模型实际名称", max_length=255)
|
||||
model_name: Optional[str] = Field(None, description="模型实际名称", max_length=255)
|
||||
description: Optional[str] = Field(None, description="备注")
|
||||
provider: ModelProvider = Field(..., description="API Key提供商")
|
||||
provider: Optional[str] = Field(None, description="API Key提供商")
|
||||
api_key: str = Field(..., description="API密钥", max_length=500)
|
||||
api_base: Optional[str] = Field(None, description="API基础URL", max_length=500)
|
||||
config: Optional[Dict[str, Any]] = Field({}, description="API Key特定配置")
|
||||
@@ -57,6 +57,8 @@ class ModelConfigUpdate(BaseModel):
|
||||
"""更新模型配置Schema"""
|
||||
name: Optional[str] = Field(None, description="模型显示名称", max_length=255)
|
||||
type: Optional[ModelType] = Field(None, description="模型类型")
|
||||
provider: Optional[str] = Field(None, description="供应商")
|
||||
logo: Optional[str] = Field(None, description="模型logo图片URL", max_length=255)
|
||||
description: Optional[str] = Field(None, description="模型描述")
|
||||
config: Optional[Dict[str, Any]] = Field(None, description="模型配置参数")
|
||||
is_active: Optional[bool] = Field(None, description="是否激活")
|
||||
|
||||
@@ -27,4 +27,5 @@ class TokenRequest(BaseModel):
|
||||
email: EmailStr
|
||||
password: str
|
||||
invite: Optional[str] = None
|
||||
username: Optional[str] = None
|
||||
|
||||
|
||||
@@ -36,6 +36,28 @@ class AdminChangePasswordRequest(BaseModel):
|
||||
new_password: Optional[str] = Field(None, min_length=6, description="新密码,至少6位。如果不提供则自动生成随机密码")
|
||||
|
||||
|
||||
class ChangeEmailRequest(BaseModel):
|
||||
"""修改邮箱请求"""
|
||||
password: str = Field(..., description="当前密码")
|
||||
new_email: EmailStr = Field(..., description="新邮箱地址")
|
||||
|
||||
|
||||
class SendEmailCodeRequest(BaseModel):
|
||||
"""发送邮箱验证码请求"""
|
||||
email: EmailStr = Field(..., description="邮箱地址")
|
||||
|
||||
|
||||
class VerifyEmailCodeRequest(BaseModel):
|
||||
"""验证邮箱验证码并修改邮箱请求"""
|
||||
new_email: EmailStr = Field(..., description="新邮箱地址")
|
||||
code: str = Field(..., min_length=6, max_length=6, description="验证码")
|
||||
|
||||
|
||||
class VerifyPasswordRequest(BaseModel):
|
||||
"""验证密码请求"""
|
||||
password: str = Field(..., description="密码")
|
||||
|
||||
|
||||
class ChangePasswordResponse(BaseModel):
|
||||
"""修改密码响应"""
|
||||
message: str
|
||||
|
||||
@@ -129,7 +129,8 @@ def register_user_with_invite(
|
||||
email: str,
|
||||
password: str,
|
||||
invite_token: str,
|
||||
workspace_id: str
|
||||
workspace_id: str,
|
||||
username: Optional[str] = None,
|
||||
) -> User:
|
||||
"""
|
||||
使用邀请码注册新用户并加入工作空间
|
||||
@@ -139,6 +140,7 @@ def register_user_with_invite(
|
||||
:param password: 用户密码
|
||||
:param invite_token: 邀请令牌
|
||||
:param workspace_id: 工作空间ID
|
||||
:param username: 用户名
|
||||
:return: 创建的用户对象
|
||||
"""
|
||||
from app.schemas.user_schema import UserCreate
|
||||
@@ -154,7 +156,7 @@ def register_user_with_invite(
|
||||
user_create = UserCreate(
|
||||
email=email,
|
||||
password=password,
|
||||
username=email.split('@')[0]
|
||||
username=email.split('@')[0] if not username else username
|
||||
)
|
||||
user = user_service.create_user(db=db, user=user_create)
|
||||
logger.info(f"用户创建成功: {user.email} (ID: {user.id})")
|
||||
|
||||
88
api/app/services/email_service.py
Normal file
88
api/app/services/email_service.py
Normal file
@@ -0,0 +1,88 @@
|
||||
import smtplib
|
||||
import re
|
||||
import asyncio
|
||||
from email.mime.text import MIMEText
|
||||
from email.mime.multipart import MIMEMultipart
|
||||
from email.header import Header
|
||||
from email.utils import formataddr
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.logging_config import get_business_logger
|
||||
|
||||
business_logger = get_business_logger()
|
||||
|
||||
|
||||
def _send_email_sync(to_email: str, subject: str, html_content: str, text_content: str = None):
|
||||
"""同步发送邮件"""
|
||||
smtp_server = settings.SMTP_SERVER
|
||||
smtp_port = settings.SMTP_PORT
|
||||
smtp_user = settings.SMTP_USER
|
||||
smtp_password = settings.SMTP_PASSWORD
|
||||
|
||||
if not smtp_server or not smtp_user or not smtp_password:
|
||||
raise BusinessException("邮件服务未配置", code=BizCode.SERVICE_UNAVAILABLE)
|
||||
|
||||
msg = MIMEMultipart('alternative')
|
||||
msg['Subject'] = Header(subject, "utf-8")
|
||||
from_name = "MemoryBear系统"
|
||||
msg['From'] = formataddr((Header(from_name, 'utf-8').encode(), smtp_user))
|
||||
msg['To'] = Header(to_email, "utf-8")
|
||||
|
||||
if not text_content:
|
||||
text_content = html_content.replace('<br>', '\n').replace('<p>', '\n').replace('</p>', '\n')
|
||||
text_content = re.sub(r'<.*?>', '', text_content)
|
||||
text_part = MIMEText(text_content, 'plain', 'utf-8')
|
||||
msg.attach(text_part)
|
||||
|
||||
html_part = MIMEText(html_content, 'html', 'utf-8')
|
||||
msg.attach(html_part)
|
||||
|
||||
if smtp_port == 465:
|
||||
with smtplib.SMTP_SSL(smtp_server, smtp_port, timeout=10) as server:
|
||||
server.login(smtp_user, smtp_password)
|
||||
server.send_message(msg)
|
||||
else:
|
||||
with smtplib.SMTP(smtp_server, smtp_port, timeout=10) as server:
|
||||
server.starttls()
|
||||
server.login(smtp_user, smtp_password)
|
||||
server.send_message(msg)
|
||||
|
||||
|
||||
async def send_email(to_email: str, subject: str, html_content: str, text_content: str = None):
|
||||
"""异步发送邮件"""
|
||||
to_email = to_email.strip()
|
||||
if not to_email or not re.match(r'^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$', to_email):
|
||||
err_msg = f"收件人邮箱格式无效: {to_email}"
|
||||
business_logger.error(err_msg)
|
||||
raise BusinessException(err_msg, code=BizCode.INVALID_PARAMETER)
|
||||
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
with ThreadPoolExecutor() as executor:
|
||||
await loop.run_in_executor(
|
||||
executor,
|
||||
_send_email_sync,
|
||||
to_email,
|
||||
subject,
|
||||
html_content,
|
||||
text_content
|
||||
)
|
||||
business_logger.info(f"邮件发送成功: {to_email}")
|
||||
except smtplib.SMTPAuthenticationError:
|
||||
err_msg = "SMTP认证失败,请检查SMTP账号/密码是否正确"
|
||||
business_logger.error(f"邮件发送失败: {to_email} - {err_msg}")
|
||||
raise BusinessException(err_msg, code=BizCode.UNAUTHORIZED)
|
||||
except smtplib.SMTPConnectError:
|
||||
err_msg = "SMTP服务器连接失败,请检查服务器地址/端口是否正确"
|
||||
business_logger.error(f"邮件发送失败: {to_email} - {err_msg}")
|
||||
raise BusinessException(err_msg, code=BizCode.SERVICE_UNAVAILABLE)
|
||||
except TimeoutError:
|
||||
err_msg = "邮件发送超时,请检查SMTP服务器配置"
|
||||
business_logger.error(f"邮件发送失败: {to_email} - {err_msg}")
|
||||
raise BusinessException(err_msg, code=BizCode.BAD_REQUEST)
|
||||
except Exception as e:
|
||||
business_logger.error(f"邮件发送失败: {to_email} - {str(e)}")
|
||||
raise BusinessException(f"邮件发送失败: {str(e)}", code=BizCode.SERVICE_UNAVAILABLE)
|
||||
@@ -843,32 +843,33 @@ class EmotionAnalyticsService:
|
||||
end_user_id: str,
|
||||
db: Session,
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""从 Redis 缓存获取个性化情绪建议
|
||||
"""从数据库获取个性化情绪建议
|
||||
|
||||
Args:
|
||||
end_user_id: 宿主ID(用户组ID)
|
||||
db: 数据库会话(保留参数以保持接口兼容性)
|
||||
db: 数据库会话
|
||||
|
||||
Returns:
|
||||
Dict: 缓存的建议数据,如果不存在或已过期返回 None
|
||||
Dict: 存储的建议数据,如果不存在返回 None
|
||||
"""
|
||||
try:
|
||||
from app.cache.memory.emotion_memory import EmotionMemoryCache
|
||||
from app.repositories.implicit_emotions_storage_repository import ImplicitEmotionsStorageRepository
|
||||
|
||||
logger.info(f"尝试从 Redis 缓存获取情绪建议: user={end_user_id}")
|
||||
logger.info(f"尝试从数据库获取情绪建议: user={end_user_id}")
|
||||
|
||||
# 从 Redis 获取缓存
|
||||
cached_data = await EmotionMemoryCache.get_emotion_suggestions(end_user_id)
|
||||
# 从数据库获取存储记录
|
||||
repo = ImplicitEmotionsStorageRepository(db)
|
||||
storage = repo.get_by_end_user_id(end_user_id)
|
||||
|
||||
if cached_data is None:
|
||||
logger.info(f"用户 {end_user_id} 的建议缓存不存在或已过期")
|
||||
if storage is None or storage.emotion_suggestions is None:
|
||||
logger.info(f"用户 {end_user_id} 的建议数据不存在")
|
||||
return None
|
||||
|
||||
logger.info(f"成功从 Redis 缓存获取建议: user={end_user_id}")
|
||||
return cached_data
|
||||
logger.info(f"成功从数据库获取建议: user={end_user_id}")
|
||||
return storage.emotion_suggestions
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"从 Redis 缓存获取建议失败: {str(e)}", exc_info=True)
|
||||
logger.error(f"从数据库获取建议失败: {str(e)}", exc_info=True)
|
||||
return None
|
||||
|
||||
async def save_suggestions_cache(
|
||||
@@ -876,36 +877,27 @@ class EmotionAnalyticsService:
|
||||
end_user_id: str,
|
||||
suggestions_data: Dict[str, Any],
|
||||
db: Session,
|
||||
expires_hours: int = 24
|
||||
expires_hours: int = 24 # 参数保留以保持接口兼容性
|
||||
) -> None:
|
||||
"""保存建议到 Redis 缓存
|
||||
"""保存建议到数据库
|
||||
|
||||
Args:
|
||||
end_user_id: 宿主ID(用户组ID)
|
||||
suggestions_data: 建议数据
|
||||
db: 数据库会话(保留参数以保持接口兼容性)
|
||||
expires_hours: 过期时间(小时),默认24小时
|
||||
db: 数据库会话
|
||||
expires_hours: 保留参数(兼容性)
|
||||
"""
|
||||
try:
|
||||
from app.cache.memory.emotion_memory import EmotionMemoryCache
|
||||
from app.repositories.implicit_emotions_storage_repository import ImplicitEmotionsStorageRepository
|
||||
|
||||
logger.info(f"保存建议到 Redis 缓存: user={end_user_id}, expires={expires_hours}小时")
|
||||
logger.info(f"保存建议到数据库: user={end_user_id}")
|
||||
|
||||
# 计算过期时间(秒)
|
||||
expire_seconds = expires_hours * 3600
|
||||
repo = ImplicitEmotionsStorageRepository(db)
|
||||
repo.update_emotion_suggestions(end_user_id, suggestions_data)
|
||||
db.commit()
|
||||
|
||||
# 保存到 Redis
|
||||
success = await EmotionMemoryCache.set_emotion_suggestions(
|
||||
user_id=end_user_id,
|
||||
suggestions_data=suggestions_data,
|
||||
expire=expire_seconds
|
||||
)
|
||||
|
||||
if success:
|
||||
logger.info(f"建议缓存保存成功: user={end_user_id}")
|
||||
else:
|
||||
logger.warning(f"建议缓存保存失败: user={end_user_id}")
|
||||
logger.info(f"建议保存成功: user={end_user_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"保存建议缓存失败: {str(e)}", exc_info=True)
|
||||
# 不抛出异常,缓存失败不应影响主流程
|
||||
db.rollback()
|
||||
logger.error(f"保存建议失败: {str(e)}", exc_info=True)
|
||||
@@ -26,7 +26,7 @@ logger = get_business_logger()
|
||||
|
||||
def generate_file_key(
|
||||
tenant_id: uuid.UUID,
|
||||
workspace_id: uuid.UUID,
|
||||
workspace_id: uuid.UUID | None,
|
||||
file_id: uuid.UUID,
|
||||
file_ext: str,
|
||||
) -> str:
|
||||
@@ -56,8 +56,9 @@ def generate_file_key(
|
||||
# Ensure file_ext starts with a dot
|
||||
if file_ext and not file_ext.startswith('.'):
|
||||
file_ext = f'.{file_ext}'
|
||||
|
||||
return f"{tenant_id}/{workspace_id}/{file_id}{file_ext}"
|
||||
if workspace_id:
|
||||
return f"{tenant_id}/{workspace_id}/{file_id}{file_ext}"
|
||||
return f"{tenant_id}/{file_id}{file_ext}"
|
||||
|
||||
|
||||
class FileStorageService:
|
||||
@@ -96,7 +97,7 @@ class FileStorageService:
|
||||
async def upload_file(
|
||||
self,
|
||||
tenant_id: uuid.UUID,
|
||||
workspace_id: uuid.UUID,
|
||||
workspace_id: uuid.UUID | None,
|
||||
file_id: uuid.UUID,
|
||||
file_ext: str,
|
||||
content: bytes,
|
||||
|
||||
@@ -422,32 +422,33 @@ class ImplicitMemoryService:
|
||||
end_user_id: str,
|
||||
db: Session
|
||||
) -> Optional[dict]:
|
||||
"""从 Redis 缓存获取完整用户画像
|
||||
"""从数据库获取完整用户画像
|
||||
|
||||
Args:
|
||||
end_user_id: 终端用户ID
|
||||
db: 数据库会话(保留参数以保持接口兼容性)
|
||||
db: 数据库会话
|
||||
|
||||
Returns:
|
||||
Dict: 缓存的画像数据,如果不存在或已过期返回 None
|
||||
Dict: 存储的画像数据,如果不存在返回 None
|
||||
"""
|
||||
try:
|
||||
from app.cache.memory.implicit_memory import ImplicitMemoryCache
|
||||
from app.repositories.implicit_emotions_storage_repository import ImplicitEmotionsStorageRepository
|
||||
|
||||
logger.info(f"尝试从 Redis 缓存获取用户画像: user={end_user_id}")
|
||||
logger.info(f"尝试从数据库获取用户画像: user={end_user_id}")
|
||||
|
||||
# 从 Redis 获取缓存
|
||||
cached_data = await ImplicitMemoryCache.get_user_profile(end_user_id)
|
||||
# 从数据库获取存储记录
|
||||
repo = ImplicitEmotionsStorageRepository(db)
|
||||
storage = repo.get_by_end_user_id(end_user_id)
|
||||
|
||||
if cached_data is None:
|
||||
logger.info(f"用户 {end_user_id} 的画像缓存不存在或已过期")
|
||||
if storage is None or storage.implicit_profile is None:
|
||||
logger.info(f"用户 {end_user_id} 的画像数据不存在")
|
||||
return None
|
||||
|
||||
logger.info(f"成功从 Redis 缓存获取用户画像: user={end_user_id}")
|
||||
return cached_data
|
||||
logger.info(f"成功从数据库获取用户画像: user={end_user_id}")
|
||||
return storage.implicit_profile
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"从 Redis 缓存获取用户画像失败: {str(e)}", exc_info=True)
|
||||
logger.error(f"从数据库获取用户画像失败: {str(e)}", exc_info=True)
|
||||
return None
|
||||
|
||||
async def save_profile_cache(
|
||||
@@ -455,36 +456,27 @@ class ImplicitMemoryService:
|
||||
end_user_id: str,
|
||||
profile_data: dict,
|
||||
db: Session,
|
||||
expires_hours: int = 168 # 默认7天
|
||||
expires_hours: int = 168 # 参数保留以保持接口兼容性
|
||||
) -> None:
|
||||
"""保存用户画像到 Redis 缓存
|
||||
"""保存用户画像到数据库
|
||||
|
||||
Args:
|
||||
end_user_id: 终端用户ID
|
||||
profile_data: 画像数据
|
||||
db: 数据库会话(保留参数以保持接口兼容性)
|
||||
expires_hours: 过期时间(小时),默认168小时(7天)
|
||||
db: 数据库会话
|
||||
expires_hours: 保留参数(兼容性)
|
||||
"""
|
||||
try:
|
||||
from app.cache.memory.implicit_memory import ImplicitMemoryCache
|
||||
from app.repositories.implicit_emotions_storage_repository import ImplicitEmotionsStorageRepository
|
||||
|
||||
logger.info(f"保存用户画像到 Redis 缓存: user={end_user_id}, expires={expires_hours}小时")
|
||||
logger.info(f"保存用户画像到数据库: user={end_user_id}")
|
||||
|
||||
# 计算过期时间(秒)
|
||||
expire_seconds = expires_hours * 3600
|
||||
repo = ImplicitEmotionsStorageRepository(db)
|
||||
repo.update_implicit_profile(end_user_id, profile_data)
|
||||
db.commit()
|
||||
|
||||
# 保存到 Redis
|
||||
success = await ImplicitMemoryCache.set_user_profile(
|
||||
user_id=end_user_id,
|
||||
profile_data=profile_data,
|
||||
expire=expire_seconds
|
||||
)
|
||||
|
||||
if success:
|
||||
logger.info(f"用户画像缓存保存成功: user={end_user_id}")
|
||||
else:
|
||||
logger.warning(f"用户画像缓存保存失败: user={end_user_id}")
|
||||
logger.info(f"用户画像保存成功: user={end_user_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"保存用户画像缓存失败: {str(e)}", exc_info=True)
|
||||
# 不抛出异常,缓存失败不应影响主流程
|
||||
db.rollback()
|
||||
logger.error(f"保存用户画像失败: {str(e)}", exc_info=True)
|
||||
|
||||
83
api/app/services/mcp_market_config_service.py
Normal file
83
api/app/services/mcp_market_config_service.py
Normal file
@@ -0,0 +1,83 @@
|
||||
import uuid
|
||||
from sqlalchemy.orm import Session
|
||||
from app.models.user_model import User
|
||||
from app.models.mcp_market_config_model import McpMarketConfig
|
||||
from app.schemas.mcp_market_config_schema import McpMarketConfigCreate, McpMarketConfigUpdate
|
||||
from app.repositories import mcp_market_config_repository
|
||||
from app.core.logging_config import get_business_logger
|
||||
|
||||
# Obtain a dedicated logger for business logic
|
||||
business_logger = get_business_logger()
|
||||
|
||||
|
||||
def create_mcp_market_config(
|
||||
db: Session, mcp_market_config: McpMarketConfigCreate, current_user: User
|
||||
) -> McpMarketConfig:
|
||||
business_logger.info(f"Create a mcp market config base: {mcp_market_config.mcp_market_id}, creator: {current_user.username}")
|
||||
|
||||
try:
|
||||
mcp_market_config.tenant_id = current_user.tenant_id
|
||||
mcp_market_config.created_by = current_user.id
|
||||
business_logger.debug(f"Start creating the mcp market config on mcp_market_id: {mcp_market_config.mcp_market_id}")
|
||||
db_mcp_market_config = mcp_market_config_repository.create_mcp_market_config(
|
||||
db=db, mcp_market_config=mcp_market_config
|
||||
)
|
||||
business_logger.info(
|
||||
f"The mcp market config has been successfully created: {mcp_market_config.mcp_market_id} (ID: {db_mcp_market_config.id}), creator: {current_user.username}")
|
||||
return db_mcp_market_config
|
||||
except Exception as e:
|
||||
business_logger.error(f"Failed to create a mcp marke config: {mcp_market_config.mcp_market_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
def get_mcp_market_config_by_id(db: Session, mcp_market_config_id: uuid.UUID, current_user: User) -> McpMarketConfig | None:
|
||||
business_logger.debug(
|
||||
f"Query mcp market config based on ID: mcp_market_config_id={mcp_market_config_id}, username: {current_user.username}")
|
||||
|
||||
try:
|
||||
mcpMarketConfig = mcp_market_config_repository.get_mcp_market_config_by_id(db=db, mcp_market_config_id=mcp_market_config_id)
|
||||
if mcpMarketConfig:
|
||||
business_logger.info(f"mcp market config query successful: (ID: {mcp_market_config_id})")
|
||||
else:
|
||||
business_logger.warning(f"mcp market config does not exist: mcp_market_config_id={mcp_market_config_id}")
|
||||
return mcpMarketConfig
|
||||
except Exception as e:
|
||||
business_logger.error(
|
||||
f"Failed to query the mcp market config based on the ID: {mcp_market_config_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
def get_mcp_market_config_by_mcp_market_id(db: Session, mcp_market_id: uuid.UUID, current_user: User) -> McpMarketConfig | None:
|
||||
business_logger.debug(
|
||||
f"Query mcp market config based on mcp_market_id: {mcp_market_id}, username: {current_user.username}")
|
||||
|
||||
try:
|
||||
mcpMarketConfig = mcp_market_config_repository.get_mcp_market_config_by_mcp_market_id(db=db, mcp_market_id=mcp_market_id, tenant_id=current_user.tenant_id)
|
||||
if mcpMarketConfig:
|
||||
business_logger.info(f"mcp market config query successful: (mcp_market_id: {mcp_market_id})")
|
||||
else:
|
||||
business_logger.warning(f"mcp market config does not exist: mcp_market_id={mcp_market_id}")
|
||||
return mcpMarketConfig
|
||||
except Exception as e:
|
||||
business_logger.error(
|
||||
f"Failed to query the mcp market config based on the mcp_market_id: {mcp_market_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
def delete_mcp_market_config_by_id(db: Session, mcp_market_config_id: uuid.UUID, current_user: User) -> None:
|
||||
business_logger.info(f"Delete mcp market config: mcp_market_config_id={mcp_market_config_id}, operator: {current_user.username}")
|
||||
|
||||
try:
|
||||
# First, query the mcp market config information for logging purposes
|
||||
mcpMarketConfig = mcp_market_config_repository.get_mcp_market_config_by_id(db=db, mcp_market_config_id=mcp_market_config_id)
|
||||
if mcpMarketConfig:
|
||||
business_logger.debug(f"Execute mcp market config deletion: (ID: {mcp_market_config_id})")
|
||||
else:
|
||||
business_logger.warning(f"The mcp market config to be deleted does not exist: mcp_market_config_id={mcp_market_config_id}")
|
||||
|
||||
mcp_market_config_repository.delete_mcp_market_config_by_id(db=db, mcp_market_config_id=mcp_market_config_id)
|
||||
business_logger.info(
|
||||
f"mcp market config record deleted successfully: mcp_market_config_id={mcp_market_config_id}, operator: {current_user.username}")
|
||||
except Exception as e:
|
||||
business_logger.error(f"Failed to delete mcp market config: mcp_market_config_id={mcp_market_config_id} - {str(e)}")
|
||||
raise
|
||||
109
api/app/services/mcp_market_service.py
Normal file
109
api/app/services/mcp_market_service.py
Normal file
@@ -0,0 +1,109 @@
|
||||
import uuid
|
||||
from sqlalchemy.orm import Session
|
||||
from app.models.user_model import User
|
||||
from app.models.mcp_market_model import McpMarket
|
||||
from app.schemas.mcp_market_schema import McpMarketCreate, McpMarketUpdate
|
||||
from app.repositories import mcp_market_repository
|
||||
from app.core.logging_config import get_business_logger
|
||||
|
||||
# Obtain a dedicated logger for business logic
|
||||
business_logger = get_business_logger()
|
||||
|
||||
|
||||
def get_mcp_markets_paginated(
|
||||
db: Session,
|
||||
current_user: User,
|
||||
filters: list,
|
||||
page: int,
|
||||
pagesize: int,
|
||||
orderby: str = None,
|
||||
desc: bool = False
|
||||
) -> tuple[int, list]:
|
||||
business_logger.debug(
|
||||
f"Query mcp market in pages: username={current_user.username}, page={page}, pagesize={pagesize}, orderby={orderby}, desc={desc}")
|
||||
|
||||
try:
|
||||
total, items = mcp_market_repository.get_mcp_markets_paginated(
|
||||
db=db,
|
||||
filters=filters,
|
||||
page=page,
|
||||
pagesize=pagesize,
|
||||
orderby=orderby,
|
||||
desc=desc
|
||||
)
|
||||
business_logger.info(
|
||||
f"The mcp market paging query has been successful: username={current_user.username}, total={total}, Number of current page={len(items)}")
|
||||
return total, items
|
||||
except Exception as e:
|
||||
business_logger.error(f"Querying mcp market pagination failed: username={current_user.username} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
def create_mcp_market(
|
||||
db: Session, mcp_market: McpMarketCreate, current_user: User
|
||||
) -> McpMarket:
|
||||
business_logger.info(f"Create a mcp market base: {mcp_market.name}, creator: {current_user.username}")
|
||||
|
||||
try:
|
||||
mcp_market.created_by = current_user.id
|
||||
business_logger.debug(f"Start creating the mcp market: {mcp_market.name}")
|
||||
db_mcp_market = mcp_market_repository.create_mcp_market(
|
||||
db=db, mcp_market=mcp_market
|
||||
)
|
||||
business_logger.info(
|
||||
f"The mcp market has been successfully created: {mcp_market.name} (ID: {db_mcp_market.id}), creator: {current_user.username}")
|
||||
return db_mcp_market
|
||||
except Exception as e:
|
||||
business_logger.error(f"Failed to create a mcp market: {mcp_market.name} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
def get_mcp_market_by_id(db: Session, mcp_market_id: uuid.UUID, current_user: User) -> McpMarket | None:
|
||||
business_logger.debug(
|
||||
f"Query mcp market based on ID: mcp_market_id={mcp_market_id}, username: {current_user.username}")
|
||||
|
||||
try:
|
||||
mcpMarket = mcp_market_repository.get_mcp_market_by_id(db=db, mcp_market_id=mcp_market_id)
|
||||
if mcpMarket:
|
||||
business_logger.info(f"mcp market query successful: {mcpMarket.name} (ID: {mcp_market_id})")
|
||||
else:
|
||||
business_logger.warning(f"mcp market does not exist: mcp_market_id={mcp_market_id}")
|
||||
return mcpMarket
|
||||
except Exception as e:
|
||||
business_logger.error(
|
||||
f"Failed to query the mcp market based on the ID: {mcp_market_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
def get_mcp_market_by_name(db: Session, name: str, current_user: User) -> McpMarket | None:
|
||||
business_logger.debug(f"Query mcp market based on name: name={name}, username: {current_user.username}")
|
||||
|
||||
try:
|
||||
db_mcp_market = mcp_market_repository.get_mcp_market_by_name(db=db, name=name)
|
||||
if db_mcp_market:
|
||||
business_logger.info(f"mcp market query successful: {name} (ID: {db_mcp_market.id})")
|
||||
else:
|
||||
business_logger.warning(f"mcp market does not exist: name={name}")
|
||||
return db_mcp_market
|
||||
except Exception as e:
|
||||
business_logger.error(f"Failed to query the mcp market based on the name: name={name} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
def delete_mcp_market_by_id(db: Session, mcp_market_id: uuid.UUID, current_user: User) -> None:
|
||||
business_logger.info(f"Delete mcp market: mcp_market_id={mcp_market_id}, operator: {current_user.username}")
|
||||
|
||||
try:
|
||||
# First, query the mcp market information for logging purposes
|
||||
mcpMarket = mcp_market_repository.get_mcp_market_by_id(db=db, mcp_market_id=mcp_market_id)
|
||||
if mcpMarket:
|
||||
business_logger.debug(f"Execute mcp market deletion: {mcpMarket.name} (ID: {mcp_market_id})")
|
||||
else:
|
||||
business_logger.warning(f"The mcp market to be deleted does not exist: mcp_market_id={mcp_market_id}")
|
||||
|
||||
mcp_market_repository.delete_mcp_market_by_id(db=db, mcp_market_id=mcp_market_id)
|
||||
business_logger.info(
|
||||
f"mcp market record deleted successfully: mcp_market_id={mcp_market_id}, operator: {current_user.username}")
|
||||
except Exception as e:
|
||||
business_logger.error(f"Failed to delete mcp market: mcp_market_id={mcp_market_id} - {str(e)}")
|
||||
raise
|
||||
@@ -816,11 +816,10 @@ class MemoryAgentService:
|
||||
"""
|
||||
统计知识库类型分布,包含:
|
||||
1. PostgreSQL 中的知识库类型:General, Web, Third-party, Folder(根据 workspace_id 过滤)
|
||||
2. Neo4j 中的 memory 类型(仅统计 Chunk 数量,根据 end_user_id/end_user_id 过滤)
|
||||
3. total: 所有类型的总和
|
||||
2. total: 所有类型的总和
|
||||
|
||||
参数:
|
||||
- end_user_id: 用户组ID(可选,未提供时 memory 统计为 0)
|
||||
- end_user_id: 用户组ID(可选,保留参数以保持接口兼容性)
|
||||
- only_active: 是否仅统计有效记录
|
||||
- current_workspace_id: 当前工作空间ID(可选,未提供时知识库统计为 0)
|
||||
- db: 数据库会话
|
||||
@@ -831,7 +830,6 @@ class MemoryAgentService:
|
||||
"Web": count,
|
||||
"Third-party": count,
|
||||
"Folder": count,
|
||||
"memory": chunk_count,
|
||||
"total": sum_of_all
|
||||
}
|
||||
"""
|
||||
@@ -878,51 +876,8 @@ class MemoryAgentService:
|
||||
logger.error(f"知识库类型统计失败: {e}")
|
||||
raise Exception(f"知识库类型统计失败: {e}")
|
||||
|
||||
# 2. 统计 Neo4j 中的 memory 总量(统计当前空间下所有宿主的 Chunk 总数)
|
||||
try:
|
||||
if current_workspace_id:
|
||||
# 获取当前空间下的所有宿主
|
||||
from app.repositories import app_repository, end_user_repository
|
||||
from app.schemas.app_schema import App as AppSchema
|
||||
from app.schemas.end_user_schema import EndUser as EndUserSchema
|
||||
|
||||
# 查询应用并转换为 Pydantic 模型
|
||||
apps_orm = app_repository.get_apps_by_workspace_id(db, current_workspace_id)
|
||||
apps = [AppSchema.model_validate(h) for h in apps_orm]
|
||||
app_ids = [app.id for app in apps]
|
||||
|
||||
# 获取所有宿主
|
||||
end_users = []
|
||||
for app_id in app_ids:
|
||||
end_user_orm_list = end_user_repository.get_end_users_by_app_id(db, app_id)
|
||||
end_users.extend(h for h in end_user_orm_list)
|
||||
|
||||
# 统计所有宿主的 Chunk 总数
|
||||
total_chunks = 0
|
||||
for end_user in end_users:
|
||||
end_user_id_str = str(end_user.id)
|
||||
memory_query = """
|
||||
MATCH (n:Chunk) WHERE n.end_user_id = $end_user_id RETURN count(n) AS Count
|
||||
"""
|
||||
neo4j_result = await _neo4j_connector.execute_query(
|
||||
memory_query,
|
||||
end_user_id=end_user_id_str,
|
||||
)
|
||||
chunk_count = neo4j_result[0]["Count"] if neo4j_result else 0
|
||||
total_chunks += chunk_count
|
||||
logger.debug(f"EndUser {end_user_id_str} Chunk数量: {chunk_count}")
|
||||
|
||||
result["memory"] = total_chunks
|
||||
logger.info(f"Neo4j memory统计成功: 总Chunk数={total_chunks}, 宿主数={len(end_users)}")
|
||||
else:
|
||||
# 没有 workspace_id 时,返回 0
|
||||
result["memory"] = 0
|
||||
logger.info("未提供 workspace_id,memory 统计为 0")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Neo4j memory统计失败: {e}", exc_info=True)
|
||||
# 如果 Neo4j 查询失败,memory 设为 0
|
||||
result["memory"] = 0
|
||||
# 2. 统计 Neo4j 中的 memory 总量已移除
|
||||
# memory 字段不再返回
|
||||
|
||||
# 3. 计算知识库类型总和(不包括 memory)
|
||||
result["total"] = (
|
||||
|
||||
@@ -6,7 +6,7 @@ import math
|
||||
import time
|
||||
import asyncio
|
||||
|
||||
from app.models.models_model import ModelConfig, ModelApiKey, ModelType, LoadBalanceStrategy
|
||||
from app.models.models_model import ModelConfig, ModelApiKey, ModelType, LoadBalanceStrategy, ModelProvider
|
||||
from app.repositories.model_repository import ModelConfigRepository, ModelApiKeyRepository, ModelBaseRepository
|
||||
from app.schemas import model_schema
|
||||
from app.schemas.model_schema import (
|
||||
@@ -69,9 +69,9 @@ class ModelConfigService:
|
||||
return items
|
||||
|
||||
@staticmethod
|
||||
def get_model_by_name(db: Session, name: str, tenant_id: uuid.UUID | None = None) -> ModelConfig:
|
||||
def get_model_by_name(db: Session, name: str, provider: str | None = None, tenant_id: uuid.UUID | None = None) -> ModelConfig:
|
||||
"""根据名称获取模型配置"""
|
||||
model = ModelConfigRepository.get_by_name(db, name, tenant_id=tenant_id)
|
||||
model = ModelConfigRepository.get_by_name(db, name, provider=provider, tenant_id=tenant_id)
|
||||
if not model:
|
||||
raise BusinessException("模型配置不存在", BizCode.MODEL_NOT_FOUND)
|
||||
return model
|
||||
@@ -244,7 +244,7 @@ class ModelConfigService:
|
||||
async def create_model(db: Session, model_data: ModelConfigCreate, tenant_id: uuid.UUID) -> ModelConfig:
|
||||
"""创建模型配置"""
|
||||
# 检查名称是否已存在(同租户内)
|
||||
if ModelConfigRepository.get_by_name(db, model_data.name, tenant_id=tenant_id):
|
||||
if ModelConfigRepository.get_by_name(db, model_data.name, provider=model_data.provider, tenant_id=tenant_id):
|
||||
raise BusinessException("模型名称已存在", BizCode.DUPLICATE_NAME)
|
||||
|
||||
# 验证配置
|
||||
@@ -253,8 +253,8 @@ class ModelConfigService:
|
||||
for api_key_data in api_key_data_list:
|
||||
validation_result = await ModelConfigService.validate_model_config(
|
||||
db=db,
|
||||
model_name=api_key_data.model_name,
|
||||
provider=api_key_data.provider,
|
||||
model_name=model_data.name,
|
||||
provider=model_data.provider,
|
||||
api_key=api_key_data.api_key,
|
||||
api_base=api_key_data.api_base,
|
||||
model_type=model_data.type, # 传递模型类型
|
||||
@@ -277,6 +277,8 @@ class ModelConfigService:
|
||||
|
||||
if api_key_datas:
|
||||
for api_key_data in api_key_datas:
|
||||
api_key_data.model_name = model_data.name
|
||||
api_key_data.provider = model_data.provider
|
||||
api_key_create_schema = ModelApiKeyCreate(
|
||||
model_config_ids=[model.id],
|
||||
**api_key_data.model_dump()
|
||||
@@ -295,7 +297,7 @@ class ModelConfigService:
|
||||
raise BusinessException("模型配置不存在", BizCode.MODEL_NOT_FOUND)
|
||||
|
||||
if model_data.name and model_data.name != existing_model.name:
|
||||
if ModelConfigRepository.get_by_name(db, model_data.name, tenant_id=tenant_id):
|
||||
if ModelConfigRepository.get_by_name(db, model_data.name, provider=existing_model.provider, tenant_id=tenant_id):
|
||||
raise BusinessException("模型名称已存在", BizCode.DUPLICATE_NAME)
|
||||
|
||||
model = ModelConfigRepository.update(db, model_id, model_data, tenant_id=tenant_id)
|
||||
@@ -306,7 +308,7 @@ class ModelConfigService:
|
||||
@staticmethod
|
||||
async def create_composite_model(db: Session, model_data: model_schema.CompositeModelCreate, tenant_id: uuid.UUID) -> ModelConfig:
|
||||
"""创建组合模型"""
|
||||
if ModelConfigRepository.get_by_name(db, model_data.name, tenant_id=tenant_id):
|
||||
if ModelConfigRepository.get_by_name(db, model_data.name, provider=ModelProvider.COMPOSITE, tenant_id=tenant_id):
|
||||
raise BusinessException("模型名称已存在", BizCode.DUPLICATE_NAME)
|
||||
|
||||
# 验证所有 API Key 存在且类型匹配
|
||||
@@ -341,7 +343,7 @@ class ModelConfigService:
|
||||
"type": model_data.type,
|
||||
"logo": model_data.logo,
|
||||
"description": model_data.description,
|
||||
"provider": "composite",
|
||||
"provider": ModelProvider.COMPOSITE,
|
||||
"config": model_data.config,
|
||||
"is_active": model_data.is_active,
|
||||
"is_public": model_data.is_public,
|
||||
@@ -369,6 +371,10 @@ class ModelConfigService:
|
||||
existing_model = ModelConfigRepository.get_by_id(db, model_id, tenant_id=tenant_id)
|
||||
if not existing_model:
|
||||
raise BusinessException("模型配置不存在", BizCode.MODEL_NOT_FOUND)
|
||||
|
||||
if model_data.name and model_data.name != existing_model.name:
|
||||
if ModelConfigRepository.get_by_name(db, model_data.name, provider=existing_model.provider, tenant_id=tenant_id):
|
||||
raise BusinessException("模型名称已存在", BizCode.DUPLICATE_NAME)
|
||||
|
||||
if not existing_model.is_composite:
|
||||
raise BusinessException("该模型不是组合模型", BizCode.INVALID_PARAMETER)
|
||||
@@ -471,11 +477,14 @@ class ModelApiKeyService:
|
||||
# 从ModelBase获取model_name
|
||||
model_name = model_config.model_base.name if model_config.model_base else model_config.name
|
||||
|
||||
# 检查是否存在API Key(包括软删除)
|
||||
existing_key = db.query(ModelApiKey).filter(
|
||||
# 检查是否存在API Key(包括软删除),需要考虑tenant_id
|
||||
existing_key = db.query(ModelApiKey).join(
|
||||
ModelApiKey.model_configs
|
||||
).filter(
|
||||
ModelApiKey.api_key == data.api_key,
|
||||
ModelApiKey.provider == data.provider,
|
||||
ModelApiKey.model_name == model_name
|
||||
ModelApiKey.model_name == model_name,
|
||||
ModelConfig.tenant_id == model_config.tenant_id
|
||||
).first()
|
||||
|
||||
if existing_key:
|
||||
@@ -542,11 +551,14 @@ class ModelApiKeyService:
|
||||
if not model_config:
|
||||
raise BusinessException("模型配置不存在", BizCode.MODEL_NOT_FOUND)
|
||||
|
||||
# 检查API Key是否已存在(包括软删除)
|
||||
existing_key = db.query(ModelApiKey).filter(
|
||||
# 检查API Key是否已存在(包括软删除),需要考虑tenant_id
|
||||
existing_key = db.query(ModelApiKey).join(
|
||||
ModelApiKey.model_configs
|
||||
).filter(
|
||||
ModelApiKey.api_key == api_key_data.api_key,
|
||||
ModelApiKey.provider == api_key_data.provider,
|
||||
ModelApiKey.model_name == api_key_data.model_name
|
||||
ModelApiKey.model_name == api_key_data.model_name,
|
||||
ModelConfig.tenant_id == model_config.tenant_id
|
||||
).first()
|
||||
|
||||
if existing_key:
|
||||
|
||||
@@ -1,13 +1,18 @@
|
||||
import datetime
|
||||
import json
|
||||
import secrets
|
||||
import string
|
||||
|
||||
from pydantic import EmailStr
|
||||
from sqlalchemy.orm import Session
|
||||
import uuid
|
||||
|
||||
from app.aioRedis import aio_redis_set, aio_redis_get, aio_redis_delete
|
||||
from app.models.user_model import User
|
||||
from app.repositories import user_repository
|
||||
from app.schemas.user_schema import UserCreate
|
||||
from app.schemas.tenant_schema import TenantCreate
|
||||
from app.services.email_service import send_email
|
||||
from app.services.tenant_service import TenantService
|
||||
from app.services.session_service import SessionService
|
||||
from app.core.security import get_password_hash, verify_password
|
||||
@@ -563,3 +568,175 @@ def generate_random_password(length: int = 12) -> str:
|
||||
secrets.SystemRandom().shuffle(password)
|
||||
|
||||
return ''.join(password)
|
||||
|
||||
|
||||
def generate_email_code() -> str:
|
||||
"""生成6位数字验证码"""
|
||||
return ''.join([str(secrets.randbelow(10)) for _ in range(6)])
|
||||
|
||||
|
||||
async def send_email_code_method(db: Session, email: EmailStr, user_id: uuid.UUID):
|
||||
"""发送邮箱验证码"""
|
||||
business_logger.info(f"发送邮箱验证码: email={email}")
|
||||
|
||||
# 检查发送间隔
|
||||
rate_limit_key = f"email_code_rate:{user_id}"
|
||||
last_send = await aio_redis_get(rate_limit_key)
|
||||
|
||||
if last_send:
|
||||
raise BusinessException("请稍后再试,验证码发送间隔为1分钟", code=BizCode.RATE_LIMITED)
|
||||
|
||||
# 检查新邮箱是否已被使用
|
||||
existing_user = user_repository.get_user_by_email(db=db, email=email)
|
||||
if existing_user and existing_user.id != user_id:
|
||||
raise BusinessException("邮箱已被使用", code=BizCode.DUPLICATE_NAME)
|
||||
|
||||
if existing_user and existing_user.id == user_id:
|
||||
raise BusinessException("新邮箱与当前邮箱相同", code=BizCode.DUPLICATE_NAME)
|
||||
|
||||
# 生成验证码
|
||||
code = generate_email_code()
|
||||
|
||||
# 存储到 Redis,5分钟过期
|
||||
cache_key = f"email_code:{user_id}:{email}"
|
||||
await aio_redis_set(cache_key, json.dumps(code), expire=300)
|
||||
|
||||
# 发送邮件
|
||||
await send_email(
|
||||
email,
|
||||
"邮箱验证码",
|
||||
f'<p>您的验证码是:<strong>{code}</strong></p><p>验证码在5分钟内有效。</p>'
|
||||
)
|
||||
|
||||
# 设置发送间隔限制,60秒
|
||||
await aio_redis_set(rate_limit_key, "1", expire=60)
|
||||
|
||||
business_logger.info(f"邮箱验证码已发送: {email}")
|
||||
|
||||
|
||||
async def verify_and_change_email(db: Session, user_id: uuid.UUID, new_email: EmailStr, code: str) -> User:
|
||||
"""验证验证码并修改邮箱"""
|
||||
business_logger.info(f"验证并修改邮箱: user_id={user_id}, new_email={new_email}")
|
||||
|
||||
db_user = user_repository.get_user_by_id(db=db, user_id=user_id)
|
||||
if not db_user:
|
||||
raise BusinessException("用户不存在", code=BizCode.USER_NOT_FOUND)
|
||||
|
||||
# 验证验证码
|
||||
cache_key = f"email_code:{user_id}:{new_email}"
|
||||
cached_code = await aio_redis_get(cache_key)
|
||||
|
||||
if not cached_code:
|
||||
raise BusinessException("验证码已过期", code=BizCode.VALIDATION_FAILED)
|
||||
|
||||
if json.loads(cached_code) != code:
|
||||
raise BusinessException("验证码错误", code=BizCode.VALIDATION_FAILED)
|
||||
|
||||
# 修改邮箱
|
||||
db_user.email = new_email
|
||||
db.commit()
|
||||
db.refresh(db_user)
|
||||
|
||||
# 删除验证码
|
||||
await aio_redis_delete(cache_key)
|
||||
|
||||
# 使所有旧 tokens 失效
|
||||
# await SessionService.invalidate_all_user_tokens(str(user_id))
|
||||
|
||||
business_logger.info(f"用户邮箱修改成功: {db_user.username}, new_email={new_email}")
|
||||
return db_user
|
||||
|
||||
|
||||
# def generate_email_token(user_id: str, old_email: str, new_email: str) -> str:
|
||||
# """生成邮箱修改token"""
|
||||
# payload = {
|
||||
# "user_id": user_id,
|
||||
# "old_email": old_email,
|
||||
# "new_email": new_email,
|
||||
# "exp": datetime.datetime.now(datetime.timezone.utc) + timedelta(hours=24)
|
||||
# }
|
||||
# return jwt.encode(payload, settings.SECRET_KEY, algorithm=settings.ALGORITHM)
|
||||
#
|
||||
#
|
||||
# def verify_email_token(token: str) -> dict:
|
||||
# """验证邮箱修改token"""
|
||||
# try:
|
||||
# payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM])
|
||||
# return payload
|
||||
# except jwt.ExpiredSignatureError:
|
||||
# raise BusinessException("链接已过期", code=BizCode.VALIDATION_FAILED)
|
||||
# except jwt.InvalidTokenError:
|
||||
# raise BusinessException("无效的链接", code=BizCode.VALIDATION_FAILED)
|
||||
#
|
||||
#
|
||||
# async def request_change_email(db: Session, user_id: uuid.UUID, new_email: EmailStr, current_user: User):
|
||||
# """请求修改邮箱,发送验证邮件"""
|
||||
# business_logger.info(f"用户请求修改邮箱: user_id={user_id}, new_email={new_email}")
|
||||
#
|
||||
# if current_user.id != user_id:
|
||||
# raise PermissionDeniedException("只能修改自己的邮箱")
|
||||
#
|
||||
# db_user = user_repository.get_user_by_id(db=db, user_id=user_id)
|
||||
# if not db_user:
|
||||
# raise BusinessException("用户不存在", code=BizCode.USER_NOT_FOUND)
|
||||
#
|
||||
# if db_user.email == new_email:
|
||||
# raise BusinessException("新邮箱与当前邮箱相同", code=BizCode.VALIDATION_FAILED)
|
||||
#
|
||||
# existing_user = user_repository.get_user_by_email(db=db, email=new_email)
|
||||
# if existing_user and existing_user.id != user_id:
|
||||
# raise BusinessException("邮箱已被使用", code=BizCode.DUPLICATE_NAME)
|
||||
#
|
||||
# token = generate_email_token(str(user_id), db_user.email, new_email)
|
||||
#
|
||||
# # 发送确认邮件到旧邮箱
|
||||
# old_email_link = f"{settings.BASE_URL}/api/users/email/confirm-email-change?token={token}"
|
||||
# await send_email(
|
||||
# db_user.email,
|
||||
# "确认修改邮箱",
|
||||
# f'<p>请点击以下链接确认修改邮箱:</p><a href="{old_email_link}">确认修改</a>'
|
||||
# )
|
||||
#
|
||||
# business_logger.info(f"邮箱修改确认邮件已发送到旧邮箱: {db_user.email}")
|
||||
#
|
||||
#
|
||||
# async def confirm_email_change(db: Session, token: str):
|
||||
# """确认修改邮箱(旧邮箱确认)"""
|
||||
# payload = verify_email_token(token)
|
||||
# user_id = uuid.UUID(payload["user_id"])
|
||||
# new_email = payload["new_email"]
|
||||
#
|
||||
# db_user = user_repository.get_user_by_id(db=db, user_id=user_id)
|
||||
# if not db_user:
|
||||
# raise BusinessException("用户不存在", code=BizCode.USER_NOT_FOUND)
|
||||
#
|
||||
# # 发送激活邮件到新邮箱
|
||||
# activate_link = f"{settings.BASE_URL}/api/users/email/activate-new-email?token={token}"
|
||||
# await send_email(
|
||||
# new_email,
|
||||
# "激活新邮箱",
|
||||
# f'<p>请点击以下链接激活新邮箱:</p><a href="{activate_link}">激活邮箱</a>'
|
||||
# )
|
||||
#
|
||||
# business_logger.info(f"新邮箱激活邮件已发送: {new_email}")
|
||||
#
|
||||
#
|
||||
# async def activate_new_email(db: Session, token: str) -> User:
|
||||
# """激活新邮箱"""
|
||||
# payload = verify_email_token(token)
|
||||
# user_id = uuid.UUID(payload["user_id"])
|
||||
# new_email = payload["new_email"]
|
||||
#
|
||||
# db_user = user_repository.get_user_by_id(db=db, user_id=user_id)
|
||||
# if not db_user:
|
||||
# raise BusinessException("用户不存在", code=BizCode.USER_NOT_FOUND)
|
||||
#
|
||||
# db_user.email = new_email
|
||||
# db.commit()
|
||||
# db.refresh(db_user)
|
||||
#
|
||||
# # 使所有旧 tokens 失效
|
||||
# await SessionService.invalidate_all_user_tokens(str(user_id))
|
||||
#
|
||||
# business_logger.info(f"用户邮箱修改成功: {db_user.username}, new_email={new_email}")
|
||||
# return db_user
|
||||
|
||||
@@ -588,7 +588,7 @@ class WorkflowService:
|
||||
"message_length": len(payload.get("output", ""))
|
||||
}
|
||||
}
|
||||
case "node_start" | "node_end" | "node_error":
|
||||
case "node_start" | "node_end" | "node_error" | "cycle_item":
|
||||
return None
|
||||
case _:
|
||||
return event
|
||||
|
||||
@@ -70,10 +70,10 @@ def delete_workspace_member(
|
||||
_check_workspace_admin_permission(db, workspace_id, user)
|
||||
workspace_member = workspace_repository.get_member_by_id(db=db, member_id=member_id)
|
||||
if not workspace_member:
|
||||
raise BusinessException(f"工作空间成员 {member_id} 不存在", BizCode.WORKSPACE_MEMBER_NOT_FOUND)
|
||||
raise BusinessException(f"工作空间成员 {member_id} 不存在", BizCode.WORKSPACE_NOT_FOUND)
|
||||
|
||||
if workspace_member.workspace_id != workspace_id:
|
||||
raise BusinessException(f"工作空间成员 {member_id} 不存在于工作空间 {workspace_id}", BizCode.WORKSPACE_MEMBER_NOT_FOUND)
|
||||
raise BusinessException(f"工作空间成员 {member_id} 不存在于工作空间 {workspace_id}", BizCode.WORKSPACE_NOT_FOUND)
|
||||
|
||||
try:
|
||||
workspace_member.is_active = False
|
||||
|
||||
524
api/app/tasks.py
524
api/app/tasks.py
@@ -1,4 +1,5 @@
|
||||
import asyncio
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
@@ -13,7 +14,6 @@ from typing import Any, Dict, List, Optional
|
||||
|
||||
import redis
|
||||
import requests
|
||||
import trio
|
||||
|
||||
# Import a unified Celery instance
|
||||
from app.celery_app import celery_app
|
||||
@@ -66,6 +66,10 @@ def parse_document(file_path: str, document_id: uuid.UUID):
|
||||
"""
|
||||
Document parsing, vectorization, and storage
|
||||
"""
|
||||
# Force re-importing Trio in child processes (to avoid inheriting the state of the parent process)
|
||||
import trio
|
||||
import importlib
|
||||
importlib.reload(trio)
|
||||
db = next(get_db()) # Manually call the generator
|
||||
db_document = None
|
||||
db_knowledge = None
|
||||
@@ -292,6 +296,10 @@ def build_graphrag_for_kb(kb_id: uuid.UUID):
|
||||
"""
|
||||
build knowledge graph
|
||||
"""
|
||||
# Force re-importing Trio in child processes (to avoid inheriting the state of the parent process)
|
||||
import trio
|
||||
import importlib
|
||||
importlib.reload(trio)
|
||||
db = next(get_db()) # Manually call the generator
|
||||
db_documents = None
|
||||
db_knowledge = None
|
||||
@@ -362,7 +370,7 @@ def build_graphrag_for_kb(kb_id: uuid.UUID):
|
||||
print(f"{datetime.now().strftime('%H:%M:%S')} GraphRAG task result for task {task}:\n{result}\n")
|
||||
return result
|
||||
|
||||
try:
|
||||
def sync_task():
|
||||
trio.run(
|
||||
lambda: _run(
|
||||
row=task,
|
||||
@@ -377,8 +385,15 @@ def build_graphrag_for_kb(kb_id: uuid.UUID):
|
||||
with_community=with_community,
|
||||
)
|
||||
)
|
||||
try:
|
||||
with ThreadPoolExecutor(max_workers=1) as executor:
|
||||
future = executor.submit(sync_task)
|
||||
future.result() # Blocks until the task completes
|
||||
except Exception as e:
|
||||
print(f"{datetime.now().strftime('%H:%M:%S')} GraphRAG task failed for task {task}:\n{str(e)}\n")
|
||||
finally:
|
||||
if db:
|
||||
db.close()
|
||||
print(f"{datetime.now().strftime('%H:%M:%S')} Knowledge Graph done ({time.time() - start_time}s)")
|
||||
|
||||
result = f"build knowledge graph '{db_knowledge.name}' processed successfully."
|
||||
@@ -389,7 +404,8 @@ def build_graphrag_for_kb(kb_id: uuid.UUID):
|
||||
result = f"build knowledge grap '{db_knowledge.name}' failed."
|
||||
return result
|
||||
finally:
|
||||
db.close()
|
||||
if db:
|
||||
db.close()
|
||||
|
||||
|
||||
@celery_app.task(name="app.core.rag.tasks.sync_knowledge_for_kb")
|
||||
@@ -1288,6 +1304,203 @@ def write_total_memory_task(workspace_id: str) -> Dict[str, Any]:
|
||||
"workspace_id": workspace_id,
|
||||
"elapsed_time": elapsed_time,
|
||||
}
|
||||
@celery_app.task(
|
||||
name="app.tasks.write_all_workspaces_memory_task",
|
||||
bind=True,
|
||||
ignore_result=False,
|
||||
max_retries=3,
|
||||
acks_late=True,
|
||||
time_limit=3600,
|
||||
soft_time_limit=3300,
|
||||
)
|
||||
def write_all_workspaces_memory_task(self) -> Dict[str, Any]:
|
||||
"""定时任务:遍历所有工作空间,统计并写入记忆增量
|
||||
|
||||
此任务会:
|
||||
1. 查询所有活跃的工作空间
|
||||
2. 对每个工作空间统计记忆总量
|
||||
3. 将统计结果写入 memory_increments 表
|
||||
|
||||
Returns:
|
||||
包含任务执行结果的字典
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
async def _run() -> Dict[str, Any]:
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.models.workspace_model import Workspace
|
||||
from app.models.app_model import App
|
||||
from app.models.end_user_model import EndUser
|
||||
from app.repositories.memory_increment_repository import write_memory_increment
|
||||
from app.services.memory_storage_service import search_all
|
||||
|
||||
api_logger = get_api_logger()
|
||||
|
||||
with get_db_context() as db:
|
||||
try:
|
||||
# 获取所有活跃的工作空间
|
||||
workspaces = db.query(Workspace).filter(
|
||||
Workspace.is_active.is_(True)
|
||||
).all()
|
||||
|
||||
if not workspaces:
|
||||
api_logger.warning("没有找到活跃的工作空间")
|
||||
return {
|
||||
"status": "SUCCESS",
|
||||
"message": "没有找到活跃的工作空间",
|
||||
"workspace_count": 0,
|
||||
"workspace_results": []
|
||||
}
|
||||
|
||||
api_logger.info(f"开始统计 {len(workspaces)} 个工作空间的记忆增量")
|
||||
all_workspace_results = []
|
||||
|
||||
# 遍历每个工作空间
|
||||
for workspace in workspaces:
|
||||
workspace_id = workspace.id
|
||||
api_logger.info(f"开始处理工作空间: {workspace.name} (ID: {workspace_id})")
|
||||
|
||||
try:
|
||||
# 1. 查询当前workspace下的所有app(仅未删除的)
|
||||
apps = db.query(App).filter(
|
||||
App.workspace_id == workspace_id,
|
||||
App.is_active.is_(True)
|
||||
).all()
|
||||
|
||||
if not apps:
|
||||
# 如果没有app,总量为0
|
||||
memory_increment = write_memory_increment(
|
||||
db=db,
|
||||
workspace_id=workspace_id,
|
||||
total_num=0
|
||||
)
|
||||
all_workspace_results.append({
|
||||
"workspace_id": str(workspace_id),
|
||||
"workspace_name": workspace.name,
|
||||
"status": "SUCCESS",
|
||||
"total_num": 0,
|
||||
"end_user_count": 0,
|
||||
"memory_increment_id": str(memory_increment.id),
|
||||
"created_at": memory_increment.created_at.isoformat(),
|
||||
})
|
||||
api_logger.info(f"工作空间 {workspace.name} 没有应用,记录总量为0")
|
||||
continue
|
||||
|
||||
# 2. 查询所有app下的end_user_id(去重)
|
||||
app_ids = [app.id for app in apps]
|
||||
end_users = db.query(EndUser.id).filter(
|
||||
EndUser.app_id.in_(app_ids)
|
||||
).distinct().all()
|
||||
|
||||
# 3. 遍历所有end_user,查询每个宿主的记忆总量并累加
|
||||
total_num = 0
|
||||
end_user_details = []
|
||||
|
||||
for (end_user_id,) in end_users:
|
||||
try:
|
||||
# 调用 search_all 接口查询该宿主的总量
|
||||
result = await search_all(str(end_user_id))
|
||||
user_total = result.get("total", 0)
|
||||
total_num += user_total
|
||||
end_user_details.append({
|
||||
"end_user_id": str(end_user_id),
|
||||
"total": user_total
|
||||
})
|
||||
except Exception as e:
|
||||
# 记录单个用户查询失败,但继续处理其他用户
|
||||
api_logger.warning(f"查询用户 {end_user_id} 记忆失败: {str(e)}")
|
||||
end_user_details.append({
|
||||
"end_user_id": str(end_user_id),
|
||||
"total": 0,
|
||||
"error": str(e)
|
||||
})
|
||||
|
||||
# 4. 写入数据库
|
||||
memory_increment = write_memory_increment(
|
||||
db=db,
|
||||
workspace_id=workspace_id,
|
||||
total_num=total_num
|
||||
)
|
||||
|
||||
all_workspace_results.append({
|
||||
"workspace_id": str(workspace_id),
|
||||
"workspace_name": workspace.name,
|
||||
"status": "SUCCESS",
|
||||
"total_num": total_num,
|
||||
"end_user_count": len(end_users),
|
||||
"memory_increment_id": str(memory_increment.id),
|
||||
"created_at": memory_increment.created_at.isoformat(),
|
||||
})
|
||||
|
||||
api_logger.info(
|
||||
f"工作空间 {workspace.name} 统计完成: 总量={total_num}, 用户数={len(end_users)}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
db.rollback() # 回滚失败的事务,允许继续处理下一个工作空间
|
||||
api_logger.error(f"处理工作空间 {workspace.name} (ID: {workspace_id}) 失败: {str(e)}")
|
||||
all_workspace_results.append({
|
||||
"workspace_id": str(workspace_id),
|
||||
"workspace_name": workspace.name,
|
||||
"status": "FAILURE",
|
||||
"error": str(e),
|
||||
"total_num": 0,
|
||||
"end_user_count": 0,
|
||||
})
|
||||
|
||||
total_memory = sum(r.get("total_num", 0) for r in all_workspace_results)
|
||||
success_count = sum(1 for r in all_workspace_results if r.get("status") == "SUCCESS")
|
||||
|
||||
return {
|
||||
"status": "SUCCESS",
|
||||
"message": f"成功处理 {success_count}/{len(workspaces)} 个工作空间,总记忆量: {total_memory}",
|
||||
"workspace_count": len(workspaces),
|
||||
"success_count": success_count,
|
||||
"total_memory": total_memory,
|
||||
"workspace_results": all_workspace_results
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"记忆增量统计任务执行失败: {str(e)}")
|
||||
return {
|
||||
"status": "FAILURE",
|
||||
"error": str(e),
|
||||
"workspace_count": 0,
|
||||
"workspace_results": []
|
||||
}
|
||||
|
||||
try:
|
||||
# 使用 nest_asyncio 来避免事件循环冲突
|
||||
try:
|
||||
import nest_asyncio
|
||||
nest_asyncio.apply()
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# 尝试获取现有事件循环,如果不存在则创建新的
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
if loop.is_closed():
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
except RuntimeError:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
result = loop.run_until_complete(_run())
|
||||
elapsed_time = time.time() - start_time
|
||||
result["elapsed_time"] = elapsed_time
|
||||
result["task_id"] = self.request.id
|
||||
|
||||
return result
|
||||
except Exception as e:
|
||||
elapsed_time = time.time() - start_time
|
||||
return {
|
||||
"status": "FAILURE",
|
||||
"error": str(e),
|
||||
"elapsed_time": elapsed_time,
|
||||
"task_id": self.request.id
|
||||
}
|
||||
|
||||
|
||||
@celery_app.task(
|
||||
@@ -1908,4 +2121,307 @@ def run_forgetting_cycle_task(self, config_id: Optional[uuid.UUID] = None) -> Di
|
||||
# "config_id": config_id,
|
||||
# "elapsed_time": elapsed_time,
|
||||
# "task_id": self.request.id
|
||||
# }
|
||||
# }
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 隐性记忆和情绪数据更新定时任务
|
||||
# =============================================================================
|
||||
|
||||
@celery_app.task(
|
||||
name="app.tasks.update_implicit_emotions_storage",
|
||||
bind=True,
|
||||
ignore_result=True,
|
||||
max_retries=0,
|
||||
acks_late=False,
|
||||
time_limit=7200, # 2小时硬超时
|
||||
soft_time_limit=6900, # 1小时55分钟软超时
|
||||
)
|
||||
def update_implicit_emotions_storage(self) -> Dict[str, Any]:
|
||||
"""定时任务:更新所有用户的隐性记忆画像和情绪建议数据
|
||||
|
||||
遍历数据库中所有已存在数据的用户,为每个用户重新生成隐性记忆画像和情绪建议。
|
||||
实现错误隔离,单个用户失败不影响其他用户的处理。
|
||||
|
||||
Returns:
|
||||
包含任务执行结果的字典,包括:
|
||||
- status: 任务状态 (SUCCESS/FAILURE)
|
||||
- message: 执行消息
|
||||
- total_users: 总用户数
|
||||
- successful_implicit: 成功更新隐性记忆的用户数
|
||||
- successful_emotion: 成功更新情绪建议的用户数
|
||||
- failed: 失败的用户数
|
||||
- user_results: 每个用户的详细结果
|
||||
- elapsed_time: 执行耗时(秒)
|
||||
- task_id: 任务ID
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
async def _run() -> Dict[str, Any]:
|
||||
from app.core.logging_config import get_logger
|
||||
from app.repositories.implicit_emotions_storage_repository import ImplicitEmotionsStorageRepository
|
||||
from app.models.implicit_emotions_storage_model import ImplicitEmotionsStorage
|
||||
from sqlalchemy import select, func
|
||||
from app.services.implicit_memory_service import ImplicitMemoryService
|
||||
from app.services.emotion_analytics_service import EmotionAnalyticsService
|
||||
|
||||
logger = get_logger(__name__)
|
||||
logger.info("开始执行隐性记忆和情绪数据更新定时任务")
|
||||
|
||||
total_users = 0
|
||||
successful_implicit = 0
|
||||
successful_emotion = 0
|
||||
failed = 0
|
||||
user_results = []
|
||||
|
||||
with get_db_context() as db:
|
||||
try:
|
||||
# 获取所有已存储数据的用户ID(分批次处理)
|
||||
repo = ImplicitEmotionsStorageRepository(db)
|
||||
|
||||
# 先统计总数用于日志
|
||||
from sqlalchemy import func
|
||||
total_users = db.execute(
|
||||
select(func.count()).select_from(ImplicitEmotionsStorage)
|
||||
).scalar() or 0
|
||||
logger.info(f"找到 {total_users} 个需要更新的用户")
|
||||
|
||||
# 遍历每个用户并更新数据(分批次,避免一次性加载所有ID)
|
||||
for end_user_id in repo.get_all_user_ids(batch_size=100):
|
||||
logger.info(f"开始处理用户: {end_user_id}")
|
||||
user_start_time = time.time()
|
||||
|
||||
implicit_success = False
|
||||
emotion_success = False
|
||||
errors = []
|
||||
|
||||
try:
|
||||
# 更新隐性记忆画像
|
||||
try:
|
||||
implicit_service = ImplicitMemoryService(db=db, end_user_id=end_user_id)
|
||||
profile_data = await implicit_service.generate_complete_profile(user_id=end_user_id)
|
||||
await implicit_service.save_profile_cache(
|
||||
end_user_id=end_user_id,
|
||||
profile_data=profile_data,
|
||||
db=db
|
||||
)
|
||||
implicit_success = True
|
||||
logger.info(f"成功更新用户 {end_user_id} 的隐性记忆画像")
|
||||
except Exception as e:
|
||||
error_msg = f"隐性记忆更新失败: {str(e)}"
|
||||
errors.append(error_msg)
|
||||
logger.error(f"用户 {end_user_id} {error_msg}")
|
||||
|
||||
# 更新情绪建议
|
||||
try:
|
||||
emotion_service = EmotionAnalyticsService()
|
||||
suggestions_data = await emotion_service.generate_emotion_suggestions(
|
||||
end_user_id=end_user_id,
|
||||
db=db,
|
||||
language="zh"
|
||||
)
|
||||
await emotion_service.save_suggestions_cache(
|
||||
end_user_id=end_user_id,
|
||||
suggestions_data=suggestions_data,
|
||||
db=db
|
||||
)
|
||||
emotion_success = True
|
||||
logger.info(f"成功更新用户 {end_user_id} 的情绪建议")
|
||||
except Exception as e:
|
||||
error_msg = f"情绪建议更新失败: {str(e)}"
|
||||
errors.append(error_msg)
|
||||
logger.error(f"用户 {end_user_id} {error_msg}")
|
||||
|
||||
# 统计结果
|
||||
if implicit_success:
|
||||
successful_implicit += 1
|
||||
if emotion_success:
|
||||
successful_emotion += 1
|
||||
if not implicit_success and not emotion_success:
|
||||
failed += 1
|
||||
|
||||
user_elapsed = time.time() - user_start_time
|
||||
|
||||
# 记录用户处理结果
|
||||
user_result = {
|
||||
"end_user_id": end_user_id,
|
||||
"implicit_success": implicit_success,
|
||||
"emotion_success": emotion_success,
|
||||
"errors": errors,
|
||||
"elapsed_time": user_elapsed
|
||||
}
|
||||
user_results.append(user_result)
|
||||
|
||||
logger.info(
|
||||
f"用户 {end_user_id} 处理完成: "
|
||||
f"隐性记忆={'成功' if implicit_success else '失败'}, "
|
||||
f"情绪建议={'成功' if emotion_success else '失败'}, "
|
||||
f"耗时={user_elapsed:.2f}秒"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
# 单个用户失败不影响其他用户(错误隔离)
|
||||
failed += 1
|
||||
user_elapsed = time.time() - user_start_time
|
||||
error_info = {
|
||||
"end_user_id": end_user_id,
|
||||
"implicit_success": False,
|
||||
"emotion_success": False,
|
||||
"errors": [str(e)],
|
||||
"elapsed_time": user_elapsed
|
||||
}
|
||||
user_results.append(error_info)
|
||||
logger.error(f"处理用户 {end_user_id} 时出错: {str(e)}")
|
||||
|
||||
# ---- 处理增量用户(当天新增、尚未初始化的用户)----
|
||||
new_users_initialized = 0
|
||||
new_users_failed = 0
|
||||
logger.info("开始处理当天新增的增量用户初始化")
|
||||
|
||||
for end_user_id in repo.get_new_user_ids_today(batch_size=100):
|
||||
logger.info(f"开始初始化新用户: {end_user_id}")
|
||||
user_start_time = time.time()
|
||||
implicit_success = False
|
||||
emotion_success = False
|
||||
errors = []
|
||||
|
||||
try:
|
||||
try:
|
||||
implicit_service = ImplicitMemoryService(db=db, end_user_id=end_user_id)
|
||||
profile_data = await implicit_service.generate_complete_profile(user_id=end_user_id)
|
||||
await implicit_service.save_profile_cache(
|
||||
end_user_id=end_user_id,
|
||||
profile_data=profile_data,
|
||||
db=db
|
||||
)
|
||||
implicit_success = True
|
||||
logger.info(f"成功初始化新用户 {end_user_id} 的隐性记忆画像")
|
||||
except Exception as e:
|
||||
error_msg = f"隐性记忆初始化失败: {str(e)}"
|
||||
errors.append(error_msg)
|
||||
logger.error(f"新用户 {end_user_id} {error_msg}")
|
||||
|
||||
try:
|
||||
emotion_service = EmotionAnalyticsService()
|
||||
suggestions_data = await emotion_service.generate_emotion_suggestions(
|
||||
end_user_id=end_user_id,
|
||||
db=db,
|
||||
language="zh"
|
||||
)
|
||||
await emotion_service.save_suggestions_cache(
|
||||
end_user_id=end_user_id,
|
||||
suggestions_data=suggestions_data,
|
||||
db=db
|
||||
)
|
||||
emotion_success = True
|
||||
logger.info(f"成功初始化新用户 {end_user_id} 的情绪建议")
|
||||
except Exception as e:
|
||||
error_msg = f"情绪建议初始化失败: {str(e)}"
|
||||
errors.append(error_msg)
|
||||
logger.error(f"新用户 {end_user_id} {error_msg}")
|
||||
|
||||
if implicit_success or emotion_success:
|
||||
new_users_initialized += 1
|
||||
else:
|
||||
new_users_failed += 1
|
||||
|
||||
user_elapsed = time.time() - user_start_time
|
||||
user_results.append({
|
||||
"end_user_id": end_user_id,
|
||||
"type": "init",
|
||||
"implicit_success": implicit_success,
|
||||
"emotion_success": emotion_success,
|
||||
"errors": errors,
|
||||
"elapsed_time": user_elapsed
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
new_users_failed += 1
|
||||
user_elapsed = time.time() - user_start_time
|
||||
user_results.append({
|
||||
"end_user_id": end_user_id,
|
||||
"type": "init",
|
||||
"implicit_success": False,
|
||||
"emotion_success": False,
|
||||
"errors": [str(e)],
|
||||
"elapsed_time": user_elapsed
|
||||
})
|
||||
logger.error(f"初始化新用户 {end_user_id} 时出错: {str(e)}")
|
||||
|
||||
logger.info(
|
||||
f"增量用户初始化完成: 成功={new_users_initialized}, 失败={new_users_failed}"
|
||||
)
|
||||
# ---- 增量用户处理结束 ----
|
||||
|
||||
# 记录总体统计信息
|
||||
logger.info(
|
||||
f"隐性记忆和情绪数据更新定时任务完成: "
|
||||
f"存量用户总数={total_users}, "
|
||||
f"隐性记忆成功={successful_implicit}, "
|
||||
f"情绪建议成功={successful_emotion}, "
|
||||
f"存量失败={failed}, "
|
||||
f"增量初始化成功={new_users_initialized}, "
|
||||
f"增量初始化失败={new_users_failed}"
|
||||
)
|
||||
|
||||
return {
|
||||
"status": "SUCCESS",
|
||||
"message": (
|
||||
f"存量用户 {total_users} 个,隐性记忆 {successful_implicit} 个成功,情绪建议 {successful_emotion} 个成功;"
|
||||
f"增量新用户初始化 {new_users_initialized} 个成功,{new_users_failed} 个失败"
|
||||
),
|
||||
"total_users": total_users,
|
||||
"successful_implicit": successful_implicit,
|
||||
"successful_emotion": successful_emotion,
|
||||
"failed": failed,
|
||||
"new_users_initialized": new_users_initialized,
|
||||
"new_users_failed": new_users_failed,
|
||||
"user_results": user_results[:50] # 只保留前50个用户的详细结果
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"隐性记忆和情绪数据更新定时任务执行失败: {str(e)}")
|
||||
return {
|
||||
"status": "FAILURE",
|
||||
"error": str(e),
|
||||
"total_users": total_users,
|
||||
"successful_implicit": successful_implicit,
|
||||
"successful_emotion": successful_emotion,
|
||||
"failed": failed,
|
||||
"new_users_initialized": 0,
|
||||
"new_users_failed": 0,
|
||||
"user_results": user_results[:50]
|
||||
}
|
||||
|
||||
try:
|
||||
# 使用 nest_asyncio 来避免事件循环冲突
|
||||
try:
|
||||
import nest_asyncio
|
||||
nest_asyncio.apply()
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# 尝试获取现有事件循环,如果不存在则创建新的
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
if loop.is_closed():
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
except RuntimeError:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
result = loop.run_until_complete(_run())
|
||||
elapsed_time = time.time() - start_time
|
||||
result["elapsed_time"] = elapsed_time
|
||||
result["task_id"] = self.request.id
|
||||
|
||||
return result
|
||||
except Exception as e:
|
||||
elapsed_time = time.time() - start_time
|
||||
return {
|
||||
"status": "FAILURE",
|
||||
"error": str(e),
|
||||
"elapsed_time": elapsed_time,
|
||||
"task_id": self.request.id
|
||||
}
|
||||
|
||||
@@ -1,4 +1,68 @@
|
||||
{
|
||||
"v0.2.5": {
|
||||
"introduction": {
|
||||
"codeName": "行云",
|
||||
"releaseDate": "2026-2-26",
|
||||
"upgradePosition": "🐻 精炼根基,优化核心用户体验与系统稳定性",
|
||||
"coreUpgrades": [
|
||||
"1. 用户体验与国际化 🎨<br>* 语言参数修复:语言偏好现正确保留<br>* 邮箱修改支持:用户可直接在用户管理系统中修改邮箱地址",
|
||||
"2. 工作流可视化增强 💬<br>* 循环与迭代节点输出展示:实时显示执行进度和中间输出,便于调试复杂迭代过程<br>* 变量支持回车选择:支持回车键确认变量选择,简化工作流配置流程",
|
||||
"3. 优化模型管理 ⚙️<br>* 模型广场移除自定义模型,优化模型使用体验",
|
||||
"4. 稳健性与缺陷修复 🔧<br>* 知识图谱构建修复:解决知识图谱构建流程稳定性问题,确保更可靠的实体提取和关系映射",
|
||||
"<br>",
|
||||
"版本 0.2.5 通过解决国际化边界情况和改进工作流透明度,构建更具生产就绪性的平台。工作流可视化改进为更复杂的调试和监控能力奠定基础。未来将继续深化企业就绪性,扩展用户管理功能、优化知识图谱智能和增强工作流编排能力,在可观测性、性能优化和无缝集成模式方面持续改进。",
|
||||
"智慧致远 🐻✨"
|
||||
]
|
||||
},
|
||||
"introduction_en": {
|
||||
"codeName": "Flowing Clouds",
|
||||
"releaseDate": "2026-2-26",
|
||||
"upgradePosition": "🐻 Refined foundations with enhanced user experience and system stability",
|
||||
"coreUpgrades": [
|
||||
"1. User Experience & Internationalization 🎨<br>* Language parameter fix: language preferences are now correctly retained<br>* Email Update Support: Users can now modify email addresses directly in user management system",
|
||||
"2. Workflow Visualization Enhancements 💬<br>* Loop & Iteration Node Output Display: Real-time display of execution progress and intermediate outputs for easier debugging<br>* Variable Selection with Enter Key: Enabled Enter key confirmation for streamlined variable assignment",
|
||||
"3. Optimized Model Management ⚙️<br>* Custom models have been removed from the Model marketplace to optimize the model usage experience",
|
||||
"4. Robustness & Bug Fixes 🔧<br>* Knowledge Graph Construction Fix: Addressed stability issues in knowledge graph pipeline for more reliable entity extraction and relationship mapping",
|
||||
"<br>",
|
||||
"Version 0.2.5 matures MemoryBear's operational foundations by addressing internationalization edge cases and improving workflow transparency. The workflow visualization improvements lay groundwork for sophisticated debugging and monitoring capabilities. Looking forward, we will deepen enterprise readiness by expanding user management features, refining knowledge graph intelligence, and enhancing workflow orchestration with continued improvements in observability, performance optimization, and seamless integration patterns.",
|
||||
"Intelligent Resilience 🐻✨"
|
||||
]
|
||||
}
|
||||
},
|
||||
"v0.2.4": {
|
||||
"introduction": {
|
||||
"codeName": "智远",
|
||||
"releaseDate": "2026-2-11",
|
||||
"upgradePosition": "🐻 生产级稳健性升级版本,智慧致远,从容应对复杂场景",
|
||||
"coreUpgrades": [
|
||||
"1. Skills 技能框架 🛠️<br>* Skills 支持:引入全新的Skills技能系统,支持可扩展的能力模块,可在Agent和工作流中动态加载与编排",
|
||||
"2. 多模态与交互 💬<br>* 文件多模态支持:全面支持消息输入、LLM处理和输出渲染中的多模态文件处理,实现更丰富的媒体感知对话<br>* 语音交互:语音交互功能正在积极开发中,为免提对话体验奠定基础(开发中)",
|
||||
"3. 知识库集成 📚<br>* 飞书知识库:无缝对接飞书文档库,支持企业知识检索<br>* 语雀知识库:原生连接语雀文档平台,扩展对国内企业工具生态的覆盖<br>* Web站点知识库:通用Web站点抓取与索引,支持从公开网页内容构建知识库<br>* 视觉模型选择优化:知识库视觉模型配置现已支持LLM和Chat两种模型类型,移除了此前仅限Chat类型的限制",
|
||||
"4. 记忆智能 🧠<br>* 本体工程(二期):基于本体工程的高级记忆场景分类与萃取,实现结构化、领域感知的记忆组织,提升分类准确性<br>* 默认模型配置:情绪分析、反思和记忆萃取模块现默认使用空间级模型,确保开箱即用的一致性行为<br>* 智能模型回退:当已配置的情绪或反思模型为空或不可用时,系统自动回退至空间默认模型,避免静默失败<br>* 记忆模型回退兜底:当记忆中配置的模型为空或不可用时,系统优雅降级至空间默认模型",
|
||||
"5. 性能与扩展 ⚡<br>* 模型并发(model_api_keys):支持并发模型API Key管理,实现并行模型调用,提升高负载场景下的吞吐能力",
|
||||
"6. 稳健性与缺陷修复 🔧<br>* 记忆配置版本固定:修复用户记忆配置未跟随应用版本发布固定的问题,消除跨部署的行为不一致<br>* 空间默认记忆保护:空间级默认记忆配置现不可删除;用户级配置仍可删除<br>* Agent与工作流配置兜底:解决Agent和工作流节点中记忆配置可能为空、或已选择但未配置的边界情况——全面的回退处理现可防止运行时错误<br>* 隐形记忆字段重命名:将隐形记忆接口JSON响应中的user_id修正为end_user_id,与规范数据模型对齐<br>* 记忆配置ID迁移:将Agent和工作流记忆配置中的memory_content重命名为memory_config_id,保持API一致性<br>* Worker-Memory告警解决:解决worker-memory服务中的告警级别问题,提升运维监控清晰度<br>* 双语接口修复:修复记忆相关API接口的中英文不一致问题<br>* 新用户记忆配置自动回填:新创建的EndUser若memory_config_id为None,系统自动从最新Release获取memory_config_id并回填<br>* 存量用户记忆配置自动回填:已有EndUser若memory_config_id为None,系统同样从最新Release获取并回填,确保向后兼容,无需手动迁移",
|
||||
"<br>",
|
||||
"Memory Bear v0.2.4 向生产级稳健性迈进,Skills框架与多模态支持开启认知平台新篇章。",
|
||||
"记忆熊,智慧致远,从容应对真实世界的多样性。🐻✨"
|
||||
]
|
||||
},
|
||||
"introduction_en": {
|
||||
"codeName": "ZhiYuan",
|
||||
"releaseDate": "2026-2-11",
|
||||
"upgradePosition": "🐻 Production-grade resilience release — Wisdom Reaching Far, gracefully handling complex scenarios",
|
||||
"coreUpgrades": [
|
||||
"1. Skills Framework 🛠️<br>* Skills Support: Introduced a new Skills system, enabling extensible capability modules that can be dynamically loaded and orchestrated within agents and workflows",
|
||||
"2. Multimodal & Interaction 💬<br>* File Multimodal Support: Full multimodal file handling across message input, LLM processing, and output rendering — supporting richer, media-aware conversations<br>* Voice Interaction: Voice-based interaction capabilities are under active development, laying the groundwork for hands-free conversational experiences (In Progress)",
|
||||
"3. Knowledge Base Integration 📚<br>* Feishu Knowledge Base: Seamless integration with Feishu (Lark) document repositories for enterprise knowledge retrieval<br>* Yuque Knowledge Base: Native connector for Yuque documentation platforms, expanding coverage of Chinese enterprise tooling<br>* Web Site Knowledge Base: General-purpose web site crawling and indexing for knowledge base construction from public web content<br>* Visual Model Selection: Knowledge base visual model configuration now supports both LLM and Chat model types, removing the previous restriction to Chat-only selection",
|
||||
"4. Memory Intelligence 🧠<br>* Ontology Engineering (Phase 2): Advanced memory scene classification and extraction powered by ontology engineering — enabling structured, domain-aware memory organization with improved categorization accuracy<br>* Default Model Configuration: Emotion analysis, reflection, and memory extraction modules now default to the space-level model, ensuring consistent behavior out of the box<br>* Intelligent Model Fallback: If configured emotion or reflection models are empty or unavailable, the system automatically falls back to the space default model — preventing silent failures<br>* Memory Config Fallback for Models: When any memory-configured model is empty or unavailable, the system gracefully degrades to the space default model",
|
||||
"5. Performance & Scalability ⚡<br>* Model Concurrency (model_api_keys): Support for concurrent model API key management, enabling parallel model invocations and improved throughput for high-load scenarios",
|
||||
"6. Robustness & Bug Fixes 🔧<br>* Memory Config Version Pinning: Fixed an issue where user memory configurations were not pinned to application release versions, causing inconsistent behavior across deployments<br>* Space Default Memory Protection: Space-level default memory configurations are now protected from deletion; user-level configurations remain deletable<br>* Agent & Workflow Config Fallback: Resolved edge cases in Agent and Workflow nodes where memory config could be empty or selected but unconfigured — comprehensive fallback handling now prevents runtime errors<br>* Implicit Memory Field Rename: Corrected user_id to end_user_id in JSON responses from implicit memory interfaces, aligning with the canonical data model<br>* Memory Config ID Migration: Renamed memory_content to memory_config_id in Agent and Workflow memory configurations for API consistency<br>* Worker-Memory Alerts: Resolved warning-level alerts in the worker-memory service, improving operational monitoring clarity<br>* Bilingual Interface Fixes: Fixed Chinese/English language inconsistencies across memory-related API interfaces<br>* EndUser Memory Config Auto-Backfill (New Users): When a newly created EndUser has memory_config_id as None, the system automatically fetches the latest release's memory_config_id and backfills it<br>* EndUser Memory Config Auto-Backfill (Existing Users): For existing EndUsers with memory_config_id as None, the system similarly retrieves and backfills from the latest release — ensuring backward compatibility without manual migration",
|
||||
"<br>",
|
||||
"Memory Bear v0.2.4 advances toward production-grade resilience, with the Skills framework and multimodal support opening a new chapter for the cognitive platform.",
|
||||
"MemoryBear — Wisdom Reaching Far, gracefully handling real-world variability. 🐻✨"
|
||||
]
|
||||
}
|
||||
},
|
||||
"v0.2.3": {
|
||||
"introduction": {
|
||||
"codeName": "归墟",
|
||||
|
||||
@@ -64,6 +64,9 @@ LANGCHAIN_ENDPOINT=
|
||||
# Generate a new one with: openssl rand -hex 32
|
||||
SECRET_KEY=your-secret-key-here-generate-with-openssl-rand-hex-32
|
||||
|
||||
# official environment system version
|
||||
SYSTEM_VERSION=
|
||||
|
||||
# JWT Token expiration settings
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES=30
|
||||
REFRESH_TOKEN_EXPIRE_DAYS=7
|
||||
@@ -129,6 +132,12 @@ KB_image2text_id=
|
||||
config_id=
|
||||
reranker_id=
|
||||
|
||||
# Email Configuration
|
||||
SMTP_SERVER=
|
||||
SMTP_PORT=
|
||||
SMTP_USER=
|
||||
SMTP_PASSWORD=
|
||||
|
||||
# 本体类型融合配置 (记得写入env_example)
|
||||
GENERAL_ONTOLOGY_FILES=General_purpose_entity.ttl # 指定要加载的本体文件路径,多个文件用逗号分隔
|
||||
ENABLE_GENERAL_ONTOLOGY_TYPES=true # 总开关,控制是否启用通用本体类型融合功能(false = 不使用任何本体类型指导)
|
||||
|
||||
66
api/migrations/versions/75e28690ae87_202602251230.py
Normal file
66
api/migrations/versions/75e28690ae87_202602251230.py
Normal file
@@ -0,0 +1,66 @@
|
||||
"""202602251230
|
||||
|
||||
Revision ID: 75e28690ae87
|
||||
Revises: bab823f7cc82
|
||||
Create Date: 2026-02-25 12:27:36.919237
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '75e28690ae87'
|
||||
down_revision: Union[str, None] = 'bab823f7cc82'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table('mcp_market_configs',
|
||||
sa.Column('id', sa.UUID(), nullable=False),
|
||||
sa.Column('mcp_market_id', sa.UUID(), nullable=False, comment='mcp_markets.id'),
|
||||
sa.Column('token', sa.String(), nullable=True, comment='mcp market token'),
|
||||
sa.Column('status', sa.Integer(), nullable=True, comment='connect status(0: Not connected, 1: connected)'),
|
||||
sa.Column('tenant_id', sa.UUID(), nullable=False, comment='tenant.id'),
|
||||
sa.Column('created_by', sa.UUID(), nullable=False, comment='users.id'),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=True),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_index(op.f('ix_mcp_market_configs_id'), 'mcp_market_configs', ['id'], unique=False)
|
||||
op.create_table('mcp_markets',
|
||||
sa.Column('id', sa.UUID(), nullable=False),
|
||||
sa.Column('name', sa.String(), nullable=False, comment='mcp market name'),
|
||||
sa.Column('description', sa.String(), nullable=True, comment='mcp market description'),
|
||||
sa.Column('logo_url', sa.String(), nullable=True, comment='logo url'),
|
||||
sa.Column('mcp_count', sa.Integer(), nullable=True, comment='mcp count'),
|
||||
sa.Column('url', sa.String(), nullable=False, comment='mcp market url'),
|
||||
sa.Column('category', sa.String(), nullable=False, comment='category'),
|
||||
sa.Column('created_by', sa.UUID(), nullable=False, comment='users.id'),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=True),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_index(op.f('ix_mcp_markets_category'), 'mcp_markets', ['category'], unique=False)
|
||||
op.create_index(op.f('ix_mcp_markets_description'), 'mcp_markets', ['description'], unique=False)
|
||||
op.create_index(op.f('ix_mcp_markets_id'), 'mcp_markets', ['id'], unique=False)
|
||||
op.create_index(op.f('ix_mcp_markets_logo_url'), 'mcp_markets', ['logo_url'], unique=False)
|
||||
op.create_index(op.f('ix_mcp_markets_name'), 'mcp_markets', ['name'], unique=False)
|
||||
op.create_index(op.f('ix_mcp_markets_url'), 'mcp_markets', ['url'], unique=False)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_index(op.f('ix_mcp_markets_url'), table_name='mcp_markets')
|
||||
op.drop_index(op.f('ix_mcp_markets_name'), table_name='mcp_markets')
|
||||
op.drop_index(op.f('ix_mcp_markets_logo_url'), table_name='mcp_markets')
|
||||
op.drop_index(op.f('ix_mcp_markets_id'), table_name='mcp_markets')
|
||||
op.drop_index(op.f('ix_mcp_markets_description'), table_name='mcp_markets')
|
||||
op.drop_index(op.f('ix_mcp_markets_category'), table_name='mcp_markets')
|
||||
op.drop_table('mcp_markets')
|
||||
op.drop_index(op.f('ix_mcp_market_configs_id'), table_name='mcp_market_configs')
|
||||
op.drop_table('mcp_market_configs')
|
||||
# ### end Alembic commands ###
|
||||
36
api/migrations/versions/7672d8f0f939_202602271020.py
Normal file
36
api/migrations/versions/7672d8f0f939_202602271020.py
Normal file
@@ -0,0 +1,36 @@
|
||||
"""202602271020
|
||||
|
||||
Revision ID: 7672d8f0f939
|
||||
Revises: 75e28690ae87
|
||||
Create Date: 2026-02-27 10:21:46.951584
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '7672d8f0f939'
|
||||
down_revision: Union[str, None] = '75e28690ae87'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.alter_column('file_metadata', 'workspace_id',
|
||||
existing_type=sa.UUID(),
|
||||
nullable=True,
|
||||
existing_comment='Workspace ID')
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.alter_column('file_metadata', 'workspace_id',
|
||||
existing_type=sa.UUID(),
|
||||
nullable=False,
|
||||
existing_comment='Workspace ID')
|
||||
# ### end Alembic commands ###
|
||||
@@ -144,6 +144,7 @@ dependencies = [
|
||||
"rdflib>=7.0.0",
|
||||
"lxml>=4.9.0",
|
||||
"httpx>=0.28.0",
|
||||
"modelscope>=1.34.0",
|
||||
]
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
|
||||
@@ -137,3 +137,4 @@ boto3>=1.28.0
|
||||
aiofiles>=23.0.0
|
||||
lxml>=4.9.0
|
||||
httpx>=0.28.0
|
||||
modelscope>=1.34.0
|
||||
|
||||
@@ -4,8 +4,8 @@
|
||||
# @Time : 2026/2/6
|
||||
import pytest
|
||||
|
||||
from app.core.workflow.engine.variable_pool import VariablePool, VariableSelector
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
from app.core.workflow.variable_pool import VariablePool, VariableSelector
|
||||
|
||||
|
||||
# ==================== VariableSelector 测试 ====================
|
||||
|
||||
@@ -6,8 +6,8 @@ import os
|
||||
|
||||
import pytest
|
||||
|
||||
from app.core.workflow.variable.base_variable import VariableType, DEFAULT_VALUE
|
||||
from app.core.workflow.variable_pool import VariablePool
|
||||
from app.core.workflow.engine.variable_pool import VariablePool
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
|
||||
TEST_WORKSPACE_ID = "test_workspace_id"
|
||||
TEST_USER_ID = "test_user_id"
|
||||
|
||||
@@ -4,11 +4,11 @@
|
||||
# @Time : 2026/2/6
|
||||
import pytest
|
||||
|
||||
from app.core.workflow.engine.variable_pool import VariablePool
|
||||
from app.core.workflow.nodes import StartNode
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
from app.core.workflow.variable_pool import VariablePool
|
||||
from tests.workflow.nodes.base import (
|
||||
simple_state,
|
||||
simple_state,
|
||||
simple_vairable_pool,
|
||||
TEST_EXECUTION_ID,
|
||||
TEST_WORKSPACE_ID,
|
||||
|
||||
Submodule redbear-mem-benchmark updated: 0c4bcafbc1...8494e82498
@@ -33,8 +33,6 @@ key = b64decode(key)
|
||||
|
||||
os.chdir(running_path)
|
||||
|
||||
# Preload code
|
||||
{{preload}}
|
||||
|
||||
# Apply security if library is available
|
||||
init_status = lib.init_seccomp({{uid}}, {{gid}}, {{enable_network}})
|
||||
@@ -42,6 +40,8 @@ if init_status != 0:
|
||||
raise Exception(f"code executor err - {str(init_status)}")
|
||||
del lib
|
||||
|
||||
# Preload code
|
||||
{{preload}}
|
||||
# Decrypt and execute code
|
||||
code = b64decode("{{code}}")
|
||||
|
||||
|
||||
@@ -37,7 +37,7 @@ function App() {
|
||||
const { checkJump } = useUser();
|
||||
useEffect(() => {
|
||||
const authToken = cookieUtils.get('authToken')
|
||||
if (!authToken && !window.location.hash.includes('#/login') && !window.location.hash.includes('#/conversation/') && !window.location.hash.includes('#/jump')) {
|
||||
if (!authToken && !window.location.hash.includes('#/login') && !window.location.hash.includes('#/conversation/') && !window.location.hash.includes('#/jump') && !window.location.hash.includes('#/invite-register')) {
|
||||
window.location.href = `/#/login`;
|
||||
} else {
|
||||
checkJump()
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
/*
|
||||
* @Author: ZhaoYing
|
||||
* @Date: 2026-02-03 14:00:06
|
||||
* @Last Modified by: ZhaoYing
|
||||
* @Last Modified time: 2026-02-03 14:00:06
|
||||
* @Last Modified by: ZhaoYing
|
||||
* @Last Modified time: 2026-03-03 14:58:32
|
||||
*/
|
||||
import { request } from '@/utils/request'
|
||||
import type {
|
||||
@@ -163,9 +163,14 @@ export const getImplicitInterestAreas = (end_user_id: string) => {
|
||||
export const getImplicitHabits = (end_user_id: string) => {
|
||||
return request.get(`/memory/implicit-memory/habits/${end_user_id}`)
|
||||
}
|
||||
// Implicit Memory - Generate user portrait
|
||||
export const generateProfile = (end_user_id: string) => {
|
||||
return request.post(`/memory/implicit-memory/generate_profile`, { end_user_id })
|
||||
}
|
||||
// Implicit Memory - Check if data exists
|
||||
export const implicitCheckData = (end_user_id: string) => {
|
||||
return request.get(`/memory/implicit-memory/check-data/${end_user_id}`)
|
||||
}
|
||||
// Short-term memory
|
||||
export const getShortTerm = (end_user_id: string) => {
|
||||
return request.get(`/memory/short/short_term`, { end_user_id })
|
||||
|
||||
@@ -66,9 +66,9 @@ export const addModelPlaza = (model_base_id: string) => {
|
||||
}
|
||||
// Create custom model
|
||||
export const addCustomModel = (data: CustomModelForm) => {
|
||||
return request.post('/models/model_plaza', data)
|
||||
return request.post('/models', data)
|
||||
}
|
||||
// Update custom model
|
||||
export const updateCustomModel = (model_base_id: string, data: CustomModelForm) => {
|
||||
return request.put(`/models/model_plaza/${model_base_id}`, data)
|
||||
return request.put(`/models/${model_base_id}`, data)
|
||||
}
|
||||
@@ -1,11 +1,11 @@
|
||||
/*
|
||||
* @Author: ZhaoYing
|
||||
* @Date: 2026-02-03 14:00:23
|
||||
* @Last Modified by: ZhaoYing
|
||||
* @Last Modified time: 2026-02-03 14:00:23
|
||||
* @Last Modified by: ZhaoYing
|
||||
* @Last Modified time: 2026-02-25 11:17:44
|
||||
*/
|
||||
import { request } from '@/utils/request'
|
||||
import type { CreateModalData } from '@/views/UserManagement/types'
|
||||
import type { CreateModalData, ChangeEmailModalForm } from '@/views/UserManagement/types'
|
||||
import { cookieUtils } from '@/utils/request'
|
||||
|
||||
// User info
|
||||
@@ -28,6 +28,10 @@ export const refreshToken = () => {
|
||||
export const changePassword = (data: { user_id: string; new_password: string }) => {
|
||||
return request.put('/users/admin/change-password', data)
|
||||
}
|
||||
// Verify password
|
||||
export const verifyPassword = (data: { password: string }) => {
|
||||
return request.post('/users/verify_pwd', data)
|
||||
}
|
||||
// Disable user
|
||||
export const deleteUser = (user_id: string) => {
|
||||
return request.delete(`/users/${user_id}`)
|
||||
@@ -44,4 +48,12 @@ export const addUser = (data: CreateModalData) => {
|
||||
export const logoutUrl = '/logout'
|
||||
export const logout = () => {
|
||||
return request.post(logoutUrl)
|
||||
}
|
||||
// Send email verification code
|
||||
export const sendEmailCode = (data: { email: string }) => {
|
||||
return request.post('/users/send-email-code', data)
|
||||
}
|
||||
// Verify code and change email
|
||||
export const changeEmail = (data: ChangeEmailModalForm) => {
|
||||
return request.put('/users/change-email', data)
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user