Merge branch 'develop' into feature/tool_yjp

This commit is contained in:
yujiangping
2026-03-06 15:03:57 +08:00
120 changed files with 3672 additions and 2480 deletions

View File

@@ -226,8 +226,8 @@ REDIS_PORT=6379
REDIS_DB=1
# Celery (Using Redis as broker)
BROKER_URL=redis://127.0.0.1:6379/0
RESULT_BACKEND=redis://127.0.0.1:6379/0
REDIS_DB_CELERY_BROKER=1
REDIS_DB_CELERY_BACKEND=2
# JWT Secret Key (Formation method: openssl rand -hex 32)
SECRET_KEY=your-secret-key-here

View File

@@ -201,8 +201,8 @@ REDIS_PORT=6379
REDIS_DB=1
# Celery (使用Redis作为broker)
BROKER_URL=redis://127.0.0.1:6379/0
RESULT_BACKEND=redis://127.0.0.1:6379/0
REDIS_DB_CELERY_BROKER=1
REDIS_DB_CELERY_BACKEND=2
# JWT密钥 (生成方式: openssl rand -hex 32)
SECRET_KEY=your-secret-key-here

View File

@@ -3,10 +3,8 @@ Cache 缓存模块
提供各种缓存功能的统一入口
"""
from .memory import EmotionMemoryCache, ImplicitMemoryCache, InterestMemoryCache
from .memory import InterestMemoryCache
__all__ = [
"EmotionMemoryCache",
"ImplicitMemoryCache",
"InterestMemoryCache",
]

View File

@@ -3,12 +3,8 @@ Memory 缓存模块
提供记忆系统相关的缓存功能
"""
from .emotion_memory import EmotionMemoryCache
from .implicit_memory import ImplicitMemoryCache
from .interest_memory import InterestMemoryCache
__all__ = [
"EmotionMemoryCache",
"ImplicitMemoryCache",
"InterestMemoryCache",
]

View File

@@ -1,134 +0,0 @@
"""
Emotion Suggestions Cache
情绪个性化建议缓存模块
用于缓存用户的情绪个性化建议数据
"""
import json
import logging
from typing import Optional, Dict, Any
from datetime import datetime
from app.aioRedis import aio_redis
logger = logging.getLogger(__name__)
class EmotionMemoryCache:
"""情绪建议缓存类"""
# Key 前缀
PREFIX = "cache:memory:emotion_memory"
@classmethod
def _get_key(cls, *parts: str) -> str:
"""生成 Redis key
Args:
*parts: key 的各个部分
Returns:
完整的 Redis key
"""
return ":".join([cls.PREFIX] + list(parts))
@classmethod
async def set_emotion_suggestions(
cls,
user_id: str,
suggestions_data: Dict[str, Any],
expire: int = 86400
) -> bool:
"""设置用户情绪建议缓存
Args:
user_id: 用户IDend_user_id
suggestions_data: 建议数据字典,包含:
- health_summary: 健康状态摘要
- suggestions: 建议列表
- generated_at: 生成时间(可选)
expire: 过期时间默认24小时86400秒
Returns:
是否设置成功
"""
try:
key = cls._get_key("suggestions", user_id)
# 添加生成时间戳
if "generated_at" not in suggestions_data:
suggestions_data["generated_at"] = datetime.now().isoformat()
# 添加缓存标记
suggestions_data["cached"] = True
value = json.dumps(suggestions_data, ensure_ascii=False)
await aio_redis.set(key, value, ex=expire)
logger.info(f"设置情绪建议缓存成功: {key}, 过期时间: {expire}")
return True
except Exception as e:
logger.error(f"设置情绪建议缓存失败: {e}", exc_info=True)
return False
@classmethod
async def get_emotion_suggestions(cls, user_id: str) -> Optional[Dict[str, Any]]:
"""获取用户情绪建议缓存
Args:
user_id: 用户IDend_user_id
Returns:
建议数据字典,如果不存在或已过期返回 None
"""
try:
key = cls._get_key("suggestions", user_id)
value = await aio_redis.get(key)
if value:
data = json.loads(value)
logger.info(f"成功获取情绪建议缓存: {key}")
return data
logger.info(f"情绪建议缓存不存在或已过期: {key}")
return None
except Exception as e:
logger.error(f"获取情绪建议缓存失败: {e}", exc_info=True)
return None
@classmethod
async def delete_emotion_suggestions(cls, user_id: str) -> bool:
"""删除用户情绪建议缓存
Args:
user_id: 用户IDend_user_id
Returns:
是否删除成功
"""
try:
key = cls._get_key("suggestions", user_id)
result = await aio_redis.delete(key)
logger.info(f"删除情绪建议缓存: {key}, 结果: {result}")
return result > 0
except Exception as e:
logger.error(f"删除情绪建议缓存失败: {e}", exc_info=True)
return False
@classmethod
async def get_suggestions_ttl(cls, user_id: str) -> int:
"""获取情绪建议缓存的剩余过期时间
Args:
user_id: 用户IDend_user_id
Returns:
剩余秒数,-1表示永不过期-2表示key不存在
"""
try:
key = cls._get_key("suggestions", user_id)
ttl = await aio_redis.ttl(key)
logger.debug(f"情绪建议缓存TTL: {key} = {ttl}")
return ttl
except Exception as e:
logger.error(f"获取情绪建议缓存TTL失败: {e}")
return -2

View File

@@ -1,136 +0,0 @@
"""
Implicit Memory Profile Cache
隐式记忆用户画像缓存模块
用于缓存用户的完整画像数据(偏好标签、四维画像、兴趣领域、行为习惯)
"""
import json
import logging
from typing import Optional, Dict, Any
from datetime import datetime
from app.aioRedis import aio_redis
logger = logging.getLogger(__name__)
class ImplicitMemoryCache:
"""隐式记忆用户画像缓存类"""
# Key 前缀
PREFIX = "cache:memory:implicit_memory"
@classmethod
def _get_key(cls, *parts: str) -> str:
"""生成 Redis key
Args:
*parts: key 的各个部分
Returns:
完整的 Redis key
"""
return ":".join([cls.PREFIX] + list(parts))
@classmethod
async def set_user_profile(
cls,
user_id: str,
profile_data: Dict[str, Any],
expire: int = 86400
) -> bool:
"""设置用户完整画像缓存
Args:
user_id: 用户IDend_user_id
profile_data: 画像数据字典,包含:
- preferences: 偏好标签列表
- portrait: 四维画像对象
- interest_areas: 兴趣领域分布对象
- habits: 行为习惯列表
- generated_at: 生成时间(可选)
expire: 过期时间默认24小时86400秒
Returns:
是否设置成功
"""
try:
key = cls._get_key("profile", user_id)
# 添加生成时间戳
if "generated_at" not in profile_data:
profile_data["generated_at"] = datetime.now().isoformat()
# 添加缓存标记
profile_data["cached"] = True
value = json.dumps(profile_data, ensure_ascii=False)
await aio_redis.set(key, value, ex=expire)
logger.info(f"设置用户画像缓存成功: {key}, 过期时间: {expire}")
return True
except Exception as e:
logger.error(f"设置用户画像缓存失败: {e}", exc_info=True)
return False
@classmethod
async def get_user_profile(cls, user_id: str) -> Optional[Dict[str, Any]]:
"""获取用户完整画像缓存
Args:
user_id: 用户IDend_user_id
Returns:
画像数据字典,如果不存在或已过期返回 None
"""
try:
key = cls._get_key("profile", user_id)
value = await aio_redis.get(key)
if value:
data = json.loads(value)
logger.info(f"成功获取用户画像缓存: {key}")
return data
logger.info(f"用户画像缓存不存在或已过期: {key}")
return None
except Exception as e:
logger.error(f"获取用户画像缓存失败: {e}", exc_info=True)
return None
@classmethod
async def delete_user_profile(cls, user_id: str) -> bool:
"""删除用户完整画像缓存
Args:
user_id: 用户IDend_user_id
Returns:
是否删除成功
"""
try:
key = cls._get_key("profile", user_id)
result = await aio_redis.delete(key)
logger.info(f"删除用户画像缓存: {key}, 结果: {result}")
return result > 0
except Exception as e:
logger.error(f"删除用户画像缓存失败: {e}", exc_info=True)
return False
@classmethod
async def get_profile_ttl(cls, user_id: str) -> int:
"""获取用户画像缓存的剩余过期时间
Args:
user_id: 用户IDend_user_id
Returns:
剩余秒数,-1表示永不过期-2表示key不存在
"""
try:
key = cls._get_key("profile", user_id)
ttl = await aio_redis.ttl(key)
logger.debug(f"用户画像缓存TTL: {key} = {ttl}")
return ttl
except Exception as e:
logger.error(f"获取用户画像缓存TTL失败: {e}")
return -2

View File

@@ -1,26 +1,54 @@
import os
import platform
from datetime import timedelta
from celery.schedules import crontab
from urllib.parse import quote
from celery import Celery
from celery.schedules import crontab
from app.core.config import settings
from app.core.logging_config import get_logger
logger = get_logger(__name__)
# macOS fork() safety - must be set before any Celery initialization
if platform.system() == 'Darwin':
os.environ.setdefault('OBJC_DISABLE_INITIALIZE_FORK_SAFETY', 'YES')
# 创建 Celery 应用实例
# broker: 任务队列(使用 Redis DB 0
# backend: 结果存储(使用 Redis DB 10
# broker: 任务队列(使用 Redis DB,由 CELERY_BROKER_DB 指定
# backend: 结果存储(使用 Redis DB,由 CELERY_BACKEND_DB 指定
# NOTE: 不要在 .env 中设置 BROKER_URL / RESULT_BACKEND / CELERY_BROKER / CELERY_BACKEND
# 这些名称会被 Celery CLI 的 Click 框架劫持,详见 docs/celery-env-bug-report.md
# Build canonical broker/backend URLs and force them into os.environ so that
# Celery's Settings.broker_url property (which checks CELERY_BROKER_URL first)
# cannot be overridden by stray env vars.
# See: https://github.com/celery/celery/issues/4284
_broker_url = f"redis://:{quote(settings.REDIS_PASSWORD)}@{settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.REDIS_DB_CELERY_BROKER}"
_backend_url = f"redis://:{quote(settings.REDIS_PASSWORD)}@{settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.REDIS_DB_CELERY_BACKEND}"
os.environ["CELERY_BROKER_URL"] = _broker_url
os.environ["CELERY_RESULT_BACKEND"] = _backend_url
# Neutralize legacy Celery env vars that can be hijacked by Celery's CLI/Click
# integration and accidentally override our canonical URLs.
os.environ.pop("BROKER_URL", None)
os.environ.pop("RESULT_BACKEND", None)
os.environ.pop("CELERY_BROKER", None)
os.environ.pop("CELERY_BACKEND", None)
celery_app = Celery(
"redbear_tasks",
broker=f"redis://:{quote(settings.REDIS_PASSWORD)}@{settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.CELERY_BROKER}",
backend=f"redis://:{quote(settings.REDIS_PASSWORD)}@{settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.CELERY_BACKEND}",
broker=_broker_url,
backend=_backend_url,
)
logger.info(
"Celery app initialized",
extra={
"broker": _broker_url.replace(quote(settings.REDIS_PASSWORD), "***"),
"backend": _backend_url.replace(quote(settings.REDIS_PASSWORD), "***"),
},
)
# Default queue for unrouted tasks
celery_app.conf.task_default_queue = 'memory_tasks'
@@ -44,8 +72,8 @@ celery_app.conf.update(
task_ignore_result=False,
# 超时设置
task_time_limit=1800, # 30分钟硬超时
task_soft_time_limit=1500, # 25分钟软超时
task_time_limit=3600, # 60分钟硬超时
task_soft_time_limit=3000, # 50分钟软超时
# Worker 设置 (per-worker settings are in docker-compose command line)
worker_prefetch_multiplier=1, # Don't hoard tasks, fairer distribution
@@ -84,6 +112,7 @@ celery_app.conf.update(
'app.tasks.regenerate_memory_cache': {'queue': 'periodic_tasks'},
'app.tasks.run_forgetting_cycle_task': {'queue': 'periodic_tasks'},
'app.tasks.write_all_workspaces_memory_task': {'queue': 'periodic_tasks'},
'app.tasks.update_implicit_emotions_storage': {'queue': 'periodic_tasks'},
},
)
@@ -95,6 +124,10 @@ memory_increment_schedule = crontab(hour=settings.MEMORY_INCREMENT_HOUR, minute=
memory_cache_regeneration_schedule = timedelta(hours=settings.MEMORY_CACHE_REGENERATION_HOURS)
workspace_reflection_schedule = timedelta(seconds=settings.WORKSPACE_REFLECTION_INTERVAL_SECONDS)
forgetting_cycle_schedule = timedelta(hours=settings.FORGETTING_CYCLE_INTERVAL_HOURS)
implicit_emotions_update_schedule = crontab(
hour=settings.IMPLICIT_EMOTIONS_UPDATE_HOUR,
minute=settings.IMPLICIT_EMOTIONS_UPDATE_MINUTE,
)
#构建定时任务配置
beat_schedule_config = {
@@ -120,6 +153,11 @@ beat_schedule_config = {
"schedule": memory_increment_schedule,
"args": (),
},
"update-implicit-emotions-storage": {
"task": "app.tasks.update_implicit_emotions_storage",
"schedule": implicit_emotions_update_schedule,
"args": (),
},
}
celery_app.conf.beat_schedule = beat_schedule_config

View File

@@ -396,10 +396,10 @@ async def draft_run(
from app.models import AgentConfig, ModelConfig
from sqlalchemy import select
from app.core.exceptions import BusinessException
from app.services.draft_run_service import DraftRunService
from app.services.draft_run_service import AgentRunService
service = AppService(db)
draft_service = DraftRunService(db)
draft_service = AgentRunService(db)
# 1. 验证应用
app = service._get_app_or_404(app_id)
@@ -484,8 +484,8 @@ async def draft_run(
}
)
from app.services.draft_run_service import DraftRunService
draft_service = DraftRunService(db)
from app.services.draft_run_service import AgentRunService
draft_service = AgentRunService(db)
result = await draft_service.run(
agent_config=agent_cfg,
model_config=model_config,
@@ -789,8 +789,8 @@ async def draft_run_compare(
# 流式返回
if payload.stream:
async def event_generator():
from app.services.draft_run_service import DraftRunService
draft_service = DraftRunService(db)
from app.services.draft_run_service import AgentRunService
draft_service = AgentRunService(db)
async for event in draft_service.run_compare_stream(
agent_config=agent_cfg,
models=model_configs,
@@ -820,8 +820,8 @@ async def draft_run_compare(
)
# 非流式返回
from app.services.draft_run_service import DraftRunService
draft_service = DraftRunService(db)
from app.services.draft_run_service import AgentRunService
draft_service = AgentRunService(db)
result = await draft_service.run_compare(
agent_config=agent_cfg,
models=model_configs,
@@ -835,7 +835,8 @@ async def draft_run_compare(
web_search=True,
memory=True,
parallel=payload.parallel,
timeout=payload.timeout or 60
timeout=payload.timeout or 60,
files=payload.files
)
logger.info(

View File

@@ -441,14 +441,14 @@ async def retrieve_chunks(
# 1 participle search, 2 semantic search, 3 hybrid search
match retrieve_data.retrieve_type:
case chunk_schema.RetrieveType.PARTICIPLE:
rs = vector_service.search_by_full_text(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.similarity_threshold)
rs = vector_service.search_by_full_text(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.similarity_threshold, file_names_filter=retrieve_data.file_names_filter)
return success(data=rs, msg="retrieval successful")
case chunk_schema.RetrieveType.SEMANTIC:
rs = vector_service.search_by_vector(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.vector_similarity_weight)
rs = vector_service.search_by_vector(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.vector_similarity_weight, file_names_filter=retrieve_data.file_names_filter)
return success(data=rs, msg="retrieval successful")
case _:
rs1 = vector_service.search_by_vector(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.vector_similarity_weight)
rs2 = vector_service.search_by_full_text(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.similarity_threshold)
rs1 = vector_service.search_by_vector(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.vector_similarity_weight, file_names_filter=retrieve_data.file_names_filter)
rs2 = vector_service.search_by_full_text(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.similarity_threshold, file_names_filter=retrieve_data.file_names_filter)
# Efficient deduplication
seen_ids = set()
unique_rs = []

View File

@@ -208,14 +208,64 @@ async def get_emotion_health(
# @router.post("/check-data", response_model=ApiResponse)
# async def check_emotion_data_exists(
# request: EmotionSuggestionsRequest,
# db: Session = Depends(get_db),
# current_user: User = Depends(get_current_user),
# ):
# """检查用户情绪建议数据是否存在
# Args:
# request: 包含 end_user_id
# db: 数据库会话
# current_user: 当前用户
# Returns:
# 数据存在状态
# """
# try:
# api_logger.info(
# f"检查用户情绪建议数据是否存在: {request.end_user_id}",
# extra={"end_user_id": request.end_user_id}
# )
# # 从数据库获取建议
# data = await emotion_service.get_cached_suggestions(
# end_user_id=request.end_user_id,
# db=db
# )
# if data is None:
# api_logger.info(f"用户 {request.end_user_id} 的情绪建议数据不存在")
# return fail(
# BizCode.NOT_FOUND,
# "情绪建议数据不存在,请点击右上角刷新进行初始化",
# {"exists": False}
# )
# api_logger.info(f"用户 {request.end_user_id} 的情绪建议数据存在")
# return success(data={"exists": True}, msg="情绪建议数据已存在")
# except Exception as e:
# api_logger.error(
# f"检查情绪建议数据失败: {str(e)}",
# extra={"end_user_id": request.end_user_id},
# exc_info=True
# )
# raise HTTPException(
# status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
# detail=f"检查情绪建议数据失败: {str(e)}"
# )
@router.post("/suggestions", response_model=ApiResponse)
async def get_emotion_suggestions(
request: EmotionSuggestionsRequest,
language_type: str = Header(default=None, alias="X-Language-Type"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""获取个性化情绪建议(从缓存读取)
"""获取个性化情绪建议(从数据库读取)
Args:
request: 包含 end_user_id 和可选的 config_id
@@ -223,77 +273,42 @@ async def get_emotion_suggestions(
current_user: 当前用户
Returns:
存的个性化情绪建议响应
的个性化情绪建议响应
"""
try:
# 使用集中化的语言校验
language = get_language_from_header(language_type)
api_logger.info(
f"用户 {current_user.username} 请求获取个性化情绪建议(缓存)",
f"用户 {current_user.username} 请求获取个性化情绪建议",
extra={
"end_user_id": request.end_user_id,
"config_id": request.config_id
}
)
# 从缓存获取建议
# 从数据库获取建议
data = await emotion_service.get_cached_suggestions(
end_user_id=request.end_user_id,
db=db
)
if data is None:
# 缓存不存在或已过期,自动触发生成
api_logger.info(
f"用户 {request.end_user_id} 的建议缓存不存在或已过期,自动生成新建议",
f"用户 {request.end_user_id} 的建议数据不存在",
extra={"end_user_id": request.end_user_id}
)
try:
data = await emotion_service.generate_emotion_suggestions(
end_user_id=request.end_user_id,
db=db,
language=language
)
# 保存到缓存
await emotion_service.save_suggestions_cache(
end_user_id=request.end_user_id,
suggestions_data=data,
db=db,
expires_hours=24
)
except (ValueError, KeyError) as gen_e:
# 预期内的业务异常:配置缺失、数据格式问题等
api_logger.warning(
f"自动生成建议失败(业务异常): {str(gen_e)}",
extra={"end_user_id": request.end_user_id}
)
return fail(
BizCode.NOT_FOUND,
f"自动生成建议失败: {str(gen_e)}",
""
)
except Exception as gen_e:
# 非预期异常:记录完整 traceback 便于排查
api_logger.error(
f"自动生成建议时发生未预期异常: {str(gen_e)}",
extra={"end_user_id": request.end_user_id},
exc_info=True
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"生成建议时发生内部错误: {str(gen_e)}"
)
return success(
data={"exists": False},
msg="情绪建议数据不存在,请点击右上角刷新进行初始化"
)
api_logger.info(
"个性化建议获取成功(缓存)",
"个性化建议获取成功",
extra={
"end_user_id": request.end_user_id,
"suggestions_count": len(data.get("suggestions", []))
}
)
return success(data=data, msg="个性化建议获取成功(缓存)")
return success(data=data, msg="个性化建议获取成功")
except Exception as e:
api_logger.error(
@@ -314,7 +329,7 @@ async def generate_emotion_suggestions(
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""生成个性化情绪建议调用LLM并缓存
"""生成个性化情绪建议调用LLM并保存到数据库
Args:
request: 包含 end_user_id
@@ -342,12 +357,11 @@ async def generate_emotion_suggestions(
language=language
)
# 保存到缓存
# 保存到数据库
await emotion_service.save_suggestions_cache(
end_user_id=request.end_user_id,
suggestions_data=data,
db=db,
expires_hours=24
db=db
)
api_logger.info(
@@ -369,4 +383,4 @@ async def generate_emotion_suggestions(
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"生成个性化建议失败: {str(e)}"
)
)

View File

@@ -122,6 +122,48 @@ def validate_confidence_threshold(threshold: float) -> None:
raise ValueError("confidence_threshold must be between 0.0 and 1.0")
@router.get("/check-data/{end_user_id}", response_model=ApiResponse)
@cur_workspace_access_guard()
async def check_user_data_exists(
end_user_id: str,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
) -> ApiResponse:
"""
检查用户画像数据是否存在
Args:
end_user_id: 目标用户ID
Returns:
数据存在状态
"""
api_logger.info(f"检查用户画像数据是否存在: {end_user_id}")
try:
# Validate inputs
validate_user_id(end_user_id)
# Create service with user-specific config
service = ImplicitMemoryService(db=db, end_user_id=end_user_id)
# Get cached profile
cached_profile = await service.get_cached_profile(end_user_id=end_user_id, db=db)
if cached_profile is None:
api_logger.info(f"用户 {end_user_id} 的画像数据不存在")
return success(
data={"exists": False},
msg="画像数据不存在,请点击右上角刷新进行初始化"
)
api_logger.info(f"用户 {end_user_id} 的画像数据存在")
return success(data={"exists": True}, msg="画像数据已存在")
except Exception as e:
return handle_implicit_memory_error(e, "检查画像数据", end_user_id)
@router.get("/preferences/{end_user_id}", response_model=ApiResponse)
@cur_workspace_access_guard()
async def get_preference_tags(
@@ -159,12 +201,8 @@ async def get_preference_tags(
cached_profile = await service.get_cached_profile(end_user_id=end_user_id, db=db)
if cached_profile is None:
api_logger.info(f"用户 {end_user_id} 的画像缓存不存在或已过期")
return fail(
BizCode.NOT_FOUND,
"画像缓存不存在或已过期,请右上角刷新生成新画像",
""
)
api_logger.info(f"用户 {end_user_id} 的画像数据不存在")
return fail(BizCode.NOT_FOUND, "", "")
# Extract preferences from cache
preferences = cached_profile.get("preferences", [])
@@ -230,12 +268,8 @@ async def get_dimension_portrait(
cached_profile = await service.get_cached_profile(end_user_id=end_user_id, db=db)
if cached_profile is None:
api_logger.info(f"用户 {end_user_id} 的画像缓存不存在或已过期")
return fail(
BizCode.NOT_FOUND,
"画像缓存不存在或已过期,请右上角刷新生成新画像",
""
)
api_logger.info(f"用户 {end_user_id} 的画像数据不存在")
return fail(BizCode.NOT_FOUND, "", "")
# Extract portrait from cache
portrait = cached_profile.get("portrait", {})
@@ -278,12 +312,8 @@ async def get_interest_area_distribution(
cached_profile = await service.get_cached_profile(end_user_id=end_user_id, db=db)
if cached_profile is None:
api_logger.info(f"用户 {end_user_id} 的画像缓存不存在或已过期")
return fail(
BizCode.NOT_FOUND,
"画像缓存不存在或已过期,请右上角刷新生成新画像",
""
)
api_logger.info(f"用户 {end_user_id} 的画像数据不存在")
return fail(BizCode.NOT_FOUND, "", "")
# Extract interest areas from cache
interest_areas = cached_profile.get("interest_areas", {})
@@ -330,12 +360,8 @@ async def get_behavior_habits(
cached_profile = await service.get_cached_profile(end_user_id=end_user_id, db=db)
if cached_profile is None:
api_logger.info(f"用户 {end_user_id} 的画像缓存不存在或已过期")
return fail(
BizCode.NOT_FOUND,
"画像缓存不存在或已过期,请右上角刷新生成新画像",
""
)
api_logger.info(f"用户 {end_user_id} 的画像数据不存在")
return fail(BizCode.NOT_FOUND, "", "")
# Extract habits from cache
habits = cached_profile.get("habits", [])

View File

@@ -90,7 +90,7 @@ async def get_mcp_servers(
cookies=cookies)
raise_for_http_status(r)
except requests.exceptions.RequestException as e:
api_logger.error(f"mFailed to get MCP servers: {str(e)}")
api_logger.error(f"Failed to get MCP servers: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to get MCP servers: {str(e)}"
@@ -118,6 +118,65 @@ async def get_mcp_servers(
return success(data=result, msg="Query of mcp servers list successful")
@router.get("/operational_mcp_servers", response_model=ApiResponse)
async def get_operational_mcp_servers(
mcp_market_config_id: uuid.UUID,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
Query the operational mcp servers list in pages
- Support keyword search for name,author,owner
- Return paging metadata + operational mcp server list
"""
api_logger.info(
f"Query operational mcp server list: tenant_id={current_user.tenant_id}, username: {current_user.username}")
# 1. Query mcp market config information from the database
api_logger.debug(f"Query mcp market config: {mcp_market_config_id}")
db_mcp_market_config = mcp_market_config_service.get_mcp_market_config_by_id(db,
mcp_market_config_id=mcp_market_config_id,
current_user=current_user)
if not db_mcp_market_config:
api_logger.warning(
f"The mcp market config does not exist or access is denied: mcp_market_config_id={mcp_market_config_id}")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="The mcp market config does not exist or access is denied"
)
# 2. Execute paged query
api = MCPApi()
token = db_mcp_market_config.token
api.login(token)
url = f'{api.mcp_base_url}/operational'
headers = api.builder_headers(api.headers)
try:
cookies = api.get_cookies(access_token=token, cookies_required=True)
r = api.session.get(url, headers=headers, cookies=cookies)
raise_for_http_status(r)
except requests.exceptions.RequestException as e:
api_logger.error(f"Failed to get operational MCP servers: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to get operational MCP servers: {str(e)}"
)
data = api._handle_response(r)
total = data.get('total_count', 0)
mcp_server_list = data.get('mcp_server_list', [])
# items = [{
# 'name': item.get('name', ''),
# 'id': item.get('id', ''),
# 'description': item.get('description', '')
# } for item in mcp_server_list]
# 3. Return structured response
return success(data=mcp_server_list, msg="Query of operational mcp servers list successful")
@router.get("/mcp_server", response_model=ApiResponse)
async def get_mcp_server(
mcp_market_config_id: uuid.UUID,

View File

@@ -1,28 +1,29 @@
from typing import List, Optional
from dotenv import load_dotenv
from fastapi import APIRouter, Depends, File, Form, Query, UploadFile, Header
from sqlalchemy.orm import Session
from starlette.responses import StreamingResponse
from app.cache.memory.interest_memory import InterestMemoryCache
from app.celery_app import celery_app
from app.core.error_codes import BizCode
from app.core.language_utils import get_language_from_header
from app.core.logging_config import get_api_logger
from app.core.memory.agent.utils.redis_tool import store
from app.core.memory.agent.utils.session_tools import SessionService
from app.core.rag.llm.cv_model import QWenCV
from app.core.response_utils import fail, success
from app.db import get_db
from app.dependencies import cur_workspace_access_guard, get_current_user
from app.models import ModelApiKey
from app.models.user_model import User
from app.core.memory.agent.utils.session_tools import SessionService
from app.core.memory.agent.utils.redis_tool import store
from app.repositories import knowledge_repository, WorkspaceRepository
from app.repositories import knowledge_repository
from app.schemas.memory_agent_schema import UserInput, Write_UserInput
from app.schemas.response_schema import ApiResponse
from app.services import task_service, workspace_service
from app.services.memory_agent_service import MemoryAgentService
from app.services.model_service import ModelConfigService
from dotenv import load_dotenv
from fastapi import APIRouter, Depends, File, Form, Query, UploadFile,Header
from sqlalchemy.orm import Session
from starlette.responses import StreamingResponse
load_dotenv()
api_logger = get_api_logger()
@@ -37,7 +38,7 @@ router = APIRouter(
@router.get("/health/status", response_model=ApiResponse)
async def get_health_status(
current_user: User = Depends(get_current_user)
current_user: User = Depends(get_current_user)
):
"""
Get latest health status written by Celery periodic task
@@ -55,8 +56,9 @@ async def get_health_status(
@router.get("/download_log")
async def download_log(
log_type: str = Query("file", regex="^(file|transmission)$", description="日志类型: file=完整文件, transmission=实时流式传输"),
current_user: User = Depends(get_current_user)
log_type: str = Query("file", regex="^(file|transmission)$",
description="日志类型: file=完整文件, transmission=实时流式传输"),
current_user: User = Depends(get_current_user)
):
"""
Download or stream agent service log file
@@ -75,16 +77,16 @@ async def download_log(
- transmission mode: StreamingResponse with SSE
"""
api_logger.info(f"Log download requested with log_type={log_type}")
# Validate log_type parameter (FastAPI Query regex already validates, but explicit check for clarity)
if log_type not in ["file", "transmission"]:
api_logger.warning(f"Invalid log_type parameter: {log_type}")
return fail(
BizCode.BAD_REQUEST,
"无效的log_type参数",
BizCode.BAD_REQUEST,
"无效的log_type参数",
"log_type必须是'file''transmission'"
)
# Route to appropriate mode
if log_type == "file":
# File mode: Return complete log file content
@@ -119,10 +121,10 @@ async def download_log(
@router.post("/writer_service", response_model=ApiResponse)
@cur_workspace_access_guard()
async def write_server(
user_input: Write_UserInput,
language_type: str = Header(default=None, alias="X-Language-Type"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
user_input: Write_UserInput,
language_type: str = Header(default=None, alias="X-Language-Type"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
Write service endpoint - processes write operations synchronously
@@ -136,11 +138,11 @@ async def write_server(
"""
# 使用集中化的语言校验
language = get_language_from_header(language_type)
config_id = user_input.config_id
workspace_id = current_user.current_workspace_id
api_logger.info(f"Write service: workspace_id={workspace_id}, config_id={config_id}, language_type={language}")
# 获取 storage_type如果为 None 则使用默认值
storage_type = workspace_service.get_workspace_storage_type(
db=db,
@@ -149,7 +151,7 @@ async def write_server(
)
if storage_type is None: storage_type = 'neo4j'
user_rag_memory_id = ''
# 如果 storage_type 是 rag必须确保有有效的 user_rag_memory_id
if storage_type == 'rag':
if workspace_id:
@@ -161,13 +163,15 @@ async def write_server(
if knowledge:
user_rag_memory_id = str(knowledge.id)
else:
api_logger.warning(f"未找到名为 'USER_RAG_MERORY' 的知识库workspace_id: {workspace_id},将使用 neo4j 存储")
api_logger.warning(
f"未找到名为 'USER_RAG_MERORY' 的知识库workspace_id: {workspace_id},将使用 neo4j 存储")
storage_type = 'neo4j'
else:
api_logger.warning("workspace_id 为空,无法使用 rag 存储,将使用 neo4j 存储")
storage_type = 'neo4j'
api_logger.info(f"Write service requested for group {user_input.end_user_id}, storage_type: {storage_type}, user_rag_memory_id: {user_rag_memory_id}")
api_logger.info(
f"Write service requested for group {user_input.end_user_id}, storage_type: {storage_type}, user_rag_memory_id: {user_rag_memory_id}")
try:
messages_list = memory_agent_service.get_messages_list(user_input)
result = await memory_agent_service.write_memory(
@@ -175,7 +179,7 @@ async def write_server(
messages_list,
config_id,
db,
storage_type,
storage_type,
user_rag_memory_id,
language
)
@@ -195,10 +199,10 @@ async def write_server(
@router.post("/writer_service_async", response_model=ApiResponse)
@cur_workspace_access_guard()
async def write_server_async(
user_input: Write_UserInput,
language_type: str = Header(default=None, alias="X-Language-Type"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
user_input: Write_UserInput,
language_type: str = Header(default=None, alias="X-Language-Type"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
Async write service endpoint - enqueues write processing to Celery
@@ -213,10 +217,11 @@ async def write_server_async(
"""
# 使用集中化的语言校验
language = get_language_from_header(language_type)
config_id = user_input.config_id
workspace_id = current_user.current_workspace_id
api_logger.info(f"Async write service: workspace_id={workspace_id}, config_id={config_id}, language_type={language}")
api_logger.info(
f"Async write service: workspace_id={workspace_id}, config_id={config_id}, language_type={language}")
# 获取 storage_type如果为 None 则使用默认值
storage_type = workspace_service.get_workspace_storage_type(
@@ -244,7 +249,7 @@ async def write_server_async(
args=[user_input.end_user_id, messages_list, config_id, storage_type, user_rag_memory_id, language]
)
api_logger.info(f"Write task queued: {task.id}")
return success(data={"task_id": task.id}, msg="写入任务已提交")
except Exception as e:
api_logger.error(f"Async write operation failed: {str(e)}")
@@ -254,9 +259,9 @@ async def write_server_async(
@router.post("/read_service", response_model=ApiResponse)
@cur_workspace_access_guard()
async def read_server(
user_input: UserInput,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
user_input: UserInput,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
Read service endpoint - processes read operations synchronously
@@ -291,8 +296,9 @@ async def read_server(
)
if knowledge:
user_rag_memory_id = str(knowledge.id)
api_logger.info(f"Read service: group={user_input.end_user_id}, storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}, workspace_id={workspace_id}")
api_logger.info(
f"Read service: group={user_input.end_user_id}, storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}, workspace_id={workspace_id}")
try:
result = await memory_agent_service.read_memory(
user_input.end_user_id,
@@ -306,7 +312,8 @@ async def read_server(
)
if str(user_input.search_switch) == "2":
retrieve_info = result['answer']
history = await SessionService(store).get_history(user_input.end_user_id, user_input.end_user_id, user_input.end_user_id)
history = await SessionService(store).get_history(user_input.end_user_id, user_input.end_user_id,
user_input.end_user_id)
query = user_input.message
# 调用 memory_agent_service 的方法生成最终答案
@@ -319,7 +326,7 @@ async def read_server(
db=db
)
if "信息不足,无法回答" in result['answer']:
result['answer']=retrieve_info
result['answer'] = retrieve_info
return success(data=result, msg="回复对话消息成功")
except BaseException as e:
# Handle ExceptionGroup from TaskGroup (Python 3.11+) or BaseExceptionGroup
@@ -335,9 +342,10 @@ async def read_server(
@router.post("/file", response_model=ApiResponse)
async def file_update(
files: List[UploadFile] = File(..., description="要上传的文件"),
model_id:str = Form(..., description="模型ID"),
model_id: str = Form(..., description="模型ID"),
metadata: Optional[str] = Form(None, description="文件元数据 (JSON格式)"),
current_user: User = Depends(get_current_user)
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):
"""
文件上传接口 - 支持图片识别
@@ -350,9 +358,6 @@ async def file_update(
Returns:
文件处理结果
"""
db_gen = get_db() # get_db 通常是一个生成器
db = next(db_gen)
api_logger.info(f"File upload requested, file count: {len(files)}")
config = ModelConfigService.get_model_by_id(db=db, model_id=model_id)
apiConfig: ModelApiKey = config.api_keys[0]
@@ -361,7 +366,7 @@ async def file_update(
for file in files:
api_logger.debug(f"Processing file: {file.filename}, content_type: {file.content_type}")
content = await file.read()
if file.content_type and file.content_type.startswith("image/"):
vision_model = QWenCV(
key=apiConfig.api_key,
@@ -375,12 +380,12 @@ async def file_update(
else:
api_logger.warning(f"Unsupported file type: {file.content_type}")
file_content.append(f"[不支持的文件类型: {file.content_type}]")
result_text = ';'.join(file_content)
api_logger.info(f"File processing completed, result length: {len(result_text)}")
return success(data=result_text, msg="转换文本成功")
except Exception as e:
api_logger.error(f"File processing failed: {str(e)}", exc_info=True)
return fail(BizCode.INTERNAL_ERROR, "转换文本失败", str(e))
@@ -430,8 +435,8 @@ async def read_server_async(
@router.get("/read_result/", response_model=ApiResponse)
async def get_read_task_result(
task_id: str,
current_user: User = Depends(get_current_user)
task_id: str,
current_user: User = Depends(get_current_user)
):
"""
Get the status and result of an async read task
@@ -452,7 +457,7 @@ async def get_read_task_result(
try:
result = task_service.get_task_memory_read_result(task_id)
status = result.get("status")
if status == "SUCCESS":
# 任务成功完成
task_result = result.get("result", {})
@@ -470,7 +475,7 @@ async def get_read_task_result(
else:
# 旧格式:直接返回结果
return success(data=task_result, msg="查询任务已完成")
elif status == "FAILURE":
# 任务失败
error_info = result.get("result", "Unknown error")
@@ -479,7 +484,7 @@ async def get_read_task_result(
else:
error_msg = str(error_info)
return fail(BizCode.INTERNAL_ERROR, "查询任务失败", error_msg)
elif status in ["PENDING", "STARTED"]:
# 任务进行中
return success(
@@ -499,7 +504,7 @@ async def get_read_task_result(
},
msg=f"任务状态: {status}"
)
except Exception as e:
api_logger.error(f"Read task status check failed: {str(e)}", exc_info=True)
return fail(BizCode.INTERNAL_ERROR, "任务状态查询失败", str(e))
@@ -507,8 +512,8 @@ async def get_read_task_result(
@router.get("/write_result/", response_model=ApiResponse)
async def get_write_task_result(
task_id: str,
current_user: User = Depends(get_current_user)
task_id: str,
current_user: User = Depends(get_current_user)
):
"""
Get the status and result of an async write task
@@ -529,7 +534,7 @@ async def get_write_task_result(
try:
result = task_service.get_task_memory_write_result(task_id)
status = result.get("status")
if status == "SUCCESS":
# 任务成功完成
task_result = result.get("result", {})
@@ -547,7 +552,7 @@ async def get_write_task_result(
else:
# 旧格式:直接返回结果
return success(data=task_result, msg="写入任务已完成")
elif status == "FAILURE":
# 任务失败
error_info = result.get("result", "Unknown error")
@@ -556,7 +561,7 @@ async def get_write_task_result(
else:
error_msg = str(error_info)
return fail(BizCode.INTERNAL_ERROR, "写入任务失败", error_msg)
elif status in ["PENDING", "STARTED"]:
# 任务进行中
return success(
@@ -576,7 +581,7 @@ async def get_write_task_result(
},
msg=f"任务状态: {status}"
)
except Exception as e:
api_logger.error(f"Write task status check failed: {str(e)}", exc_info=True)
return fail(BizCode.INTERNAL_ERROR, "任务状态查询失败", str(e))
@@ -584,9 +589,9 @@ async def get_write_task_result(
@router.post("/status_type", response_model=ApiResponse)
async def status_type(
user_input: Write_UserInput,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
user_input: Write_UserInput,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
Determine the type of user message (read or write)
@@ -629,9 +634,10 @@ async def status_type(
@router.get("/stats/types", response_model=ApiResponse)
async def get_knowledge_type_stats_api(
end_user_id: Optional[str] = Query(None, description="用户ID可选"),
only_active: bool = Query(True, description="仅统计有效记录(status=1)"),
current_user: User = Depends(get_current_user)
end_user_id: Optional[str] = Query(None, description="用户ID可选"),
only_active: bool = Query(True, description="仅统计有效记录(status=1)"),
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):
"""
统计当前空间下各知识库类型的数量,包含 General | Web | Third-party | Folder。
@@ -640,14 +646,9 @@ async def get_knowledge_type_stats_api(
- 知识库类型根据当前用户的 current_workspace_id 过滤
- 如果用户没有当前工作空间,对应的统计返回 0
"""
api_logger.info(f"Knowledge type stats requested for workspace_id: {current_user.current_workspace_id}, end_user_id: {end_user_id}")
api_logger.info(
f"Knowledge type stats requested for workspace_id: {current_user.current_workspace_id}, end_user_id: {end_user_id}")
try:
from app.db import get_db
# 获取数据库会话
db_gen = get_db()
db = next(db_gen)
# 调用service层函数
result = await memory_agent_service.get_knowledge_type_stats(
end_user_id=end_user_id,
@@ -655,7 +656,7 @@ async def get_knowledge_type_stats_api(
current_workspace_id=current_user.current_workspace_id,
db=db
)
return success(data=result, msg="获取知识库类型统计成功")
except Exception as e:
api_logger.error(f"Knowledge type stats failed: {str(e)}")
@@ -664,11 +665,11 @@ async def get_knowledge_type_stats_api(
@router.get("/analytics/interest_distribution/by_user", response_model=ApiResponse)
async def get_interest_distribution_by_user_api(
end_user_id: str = Query(..., description="用户ID必填"),
limit: int = Query(5, le=5, description="返回兴趣标签数量限制最多5个"),
language_type: str = Header(default=None, alias="X-Language-Type"),
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
end_user_id: str = Query(..., description="用户ID必填"),
limit: int = Query(5, le=5, description="返回兴趣标签数量限制最多5个"),
language_type: str = Header(default=None, alias="X-Language-Type"),
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):
"""
获取指定用户的兴趣分布标签
@@ -716,9 +717,9 @@ async def get_interest_distribution_by_user_api(
@router.get("/analytics/user_profile", response_model=ApiResponse)
async def get_user_profile_api(
end_user_id: Optional[str] = Query(None, description="用户ID可选"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
end_user_id: Optional[str] = Query(None, description="用户ID可选"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
获取用户详情,包含:
@@ -756,17 +757,17 @@ async def get_user_profile_api(
# ):
# """
# Get parsed API documentation (Public endpoint - no authentication required)
# Args:
# file_path: Optional path to API docs file. If None, uses default path.
# Returns:
# Parsed API documentation including title, meta info, and sections
# """
# api_logger.info(f"API docs requested, file_path: {file_path or 'default'}")
# try:
# result = await memory_agent_service.get_api_docs(file_path)
# if result.get("success"):
# return success(msg=result["msg"], data=result["data"])
# else:
@@ -782,9 +783,9 @@ async def get_user_profile_api(
@router.get("/end_user/{end_user_id}/connected_config", response_model=ApiResponse)
async def get_end_user_connected_config(
end_user_id: str,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
end_user_id: str,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
获取终端用户关联的记忆配置
@@ -803,9 +804,9 @@ async def get_end_user_connected_config(
from app.services.memory_agent_service import (
get_end_user_connected_config as get_config,
)
api_logger.info(f"Getting connected config for end_user: {end_user_id}")
try:
result = get_config(end_user_id, db)
return success(data=result, msg="获取终端用户关联配置成功")
@@ -814,4 +815,4 @@ async def get_end_user_connected_config(
return fail(BizCode.NOT_FOUND, str(e))
except Exception as e:
api_logger.error(f"Failed to get end user connected config: {str(e)}", exc_info=True)
return fail(BizCode.INTERNAL_ERROR, "获取终端用户关联配置失败", str(e))
return fail(BizCode.INTERNAL_ERROR, "获取终端用户关联配置失败", str(e))

View File

@@ -606,8 +606,8 @@ async def dashboard_data(
# 获取RAG相关数据
try:
# total_memory: 使用 total_chunkchunk数
total_chunk = memory_dashboard_service.get_rag_total_chunk(db, current_user)
# total_memory: 只统计用户知识库permission_id='Memory')的chunk数
total_chunk = memory_dashboard_service.get_rag_user_kb_total_chunk(db, current_user)
rag_data["total_memory"] = total_chunk
# total_app: 统计当前空间下的所有app数量

View File

@@ -2,7 +2,7 @@ from typing import Optional
from uuid import UUID
from fastapi import APIRouter, Depends, Query
from fastapi.responses import StreamingResponse
from fastapi.responses import StreamingResponse, JSONResponse
from sqlalchemy.orm import Session
from app.core.error_codes import BizCode
@@ -85,6 +85,7 @@ def create_config(
payload: ConfigParamsCreate,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
x_language_type: Optional[str] = Header(None, alias="X-Language-Type"),
) -> dict:
workspace_id = current_user.current_workspace_id
# 检查用户是否已选择工作空间
@@ -99,7 +100,29 @@ def create_config(
svc = DataConfigService(db)
result = svc.create(payload)
return success(data=result, msg="创建成功")
except ValueError as e:
err_str = str(e)
if err_str.startswith("DUPLICATE_CONFIG_NAME:"):
config_name = err_str.split(":", 1)[1]
api_logger.warning(f"重复的配置名称 '{config_name}' 在工作空间 {workspace_id}")
lang = get_language_from_header(x_language_type)
if lang == "en":
msg = fail(BizCode.BAD_REQUEST, "Config name already exists", f"A config named \"{config_name}\" already exists in the current workspace. Please use a different name.")
else:
msg = fail(BizCode.BAD_REQUEST, "配置名称已存在", f"当前工作空间下已存在名为「{config_name}」的记忆配置,请使用其他名称")
return JSONResponse(status_code=400, content=msg)
api_logger.error(f"Create config failed: {err_str}")
return fail(BizCode.INTERNAL_ERROR, "创建配置失败", err_str)
except Exception as e:
from sqlalchemy.exc import IntegrityError
if isinstance(e, IntegrityError) and "uq_workspace_config_name" in str(getattr(e, 'orig', '')):
api_logger.warning(f"重复的配置名称 '{payload.config_name}' 在工作空间 {workspace_id}")
lang = get_language_from_header(x_language_type)
if lang == "en":
msg = fail(BizCode.BAD_REQUEST, "Config name already exists", f"A config named \"{payload.config_name}\" already exists in the current workspace. Please use a different name.")
else:
msg = fail(BizCode.BAD_REQUEST, "配置名称已存在", f"当前工作空间下已存在名为「{payload.config_name}」的记忆配置,请使用其他名称")
return JSONResponse(status_code=400, content=msg)
api_logger.error(f"Create config failed: {str(e)}")
return fail(BizCode.INTERNAL_ERROR, "创建配置失败", str(e))

View File

@@ -469,7 +469,9 @@ async def create_model_api_key_by_provider(
config=api_key_data.config,
is_active=api_key_data.is_active,
priority=api_key_data.priority,
model_config_ids=model_config_ids
model_config_ids=model_config_ids,
capability=api_key_data.capability,
is_omni=api_key_data.is_omni
)
created_keys, failed_models = await ModelApiKeyService.create_api_key_by_provider(db=db, data=create_data)

View File

@@ -25,7 +25,7 @@ from typing import Dict, Optional, List
from urllib.parse import quote
from fastapi import APIRouter, Depends, HTTPException, File, UploadFile, Form, Header
from fastapi.responses import StreamingResponse
from fastapi.responses import StreamingResponse, JSONResponse
from sqlalchemy.orm import Session
from app.core.config import settings
@@ -124,15 +124,23 @@ def _get_ontology_service(
)
# 通过 Repository 获取可用的 API Key负载均衡逻辑由 Repository 处理)
from app.repositories.model_repository import ModelApiKeyRepository
api_keys = ModelApiKeyRepository.get_by_model_config(db, model_config.id)
if not api_keys:
# from app.repositories.model_repository import ModelApiKeyRepository
from app.services.model_service import ModelApiKeyService
api_key_config = ModelApiKeyService.get_available_api_key(db, model_config.id)
if not api_key_config:
logger.error(f"Model {llm_id} has no active API key")
raise HTTPException(
status_code=400,
detail="指定的LLM模型没有可用的API密钥"
)
api_key_config = api_keys[0]
# api_keys = ModelApiKeyRepository.get_by_model_config(db, model_config.id)
# if not api_keys:
# logger.error(f"Model {llm_id} has no active API key")
# raise HTTPException(
# status_code=400,
# detail="指定的LLM模型没有可用的API密钥"
# )
# api_key_config = api_keys[0]
is_composite = getattr(model_config, 'is_composite', False)
logger.info(
@@ -154,6 +162,7 @@ def _get_ontology_service(
provider=actual_provider,
api_key=api_key_config.api_key,
base_url=api_key_config.api_base,
is_omni=api_key_config.is_omni,
max_retries=3,
timeout=60.0
)
@@ -280,7 +289,8 @@ async def extract_ontology(
async def create_scene(
request: SceneCreateRequest,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
current_user: User = Depends(get_current_user),
x_language_type: Optional[str] = Header(None, alias="X-Language-Type")
):
"""创建本体场景
@@ -351,8 +361,18 @@ async def create_scene(
return fail(BizCode.BAD_REQUEST, "请求参数无效", str(e))
except RuntimeError as e:
api_logger.error(f"Runtime error in scene creation: {str(e)}", exc_info=True)
return fail(BizCode.INTERNAL_ERROR, "场景创建失败", str(e))
err_str = str(e)
if "UniqueViolation" in err_str or "uq_workspace_scene_name" in err_str:
api_logger.warning(f"Duplicate scene name '{request.scene_name}' in workspace {current_user.current_workspace_id}")
from app.core.language_utils import get_language_from_header
lang = get_language_from_header(x_language_type)
if lang == "en":
msg = fail(BizCode.BAD_REQUEST, "Scene name already exists", f"A scene named \"{request.scene_name}\" already exists in the current workspace. Please use a different name.")
else:
msg = fail(BizCode.BAD_REQUEST, "场景名称已存在", f"当前工作空间下已存在名为「{request.scene_name}」的场景,请使用其他名称")
return JSONResponse(status_code=400, content=msg)
api_logger.error(f"Runtime error in scene creation: {err_str}", exc_info=True)
return fail(BizCode.INTERNAL_ERROR, "场景创建失败", err_str)
except Exception as e:
api_logger.error(f"Unexpected error in scene creation: {str(e)}", exc_info=True)
@@ -514,10 +534,9 @@ async def delete_scene(
f"尝试删除系统默认场景: user_id={current_user.id}, "
f"scene_id={scene_id}, scene_name={scene.scene_name}"
)
return fail(
BizCode.BAD_REQUEST,
"系统默认场景不可删除",
"该场景为系统预设场景,不允许删除"
raise HTTPException(
status_code=400,
detail="SYSTEM_DEFAULT_SCENE_CANNOT_DELETE"
)
# 创建OntologyService实例
@@ -543,6 +562,9 @@ async def delete_scene(
return success(data={"deleted": success_flag}, msg="场景删除成功")
except HTTPException:
raise
except ValueError as e:
api_logger.warning(f"Validation error in scene deletion: {str(e)}")
return fail(BizCode.BAD_REQUEST, "请求参数无效", str(e))
@@ -650,7 +672,8 @@ async def get_scenes(
async def create_class(
request: ClassCreateRequest,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
current_user: User = Depends(get_current_user),
x_language_type: Optional[str] = Header(None, alias="X-Language-Type")
):
"""创建本体类型
@@ -665,7 +688,7 @@ async def create_class(
ApiResponse: 包含创建的类型信息
"""
from app.controllers.ontology_secondary_routes import create_class_handler
return await create_class_handler(request, db, current_user)
return await create_class_handler(request, db, current_user, x_language_type)
@router.put("/class/{class_id}", response_model=ApiResponse)

View File

@@ -7,7 +7,7 @@
from uuid import UUID
from typing import Optional
from fastapi import Depends
from fastapi import Depends, Header
from sqlalchemy.orm import Session
from app.core.error_codes import BizCode
@@ -58,7 +58,7 @@ async def scenes_handler(
workspace_id: Optional[str] = None,
scene_name: Optional[str] = None,
page: Optional[int] = None,
page_size: Optional[int] = None,
pagesize: Optional[int] = None,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
@@ -71,14 +71,14 @@ async def scenes_handler(
workspace_id: 工作空间ID可选默认当前用户工作空间
scene_name: 场景名称关键词(可选,支持模糊匹配)
page: 页码可选从1开始仅在全量查询时有效
page_size: 每页数量(可选,仅在全量查询时有效)
pagesize: 每页数量(可选,仅在全量查询时有效)
db: 数据库会话
current_user: 当前用户
"""
operation = "search" if scene_name else "list"
api_logger.info(
f"Scene {operation} requested by user {current_user.id}, "
f"workspace_id={workspace_id}, keyword={scene_name}, page={page}, page_size={page_size}"
f"workspace_id={workspace_id}, keyword={scene_name}, page={page}, pagesize={pagesize}"
)
try:
@@ -105,13 +105,13 @@ async def scenes_handler(
api_logger.warning(f"Invalid page number: {page}")
return fail(BizCode.BAD_REQUEST, "请求参数无效", "页码必须大于0")
if page_size is not None and page_size < 1:
api_logger.warning(f"Invalid page_size: {page_size}")
if pagesize is not None and pagesize < 1:
api_logger.warning(f"Invalid pagesize: {pagesize}")
return fail(BizCode.BAD_REQUEST, "请求参数无效", "每页数量必须大于0")
# 如果只提供了page或page_size中的一个返回错误
if (page is not None and page_size is None) or (page is None and page_size is not None):
api_logger.warning(f"Incomplete pagination params: page={page}, page_size={page_size}")
# 如果只提供了page或pagesize中的一个返回错误
if (page is not None and pagesize is None) or (page is None and pagesize is not None):
api_logger.warning(f"Incomplete pagination params: page={page}, pagesize={pagesize}")
return fail(BizCode.BAD_REQUEST, "请求参数无效", "分页参数page和pagesize必须同时提供")
# 模糊搜索场景(支持分页)
@@ -119,17 +119,15 @@ async def scenes_handler(
total = len(scenes)
# 如果提供了分页参数,进行分页处理
if page is not None and page_size is not None:
start_idx = (page - 1) * page_size
end_idx = start_idx + page_size
if page is not None and pagesize is not None:
start_idx = (page - 1) * pagesize
end_idx = start_idx + pagesize
scenes = scenes[start_idx:end_idx]
# 构建响应
items = []
for scene in scenes:
# 获取前3个class_name作为entity_type
entity_type = [cls.class_name for cls in scene.classes[:3]] if scene.classes else None
# 动态计算 type_num
type_num = len(scene.classes) if scene.classes else 0
items.append(SceneResponse(
@@ -141,17 +139,16 @@ async def scenes_handler(
workspace_id=scene.workspace_id,
created_at=scene.created_at,
updated_at=scene.updated_at,
classes_count=type_num
classes_count=type_num,
is_system_default=scene.is_system_default
))
# 构建响应(包含分页信息)
if page is not None and page_size is not None:
# 计算是否有下一页
hasnext = (page * page_size) < total
if page is not None and pagesize is not None:
hasnext = (page * pagesize) < total
pagination_info = PaginationInfo(
page=page,
pagesize=page_size,
pagesize=pagesize,
total=total,
hasnext=hasnext
)
@@ -165,28 +162,25 @@ async def scenes_handler(
)
else:
# 获取所有场景(支持分页)
# 验证分页参数
if page is not None and page < 1:
api_logger.warning(f"Invalid page number: {page}")
return fail(BizCode.BAD_REQUEST, "请求参数无效", "页码必须大于0")
if page_size is not None and page_size < 1:
api_logger.warning(f"Invalid page_size: {page_size}")
if pagesize is not None and pagesize < 1:
api_logger.warning(f"Invalid pagesize: {pagesize}")
return fail(BizCode.BAD_REQUEST, "请求参数无效", "每页数量必须大于0")
# 如果只提供了page或page_size中的一个返回错误
if (page is not None and page_size is None) or (page is None and page_size is not None):
api_logger.warning(f"Incomplete pagination params: page={page}, page_size={page_size}")
# 如果只提供了page或pagesize中的一个返回错误
if (page is not None and pagesize is None) or (page is None and pagesize is not None):
api_logger.warning(f"Incomplete pagination params: page={page}, pagesize={pagesize}")
return fail(BizCode.BAD_REQUEST, "请求参数无效", "分页参数page和pagesize必须同时提供")
scenes, total = service.list_scenes(ws_uuid, page, page_size)
scenes, total = service.list_scenes(ws_uuid, page, pagesize)
# 构建响应
items = []
for scene in scenes:
# 获取前3个class_name作为entity_type
entity_type = [cls.class_name for cls in scene.classes[:3]] if scene.classes else None
# 动态计算 type_num
type_num = len(scene.classes) if scene.classes else 0
items.append(SceneResponse(
@@ -198,17 +192,16 @@ async def scenes_handler(
workspace_id=scene.workspace_id,
created_at=scene.created_at,
updated_at=scene.updated_at,
classes_count=type_num
classes_count=type_num,
is_system_default=scene.is_system_default
))
# 构建响应(包含分页信息)
if page is not None and page_size is not None:
# 计算是否有下一页
hasnext = (page * page_size) < total
if page is not None and pagesize is not None:
hasnext = (page * pagesize) < total
pagination_info = PaginationInfo(
page=page,
pagesize=page_size,
pagesize=pagesize,
total=total,
hasnext=hasnext
)
@@ -238,7 +231,8 @@ async def scenes_handler(
async def create_class_handler(
request: ClassCreateRequest,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
current_user: User = Depends(get_current_user),
x_language_type: Optional[str] = None
):
"""创建本体类型(统一使用列表形式,支持单个或批量)"""
@@ -271,8 +265,11 @@ async def create_class_handler(
]
if count == 1:
# 单个创建
# 单个创建 - 先检查重名
class_data = classes_data[0]
existing = OntologyClassRepository(db).get_by_name(class_data["class_name"], request.scene_id)
if existing:
raise ValueError(f"DUPLICATE_CLASS_NAME:{class_data['class_name']}")
ontology_class = service.create_class(
scene_id=request.scene_id,
class_name=class_data["class_name"],
@@ -330,12 +327,36 @@ async def create_class_handler(
return success(data=response.model_dump(mode='json'), msg="批量创建完成")
except ValueError as e:
api_logger.warning(f"Validation error in class creation: {str(e)}")
return fail(BizCode.BAD_REQUEST, "请求参数无效", str(e))
err_str = str(e)
if err_str.startswith("DUPLICATE_CLASS_NAME:"):
class_name = err_str.split(":", 1)[1]
api_logger.warning(f"Duplicate class name '{class_name}' in scene {request.scene_id}")
from app.core.language_utils import get_language_from_header
from fastapi.responses import JSONResponse
lang = get_language_from_header(x_language_type)
if lang == "en":
msg = fail(BizCode.BAD_REQUEST, "Class name already exists", f"A class named \"{class_name}\" already exists in this scene. Please use a different name.")
else:
msg = fail(BizCode.BAD_REQUEST, "类型名称已存在", f"当前场景下已存在名为「{class_name}」的类型,请使用其他名称")
return JSONResponse(status_code=400, content=msg)
api_logger.warning(f"Validation error in class creation: {err_str}")
return fail(BizCode.BAD_REQUEST, "请求参数无效", err_str)
except RuntimeError as e:
api_logger.error(f"Runtime error in class creation: {str(e)}", exc_info=True)
return fail(BizCode.INTERNAL_ERROR, "类型创建失败", str(e))
err_str = str(e)
if "UniqueViolation" in err_str or "uq_scene_class_name" in err_str:
api_logger.warning(f"Duplicate class name in scene {request.scene_id}")
from app.core.language_utils import get_language_from_header
from fastapi.responses import JSONResponse
lang = get_language_from_header(x_language_type)
class_name = request.classes[0].class_name if request.classes else ""
if lang == "en":
msg = fail(BizCode.BAD_REQUEST, "Class name already exists", f"A class named \"{class_name}\" already exists in this scene. Please use a different name.")
else:
msg = fail(BizCode.BAD_REQUEST, "类型名称已存在", f"当前场景下已存在名为「{class_name}」的类型,请使用其他名称")
return JSONResponse(status_code=400, content=msg)
api_logger.error(f"Runtime error in class creation: {err_str}", exc_info=True)
return fail(BizCode.INTERNAL_ERROR, "类型创建失败", err_str)
except Exception as e:
api_logger.error(f"Unexpected error in class creation: {str(e)}", exc_info=True)
@@ -615,6 +636,7 @@ async def classes_handler(
scene_id=scene_uuid,
scene_name=scene.scene_name,
scene_description=scene.scene_description,
is_system_default=scene.is_system_default,
items=items
)

View File

@@ -249,6 +249,7 @@ async def chat(
app_id=app.id,
workspace_id=workspace_id,
release_id=app.current_release.id,
public=True
):
event_type = event.get("event", "message")
event_data = event.get("data", {})

View File

@@ -39,7 +39,7 @@ async def write_memory_api_service(
Stores memory content for the specified end user using the Memory API Service.
"""
logger.info(f"Memory write request - end_user_id: {payload.end_user_id}, tenant_id: {api_key_auth.tenant_id}")
logger.info(f"Memory write request - end_user_id: {payload.end_user_id}, workspace_id: {api_key_auth.workspace_id}")
memory_api_service = MemoryAPIService(db)

View File

@@ -11,35 +11,37 @@ LangChain Agent 封装
import time
from typing import Any, AsyncGenerator, Dict, List, Optional, Sequence
from app.core.memory.agent.langgraph_graph.write_graph import write_long_term
from app.core.memory.agent.langgraph_graph.write_graph import write_long_term
from app.db import get_db
from app.core.logging_config import get_business_logger
from app.core.models import RedBearLLM, RedBearModelConfig
from app.models.models_model import ModelType
from app.models.models_model import ModelType, ModelProvider
from app.services.memory_agent_service import (
get_end_user_connected_config,
)
from langchain.agents import create_agent
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
from langchain_core.tools import BaseTool
logger = get_business_logger()
class LangChainAgent:
def __init__(
self,
model_name: str,
api_key: str,
provider: str = "openai",
api_base: Optional[str] = None,
temperature: float = 0.7,
max_tokens: int = 2000,
system_prompt: Optional[str] = None,
tools: Optional[Sequence[BaseTool]] = None,
streaming: bool = False,
max_iterations: Optional[int] = None, # 最大迭代次数None 表示自动计算)
max_tool_consecutive_calls: int = 3 # 单个工具最大连续调用次数
self,
model_name: str,
api_key: str,
provider: str = "openai",
api_base: Optional[str] = None,
is_omni: bool = False,
temperature: float = 0.7,
max_tokens: int = 2000,
system_prompt: Optional[str] = None,
tools: Optional[Sequence[BaseTool]] = None,
streaming: bool = False,
max_iterations: Optional[int] = None, # 最大迭代次数None 表示自动计算)
max_tool_consecutive_calls: int = 3 # 单个工具最大连续调用次数
):
"""初始化 LangChain Agent
@@ -60,12 +62,13 @@ class LangChainAgent:
self.provider = provider
self.tools = tools or []
self.streaming = streaming
self.is_omni = is_omni
self.max_tool_consecutive_calls = max_tool_consecutive_calls
# 工具调用计数器:记录每个工具的连续调用次数
self.tool_call_counter: Dict[str, int] = {}
self.last_tool_called: Optional[str] = None
# 根据工具数量动态调整最大迭代次数
# 基础值 + 每个工具额外的调用机会
if max_iterations is None:
@@ -73,9 +76,9 @@ class LangChainAgent:
self.max_iterations = 5 + len(self.tools) * 2
else:
self.max_iterations = max_iterations
self.system_prompt = system_prompt or "你是一个专业的AI助手"
logger.debug(
f"Agent 迭代次数配置: max_iterations={self.max_iterations}, "
f"tool_count={len(self.tools)}, "
@@ -89,6 +92,7 @@ class LangChainAgent:
provider=provider,
api_key=api_key,
base_url=api_base,
is_omni=is_omni,
extra_params={
"temperature": temperature,
"max_tokens": max_tokens,
@@ -143,21 +147,22 @@ class LangChainAgent:
"""
from langchain_core.tools import StructuredTool
from functools import wraps
wrapped_tools = []
for original_tool in tools:
tool_name = original_tool.name
original_func = original_tool.func if hasattr(original_tool, 'func') else None
if not original_func:
# 如果无法获取原始函数,直接使用原工具
wrapped_tools.append(original_tool)
continue
# 创建包装函数
def make_wrapped_func(tool_name, original_func):
"""创建包装函数的工厂函数,避免闭包问题"""
@wraps(original_func)
def wrapped_func(*args, **kwargs):
"""包装后的工具函数,跟踪连续调用次数"""
@@ -168,13 +173,13 @@ class LangChainAgent:
# 切换到新工具,重置计数器
self.tool_call_counter[tool_name] = 1
self.last_tool_called = tool_name
current_count = self.tool_call_counter[tool_name]
logger.debug(
f"工具调用: {tool_name}, 连续调用次数: {current_count}/{self.max_tool_consecutive_calls}"
)
# 检查是否超过最大连续调用次数
if current_count > self.max_tool_consecutive_calls:
logger.warning(
@@ -185,12 +190,12 @@ class LangChainAgent:
f"工具 '{tool_name}' 已连续调用 {self.max_tool_consecutive_calls} 次,"
f"未找到有效结果。请尝试其他方法或直接回答用户的问题。"
)
# 调用原始工具函数
return original_func(*args, **kwargs)
return wrapped_func
# 使用 StructuredTool 创建新工具
wrapped_tool = StructuredTool(
name=original_tool.name,
@@ -198,17 +203,17 @@ class LangChainAgent:
func=make_wrapped_func(tool_name, original_func),
args_schema=original_tool.args_schema if hasattr(original_tool, 'args_schema') else None
)
wrapped_tools.append(wrapped_tool)
return wrapped_tools
def _prepare_messages(
self,
message: str,
history: Optional[List[Dict[str, str]]] = None,
context: Optional[str] = None,
files: Optional[List[Dict[str, Any]]] = None
self,
message: str,
history: Optional[List[Dict[str, str]]] = None,
context: Optional[str] = None,
files: Optional[List[Dict[str, Any]]] = None
) -> List[BaseMessage]:
"""准备消息列表
@@ -248,7 +253,7 @@ class LangChainAgent:
messages.append(HumanMessage(content=user_content))
return messages
def _build_multimodal_content(self, text: str, files: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""
构建多模态消息内容
@@ -261,23 +266,26 @@ class LangChainAgent:
List[Dict]: 消息内容列表
"""
# 根据 provider 使用不同的文本格式
if self.provider.lower() in ["bedrock", "anthropic"]:
# Anthropic/Bedrock: {"type": "text", "text": "..."}
content_parts = [{"type": "text", "text": text}]
else:
# 通义千问等: {"text": "..."}
content_parts = [{"text": text}]
# if (self.provider.lower() in [ModelProvider.BEDROCK, ModelProvider.OPENAI, ModelProvider.XINFERENCE,
# ModelProvider.GPUSTACK] or (
# self.provider.lower() == ModelProvider.DASHSCOPE and self.is_omni)):
# # Anthropic/Bedrock/Xinference/Gpustack/Openai: {"type": "text", "text": "..."}
# content_parts = [{"type": "text", "text": text}]
# else:
# # 通义千问等: {"text": "..."}
# content_parts = [{"type": "text", "text": text}]
content_parts = [{"type": "text", "text": text}]
# 添加文件内容
# MultimodalService 已经根据 provider 返回了正确格式,直接使用
content_parts.extend(files)
logger.debug(
f"构建多模态消息: provider={self.provider}, "
f"parts={len(content_parts)}, "
f"files={len(files)}"
)
return content_parts
async def chat(
@@ -302,7 +310,7 @@ class LangChainAgent:
Returns:
Dict: 包含 content 和元数据的字典
"""
message_chat= message
message_chat = message
start_time = time.time()
actual_config_id = config_id
# If config_id is None, try to get from end_user's connected config
@@ -322,8 +330,8 @@ class LangChainAgent:
except Exception as e:
logger.warning(f"Failed to get db session: {e}")
actual_end_user_id = end_user_id if end_user_id is not None else "unknown"
logger.info(f'写入类型{storage_type,str(end_user_id), message, str(user_rag_memory_id)}')
print(f'写入类型{storage_type,str(end_user_id), message, str(user_rag_memory_id)}')
logger.info(f'写入类型{storage_type, str(end_user_id), message, str(user_rag_memory_id)}')
print(f'写入类型{storage_type, str(end_user_id), message, str(user_rag_memory_id)}')
try:
# 准备消息列表(支持多模态)
messages = self._prepare_messages(message, history, context, files)
@@ -367,14 +375,14 @@ class LangChainAgent:
# 获取最后的 AI 消息
output_messages = result.get("messages", [])
content = ""
logger.debug(f"输出消息数量: {len(output_messages)}")
total_tokens = 0
for msg in reversed(output_messages):
if isinstance(msg, AIMessage):
logger.debug(f"找到 AI 消息content 类型: {type(msg.content)}")
logger.debug(f"AI 消息内容: {msg.content}")
# 处理多模态响应content 可能是字符串或列表
if isinstance(msg.content, str):
content = msg.content
@@ -407,12 +415,13 @@ class LangChainAgent:
response_meta = msg.response_metadata if hasattr(msg, 'response_metadata') else None
total_tokens = response_meta.get("token_usage", {}).get("total_tokens", 0) if response_meta else 0
break
logger.info(f"最终提取的内容长度: {len(content)}")
elapsed_time = time.time() - start_time
if memory_flag:
await write_long_term(storage_type, end_user_id, message_chat, content, user_rag_memory_id, actual_config_id)
await write_long_term(storage_type, end_user_id, message_chat, content, user_rag_memory_id,
actual_config_id)
response = {
"content": content,
"model": self.model_name,
@@ -439,16 +448,16 @@ class LangChainAgent:
raise
async def chat_stream(
self,
message: str,
history: Optional[List[Dict[str, str]]] = None,
context: Optional[str] = None,
end_user_id:Optional[str] = None,
config_id: Optional[str] = None,
storage_type:Optional[str] = None,
user_rag_memory_id:Optional[str] = None,
memory_flag: Optional[bool] = True,
files: Optional[List[Dict[str, Any]]] = None # 新增:多模态文件
self,
message: str,
history: Optional[List[Dict[str, str]]] = None,
context: Optional[str] = None,
end_user_id: Optional[str] = None,
config_id: Optional[str] = None,
storage_type: Optional[str] = None,
user_rag_memory_id: Optional[str] = None,
memory_flag: Optional[bool] = True,
files: Optional[List[Dict[str, Any]]] = None # 新增:多模态文件
) -> AsyncGenerator[str, None]:
"""执行流式对话
@@ -482,7 +491,6 @@ class LangChainAgent:
except Exception as e:
logger.warning(f"Failed to get db session: {e}")
# 注意:不在这里写入用户消息,等 AI 回复后一起写入
try:
# 准备消息列表(支持多模态)
@@ -500,13 +508,13 @@ class LangChainAgent:
full_content = ''
try:
async for event in self.agent.astream_events(
{"messages": messages},
version="v2",
config={"recursion_limit": self.max_iterations}
{"messages": messages},
version="v2",
config={"recursion_limit": self.max_iterations}
):
chunk_count += 1
kind = event.get("event")
# 处理所有可能的流式事件
if kind == "on_chat_model_stream":
# LLM 流式输出
@@ -540,7 +548,7 @@ class LangChainAgent:
full_content += item
yield item
yielded_content = True
elif kind == "on_llm_stream":
# 另一种 LLM 流式事件
chunk = event.get("data", {}).get("chunk")
@@ -577,13 +585,13 @@ class LangChainAgent:
full_content += chunk
yield chunk
yielded_content = True
# 记录工具调用(可选)
elif kind == "on_tool_start":
logger.debug(f"工具调用开始: {event.get('name')}")
elif kind == "on_tool_end":
logger.debug(f"工具调用结束: {event.get('name')}")
logger.debug(f"Agent 流式完成,共 {chunk_count} 个事件")
# 统计token消耗
output_messages = event.get("data", {}).get("output", {}).get("messages", [])
@@ -595,7 +603,8 @@ class LangChainAgent:
yield total_tokens
break
if memory_flag:
await write_long_term(storage_type, end_user_id, message_chat, full_content, user_rag_memory_id, actual_config_id)
await write_long_term(storage_type, end_user_id, message_chat, full_content, user_rag_memory_id,
actual_config_id)
except Exception as e:
logger.error(f"Agent astream_events 失败: {str(e)}", exc_info=True)
raise
@@ -609,5 +618,3 @@ class LangChainAgent:
logger.info("=" * 80)
logger.info("chat_stream 方法执行结束")
logger.info("=" * 80)

View File

@@ -190,8 +190,10 @@ class Settings:
LOG_FILE_MAX_SIZE_MB: int = int(os.getenv("LOG_FILE_MAX_SIZE_MB", "10")) # 10MB
# Celery configuration (internal)
CELERY_BROKER: int = int(os.getenv("CELERY_BROKER", "1"))
CELERY_BACKEND: int = int(os.getenv("CELERY_BACKEND", "2"))
# NOTE: 变量名不以 CELERY_ 开头,避免被 Celery CLI 的前缀匹配机制劫持
# 详见 docs/celery-env-bug-report.md
REDIS_DB_CELERY_BROKER: int = int(os.getenv("REDIS_DB_CELERY_BROKER", "1"))
REDIS_DB_CELERY_BACKEND: int = int(os.getenv("REDIS_DB_CELERY_BACKEND", "2"))
# SMTP Email Configuration
SMTP_SERVER: str = os.getenv("SMTP_SERVER", "smtp.gmail.com")
@@ -219,8 +221,12 @@ class Settings:
FORGETTING_CYCLE_INTERVAL_HOURS: int = TypeAdapter(
Annotated[int, Field(ge=1, description="forgetting cycle interval in hours, must be >= 1")]
).validate_python(int(os.getenv("FORGETTING_CYCLE_INTERVAL_HOURS", "24")))
IMPLICIT_EMOTIONS_UPDATE_HOUR: int = int(os.getenv("IMPLICIT_EMOTIONS_UPDATE_HOUR", "2"))
# implicit_emotions_update: 每天几分执行分钟0-59
IMPLICIT_EMOTIONS_UPDATE_MINUTE: int = int(os.getenv("IMPLICIT_EMOTIONS_UPDATE_MINUTE", "0"))
# Memory Module Configuration (internal)
MEMORY_OUTPUT_DIR: str = os.getenv("MEMORY_OUTPUT_DIR", "logs/memory-output")
MEMORY_CONFIG_DIR: str = os.getenv("MEMORY_CONFIG_DIR", "app/core/memory")

View File

@@ -1,10 +1,10 @@
import os
import json
import os
import time
from app.core.logging_config import get_agent_logger
from app.db import get_db
from app.core.logging_config import get_agent_logger
from app.core.memory.agent.models.problem_models import ProblemExtensionResponse
from app.core.memory.agent.services.optimized_llm_service import LLMServiceMixin
from app.core.memory.agent.utils.llm_tools import (
PROJECT_ROOT_,
ReadState,
@@ -12,10 +12,9 @@ from app.core.memory.agent.utils.llm_tools import (
from app.core.memory.agent.utils.redis_tool import store
from app.core.memory.agent.utils.session_tools import SessionService
from app.core.memory.agent.utils.template_tools import TemplateService
from app.core.memory.agent.services.optimized_llm_service import LLMServiceMixin
from app.db import get_db_context
template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt')
db_session = next(get_db())
logger = get_agent_logger(__name__)
@@ -53,13 +52,14 @@ async def Split_The_Problem(state: ReadState) -> ReadState:
try:
# 使用优化的LLM服务
structured = await problem_service.call_llm_structured(
state=state,
db_session=db_session,
system_prompt=system_prompt,
response_model=ProblemExtensionResponse,
fallback_value=[]
)
with get_db_context() as db_session:
structured = await problem_service.call_llm_structured(
state=state,
db_session=db_session,
system_prompt=system_prompt,
response_model=ProblemExtensionResponse,
fallback_value=[]
)
# 添加更详细的日志记录
logger.info(f"Split_The_Problem: 开始处理问题分解,内容长度: {len(content)}")
@@ -171,13 +171,14 @@ async def Problem_Extension(state: ReadState) -> ReadState:
try:
# 使用优化的LLM服务
response_content = await problem_service.call_llm_structured(
state=state,
db_session=db_session,
system_prompt=system_prompt,
response_model=ProblemExtensionResponse,
fallback_value=[]
)
with get_db_context() as db_session:
response_content = await problem_service.call_llm_structured(
state=state,
db_session=db_session,
system_prompt=system_prompt,
response_model=ProblemExtensionResponse,
fallback_value=[]
)
logger.info(f"Problem_Extension: 开始处理问题扩展,问题数量: {len(databasets)}")

View File

@@ -6,31 +6,26 @@ import os
# ===== 第三方库 =====
from langchain.agents import create_agent
from langchain_openai import ChatOpenAI
from app.core.logging_config import get_agent_logger
from app.db import get_db, get_db_context
from app.schemas import model_schema
from app.services.memory_config_service import MemoryConfigService
from app.services.model_service import ModelConfigService
from app.core.memory.agent.services.search_service import SearchService
from app.core.memory.agent.utils.llm_tools import (
COUNTState,
ReadState,
deduplicate_entries,
merge_to_key_value_pairs,
)
from app.core.memory.agent.langgraph_graph.tools.tool import (
create_hybrid_retrieval_tool_sync,
create_time_retrieval_tool,
extract_tool_message_content,
)
from app.core.memory.agent.services.search_service import SearchService
from app.core.memory.agent.utils.llm_tools import (
ReadState,
deduplicate_entries,
merge_to_key_value_pairs,
)
from app.core.rag.nlp.search import knowledge_retrieval
from app.db import get_db_context
from app.schemas import model_schema
from app.services.memory_config_service import MemoryConfigService
from app.services.model_service import ModelConfigService
logger = get_agent_logger(__name__)
db = next(get_db())
async def rag_config(state):
@@ -50,10 +45,12 @@ async def rag_config(state):
"reranker_top_k": 10
}
return kb_config
async def rag_knowledge(state,question):
async def rag_knowledge(state, question):
kb_config = await rag_config(state)
end_user_id = state.get('end_user_id', '')
user_rag_memory_id=state.get("user_rag_memory_id",'')
user_rag_memory_id = state.get("user_rag_memory_id", '')
retrieve_chunks_result = knowledge_retrieval(question, kb_config, [str(end_user_id)])
try:
retrieval_knowledge = [i.page_content for i in retrieve_chunks_result]
@@ -61,13 +58,13 @@ async def rag_knowledge(state,question):
cleaned_query = question
raw_results = clean_content
logger.info(f" Using RAG storage with memory_id={user_rag_memory_id}")
except Exception :
retrieval_knowledge=[]
except Exception:
retrieval_knowledge = []
clean_content = ''
raw_results = ''
cleaned_query = question
logger.info(f"No content retrieved from knowledge base: {user_rag_memory_id}")
return retrieval_knowledge,clean_content,cleaned_query,raw_results
return retrieval_knowledge, clean_content, cleaned_query, raw_results
async def llm_infomation(state: ReadState) -> ReadState:
@@ -113,7 +110,7 @@ async def clean_databases(data) -> str:
# 收集所有内容
content_list = []
# 处理重排序结果
reranked = results.get('reranked_results', {})
if reranked:
@@ -141,7 +138,6 @@ async def clean_databases(data) -> str:
elif isinstance(item, str):
text_parts.append(item)
return '\n'.join(text_parts).strip()
except Exception as e:
@@ -150,23 +146,23 @@ async def clean_databases(data) -> str:
async def retrieve_nodes(state: ReadState) -> ReadState:
'''
模型信息
'''
problem_extension=state.get('problem_extension', '')['context']
storage_type=state.get('storage_type', '')
user_rag_memory_id=state.get('user_rag_memory_id', '')
end_user_id=state.get('end_user_id', '')
problem_extension = state.get('problem_extension', '')['context']
storage_type = state.get('storage_type', '')
user_rag_memory_id = state.get('user_rag_memory_id', '')
end_user_id = state.get('end_user_id', '')
memory_config = state.get('memory_config', None)
original=state.get('data', '')
problem_list=[]
for key,values in problem_extension.items():
original = state.get('data', '')
problem_list = []
for key, values in problem_extension.items():
for data in values:
problem_list.append(data)
logger.info(f"Retrieve: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}")
# 创建异步任务处理单个问题
async def process_question_nodes(idx, question):
try:
@@ -244,7 +240,7 @@ async def retrieve_nodes(state: ReadState) -> ReadState:
send_verify = []
for i, j in zip(keys, val, strict=False):
if j!=['']:
if j != ['']:
send_verify.append({
"Query_small": i,
"Answer_Small": j
@@ -257,15 +253,13 @@ async def retrieve_nodes(state: ReadState) -> ReadState:
}
logger.info(f"Collected {len(intermediate_outputs)} intermediate outputs from search results")
return {'retrieve':dup_databases}
return {'retrieve': dup_databases}
async def retrieve(state: ReadState) -> ReadState:
# 从state中获取end_user_id
import time
start=time.time()
start = time.time()
problem_extension = state.get('problem_extension', '')['context']
storage_type = state.get('storage_type', '')
user_rag_memory_id = state.get('user_rag_memory_id', '')
@@ -283,6 +277,7 @@ async def retrieve(state: ReadState) -> ReadState:
with get_db_context() as db: # 使用同步数据库上下文管理器
config_service = MemoryConfigService(db)
return await llm_infomation(state)
llm_config = await get_llm_info()
api_key_obj = llm_config.api_keys[0]
api_key = api_key_obj.api_key
@@ -296,11 +291,11 @@ async def retrieve(state: ReadState) -> ReadState:
)
time_retrieval_tool = create_time_retrieval_tool(end_user_id)
search_params = { "end_user_id": end_user_id, "return_raw_results": True }
hybrid_retrieval=create_hybrid_retrieval_tool_sync(memory_config, **search_params)
search_params = {"end_user_id": end_user_id, "return_raw_results": True}
hybrid_retrieval = create_hybrid_retrieval_tool_sync(memory_config, **search_params)
agent = create_agent(
llm,
tools=[time_retrieval_tool,hybrid_retrieval],
tools=[time_retrieval_tool, hybrid_retrieval],
system_prompt=f"我是检索专家可以根据适合的工具进行检索。当前使用的end_user_id是: {end_user_id}"
)
@@ -314,7 +309,8 @@ async def retrieve(state: ReadState) -> ReadState:
async with SEMAPHORE: # 限制并发
try:
if storage_type == "rag" and user_rag_memory_id:
retrieval_knowledge, clean_content, cleaned_query, raw_results = await rag_knowledge(state, question)
retrieval_knowledge, clean_content, cleaned_query, raw_results = await rag_knowledge(state,
question)
else:
cleaned_query = question
# 使用 asyncio 在线程池中运行同步的 agent.invoke
@@ -413,5 +409,3 @@ async def retrieve(state: ReadState) -> ReadState:
# json.dump(dup_databases, f, indent=4)
logger.info(f"Collected {len(intermediate_outputs)} intermediate outputs from search results")
return {'retrieve': dup_databases}

View File

@@ -1,5 +1,3 @@
import os
import time
@@ -18,22 +16,24 @@ from app.core.memory.agent.utils.redis_tool import store
from app.core.memory.agent.utils.session_tools import SessionService
from app.core.memory.agent.utils.template_tools import TemplateService
from app.core.rag.nlp.search import knowledge_retrieval
from app.db import get_db
from app.db import get_db_context
template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt')
logger = get_agent_logger(__name__)
db_session = next(get_db())
class SummaryNodeService(LLMServiceMixin):
"""总结节点服务类"""
def __init__(self):
super().__init__()
self.template_service = TemplateService(template_root)
# 创建全局服务实例
summary_service = SummaryNodeService()
async def rag_config(state):
user_rag_memory_id = state.get('user_rag_memory_id', '')
kb_config = {
@@ -51,10 +51,12 @@ async def rag_config(state):
"reranker_top_k": 10
}
return kb_config
async def rag_knowledge(state,question):
async def rag_knowledge(state, question):
kb_config = await rag_config(state)
end_user_id = state.get('end_user_id', '')
user_rag_memory_id=state.get("user_rag_memory_id",'')
user_rag_memory_id = state.get("user_rag_memory_id", '')
retrieve_chunks_result = knowledge_retrieval(question, kb_config, [str(end_user_id)])
try:
retrieval_knowledge = [i.page_content for i in retrieve_chunks_result]
@@ -62,25 +64,28 @@ async def rag_knowledge(state,question):
cleaned_query = question
raw_results = clean_content
logger.info(f" Using RAG storage with memory_id={user_rag_memory_id}")
except Exception :
retrieval_knowledge=[]
except Exception:
retrieval_knowledge = []
clean_content = ''
raw_results = ''
cleaned_query = question
logger.info(f"No content retrieved from knowledge base: {user_rag_memory_id}")
return retrieval_knowledge,clean_content,cleaned_query,raw_results
return retrieval_knowledge, clean_content, cleaned_query, raw_results
async def summary_history(state: ReadState) -> ReadState:
end_user_id = state.get("end_user_id", '')
history = await SessionService(store).get_history(end_user_id, end_user_id, end_user_id)
return history
async def summary_llm(state: ReadState, history, retrieve_info, template_name, operation_name, response_model,search_mode) -> str:
async def summary_llm(state: ReadState, history, retrieve_info, template_name, operation_name, response_model,
search_mode) -> str:
"""
增强的summary_llm函数包含更好的错误处理和数据验证
"""
data = state.get("data", '')
# 构建系统提示词
if str(search_mode) == "0":
system_prompt = await summary_service.template_service.render_template(
@@ -99,18 +104,19 @@ async def summary_llm(state: ReadState, history, retrieve_info, template_name, o
)
try:
# 使用优化的LLM服务进行结构化输出
structured = await summary_service.call_llm_structured(
state=state,
db_session=db_session,
system_prompt=system_prompt,
response_model=response_model,
fallback_value=None
)
with get_db_context() as db_session:
structured = await summary_service.call_llm_structured(
state=state,
db_session=db_session,
system_prompt=system_prompt,
response_model=response_model,
fallback_value=None
)
# 验证结构化响应
if structured is None:
logger.warning("LLM返回None使用默认回答")
return "信息不足,无法回答"
# 根据操作类型提取答案
if operation_name == "summary":
aimessages = getattr(structured, 'query_answer', None) or "信息不足,无法回答"
@@ -121,16 +127,16 @@ async def summary_llm(state: ReadState, history, retrieve_info, template_name, o
else:
logger.warning("结构化响应缺少data字段")
aimessages = "信息不足,无法回答"
# 验证答案不为空
if not aimessages or aimessages.strip() == "":
aimessages = "信息不足,无法回答"
return aimessages
except Exception as e:
logger.error(f"结构化输出失败: {e}", exc_info=True)
# 尝试非结构化输出作为fallback
try:
logger.info("尝试非结构化输出作为fallback")
@@ -140,7 +146,7 @@ async def summary_llm(state: ReadState, history, retrieve_info, template_name, o
system_prompt=system_prompt,
fallback_message="信息不足,无法回答"
)
if response and response.strip():
# 简单清理响应
cleaned_response = response.strip()
@@ -148,16 +154,17 @@ async def summary_llm(state: ReadState, history, retrieve_info, template_name, o
if cleaned_response.startswith('```'):
lines = cleaned_response.split('\n')
cleaned_response = '\n'.join(lines[1:-1])
return cleaned_response
else:
return "信息不足,无法回答"
except Exception as fallback_error:
logger.error(f"Fallback也失败: {fallback_error}")
return "信息不足,无法回答"
async def summary_redis_save(state: ReadState,aimessages) -> ReadState:
async def summary_redis_save(state: ReadState, aimessages) -> ReadState:
data = state.get("data", '')
end_user_id = state.get("end_user_id", '')
await SessionService(store).save_session(
@@ -169,10 +176,12 @@ async def summary_redis_save(state: ReadState,aimessages) -> ReadState:
)
await SessionService(store).cleanup_duplicates()
logger.info(f"sessionid: {aimessages} 写入成功")
async def summary_prompt(state: ReadState,aimessages,raw_results) -> ReadState:
storage_type=state.get("storage_type",'')
user_rag_memory_id=state.get("user_rag_memory_id",'')
data=state.get("data", '')
async def summary_prompt(state: ReadState, aimessages, raw_results) -> ReadState:
storage_type = state.get("storage_type", '')
user_rag_memory_id = state.get("user_rag_memory_id", '')
data = state.get("data", '')
input_summary = {
"status": "success",
"summary_result": aimessages,
@@ -189,14 +198,14 @@ async def summary_prompt(state: ReadState,aimessages,raw_results) -> ReadState:
"user_rag_memory_id": user_rag_memory_id
}
}
retrieve={
retrieve = {
"status": "success",
"summary_result": aimessages,
"storage_type": storage_type,
"user_rag_memory_id": user_rag_memory_id,
"_intermediate": {
"type": "retrieval_summary",
"title":"快速检索",
"title": "快速检索",
"summary": aimessages,
"query": data,
"storage_type": storage_type,
@@ -204,17 +213,18 @@ async def summary_prompt(state: ReadState,aimessages,raw_results) -> ReadState:
}
}
return input_summary,retrieve
return input_summary, retrieve
async def Input_Summary(state: ReadState) -> ReadState:
start=time.time()
storage_type=state.get("storage_type",'')
start = time.time()
storage_type = state.get("storage_type", '')
memory_config = state.get('memory_config', None)
user_rag_memory_id=state.get("user_rag_memory_id",'')
data=state.get("data", '')
end_user_id=state.get("end_user_id", '')
user_rag_memory_id = state.get("user_rag_memory_id", '')
data = state.get("data", '')
end_user_id = state.get("end_user_id", '')
logger.info(f"Input_Summary: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}")
history = await summary_history( state)
history = await summary_history(state)
search_params = {
"end_user_id": end_user_id,
"question": data,
@@ -223,12 +233,13 @@ async def Input_Summary(state: ReadState) -> ReadState:
}
try:
if storage_type!="rag":
retrieve_info, question, raw_results = await SearchService().execute_hybrid_search(**search_params, memory_config=memory_config)
if storage_type != "rag":
retrieve_info, question, raw_results = await SearchService().execute_hybrid_search(**search_params,
memory_config=memory_config)
else:
retrieval_knowledge, retrieve_info, question, raw_results = await rag_knowledge(state, data)
except Exception as e:
logger.error( f"Input_Summary: hybrid_search failed, using empty results: {e}", exc_info=True )
logger.error(f"Input_Summary: hybrid_search failed, using empty results: {e}", exc_info=True)
retrieve_info, question, raw_results = "", data, []
try:
# aimessages=await summary_llm(state,history,retrieve_info,'Retrieve_Summary_prompt.jinja2',
@@ -237,8 +248,8 @@ async def Input_Summary(state: ReadState) -> ReadState:
summary_result = await summary_prompt(state, retrieve_info, retrieve_info)
summary = summary_result[0]
except Exception as e:
logger.error( f"Input_Summary failed: {e}", exc_info=True )
summary= {
logger.error(f"Input_Summary failed: {e}", exc_info=True)
summary = {
"status": "fail",
"summary_result": "信息不足,无法回答",
"storage_type": storage_type,
@@ -251,30 +262,31 @@ async def Input_Summary(state: ReadState) -> ReadState:
except Exception:
duration = 0.0
log_time('检索', duration)
return {"summary":summary}
return {"summary": summary}
async def Retrieve_Summary(state: ReadState)-> ReadState:
retrieve=state.get("retrieve", '')
history = await summary_history( state)
async def Retrieve_Summary(state: ReadState) -> ReadState:
retrieve = state.get("retrieve", '')
history = await summary_history(state)
import json
with open("检索.json","w",encoding='utf-8') as f:
with open("检索.json", "w", encoding='utf-8') as f:
f.write(json.dumps(retrieve, indent=4, ensure_ascii=False))
retrieve=retrieve.get("Expansion_issue", [])
start=time.time()
retrieve_info_str=[]
retrieve = retrieve.get("Expansion_issue", [])
start = time.time()
retrieve_info_str = []
for data in retrieve:
if data=='':
retrieve_info_str=''
if data == '':
retrieve_info_str = ''
else:
for key, value in data.items():
if key=='Answer_Small':
if key == 'Answer_Small':
for i in value:
retrieve_info_str.append(i)
retrieve_info_str=list(set(retrieve_info_str))
retrieve_info_str='\n'.join(retrieve_info_str)
retrieve_info_str = list(set(retrieve_info_str))
retrieve_info_str = '\n'.join(retrieve_info_str)
aimessages=await summary_llm(state,history,retrieve_info_str,
'direct_summary_prompt.jinja2','retrieve_summary',RetrieveSummaryResponse,"1")
aimessages = await summary_llm(state, history, retrieve_info_str,
'direct_summary_prompt.jinja2', 'retrieve_summary', RetrieveSummaryResponse, "1")
if '信息不足,无法回答' not in str(aimessages) or str(aimessages) != "":
await summary_redis_save(state, aimessages)
if aimessages == '':
@@ -286,33 +298,33 @@ async def Retrieve_Summary(state: ReadState)-> ReadState:
except Exception:
duration = 0.0
log_time('Retrieval summary', duration)
# 修复协程调用 - 先await然后访问返回值
summary_result = await summary_prompt(state, aimessages, retrieve_info_str)
summary = summary_result[1]
return {"summary":summary}
return {"summary": summary}
async def Summary(state: ReadState)-> ReadState:
start=time.time()
async def Summary(state: ReadState) -> ReadState:
start = time.time()
query = state.get("data", '')
verify=state.get("verify", '')
verify_expansion_issue=verify.get("verified_data", '')
retrieve_info_str=''
verify = state.get("verify", '')
verify_expansion_issue = verify.get("verified_data", '')
retrieve_info_str = ''
for data in verify_expansion_issue:
for key, value in data.items():
if key=='answer_small':
if key == 'answer_small':
for i in value:
retrieve_info_str+=i+'\n'
history=await summary_history(state)
retrieve_info_str += i + '\n'
history = await summary_history(state)
data = {
"query": query,
"history": history,
"retrieve_info": retrieve_info_str
}
aimessages=await summary_llm(state,history,data,
'summary_prompt.jinja2','summary',SummaryResponse,0)
aimessages = await summary_llm(state, history, data,
'summary_prompt.jinja2', 'summary', SummaryResponse, 0)
if '信息不足,无法回答' not in str(aimessages) or str(aimessages) != "":
await summary_redis_save(state, aimessages)
@@ -327,10 +339,12 @@ async def Summary(state: ReadState)-> ReadState:
# 修复协程调用 - 先await然后访问返回值
summary_result = await summary_prompt(state, aimessages, retrieve_info_str)
summary = summary_result[1]
return {"summary":summary}
async def Summary_fails(state: ReadState)-> ReadState:
storage_type=state.get("storage_type", '')
user_rag_memory_id=state.get("user_rag_memory_id", '')
return {"summary": summary}
async def Summary_fails(state: ReadState) -> ReadState:
storage_type = state.get("storage_type", '')
user_rag_memory_id = state.get("user_rag_memory_id", '')
history = await summary_history(state)
query = state.get("data", '')
verify = state.get("verify", '')
@@ -346,12 +360,12 @@ async def Summary_fails(state: ReadState)-> ReadState:
"history": history,
"retrieve_info": retrieve_info_str
}
aimessages = await summary_llm(state, history, data,
'fail_summary_prompt.jinja2', 'summary', SummaryResponse, 0)
result= {
aimessages = await summary_llm(state, history, data,
'fail_summary_prompt.jinja2', 'summary', SummaryResponse, 0)
result = {
"status": "success",
"summary_result": aimessages,
"storage_type": storage_type,
"user_rag_memory_id": user_rag_memory_id
}
return {"summary":result}
return {"summary": result}

View File

@@ -1,8 +1,9 @@
import asyncio
import os
from app.core.logging_config import get_agent_logger
from app.db import get_db
from app.core.logging_config import get_agent_logger
from app.core.memory.agent.models.verification_models import VerificationResult
from app.core.memory.agent.services.optimized_llm_service import LLMServiceMixin
from app.core.memory.agent.utils.llm_tools import (
PROJECT_ROOT_,
ReadState,
@@ -10,28 +11,30 @@ from app.core.memory.agent.utils.llm_tools import (
from app.core.memory.agent.utils.redis_tool import store
from app.core.memory.agent.utils.session_tools import SessionService
from app.core.memory.agent.utils.template_tools import TemplateService
from app.core.memory.agent.services.optimized_llm_service import LLMServiceMixin
from app.db import get_db_context
template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt')
db_session = next(get_db())
logger = get_agent_logger(__name__)
class VerificationNodeService(LLMServiceMixin):
"""验证节点服务类"""
def __init__(self):
super().__init__()
self.template_service = TemplateService(template_root)
# 创建全局服务实例
verification_service = VerificationNodeService()
async def Verify_prompt(state: ReadState, messages_deal: VerificationResult):
"""处理验证结果并生成输出格式"""
storage_type = state.get('storage_type', '')
user_rag_memory_id = state.get('user_rag_memory_id', '')
data = state.get('data', '')
# 将 VerificationItem 对象转换为字典列表
verified_data = []
if messages_deal.expansion_issue:
@@ -40,7 +43,7 @@ async def Verify_prompt(state: ReadState, messages_deal: VerificationResult):
verified_data.append(item.model_dump())
elif isinstance(item, dict):
verified_data.append(item)
Verify_result = {
"status": messages_deal.split_result,
"verified_data": verified_data,
@@ -58,34 +61,37 @@ async def Verify_prompt(state: ReadState, messages_deal: VerificationResult):
}
}
return Verify_result
async def Verify(state: ReadState):
logger.info("=== Verify 节点开始执行 ===")
try:
content = state.get('data', '')
end_user_id = state.get('end_user_id', '')
memory_config = state.get('memory_config', None)
logger.info(f"Verify: content={content[:50] if content else 'empty'}..., end_user_id={end_user_id}")
history = await SessionService(store).get_history(end_user_id, end_user_id, end_user_id)
logger.info(f"Verify: 获取历史记录完成history length={len(history)}")
retrieve = state.get("retrieve", {})
logger.info(f"Verify: retrieve data type={type(retrieve)}, keys={retrieve.keys() if isinstance(retrieve, dict) else 'N/A'}")
logger.info(
f"Verify: retrieve data type={type(retrieve)}, keys={retrieve.keys() if isinstance(retrieve, dict) else 'N/A'}")
retrieve_expansion = retrieve.get("Expansion_issue", []) if isinstance(retrieve, dict) else []
logger.info(f"Verify: Expansion_issue length={len(retrieve_expansion)}")
messages = {
"Query": content,
"Expansion_issue": retrieve_expansion
}
logger.info("Verify: 开始渲染模板")
# 生成 JSON schema 以指导 LLM 输出正确格式
json_schema = VerificationResult.model_json_schema()
system_prompt = await verification_service.template_service.render_template(
template_name='split_verify_prompt.jinja2',
operation_name='split_verify_prompt',
@@ -94,29 +100,30 @@ async def Verify(state: ReadState):
json_schema=json_schema
)
logger.info(f"Verify: 模板渲染完成prompt length={len(system_prompt)}")
# 使用优化的LLM服务添加超时保护
logger.info("Verify: 开始调用 LLM")
try:
# 添加 asyncio.wait_for 超时包裹,防止无限等待
# 超时时间设置为 150 秒(比 LLM 配置的 120 秒稍长)
import asyncio
structured = await asyncio.wait_for(
verification_service.call_llm_structured(
state=state,
db_session=db_session,
system_prompt=system_prompt,
response_model=VerificationResult,
fallback_value={
"query": content,
"history": history if isinstance(history, list) else [],
"expansion_issue": [],
"split_result": "failed",
"reason": "验证失败或超时"
}
),
timeout=150.0 # 150秒超时
)
with get_db_context() as db_session:
structured = await asyncio.wait_for(
verification_service.call_llm_structured(
state=state,
db_session=db_session,
system_prompt=system_prompt,
response_model=VerificationResult,
fallback_value={
"query": content,
"history": history if isinstance(history, list) else [],
"expansion_issue": [],
"split_result": "failed",
"reason": "验证失败或超时"
}
),
timeout=150.0 # 150秒超时
)
logger.info(f"Verify: LLM 调用完成result={structured}")
except asyncio.TimeoutError:
logger.error("Verify: LLM 调用超时150秒使用 fallback 值")
@@ -127,11 +134,11 @@ async def Verify(state: ReadState):
split_result="failed",
reason="LLM调用超时"
)
result = await Verify_prompt(state, structured)
logger.info("=== Verify 节点执行完成 ===")
return {"verify": result}
except Exception as e:
logger.error(f"Verify 节点执行失败: {e}", exc_info=True)
# 返回失败的验证结果
@@ -152,4 +159,4 @@ async def Verify(state: ReadState):
"user_rag_memory_id": state.get('user_rag_memory_id', '')
}
}
}
}

View File

@@ -5,7 +5,6 @@ from langchain_core.messages import HumanMessage
from langgraph.constants import START, END
from langgraph.graph import StateGraph
from app.db import get_db
from app.services.memory_config_service import MemoryConfigService
@@ -32,7 +31,6 @@ from app.core.memory.agent.langgraph_graph.routing.routers import (
)
@asynccontextmanager
async def make_read_graph():
"""创建并返回 LangGraph 工作流"""
@@ -49,7 +47,7 @@ async def make_read_graph():
workflow.add_node("Retrieve_Summary", Retrieve_Summary)
workflow.add_node("Summary", Summary)
workflow.add_node("Summary_fails", Summary_fails)
# 添加边
workflow.add_edge(START, "content_input")
workflow.add_conditional_edges("content_input", Split_continue)
@@ -62,20 +60,20 @@ async def make_read_graph():
workflow.add_edge("Summary_fails", END)
workflow.add_edge("Summary", END)
'''-----'''
# workflow.add_edge("Retrieve", END)
# 编译工作流
graph = workflow.compile()
yield graph
except Exception as e:
print(f"创建工作流失败: {e}")
raise
finally:
print("工作流创建完成")
async def main():
"""主函数 - 运行工作流"""
message = "昨天有什么好看的电影"
@@ -92,17 +90,19 @@ async def main():
service_name="MemoryAgentService"
)
import time
start=time.time()
start = time.time()
try:
async with make_read_graph() as graph:
config = {"configurable": {"thread_id": end_user_id}}
# 初始状态 - 包含所有必要字段
initial_state = {"messages": [HumanMessage(content=message)] ,"search_switch":search_switch,"end_user_id":end_user_id
,"storage_type":storage_type,"user_rag_memory_id":user_rag_memory_id,"memory_config":memory_config}
initial_state = {"messages": [HumanMessage(content=message)], "search_switch": search_switch,
"end_user_id": end_user_id
, "storage_type": storage_type, "user_rag_memory_id": user_rag_memory_id,
"memory_config": memory_config}
# 获取节点更新信息
_intermediate_outputs = []
summary = ''
async for update_event in graph.astream(
initial_state,
stream_mode="updates",
@@ -110,7 +110,7 @@ async def main():
):
for node_name, node_data in update_event.items():
print(f"处理节点: {node_name}")
# 处理不同Summary节点的返回结构
if 'Summary' in node_name:
if 'InputSummary' in node_data and 'summary_result' in node_data['InputSummary']:
@@ -125,23 +125,22 @@ async def main():
spit_data = node_data.get('spit_data', {}).get('_intermediate', None)
if spit_data and spit_data != [] and spit_data != {}:
_intermediate_outputs.append(spit_data)
# Problem_Extension 节点
problem_extension = node_data.get('problem_extension', {}).get('_intermediate', None)
if problem_extension and problem_extension != [] and problem_extension != {}:
_intermediate_outputs.append(problem_extension)
# Retrieve 节点
retrieve_node = node_data.get('retrieve', {}).get('_intermediate_outputs', None)
if retrieve_node and retrieve_node != [] and retrieve_node != {}:
_intermediate_outputs.extend(retrieve_node)
# Verify 节点
verify_n = node_data.get('verify', {}).get('_intermediate', None)
if verify_n and verify_n != [] and verify_n != {}:
_intermediate_outputs.append(verify_n)
# Summary 节点
summary_n = node_data.get('summary', {}).get('_intermediate', None)
if summary_n and summary_n != [] and summary_n != {}:
@@ -161,17 +160,20 @@ async def main():
#
print(f"=== 最终摘要 ===")
print(summary)
except Exception as e:
import traceback
traceback.print_exc()
finally:
db_session.close()
end=time.time()
print(100*'y')
print(f"总耗时: {end-start}s")
print(100*'y')
end = time.time()
print(100 * 'y')
print(f"总耗时: {end - start}s")
print(100 * 'y')
if __name__ == "__main__":
import asyncio
asyncio.run(main())

View File

@@ -1,56 +0,0 @@
import asyncio
from typing import Dict, Optional
from app.core.memory.utils.llm.llm_utils import get_llm_client_fast
from app.db import get_db
from app.core.logging_config import get_agent_logger
logger = get_agent_logger(__name__)
class LLMClientPool:
"""LLM客户端连接池"""
def __init__(self, max_size: int = 5):
self.max_size = max_size
self.pools: Dict[str, asyncio.Queue] = {}
self.active_clients: Dict[str, int] = {}
async def get_client(self, llm_model_id: str):
"""获取LLM客户端"""
if llm_model_id not in self.pools:
self.pools[llm_model_id] = asyncio.Queue(maxsize=self.max_size)
self.active_clients[llm_model_id] = 0
pool = self.pools[llm_model_id]
try:
# 尝试从池中获取客户端
client = pool.get_nowait()
logger.debug(f"从池中获取LLM客户端: {llm_model_id}")
return client
except asyncio.QueueEmpty:
# 池为空,创建新客户端
if self.active_clients[llm_model_id] < self.max_size:
db_session = next(get_db())
client = get_llm_client_fast(llm_model_id, db_session)
self.active_clients[llm_model_id] += 1
logger.debug(f"创建新LLM客户端: {llm_model_id}")
return client
else:
# 等待可用客户端
logger.debug(f"等待LLM客户端可用: {llm_model_id}")
return await pool.get()
async def return_client(self, llm_model_id: str, client):
"""归还LLM客户端到池中"""
if llm_model_id in self.pools:
try:
self.pools[llm_model_id].put_nowait(client)
logger.debug(f"归还LLM客户端到池: {llm_model_id}")
except asyncio.QueueFull:
# 池已满,丢弃客户端
self.active_clients[llm_model_id] -= 1
logger.debug(f"池已满丢弃LLM客户端: {llm_model_id}")
# 全局客户端池
llm_client_pool = LLMClientPool()

View File

@@ -21,31 +21,55 @@ from pydantic import BaseModel, Field
T = TypeVar("T")
class RedBearModelConfig(BaseModel):
"""模型配置基类"""
model_name: str
provider: str
api_key: str
base_url: Optional[str] = None
is_omni: bool = False # 是否为 Omni 模型
# 请求超时时间(秒)- 默认120秒以支持复杂的LLM调用可通过环境变量 LLM_TIMEOUT 配置
timeout: float = Field(default_factory=lambda: float(os.getenv("LLM_TIMEOUT", "120.0")))
# 最大重试次数 - 默认2次以避免过长等待可通过环境变量 LLM_MAX_RETRIES 配置
max_retries: int = Field(default_factory=lambda: int(os.getenv("LLM_MAX_RETRIES", "2")))
concurrency: int = 5 # 并发限流
concurrency: int = 5 # 并发限流
extra_params: Dict[str, Any] = {}
class RedBearModelFactory:
"""模型工厂类"""
@classmethod
def get_model_params(cls, config: RedBearModelConfig) -> Dict[str, Any]:
"""根据提供商获取模型参数"""
provider = config.provider.lower()
# 打印供应商信息用于调试
from app.core.logging_config import get_business_logger
logger = get_business_logger()
logger.debug(f"获取模型参数 - Provider: {provider}, Model: {config.model_name}")
logger.debug(f"获取模型参数 - Provider: {provider}, Model: {config.model_name}, is_omni: {config.is_omni}")
# dashscope 的 omni 模型使用 OpenAI 兼容模式
if provider == ModelProvider.DASHSCOPE and config.is_omni:
import httpx
if not config.base_url:
config.base_url = "https://dashscope.aliyuncs.com/compatible-mode/v1"
timeout_config = httpx.Timeout(
timeout=config.timeout,
connect=60.0,
read=config.timeout,
write=60.0,
pool=10.0,
)
return {
"model": config.model_name,
"base_url": config.base_url,
"api_key": config.api_key,
"timeout": timeout_config,
"max_retries": config.max_retries,
**config.extra_params
}
if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK, ModelProvider.OLLAMA]:
# 使用 httpx.Timeout 对象来设置详细的超时配置
@@ -65,7 +89,7 @@ class RedBearModelFactory:
"timeout": timeout_config,
"max_retries": config.max_retries,
**config.extra_params
}
}
elif provider == ModelProvider.DASHSCOPE:
# DashScope (通义千问) 使用自己的参数格式
# 注意: DashScopeEmbeddings 不支持 timeout 和 base_url 参数
@@ -82,7 +106,7 @@ class RedBearModelFactory:
# region 从 base_url 或 extra_params 获取
from botocore.config import Config as BotoConfig
from app.core.models.bedrock_model_mapper import normalize_bedrock_model_id
max_pool_connections = int(os.getenv("BEDROCK_MAX_POOL_CONNECTIONS", "50"))
max_retries = int(os.getenv("BEDROCK_MAX_RETRIES", "2"))
# Configure with increased connection pool
@@ -90,16 +114,16 @@ class RedBearModelFactory:
max_pool_connections=max_pool_connections,
retries={'max_attempts': max_retries, 'mode': 'adaptive'}
)
# 标准化模型 ID自动转换简化名称为完整 Bedrock Model ID
model_id = normalize_bedrock_model_id(config.model_name)
params = {
"model_id": model_id,
"config": boto_config,
**config.extra_params
}
# 解析 API key (格式: access_key_id:secret_access_key)
if config.api_key and ":" in config.api_key:
access_key_id, secret_access_key = config.api_key.split(":", 1)
@@ -107,45 +131,52 @@ class RedBearModelFactory:
params["aws_secret_access_key"] = secret_access_key
elif config.api_key:
params["aws_access_key_id"] = config.api_key
# 设置 region
if config.base_url:
params["region_name"] = config.base_url
elif "region_name" not in params:
params["region_name"] = "us-east-1" # 默认区域
return params
else:
raise BusinessException(f"不支持的提供商: {provider}", code=BizCode.PROVIDER_NOT_SUPPORTED)
@classmethod
def get_rerank_model_params(cls, config: RedBearModelConfig) -> Dict[str, Any]:
"""根据提供商获取模型参数"""
provider = config.provider.lower()
if provider in [ModelProvider.XINFERENCE, ModelProvider.GPUSTACK]:
return {
return {
"model": config.model_name,
# "base_url": config.base_url,
"jina_api_key": config.api_key,
**config.extra_params
}
}
else:
raise BusinessException(f"不支持的提供商: {provider}", code=BizCode.PROVIDER_NOT_SUPPORTED)
def get_provider_llm_class(config:RedBearModelConfig, type: ModelType=ModelType.LLM) -> type[BaseLLM]:
def get_provider_llm_class(config: RedBearModelConfig, type: ModelType = ModelType.LLM) -> type[BaseLLM]:
"""根据模型提供商获取对应的模型类"""
provider = config.provider.lower()
if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK] :
# dashscope 的 omni 模型使用 OpenAI 兼容模式
if provider == ModelProvider.DASHSCOPE and config.is_omni:
from langchain_openai import ChatOpenAI
return ChatOpenAI
if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK] :
if type == ModelType.LLM:
from langchain_openai import OpenAI
return OpenAI
return OpenAI
elif type == ModelType.CHAT:
from langchain_openai import ChatOpenAI
return ChatOpenAI
elif provider == ModelProvider.DASHSCOPE:
from langchain_community.chat_models import ChatTongyi
return ChatTongyi
elif provider == ModelProvider.OLLAMA:
elif provider == ModelProvider.OLLAMA:
from langchain_ollama import OllamaLLM
return OllamaLLM
elif provider == ModelProvider.BEDROCK:
@@ -155,15 +186,16 @@ def get_provider_llm_class(config:RedBearModelConfig, type: ModelType=ModelType.
else:
raise BusinessException(f"不支持的模型提供商: {provider}", code=BizCode.PROVIDER_NOT_SUPPORTED)
def get_provider_embedding_class(provider: str) -> type[Embeddings]:
"""根据模型提供商获取对应的模型类"""
provider = provider.lower()
if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK] :
if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK]:
from langchain_openai import OpenAIEmbeddings
return OpenAIEmbeddings
return OpenAIEmbeddings
elif provider == ModelProvider.DASHSCOPE:
from langchain_community.embeddings import DashScopeEmbeddings
return DashScopeEmbeddings
return DashScopeEmbeddings
elif provider == ModelProvider.OLLAMA:
from langchain_ollama import OllamaEmbeddings
return OllamaEmbeddings
@@ -173,14 +205,15 @@ def get_provider_embedding_class(provider: str) -> type[Embeddings]:
else:
raise BusinessException(f"不支持的模型提供商: {provider}", code=BizCode.PROVIDER_NOT_SUPPORTED)
def get_provider_rerank_class(provider: str):
"""根据模型提供商获取对应的模型类"""
provider = provider.lower()
if provider in [ModelProvider.XINFERENCE, ModelProvider.GPUSTACK] :
provider = provider.lower()
if provider in [ModelProvider.XINFERENCE, ModelProvider.GPUSTACK]:
from langchain_community.document_compressors import JinaRerank
return JinaRerank
# elif provider == ModelProvider.OLLAMA:
return JinaRerank
# elif provider == ModelProvider.OLLAMA:
# from langchain_ollama import OllamaEmbeddings
# return OllamaEmbeddings
else:
raise BusinessException(f"不支持的模型提供商: {provider}", code=BizCode.PROVIDER_NOT_SUPPORTED)
raise BusinessException(f"不支持的模型提供商: {provider}", code=BizCode.PROVIDER_NOT_SUPPORTED)

View File

@@ -6,6 +6,8 @@ models:
description: AI21 Labs大语言模型completion生成模式256000上下文窗口
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
logo: bedrock
@@ -15,6 +17,9 @@ models:
description: Amazon Nova大语言模型支持智能体思考、工具调用、流式工具调用、视觉能力300000上下文窗口对话模式
is_deprecated: false
is_official: true
capability:
- vision
is_omni: false
tags:
- 大语言模型
- agent-thought
@@ -28,6 +33,9 @@ models:
description: Anthropic Claude大语言模型支持智能体思考、视觉能力、工具调用、流式工具调用、文档处理200000上下文窗口对话模式
is_deprecated: false
is_official: true
capability:
- vision
is_omni: false
tags:
- 大语言模型
- agent-thought
@@ -42,6 +50,8 @@ models:
description: Cohere大语言模型支持智能体思考、工具调用、流式工具调用128000上下文窗口对话模式
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- agent-thought
@@ -54,6 +64,9 @@ models:
description: DeepSeek大语言模型支持智能体思考、视觉能力、工具调用、流式工具调用32768上下文窗口对话模式
is_deprecated: false
is_official: true
capability:
- vision
is_omni: false
tags:
- 大语言模型
- agent-thought
@@ -67,6 +80,8 @@ models:
description: Meta Llama大语言模型支持智能体思考、工具调用128000上下文窗口对话模式
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- agent-thought
@@ -78,6 +93,8 @@ models:
description: Mistral AI大语言模型支持智能体思考、工具调用32000上下文窗口对话模式
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- agent-thought
@@ -89,6 +106,8 @@ models:
description: OpenAI大语言模型支持智能体思考、工具调用、流式工具调用32768上下文窗口对话模式
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- agent-thought
@@ -101,6 +120,8 @@ models:
description: Qwen大语言模型支持智能体思考、工具调用、流式工具调用32768上下文窗口对话模式
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- agent-thought
@@ -113,6 +134,8 @@ models:
description: amazon.rerank-v1:0重排序模型5120上下文窗口
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 重排序模型
logo: bedrock
@@ -122,6 +145,8 @@ models:
description: cohere.rerank-v3-5:0重排序模型5120上下文窗口
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 重排序模型
logo: bedrock
@@ -131,6 +156,9 @@ models:
description: amazon.nova-2-multimodal-embeddings-v1:0文本嵌入模型支持视觉能力8192上下文窗口
is_deprecated: false
is_official: true
capability:
- vision
is_omni: false
tags:
- 文本嵌入模型
- vision
@@ -141,6 +169,8 @@ models:
description: amazon.titan-embed-text-v1文本嵌入模型8192上下文窗口
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 文本嵌入模型
logo: bedrock
@@ -150,6 +180,8 @@ models:
description: amazon.titan-embed-text-v2:0文本嵌入模型8192上下文窗口
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 文本嵌入模型
logo: bedrock
@@ -159,6 +191,8 @@ models:
description: Cohere Embed 3 English文本嵌入模型512上下文窗口
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 文本嵌入模型
logo: bedrock
@@ -168,6 +202,8 @@ models:
description: Cohere Embed 3 Multilingual文本嵌入模型512上下文窗口
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 文本嵌入模型
logo: bedrock
logo: bedrock

View File

@@ -6,6 +6,8 @@ models:
description: DeepSeek-R1-Distill-Qwen-14B大语言模型支持智能体思考32000上下文窗口对话模式
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- agent-thought
@@ -16,6 +18,8 @@ models:
description: DeepSeek-R1-Distill-Qwen-32B大语言模型支持智能体思考32000上下文窗口对话模式
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- agent-thought
@@ -26,6 +30,8 @@ models:
description: DeepSeek-R1大语言模型支持智能体思考131072超大上下文窗口对话模式
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- agent-thought
@@ -36,6 +42,8 @@ models:
description: DeepSeek-V3.1大语言模型支持智能体思考131072超大上下文窗口对话模式支持丰富生成参数调节
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- agent-thought
@@ -46,6 +54,8 @@ models:
description: DeepSeek-V3.2-exp实验版大语言模型支持智能体思考131072超大上下文窗口对话模式支持丰富生成参数调节
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- agent-thought
@@ -56,6 +66,8 @@ models:
description: DeepSeek-V3.2大语言模型支持智能体思考131072超大上下文窗口对话模式支持丰富生成参数调节
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- agent-thought
@@ -66,6 +78,8 @@ models:
description: DeepSeek-V3大语言模型支持智能体思考64000上下文窗口对话模式支持文本与JSON格式输出
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- agent-thought
@@ -76,6 +90,8 @@ models:
description: farui-plus大语言模型支持多工具调用、智能体思考、流式工具调用12288上下文窗口对话模式
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- multi-tool-call
@@ -88,6 +104,8 @@ models:
description: GLM-4.7大语言模型支持多工具调用、智能体思考、流式工具调用202752超大上下文窗口对话模式
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- multi-tool-call
@@ -100,6 +118,9 @@ models:
description: qvq-max-latest大语言模型支持视觉、智能体思考、流式工具调用131072上下文窗口对话模式
is_deprecated: false
is_official: true
capability:
- vision
is_omni: false
tags:
- 大语言模型
- vision
@@ -112,6 +133,9 @@ models:
description: qvq-max大语言模型支持视觉、智能体思考、流式工具调用131072上下文窗口对话模式
is_deprecated: false
is_official: true
capability:
- vision
is_omni: false
tags:
- 大语言模型
- vision
@@ -124,6 +148,8 @@ models:
description: qwen-coder-turbo-0919代码专用大语言模型支持智能体思考131072上下文窗口对话模式已废弃
is_deprecated: true
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- 代码模型
@@ -135,6 +161,8 @@ models:
description: qwen-max-latest大语言模型支持多工具调用、智能体思考、流式工具调用131072上下文窗口对话模式支持联网搜索
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- multi-tool-call
@@ -147,6 +175,8 @@ models:
description: qwen-max-longcontext长上下文大语言模型支持多工具调用、智能体思考、流式工具调用32000上下文窗口对话模式已废弃
is_deprecated: true
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- multi-tool-call
@@ -159,6 +189,8 @@ models:
description: qwen-max大语言模型支持多工具调用、智能体思考、流式工具调用32768上下文窗口对话模式支持联网搜索
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- multi-tool-call
@@ -171,6 +203,8 @@ models:
description: qwen-mt-plus多语言翻译大语言模型支持智能体思考16384上下文窗口对话模式支持多语种互译与领域翻译适配
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- 翻译模型
@@ -182,6 +216,8 @@ models:
description: qwen-mt-turbo轻量化多语言翻译大语言模型支持智能体思考16384上下文窗口对话模式支持多语种互译与领域翻译适配
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- 翻译模型
@@ -193,6 +229,8 @@ models:
description: qwen-plus-0112大语言模型支持多工具调用、智能体思考、流式工具调用131072上下文窗口对话模式支持联网搜索已废弃
is_deprecated: true
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- multi-tool-call
@@ -205,6 +243,8 @@ models:
description: qwen-plus-0125大语言模型支持多工具调用、智能体思考、流式工具调用131072上下文窗口对话模式支持联网搜索已废弃
is_deprecated: true
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- multi-tool-call
@@ -217,6 +257,8 @@ models:
description: qwen-plus-0723大语言模型支持多工具调用、智能体思考、流式工具调用32000上下文窗口对话模式支持联网搜索已废弃
is_deprecated: true
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- multi-tool-call
@@ -229,6 +271,8 @@ models:
description: qwen-plus-0806大语言模型支持多工具调用、智能体思考、流式工具调用131072上下文窗口对话模式支持联网搜索已废弃
is_deprecated: true
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- multi-tool-call
@@ -241,6 +285,8 @@ models:
description: qwen-plus-0919大语言模型支持多工具调用、智能体思考、流式工具调用131072上下文窗口对话模式支持联网搜索已废弃
is_deprecated: true
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- multi-tool-call
@@ -253,6 +299,8 @@ models:
description: qwen-plus-1125大语言模型支持多工具调用、智能体思考、流式工具调用131072上下文窗口对话模式支持联网搜索已废弃
is_deprecated: true
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- multi-tool-call
@@ -265,6 +313,8 @@ models:
description: qwen-plus-1127大语言模型支持多工具调用、智能体思考、流式工具调用131072上下文窗口对话模式支持联网搜索已废弃
is_deprecated: true
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- multi-tool-call
@@ -277,6 +327,8 @@ models:
description: qwen-plus-1220大语言模型支持多工具调用、智能体思考、流式工具调用131072上下文窗口对话模式已废弃
is_deprecated: true
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- multi-tool-call
@@ -289,6 +341,10 @@ models:
description: qwen-vl-max多模态大模型支持视觉理解、智能体思考、视频理解131072上下文窗口对话模式未废弃
is_deprecated: false
is_official: true
capability:
- vision
- video
is_omni: false
tags:
- 大语言模型
- 多模态模型
@@ -302,6 +358,10 @@ models:
description: qwen-vl-plus-0809多模态大模型支持视觉理解、智能体思考、视频理解32768上下文窗口对话模式已废弃
is_deprecated: true
is_official: true
capability:
- vision
- video
is_omni: false
tags:
- 大语言模型
- 多模态模型
@@ -315,6 +375,10 @@ models:
description: qwen-vl-plus-2025-01-02多模态大模型支持视觉理解、智能体思考、视频理解32768上下文窗口对话模式未废弃
is_deprecated: false
is_official: true
capability:
- vision
- video
is_omni: false
tags:
- 大语言模型
- 多模态模型
@@ -328,6 +392,10 @@ models:
description: qwen-vl-plus-2025-01-25多模态大模型支持视觉理解、智能体思考、视频理解131072上下文窗口对话模式未废弃
is_deprecated: false
is_official: true
capability:
- vision
- video
is_omni: false
tags:
- 大语言模型
- 多模态模型
@@ -341,6 +409,10 @@ models:
description: qwen-vl-plus-latest多模态大模型支持视觉理解、智能体思考、视频理解131072上下文窗口对话模式未废弃
is_deprecated: false
is_official: true
capability:
- vision
- video
is_omni: false
tags:
- 大语言模型
- 多模态模型
@@ -354,6 +426,10 @@ models:
description: qwen-vl-plus多模态大模型支持视觉理解、智能体思考、视频理解131072上下文窗口对话模式未废弃
is_deprecated: false
is_official: true
capability:
- vision
- video
is_omni: false
tags:
- 大语言模型
- 多模态模型
@@ -367,6 +443,8 @@ models:
description: qwen2.5-0.5b-instruct大语言模型支持多工具调用、智能体思考、流式工具调用32768上下文窗口对话模式未废弃
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- multi-tool-call
@@ -379,6 +457,8 @@ models:
description: qwen3-14b大语言模型支持多工具调用、智能体思考、流式工具调用131072上下文窗口对话模式
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- multi-tool-call
@@ -391,6 +471,8 @@ models:
description: qwen3-235b-a22b-instruct-2507大语言模型支持多工具调用、智能体思考、流式工具调用131072上下文窗口对话模式
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- multi-tool-call
@@ -403,6 +485,8 @@ models:
description: qwen3-235b-a22b-thinking-2507大语言模型支持多工具调用、智能体思考、流式工具调用131072上下文窗口对话模式
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- multi-tool-call
@@ -415,6 +499,8 @@ models:
description: qwen3-235b-a22b大语言模型支持多工具调用、智能体思考、流式工具调用131072上下文窗口对话模式
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- multi-tool-call
@@ -427,6 +513,8 @@ models:
description: qwen3-30b-a3b-instruct-2507大语言模型支持多工具调用、智能体思考、流式工具调用131072上下文窗口对话模式
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- multi-tool-call
@@ -439,6 +527,8 @@ models:
description: qwen3-30b-a3b大语言模型支持多工具调用、智能体思考、流式工具调用131072上下文窗口对话模式
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- multi-tool-call
@@ -451,6 +541,8 @@ models:
description: qwen3-32b大语言模型支持多工具调用、智能体思考、流式工具调用131072上下文窗口对话模式
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- multi-tool-call
@@ -463,6 +555,8 @@ models:
description: qwen3-4b大语言模型支持多工具调用、智能体思考、流式工具调用131072上下文窗口对话模式
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- multi-tool-call
@@ -475,6 +569,8 @@ models:
description: qwen3-8b大语言模型支持多工具调用、智能体思考、流式工具调用131072上下文窗口对话模式
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- multi-tool-call
@@ -487,6 +583,8 @@ models:
description: qwen3-coder-30b-a3b-instruct大语言模型支持智能体思考262144上下文窗口对话模式
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- 代码模型
@@ -498,6 +596,8 @@ models:
description: qwen3-coder-480b-a35b-instruct大语言模型支持智能体思考262144上下文窗口对话模式
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- 代码模型
@@ -509,6 +609,8 @@ models:
description: qwen3-coder-plus-2025-09-23大语言模型支持智能体思考1000000上下文窗口对话模式
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- 代码模型
@@ -520,6 +622,8 @@ models:
description: qwen3-coder-plus大语言模型支持智能体思考1000000上下文窗口对话模式
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- 代码模型
@@ -531,6 +635,8 @@ models:
description: qwen3-max-2025-09-23大语言模型支持多工具调用、智能体思考、流式工具调用262144上下文窗口对话模式支持联网搜索
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- multi-tool-call
@@ -544,6 +650,8 @@ models:
description: qwen3-max-2026-01-23大语言模型支持多工具调用、智能体思考、流式工具调用262144上下文窗口对话模式支持联网搜索
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- multi-tool-call
@@ -557,6 +665,8 @@ models:
description: qwen3-max-preview大语言模型支持多工具调用、智能体思考、流式工具调用262144上下文窗口对话模式
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- multi-tool-call
@@ -569,6 +679,8 @@ models:
description: qwen3-max大语言模型支持多工具调用、智能体思考、流式工具调用262144上下文窗口对话模式支持联网搜索
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- multi-tool-call
@@ -582,6 +694,8 @@ models:
description: qwen3-next-80b-a3b-instruct大语言模型支持多工具调用、智能体思考、流式工具调用131072上下文窗口对话模式
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- multi-tool-call
@@ -594,6 +708,8 @@ models:
description: qwen3-next-80b-a3b-thinking大语言模型支持多工具调用、智能体思考、流式工具调用131072上下文窗口对话模式
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- multi-tool-call
@@ -606,6 +722,11 @@ models:
description: qwen3-omni-flash-2025-12-01多模态大语言模型支持视觉、智能体思考、视频、音频能力65536上下文窗口对话模式
is_deprecated: false
is_official: true
capability:
- vision
- video
- audio
is_omni: true
tags:
- 大语言模型
- 多模态模型
@@ -620,6 +741,10 @@ models:
description: qwen3-vl-235b-a22b-instruct多模态大语言模型支持多工具调用、智能体思考、流式工具调用、视觉、视频能力131072上下文窗口对话模式
is_deprecated: false
is_official: true
capability:
- vision
- video
is_omni: false
tags:
- 大语言模型
- 多模态模型
@@ -635,6 +760,10 @@ models:
description: qwen3-vl-235b-a22b-thinking多模态大语言模型支持多工具调用、智能体思考、流式工具调用、视觉、视频能力131072上下文窗口对话模式
is_deprecated: false
is_official: true
capability:
- vision
- video
is_omni: false
tags:
- 大语言模型
- 多模态模型
@@ -650,6 +779,10 @@ models:
description: qwen3-vl-30b-a3b-instruct多模态大语言模型支持多工具调用、智能体思考、流式工具调用、视觉、视频能力131072上下文窗口对话模式
is_deprecated: false
is_official: true
capability:
- vision
- video
is_omni: false
tags:
- 大语言模型
- 多模态模型
@@ -665,6 +798,10 @@ models:
description: qwen3-vl-30b-a3b-thinking多模态大语言模型支持多工具调用、智能体思考、流式工具调用、视觉、视频能力131072上下文窗口对话模式
is_deprecated: false
is_official: true
capability:
- vision
- video
is_omni: false
tags:
- 大语言模型
- 多模态模型
@@ -680,6 +817,10 @@ models:
description: qwen3-vl-flash多模态大语言模型支持多工具调用、智能体思考、流式工具调用、视觉、视频能力131072上下文窗口对话模式
is_deprecated: false
is_official: true
capability:
- vision
- video
is_omni: false
tags:
- 大语言模型
- 多模态模型
@@ -695,6 +836,10 @@ models:
description: qwen3-vl-plus-2025-09-23多模态大语言模型支持视觉、智能体思考、视频能力262144上下文窗口对话模式
is_deprecated: false
is_official: true
capability:
- vision
- video
is_omni: false
tags:
- 大语言模型
- 多模态模型
@@ -708,6 +853,10 @@ models:
description: qwen3-vl-plus多模态大语言模型支持视觉、智能体思考、视频能力262144上下文窗口对话模式
is_deprecated: false
is_official: true
capability:
- vision
- video
is_omni: false
tags:
- 大语言模型
- 多模态模型
@@ -721,6 +870,8 @@ models:
description: qwq-32b大语言模型支持智能体思考、流式工具调用131072上下文窗口对话模式
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- agent-thought
@@ -732,6 +883,8 @@ models:
description: qwq-plus-0305大语言模型支持智能体思考、流式工具调用131072上下文窗口对话模式
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- agent-thought
@@ -743,6 +896,8 @@ models:
description: qwq-plus大语言模型支持智能体思考、流式工具调用131072上下文窗口对话模式
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- agent-thought
@@ -754,6 +909,8 @@ models:
description: gte-rerank-v2重排序模型4000上下文窗口
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 重排序模型
logo: dashscope
@@ -763,6 +920,8 @@ models:
description: gte-rerank重排序模型4000上下文窗口
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 重排序模型
logo: dashscope
@@ -772,6 +931,9 @@ models:
description: multimodal-embedding-v1多模态嵌入模型支持视觉能力8192上下文窗口最大分块数10
is_deprecated: false
is_official: true
capability:
- vision
is_omni: false
tags:
- 嵌入模型
- 多模态模型
@@ -783,6 +945,8 @@ models:
description: text-embedding-v1文本嵌入模型2048上下文窗口最大分块数25
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 嵌入模型
- 文本嵌入
@@ -793,6 +957,8 @@ models:
description: text-embedding-v2文本嵌入模型2048上下文窗口最大分块数25
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 嵌入模型
- 文本嵌入
@@ -803,6 +969,8 @@ models:
description: text-embedding-v3文本嵌入模型8192上下文窗口最大分块数10
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 嵌入模型
- 文本嵌入
@@ -813,7 +981,9 @@ models:
description: text-embedding-v4文本嵌入模型8192上下文窗口最大分块数10
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 嵌入模型
- 文本嵌入
logo: dashscope
logo: dashscope

View File

@@ -6,7 +6,7 @@ from typing import Callable
import yaml
from sqlalchemy.orm import Session
from app.models.models_model import ModelBase, ModelProvider
from app.models.models_model import ModelBase, ModelProvider, ModelConfig
def _load_yaml_config(provider: ModelProvider) -> list[dict]:
@@ -55,6 +55,15 @@ def load_models(db: Session, providers: list[str] = None, silent: bool = False)
print(f"\n正在加载 {provider.value}{len(models)} 个模型...")
for model_data in models:
config_sync_fields = {
"logo": None,
"capability": None,
"is_omni": None,
"name": None,
"provider": None,
"type": None,
"description": None
}
try:
# 检查模型是否已存在
existing = db.query(ModelBase).filter(
@@ -66,6 +75,40 @@ def load_models(db: Session, providers: list[str] = None, silent: bool = False)
# 更新现有模型配置
for key, value in model_data.items():
setattr(existing, key, value)
# 更新绑定了该 model_id 的 ModelConfig 和 ModelApiKey
sync_fields = [k for k in config_sync_fields.keys() if k in model_data]
if sync_fields:
# 批量更新 ModelConfig
update_kwargs = {k: model_data[k] for k in sync_fields}
db.query(ModelConfig).filter(ModelConfig.model_id == existing.id).update(
update_kwargs,
synchronize_session=False
)
# 更新 ModelApiKey 的 capability 和 is_omni
if 'capability' in model_data or 'is_omni' in model_data:
from app.models.models_model import ModelApiKey, model_config_api_key_association
api_key_update = {}
if 'capability' in model_data:
api_key_update['capability'] = model_data['capability']
if 'is_omni' in model_data:
api_key_update['is_omni'] = model_data['is_omni']
if api_key_update:
# 查找所有关联的 API Key
api_key_ids = db.query(model_config_api_key_association.c.api_key_id).join(
ModelConfig,
ModelConfig.id == model_config_api_key_association.c.model_config_id
).filter(ModelConfig.model_id == existing.id).distinct().all()
if api_key_ids:
api_key_ids = [aid[0] for aid in api_key_ids]
db.query(ModelApiKey).filter(ModelApiKey.id.in_(api_key_ids)).update(
api_key_update,
synchronize_session=False
)
db.commit()
if not silent:
print(f"更新成功: {model_data['name']}")

View File

@@ -6,12 +6,19 @@ models:
description: chatgpt-4o-latest大语言模型支持多工具调用、智能体思考、流式工具调用、视觉能力128000上下文窗口对话模式
is_deprecated: false
is_official: true
capability:
- vision
- audio
- video
is_omni: true
tags:
- 大语言模型
- multi-tool-call
- agent-thought
- stream-tool-call
- vision
- audio
- video
logo: openai
- name: gpt-3.5-turbo-0125
type: llm
@@ -19,6 +26,8 @@ models:
description: gpt-3.5-turbo-0125大语言模型支持多工具调用、智能体思考、流式工具调用16385上下文窗口对话模式
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- multi-tool-call
@@ -31,6 +40,8 @@ models:
description: gpt-3.5-turbo-1106大语言模型支持多工具调用、智能体思考、流式工具调用16385上下文窗口对话模式
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- multi-tool-call
@@ -43,6 +54,8 @@ models:
description: gpt-3.5-turbo-16k大语言模型支持多工具调用、智能体思考、流式工具调用16385上下文窗口对话模式
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- multi-tool-call
@@ -55,6 +68,8 @@ models:
description: gpt-3.5-turbo-instruct大语言模型4096上下文窗口文本补全模式
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
logo: openai
@@ -64,6 +79,8 @@ models:
description: gpt-3.5-turbo大语言模型支持多工具调用、智能体思考、流式工具调用16385上下文窗口对话模式
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- multi-tool-call
@@ -76,6 +93,8 @@ models:
description: gpt-4-0125-preview大语言模型支持多工具调用、智能体思考、流式工具调用128000上下文窗口对话模式
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- multi-tool-call
@@ -88,6 +107,8 @@ models:
description: gpt-4-1106-preview大语言模型支持多工具调用、智能体思考、流式工具调用128000上下文窗口对话模式
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- multi-tool-call
@@ -100,6 +121,9 @@ models:
description: gpt-4-turbo-2024-04-09大语言模型支持多工具调用、智能体思考、流式工具调用、视觉能力128000上下文窗口对话模式
is_deprecated: false
is_official: true
capability:
- vision
is_omni: false
tags:
- 大语言模型
- multi-tool-call
@@ -113,6 +137,8 @@ models:
description: gpt-4-turbo-preview大语言模型支持多工具调用、智能体思考、流式工具调用128000上下文窗口对话模式
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- multi-tool-call
@@ -125,6 +151,9 @@ models:
description: gpt-4-turbo大语言模型支持多工具调用、智能体思考、流式工具调用、视觉能力128000上下文窗口对话模式
is_deprecated: false
is_official: true
capability:
- vision
is_omni: false
tags:
- 大语言模型
- multi-tool-call
@@ -138,6 +167,8 @@ models:
description: o1-preview大语言模型支持智能体思考128000上下文窗口对话模式已废弃
is_deprecated: true
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- agent-thought
@@ -148,6 +179,9 @@ models:
description: o1大语言模型支持多工具调用、智能体思考、流式工具调用、视觉能力、结构化输出200000上下文窗口对话模式
is_deprecated: false
is_official: true
capability:
- vision
is_omni: false
tags:
- 大语言模型
- multi-tool-call
@@ -162,6 +196,9 @@ models:
description: o3-2025-04-16大语言模型支持智能体思考、工具调用、视觉能力、流式工具调用、结构化输出200000上下文窗口对话模式
is_deprecated: false
is_official: true
capability:
- vision
is_omni: false
tags:
- 大语言模型
- agent-thought
@@ -176,6 +213,8 @@ models:
description: o3-mini-2025-01-31大语言模型支持智能体思考、工具调用、流式工具调用、结构化输出200000上下文窗口对话模式
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- agent-thought
@@ -189,6 +228,8 @@ models:
description: o3-mini大语言模型支持智能体思考、工具调用、流式工具调用、结构化输出200000上下文窗口对话模式
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- agent-thought
@@ -202,6 +243,9 @@ models:
description: o3-pro-2025-06-10大语言模型支持智能体思考、工具调用、视觉能力、结构化输出200000上下文窗口对话模式
is_deprecated: false
is_official: true
capability:
- vision
is_omni: false
tags:
- 大语言模型
- agent-thought
@@ -215,6 +259,9 @@ models:
description: o3-pro大语言模型支持智能体思考、工具调用、视觉能力、结构化输出200000上下文窗口对话模式
is_deprecated: false
is_official: true
capability:
- vision
is_omni: false
tags:
- 大语言模型
- agent-thought
@@ -228,6 +275,9 @@ models:
description: o3大语言模型支持智能体思考、视觉能力、工具调用、流式工具调用、结构化输出200000上下文窗口对话模式
is_deprecated: false
is_official: true
capability:
- vision
is_omni: false
tags:
- 大语言模型
- agent-thought
@@ -242,6 +292,9 @@ models:
description: o4-mini-2025-04-16大语言模型支持智能体思考、工具调用、视觉能力、流式工具调用、结构化输出200000上下文窗口对话模式
is_deprecated: false
is_official: true
capability:
- vision
is_omni: false
tags:
- 大语言模型
- agent-thought
@@ -256,6 +309,9 @@ models:
description: o4-mini大语言模型支持智能体思考、工具调用、视觉能力、流式工具调用、结构化输出200000上下文窗口对话模式
is_deprecated: false
is_official: true
capability:
- vision
is_omni: false
tags:
- 大语言模型
- agent-thought
@@ -270,6 +326,8 @@ models:
description: text-embedding-3-large文本向量模型8191上下文窗口最大分块数32
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 文本向量模型
logo: openai
@@ -279,6 +337,8 @@ models:
description: text-embedding-3-small文本向量模型8191上下文窗口最大分块数32
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 文本向量模型
logo: openai
@@ -288,6 +348,8 @@ models:
description: text-embedding-ada-002文本向量模型8097上下文窗口最大分块数32
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 文本向量模型
logo: openai
logo: openai

View File

@@ -98,7 +98,7 @@ class DifyConverter(BaseConverter):
if not var_selector:
return ""
selector = var_selector.split('.')
if len(selector) not in [2, 3]:
if len(selector) not in [2, 3] and var_selector != "context":
raise Exception(f"invalid variable selector: {var_selector}")
if len(selector) == 3:
selector = selector[1:]
@@ -332,7 +332,9 @@ class DifyConverter(BaseConverter):
messages.append(
MessageConfig(
role="user",
content=self.trans_variable_format(node_data["memory"]["query_prompt_template"])
content=self.trans_variable_format(
node_data["memory"].get("query_prompt_template", "{{#sys.query#}}")
)
)
)
vision = node_data["vision"]["enabled"]

View File

@@ -80,7 +80,7 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
return True
def validate_config(self) -> bool:
require_fields = frozenset({'app', 'dependencies', 'kind', 'version', 'workflow'})
require_fields = frozenset({'app', 'kind', 'version', 'workflow'})
if not all(field in self.config for field in require_fields):
return False

View File

@@ -303,38 +303,52 @@ class VariablePool:
"""
return self._get_variable_struct(selector) is not None
def get_all_system_vars(self) -> dict[str, Any]:
def get_all_system_vars(self, literal=False) -> dict[str, Any]:
"""获取所有系统变量
Returns:
系统变量字典
"""
sys_namespace = self.variables.get("sys", {})
if literal:
return {k: v.instance.to_literal() for k, v in sys_namespace.items()}
return {k: v.instance.get_value() for k, v in sys_namespace.items()}
def get_all_conversation_vars(self) -> dict[str, Any]:
def get_all_conversation_vars(self, literal=False) -> dict[str, Any]:
"""获取所有会话变量
Returns:
会话变量字典
"""
conv_namespace = self.variables.get("conv", {})
if literal:
return {k: v.instance.to_literal() for k, v in conv_namespace.items()}
return {k: v.instance.get_value() for k, v in conv_namespace.items()}
def get_all_node_outputs(self) -> dict[str, Any]:
def get_all_node_outputs(self, literal=False) -> dict[str, Any]:
"""获取所有节点输出(运行时变量)
Returns:
节点输出字典,键为节点 ID
"""
runtime_vars = {
namespace: {
k: v.instance.get_value()
for k, v in vars_dict.items()
if literal:
runtime_vars = {
namespace: {
k: v.instance.to_literal()
for k, v in vars_dict.items()
}
for namespace, vars_dict in self.variables.items()
if namespace not in ("sys", "conv")
}
else:
runtime_vars = {
namespace: {
k: v.instance.get_value()
for k, v in vars_dict.items()
}
for namespace, vars_dict in self.variables.items()
if namespace not in ("sys", "conv")
}
for namespace, vars_dict in self.variables.items()
if namespace not in ("sys", "conv")
}
return runtime_vars
def get_node_output(self, node_id: str, defalut: Any = None, strict: bool = True) -> dict[str, Any] | None:

View File

@@ -14,9 +14,9 @@ from app.core.workflow.engine.state_manager import WorkflowState
from app.core.workflow.engine.variable_pool import VariablePool
from app.core.workflow.nodes.base_node import BaseNode
from app.core.workflow.variable.base_variable import VariableType
from app.db import get_db
from app.db import get_db_context
from app.models import AppRelease
from app.services.draft_run_service import DraftRunService
from app.services.draft_run_service import AgentRunService
logger = logging.getLogger(__name__)
@@ -39,7 +39,7 @@ class AgentNode(BaseNode):
def _output_types(self) -> dict[str, VariableType]:
return {"output": VariableType.STRING}
def _prepare_agent(self, variable_pool: VariablePool) -> tuple[DraftRunService, AppRelease, str]:
def _prepare_agent(self, variable_pool: VariablePool) -> tuple[AppRelease, str]:
"""准备 Agent公共逻辑
Args:
@@ -57,17 +57,17 @@ class AgentNode(BaseNode):
if not agent_id:
raise ValueError(f"节点 {self.node_id} 缺少 agent_id 配置")
db = next(get_db())
release = db.query(AppRelease).filter(
AppRelease.id == agent_id
).first()
with get_db_context() as db:
release = db.query(AppRelease).filter(
AppRelease.id == agent_id
).first()
if not release:
raise ValueError(f"Agent 不存在: {agent_id}")
draft_service = DraftRunService(db)
return draft_service, release, message
return release, message
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]:
"""非流式执行
@@ -79,19 +79,21 @@ class AgentNode(BaseNode):
Returns:
状态更新字典
"""
draft_service, release, message = self._prepare_agent(variable_pool)
release, message = self._prepare_agent(variable_pool)
logger.info(f"节点 {self.node_id} 开始执行 Agent 调用(非流式)")
# 执行 Agent非流式
result = await draft_service.run(
agent_config=release.config,
model_config=None,
message=message,
workspace_id=variable_pool.get_value("sys.workspace_id"),
user_id=state.get("user_id"),
variables=variable_pool.get_all_conversation_vars()
)
with get_db_context() as db:
draft_service = AgentRunService(db)
# 执行 Agent非流式
result = await draft_service.run(
agent_config=release.config,
model_config=None,
message=message,
workspace_id=variable_pool.get_value("sys.workspace_id"),
user_id=state.get("user_id"),
variables=variable_pool.get_all_conversation_vars()
)
response = result.get("response", "")
@@ -118,34 +120,35 @@ class AgentNode(BaseNode):
Yields:
流式事件字典
"""
draft_service, release, message = self._prepare_agent(variable_pool)
release, message = self._prepare_agent(variable_pool)
logger.info(f"节点 {self.node_id} 开始执行 Agent 调用(流式)")
# 累积完整响应
full_response = ""
with get_db_context() as db:
draft_service = AgentRunService(db)
# 执行 Agent流式
async for chunk in draft_service.run_stream(
agent_config=release.config,
model_config=None,
message=message,
workspace_id=variable_pool.get_value("sys.workspace_id"),
user_id=state.get("user_id"),
variables=variable_pool.get_all_conversation_vars()
):
# 提取内容
content = chunk.get("content", "")
full_response += content
# 流式返回每个 chunk
yield {
"type": "chunk",
"node_id": self.node_id,
"content": content,
"full_content": full_response,
"meta_data": chunk.get("meta_data", {})
}
async for chunk in draft_service.run_stream(
agent_config=release.config,
model_config=None,
message=message,
workspace_id=variable_pool.get_value("sys.workspace_id"),
user_id=state.get("user_id"),
variables=variable_pool.get_all_conversation_vars()
):
# 提取内容
content = chunk.get("content", "")
full_response += content
# 流式返回每个 chunk
yield {
"type": "chunk",
"node_id": self.node_id,
"content": content,
"full_content": full_response,
"meta_data": chunk.get("meta_data", {})
}
logger.info(f"节点 {self.node_id} Agent 调用完成,输出长度: {len(full_response)}")

View File

@@ -1,5 +1,6 @@
import asyncio
import logging
import uuid
from abc import ABC, abstractmethod
from functools import cached_property
from typing import Any, AsyncGenerator
@@ -10,8 +11,10 @@ from app.core.config import settings
from app.core.workflow.engine.state_manager import WorkflowState
from app.core.workflow.engine.variable_pool import VariablePool
from app.core.workflow.nodes.enums import BRANCH_NODES
from app.core.workflow.variable.base_variable import VariableType
from app.services.multimodal_service import PROVIDER_STRATEGIES
from app.core.workflow.variable.base_variable import VariableType, FileObject
from app.db import get_db_read
from app.schemas import FileInput
from app.services.multimodal_service import MultimodalService
logger = logging.getLogger(__name__)
@@ -548,9 +551,9 @@ class BaseNode(ABC):
return render_template(
template=template,
conv_vars=variable_pool.get_all_conversation_vars(),
node_outputs=variable_pool.get_all_node_outputs(),
system_vars=variable_pool.get_all_system_vars(),
conv_vars=variable_pool.get_all_conversation_vars(literal=True),
node_outputs=variable_pool.get_all_node_outputs(literal=True),
system_vars=variable_pool.get_all_system_vars(literal=True),
strict=strict
)
@@ -614,16 +617,32 @@ class BaseNode(ABC):
return variable_pool.has(selector)
@staticmethod
async def process_message(provider, content, enable_file=False) -> dict | str | None:
async def process_message(provider: str, content: str | FileObject, enable_file=False) -> dict | str | None:
if isinstance(content, str):
if enable_file:
return {"text": content}
return content
elif isinstance(content, dict):
trans_tool = PROVIDER_STRATEGIES[provider]()
result = await trans_tool.format_image(content["url"])
return result
raise TypeError('Unexpect input value type')
elif isinstance(content, FileObject):
if content.content_cache.get(provider):
return content.content_cache[provider]
with get_db_read() as db:
multimodel_service = MultimodalService(db, provider)
message = await multimodel_service.process_files(
[FileInput.model_construct(
type=content.type,
url=content.url,
transfer_method=content.transfer_method,
file_type=content.origin_file_type,
upload_file_id=content.file_id
)]
)
if message:
content.content_cache[provider] = message[0]
return message[0]
return None
raise TypeError(f'Unexpect input value type - {type(content)}')
@staticmethod
def process_model_output(content) -> str:

View File

@@ -91,8 +91,8 @@ class IterationRuntime:
return loopstate
def merge_conv_vars(self):
self.variable_pool.get_all_conversation_vars().update(
self.child_variable_pool.get_all_conversation_vars()
self.variable_pool.variables["conv"].update(
self.child_variable_pool.variables["conv"]
)
async def run_task(self, item, idx):

View File

@@ -156,7 +156,7 @@ class LoopRuntime:
def merge_conv_vars(self, loopstate):
self.variable_pool.variables["conv"].update(
self.child_variable_pool.variables.get("conv", {})
self.child_variable_pool.variables["conv"]
)
loop_vars = self.child_variable_pool.get_node_output(self.node_id, defalut={}, strict=False)
loopstate["node_outputs"][self.node_id] = loop_vars

View File

@@ -172,9 +172,9 @@ class LLMNode(BaseNode):
if self.typed_config.vision_input and self.typed_config.vision:
file_content = []
files = variable_pool.get_value(self.typed_config.vision_input)
for file in files:
content = await self.process_message(provider, file, self.typed_config.vision)
files = variable_pool.get_instance(self.typed_config.vision_input)
for file in files.value:
content = await self.process_message(provider, file.value, self.typed_config.vision)
if content:
file_content.append(content)
if messages and messages[-1]["role"] == 'user':

View File

@@ -2,7 +2,7 @@ from enum import StrEnum
from abc import abstractmethod, ABC
from typing import Any
from pydantic import BaseModel
from pydantic import BaseModel, Field
from app.schemas import FileType
@@ -45,7 +45,7 @@ class VariableType(StrEnum):
return cls.NUMBER
elif isinstance(var, bool):
return cls.BOOLEAN
elif isinstance(var, FileObject) or (isinstance(var, dict) and var.get('__file')):
elif isinstance(var, FileObject) or (isinstance(var, dict) and var.get('is_file')):
return cls.FILE
elif isinstance(var, dict):
return cls.OBJECT
@@ -109,7 +109,13 @@ def DEFAULT_VALUE(var_type: VariableType) -> Any:
class FileObject(BaseModel):
type: FileType
url: str
__file: bool
transfer_method: str
origin_file_type: str
file_id: str | None
content_cache: dict = Field(default_factory=dict)
is_file: bool
class BaseVariable(ABC):

View File

@@ -63,13 +63,16 @@ class FileVariable(BaseVariable):
def valid_value(self, value) -> FileObject:
if isinstance(value, dict):
if not value.get("__file"):
if not value.get("is_file"):
raise TypeError(f"Value must be a FileObject - {type(value)}:{value}")
return FileObject(
**{
"type": str(value.get('type')),
"transfer_method": value.get("transfer_method"),
"url": value.get('url'),
"__file": True
"file_id": value.get("file_id"),
"origin_file_type": value.get("origin_file_type"),
"is_file": True
}
)
if isinstance(value, FileObject):

View File

@@ -35,6 +35,7 @@ from .ontology_scene import OntologyScene
from .ontology_class import OntologyClass
from .ontology_scene import OntologyScene
from .ontology_class import OntologyClass
from .implicit_emotions_storage_model import ImplicitEmotionsStorage
__all__ = [
"Tenants",
@@ -90,5 +91,6 @@ __all__ = [
"MemoryPerceptualModel",
"ModelBase",
"LoadBalanceStrategy",
"Skill"
"Skill",
"ImplicitEmotionsStorage"
]

View File

@@ -0,0 +1,45 @@
"""
Implicit Emotions Storage Model
数据库模型:存储用户的隐性记忆画像和情绪建议数据
替代原有的Redis缓存方式
"""
import uuid
from datetime import datetime
from sqlalchemy import Column, String, Text, DateTime, Index
from sqlalchemy.dialects.postgresql import UUID, JSONB
from app.db import Base
class ImplicitEmotionsStorage(Base):
"""隐性记忆和情绪存储表"""
__tablename__ = "implicit_emotions_storage"
# 主键
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, comment="主键ID")
# 用户标识unique=True会自动创建唯一索引
end_user_id = Column(String(255), nullable=False, unique=True, comment="终端用户ID")
# 隐性记忆画像数据JSON格式
implicit_profile = Column(JSONB, nullable=True, comment="隐性记忆用户画像数据")
# 情绪建议数据JSON格式
emotion_suggestions = Column(JSONB, nullable=True, comment="情绪个性化建议数据")
# 时间戳
created_at = Column(DateTime, nullable=False, default=datetime.utcnow, comment="创建时间")
updated_at = Column(DateTime, nullable=False, default=datetime.utcnow, onupdate=datetime.utcnow, comment="更新时间")
# 数据生成时间(用于业务逻辑)
implicit_generated_at = Column(DateTime, nullable=True, comment="隐性记忆画像生成时间")
emotion_generated_at = Column(DateTime, nullable=True, comment="情绪建议生成时间")
# 索引只为updated_at创建索引end_user_id的unique约束已自动创建索引
__table_args__ = (
Index('idx_updated_at', 'updated_at'),
)
def __repr__(self):
return f"<ImplicitEmotionsStorage(id={self.id}, end_user_id={self.end_user_id})>"

View File

@@ -2,7 +2,7 @@ import datetime
import uuid
from enum import StrEnum
from sqlalchemy import Column, String, Boolean, DateTime, Text, ForeignKey, Enum as SQLEnum, UniqueConstraint, Integer, ARRAY, Table
from sqlalchemy import Column, String, Boolean, DateTime, Text, ForeignKey, Enum as SQLEnum, UniqueConstraint, Integer, ARRAY, Table, text
from sqlalchemy.dialects.postgresql import UUID, JSON
from sqlalchemy.orm import relationship
from sqlalchemy.sql import func
@@ -78,6 +78,9 @@ class ModelConfig(BaseModel):
description = Column(String, comment="模型描述")
# 模型配置参数
capability = Column(ARRAY(String), default=list, nullable=False, server_default=text("'{}'::varchar[]"),
comment="模型能力列表(如['vision', 'audio', 'video']")
is_omni = Column(Boolean, default=False, nullable=False, server_default="false", comment="是否为Omni模型使用特殊API调用")
config = Column(JSON, comment="模型配置参数")
# - temperature : 控制生成文本的随机性。值越高,输出越随机、越有创造性;值越低,输出越确定、越保守。
# - top_p : 一种替代 temperature 的采样方法,控制模型从概率最高的词中选择的范围。
@@ -118,6 +121,11 @@ class ModelApiKey(BaseModel):
api_key = Column(String, nullable=False, comment="API密钥")
api_base = Column(String, comment="API基础URL")
# 模型能力参数
capability = Column(ARRAY(String), default=list, nullable=False, server_default=text("'{}'::varchar[]"),
comment="模型能力列表(如['vision', 'audio', 'video']")
is_omni = Column(Boolean, default=False, nullable=False, server_default="false", comment="是否为Omni模型使用特殊API调用")
# 配置参数
config = Column(JSON, comment="API Key特定配置")
@@ -155,6 +163,9 @@ class ModelBase(Base):
tags = Column(ARRAY(String), default=list, nullable=False, comment="模型标签(如['聊天', '创作']")
add_count = Column(Integer, default=0, nullable=False, comment="模型被用户添加的次数")
created_at = Column(DateTime, default=datetime.datetime.now, comment="创建时间", server_default=func.now())
capability = Column(ARRAY(String), default=list, nullable=False, server_default=text("'{}'::varchar[]"),
comment="模型能力列表(如['vision', 'audio', 'video']")
is_omni = Column(Boolean, default=False, nullable=False, server_default="false", comment="是否为Omni模型使用特殊API调用")
# 关联关系
configs = relationship("ModelConfig", back_populates="model_base", cascade="all, delete-orphan")

View File

@@ -0,0 +1,169 @@
"""
Implicit Emotions Storage Repository
数据访问层:处理隐性记忆和情绪数据的数据库操作
事务由调用方控制,仓储层只使用 flush/refresh
"""
import logging
from datetime import datetime, date, timezone, timedelta
from typing import Optional, Generator
from sqlalchemy.orm import Session
from sqlalchemy import select, not_, exists
from app.models.implicit_emotions_storage_model import ImplicitEmotionsStorage
from app.models.end_user_model import EndUser
logger = logging.getLogger(__name__)
class ImplicitEmotionsStorageRepository:
"""隐性记忆和情绪存储仓储类"""
def __init__(self, db: Session):
self.db = db
def get_by_end_user_id(self, end_user_id: str) -> Optional[ImplicitEmotionsStorage]:
"""根据终端用户ID获取存储记录"""
try:
stmt = select(ImplicitEmotionsStorage).where(
ImplicitEmotionsStorage.end_user_id == end_user_id
)
return self.db.execute(stmt).scalar_one_or_none()
except Exception as e:
logger.error(f"获取用户存储记录失败: end_user_id={end_user_id}, error={e}")
return None
def create(self, end_user_id: str) -> ImplicitEmotionsStorage:
"""创建新的存储记录(事务由调用方提交)"""
storage = ImplicitEmotionsStorage(
end_user_id=end_user_id,
created_at=datetime.utcnow(),
updated_at=datetime.utcnow()
)
self.db.add(storage)
self.db.flush()
self.db.refresh(storage)
logger.info(f"创建用户存储记录成功: end_user_id={end_user_id}")
return storage
def update_implicit_profile(
self,
end_user_id: str,
profile_data: dict
) -> ImplicitEmotionsStorage:
"""更新隐性记忆画像数据(事务由调用方提交)"""
storage = self.get_by_end_user_id(end_user_id)
if storage is None:
storage = self.create(end_user_id)
storage.implicit_profile = profile_data
storage.implicit_generated_at = datetime.utcnow()
storage.updated_at = datetime.utcnow()
self.db.flush()
self.db.refresh(storage)
logger.info(f"更新隐性记忆画像成功: end_user_id={end_user_id}")
return storage
def update_emotion_suggestions(
self,
end_user_id: str,
suggestions_data: dict
) -> ImplicitEmotionsStorage:
"""更新情绪建议数据(事务由调用方提交)"""
storage = self.get_by_end_user_id(end_user_id)
if storage is None:
storage = self.create(end_user_id)
storage.emotion_suggestions = suggestions_data
storage.emotion_generated_at = datetime.utcnow()
storage.updated_at = datetime.utcnow()
self.db.flush()
self.db.refresh(storage)
logger.info(f"更新情绪建议成功: end_user_id={end_user_id}")
return storage
def get_all_user_ids(self, batch_size: int = 100) -> Generator[str, None, None]:
"""分批次获取所有已存储数据的用户ID避免大数据量内存溢出
Args:
batch_size: 每批次加载的数量默认100
Yields:
用户ID字符串
"""
offset = 0
while True:
try:
stmt = (
select(ImplicitEmotionsStorage.end_user_id)
.order_by(ImplicitEmotionsStorage.end_user_id)
.limit(batch_size)
.offset(offset)
)
batch = self.db.execute(stmt).scalars().all()
if not batch:
break
yield from batch
offset += batch_size
except Exception as e:
logger.error(f"分批获取用户ID失败: offset={offset}, error={e}")
break
def get_new_user_ids_today(self, batch_size: int = 100) -> Generator[str, None, None]:
"""分批次获取当天新增的、尚未初始化隐性记忆和情绪建议数据的用户ID
查询逻辑end_users 表中 created_at 为今天,且在 implicit_emotions_storage 中没有对应记录。
没有对应记录意味着隐性记忆画像和情绪建议均未初始化,需要对这批用户执行首次初始化。
end_users.idUUID转为字符串后与 implicit_emotions_storage.end_user_idString对比。
Args:
batch_size: 每批次加载的数量默认100
Yields:
用户ID字符串
"""
from sqlalchemy import cast, String as SAString
CST = timezone(timedelta(hours=8))
now_cst = datetime.now(CST)
today_start = now_cst.replace(hour=0, minute=0, second=0, microsecond=0).astimezone(timezone.utc).replace(tzinfo=None)
tomorrow_start = today_start + timedelta(days=1)
offset = 0
while True:
try:
stmt = (
select(EndUser.id)
.where(
EndUser.created_at >= today_start,
EndUser.created_at < tomorrow_start,
not_(
exists(
select(ImplicitEmotionsStorage.end_user_id).where(
ImplicitEmotionsStorage.end_user_id == cast(EndUser.id, SAString)
)
)
)
)
.order_by(EndUser.id)
.limit(batch_size)
.offset(offset)
)
batch = self.db.execute(stmt).scalars().all()
if not batch:
break
yield from (str(uid) for uid in batch)
offset += batch_size
except Exception as e:
logger.error(f"分批获取当天新增用户ID失败: offset={offset}, error={e}")
break
def delete_by_end_user_id(self, end_user_id: str) -> bool:
"""删除用户的存储记录(事务由调用方提交)"""
storage = self.get_by_end_user_id(end_user_id)
if storage:
self.db.delete(storage)
self.db.flush()
logger.info(f"删除用户存储记录成功: end_user_id={end_user_id}")
return True
return False

View File

@@ -211,3 +211,46 @@ def get_total_kb_count_by_workspace(db: Session, workspace_id: uuid.UUID) -> int
except Exception as e:
db_logger.error(f"Failed to query total knowledge base count: workspace_id={workspace_id} - {str(e)}")
raise
def get_user_kb_chunk_num_by_workspace(db: Session, workspace_id: uuid.UUID) -> int:
"""
根据workspace_id查询knowledges表中permission_id='Memory'用户知识库的chunk_num总和
"""
db_logger.debug(f"Query user KB chunk_num by workspace_id: workspace_id={workspace_id}")
try:
from sqlalchemy import func
result = db.query(func.sum(Knowledge.chunk_num)).filter(
Knowledge.workspace_id == workspace_id,
Knowledge.status == 1,
Knowledge.permission_id == "Memory"
).scalar()
total = result if result is not None else 0
db_logger.info(f"User KB chunk_num query successful: workspace_id={workspace_id}, total={total}")
return total
except Exception as e:
db_logger.error(f"Failed to query user KB chunk_num: workspace_id={workspace_id} - {str(e)}")
raise
def get_non_user_kb_count_by_workspace(db: Session, workspace_id: uuid.UUID) -> int:
"""
根据workspace_id查询knowledges表中排除用户知识库permission_id!='Memory')的数量
"""
db_logger.debug(f"Query non-user KB count by workspace_id: workspace_id={workspace_id}")
try:
count = db.query(Knowledge).filter(
Knowledge.workspace_id == workspace_id,
Knowledge.status == 1,
Knowledge.permission_id != "Memory"
).count()
db_logger.info(f"Non-user KB count query successful: workspace_id={workspace_id}, count={count}")
return count
except Exception as e:
db_logger.error(f"Failed to query non-user KB count: workspace_id={workspace_id} - {str(e)}")
raise

View File

@@ -374,7 +374,7 @@ class OntologySceneRepository:
count = self.db.query(OntologyScene).filter(
OntologyScene.scene_id == scene_id,
OntologyScene.workspace_id == workspace_id
(OntologyScene.workspace_id == workspace_id) | (OntologyScene.is_system_default == True)
).count()
is_owner = count > 0

View File

@@ -15,7 +15,7 @@ class ApiKeyCreate(BaseModel):
type: ApiKeyType = Field(..., description="API Key 类型")
scopes: List[str] = Field(default_factory=list, description="权限范围列表")
resource_id: Optional[uuid.UUID] = Field(None, description="关联资源ID")
rate_limit: Optional[int] = Field(10, ge=1, le=1000, description="QPS限制请求/秒)")
rate_limit: Optional[int] = Field(100, ge=1, le=1000, description="QPS限制请求/秒)")
daily_request_limit: Optional[int] = Field(10000, description="日请求限制", ge=1)
quota_limit: Optional[int] = Field(None, description="配额限制(总请求数)", ge=1)
expires_at: Optional[datetime.datetime] = Field(None, description="过期时间")
@@ -155,8 +155,7 @@ class ApiKey(BaseModel):
return datetime.datetime.now() > self.expires_at
@field_serializer('expires_at', 'last_used_at', 'created_at', 'updated_at')
@classmethod
def serialize_datetime(cls, v: Optional[datetime.datetime]) -> Optional[int]:
def serialize_datetime(self, v: Optional[datetime.datetime]) -> Optional[int]:
"""将datetime转换为时间戳"""
return datetime_to_timestamp(v)
@@ -171,8 +170,7 @@ class ApiKeyStats(BaseModel):
avg_response_time: Optional[float] = Field(None, description="平均响应时间(毫秒)")
@field_serializer('last_used_at')
@classmethod
def serialize_datetime(cls, v: Optional[datetime.datetime]) -> Optional[int]:
def serialize_datetime(self, v: Optional[datetime.datetime]) -> Optional[int]:
"""将datetime转换为时间戳"""
return datetime_to_timestamp(v)
@@ -219,7 +217,6 @@ class ApiKeyLog(BaseModel):
created_at: datetime.datetime
@field_serializer('created_at')
@classmethod
def serialize_datetime(cls, v: datetime.datetime) -> int:
def serialize_datetime(self, v: datetime.datetime) -> int:
"""将datetime转换为时间戳"""
return datetime_to_timestamp(v)

View File

@@ -21,8 +21,14 @@ class FileType(StrEnum):
def trans(cls, value: str) -> 'FileType':
if value.startswith("image"):
return cls.IMAGE
# TODO: other file type support
raise RuntimeError("Unsupport file type")
elif value.startswith("document"):
return cls.DOCUMENT
elif value.startswith("audio"):
return cls.AUDIO
elif value.startswith("video"):
return cls.VIDEO
else:
raise RuntimeError("Unsupport file type")
class TransferMethod(str, Enum):
@@ -37,6 +43,12 @@ class FileInput(BaseModel):
transfer_method: TransferMethod = Field(..., description="传输方式: local_file/remote_url")
upload_file_id: Optional[uuid.UUID] = Field(None, description="已上传文件IDlocal_file时必填")
url: Optional[str] = Field(None, description="远程URLremote_url时必填")
file_type: Optional[str] = Field(None, description="具体文件格式如image/jpg、audio/wav、document/docx、video/mp4")
def __init__(self, **data):
if "type" in data:
data['file_type'] = data['type']
super().__init__(**data)
@field_validator("type", mode="before")
@classmethod

View File

@@ -46,6 +46,7 @@ class ChunkUpdate(BaseModel):
class ChunkRetrieve(BaseModel):
query: str
kb_ids: list[uuid.UUID]
file_names_filter: list[str] | None = Field(None)
similarity_threshold: float | None = Field(None)
vector_similarity_weight: float | None = Field(None)
top_k: int | None = Field(None)

View File

@@ -21,6 +21,8 @@ class ModelConfigBase(BaseModel):
is_active: bool = Field(True, description="是否激活")
is_public: bool = Field(False, description="是否公开")
load_balance_strategy: Optional[str] = Field(LoadBalanceStrategy.NONE.value, description="负载均衡策略")
capability: List[str] = Field(default_factory=list, description="模型能力列表")
is_omni: bool = Field(False, description="是否为Omni模型")
class ApiKeyCreateNested(BaseModel):
@@ -30,6 +32,8 @@ class ApiKeyCreateNested(BaseModel):
provider: Optional[str] = Field(None, description="API Key提供商")
api_key: str = Field(..., description="API密钥", max_length=500)
api_base: Optional[str] = Field(None, description="API基础URL", max_length=500)
capability: Optional[List[str]] = Field(None, description="模型能力列表")
is_omni: Optional[bool] = Field(None, description="是否为Omni模型")
config: Optional[Dict[str, Any]] = Field({}, description="API Key特定配置")
priority: str = Field("1", description="优先级", max_length=10)
@@ -63,6 +67,8 @@ class ModelConfigUpdate(BaseModel):
config: Optional[Dict[str, Any]] = Field(None, description="模型配置参数")
is_active: Optional[bool] = Field(None, description="是否激活")
is_public: Optional[bool] = Field(None, description="是否公开")
capability: Optional[List[str]] = Field(None, description="模型能力列表")
is_omni: Optional[bool] = Field(None, description="是否为Omni模型")
class ModelConfig(ModelConfigBase):
@@ -95,6 +101,8 @@ class ModelApiKeyCreateByProvider(BaseModel):
api_key: str = Field(..., description="API密钥", max_length=500)
api_base: Optional[str] = Field(None, description="API基础URL", max_length=500)
description: Optional[str] = Field(None, description="备注")
capability: Optional[List[str]] = Field(None, description="模型能力列表")
is_omni: Optional[bool] = Field(None, description="是否为Omni模型")
config: Optional[Dict[str, Any]] = Field({}, description="API Key特定配置")
is_active: bool = Field(True, description="是否激活")
priority: str = Field("1", description="优先级", max_length=10)
@@ -108,6 +116,8 @@ class ModelApiKeyBase(BaseModel):
provider: ModelProvider = Field(..., description="API Key提供商")
api_key: str = Field(..., description="API密钥", max_length=500)
api_base: Optional[str] = Field(None, description="API基础URL", max_length=500)
capability: Optional[List[str]] = Field(None, description="模型能力列表")
is_omni: Optional[bool] = Field(None, description="是否为Omni模型")
config: Optional[Dict[str, Any]] = Field({}, description="API Key特定配置")
is_active: bool = Field(True, description="是否激活")
priority: str = Field("1", description="优先级", max_length=10)
@@ -124,6 +134,8 @@ class ModelApiKeyUpdate(BaseModel):
provider: Optional[ModelProvider] = Field(None, description="API Key提供商")
api_key: Optional[str] = Field(None, description="API密钥", max_length=500)
api_base: Optional[str] = Field(None, description="API基础URL", max_length=500)
capability: Optional[List[str]] = Field(None, description="模型能力列表")
is_omni: Optional[bool] = Field(None, description="是否为Omni模型")
config: Optional[Dict[str, Any]] = Field(None, description="API Key特定配置")
is_active: Optional[bool] = Field(None, description="是否激活")
priority: Optional[str] = Field(None, description="优先级", max_length=10)
@@ -270,6 +282,8 @@ class ModelBaseCreate(BaseModel):
description: Optional[str] = Field(None, description="模型描述")
is_official: bool = Field(True, description="是否供应商官方模型")
tags: List[str] = Field(default_factory=list, description="模型标签")
capability: List[str] = Field(default_factory=list, description="模型能力列表(如['vision', 'audio', 'video']")
is_omni: bool = Field(False, description="是否为Omni模型")
class ModelBaseUpdate(BaseModel):
@@ -282,6 +296,8 @@ class ModelBaseUpdate(BaseModel):
is_deprecated: Optional[bool] = Field(None, description="是否弃用")
is_official: Optional[bool] = Field(None, description="是否供应商官方模型")
tags: Optional[List[str]] = Field(None, description="模型标签")
capability: Optional[List[str]] = Field(None, description="模型能力列表")
is_omni: Optional[bool] = Field(None, description="是否为Omni模型")
class ModelBase(BaseModel):
@@ -298,6 +314,8 @@ class ModelBase(BaseModel):
is_official: bool
tags: List[str]
add_count: int
capability: List[str] = []
is_omni: bool = False
class ModelBaseQuery(BaseModel):

View File

@@ -64,14 +64,14 @@ class ExecutionConfig(BaseModel):
class MultiAgentConfigCreate(BaseModel):
"""创建多 Agent 配置"""
master_agent_id: uuid.UUID = Field(..., description="主 Agent ID")
master_agent_name: Optional[str] = Field(None, max_length=100, description="主 Agent 名称")
master_agent_name: Optional[str] = Field(default=None, max_length=100, description="主 Agent 名称")
orchestration_mode: str = Field(
default="collaboration",
pattern="^(collaboration|supervisor)$",
description="协作模式collaboration协作| supervisor监督"
)
sub_agents: List[SubAgentConfig] = Field(..., description="子 Agent 列表")
routing_rules: Optional[List[RoutingRule]] = Field(None, description="路由规则")
routing_rules: Optional[List[RoutingRule]] = Field(default=None, description="路由规则")
execution_config: ExecutionConfig = Field(default_factory=ExecutionConfig, description="执行配置")
aggregation_strategy: str = Field(
default="merge",
@@ -83,7 +83,7 @@ class MultiAgentConfigCreate(BaseModel):
class MultiAgentConfigUpdate(BaseModel):
"""更新多 Agent 配置"""
master_agent_id: Optional[uuid.UUID] = None
master_agent_name: Optional[str] = Field(None, max_length=100, description="主 Agent 名称")
master_agent_name: Optional[str] = Field(default=None, max_length=100, description="主 Agent 名称")
default_model_config_id: Optional[uuid.UUID] = Field(None, description="默认模型配置ID")
model_parameters: Optional[ModelParameters] = Field(
None,

View File

@@ -241,6 +241,7 @@ class SceneResponse(BaseModel):
created_at: datetime.datetime = Field(..., description="创建时间(毫秒时间戳)")
updated_at: datetime.datetime = Field(..., description="更新时间(毫秒时间戳)")
classes_count: int = Field(0, description="类型数量")
is_system_default: bool = Field(False, description="是否为系统默认场景")
@field_serializer("created_at", when_used="json")
def _serialize_created_at(self, dt: datetime.datetime):
@@ -462,6 +463,7 @@ class ClassListResponse(BaseModel):
scene_id: UUID = Field(..., description="所属场景ID")
scene_name: str = Field(..., description="场景名称")
scene_description: Optional[str] = Field(None, description="场景描述")
is_system_default: bool = Field(False, description="是否为系统默认场景")
items: List[ClassResponse] = Field(..., description="类型列表")

View File

@@ -263,8 +263,8 @@ def create_agent_invocation_tool(
try:
# 9. 调用 Agent
from app.services.draft_run_service import DraftRunService
draft_service = DraftRunService(db)
from app.services.draft_run_service import AgentRunService
draft_service = AgentRunService(db)
result = await draft_service.run(
agent_config=agent_config,

View File

@@ -10,25 +10,24 @@ from sqlalchemy.orm import Session
from app.core.agent.agent_middleware import AgentMiddleware
from app.core.agent.langchain_agent import LangChainAgent
from app.core.error_codes import BizCode
from app.core.exceptions import BusinessException
from app.core.logging_config import get_business_logger
from app.db import get_db, get_db_context
from app.models import MultiAgentConfig, AgentConfig, WorkflowConfig
from app.schemas import DraftRunRequest
from app.schemas.app_schema import FileInput
from app.services.tool_service import ToolService
from app.repositories.tool_repository import ToolRepository
from app.db import get_db
from app.models import MultiAgentConfig, AgentConfig
from app.models import WorkflowConfig
from app.repositories.tool_repository import ToolRepository
from app.schemas import DraftRunRequest
from app.schemas.app_schema import FileInput
from app.schemas.prompt_schema import render_prompt_message, PromptMessageRole
from app.services.conversation_service import ConversationService
from app.services.draft_run_service import create_knowledge_retrieval_tool, create_long_term_memory_tool
from app.services.draft_run_service import create_knowledge_retrieval_tool, create_long_term_memory_tool, \
AgentRunService
from app.services.draft_run_service import create_web_search_tool
from app.services.model_service import ModelApiKeyService
from app.services.multi_agent_orchestrator import MultiAgentOrchestrator
from app.services.workflow_service import WorkflowService
from app.services.multimodal_service import MultimodalService
from app.services.tool_service import ToolService
from app.services.workflow_service import WorkflowService
logger = get_business_logger()
@@ -39,6 +38,8 @@ class AppChatService:
def __init__(self, db: Session):
self.db = db
self.conversation_service = ConversationService(db)
self.agent_service = AgentRunService(db)
self.workflow_service = WorkflowService(db)
async def agnet_chat(
self,
@@ -55,12 +56,10 @@ class AppChatService:
files: Optional[List[FileInput]] = None # 新增:多模态文件
) -> Dict[str, Any]:
"""聊天(非流式)"""
start_time = time.time()
config_id = None
if variables is None:
variables = {}
variables = self.agent_service.prepare_variables(variables, config.variables)
# 获取模型配置ID
model_config_id = config.default_model_config_id
@@ -79,74 +78,20 @@ class AppChatService:
tools = []
# 获取工具服务
tool_service = ToolService(self.db)
tenant_id = ToolRepository.get_tenant_id_by_workspace_id(self.db, str(workspace_id))
# 从配置中获取启用的工具
if hasattr(config, 'tools') and config.tools and isinstance(config.tools, list):
for tool_config in config.tools:
if tool_config.get("enabled", False):
# 根据工具名称查找工具实例
tool_instance = tool_service._get_tool_instance(tool_config.get("tool_id", ""), tenant_id)
if tool_instance:
if tool_instance.name == "baidu_search_tool" and not web_search:
continue
# 转换为LangChain工具
langchain_tool = tool_instance.to_langchain_tool(tool_config.get("operation", None))
tools.append(langchain_tool)
elif hasattr(config, 'tools') and config.tools and isinstance(config.tools, dict):
web_tools = config.tools
web_search_choice = web_tools.get("web_search", {})
web_search_enable = web_search_choice.get("enabled", False)
if web_search:
if web_search_enable:
search_tool = create_web_search_tool({})
tools.append(search_tool)
logger.debug(
"已添加网络搜索工具",
extra={
"tool_count": len(tools)
}
)
# 加载技能关联的工具
if hasattr(config, 'skills') and config.skills:
skills = config.skills
skill_enable = skills.get("enabled", False)
if skill_enable:
middleware = AgentMiddleware(skills=skills)
skill_tools, skill_configs, tool_to_skill_map = middleware.load_skill_tools(self.db, tenant_id)
tools.extend(skill_tools)
logger.debug(f"已加载 {len(skill_tools)} 个技能工具")
# 应用动态过滤
if skill_configs:
tools, activated_skill_ids = middleware.filter_tools(tools, message, skill_configs,
tool_to_skill_map)
logger.debug(f"过滤后剩余 {len(tools)} 个工具")
active_prompts = AgentMiddleware.get_active_prompts(
activated_skill_ids, skill_configs
)
system_prompt = f"{system_prompt}\n\n{active_prompts}"
# 添加知识库检索工具
knowledge_retrieval = config.knowledge_retrieval
if knowledge_retrieval:
knowledge_bases = knowledge_retrieval.get("knowledge_bases", [])
kb_ids = [kb.get("kb_id") for kb in knowledge_bases if kb.get("kb_id")]
if kb_ids:
kb_tool = create_knowledge_retrieval_tool(knowledge_retrieval, kb_ids, user_id)
tools.append(kb_tool)
# 添加长期记忆工具
tools.extend(self.agent_service.load_tools_config(config.tools, web_search, tenant_id))
skill_tools, skill_prompts = self.agent_service.load_skill_config(config.skills, message, tenant_id)
tools.extend(skill_tools)
if skill_prompts:
system_prompt = f"{system_prompt}\n\n{skill_prompts}"
tools.extend(self.agent_service.load_knowledge_retrieval_config(config.knowledge_retrieval, user_id))
memory_flag = False
if memory == True:
memory_config = config.memory
if memory_config.get("enabled") and user_id:
memory_flag = True
memory_tool = create_long_term_memory_tool(memory_config, user_id)
tools.append(memory_tool)
if memory:
memory_tools, memory_flag = self.agent_service.load_memory_config(
config.memory, user_id, storage_type, user_rag_memory_id
)
tools.extend(memory_tools)
# 获取模型参数
model_parameters = config.model_parameters
@@ -157,6 +102,7 @@ class AppChatService:
api_key=api_key_obj.api_key,
provider=api_key_obj.provider,
api_base=api_key_obj.api_base,
is_omni=api_key_obj.is_omni,
temperature=model_parameters.get("temperature", 0.7),
max_tokens=model_parameters.get("max_tokens", 2000),
system_prompt=system_prompt,
@@ -180,7 +126,7 @@ class AppChatService:
# 处理多模态文件
processed_files = None
if files:
multimodal_service = MultimodalService(self.db)
multimodal_service = MultimodalService(self.db, api_key_obj.provider, is_omni=api_key_obj.is_omni)
processed_files = await multimodal_service.process_files(files)
logger.info(f"处理了 {len(processed_files)} 个文件")
@@ -245,10 +191,9 @@ class AppChatService:
try:
start_time = time.time()
config_id = None
yield f"event: start\ndata: {json.dumps({'conversation_id': str(conversation_id)}, ensure_ascii=False)}\n\n"
if variables is None:
variables = {}
variables = self.agent_service.prepare_variables(variables, config.variables)
# 获取模型配置ID
model_config_id = config.default_model_config_id
api_key_obj = ModelApiKeyService.get_available_api_key(self.db, model_config_id)
@@ -266,73 +211,22 @@ class AppChatService:
tools = []
# 获取工具服务
tool_service = ToolService(self.db)
tenant_id = ToolRepository.get_tenant_id_by_workspace_id(self.db, str(workspace_id))
if hasattr(config, 'tools') and config.tools and isinstance(config.tools, list):
for tool_config in config.tools:
if tool_config.get("enabled", False):
# 根据工具名称查找工具实例
tool_instance = tool_service._get_tool_instance(tool_config.get("tool_id", ""), tenant_id)
if tool_instance:
if tool_instance.name == "baidu_search_tool" and not web_search:
continue
# 转换为LangChain工具
langchain_tool = tool_instance.to_langchain_tool(tool_config.get("operation", None))
tools.append(langchain_tool)
elif hasattr(config, 'tools') and config.tools and isinstance(config.tools, dict):
web_tools = config.tools
web_search_choice = web_tools.get("web_search", {})
web_search_enable = web_search_choice.get("enabled", False)
if web_search:
if web_search_enable:
search_tool = create_web_search_tool({})
tools.append(search_tool)
logger.debug(
"已添加网络搜索工具",
extra={
"tool_count": len(tools)
}
)
# 加载技能关联的工具
if hasattr(config, 'skills') and config.skills:
skills = config.skills
skill_enable = skills.get("enabled", False)
if skill_enable:
middleware = AgentMiddleware(skills=skills)
skill_tools, skill_configs, tool_to_skill_map = middleware.load_skill_tools(self.db, tenant_id)
tools.extend(skill_tools)
logger.debug(f"已加载 {len(skill_tools)} 个技能工具")
# 应用动态过滤
if skill_configs:
tools, activated_skill_ids = middleware.filter_tools(tools, message, skill_configs,
tool_to_skill_map)
logger.debug(f"过滤后剩余 {len(tools)} 个工具")
active_prompts = AgentMiddleware.get_active_prompts(
activated_skill_ids, skill_configs
)
system_prompt = f"{system_prompt}\n\n{active_prompts}"
# 添加知识库检索工具
knowledge_retrieval = config.knowledge_retrieval
if knowledge_retrieval:
knowledge_bases = knowledge_retrieval.get("knowledge_bases", [])
kb_ids = [kb.get("kb_id") for kb in knowledge_bases if kb.get("kb_id")]
if kb_ids:
kb_tool = create_knowledge_retrieval_tool(knowledge_retrieval, kb_ids, user_id)
tools.append(kb_tool)
tools.extend(self.agent_service.load_tools_config(config.tools, web_search, tenant_id))
skill_tools, skill_prompts = self.agent_service.load_skill_config(config.skills, message, tenant_id)
tools.extend(skill_tools)
if skill_prompts:
system_prompt = f"{system_prompt}\n\n{skill_prompts}"
tools.extend(self.agent_service.load_knowledge_retrieval_config(config.knowledge_retrieval, user_id))
# 添加长期记忆工具
memory_flag = False
if memory:
memory_config = config.memory
if memory_config.get("enabled") and user_id:
memory_flag = True
memory_tool = create_long_term_memory_tool(memory_config, user_id)
tools.append(memory_tool)
memory_tools, memory_flag = self.agent_service.load_memory_config(
config.memory, user_id, storage_type, user_rag_memory_id
)
tools.extend(memory_tools)
# 获取模型参数
model_parameters = config.model_parameters
@@ -343,6 +237,7 @@ class AppChatService:
api_key=api_key_obj.api_key,
provider=api_key_obj.provider,
api_base=api_key_obj.api_base,
is_omni=api_key_obj.is_omni,
temperature=model_parameters.get("temperature", 0.7),
max_tokens=model_parameters.get("max_tokens", 2000),
system_prompt=system_prompt,
@@ -366,13 +261,10 @@ class AppChatService:
# 处理多模态文件
processed_files = None
if files:
multimodal_service = MultimodalService(self.db)
multimodal_service = MultimodalService(self.db, api_key_obj.provider, is_omni=api_key_obj.is_omni)
processed_files = await multimodal_service.process_files(files)
logger.info(f"处理了 {len(processed_files)} 个文件")
# 发送开始事件
yield f"event: start\ndata: {json.dumps({'conversation_id': str(conversation_id)}, ensure_ascii=False)}\n\n"
# 流式调用 Agent支持多模态
full_content = ""
total_tokens = 0
@@ -416,7 +308,7 @@ class AppChatService:
ModelApiKeyService.record_api_key_usage(self.db, api_key_obj.id)
# 发送结束事件
end_data = {"elapsed_time": elapsed_time, "message_length": len(full_content)}
end_data = {"elapsed_time": elapsed_time, "message_length": len(full_content), "error": None}
yield f"event: end\ndata: {json.dumps(end_data, ensure_ascii=False)}\n\n"
logger.info(
@@ -435,7 +327,7 @@ class AppChatService:
except Exception as e:
logger.error(f"流式聊天失败: {str(e)}", exc_info=True)
# 发送错误事件
yield f"event: error\ndata: {json.dumps({'error': str(e)}, ensure_ascii=False)}\n\n"
yield f"event: end\ndata: {json.dumps({'error': str(e)}, ensure_ascii=False)}\n\n"
async def multi_agent_chat(
self,
@@ -489,10 +381,10 @@ class AppChatService:
"mode": result.get("mode"),
"elapsed_time": result.get("elapsed_time"),
"usage": result.get("usage", {
"prompt_tokens": 0,
"completion_tokens": 0,
"total_tokens": 0
})
"prompt_tokens": 0,
"completion_tokens": 0,
"total_tokens": 0
})
}
)
@@ -522,8 +414,6 @@ class AppChatService:
"""多 Agent 聊天(流式)"""
start_time = time.time()
actual_config_id = None
config_id = actual_config_id
if variables is None:
variables = {}
@@ -629,7 +519,6 @@ class AppChatService:
user_rag_memory_id: Optional[str] = None,
) -> Dict[str, Any]:
"""聊天(非流式)"""
workflow_service = WorkflowService(self.db)
payload = DraftRunRequest(
message=message,
variables=variables,
@@ -637,7 +526,7 @@ class AppChatService:
stream=True,
user_id=user_id
)
return await workflow_service.run(
return await self.workflow_service.run(
app_id=app_id,
payload=payload,
config=config,
@@ -664,7 +553,6 @@ class AppChatService:
) -> AsyncGenerator[dict, None]:
"""聊天(流式)"""
workflow_service = WorkflowService(self.db)
payload = DraftRunRequest(
message=message,
variables=variables,
@@ -673,7 +561,7 @@ class AppChatService:
user_id=user_id,
files=files
)
async for event in workflow_service.run_stream(
async for event in self.workflow_service.run_stream(
app_id=app_id,
payload=payload,
config=config,

View File

@@ -232,7 +232,7 @@ class AppService:
# 检查主 Agent 的模型配置
multi_agent_config.default_model_config_id = master_agent_release.default_model_config_id
model_api_key = ModelApiKeyService.get_a_api_key(self.db, multi_agent_config.default_model_config_id)
model_api_key = ModelApiKeyService.get_available_api_key(self.db, multi_agent_config.default_model_config_id)
if not model_api_key:
raise ResourceNotFoundException("模型配置", str(multi_agent_config.default_model_config_id))
@@ -1791,372 +1791,6 @@ class AppService:
return shares
# ==================== 试运行功能 ====================
async def draft_run(
self,
*,
app_id: uuid.UUID,
message: str,
conversation_id: Optional[str] = None,
user_id: Optional[str] = None,
variables: Optional[Dict[str, Any]] = None,
workspace_id: Optional[uuid.UUID] = None
) -> Dict[str, Any]:
"""试运行 Agent使用当前草稿配置
Args:
app_id: 应用ID
message: 用户消息
conversation_id: 会话ID用于多轮对话
user_id: 用户ID用于会话管理
variables: 自定义变量参数值
workspace_id: 工作空间ID用于权限验证
Returns:
Dict: 包含 AI 回复和元数据的字典
Raises:
ResourceNotFoundException: 当应用不存在时
BusinessException: 当应用类型不支持或配置缺失时
"""
from app.services.draft_run_service import DraftRunService
logger.info("试运行 Agent", extra={"app_id": str(app_id), "user_message": message[:50]})
# 1. 验证应用
app = self._get_app_or_404(app_id)
if app.type != "agent":
raise BusinessException("只有 Agent 类型应用支持试运行", BizCode.APP_TYPE_NOT_SUPPORTED)
# 只读操作,允许访问共享应用
self._validate_app_accessible(app, workspace_id)
# 2. 获取 Agent 配置
stmt = select(AgentConfig).where(AgentConfig.app_id == app_id)
agent_cfg = self.db.scalars(stmt).first()
if not agent_cfg:
raise BusinessException("Agent 配置不存在,无法试运行", BizCode.AGENT_CONFIG_MISSING)
# 3. 获取模型配置
model_config = None
if agent_cfg.default_model_config_id:
from app.models import ModelConfig
model_config = self.db.get(ModelConfig, agent_cfg.default_model_config_id)
if not model_config:
raise BusinessException("模型配置不存在,无法试运行", BizCode.AGENT_CONFIG_MISSING)
# 4. 调用试运行服务
logger.debug(
"准备调用试运行服务",
extra={
"app_id": str(app_id),
"model": model_config.name,
"has_conversation_id": bool(conversation_id),
"has_variables": bool(variables)
}
)
draft_service = DraftRunService(self.db)
result = await draft_service.run(
agent_config=agent_cfg,
model_config=model_config,
message=message,
workspace_id=workspace_id,
conversation_id=conversation_id,
user_id=user_id,
variables=variables
)
logger.debug(
"试运行服务返回结果",
extra={
"result_type": str(type(result)),
"result_keys": list(result.keys()) if isinstance(result, dict) else "not_dict",
"has_message": "message" in result if isinstance(result, dict) else False,
"has_conversation_id": "conversation_id" in result if isinstance(result, dict) else False
}
)
logger.info(
"试运行完成",
extra={
"app_id": str(app_id),
"elapsed_time": result.get("elapsed_time"),
"model": model_config.name
}
)
return result
async def draft_run_stream(
self,
*,
app_id: uuid.UUID,
message: str,
conversation_id: Optional[str] = None,
user_id: Optional[str] = None,
variables: Optional[Dict[str, Any]] = None,
workspace_id: Optional[uuid.UUID] = None
):
"""试运行 Agent流式返回
Args:
app_id: 应用ID
message: 用户消息
conversation_id: 会话ID用于多轮对话
user_id: 用户ID用于会话管理
variables: 自定义变量参数值
workspace_id: 工作空间ID用于权限验证
Yields:
str: SSE 格式的事件数据
Raises:
ResourceNotFoundException: 当应用不存在时
BusinessException: 当应用类型不支持或配置缺失时
"""
from app.services.draft_run_service import DraftRunService
logger.info("流式试运行 Agent", extra={"app_id": str(app_id), "user_message": message[:50]})
# 1. 验证应用
app = self._get_app_or_404(app_id)
if app.type != "agent":
raise BusinessException("只有 Agent 类型应用支持试运行", BizCode.APP_TYPE_NOT_SUPPORTED)
# 只读操作,允许访问共享应用
self._validate_app_accessible(app, workspace_id)
# 2. 获取 Agent 配置
stmt = select(AgentConfig).where(AgentConfig.app_id == app_id)
agent_cfg = self.db.scalars(stmt).first()
if not agent_cfg:
raise BusinessException("Agent 配置不存在,无法试运行", BizCode.AGENT_CONFIG_MISSING)
# 3. 获取模型配置
model_config = None
if agent_cfg.default_model_config_id:
from app.models import ModelConfig
model_config = self.db.get(ModelConfig, agent_cfg.default_model_config_id)
if not model_config:
raise BusinessException("模型配置不存在,无法试运行", BizCode.AGENT_CONFIG_MISSING)
# 4. 调用流式试运行服务
draft_service = DraftRunService(self.db)
async for event in draft_service.run_stream(
agent_config=agent_cfg,
model_config=model_config,
message=message,
workspace_id=workspace_id,
conversation_id=conversation_id,
user_id=user_id,
variables=variables
):
yield event
# ==================== 多模型对比试运行 ====================
async def draft_run_compare(
self,
*,
app_id: uuid.UUID,
message: str,
models: List[app_schema.ModelCompareItem],
conversation_id: Optional[str] = None,
user_id: Optional[str] = None,
variables: Optional[Dict[str, Any]] = None,
workspace_id: Optional[uuid.UUID] = None,
parallel: bool = True,
timeout: int = 60
) -> Dict[str, Any]:
"""多模型对比试运行
Args:
app_id: 应用ID
message: 用户消息
models: 要对比的模型列表
conversation_id: 会话ID
user_id: 用户ID
variables: 变量参数
workspace_id: 工作空间ID
parallel: 是否并行执行
timeout: 超时时间(秒)
Returns:
Dict: 对比结果
"""
from app.models import ModelConfig
from app.services.draft_run_service import DraftRunService
logger.info(
"多模型对比试运行",
extra={
"app_id": str(app_id),
"model_count": len(models),
"parallel": parallel
}
)
# 1. 验证应用
app = self._get_app_or_404(app_id)
if app.type != "agent":
raise BusinessException("只有 Agent 类型应用支持试运行", BizCode.APP_TYPE_NOT_SUPPORTED)
# 只读操作,允许访问共享应用
self._validate_app_accessible(app, workspace_id)
# 2. 获取 Agent 配置
stmt = select(AgentConfig).where(AgentConfig.app_id == app_id)
agent_cfg = self.db.scalars(stmt).first()
if not agent_cfg:
raise BusinessException("Agent 配置不存在", BizCode.AGENT_CONFIG_MISSING)
# 3. 准备所有模型配置
model_configs = []
for model_item in models:
model_config = self.db.get(ModelConfig, model_item.model_config_id)
if not model_config:
raise ResourceNotFoundException("模型配置", str(model_item.model_config_id))
# 合并参数agent配置参数 + 请求覆盖参数
merged_parameters = {
**(agent_cfg.model_parameters or {}),
**(model_item.model_parameters or {})
}
model_configs.append({
"model_config": model_config,
"parameters": merged_parameters,
"label": model_item.label or model_config.name,
"model_config_id": model_item.model_config_id
})
# 4. 调用 DraftRunService 的对比方法
draft_service = DraftRunService(self.db)
result = await draft_service.run_compare(
agent_config=agent_cfg,
models=model_configs,
message=message,
workspace_id=workspace_id,
conversation_id=conversation_id,
user_id=user_id,
variables=variables,
parallel=parallel,
timeout=timeout
)
logger.info(
"多模型对比完成",
extra={
"app_id": str(app_id),
"successful": result["successful_count"],
"failed": result["failed_count"]
}
)
return result
async def draft_run_compare_stream(
self,
*,
app_id: uuid.UUID,
message: str,
models: List[app_schema.ModelCompareItem],
conversation_id: Optional[str] = None,
user_id: Optional[str] = None,
variables: Optional[Dict[str, Any]] = None,
workspace_id: Optional[uuid.UUID] = None,
parallel: bool = True,
timeout: int = 60
):
"""多模型对比试运行(流式返回)
Args:
app_id: 应用ID
message: 用户消息
models: 要对比的模型列表
conversation_id: 会话ID
user_id: 用户ID
variables: 变量参数
workspace_id: 工作空间ID
timeout: 超时时间(秒)
Yields:
str: SSE 格式的事件数据
"""
from app.models import ModelConfig
from app.services.draft_run_service import DraftRunService
logger.info(
"多模型对比流式试运行",
extra={
"app_id": str(app_id),
"model_count": len(models)
}
)
# 1. 验证应用
app = self._get_app_or_404(app_id)
if app.type != "agent":
raise BusinessException("只有 Agent 类型应用支持试运行", BizCode.APP_TYPE_NOT_SUPPORTED)
# 只读操作,允许访问共享应用
self._validate_app_accessible(app, workspace_id)
# 2. 获取 Agent 配置
stmt = select(AgentConfig).where(AgentConfig.app_id == app_id)
agent_cfg = self.db.scalars(stmt).first()
if not agent_cfg:
raise BusinessException("Agent 配置不存在", BizCode.AGENT_CONFIG_MISSING)
# 3. 准备所有模型配置
model_configs = []
for model_item in models:
model_config = self.db.get(ModelConfig, model_item.model_config_id)
if not model_config:
raise ResourceNotFoundException("模型配置", str(model_item.model_config_id))
# 合并参数agent配置参数 + 请求覆盖参数
merged_parameters = {
**(agent_cfg.model_parameters or {}),
**(model_item.model_parameters or {})
}
model_configs.append({
"model_config": model_config,
"parameters": merged_parameters,
"label": model_item.label or model_config.name,
"model_config_id": model_item.model_config_id
})
# 4. 调用 DraftRunService 的流式对比方法
draft_service = DraftRunService(self.db)
async for event in draft_service.run_compare_stream(
agent_config=agent_cfg,
models=model_configs,
message=message,
workspace_id=workspace_id,
conversation_id=conversation_id,
user_id=user_id,
variables=variables,
parallel=parallel,
timeout=timeout
):
yield event
logger.info(
"多模型对比流式完成",
extra={"app_id": str(app_id)}
)
# ==================== 向后兼容的函数接口 ====================
# 保留函数接口以兼容现有代码,但内部使用服务类
@@ -2278,53 +1912,6 @@ def get_apps_by_ids(
return service.get_apps_by_ids(app_ids, workspace_id)
# ==================== 向后兼容的函数接口 ====================
async def draft_run(
db: Session,
*,
app_id: uuid.UUID,
message: str,
conversation_id: Optional[str] = None,
user_id: Optional[str] = None,
variables: Optional[Dict[str, Any]] = None,
workspace_id: Optional[uuid.UUID] = None
) -> Dict[str, Any]:
"""试运行 Agent向后兼容接口"""
service = AppService(db)
return await service.draft_run(
app_id=app_id,
message=message,
conversation_id=conversation_id,
user_id=user_id,
variables=variables,
workspace_id=workspace_id
)
async def draft_run_stream(
db: Session,
*,
app_id: uuid.UUID,
message: str,
conversation_id: Optional[str] = None,
user_id: Optional[str] = None,
variables: Optional[Dict[str, Any]] = None,
workspace_id: Optional[uuid.UUID] = None
):
"""试运行 Agent 流式返回(向后兼容接口)"""
service = AppService(db)
async for event in service.draft_run_stream(
app_id=app_id,
message=message,
conversation_id=conversation_id,
user_id=user_id,
variables=variables,
workspace_id=workspace_id
):
yield event
# ==================== 依赖注入函数 ====================
def get_app_service(

View File

@@ -0,0 +1,101 @@
"""
音频转文本服务
支持的服务商:
- DashScope (阿里云通义千问)
- OpenAI Whisper
"""
import httpx
from app.core.logging_config import get_business_logger
logger = get_business_logger()
class AudioTranscriptionService:
"""音频转文本服务"""
@staticmethod
async def transcribe_dashscope(audio_url: str, api_key: str) -> str:
"""
使用阿里云通义千问语音识别服务转换音频为文本
Args:
audio_url: 音频文件 URL
api_key: DashScope API Key
Returns:
str: 转录的文本
"""
try:
async with httpx.AsyncClient(timeout=60.0) as client:
response = await client.post(
"https://dashscope.aliyuncs.com/api/v1/services/audio/asr/transcription",
headers={
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
"X-DashScope-Async": "enable",
},
json={
"model": "paraformer-v2",
"input": {
"file_urls": [audio_url]
},
"parameters": {
"language_hints": ["zh", "en", "ja", "yue", "ko", "de", "fr", "ru"]
}
}
)
response.raise_for_status()
result = response.json()
if result.get("output", {}).get("results"):
text = result["output"]["results"][0].get("transcription_text", "")
logger.info(f"音频转文本成功: {len(text)} 字符")
return text
return "[音频转文本失败]"
except Exception as e:
logger.error(f"DashScope 音频转文本失败: {e}")
return f"[音频转文本失败: {str(e)}]"
@staticmethod
async def transcribe_openai(audio_url: str, api_key: str) -> str:
"""
使用 OpenAI Whisper 转换音频为文本
Args:
audio_url: 音频文件 URL
api_key: OpenAI API Key
Returns:
str: 转录的文本
"""
try:
# 下载音频文件
async with httpx.AsyncClient(timeout=60.0) as client:
audio_response = await client.get(audio_url)
audio_response.raise_for_status()
audio_data = audio_response.content
# 调用 Whisper API
files = {"file": ("audio.mp3", audio_data, "audio/mpeg")}
data = {"model": "whisper-1"}
response = await client.post(
"https://api.openai.com/v1/audio/transcriptions",
headers={"Authorization": f"Bearer {api_key}"},
files=files,
data=data
)
response.raise_for_status()
result = response.json()
text = result.get("text", "")
logger.info(f"音频转文本成功: {len(text)} 字符")
return text
except Exception as e:
logger.error(f"OpenAI Whisper 音频转文本失败: {e}")
return f"[音频转文本失败: {str(e)}]"

View File

@@ -445,6 +445,7 @@ class CollaborativeOrchestrator:
"provider": api_key_config.provider,
"api_key": api_key_config.api_key,
"api_base": api_key_config.api_base,
"is_omni": api_key_config.is_omni,
"model_parameters": config_data.get("model_parameters", {}),
"api_key_id": api_key_config.id
}
@@ -511,6 +512,7 @@ class CollaborativeOrchestrator:
provider=agent_config["provider"],
api_key=agent_config["api_key"],
base_url=agent_config.get("api_base"),
is_omni=agent_config.get("is_omni", False),
extra_params=extra_params
)

View File

@@ -17,15 +17,18 @@ from sqlalchemy.orm import Session
from app.celery_app import celery_app
from app.core.agent.agent_middleware import AgentMiddleware
from app.core.agent.langchain_agent import LangChainAgent
from app.core.error_codes import BizCode
from app.core.exceptions import BusinessException
from app.core.logging_config import get_business_logger
from app.core.rag.nlp.search import knowledge_retrieval
from app.db import get_db_context
from app.models import AgentConfig, ModelConfig
from app.repositories.tool_repository import ToolRepository
from app.schemas.app_schema import FileInput
from app.schemas.prompt_schema import PromptMessageRole, render_prompt_message
from app.services import task_service
from app.services.conversation_service import ConversationService
from app.services.langchain_tool_server import Search
from app.services.memory_agent_service import MemoryAgentService
from app.services.model_parameter_merger import ModelParameterMerger
@@ -52,8 +55,12 @@ class LongTermMemoryInput(BaseModel):
description="经过优化重写的查询问题。请将用户的原始问题重写为更合适的检索形式,包含关键词,上下文和具体描述,注意错词检查并且改写")
def create_long_term_memory_tool(memory_config: Dict[str, Any], end_user_id: str, storage_type: Optional[str] = None,
user_rag_memory_id: Optional[str] = None):
def create_long_term_memory_tool(
memory_config: Dict[str, Any],
end_user_id: str,
storage_type: Optional[str] = None,
user_rag_memory_id: Optional[str] = None
):
"""创建记忆工具,
@@ -61,6 +68,7 @@ def create_long_term_memory_tool(memory_config: Dict[str, Any], end_user_id: str
memory_config: 记忆配置
end_user_id: 用户ID
storage_type: 存储类型(可选)
user_rag_memory_id: 用户RAG记忆ID可选
Returns:
长期记忆工具
@@ -96,9 +104,7 @@ def create_long_term_memory_tool(memory_config: Dict[str, Any], end_user_id: str
"""
logger.info(f" 长期记忆工具被调用question={question}, user={end_user_id}")
try:
from app.db import get_db
db = next(get_db())
try:
with get_db_context() as db:
memory_content = asyncio.run(
MemoryAgentService().read_memory(
end_user_id=end_user_id,
@@ -120,9 +126,6 @@ def create_long_term_memory_tool(memory_config: Dict[str, Any], end_user_id: str
logger.info(f"读取任务状态:{status}")
if memory_content:
memory_content = memory_content['answer']
finally:
db.close()
logger.info(f'用户IDAgent:{end_user_id}')
logger.debug("调用长期记忆 API", extra={"question": question, "end_user_id": end_user_id})
@@ -188,7 +191,9 @@ def create_knowledge_retrieval_tool(kb_config, kb_ids, user_id):
"""从知识库中检索相关信息。当用户的问题需要参考知识库、文档或历史记录时,使用此工具进行检索。
Args:
query: 需要检索的问题或关键词
kb_config: 知识库配置
kb_ids: 知识库ID列表
user_id: 用户ID
Returns:
检索到的相关知识内容
@@ -232,17 +237,141 @@ def create_knowledge_retrieval_tool(kb_config, kb_ids, user_id):
return knowledge_retrieval_tool
class DraftRunService:
"""运行服务类"""
class AgentRunService:
"""Agent运行服务类"""
def __init__(self, db: Session):
"""初始化试运行服务
"""Agent运行服务
Args:
db: 数据库会话
"""
self.db = db
@staticmethod
def prepare_variables(
input_vars: dict | None,
variables_config: dict
) -> dict:
input_vars = input_vars or {}
for variable in variables_config:
if variable.get("required") and variable.get("name") not in input_vars:
raise ValueError(f"The required parameter '{variable.get('name')}' was not provided")
return input_vars
def load_tools_config(self, tools_config, web_search, tenant_id) -> list:
"""加载工具配置"""
if not tools_config:
return []
tools = []
tool_service = ToolService(self.db)
if tools_config and isinstance(tools_config, list):
for tool_config in tools_config:
if tool_config.get("enabled", False):
# 根据工具名称查找工具实例
tool_instance = tool_service.get_tool_instance(tool_config.get("tool_id", ""), tenant_id)
if tool_instance:
if tool_instance.name == "baidu_search_tool" and not web_search:
continue
# 转换为LangChain工具
langchain_tool = tool_instance.to_langchain_tool(tool_config.get("operation", None))
tools.append(langchain_tool)
elif tools_config and isinstance(tools_config, dict):
web_search_choice = tools_config.get("web_search", {})
web_search_enable = web_search_choice.get("enabled", False)
if web_search and web_search_enable:
search_tool = create_web_search_tool({})
tools.append(search_tool)
logger.debug(
"已添加网络搜索工具",
extra={
"tool_count": len(tools)
}
)
return tools
def load_skill_config(
self,
skills_config: dict | None,
message: str, tenant_id
) -> tuple[list, str]:
if not skills_config:
return [], ""
tools = []
skill_prompts = ""
skill_enable = skills_config.get("enabled", False)
if skill_enable:
middleware = AgentMiddleware(skills=skills_config)
skill_tools, skill_configs, tool_to_skill_map = middleware.load_skill_tools(self.db, tenant_id)
tools.extend(skill_tools)
logger.debug(f"已加载 {len(skill_tools)} 个技能工具")
if skill_configs:
tools, activated_skill_ids = middleware.filter_tools(tools, message, skill_configs,
tool_to_skill_map)
logger.debug(f"过滤后剩余 {len(tools)} 个工具")
skill_prompts = AgentMiddleware.get_active_prompts(
activated_skill_ids, skill_configs
)
return tools, skill_prompts
def load_knowledge_retrieval_config(
self,
knowledge_retrieval_config: dict | None,
user_id
) -> list:
if not knowledge_retrieval_config:
return []
tools = []
knowledge_bases = knowledge_retrieval_config.get("knowledge_bases", [])
kb_ids = bool(knowledge_bases and knowledge_bases[0].get("kb_id"))
if kb_ids:
# 创建知识库检索工具
kb_tool = create_knowledge_retrieval_tool(knowledge_retrieval_config, kb_ids, user_id)
tools.append(kb_tool)
logger.debug(
"已添加知识库检索工具",
extra={
"kb_ids": kb_ids,
"tool_count": len(tools)
}
)
return tools
def load_memory_config(
self,
memory_config: dict | None,
user_id,
storage_type,
user_rag_memory_id
) -> tuple[list, bool]:
"""加载长期记忆配置"""
if not memory_config:
return [], False
tools = []
if memory_config.get("enabled"):
if user_id:
# 创建长期记忆工具
memory_tool = create_long_term_memory_tool(memory_config, user_id, storage_type,
user_rag_memory_id)
tools.append(memory_tool)
logger.debug(
"已添加长期记忆工具",
extra={
"user_id": user_id,
"tool_count": len(tools)
}
)
return tools, bool(memory_config.get("enabled"))
async def run(
self,
*,
@@ -270,19 +399,21 @@ class DraftRunService:
conversation_id: 会话ID用于多轮对话
user_id: 用户ID
variables: 自定义变量参数值
storage_type: 存储类型(可选)
user_rag_memory_id: 用户RAG记忆ID可选
web_search: 是否启用网络搜索默认True
memory: 是否启用长期记忆默认True
sub_agent: 是否为子代理调用默认False
files: 多模态文件列表(可选)
Returns:
Dict: 包含 AI 回复和元数据的字典
"""
memory_flag = False
print('===========', storage_type)
print(user_id)
if variables == None: variables = {}
from app.core.agent.langchain_agent import LangChainAgent
start_time = time.time()
tools_config: dict | list | None = agent_config.tools
skills_config: dict | None = agent_config.skills
knowledge_retrieval_config: dict | None = agent_config.knowledge_retrieval
memory_config: dict | None = agent_config.memory
try:
# 1. 获取 API Key 配置
@@ -302,112 +433,40 @@ class DraftRunService:
agent_config=agent_config
)
items_params = variables
if sub_agent:
variables = self.prepare_variables(variables, agent_config.variables)
else:
# FIXME: subagent input valid
variables = variables or {}
system_prompt = render_prompt_message(
agent_config.system_prompt, # 修正拼写错误
agent_config.system_prompt,
PromptMessageRole.USER,
items_params
variables
)
# 3. 处理系统提示词(支持变量替换)
system_prompt = system_prompt.get_text_content() or "你是一个专业的AI助手"
print('系统提示词:', system_prompt)
# 4. 准备工具列表
tools = []
tool_service = ToolService(self.db)
tenant_id = ToolRepository.get_tenant_id_by_workspace_id(self.db, str(workspace_id))
# 从配置中获取启用的工具
if hasattr(agent_config, 'tools') and agent_config.tools and isinstance(agent_config.tools, list):
if hasattr(agent_config, 'tools') and agent_config.tools:
for tool_config in agent_config.tools:
print("+" * 50)
print(f"agent_config:{agent_config}")
print(f"tool_config:{tool_config}")
if tool_config.get("enabled", False):
# 根据工具名称查找工具实例
tool_instance = tool_service._get_tool_instance(tool_config.get("tool_id", ""), tenant_id)
if tool_instance:
if tool_instance.name == "baidu_search_tool" and not web_search:
continue
# 转换为LangChain工具
langchain_tool = tool_instance.to_langchain_tool(tool_config.get("operation", None))
tools.append(langchain_tool)
elif hasattr(agent_config, 'tools') and agent_config.tools and isinstance(agent_config.tools, dict):
web_tools = agent_config.tools
web_search_choice = web_tools.get("web_search", {})
web_search_enable = web_search_choice.get("enabled", False)
if web_search:
if web_search_enable:
search_tool = create_web_search_tool({})
tools.append(search_tool)
logger.debug(
"已添加网络搜索工具",
extra={
"tool_count": len(tools)
}
)
# 加载技能关联的工具
if hasattr(agent_config, 'skills') and agent_config.skills:
skills = agent_config.skills
skill_enable = skills.get("enabled", False)
if skill_enable:
middleware = AgentMiddleware(skills=skills)
skill_tools, skill_configs, tool_to_skill_map = middleware.load_skill_tools(self.db, tenant_id)
tools.extend(skill_tools)
logger.debug(f"已加载 {len(skill_tools)} 个技能工具")
# 应用动态过滤
if skill_configs:
tools, activated_skill_ids = middleware.filter_tools(tools, message, skill_configs,
tool_to_skill_map)
logger.debug(f"过滤后剩余 {len(tools)} 个工具")
active_prompts = AgentMiddleware.get_active_prompts(
activated_skill_ids, skill_configs
)
system_prompt = f"{system_prompt}\n\n{active_prompts}"
# 添加知识库检索工具
if agent_config.knowledge_retrieval:
kb_config = agent_config.knowledge_retrieval
knowledge_bases = kb_config.get("knowledge_bases", [])
kb_ids = bool(knowledge_bases and knowledge_bases[0].get("kb_id"))
if kb_ids:
# 创建知识库检索工具
kb_tool = create_knowledge_retrieval_tool(kb_config, kb_ids, user_id)
tools.append(kb_tool)
logger.debug(
"已添加知识库检索工具",
extra={
"kb_ids": kb_ids,
"tool_count": len(tools)
}
)
tools.extend(self.load_tools_config(tools_config, web_search, tenant_id))
skill_tools, skill_prompts = self.load_skill_config(skills_config, message, tenant_id)
tools.extend(skill_tools)
if skill_prompts:
system_prompt = f"{system_prompt}\n\n{skill_prompts}"
tools.extend(self.load_knowledge_retrieval_config(knowledge_retrieval_config, user_id))
# 添加长期记忆工具
memory_flag = False
if memory:
if agent_config.memory and agent_config.memory.get("enabled"):
memory_flag = True
memory_config = agent_config.memory
if user_id:
# 创建长期记忆工具
memory_tool = create_long_term_memory_tool(memory_config, user_id, storage_type,
user_rag_memory_id)
tools.append(memory_tool)
logger.debug(
"已添加长期记忆工具",
extra={
"user_id": user_id,
"tool_count": len(tools)
}
)
memory_tools, memory_flag = self.load_memory_config(
memory_config, user_id, storage_type, user_rag_memory_id
)
tools.extend(memory_tools)
# 4. 创建 LangChain Agent
agent = LangChainAgent(
@@ -415,6 +474,7 @@ class DraftRunService:
api_key=api_key_config["api_key"],
provider=api_key_config.get("provider", "openai"),
api_base=api_key_config.get("api_base"),
is_omni=api_key_config.get("is_omni", False),
temperature=effective_params.get("temperature", 0.7),
max_tokens=effective_params.get("max_tokens", 2000),
system_prompt=system_prompt,
@@ -431,7 +491,7 @@ class DraftRunService:
# 6. 加载历史消息
history = []
if agent_config.memory and agent_config.memory.get("enabled"):
if memory_config and memory_config.get("enabled"):
history = await self._load_conversation_history(
conversation_id=conversation_id,
max_history=agent_config.memory.get("max_history", 10)
@@ -442,7 +502,7 @@ class DraftRunService:
if files:
# 获取 provider 信息
provider = api_key_config.get("provider", "openai")
multimodal_service = MultimodalService(self.db, provider=provider)
multimodal_service = MultimodalService(self.db, provider=provider, is_omni=api_key_config.get("is_omni", False))
processed_files = await multimodal_service.process_files(files)
logger.info(f"处理了 {len(processed_files)} 个文件provider={provider}")
@@ -481,7 +541,7 @@ class DraftRunService:
ModelApiKeyService.record_api_key_usage(self.db, api_key_config.get("api_key_id"))
# 9. 保存会话消息
if not sub_agent and agent_config.memory and agent_config.memory.get("enabled"):
if not sub_agent and memory_config and memory_config.get("enabled"):
await self._save_conversation_message(
conversation_id=conversation_id,
user_message=message,
@@ -556,16 +616,21 @@ class DraftRunService:
Yields:
str: SSE 格式的事件数据
"""
memory_flag = False
if variables == None: variables = {}
from app.core.agent.langchain_agent import LangChainAgent
tools_config: dict | list | None = agent_config.tools
skills_config: dict | None = agent_config.skills
knowledge_retrieval_config: dict | None = agent_config.knowledge_retrieval
memory_config: dict | None = agent_config.memory
start_time = time.time()
try:
# 1. 获取 API Key 配置
api_key_config = await self._get_api_key(model_config.id)
if not sub_agent:
variables = self.prepare_variables(variables, agent_config.variables)
else:
# FIXME: subagent input valid
variables = variables or {}
# 2. 合并模型参数
effective_params = ModelParameterMerger.get_effective_parameters(
@@ -587,95 +652,22 @@ class DraftRunService:
# 4. 准备工具列表
tools = []
tool_service = ToolService(self.db)
tenant_id = ToolRepository.get_tenant_id_by_workspace_id(self.db, str(workspace_id))
# 从配置中获取启用的工具
if hasattr(agent_config, 'tools') and agent_config.tools and isinstance(agent_config.tools, list):
for tool_config in agent_config.tools:
# print("+"*50)
# print(f"agent_config:{agent_config}")
# print(f"tool_config:{tool_config}")
if tool_config.get("enabled", False):
# 根据工具名称查找工具实例
tool_instance = tool_service._get_tool_instance(tool_config.get("tool_id", ""), tenant_id)
if tool_instance:
if tool_instance.name == "baidu_search_tool" and not web_search:
continue
# 转换为LangChain工具
langchain_tool = tool_instance.to_langchain_tool(tool_config.get("operation", None))
tools.append(langchain_tool)
elif hasattr(agent_config, 'tools') and agent_config.tools and isinstance(agent_config.tools, dict):
web_tools = agent_config.tools
web_search_choice = web_tools.get("web_search", {})
web_search_enable = web_search_choice.get("enabled", False)
if web_search:
if web_search_enable:
search_tool = create_web_search_tool({})
tools.append(search_tool)
tools.extend(self.load_tools_config(tools_config, web_search, tenant_id))
skill_tools, skill_prompts = self.load_skill_config(skills_config, message, tenant_id)
tools.extend(skill_tools)
if skill_prompts:
system_prompt = f"{system_prompt}\n\n{skill_prompts}"
tools.extend(self.load_knowledge_retrieval_config(knowledge_retrieval_config, user_id))
logger.debug(
"已添加网络搜索工具",
extra={
"tool_count": len(tools)
}
)
# 加载技能关联的工具
if hasattr(agent_config, 'skills') and agent_config.skills:
skills = agent_config.skills
skill_enable = skills.get("enabled", False)
if skill_enable:
middleware = AgentMiddleware(skills=skills)
skill_tools, skill_configs, tool_to_skill_map = middleware.load_skill_tools(self.db, tenant_id)
tools.extend(skill_tools)
logger.debug(f"已加载 {len(skill_tools)} 个技能工具")
# 应用动态过滤
if skill_configs:
tools, activated_skill_ids = middleware.filter_tools(tools, message, skill_configs,
tool_to_skill_map)
logger.debug(f"过滤后剩余 {len(tools)} 个工具")
active_prompts = AgentMiddleware.get_active_prompts(
activated_skill_ids, skill_configs
)
system_prompt = f"{system_prompt}\n\n{active_prompts}"
# 添加知识库检索工具
if agent_config.knowledge_retrieval:
kb_config = agent_config.knowledge_retrieval
knowledge_bases = kb_config.get("knowledge_bases", [])
kb_ids = bool(knowledge_bases and knowledge_bases[0].get("kb_id"))
if kb_ids:
# 创建知识库检索工具
kb_tool = create_knowledge_retrieval_tool(kb_config, kb_ids, user_id)
tools.append(kb_tool)
logger.debug(
"已添加知识库检索工具",
extra={
"kb_ids": kb_ids,
"tool_count": len(tools)
}
)
# 添加长期记忆工具
memory_flag = False
if memory:
if agent_config.memory and agent_config.memory.get("enabled"):
memory_flag = True
memory_config = agent_config.memory
if user_id:
# 创建长期记忆工具
memory_tool = create_long_term_memory_tool(memory_config, user_id, storage_type,
user_rag_memory_id)
tools.append(memory_tool)
logger.debug(
"已添加长期记忆工具",
extra={
"user_id": user_id,
"tool_count": len(tools)
}
)
memory_tools, memory_flag = self.load_memory_config(memory_config, user_id, storage_type,
user_rag_memory_id)
tools.extend(memory_tools)
# 4. 创建 LangChain Agent
agent = LangChainAgent(
@@ -683,6 +675,7 @@ class DraftRunService:
api_key=api_key_config["api_key"],
provider=api_key_config.get("provider", "openai"),
api_base=api_key_config.get("api_base"),
is_omni=api_key_config.get("is_omni", False),
temperature=effective_params.get("temperature", 0.7),
max_tokens=effective_params.get("max_tokens", 2000),
system_prompt=system_prompt,
@@ -700,10 +693,10 @@ class DraftRunService:
# 6. 加载历史消息
history = []
if agent_config.memory and agent_config.memory.get("enabled"):
if memory_config and memory_config.get("enabled"):
history = await self._load_conversation_history(
conversation_id=conversation_id,
max_history=agent_config.memory.get("max_history", 10)
max_history=memory_config.get("max_history", 10)
)
# 6. 处理多模态文件
@@ -711,7 +704,7 @@ class DraftRunService:
if files:
# 获取 provider 信息
provider = api_key_config.get("provider", "openai")
multimodal_service = MultimodalService(self.db, provider=provider)
multimodal_service = MultimodalService(self.db, provider=provider, is_omni=api_key_config.get("is_omni", False))
processed_files = await multimodal_service.process_files(files)
logger.info(f"处理了 {len(processed_files)} 个文件provider={provider}")
@@ -761,7 +754,7 @@ class DraftRunService:
})
# 10. 保存会话消息
if not sub_agent and agent_config.memory and agent_config.memory.get("enabled"):
if not sub_agent and memory_config and memory_config.get("enabled"):
await self._save_conversation_message(
conversation_id=conversation_id,
user_message=message,
@@ -809,7 +802,7 @@ class DraftRunService:
"""
return f"event: {event}\ndata: {json.dumps(data, ensure_ascii=False)}\n\n"
async def _get_api_key(self, model_config_id: uuid.UUID) -> Dict[str, str]:
async def _get_api_key(self, model_config_id: uuid.UUID) -> Dict:
"""获取模型的 API Key
Args:
@@ -846,7 +839,8 @@ class DraftRunService:
"provider": api_key.provider,
"api_key": api_key.api_key,
"api_base": api_key.api_base,
"api_key_id": api_key.id
"api_key_id": api_key.id,
"is_omni": api_key.is_omni
}
async def _ensure_conversation(
@@ -966,7 +960,6 @@ class DraftRunService:
List[Dict]: 历史消息列表
"""
try:
from app.services.conversation_service import ConversationService
conversation_service = ConversationService(self.db)
history = conversation_service.get_conversation_history(
@@ -1486,6 +1479,15 @@ class DraftRunService:
"conversation_id": returned_conversation_id,
"content": chunk
}))
if event_type == "error" and event_data:
await event_queue.put(self._format_sse_event("model_error", {
"model_index": idx,
"model_config_id": model_config_id,
"label": model_label,
"conversation_id": returned_conversation_id,
"error": event_data.get("error", "未知错误")
}))
except Exception as e:
logger.warning(f"解析流式事件失败: {e}")
finally:
@@ -1670,41 +1672,3 @@ class DraftRunService:
"total_time": sum(r.get("elapsed_time", 0) for r in results)
}
)
async def draft_run(
db: Session,
*,
agent_config: AgentConfig,
model_config: ModelConfig,
message: str,
user_id: Optional[str] = None,
kb_ids: Optional[List[str]] = None,
similarity_threshold: float = 0.7,
top_k: int = 3
) -> Dict[str, Any]:
"""试运行 Agent便捷函数
Args:
db: 数据库会话
agent_config: Agent 配置
model_config: 模型配置
message: 用户消息
user_id: 用户ID
kb_ids: 知识库ID列表
similarity_threshold: 相似度阈值
top_k: 检索返回的文档数量
Returns:
Dict: 包含 AI 回复和元数据的字典
"""
service = DraftRunService(db)
return await service.run(
agent_config=agent_config,
model_config=model_config,
message=message,
user_id=user_id,
kb_ids=kb_ids,
similarity_threshold=similarity_threshold,
top_k=top_k
)

View File

@@ -843,32 +843,33 @@ class EmotionAnalyticsService:
end_user_id: str,
db: Session,
) -> Optional[Dict[str, Any]]:
""" Redis 缓存获取个性化情绪建议
"""数据库获取个性化情绪建议
Args:
end_user_id: 宿主ID用户组ID
db: 数据库会话(保留参数以保持接口兼容性)
db: 数据库会话
Returns:
Dict: 存的建议数据,如果不存在或已过期返回 None
Dict: 存的建议数据,如果不存在返回 None
"""
try:
from app.cache.memory.emotion_memory import EmotionMemoryCache
from app.repositories.implicit_emotions_storage_repository import ImplicitEmotionsStorageRepository
logger.info(f"尝试从 Redis 缓存获取情绪建议: user={end_user_id}")
logger.info(f"尝试从数据库获取情绪建议: user={end_user_id}")
# 从 Redis 获取缓存
cached_data = await EmotionMemoryCache.get_emotion_suggestions(end_user_id)
# 从数据库获取存储记录
repo = ImplicitEmotionsStorageRepository(db)
storage = repo.get_by_end_user_id(end_user_id)
if cached_data is None:
logger.info(f"用户 {end_user_id} 的建议缓存不存在或已过期")
if storage is None or storage.emotion_suggestions is None:
logger.info(f"用户 {end_user_id} 的建议数据不存在")
return None
logger.info(f"成功从 Redis 缓存获取建议: user={end_user_id}")
return cached_data
logger.info(f"成功从数据库获取建议: user={end_user_id}")
return storage.emotion_suggestions
except Exception as e:
logger.error(f" Redis 缓存获取建议失败: {str(e)}", exc_info=True)
logger.error(f"数据库获取建议失败: {str(e)}", exc_info=True)
return None
async def save_suggestions_cache(
@@ -876,36 +877,27 @@ class EmotionAnalyticsService:
end_user_id: str,
suggestions_data: Dict[str, Any],
db: Session,
expires_hours: int = 24
expires_hours: int = 24 # 参数保留以保持接口兼容性
) -> None:
"""保存建议到 Redis 缓存
"""保存建议到数据库
Args:
end_user_id: 宿主ID用户组ID
suggestions_data: 建议数据
db: 数据库会话(保留参数以保持接口兼容性)
expires_hours: 过期时间小时默认24小时
db: 数据库会话
expires_hours: 保留参数(兼容性)
"""
try:
from app.cache.memory.emotion_memory import EmotionMemoryCache
from app.repositories.implicit_emotions_storage_repository import ImplicitEmotionsStorageRepository
logger.info(f"保存建议到 Redis 缓存: user={end_user_id}, expires={expires_hours}小时")
logger.info(f"保存建议到数据库: user={end_user_id}")
# 计算过期时间(秒)
expire_seconds = expires_hours * 3600
repo = ImplicitEmotionsStorageRepository(db)
repo.update_emotion_suggestions(end_user_id, suggestions_data)
db.commit()
# 保存到 Redis
success = await EmotionMemoryCache.set_emotion_suggestions(
user_id=end_user_id,
suggestions_data=suggestions_data,
expire=expire_seconds
)
if success:
logger.info(f"建议缓存保存成功: user={end_user_id}")
else:
logger.warning(f"建议缓存保存失败: user={end_user_id}")
logger.info(f"建议保存成功: user={end_user_id}")
except Exception as e:
logger.error(f"保存建议缓存失败: {str(e)}", exc_info=True)
# 不抛出异常,缓存失败不应影响主流程
db.rollback()
logger.error(f"保存建议失败: {str(e)}", exc_info=True)

View File

@@ -544,6 +544,7 @@ def convert_multi_agent_config_to_handoffs(
provider=model_api_key.provider,
api_key=model_api_key.api_key,
base_url=model_api_key.api_base,
is_omni=model_api_key.is_omni,
extra_params={
"temperature": 0.7,
"max_tokens": 2000,

View File

@@ -422,32 +422,33 @@ class ImplicitMemoryService:
end_user_id: str,
db: Session
) -> Optional[dict]:
""" Redis 缓存获取完整用户画像
"""数据库获取完整用户画像
Args:
end_user_id: 终端用户ID
db: 数据库会话(保留参数以保持接口兼容性)
db: 数据库会话
Returns:
Dict: 存的画像数据,如果不存在或已过期返回 None
Dict: 存的画像数据,如果不存在返回 None
"""
try:
from app.cache.memory.implicit_memory import ImplicitMemoryCache
from app.repositories.implicit_emotions_storage_repository import ImplicitEmotionsStorageRepository
logger.info(f"尝试从 Redis 缓存获取用户画像: user={end_user_id}")
logger.info(f"尝试从数据库获取用户画像: user={end_user_id}")
# 从 Redis 获取缓存
cached_data = await ImplicitMemoryCache.get_user_profile(end_user_id)
# 从数据库获取存储记录
repo = ImplicitEmotionsStorageRepository(db)
storage = repo.get_by_end_user_id(end_user_id)
if cached_data is None:
logger.info(f"用户 {end_user_id} 的画像缓存不存在或已过期")
if storage is None or storage.implicit_profile is None:
logger.info(f"用户 {end_user_id} 的画像数据不存在")
return None
logger.info(f"成功从 Redis 缓存获取用户画像: user={end_user_id}")
return cached_data
logger.info(f"成功从数据库获取用户画像: user={end_user_id}")
return storage.implicit_profile
except Exception as e:
logger.error(f" Redis 缓存获取用户画像失败: {str(e)}", exc_info=True)
logger.error(f"数据库获取用户画像失败: {str(e)}", exc_info=True)
return None
async def save_profile_cache(
@@ -455,36 +456,27 @@ class ImplicitMemoryService:
end_user_id: str,
profile_data: dict,
db: Session,
expires_hours: int = 168 # 默认7天
expires_hours: int = 168 # 参数保留以保持接口兼容性
) -> None:
"""保存用户画像到 Redis 缓存
"""保存用户画像到数据库
Args:
end_user_id: 终端用户ID
profile_data: 画像数据
db: 数据库会话(保留参数以保持接口兼容性)
expires_hours: 过期时间小时默认168小时7天
db: 数据库会话
expires_hours: 保留参数(兼容性
"""
try:
from app.cache.memory.implicit_memory import ImplicitMemoryCache
from app.repositories.implicit_emotions_storage_repository import ImplicitEmotionsStorageRepository
logger.info(f"保存用户画像到 Redis 缓存: user={end_user_id}, expires={expires_hours}小时")
logger.info(f"保存用户画像到数据库: user={end_user_id}")
# 计算过期时间(秒)
expire_seconds = expires_hours * 3600
repo = ImplicitEmotionsStorageRepository(db)
repo.update_implicit_profile(end_user_id, profile_data)
db.commit()
# 保存到 Redis
success = await ImplicitMemoryCache.set_user_profile(
user_id=end_user_id,
profile_data=profile_data,
expire=expire_seconds
)
if success:
logger.info(f"用户画像缓存保存成功: user={end_user_id}")
else:
logger.warning(f"用户画像缓存保存失败: user={end_user_id}")
logger.info(f"用户画像保存成功: user={end_user_id}")
except Exception as e:
logger.error(f"保存用户画像缓存失败: {str(e)}", exc_info=True)
# 不抛出异常,缓存失败不应影响主流程
db.rollback()
logger.error(f"保存用户画像失败: {str(e)}", exc_info=True)

View File

@@ -9,6 +9,8 @@ load_dotenv()
# 读取web_search环境变量
web_search_value = os.getenv('web_search')
def Search(query):
url = "https://qianfan.baidubce.com/v2/ai_search/chat/completions"
api_key = web_search_value
@@ -18,23 +20,24 @@ def Search(query):
"role": "user",
"content": query
}
], #搜索输入
"edition":"standard", #搜索版本。默认为standard。可选值standard完整版本。lite标准版本对召回规模和精排条数简化后的版本时延表现更好效果略弱于完整版。
"search_source": "baidu_search_v2", #使用的搜索引擎版本
"resource_type_filter": [{"type": "web","top_k": 20}], #支持设置网页、视频、图片、阿拉丁搜索模态网页top_k最大取值为50视频top_k最大为10图片top_k最大为30阿拉丁top_k最大为5
], # 搜索输入
"edition": "standard", # 搜索版本。默认为standard。可选值standard完整版本。lite标准版本对召回规模和精排条数简化后的版本时延表现更好效果略弱于完整版。
"search_source": "baidu_search_v2", # 使用的搜索引擎版本
"resource_type_filter": [{"type": "web", "top_k": 20}],
# 支持设置网页、视频、图片、阿拉丁搜索模态网页top_k最大取值为50视频top_k最大为10图片top_k最大为30阿拉丁top_k最大为5
"search_filter": {
"range": {
"page_time": {
"gte": "now-1w/d", #时间查询参数,大于或等于
"lt": "now/d", #时间查询参数,小于
"gt": "", #时间查询参数,大于
"lte": "" #时间查询参数,小于或等于
"gte": "now-1w/d", # 时间查询参数,大于或等于
"lt": "now/d", # 时间查询参数,小于
"gt": "", # 时间查询参数,大于
"lte": "" # 时间查询参数,小于或等于
}
}
},
"block_websites":["tieba.baidu.com"], #需要屏蔽的站点列表
"search_recency_filter":"week", #根据网页发布时间进行筛选可填值为week,month,semiyear,year
"enable_full_content":True #是否输出网页完整原文
"block_websites": ["tieba.baidu.com"], # 需要屏蔽的站点列表
"search_recency_filter": "week", # 根据网页发布时间进行筛选可填值为week,month,semiyear,year
"enable_full_content": True # 是否输出网页完整原文
}, ensure_ascii=False)
headers = {
'Content-Type': 'application/json',
@@ -42,10 +45,10 @@ def Search(query):
}
response = requests.request("POST", url, headers=headers, data=payload.encode("utf-8")).json()
content=[]
content = []
for i in response['references']:
title=i['title']
snippet=i['snippet']
content.append(title+';'+snippet)
content=''.join(content)
return content
title = i['title']
snippet = i['snippet']
content.append(title + ';' + snippet)
content = ''.join(content)
return content

View File

@@ -414,6 +414,7 @@ class LLMRouter:
provider=api_key_config.provider,
api_key=api_key_config.api_key,
base_url=api_key_config.api_base,
is_omni=api_key_config.is_omni,
temperature=0.3,
max_tokens=500
)

View File

@@ -392,6 +392,7 @@ class MasterAgentRouter:
provider=api_key_config.provider,
api_key=api_key_config.api_key,
base_url=api_key_config.api_base,
is_omni=api_key_config.is_omni,
extra_params = extra_params
)

View File

@@ -13,7 +13,6 @@ TODO: Refactor get_end_user_connected_config
"""
import json
import os
import re
import time
import uuid
from typing import Any, AsyncGenerator, Dict, List, Optional
@@ -35,12 +34,10 @@ from app.core.memory.agent.utils.messages_tools import (
reorder_output_results,
)
from app.core.memory.agent.utils.type_classifier import status_typle
from app.core.memory.agent.utils.write_tools import write # 新增:直接导入 write 函数
from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags, get_interest_distribution
from app.core.memory.analytics.hot_memory_tags import get_interest_distribution
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.db import get_db_context
from app.models.knowledge_model import Knowledge, KnowledgeType
from app.repositories.memory_short_repository import ShortTermMemoryRepository
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.schemas.memory_agent_schema import Write_UserInput
from app.schemas.memory_config_schema import ConfigurationError
@@ -69,7 +66,8 @@ class MemoryAgentService:
logger.info(f"Write operation successful for group {end_user_id} with config_id {config_id}")
# 记录成功的操作
if audit_logger:
audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id, success=True,
audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id,
success=True,
duration=duration, details={"message_length": len(message)})
return context
else:
@@ -88,8 +86,6 @@ class MemoryAgentService:
raise ValueError(f"写入失败: {messages}")
def extract_tool_call_info(self, event: Dict) -> bool:
"""Extract tool call information from event"""
last_message = event["messages"][-1]
@@ -271,7 +267,8 @@ class MemoryAgentService:
logger.info("Log streaming completed, cleaning up resources")
# LogStreamer uses context manager for file handling, so cleanup is automatic
async def write_memory(self, end_user_id: str, messages: list[dict], config_id: Optional[uuid.UUID]|int, db: Session, storage_type: str, user_rag_memory_id: str, language: str = "zh") -> str:
async def write_memory(self, end_user_id: str, messages: list[dict], config_id: Optional[uuid.UUID] | int,
db: Session, storage_type: str, user_rag_memory_id: str, language: str = "zh") -> str:
"""
Process write operation with config_id
@@ -300,7 +297,8 @@ class MemoryAgentService:
config_id = connected_config.get("memory_config_id")
logger.info(f"Resolved config from end_user: config_id={config_id}, workspace_id={workspace_id}")
if config_id is None and workspace_id is None:
raise ValueError(f"No memory configuration found for end_user {end_user_id}. Please ensure the user has a connected memory configuration.")
raise ValueError(
f"No memory configuration found for end_user {end_user_id}. Please ensure the user has a connected memory configuration.")
except Exception as e:
if "No memory configuration found" in str(e):
raise # Re-raise our specific error
@@ -331,7 +329,8 @@ class MemoryAgentService:
# Log failed operation
if audit_logger:
duration = time.time() - start_time
audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id, success=False, duration=duration, error=error_msg)
audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id,
success=False, duration=duration, error=error_msg)
raise ValueError(error_msg)
@@ -351,9 +350,9 @@ class MemoryAgentService:
langchain_messages.append(HumanMessage(content=msg['content']))
elif msg['role'] == 'assistant':
langchain_messages.append(AIMessage(content=msg['content']))
print(100*'-')
print(100 * '-')
print(langchain_messages)
print(100*'-')
print(100 * '-')
# 初始状态 - 包含所有必要字段
initial_state = {
"messages": langchain_messages,
@@ -375,29 +374,28 @@ class MemoryAgentService:
contents = massages.get('write_result')
# Convert messages back to string for logging
message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages])
return self.writer_messages_deal(massagesstatus, start_time, end_user_id, config_id, message_text, contents)
return self.writer_messages_deal(massagesstatus, start_time, end_user_id, config_id, message_text,
contents)
except Exception as e:
# Ensure proper error handling and logging
error_msg = f"Write operation failed: {str(e)}"
logger.error(error_msg)
if audit_logger:
duration = time.time() - start_time
audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id, success=False, duration=duration, error=error_msg)
audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id,
success=False, duration=duration, error=error_msg)
raise ValueError(error_msg)
async def read_memory(
self,
end_user_id: str,
message: str,
history: List[Dict],
search_switch: str,
config_id: Optional[uuid.UUID]|int,
db: Session,
storage_type: str,
user_rag_memory_id: str) -> Dict:
self,
end_user_id: str,
message: str,
history: List[Dict],
search_switch: str,
config_id: Optional[uuid.UUID] | int,
db: Session,
storage_type: str,
user_rag_memory_id: str) -> Dict:
"""
Process read operation with config_id
@@ -425,7 +423,7 @@ class MemoryAgentService:
import time
start_time = time.time()
ori_message= message
ori_message = message
# Resolve config_id and workspace_id
# Always get workspace_id from end_user for fallback, even if config_id is provided
@@ -437,7 +435,8 @@ class MemoryAgentService:
config_id = connected_config.get("memory_config_id")
logger.info(f"Resolved config from end_user: config_id={config_id}, workspace_id={workspace_id}")
if config_id is None and workspace_id is None:
raise ValueError(f"No memory configuration found for end_user {end_user_id}. Please ensure the user has a connected memory configuration.")
raise ValueError(
f"No memory configuration found for end_user {end_user_id}. Please ensure the user has a connected memory configuration.")
except Exception as e:
if "No memory configuration found" in str(e):
raise # Re-raise our specific error
@@ -454,7 +453,6 @@ class MemoryAgentService:
except ImportError:
audit_logger = None
config_load_start = time.time()
try:
# Use a separate database session to avoid transaction failures
@@ -562,34 +560,35 @@ class MemoryAgentService:
from app.repositories.memory_short_repository import (
ShortTermMemoryRepository,
)
retrieved_content = []
repo = ShortTermMemoryRepository(db)
if str(search_switch) != "2":
for intermediate in _intermediate_outputs:
logger.debug(f"处理中间结果: {intermediate}")
intermediate_type = intermediate.get('type', '')
if intermediate_type == "search_result":
query = intermediate.get('query', '')
raw_results = intermediate.get('raw_results', {})
try:
reranked_results = raw_results.get('reranked_results', [])
statements = [statement['statement'] for statement in reranked_results.get('statements', [])]
statements = [statement['statement'] for statement in
reranked_results.get('statements', [])]
except Exception:
statements = []
# 去重
statements = list(set(statements))
if query and statements:
retrieved_content.append({query: statements})
# 如果 retrieved_content 为空,设置为空字符串
if retrieved_content == []:
retrieved_content = ''
# 只有当回答不是"信息不足"且不是快速检索时才保存
if '信息不足,无法回答。' != str(summary) and str(search_switch).strip() != "2":
# 使用 upsert 方法
@@ -602,15 +601,17 @@ class MemoryAgentService:
)
logger.info(f"成功保存短期记忆: end_user_id={end_user_id}, search_switch={search_switch}")
else:
logger.debug(f"跳过保存短期记忆: summary={summary[:50] if summary else 'None'}, search_switch={search_switch}")
logger.debug(
f"跳过保存短期记忆: summary={summary[:50] if summary else 'None'}, search_switch={search_switch}")
except Exception as save_error:
# 保存失败不应该影响主流程,只记录错误
logger.error(f"保存短期记忆失败: {str(save_error)}", exc_info=True)
# Log successful operation
total_time = time.time() - start_time
logger.info(f"[PERF] read_memory completed successfully in {total_time:.4f}s (config: {config_load_time:.4f}s, graph: {graph_exec_time:.4f}s)")
logger.info(
f"[PERF] read_memory completed successfully in {total_time:.4f}s (config: {config_load_time:.4f}s, graph: {graph_exec_time:.4f}s)")
if audit_logger:
duration = time.time() - start_time
audit_logger.log_operation(
@@ -641,7 +642,6 @@ class MemoryAgentService:
)
raise ValueError(error_msg)
def get_messages_list(self, user_input: Write_UserInput) -> list[dict]:
"""
Get standardized message list from user input.
@@ -657,41 +657,43 @@ class MemoryAgentService:
"""
from app.core.logging_config import get_api_logger
logger = get_api_logger()
if len(user_input.messages) == 0:
logger.error("Validation failed: Message list cannot be empty")
raise ValueError("Message list cannot be empty")
for idx, msg in enumerate(user_input.messages):
if not isinstance(msg, dict):
logger.error(f"Validation failed: Message {idx} is not a dict: {type(msg)}")
raise ValueError(f"Message format error: Message must be a dictionary. Error message index: {idx}, type: {type(msg)}")
raise ValueError(
f"Message format error: Message must be a dictionary. Error message index: {idx}, type: {type(msg)}")
if 'role' not in msg:
logger.error(f"Validation failed: Message {idx} missing 'role' field: {msg}")
raise ValueError(f"Message format error: Message must contain 'role' field. Error message index: {idx}")
if 'content' not in msg:
logger.error(f"Validation failed: Message {idx} missing 'content' field: {msg}")
raise ValueError(f"Message format error: Message must contain 'content' field. Error message index: {idx}")
raise ValueError(
f"Message format error: Message must contain 'content' field. Error message index: {idx}")
if msg['role'] not in ['user', 'assistant']:
logger.error(f"Validation failed: Message {idx} invalid role: {msg['role']}")
raise ValueError(f"Role must be 'user' or 'assistant', got: {msg['role']}. Message index: {idx}")
if not msg['content'] or not msg['content'].strip():
logger.error(f"Validation failed: Message {idx} content is empty")
raise ValueError(f"Message content cannot be empty. Message index: {idx}, role: {msg['role']}")
logger.info(f"Validation successful: Structured message list, count: {len(user_input.messages)}")
return user_input.messages
async def classify_message_type(
self,
message: str,
config_id: UUID,
db: Session,
workspace_id: Optional[UUID] = None
self,
message: str,
config_id: UUID,
db: Session,
workspace_id: Optional[UUID] = None
) -> Dict:
"""
Determine the type of user message (read or write)
@@ -719,14 +721,15 @@ class MemoryAgentService:
status = await status_typle(message, memory_config.llm_model_id)
logger.debug(f"Message type: {status}")
return status
async def generate_summary_from_retrieve(
self,
end_user_id: str,
retrieve_info: str,
history: List[Dict],
query: str,
config_id: str,
db: Session
self,
end_user_id: str,
retrieve_info: str,
history: List[Dict],
query: str,
config_id: str,
db: Session
) -> str:
"""
基于检索信息、历史对话和查询生成最终答案
@@ -761,9 +764,9 @@ class MemoryAgentService:
if config_id is None:
raise ValueError(f"Unable to determine memory configuration for end_user {end_user_id}: {e}")
# If config_id was provided, continue without workspace_id fallback
logger.info(f"Generating summary from retrieve info for query: {query[:50]}...")
try:
# 加载配置
config_service = MemoryConfigService(db)
@@ -772,7 +775,7 @@ class MemoryAgentService:
workspace_id=workspace_id,
service_name="MemoryAgentService"
)
# 导入必要的模块
from app.core.memory.agent.langgraph_graph.nodes.summary_nodes import (
summary_llm,
@@ -780,13 +783,13 @@ class MemoryAgentService:
from app.core.memory.agent.models.summary_models import (
RetrieveSummaryResponse,
)
# 构建状态对象
state = {
"data": query,
"memory_config": memory_config
}
# 直接调用 summary_llm 函数
answer = await summary_llm(
state=state,
@@ -797,21 +800,20 @@ class MemoryAgentService:
response_model=RetrieveSummaryResponse,
search_mode="1"
)
logger.info(f"Successfully generated summary: {answer[:100] if answer else 'None'}...")
return answer if answer else "信息不足,无法回答。"
except Exception as e:
logger.error(f"生成摘要失败: {str(e)}", exc_info=True)
return "信息不足,无法回答。"
async def get_knowledge_type_stats(
self,
end_user_id: Optional[str] = None,
only_active: bool = True,
current_workspace_id: Optional[uuid.UUID] = None,
db: Session = None
self,
db: Session,
end_user_id: Optional[str] = None,
only_active: bool = True,
current_workspace_id: Optional[uuid.UUID] = None
) -> Dict[str, Any]:
"""
统计知识库类型分布,包含:
@@ -837,11 +839,6 @@ class MemoryAgentService:
# 1. 统计 PostgreSQL 中的知识库类型
try:
if db is None:
from app.db import get_db
db_gen = get_db()
db = next(db_gen)
# 初始化所有标准类型为 0
for kb_type in KnowledgeType:
result[kb_type.value] = 0
@@ -881,21 +878,19 @@ class MemoryAgentService:
# 3. 计算知识库类型总和(不包括 memory
result["total"] = (
result.get("General", 0) +
result.get("Web", 0) +
result.get("Third-party", 0) +
result.get("Folder", 0)
result.get("General", 0) +
result.get("Web", 0) +
result.get("Third-party", 0) +
result.get("Folder", 0)
)
return result
async def get_interest_distribution_by_user(
self,
end_user_id: Optional[str] = None,
limit: int = 5,
language: str = "zh"
self,
end_user_id: Optional[str] = None,
limit: int = 5,
language: str = "zh"
) -> List[Dict[str, Any]]:
"""
获取指定用户的兴趣分布标签。
@@ -921,13 +916,12 @@ class MemoryAgentService:
logger.error(f"兴趣分布标签查询失败: {e}")
raise Exception(f"兴趣分布标签查询失败: {e}")
async def get_user_profile(
self,
end_user_id: Optional[str] = None,
current_user_id: Optional[str] = None,
llm_id: Optional[str] = None,
db: Session = None
self,
end_user_id: Optional[str] = None,
current_user_id: Optional[str] = None,
llm_id: Optional[str] = None,
db: Session = None
) -> Dict[str, Any]:
"""
获取用户详情,包含:
@@ -1017,7 +1011,8 @@ class MemoryAgentService:
# 定义标签提取的结构
class UserTags(BaseModel):
tags: list[str] = Field(..., description="3个描述用户特征的标签产品设计师、旅行爱好者、摄影发烧友")
tags: list[str] = Field(...,
description="3个描述用户特征的标签产品设计师、旅行爱好者、摄影发烧友")
messages = [
{
@@ -1160,7 +1155,6 @@ def get_end_user_connected_config(end_user_id: str, db: Session) -> Dict[str, An
ValueError: 当终端用户不存在或应用未发布时
"""
import json as json_module
import uuid
from sqlalchemy import select
@@ -1192,14 +1186,14 @@ def get_end_user_connected_config(end_user_id: str, db: Session) -> Dict[str, An
# 3. 兼容旧数据:如果 memory_config_id 为空,从 AppRelease.config 获取并回填
memory_config_id_to_use = end_user.memory_config_id
# 如果已有 memory_config_id直接使用
# 如果新创建enduserenduser.memory_config_id 必定为none
# 那么使用从release中获取memory_config_id为预期行为并且回填到
# end_user.memory_config_id
if not memory_config_id_to_use:
logger.info(f"end_user.memory_config_id is None, migrating from AppRelease.config")
# 获取最新发布版本
stmt = (
select(AppRelease)
@@ -1208,10 +1202,10 @@ def get_end_user_connected_config(end_user_id: str, db: Session) -> Dict[str, An
)
# TODO: change to current_release_id
latest_release = db.scalars(stmt).first()
if latest_release:
config = latest_release.config or {}
# 如果 config 是字符串,解析为字典
if isinstance(config, str):
try:
@@ -1219,22 +1213,22 @@ def get_end_user_connected_config(end_user_id: str, db: Session) -> Dict[str, An
except json_module.JSONDecodeError:
logger.warning(f"Failed to parse config JSON for release {latest_release.id}")
config = {}
# 使用 MemoryConfigService 的提取方法
memory_config_service = MemoryConfigService(db)
legacy_config_id, is_legacy_int = memory_config_service.extract_memory_config_id(
app_type=app.type,
config=config
)
if legacy_config_id:
# 验证提取的 config_id 是否存在于数据库中
from app.models.memory_config_model import MemoryConfig as MemoryConfigModel
existing_config = db.get(MemoryConfigModel, legacy_config_id)
if existing_config:
memory_config_id_to_use = legacy_config_id
# 回填到 end_user 表lazy update
end_user.memory_config_id = memory_config_id_to_use
db.commit()
@@ -1268,7 +1262,8 @@ def get_end_user_connected_config(end_user_id: str, db: Session) -> Dict[str, An
"workspace_id": str(app.workspace_id)
}
logger.info(f"Successfully retrieved connected config: memory_config_id={memory_config_id}, workspace_id={app.workspace_id}")
logger.info(
f"Successfully retrieved connected config: memory_config_id={memory_config_id}, workspace_id={app.workspace_id}")
return result
@@ -1312,7 +1307,7 @@ def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session)
# 1. 批量查询所有 end_user 及其 app_id 和 memory_config_id
end_users = db.query(EndUser).filter(EndUser.id.in_(end_user_ids)).all()
# 创建映射 - 保留 EndUser 对象引用以便回填
end_user_map = {str(eu.id): eu for eu in end_users}
user_data = {str(eu.id): {"app_id": eu.app_id, "memory_config_id": eu.memory_config_id} for eu in end_users}
@@ -1336,15 +1331,15 @@ def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session)
# 3. 对于没有 memory_config_id 的用户,尝试从 AppRelease.config 提取
users_needing_migration = [
(end_user_id, data["app_id"])
for end_user_id, data in user_data.items()
(end_user_id, data["app_id"])
for end_user_id, data in user_data.items()
if not data["memory_config_id"]
]
if users_needing_migration:
# 批量获取相关应用的最新发布版本
migration_app_ids = list(set(app_id for _, app_id in users_needing_migration))
# 查询每个应用的最新活跃发布版本
app_latest_releases = {}
for app_id in migration_app_ids:
@@ -1357,18 +1352,18 @@ def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session)
latest_release = db.scalars(stmt).first()
if latest_release:
app_latest_releases[app_id] = latest_release
# 为每个需要迁移的用户提取 memory_config_id
config_service = MemoryConfigService(db)
users_to_backfill = [] # [(end_user, memory_config_id), ...]
for end_user_id, app_id in users_needing_migration:
latest_release = app_latest_releases.get(app_id)
if not latest_release:
continue
config = latest_release.config or {}
# 如果 config 是字符串,解析为字典
if isinstance(config, str):
try:
@@ -1376,21 +1371,21 @@ def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session)
except json_module.JSONDecodeError:
logger.warning(f"Failed to parse config JSON for release {latest_release.id}")
continue
# 使用 MemoryConfigService 的提取方法
app = app_map.get(app_id)
if not app:
continue
legacy_config_id, is_legacy_int = config_service.extract_memory_config_id(
app_type=app.type,
config=config
)
if legacy_config_id:
# 更新 user_data 中的 memory_config_id
user_data[end_user_id]["memory_config_id"] = legacy_config_id
# 记录需要回填的用户(稍后验证配置存在后再回填)
end_user = end_user_map.get(end_user_id)
if end_user:
@@ -1399,7 +1394,7 @@ def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session)
logger.info(
f"Legacy int config detected for end_user {end_user_id}, will use workspace default"
)
# 验证提取的 config_id 是否存在于数据库中
if users_to_backfill:
config_ids_to_validate = list(set(cid for _, cid in users_to_backfill))
@@ -1407,17 +1402,17 @@ def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session)
MemoryConfig.config_id.in_(config_ids_to_validate)
).all()
valid_config_ids = {mc.config_id for mc in existing_configs}
# 只回填存在的配置
valid_backfills = [
(eu, cid) for eu, cid in users_to_backfill
(eu, cid) for eu, cid in users_to_backfill
if cid in valid_config_ids
]
invalid_backfills = [
(eu, cid) for eu, cid in users_to_backfill
(eu, cid) for eu, cid in users_to_backfill
if cid not in valid_config_ids
]
if invalid_backfills:
invalid_ids = [str(cid) for _, cid in invalid_backfills]
logger.warning(
@@ -1426,7 +1421,7 @@ def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session)
# 清除 user_data 中无效的 config_id
for eu, cid in invalid_backfills:
user_data[str(eu.id)]["memory_config_id"] = None
# 批量回填 end_user.memory_config_id
if valid_backfills:
for end_user, memory_config_id in valid_backfills:
@@ -1437,7 +1432,7 @@ def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session)
# 4. 收集需要查询的 memory_config_id 和需要回退的 workspace_id
direct_config_ids = []
workspace_fallback_users = [] # [(end_user_id, workspace_id), ...]
for end_user_id, data in user_data.items():
if data["memory_config_id"]:
direct_config_ids.append(data["memory_config_id"])
@@ -1455,7 +1450,7 @@ def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session)
# 6. 获取工作空间默认配置(需要逐个查询,因为 get_workspace_default_config 有复杂逻辑)
workspace_default_configs = {}
unique_workspace_ids = list(set(ws_id for _, ws_id in workspace_fallback_users))
if unique_workspace_ids:
config_service = MemoryConfigService(db)
for workspace_id in unique_workspace_ids:
@@ -1466,11 +1461,11 @@ def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session)
# 7. 构建最终结果
for end_user_id, data in user_data.items():
memory_config = None
# 优先使用 end_user 直接分配的配置
if data["memory_config_id"]:
memory_config = config_id_to_config.get(data["memory_config_id"])
# 回退到工作空间默认配置
if not memory_config:
workspace_id = app_to_workspace.get(data["app_id"])
@@ -1486,4 +1481,4 @@ def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session)
result[end_user_id] = {"memory_config_id": None, "memory_config_name": None}
logger.info(f"Successfully retrieved {len(result)} connected configs")
return result
return result

View File

@@ -140,9 +140,11 @@ class MemoryAPIService:
try:
# Delegate to MemoryAgentService
# Convert string message to list[dict] format expected by MemoryAgentService
messages = message if isinstance(message, list) else [{"role": "user", "content": message}]
result = await MemoryAgentService().write_memory(
end_user_id=end_user_id,
messages=message,
messages=messages,
config_id=config_id,
db=self.db,
storage_type=storage_type,
@@ -151,9 +153,18 @@ class MemoryAPIService:
logger.info(f"Memory write successful for end_user: {end_user_id}")
# result may be a string "success" or a dict with a "status" key
# Preserve the full dict so callers don't silently lose extra fields
# (e.g. error codes, metadata) returned by MemoryAgentService.
if isinstance(result, dict):
return {
**result,
"status": result.get("status", "unknown"),
"end_user_id": end_user_id,
}
return {
"status": "success" if result == "success" else result,
"end_user_id": end_user_id
"status": result if isinstance(result, str) else "success",
"end_user_id": end_user_id,
}
except ConfigurationError as e:

View File

@@ -390,19 +390,59 @@ def get_rag_total_kb(
current_user: User
) -> int:
"""
根据当前用户所在的workspace_id查询konwledges表所有不同id的数量
根据当前用户所在的workspace_id查询konwledges表中排除用户知识库permission_id!='Memory'的数量
"""
workspace_id = current_user.current_workspace_id
business_logger.info(f"获取RAG总知识库数: workspace_id={workspace_id}, 操作者: {current_user.username}")
business_logger.info(f"获取RAG总知识库数(排除用户知识库): workspace_id={workspace_id}, 操作者: {current_user.username}")
try:
total_kb = knowledge_repository.get_total_kb_count_by_workspace(db, workspace_id)
total_kb = knowledge_repository.get_non_user_kb_count_by_workspace(db, workspace_id)
business_logger.info(f"成功获取RAG总知识库数: {total_kb}")
return total_kb
except Exception as e:
business_logger.error(f"获取RAG总知识库数失败: workspace_id={workspace_id} - {str(e)}")
raise
def get_rag_user_kb_total_chunk(
db: Session,
current_user: User
) -> int:
"""
根据当前用户所在的workspace_id从documents表统计所有用户知识库的chunk总数。
与 /end_users 接口保持同源:查询 file_name 匹配 end_user_id.txt 的文档 chunk_num 之和。
"""
workspace_id = current_user.current_workspace_id
business_logger.info(f"获取用户知识库总chunk数(documents表): workspace_id={workspace_id}, 操作者: {current_user.username}")
try:
from app.models.document_model import Document
from app.models.end_user_model import EndUser
from app.models.app_model import App
from sqlalchemy import func
# 通过 App 关联取该 workspace 下所有 end_user_id
end_user_ids = [
str(eid) for (eid,) in db.query(EndUser.id)
.join(App, EndUser.app_id == App.id)
.filter(App.workspace_id == workspace_id)
.all()
]
if not end_user_ids:
return 0
file_names = [f"{uid}.txt" for uid in end_user_ids]
result = db.query(func.sum(Document.chunk_num)).filter(
Document.file_name.in_(file_names)
).scalar()
total_chunk = int(result or 0)
business_logger.info(f"成功获取用户知识库总chunk数: {total_chunk}")
return total_chunk
except Exception as e:
business_logger.error(f"获取用户知识库总chunk数失败: workspace_id={workspace_id} - {str(e)}")
raise
def get_current_user_total_chunk(
end_user_id: str,
db: Session,

View File

@@ -1,45 +1,42 @@
# 修改 memory_konwledges_server.py 文件
import asyncio
import os
import re
import uuid
from pathlib import Path
from typing import Optional
from pydantic import BaseModel, Field
from fastapi import HTTPException, status
from pydantic import BaseModel
from sqlalchemy.orm import Session
from app.celery_app import celery_app
from app.core.config import settings
from app.core.logging_config import get_api_logger
from app.core.rag.models.chunk import DocumentChunk
from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory
from app.core.response_utils import success
from app.db import get_db
from app.schemas import file_schema, document_schema
from fastapi import APIRouter, Depends, HTTPException, status, File, UploadFile, Query
from app.db import get_db_context
from app.models.document_model import Document
import uuid
from sqlalchemy.orm import Session
from fastapi import HTTPException, status
from app.core.config import settings
from app.models.user_model import User
from app.schemas import file_schema, document_schema
from app.schemas.file_schema import CustomTextFileCreate
from app.services import document_service, file_service, knowledge_service
from app.celery_app import celery_app
from app.core.logging_config import get_api_logger
from app.schemas.file_schema import CustomTextFileCreate
from app.db import get_db
# 创建一个简单的用户类用于测试
api_logger = get_api_logger()
class ChunkCreate(BaseModel):
content: str
class SimpleUser:
def __init__(self, user_id: str):
# 确保ID是UUID类型
self.id = user_id
self.username = user_id
'''解析'''
async def parse_document_by_id(document_id: uuid.UUID, db: Session, current_user: User):
"""
解析指定文档
@@ -120,7 +117,7 @@ async def parse_document_by_id(document_id: uuid.UUID, db: Session, current_user
api_logger.error(f"文档解析失败: document_id={document_id} - {str(e)}")
raise
'''获取块ID'''
async def get_document_chunks(
kb_id: uuid.UUID,
document_id: uuid.UUID,
@@ -198,7 +195,7 @@ async def get_document_chunks(
return success(data=result, msg="文档块列表查询成功")
'''查找文档ID'''
def find_document_id_by_kb_and_filename(
db: Session,
kb_id: str,
@@ -231,7 +228,7 @@ def find_document_id_by_kb_and_filename(
except Exception as e:
return None
'''获取知识库ID'''
def find_documents_by_kb_id(
db: Session,
kb_id: str,
@@ -268,18 +265,14 @@ def find_documents_by_kb_id(
except Exception as e:
return []
''''上传文件'''
async def memory_konwledges_up(
kb_id: str,
parent_id: str,
create_data: file_schema.CustomTextFileCreate,
db: Session = Depends(get_db),
current_user: SimpleUser = None, # 修改为SimpleUser
db: Session,
current_user: SimpleUser,
):
# 如果没有提供current_user则创建一个默认的
if current_user is None:
current_user = SimpleUser("5d27df0b-7eec-4fa6-9f8b-0f9b7e852f60")
content_bytes = create_data.content.encode('utf-8')
file_size = len(content_bytes)
print(f"file size: {file_size} byte")
@@ -350,8 +343,6 @@ async def memory_konwledges_up(
return success(data=document_schema.Document.model_validate(db_document), msg="custom text upload successful")
'''添加新块'''
async def create_document_chunk(
kb_id: uuid.UUID,
@@ -417,7 +408,7 @@ async def create_document_chunk(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"查询文档块失败: {error_msg}"
)
sort_id = sort_id + 1
# 5. 创建文档块
@@ -450,6 +441,7 @@ async def create_document_chunk(
return success(data=chunk, msg="文档块创建成功")
async def write_rag(end_user_id, message, user_rag_memory_id):
"""
将消息写入 RAG 知识库
@@ -483,15 +475,12 @@ async def write_rag(end_user_id, message, user_rag_memory_id):
detail=f"知识库ID格式无效: {user_rag_memory_id}"
)
db_gen = get_db()
db = next(db_gen)
try:
with get_db_context() as db:
create_data = CustomTextFileCreate(title=end_user_id, content=message)
current_user = SimpleUser(user_rag_memory_id)
# 检查文档是否已存在
document = find_document_id_by_kb_and_filename(db=db, kb_id=user_rag_memory_id, file_name=f"{end_user_id}.txt")
print('======',document)
print('======', document)
api_logger.info(f"查找文档结果: document_id={document}")
if document is not None:
# 文档已存在,直接添加新块
@@ -528,6 +517,3 @@ async def write_rag(end_user_id, message, user_rag_memory_id):
else:
api_logger.error(f"创建文档后无法找到文档ID: end_user_id={end_user_id}")
return result
finally:
# 确保数据库会话被关闭
db.close()

View File

@@ -115,6 +115,17 @@ class DataConfigService: # 数据配置服务类PostgreSQL
# --- Create ---
def create(self, params: ConfigParamsCreate) -> Dict[str, Any]: # 创建配置参数(仅名称与描述)
# 业务层检查同一工作空间下是否已存在同名配置
if params.workspace_id and params.config_name:
from app.models.memory_config_model import MemoryConfig
existing = (
self.db.query(MemoryConfig)
.filter_by(workspace_id=params.workspace_id, config_name=params.config_name)
.first()
)
if existing:
raise ValueError(f"DUPLICATE_CONFIG_NAME:{params.config_name}")
# 如果workspace_id存在且模型字段未全部指定则自动获取
if params.workspace_id and not all([params.llm_id, params.embedding_id, params.rerank_id]):
configs = self._get_workspace_configs(params.workspace_id)
@@ -211,6 +222,7 @@ class DataConfigService: # 数据配置服务类PostgreSQL
"apply_id": config.apply_id,
"scene_id": str(config.scene_id) if config.scene_id else None,
"scene_name": scene_name, # 新增:场景名称
"is_system_default": config.is_default, # 是否为系统默认配置
"llm_id": config.llm_id,
"embedding_id": config.embedding_id,
"rerank_id": config.rerank_id,

View File

@@ -90,7 +90,8 @@ class ModelConfigService:
api_key: str,
api_base: Optional[str] = None,
model_type: str = "llm",
test_message: str = "Hello"
test_message: str = "Hello",
is_omni: bool = False
) -> Dict[str, Any]:
"""验证模型配置是否有效
@@ -102,6 +103,7 @@ class ModelConfigService:
api_base: API基础URL
model_type: 模型类型 (llm/chat/embedding/rerank)
test_message: 测试消息
is_omni: 是否为Omni模型
Returns:
Dict: 验证结果
@@ -119,6 +121,7 @@ class ModelConfigService:
provider=provider,
api_key=api_key,
base_url=api_base,
is_omni=is_omni,
temperature=0.7,
max_tokens=100
)
@@ -257,8 +260,9 @@ class ModelConfigService:
provider=model_data.provider,
api_key=api_key_data.api_key,
api_base=api_key_data.api_base,
model_type=model_data.type, # 传递模型类型
test_message="Hello"
model_type=model_data.type,
test_message="Hello",
is_omni=model_data.is_omni
)
if not validation_result["valid"]:
raise BusinessException(
@@ -279,6 +283,9 @@ class ModelConfigService:
for api_key_data in api_key_datas:
api_key_data.model_name = model_data.name
api_key_data.provider = model_data.provider
# 同步capability和is_omni
api_key_data.capability = model_data.capability
api_key_data.is_omni = model_data.is_omni
api_key_create_schema = ModelApiKeyCreate(
model_config_ids=[model.id],
**api_key_data.model_dump()
@@ -473,6 +480,9 @@ class ModelApiKeyService:
model_config = ModelConfigRepository.get_by_id(db, model_config_id)
if not model_config:
continue
data.is_omni = model_config.is_omni
data.capability = model_config.capability
# 从ModelBase获取model_name
model_name = model_config.model_base.name if model_config.model_base else model_config.name
@@ -497,6 +507,8 @@ class ModelApiKeyService:
existing_key.config = data.config
existing_key.priority = data.priority
existing_key.model_name = model_name
existing_key.capability = data.capability
existing_key.is_omni = data.is_omni
# 检查是否已关联该模型配置
if model_config not in existing_key.model_configs:
@@ -513,7 +525,8 @@ class ModelApiKeyService:
api_key=data.api_key,
api_base=data.api_base,
model_type=model_config.type,
test_message="Hello"
test_message="Hello",
is_omni=data.is_omni
)
if not validation_result["valid"]:
# 记录验证失败的模型,但不抛出异常
@@ -528,6 +541,8 @@ class ModelApiKeyService:
provider=data.provider,
api_key=data.api_key,
api_base=data.api_base,
capability=data.capability,
is_omni=data.is_omni,
config=data.config,
is_active=data.is_active,
priority=data.priority
@@ -550,6 +565,10 @@ class ModelApiKeyService:
model_config = ModelConfigRepository.get_by_id(db, model_config_id)
if not model_config:
raise BusinessException("模型配置不存在", BizCode.MODEL_NOT_FOUND)
if api_key_data.is_omni is None:
api_key_data.is_omni = model_config.is_omni
if api_key_data.capability is None:
api_key_data.capability = model_config.capability
# 检查API Key是否已存在(包括软删除)需要考虑tenant_id
existing_key = db.query(ModelApiKey).join(
@@ -572,6 +591,8 @@ class ModelApiKeyService:
existing_key.config = api_key_data.config
existing_key.priority = api_key_data.priority
existing_key.model_name = api_key_data.model_name
existing_key.capability = api_key_data.capability
existing_key.is_omni = api_key_data.is_omni
# 检查是否已关联该模型配置
if model_config not in existing_key.model_configs:
@@ -589,7 +610,8 @@ class ModelApiKeyService:
api_key=api_key_data.api_key,
api_base=api_key_data.api_base,
model_type=model_config.type,
test_message="Hello"
test_message="Hello",
is_omni=api_key_data.is_omni
)
if not validation_result["valid"]:
raise BusinessException(
@@ -620,7 +642,8 @@ class ModelApiKeyService:
api_key=api_key_data.api_key or existing_api_key.api_key,
api_base=api_key_data.api_base or existing_api_key.api_base,
model_type=model_config.type,
test_message="Hello"
test_message="Hello",
is_omni=model_config.is_omni
)
if not validation_result["valid"]:
raise BusinessException(
@@ -755,6 +778,8 @@ class ModelBaseService:
"type": model_base.type,
"logo": model_base.logo,
"description": model_base.description,
"capability": model_base.capability,
"is_omni": model_base.is_omni,
"is_composite": False
}
model_config = ModelConfigRepository.create(db, model_config_data)

View File

@@ -123,11 +123,14 @@ class MultiAgentOrchestrator:
user_id: 用户 ID
variables: 变量参数
use_llm_routing: 是否使用 LLM 路由
web_search: 是否启用网络搜索
memory: 是否启用记忆功能
storage_type: 存储类型
user_rag_memory_id: 用户 RAG 记忆 ID
Yields:
SSE 格式的事件流
"""
import json
start_time = time.time()
@@ -200,7 +203,8 @@ class MultiAgentOrchestrator:
except Exception as e:
logger.error(
"多 Agent 任务执行失败(流式)",
extra={"error": str(e), "mode": self._normalized_mode}
extra={"error": str(e), "mode": self._normalized_mode},
exc_info=True
)
# 发送错误事件
yield self._format_sse_event("error", {
@@ -1267,7 +1271,7 @@ class MultiAgentOrchestrator:
Yields:
SSE 格式的事件流
"""
from app.services.draft_run_service import DraftRunService
from app.services.draft_run_service import AgentRunService
# 获取模型配置
model_config = self.db.get(ModelConfig, agent_config.default_model_config_id)
@@ -1278,7 +1282,7 @@ class MultiAgentOrchestrator:
)
# 流式执行 Agent
draft_service = DraftRunService(self.db)
draft_service = AgentRunService(self.db)
async for event in draft_service.run_stream(
agent_config=agent_config,
model_config=model_config,
@@ -1320,7 +1324,7 @@ class MultiAgentOrchestrator:
Returns:
执行结果
"""
from app.services.draft_run_service import DraftRunService
from app.services.draft_run_service import AgentRunService
# 获取模型配置
model_config = self.db.get(ModelConfig, agent_config.default_model_config_id)
@@ -1331,7 +1335,7 @@ class MultiAgentOrchestrator:
)
# 执行 Agent
draft_service = DraftRunService(self.db)
draft_service = AgentRunService(self.db)
result = await draft_service.run(
agent_config=agent_config,
model_config=model_config,
@@ -1633,6 +1637,7 @@ class MultiAgentOrchestrator:
self.memory = config_data.get("memory")
self.variables = config_data.get("variables", [])
self.tools = config_data.get("tools", {})
self.skills = config_data.get("skills", {})
self.default_model_config_id = release.default_model_config_id
return AgentConfigProxy(release, app, config_data)
@@ -2593,6 +2598,7 @@ class MultiAgentOrchestrator:
provider=api_key_config.provider,
api_key=api_key_config.api_key,
base_url=api_key_config.api_base,
is_omni=api_key_config.is_omni,
temperature=0.7, # 整合任务使用中等温度
max_tokens=2000
)
@@ -2758,6 +2764,7 @@ class MultiAgentOrchestrator:
provider=api_key_config.provider,
api_key=api_key_config.api_key,
base_url=api_key_config.api_base,
is_omni=api_key_config.is_omni,
temperature=0.7,
max_tokens=2000,
extra_params={"streaming": True} # 启用流式输出

View File

@@ -267,7 +267,7 @@ class MultiAgentService:
# 2. 验证模型配置(如果提供了)
if data.default_model_config_id:
model_api_key = ModelApiKeyService.get_a_api_key(self.db, data.default_model_config_id)
model_api_key = ModelApiKeyService.get_available_api_key(self.db, data.default_model_config_id)
if not model_api_key:
raise ResourceNotFoundException("模型配置", str(data.default_model_config_id))

View File

@@ -9,47 +9,100 @@
- OpenAI: 支持 URL 和 base64 格式
"""
import uuid
from typing import List, Dict, Any, Optional, Protocol
import httpx
import base64
from typing import List, Dict, Any, Optional
from abc import ABC, abstractmethod
from sqlalchemy.orm import Session
from docx import Document
import io
import PyPDF2
from app.core.logging_config import get_business_logger
from app.core.exceptions import BusinessException
from app.core.error_codes import BizCode
from app.schemas.app_schema import FileInput, FileType, TransferMethod
from app.models.generic_file_model import GenericFile
from app.models.file_metadata_model import FileMetadata
from app.core.config import settings
from app.services.audio_transcription_service import AudioTranscriptionService
logger = get_business_logger()
class ImageFormatStrategy(Protocol):
"""图片格式策略接口"""
class MultimodalFormatStrategy(ABC):
"""多模态格式策略基类"""
@abstractmethod
async def format_image(self, url: str) -> Dict[str, Any]:
"""格式化图片"""
pass
@abstractmethod
async def format_document(self, file_name: str, text: str) -> Dict[str, Any]:
"""格式化文档"""
pass
@abstractmethod
async def format_audio(self, file_type: str, url: str) -> Dict[str, Any]:
"""格式化音频"""
pass
@abstractmethod
async def format_video(self, url: str) -> Dict[str, Any]:
"""格式化视频"""
pass
class DashScopeFormatStrategy(MultimodalFormatStrategy):
"""通义千问策略"""
async def format_image(self, url: str) -> Dict[str, Any]:
"""将图片 URL 转换为特定 provider 的格式"""
...
class DashScopeImageStrategy:
"""通义千问图片格式策略"""
async def format_image(self, url: str) -> Dict[str, Any]:
"""通义千问格式: {"type": "image", "image": "url"}"""
"""通义千问图片格式:{"type": "image", "image": "url"}"""
return {
"type": "image",
"image": url
}
async def format_document(self, file_name: str, text: str) -> Dict[str, Any]:
"""通义千问文档格式"""
return {
"type": "text",
"text": f"<document name=\"{file_name}\">\n{text}\n</document>"
}
class BedrockImageStrategy:
"""Bedrock/Anthropic 图片格式策略"""
async def format_audio(self, file_type: str, url: str, transcription: Optional[str] = None) -> Dict[str, Any]:
"""
通义千问音频格式
- 原生支持: qwen-audio 系列
- 其他模型: 需要转录为文本
"""
if transcription:
return {
"type": "text",
"text": f"<audio url=\"{url}\">\n{transcription}\n</audio>"
}
# 通义千问音频格式:{"type": "audio", "audio": "url"}
return {
"type": "audio",
"audio": url
}
async def format_video(self, url: str) -> Dict[str, Any]:
"""通义千问视频格式qwen-vl 系列原生支持)"""
return {
"type": "video",
"video": url
}
class BedrockFormatStrategy(MultimodalFormatStrategy):
"""Bedrock/Anthropic 策略"""
async def format_image(self, url: str) -> Dict[str, Any]:
"""
Bedrock/Anthropic 格式: base64 编码
{"type": "image", "source": {"type": "base64", "media_type": "...", "data": "..."}}
"""
import httpx
import base64
from mimetypes import guess_type
logger.info(f"下载并编码图片: {url}")
@@ -84,9 +137,46 @@ class BedrockImageStrategy:
}
}
async def format_document(self, file_name: str, text: str) -> Dict[str, Any]:
"""Bedrock/Anthropic 文档格式(需要 base64 编码)"""
# Bedrock 文档需要 base64 编码
text_bytes = text.encode('utf-8')
base64_text = base64.b64encode(text_bytes).decode('utf-8')
class OpenAIImageStrategy:
"""OpenAI 图片格式策略"""
return {
"type": "document",
"source": {
"type": "base64",
"media_type": "text/plain",
"data": base64_text
}
}
async def format_audio(self, file_type: str, url: str, transcription: Optional[str] = None) -> Dict[str, Any]:
"""
Bedrock/Anthropic 音频格式
不支持原生音频,必须转录为文本
"""
if transcription:
return {
"type": "text",
"text": f"[音频转录]\n{transcription}"
}
return {
"type": "text",
"text": "[音频文件Bedrock 不支持原生音频,请启用音频转文本功能]"
}
async def format_video(self, url: str) -> Dict[str, Any]:
"""Bedrock/Anthropic 视频格式"""
return {
"type": "text",
"text": f"<video url=\"{url}\">\n[视频文件,当前 provider 暂不支持]\n</video>"
}
class OpenAIFormatStrategy(MultimodalFormatStrategy):
"""OpenAI 策略"""
async def format_image(self, url: str) -> Dict[str, Any]:
"""OpenAI 格式: {"type": "image_url", "image_url": {"url": "..."}}"""
@@ -97,29 +187,97 @@ class OpenAIImageStrategy:
}
}
async def format_document(self, file_name: str, text: str) -> Dict[str, Any]:
"""OpenAI 文档格式"""
return {
"type": "text",
"text": f"<document name=\"{file_name}\">\n{text}\n</document>"
}
async def format_audio(self, file_type: str, url: str, transcription: Optional[str] = None) -> Dict[str, Any]:
"""
OpenAI 音频格式
- gpt-4o-audio 系列支持原生音频(需要 base64 编码)
- 其他模型使用转录文本
"""
if transcription:
return {
"type": "text",
"text": f"<audio url=\"{url}\">\n{transcription}\n</audio>"
}
# OpenAI 音频需要 base64 编码
try:
async with httpx.AsyncClient(timeout=30.0) as client:
response = await client.get(url)
response.raise_for_status()
audio_data = response.content
base64_audio = base64.b64encode(audio_data).decode('utf-8')
# 1. 优先从 file_type (MIME) 取扩展名
file_ext = file_type.split('/')[-1] if file_type and '/' in file_type else None
# 2. 从响应头 content-type 取
if not file_ext:
ct = response.headers.get("content-type", "")
file_ext = ct.split('/')[-1].split(';')[0].strip() if '/' in ct else None
# 3. 从 URL 路径取扩展名
if not file_ext:
file_ext = url.split('?')[0].rsplit('.', 1)[-1].lower() or None
# 4. 默认 wav
# supported_ext = {"wav", "mp3", "mp4", "ogg", "flac", "webm", "m4a", "wave", "x-m4a"}
file_ext = "wav" if not file_ext else file_ext
return {
"type": "input_audio",
"input_audio": {
"data": f"data:;base64,{base64_audio}",
"format": file_ext
}
}
except Exception as e:
logger.error(f"下载音频失败: {e}")
return {
"type": "text",
"text": f"[音频处理失败: {str(e)}]"
}
async def format_video(self, url: str) -> Dict[str, Any]:
"""OpenAI 视频格式"""
return {
"type": "video_url",
"video_url": {
"url": url
}
}
# Provider 到策略的映射
PROVIDER_STRATEGIES = {
"dashscope": DashScopeImageStrategy,
"bedrock": BedrockImageStrategy,
"anthropic": BedrockImageStrategy,
"openai": OpenAIImageStrategy,
"dashscope": DashScopeFormatStrategy,
"bedrock": BedrockFormatStrategy,
"anthropic": BedrockFormatStrategy,
"openai": OpenAIFormatStrategy,
}
class MultimodalService:
"""多模态文件处理服务"""
def __init__(self, db: Session, provider: str = "dashscope"):
def __init__(self, db: Session, provider: str = "dashscope", api_key: Optional[str] = None, enable_audio_transcription: bool = False, is_omni: bool = False):
"""
初始化多模态服务
Args:
db: 数据库会话
provider: 模型提供商dashscope, bedrock, anthropic 等)
provider: 模型提供商dashscope, bedrock, anthropic, openai 等)
api_key: API 密钥(用于音频转文本)
enable_audio_transcription: 是否启用音频转文本
is_omni: 是否为 Omni 模型dashscope 的 omni 模型需要使用 OpenAI 兼容格式)
"""
self.db = db
self.provider = provider.lower()
self.api_key = api_key
self.enable_audio_transcription = enable_audio_transcription
self.is_omni = is_omni
async def process_files(
self,
@@ -137,20 +295,32 @@ class MultimodalService:
if not files:
return []
# 获取对应的策略
# dashscope 的 omni 模型使用 OpenAI 兼容格式
if self.provider == "dashscope" and self.is_omni:
strategy_class = OpenAIFormatStrategy
else:
strategy_class = PROVIDER_STRATEGIES.get(self.provider)
if not strategy_class:
logger.warning(f"未找到 provider '{self.provider}' 的策略,使用默认策略")
strategy_class = DashScopeFormatStrategy
strategy = strategy_class()
result = []
for idx, file in enumerate(files):
try:
if file.type == FileType.IMAGE:
content = await self._process_image(file)
content = await self._process_image(file, strategy)
result.append(content)
elif file.type == FileType.DOCUMENT:
content = await self._process_document(file)
content = await self._process_document(file, strategy)
result.append(content)
elif file.type == FileType.AUDIO:
content = await self._process_audio(file)
content = await self._process_audio(file, strategy)
result.append(content)
elif file.type == FileType.VIDEO:
content = await self._process_video(file)
content = await self._process_video(file, strategy)
result.append(content)
else:
logger.warning(f"不支持的文件类型: {file.type}")
@@ -172,55 +342,29 @@ class MultimodalService:
logger.info(f"成功处理 {len(result)}/{len(files)} 个文件provider={self.provider}")
return result
async def _process_image(self, file: FileInput) -> Dict[str, Any]:
async def _process_image(self, file: FileInput, strategy) -> Dict[str, Any]:
"""
处理图片文件
Args:
file: 图片文件输入
strategy: 格式化策略
Returns:
Dict: 根据 provider 返回不同格式
- Anthropic/Bedrock: {"type": "image", "source": {"type": "base64", "media_type": "...", "data": "..."}}
- 通义千问: {"type": "image", "image": "url"}
Dict: 根据 provider 返回不同格式的图片内容
"""
url = await self.get_file_url(file)
logger.debug(f"处理图片: {url}, provider={self.provider}")
# 根据 provider 返回不同格式
if self.provider in ["bedrock", "anthropic"]:
# Anthropic/Bedrock 只支持 base64 格式,需要下载并转换
try:
logger.info(f"开始下载并编码图片: {url}")
base64_data, media_type = await self._download_and_encode_image(url)
result = {
"type": "image",
"source": {
"type": "base64",
"media_type": media_type,
"data": base64_data[:100] + "..." # 只记录前100个字符
}
}
logger.info(f"图片编码完成: media_type={media_type}, data_length={len(base64_data)}")
# 返回完整数据
result["source"]["data"] = base64_data
return result
except Exception as e:
logger.error(f"下载并编码图片失败: {e}", exc_info=True)
# 返回错误提示
return {
"type": "text",
"text": f"[图片加载失败: {str(e)}]"
}
else:
# 通义千问等其他格式支持 URL
try:
url = await self.get_file_url(file)
return await strategy.format_image(url)
except Exception as e:
logger.error(f"处理图片失败: {e}", exc_info=True)
return {
"type": "image",
"image": url
"type": "text",
"text": f"[图片处理失败: {str(e)}]"
}
async def _download_and_encode_image(self, url: str) -> tuple[str, str]:
@staticmethod
async def _download_and_encode_image(url: str) -> tuple[str, str]:
"""
下载图片并转换为 base64
@@ -230,8 +374,6 @@ class MultimodalService:
Returns:
tuple: (base64_data, media_type)
"""
import httpx
import base64
from mimetypes import guess_type
# 下载图片
@@ -258,15 +400,16 @@ class MultimodalService:
return base64_data, media_type
async def _process_document(self, file: FileInput) -> Dict[str, Any]:
async def _process_document(self, file: FileInput, strategy) -> Dict[str, Any]:
"""
处理文档文件PDF、Word 等)
Args:
file: 文档文件输入
strategy: 格式化策略
Returns:
Dict: text 格式的内容(包含提取的文本)
Dict: 根据 provider 返回不同格式的文档内容
"""
if file.transfer_method == TransferMethod.REMOTE_URL:
# 远程文档暂不支持提取
@@ -277,48 +420,68 @@ class MultimodalService:
else:
# 本地文件,提取文本内容
text = await self._extract_document_text(file.upload_file_id)
generic_file = self.db.query(GenericFile).filter(
GenericFile.id == file.upload_file_id
file_metadata = self.db.query(FileMetadata).filter(
FileMetadata.id == file.upload_file_id
).first()
file_name = generic_file.file_name if generic_file else "unknown"
file_name = file_metadata.file_name if file_metadata else "unknown"
return {
"type": "text",
"text": f"<document name=\"{file_name}\">\n{text}\n</document>"
}
# 使用策略格式化文档
return await strategy.format_document(file_name, text)
async def _process_audio(self, file: FileInput) -> Dict[str, Any]:
async def _process_audio(self, file: FileInput, strategy) -> Dict[str, Any]:
"""
处理音频文件
Args:
file: 音频文件输入
strategy: 格式化策略
Returns:
Dict: 音频内容(暂时返回占位符)
Dict: 根据 provider 返回不同格式的音频内容
"""
# TODO: 实现音频转文字功能
return {
"type": "text",
"text": "[音频文件,暂不支持处理]"
}
try:
url = await self.get_file_url(file)
async def _process_video(self, file: FileInput) -> Dict[str, Any]:
# 如果启用音频转文本且有 API Key
transcription = None
if self.enable_audio_transcription and self.api_key:
logger.info(f"开始音频转文本: {url}")
if self.provider == "dashscope":
transcription = await AudioTranscriptionService.transcribe_dashscope(url, self.api_key)
elif self.provider == "openai":
transcription = await AudioTranscriptionService.transcribe_openai(url, self.api_key)
else:
logger.warning(f"Provider {self.provider} 不支持音频转文本")
return await strategy.format_audio(file.file_type, url, transcription)
except Exception as e:
logger.error(f"处理音频失败: {e}", exc_info=True)
return {
"type": "text",
"text": f"[音频处理失败: {str(e)}]"
}
async def _process_video(self, file: FileInput, strategy) -> Dict[str, Any]:
"""
处理视频文件
Args:
file: 视频文件输入
strategy: 格式化策略
Returns:
Dict: 视频内容(暂时返回占位符)
Dict: 根据 provider 返回不同格式的视频内容
"""
# TODO: 实现视频处理功能
return {
"type": "text",
"text": "[视频文件,暂不支持处理]"
}
try:
url = await self.get_file_url(file)
return await strategy.format_video(url)
except Exception as e:
logger.error(f"处理视频失败: {e}", exc_info=True)
return {
"type": "text",
"text": f"[视频处理失败: {str(e)}]"
}
async def get_file_url(self, file: FileInput) -> str:
"""
@@ -336,26 +499,22 @@ class MultimodalService:
if file.transfer_method == TransferMethod.REMOTE_URL:
return file.url
else:
# 本地文件,通过 file_storage 系统获取永久访问 URL
from app.models.file_metadata_model import FileMetadata
from app.core.config import settings
file_id = file.upload_file_id
print("="*50)
print("file_id",file_id)
# 查询 FileMetadata
file_metadata = self.db.query(FileMetadata).filter(
FileMetadata.id == file_id,
FileMetadata.status == "completed"
).first()
if not file_metadata:
raise BusinessException(
f"文件不存在或已删除: {file_id}",
BizCode.NOT_FOUND
)
# 返回永久URL
server_url = settings.FILE_LOCAL_SERVER_URL
return f"{server_url}/storage/permanent/{file_id}"
@@ -370,58 +529,79 @@ class MultimodalService:
Returns:
str: 提取的文本内容
"""
generic_file = self.db.query(GenericFile).filter(
GenericFile.id == file_id,
GenericFile.status == "active"
file_metadata = self.db.query(FileMetadata).filter(
FileMetadata.id == file_id,
FileMetadata.status == "completed"
).first()
if not generic_file:
if not file_metadata:
raise BusinessException(
f"文件不存在或已删除: {file_id}",
BizCode.NOT_FOUND
)
# TODO: 根据文件类型提取文本
# - PDF: 使用 PyPDF2 或 pdfplumber
# - Word: 使用 python-docx
# - TXT/MD: 直接读取
file_ext = generic_file.file_ext.lower()
file_ext = file_metadata.file_ext.lower()
server_url = settings.FILE_LOCAL_SERVER_URL
file_url = f"{server_url}/storage/permanent/{file_id}"
if file_ext in ['.txt', '.md', '.markdown']:
return await self._read_text_file(generic_file.storage_path)
return await self._read_text_file(file_url)
elif file_ext == '.pdf':
return await self._extract_pdf_text(generic_file.storage_path)
return await self._extract_pdf_text(file_url)
elif file_ext in ['.doc', '.docx']:
return await self._extract_word_text(generic_file.storage_path)
return await self._extract_word_text(file_url)
else:
return f"[不支持的文档格式: {file_ext}]"
async def _read_text_file(self, storage_path: str) -> str:
@staticmethod
async def _read_text_file(file_url: str) -> str:
"""读取纯文本文件"""
try:
with open(storage_path, 'r', encoding='utf-8') as f:
return f.read()
# 下载文件
async with httpx.AsyncClient(timeout=30.0) as client:
response = await client.get(file_url)
response.raise_for_status()
return response.text
except Exception as e:
logger.error(f"读取文本文件失败: {e}")
return f"[文件读取失败: {str(e)}]"
async def _extract_pdf_text(self, storage_path: str) -> str:
@staticmethod
async def _extract_pdf_text(file_url: str) -> str:
"""提取 PDF 文本"""
try:
# TODO: 实现 PDF 文本提取
# import PyPDF2 或 pdfplumber
return "[PDF 文本提取功能待实现]"
# 下载 PDF 文
async with httpx.AsyncClient(timeout=30.0) as client:
response = await client.get(file_url)
response.raise_for_status()
pdf_data = response.content
# 使用 BytesIO 读取 PDF
text_parts = []
pdf_file = io.BytesIO(pdf_data)
pdf_reader = PyPDF2.PdfReader(pdf_file)
for page in pdf_reader.pages:
text_parts.append(page.extract_text())
return '\n'.join(text_parts)
except Exception as e:
logger.error(f"提取 PDF 文本失败: {e}")
return f"[PDF 提取失败: {str(e)}]"
async def _extract_word_text(self, storage_path: str) -> str:
@staticmethod
async def _extract_word_text(file_url: str) -> str:
"""提取 Word 文档文本"""
try:
# TODO: 实现 Word 文本提取
# import docx
return "[Word 文本提取功能待实现]"
# 下载 Word 文
async with httpx.AsyncClient(timeout=30.0) as client:
response = await client.get(file_url)
response.raise_for_status()
word_data = response.content
# 使用 BytesIO 读取 Word 文档
word_file = io.BytesIO(word_data)
doc = Document(word_file)
text_parts = [paragraph.text for paragraph in doc.paragraphs]
return '\n'.join(text_parts)
except Exception as e:
logger.error(f"提取 Word 文本失败: {e}")
return f"[Word 提取失败: {str(e)}]"

View File

@@ -184,7 +184,8 @@ class PromptOptimizerService:
model_name=api_config.model_name,
provider=api_config.provider,
api_key=api_config.api_key,
base_url=api_config.api_base
base_url=api_config.api_base,
is_omni=api_config.is_omni
), type=ModelType(model_config.type))
try:
prompt_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'prompt')

View File

@@ -247,6 +247,7 @@ class SharedChatService:
api_key=api_key_obj.api_key,
provider=api_key_obj.provider,
api_base=api_key_obj.api_base,
is_omni=api_key_obj.is_omni,
temperature=model_parameters.get("temperature", 0.7),
max_tokens=model_parameters.get("max_tokens", 2000),
system_prompt=system_prompt,
@@ -454,6 +455,7 @@ class SharedChatService:
api_key=api_key_obj.api_key,
provider=api_key_obj.provider,
api_base=api_key_obj.api_base,
is_omni=api_key_obj.is_omni,
temperature=model_parameters.get("temperature", 0.7),
max_tokens=model_parameters.get("max_tokens", 2000),
system_prompt=system_prompt,

View File

@@ -121,7 +121,7 @@ class SkillService:
if skill and skill.is_active:
# 加载技能关联的工具
for tool_config in skill.tools:
tool = tool_service._get_tool_instance(tool_config.get("tool_id", ""), tenant_id)
tool = tool_service.get_tool_instance(tool_config.get("tool_id", ""), tenant_id)
if tool:
langchain_tool = tool.to_langchain_tool(tool_config.get("operation", None))
tools.append(langchain_tool)

View File

@@ -8,6 +8,8 @@ from datetime import datetime
from sqlalchemy.orm import Session
from app.core.error_codes import BizCode
from app.core.exceptions import BusinessException
from app.core.tools.mcp import MCPToolManager, SimpleMCPClient
from app.repositories.tool_repository import (
ToolRepository, BuiltinToolRepository, CustomToolRepository,
@@ -79,6 +81,18 @@ class ToolService:
config = self.tool_repo.find_by_id_and_tenant(self.db, uuid.UUID(tool_id), tenant_id)
return self._config_to_info(config) if config else None
def _check_name_duplicate(self, name: str, tool_type: ToolType, tenant_id: uuid.UUID, exclude_id: Optional[uuid.UUID] = None):
"""检查工具名称是否重复"""
query = self.db.query(ToolConfig).filter(
ToolConfig.name == name,
ToolConfig.tool_type == tool_type.value,
ToolConfig.tenant_id == tenant_id
)
if exclude_id:
query = query.filter(ToolConfig.id != exclude_id)
if query.first():
raise BusinessException(f"工具名称 '{name}' 已存在", BizCode.DUPLICATE_NAME)
def create_tool(
self,
name: str,
@@ -92,6 +106,7 @@ class ToolService:
"""创建工具"""
if tool_type == ToolType.BUILTIN:
raise ValueError("内置工具不允许创建")
self._check_name_duplicate(name, tool_type, tenant_id)
try:
# 创建基础配置
@@ -141,6 +156,7 @@ class ToolService:
raise ValueError("内置工具不允许修改名称、描述和图标")
try:
if name:
self._check_name_duplicate(name, config_obj.tool_type, tenant_id, exclude_id=config_obj.id)
config_obj.name = name
if description:
config_obj.description = description
@@ -209,7 +225,7 @@ class ToolService:
try:
# 获取工具实例
tool = self._get_tool_instance(tool_id, tenant_id)
tool = self.get_tool_instance(tool_id, tenant_id)
if not tool:
return ToolResult.error_result(
error=f"工具不存在: {tool_id}",
@@ -335,7 +351,7 @@ class ToolService:
return []
# 获取工具实例
tool_instance = self._get_tool_instance(str(config.id), config.tenant_id)
tool_instance = self.get_tool_instance(str(config.id), config.tenant_id)
if not tool_instance:
return []
@@ -792,7 +808,7 @@ class ToolService:
"""获取工具配置"""
return self.tool_repo.find_by_id_and_tenant(self.db, uuid.UUID(tool_id), tenant_id)
def _get_tool_instance(self, tool_id: str, tenant_id: uuid.UUID) -> Optional[BaseTool]:
def get_tool_instance(self, tool_id: str, tenant_id: uuid.UUID) -> Optional[BaseTool]:
"""获取工具实例"""
if tool_id in self._tool_cache:
return self._tool_cache[tool_id]
@@ -1416,7 +1432,7 @@ class ToolService:
"""测试内置工具连接"""
try:
# 获取工具实例
tool_instance = self._get_tool_instance(str(config.id), config.tenant_id)
tool_instance = self.get_tool_instance(str(config.id), config.tenant_id)
if not tool_instance:
return {"success": False, "message": "无法创建工具实例"}

View File

@@ -21,8 +21,7 @@ from app.repositories.end_user_repository import EndUserRepository
from app.repositories.neo4j.cypher_queries import Graph_Node_query
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.schemas.memory_episodic_schema import EmotionSubject, EmotionType, type_mapping
from app.services.implicit_memory_service import ImplicitMemoryService
from app.services.memory_base_service import MemoryBaseService, MemoryTransService
from app.services.memory_base_service import MemoryBaseService
from app.services.memory_config_service import MemoryConfigService
from app.services.memory_perceptual_service import MemoryPerceptualService
from app.services.memory_short_service import ShortService
@@ -1167,7 +1166,6 @@ async def analytics_user_summary(end_user_id: Optional[str] = None, language: st
from app.core.language_utils import validate_language
from app.core.memory.utils.prompt.prompt_utils import render_user_summary_prompt
from app.db import get_db
from app.repositories.end_user_repository import EndUserRepository
# 验证语言参数
@@ -1178,8 +1176,7 @@ async def analytics_user_summary(end_user_id: Optional[str] = None, language: st
if end_user_id:
try:
# 获取数据库会话并查询用户信息
db = next(get_db())
try:
with get_db_context() as db:
repo = EndUserRepository(db)
end_user = repo.get_by_id(uuid.UUID(end_user_id))
if end_user and end_user.other_name:
@@ -1187,8 +1184,7 @@ async def analytics_user_summary(end_user_id: Optional[str] = None, language: st
logger.info(f"使用 other_name 作为用户显示名称: {user_display_name}")
else:
logger.info(f"用户 {end_user_id} 的 other_name 为空,使用默认称呼: {user_display_name}")
finally:
db.close()
except Exception as e:
logger.warning(f"获取用户 other_name 失败,使用默认称呼: {str(e)}")

View File

@@ -16,6 +16,7 @@ from app.core.workflow.adapters.registry import PlatformAdapterRegistry
from app.core.workflow.executor import execute_workflow, execute_workflow_stream
from app.core.workflow.nodes.enums import NodeType
from app.core.workflow.validator import validate_workflow_config
from app.core.workflow.variable.base_variable import FileObject
from app.db import get_db
from app.models import App
from app.models.workflow_model import WorkflowConfig, WorkflowExecution
@@ -453,11 +454,14 @@ class WorkflowService:
files_struct = []
for file in files:
files_struct.append(
{
"type": file.type,
"url": await self.multimodal_service.get_file_url(file),
"__file": True
}
FileObject(
type=file.type,
url=await self.multimodal_service.get_file_url(file),
transfer_method=file.transfer_method,
file_id=str(file.upload_file_id),
origin_file_type=file.file_type,
is_file=True
).model_dump()
)
return files_struct

View File

@@ -107,6 +107,7 @@ def get_user_workspaces(db: Session, user: User) -> List[Workspace]:
for workspace in workspaces:
if workspace.storage_type == 'neo4j':
_ensure_default_memory_config(db, workspace)
_ensure_default_ontology_scenes(db, workspace)
business_logger.info(f"用户 {user.username} 的工作空间数量: {len(workspaces)}")
return workspaces
@@ -1104,6 +1105,52 @@ def _fill_workspace_configs_model_defaults(
)
def _ensure_default_ontology_scenes(db: Session, workspace: Workspace) -> None:
"""Ensure a workspace has default ontology scenes, creating them if missing.
Checks whether any is_system_default scene exists for the workspace.
If not, runs the DefaultOntologyInitializer to create them.
Args:
db: Database session
workspace: The workspace to check
"""
from app.models.ontology_scene import OntologyScene
# 幂等检查:是否已存在系统默认场景
existing = db.query(OntologyScene).filter(
OntologyScene.workspace_id == workspace.id,
OntologyScene.is_system_default.is_(True)
).first()
if existing:
return
business_logger.info(
f"Workspace {workspace.id} missing default ontology scenes, creating them"
)
try:
initializer = DefaultOntologyInitializer(db)
success, error_msg = initializer.initialize_default_scenes(
workspace.id, language="zh"
)
if success:
db.commit()
business_logger.info(
f"为工作空间 {workspace.id} 补建默认本体场景成功"
)
else:
business_logger.warning(
f"为工作空间 {workspace.id} 补建默认本体场景失败: {error_msg}"
)
except Exception as e:
db.rollback()
business_logger.error(
f"为工作空间 {workspace.id} 补建默认本体场景异常: {str(e)}"
)
def _create_default_memory_config(
db: Session,
workspace_id: uuid.UUID,

View File

@@ -257,7 +257,7 @@ def parse_document(file_path: str, document_id: uuid.UUID):
progress_msg += f"{datetime.now().strftime('%H:%M:%S')} GraphRAG task result for task {task}:\n{result}\n"
return result
try:
def sync_task():
trio.run(
lambda: _run(
row=task,
@@ -272,6 +272,10 @@ def parse_document(file_path: str, document_id: uuid.UUID):
with_community=with_community,
)
)
try:
with ThreadPoolExecutor(max_workers=1) as executor:
future = executor.submit(sync_task)
future.result() # Blocks until the task completes
except Exception as e:
progress_msg += f"{datetime.now().strftime('%H:%M:%S')} GraphRAG task failed for task {task}:\n{str(e)}\n"
progress_msg += f"{datetime.now().strftime('%H:%M:%S')} Knowledge Graph done ({time.time() - start_time}s)"
@@ -2108,4 +2112,307 @@ def run_forgetting_cycle_task(self, config_id: Optional[uuid.UUID] = None) -> Di
# "config_id": config_id,
# "elapsed_time": elapsed_time,
# "task_id": self.request.id
# }
# }
# =============================================================================
# 隐性记忆和情绪数据更新定时任务
# =============================================================================
@celery_app.task(
name="app.tasks.update_implicit_emotions_storage",
bind=True,
ignore_result=True,
max_retries=0,
acks_late=False,
time_limit=7200, # 2小时硬超时
soft_time_limit=6900, # 1小时55分钟软超时
)
def update_implicit_emotions_storage(self) -> Dict[str, Any]:
"""定时任务:更新所有用户的隐性记忆画像和情绪建议数据
遍历数据库中所有已存在数据的用户,为每个用户重新生成隐性记忆画像和情绪建议。
实现错误隔离,单个用户失败不影响其他用户的处理。
Returns:
包含任务执行结果的字典,包括:
- status: 任务状态 (SUCCESS/FAILURE)
- message: 执行消息
- total_users: 总用户数
- successful_implicit: 成功更新隐性记忆的用户数
- successful_emotion: 成功更新情绪建议的用户数
- failed: 失败的用户数
- user_results: 每个用户的详细结果
- elapsed_time: 执行耗时(秒)
- task_id: 任务ID
"""
start_time = time.time()
async def _run() -> Dict[str, Any]:
from app.core.logging_config import get_logger
from app.repositories.implicit_emotions_storage_repository import ImplicitEmotionsStorageRepository
from app.models.implicit_emotions_storage_model import ImplicitEmotionsStorage
from sqlalchemy import select, func
from app.services.implicit_memory_service import ImplicitMemoryService
from app.services.emotion_analytics_service import EmotionAnalyticsService
logger = get_logger(__name__)
logger.info("开始执行隐性记忆和情绪数据更新定时任务")
total_users = 0
successful_implicit = 0
successful_emotion = 0
failed = 0
user_results = []
with get_db_context() as db:
try:
# 获取所有已存储数据的用户ID分批次处理
repo = ImplicitEmotionsStorageRepository(db)
# 先统计总数用于日志
from sqlalchemy import func
total_users = db.execute(
select(func.count()).select_from(ImplicitEmotionsStorage)
).scalar() or 0
logger.info(f"找到 {total_users} 个需要更新的用户")
# 遍历每个用户并更新数据分批次避免一次性加载所有ID
for end_user_id in repo.get_all_user_ids(batch_size=100):
logger.info(f"开始处理用户: {end_user_id}")
user_start_time = time.time()
implicit_success = False
emotion_success = False
errors = []
try:
# 更新隐性记忆画像
try:
implicit_service = ImplicitMemoryService(db=db, end_user_id=end_user_id)
profile_data = await implicit_service.generate_complete_profile(user_id=end_user_id)
await implicit_service.save_profile_cache(
end_user_id=end_user_id,
profile_data=profile_data,
db=db
)
implicit_success = True
logger.info(f"成功更新用户 {end_user_id} 的隐性记忆画像")
except Exception as e:
error_msg = f"隐性记忆更新失败: {str(e)}"
errors.append(error_msg)
logger.error(f"用户 {end_user_id} {error_msg}")
# 更新情绪建议
try:
emotion_service = EmotionAnalyticsService()
suggestions_data = await emotion_service.generate_emotion_suggestions(
end_user_id=end_user_id,
db=db,
language="zh"
)
await emotion_service.save_suggestions_cache(
end_user_id=end_user_id,
suggestions_data=suggestions_data,
db=db
)
emotion_success = True
logger.info(f"成功更新用户 {end_user_id} 的情绪建议")
except Exception as e:
error_msg = f"情绪建议更新失败: {str(e)}"
errors.append(error_msg)
logger.error(f"用户 {end_user_id} {error_msg}")
# 统计结果
if implicit_success:
successful_implicit += 1
if emotion_success:
successful_emotion += 1
if not implicit_success and not emotion_success:
failed += 1
user_elapsed = time.time() - user_start_time
# 记录用户处理结果
user_result = {
"end_user_id": end_user_id,
"implicit_success": implicit_success,
"emotion_success": emotion_success,
"errors": errors,
"elapsed_time": user_elapsed
}
user_results.append(user_result)
logger.info(
f"用户 {end_user_id} 处理完成: "
f"隐性记忆={'成功' if implicit_success else '失败'}, "
f"情绪建议={'成功' if emotion_success else '失败'}, "
f"耗时={user_elapsed:.2f}"
)
except Exception as e:
# 单个用户失败不影响其他用户(错误隔离)
failed += 1
user_elapsed = time.time() - user_start_time
error_info = {
"end_user_id": end_user_id,
"implicit_success": False,
"emotion_success": False,
"errors": [str(e)],
"elapsed_time": user_elapsed
}
user_results.append(error_info)
logger.error(f"处理用户 {end_user_id} 时出错: {str(e)}")
# ---- 处理增量用户(当天新增、尚未初始化的用户)----
new_users_initialized = 0
new_users_failed = 0
logger.info("开始处理当天新增的增量用户初始化")
for end_user_id in repo.get_new_user_ids_today(batch_size=100):
logger.info(f"开始初始化新用户: {end_user_id}")
user_start_time = time.time()
implicit_success = False
emotion_success = False
errors = []
try:
try:
implicit_service = ImplicitMemoryService(db=db, end_user_id=end_user_id)
profile_data = await implicit_service.generate_complete_profile(user_id=end_user_id)
await implicit_service.save_profile_cache(
end_user_id=end_user_id,
profile_data=profile_data,
db=db
)
implicit_success = True
logger.info(f"成功初始化新用户 {end_user_id} 的隐性记忆画像")
except Exception as e:
error_msg = f"隐性记忆初始化失败: {str(e)}"
errors.append(error_msg)
logger.error(f"新用户 {end_user_id} {error_msg}")
try:
emotion_service = EmotionAnalyticsService()
suggestions_data = await emotion_service.generate_emotion_suggestions(
end_user_id=end_user_id,
db=db,
language="zh"
)
await emotion_service.save_suggestions_cache(
end_user_id=end_user_id,
suggestions_data=suggestions_data,
db=db
)
emotion_success = True
logger.info(f"成功初始化新用户 {end_user_id} 的情绪建议")
except Exception as e:
error_msg = f"情绪建议初始化失败: {str(e)}"
errors.append(error_msg)
logger.error(f"新用户 {end_user_id} {error_msg}")
if implicit_success or emotion_success:
new_users_initialized += 1
else:
new_users_failed += 1
user_elapsed = time.time() - user_start_time
user_results.append({
"end_user_id": end_user_id,
"type": "init",
"implicit_success": implicit_success,
"emotion_success": emotion_success,
"errors": errors,
"elapsed_time": user_elapsed
})
except Exception as e:
new_users_failed += 1
user_elapsed = time.time() - user_start_time
user_results.append({
"end_user_id": end_user_id,
"type": "init",
"implicit_success": False,
"emotion_success": False,
"errors": [str(e)],
"elapsed_time": user_elapsed
})
logger.error(f"初始化新用户 {end_user_id} 时出错: {str(e)}")
logger.info(
f"增量用户初始化完成: 成功={new_users_initialized}, 失败={new_users_failed}"
)
# ---- 增量用户处理结束 ----
# 记录总体统计信息
logger.info(
f"隐性记忆和情绪数据更新定时任务完成: "
f"存量用户总数={total_users}, "
f"隐性记忆成功={successful_implicit}, "
f"情绪建议成功={successful_emotion}, "
f"存量失败={failed}, "
f"增量初始化成功={new_users_initialized}, "
f"增量初始化失败={new_users_failed}"
)
return {
"status": "SUCCESS",
"message": (
f"存量用户 {total_users} 个,隐性记忆 {successful_implicit} 个成功,情绪建议 {successful_emotion} 个成功;"
f"增量新用户初始化 {new_users_initialized} 个成功,{new_users_failed} 个失败"
),
"total_users": total_users,
"successful_implicit": successful_implicit,
"successful_emotion": successful_emotion,
"failed": failed,
"new_users_initialized": new_users_initialized,
"new_users_failed": new_users_failed,
"user_results": user_results[:50] # 只保留前50个用户的详细结果
}
except Exception as e:
logger.error(f"隐性记忆和情绪数据更新定时任务执行失败: {str(e)}")
return {
"status": "FAILURE",
"error": str(e),
"total_users": total_users,
"successful_implicit": successful_implicit,
"successful_emotion": successful_emotion,
"failed": failed,
"new_users_initialized": 0,
"new_users_failed": 0,
"user_results": user_results[:50]
}
try:
# 使用 nest_asyncio 来避免事件循环冲突
try:
import nest_asyncio
nest_asyncio.apply()
except ImportError:
pass
# 尝试获取现有事件循环,如果不存在则创建新的
try:
loop = asyncio.get_event_loop()
if loop.is_closed():
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
result = loop.run_until_complete(_run())
elapsed_time = time.time() - start_time
result["elapsed_time"] = elapsed_time
result["task_id"] = self.request.id
return result
except Exception as e:
elapsed_time = time.time() - start_time
return {
"status": "FAILURE",
"error": str(e),
"elapsed_time": elapsed_time,
"task_id": self.request.id
}

View File

@@ -29,10 +29,10 @@ REDIS_DB=
REDIS_PASSWORD=password
#celery
BROKER_URL=
RESULT_BACKEND=
CELERY_BROKER=
CELERY_BACKEND=
# NOTE: 不要使用 BROKER_URL / RESULT_BACKEND / CELERY_BROKER / CELERY_BACKEND
# 这些名称会被 Celery CLI 劫持,详见 docs/celery-env-bug-report.md
REDIS_DB_CELERY_BROKER=
REDIS_DB_CELERY_BACKEND=
# Memory Cache Regeneration Configuration
# Interval in hours for regenerating memory insight and user summary cache

View File

@@ -0,0 +1,43 @@
"""202603051440
Revision ID: 6a4641cf192b
Revises: b4af97639217
Create Date: 2026-03-05 14:41:03.371557
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision: str = '6a4641cf192b'
down_revision: Union[str, None] = 'b4af97639217'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('implicit_emotions_storage',
sa.Column('id', sa.UUID(), nullable=False, comment='主键ID'),
sa.Column('end_user_id', sa.String(length=255), nullable=False, comment='终端用户ID'),
sa.Column('implicit_profile', postgresql.JSONB(astext_type=sa.Text()), nullable=True, comment='隐性记忆用户画像数据'),
sa.Column('emotion_suggestions', postgresql.JSONB(astext_type=sa.Text()), nullable=True, comment='情绪个性化建议数据'),
sa.Column('created_at', sa.DateTime(), nullable=False, comment='创建时间'),
sa.Column('updated_at', sa.DateTime(), nullable=False, comment='更新时间'),
sa.Column('implicit_generated_at', sa.DateTime(), nullable=True, comment='隐性记忆画像生成时间'),
sa.Column('emotion_generated_at', sa.DateTime(), nullable=True, comment='情绪建议生成时间'),
sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint('end_user_id')
)
op.create_index('idx_updated_at', 'implicit_emotions_storage', ['updated_at'], unique=False)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_index('idx_updated_at', table_name='implicit_emotions_storage')
op.drop_table('implicit_emotions_storage')
# ### end Alembic commands ###

View File

@@ -0,0 +1,63 @@
"""202603051033
Revision ID: b4af97639217
Revises: 4bf27c66ae63
Create Date: 2026-03-05 10:36:06.282227
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = 'b4af97639217'
down_revision: Union[str, None] = '4bf27c66ae63'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
# Add columns as nullable first to avoid table locks
op.add_column('model_api_keys', sa.Column('capability', sa.ARRAY(sa.String()), nullable=True, comment="模型能力列表(如['vision', 'audio', 'video']"))
op.add_column('model_api_keys', sa.Column('is_omni', sa.Boolean(), nullable=True, comment='是否为Omni模型使用特殊API调用'))
op.add_column('model_bases', sa.Column('capability', sa.ARRAY(sa.String()), nullable=True, comment="模型能力列表(如['vision', 'audio', 'video']"))
op.add_column('model_bases', sa.Column('is_omni', sa.Boolean(), nullable=True, comment='是否为Omni模型使用特殊API调用'))
op.add_column('model_configs', sa.Column('capability', sa.ARRAY(sa.String()), nullable=True, comment="模型能力列表(如['vision', 'audio', 'video']"))
op.add_column('model_configs', sa.Column('is_omni', sa.Boolean(), nullable=True, comment='是否为Omni模型使用特殊API调用'))
# Update existing rows with default values
op.execute("UPDATE model_api_keys SET capability = '{}' WHERE capability IS NULL")
op.execute("UPDATE model_api_keys SET is_omni = false WHERE is_omni IS NULL")
op.execute("UPDATE model_bases SET capability = '{}' WHERE capability IS NULL")
op.execute("UPDATE model_bases SET is_omni = false WHERE is_omni IS NULL")
op.execute("UPDATE model_configs SET capability = '{}' WHERE capability IS NULL")
op.execute("UPDATE model_configs SET is_omni = false WHERE is_omni IS NULL")
# Now make columns NOT NULL
op.alter_column('model_api_keys', 'capability', nullable=False)
op.alter_column('model_api_keys', 'is_omni', nullable=False)
op.alter_column('model_bases', 'capability', nullable=False)
op.alter_column('model_bases', 'is_omni', nullable=False)
op.alter_column('model_configs', 'capability', nullable=False)
op.alter_column('model_configs', 'is_omni', nullable=False)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column('model_configs', 'is_omni')
op.drop_column('model_configs', 'capability')
op.drop_column('model_bases', 'is_omni')
op.drop_column('model_bases', 'capability')
op.drop_column('model_api_keys', 'is_omni')
op.drop_column('model_api_keys', 'capability')
# ### end Alembic commands ###

View File

@@ -163,9 +163,14 @@ export const getImplicitInterestAreas = (end_user_id: string) => {
export const getImplicitHabits = (end_user_id: string) => {
return request.get(`/memory/implicit-memory/habits/${end_user_id}`)
}
// Implicit Memory - Generate user portrait
export const generateProfile = (end_user_id: string) => {
return request.post(`/memory/implicit-memory/generate_profile`, { end_user_id })
}
// Implicit Memory - Check if data exists
export const implicitCheckData = (end_user_id: string) => {
return request.get(`/memory/implicit-memory/check-data/${end_user_id}`)
}
// Short-term memory
export const getShortTerm = (end_user_id: string) => {
return request.get(`/memory/short/short_term`, { end_user_id })

View File

@@ -1,16 +1,21 @@
import { type FC, useRef, useState } from 'react'
import RecordRTC from 'recordrtc'
import { fileUpload } from '@/api/fileStorage'
import { fileUploadUrlWithoutApiPrefix } from '@/api/fileStorage'
import { request } from '@/utils/request'
interface AudioRecorderProps {
onRecordingComplete?: (file: { file_id: string; file_key: string; }, blob: Blob) => void
className?: string
onRecordingComplete?: (file: { file_id: string; file_key: string; url: string; type?: string; }, blob?: Blob) => void
className?: string;
action?: string;
requestConfig?: Record<string, any>;
}
const AudioRecorder: FC<AudioRecorderProps> = ({
onRecordingComplete,
className = '',
action = fileUploadUrlWithoutApiPrefix,
requestConfig = {}
}) => {
const [isRecording, setIsRecording] = useState(false)
const recorderRef = useRef<RecordRTC | null>(null)
@@ -33,11 +38,17 @@ const AudioRecorder: FC<AudioRecorderProps> = ({
if (recorderRef.current) {
recorderRef.current.stopRecording(() => {
const blob = recorderRef.current!.getBlob()
const url = recorderRef.current!.toURL()
const formData = new FormData()
formData.append('file', blob, `recording_${Date.now()}.webm`)
fileUpload(formData)
request
.uploadFile(action, formData, requestConfig)
.then(res => {
onRecordingComplete?.(res as { file_id: string; file_key: string; }, blob)
onRecordingComplete?.({
...(res as { file_id: string; file_key: string }),
type: blob.type,
url
}, blob)
recorderRef.current?.destroy()
recorderRef.current = null
})

View File

@@ -2,10 +2,11 @@
* @Author: ZhaoYing
* @Date: 2025-12-10 16:46:14
* @Last Modified by: ZhaoYing
* @Last Modified time: 2026-02-10 12:13:52
* @Last Modified time: 2026-03-04 18:42:49
*/
import { type FC, useEffect, useMemo } from 'react'
import { Flex, Input, Form } from 'antd'
import SendIcon from '@/assets/images/conversation/send.svg'
import SendDisabledIcon from '@/assets/images/conversation/sendDisabled.svg'
import LoadingIcon from '@/assets/images/conversation/loading.svg'
@@ -80,9 +81,31 @@ const ChatInput: FC<ChatInputProps> = ({
</div>
)
}
if (file.type.includes('video')) {
return (
<div key={file.uid} className="rb:w-45 rb:h-16 rb:inline-block rb:group rb:relative rb:rounded-lg">
<video src={file.url} controls className="rb:w-45 rb:h-16 rb:rounded-lg rb:object-cover rb:cursor-pointer" />
<div
className="rb:hidden rb:group-hover:block rb:absolute rb:-right-1 rb:-top-1 rb:size-3.5 rb:cursor-pointer rb:bg-cover rb:bg-[url('@/assets/images/conversation/delete.svg')] rb:hover:bg-[url('@/assets/images/conversation/delete_hover.svg')]"
onClick={() => handleDelete(file)}
></div>
</div>
)
}
if (file.type.includes('audio')) {
return (
<div key={file.uid} className="rb:w-45 rb:h-16 rb:inline-flex rb:items-center rb:group rb:relative rb:rounded-lg rb:bg-[#F0F3F8] rb:py-2 rb:px-2.5 rb:gap-2">
<audio src={file.url} controls className="rb:w-45 rb:h-16" />
<div
className="rb:hidden rb:group-hover:block rb:absolute rb:-right-1 rb:-top-1 rb:size-3.5 rb:cursor-pointer rb:bg-cover rb:bg-[url('@/assets/images/conversation/delete.svg')] rb:hover:bg-[url('@/assets/images/conversation/delete_hover.svg')]"
onClick={() => handleDelete(file)}
></div>
</div>
)
}
return (
<div key={file.uid} className="rb:w-45 rb:text-[12px] rb:gap-2.5 rb:flex rb:items-center rb:group rb:relative rb:rounded-lg rb:bg-[#F0F3F8] rb:py-2 rb:px-2.5">
{(file.type.includes('word') || file.type.includes('wordprocessingml.document')) && <div
{(file.type.includes('doc') || file.type.includes('docx') || file.type.includes('word') || file.type.includes('wordprocessingml.document')) && <div
className="rb:size-5 rb:cursor-pointer rb:bg-cover rb:bg-[url('@/assets/images/conversation/word_disabled.svg')] rb:hover:bg-[url('@/assets/images/conversation/word.svg')]"
></div>}
{(file.type.includes('pdf')) && <div

View File

@@ -453,9 +453,11 @@ export const en = {
prevStep: 'Previous Step',
exportSuccess: 'Export successful',
recommend: 'Recommend',
default: 'Default',
logoTip: `Supported image formats: JPG, PNG \n Suggested size: square ratio \n Maximum size: ≤ 2MB`,
imageSquareRequired: 'Please upload a square image',
nameInvalid: 'Name cannot start or end with a space',
notAllSpaces: 'Cannot be all spaces',
},
model: {
searchPlaceholder: 'search model…',
@@ -603,7 +605,13 @@ export const en = {
ollama: "Ollama",
xinference: "Xinference",
gpustack: "Gpustack",
bedrock: "Bedrock"
bedrock: "Bedrock",
is_vision: 'Vision Support',
is_omni: 'Omni Support',
vision: 'Vision',
audio: 'Audio',
video: 'Video',
},
knowledgeBase: {
home: 'Home',
@@ -1684,6 +1692,8 @@ Memory Bear: After the rebellion, regional warlordism intensified for several re
uploadFile: 'Upload File',
fileType: 'File Type',
image: 'Image',
video: 'Video',
audio: 'Audio',
fileUrl: 'File URL',
addRemoteFile: 'Add Remote File',
variableConfig: 'Variable Configuration',
@@ -1962,6 +1972,8 @@ Memory Bear: After the rebellion, regional warlordism intensified for several re
marketConnected: '● Connected',
marketDisconnected: '○ Disconnected',
marketConnecting: 'Connecting to {{name}}...',
serverUrlInvalid: 'Must start with http:// or https://, and cannot have leading or trailing spaces',
requestHeaderKeyInvalid: 'Only English letters, numbers, hyphens (-), and underscores (_) are allowed, and cannot start or end with a hyphen or underscore',
},
workflow: {
coreNode: 'Core Nodes',
@@ -2311,6 +2323,7 @@ Memory Bear: After the rebellion, regional warlordism intensified for several re
suggestions: 'Personalized Suggestions',
suggestionLoading: 'Your personalized suggestions are being generated',
item: 'item',
noData: 'Emotion suggestion data does not exist, please click the refresh button to initialize',
},
reflectionEngine: {
reflectionEngineConfig: 'Reflection Engine Configuration',
@@ -2557,7 +2570,8 @@ Memory Bear: After the rebellion, regional warlordism intensified for several re
context_details: 'Preference Details',
supporting_evidence: 'Preference Source',
specific_examples: 'Source',
wordEmpty: 'Click on a node in the left chart to view preference details'
wordEmpty: 'Click on a node in the left chart to view preference details',
noData: 'Portrait data does not exist, please click the refresh button to initialize',
},
shortTermDetail: {
title: 'Short-term memory is the "workbench" of the AI system, connecting instant conversations with long-term knowledge bases. Through real-time capture, deep retrieval, intelligent extraction and filtering transformation, temporary unstructured information is converted into valuable long-term knowledge.',
@@ -2615,6 +2629,7 @@ Memory Bear: After the rebellion, regional warlordism intensified for several re
updated_at: 'Updated At',
entityTypes: 'Entity Types',
classSearchPlaceholder: 'Search types',
addClass: 'Add Type',
class_name: 'Type Name',
class_description: 'Type Definition',

View File

@@ -1033,9 +1033,11 @@ export const zh = {
prevStep: '上一步',
exportSuccess: '导出成功',
recommend: '推荐',
default: '默认',
logoTip: `支持图片格式JPG、PNG\n 尺寸:正方形比例 \n 文件大小限制:≤ 2MB`,
imageSquareRequired: '请上传正方形比例图片',
nameInvalid: '不能是空格开头或结尾',
notAllSpaces: '不能是纯空格',
},
model: {
searchPlaceholder: '搜索模型…',
@@ -1184,6 +1186,12 @@ export const zh = {
xinference: "Xinference",
gpustack: "Gpustack",
bedrock: "Bedrock",
is_vision: '支持视觉',
is_omni: '支持全模态',
vision: '视觉',
audio: '音频',
video: '视频',
},
timezones: {
'Asia/Shanghai': '中国标准时间 (UTC+8)',
@@ -1681,6 +1689,8 @@ export const zh = {
uploadFile: '上传文件',
fileType: '文件类型',
image: '图片',
video: '视频',
audio: '音频',
fileUrl: '文件链接',
addRemoteFile: '添加远程文件',
variableConfig: '变量配置',
@@ -1959,6 +1969,8 @@ export const zh = {
marketConnected: '● 已连接',
marketDisconnected: '○ 未连接',
marketConnecting: '正在连接 {{name}}...',
serverUrlInvalid: '必须以 http:// 或 https:// 开头,且不能有前后空格',
requestHeaderKeyInvalid: '只支持英文、数字、连字符(-)、下划线(_),不能以连字符或下划线开头结尾',
},
workflow: {
coreNode: '核心节点',
@@ -2312,6 +2324,7 @@ export const zh = {
suggestions: '个性化建议',
suggestionLoading: '您的个性化建议正在生成中',
item: '个',
noData: '情绪建议数据不存在,请点击刷新按钮进行初始化',
},
reflectionEngine: {
reflectionEngineConfig: '反思引擎配置',
@@ -2558,7 +2571,8 @@ export const zh = {
context_details: '偏好详情',
supporting_evidence: '偏好来源',
specific_examples: '来源',
wordEmpty: '点击左侧图表中的节点查看偏好详情'
wordEmpty: '点击左侧图表中的节点查看偏好详情',
noData: '画像数据不存在,请点击刷新按钮进行初始化',
},
shortTermDetail: {
title: '短期记忆是AI系统的"工作台",连接即时对话与长期知识库。通过实时捕获、深度检索、智能提取和筛选转化,将临时的非结构化信息转化为有价值的长期知识。',
@@ -2616,6 +2630,7 @@ export const zh = {
updated_at: '更新时间',
entityTypes: '实体类型',
classSearchPlaceholder: '搜索类型',
addClass: '添加类型',
class_name: '类型名称',
class_description: '类型定义',

View File

@@ -1,8 +1,8 @@
/*
* @Author: ZhaoYing
* @Date: 2026-02-02 16:35:15
* @Last Modified by: ZhaoYing
* @Last Modified time: 2026-02-02 16:35:15
* @Last Modified by: ZhaoYing
* @Last Modified time: 2026-03-06 10:39:00
*/
/**
* HTTP Request Utility Module
@@ -183,7 +183,9 @@ service.interceptors.response.use(
msg = msg || i18n.t('common.serverError');
break;
default:
if (!msg && Array.isArray(error.response?.data?.detail)) {
if (['SYSTEM_DEFAULT_SCENE_CANNOT_DELETE', 'SYSTEM_DEFAULT_CLASS_CANNOT_DELETE', 'SYSTEM_DEFAULT_SCENE_CANNOT_UPDATE'].includes(msg)) {
msg = i18n.t(`common.${msg}`)
} else if (!msg && Array.isArray(error.response?.data?.detail)) {
msg = error.response?.data?.detail?.map((item: { msg: string }) => item.msg).join(';')
} else {
msg = msg || i18n.t('common.unknownError');

View File

@@ -1,8 +1,8 @@
/*
* @Author: ZhaoYing
* @Date: 2026-02-02 16:35:43
* @Last Modified by: ZhaoYing
* @Last Modified time: 2026-02-02 16:35:43
* @Last Modified by: ZhaoYing
* @Last Modified time: 2026-03-04 18:19:24
*/
/**
* Server-Sent Events (SSE) Stream Utility Module
@@ -176,17 +176,17 @@ export const handleSSE = async (url: string, data: any, onMessage?: (data: SSEMe
case 500:
case 502:
const errorData = await response.json();
errorData.error || i18n.t('common.serviceUpgrading');
message.warning(errorData.error || i18n.t('common.serviceUpgrading'));
return;
let errorInfo = errorData.error || i18n.t('common.serviceUpgrading')
message.warning(errorInfo);
throw errorInfo;
case 400:
const error = await response.json();
message.warning(error.error);
throw error || 'Bad Request';
throw error.error || 'Bad Request';
case 504:
const errorJson = await response.json();
message.warning(errorJson.error || i18n.t('common.serverError'));
return;
throw errorData.error;
case 401:
if (url?.includes('/public')) {
return message.warning(i18n.t('common.publicApiCannotRefreshToken'));

View File

@@ -2,7 +2,7 @@
* @Author: ZhaoYing
* @Date: 2026-02-03 16:27:39
* @Last Modified by: ZhaoYing
* @Last Modified time: 2026-03-03 14:21:54
* @Last Modified time: 2026-03-05 17:03:46
*/
/**
* Chat debugging component for application testing
@@ -13,7 +13,7 @@
import { type FC, useEffect, useState, useRef } from 'react';
import { useTranslation } from 'react-i18next';
import clsx from 'clsx'
import { Flex, Dropdown, type MenuProps, App } from 'antd'
import { Flex, Dropdown, type MenuProps, App, Divider } from 'antd'
import ChatIcon from '@/assets/images/application/chat.png'
import DebuggingEmpty from '@/assets/images/application/debuggingEmpty.png'
@@ -25,7 +25,7 @@ import type { ChatItem } from '@/components/Chat/types'
import { type SSEMessage } from '@/utils/stream'
import ChatInput from '@/components/Chat/ChatInput'
import UploadFiles from '@/views/Conversation/components/FileUpload'
// import AudioRecorder from '@/components/AudioRecorder'
import AudioRecorder from '@/components/AudioRecorder'
import UploadFileListModal from '@/views/Conversation/components/UploadFileListModal'
import type { UploadFileListModalRef } from '@/views/Conversation/types'
import type { Variable } from './VariableList/types'
@@ -88,7 +88,7 @@ const Chat: FC<ChatProps> = ({ chatList, data, updateChatList, handleSave, sourc
content: '',
created_at: Date.now(),
};
if (isCluster) {
updateChatList(prev => prev.map(item => ({
...item,
@@ -134,7 +134,7 @@ const Chat: FC<ChatProps> = ({ chatList, data, updateChatList, handleSave, sourc
})
}
/** Update assistant message when error occurs */
const updateErrorAssistantMessage = (message_length: number, model_config_id?: string) => {
const updateErrorAssistantMessage = (message_length: number, model_config_id?: string) => {
if (message_length > 0 || !model_config_id) return
updateChatList(prev => {
@@ -171,6 +171,29 @@ const Chat: FC<ChatProps> = ({ chatList, data, updateChatList, handleSave, sourc
.then(() => {
const message = msg
if (!message?.trim()) return
// Validate required variables before sending
let isCanSend = true
const params: Record<string, any> = {}
if (chatVariables && chatVariables.length > 0) {
const needRequired: string[] = []
chatVariables.forEach(vo => {
params[vo.name] = vo.value
if (vo.required && (params[vo.name] === null || params[vo.name] === undefined || params[vo.name] === '')) {
isCanSend = false
needRequired.push(vo.name)
}
})
if (needRequired.length) {
messageApi.error(`${needRequired.join(',')} ${t('workflow.variableRequired')}`)
}
}
if (!isCanSend) {
setLoading(false)
setCompareLoading(false)
return
}
addUserMessage(message, fileList)
setMessage(message)
@@ -198,27 +221,6 @@ const Chat: FC<ChatProps> = ({ chatList, data, updateChatList, handleSave, sourc
};
setTimeout(() => {
// Validate required variables before sending
let isCanSend = true
const params: Record<string, any> = {}
if (chatVariables && chatVariables.length > 0) {
const needRequired: string[] = []
chatVariables.forEach(vo => {
params[vo.name] = vo.value
if (vo.required && (params[vo.name] === null || params[vo.name] === undefined || params[vo.name] === '')) {
isCanSend = false
needRequired.push(vo.name)
}
})
if (needRequired.length) {
messageApi.error(`${needRequired.join(',')} ${t('workflow.variableRequired')}`)
}
}
if (!isCanSend) {
return
}
runCompare(data.app_id, {
message,
files: fileList.map(file => {
@@ -243,7 +245,15 @@ const Chat: FC<ChatProps> = ({ chatList, data, updateChatList, handleSave, sourc
"stream": true,
"timeout": 60,
}, handleStreamMessage)
.finally(() => setLoading(false));
.catch(() => {
setLoading(false)
setCompareLoading(false)
updateClusterErrorAssistantMessage(0)
})
.finally(() => {
setLoading(false)
setCompareLoading(false)
})
}, 0)
})
.catch(() => {
@@ -288,7 +298,7 @@ const Chat: FC<ChatProps> = ({ chatList, data, updateChatList, handleSave, sourc
})
}
/** Update cluster message when error occurs */
const updateClusterErrorAssistantMessage = (message_length: number) => {
const updateClusterErrorAssistantMessage = (message_length: number) => {
if (message_length > 0) return
updateChatList(prev => {
@@ -331,7 +341,7 @@ const Chat: FC<ChatProps> = ({ chatList, data, updateChatList, handleSave, sourc
data.map(item => {
const { conversation_id, content, message_length } = item.data as { conversation_id: string, content: string, message_length: number };
switch(item.event) {
switch (item.event) {
case 'start':
if (conversation_id && conversationId !== conversation_id) {
setConversationId(conversation_id);
@@ -354,27 +364,35 @@ const Chat: FC<ChatProps> = ({ chatList, data, updateChatList, handleSave, sourc
};
setTimeout(() => {
draftRun(
data.app_id,
{
message,
conversation_id: conversationId,
stream: true,
files: fileList.map(file => {
if (file.url) {
return file
} else {
return {
type: file.type,
transfer_method: 'local_file',
upload_file_id: file.response.data.file_id
}
draftRun(
data.app_id,
{
message,
conversation_id: conversationId,
stream: true,
files: fileList.map(file => {
if (file.url) {
return file
} else {
return {
type: file.type,
transfer_method: 'local_file',
upload_file_id: file.response.data.file_id
}
}),
},
handleStreamMessage
)
.finally(() => setLoading(false))
}
}),
},
handleStreamMessage
)
.catch(() => {
setLoading(false)
setCompareLoading(false)
updateClusterErrorAssistantMessage(0)
})
.finally(() => {
setLoading(false)
setCompareLoading(false)
})
}, 0)
})
.catch(() => {
@@ -393,12 +411,17 @@ const Chat: FC<ChatProps> = ({ chatList, data, updateChatList, handleSave, sourc
const fileChange = (file?: any) => {
setFileList([...fileList, file])
}
// const handleRecordingComplete = async (file: any) => {
// console.log('file', file)
// }
const handleRecordingComplete = async (file: any) => {
setFileList([...fileList, {
uid: file.file_id,
response: { data: file },
thumbUrl: file.url,
type: file.type
}])
}
const handleShowUpload: MenuProps['onClick'] = ({ key }) => {
switch(key) {
switch (key) {
case 'define':
uploadFileListModalRef.current?.handleOpen()
break
@@ -415,99 +438,98 @@ const Chat: FC<ChatProps> = ({ chatList, data, updateChatList, handleSave, sourc
return (
<div className="rb:relative rb:h-full rb:flex rb:flex-col">
{chatList.length === 0
? <Empty
url={DebuggingEmpty}
? <Empty
url={DebuggingEmpty}
size={[300, 200]}
title={t('application.debuggingEmpty')}
subTitle={t('application.debuggingEmptyDesc')}
title={t('application.debuggingEmpty')}
subTitle={t('application.debuggingEmptyDesc')}
className="rb:h-full"
/>
: <>
<div className={clsx(`rb:relative rb:grid rb:grid-cols-${chatList.length} rb:overflow-hidden rb:w-full rb:flex-1 rb:min-h-0`)}>
{chatList.map((chat, index) => (
<div key={index} className={clsx('rb:flex rb:flex-col', {
"rb:border-r rb:border-[#DFE4ED]": index !== chatList.length - 1 && chatList.length > 1,
})}>
{chat.label &&
<div className={clsx(
"rb:grid rb:bg-[#F0F3F8] rb:text-center rb:flex-[0_0_auto]",
{
'rb:rounded-tr-xl': index === chatList.length - 1,
'rb:rounded-tl-xl': index === 0,
}
)}>
<div className='rb:relative rb:p-[10px_12px] rb:overflow-hidden'>
<div className="rb:text-ellipsis rb:overflow-hidden rb:whitespace-nowrap rb:w-[calc(100%-24px)]">{chat.label}</div>
<div
className="rb:w-4 rb:h-4 rb:cursor-pointer rb:absolute rb:top-3 rb:right-3 rb:bg-cover rb:bg-[url('@/assets/images/close.svg')] rb:hover:bg-[url('@/assets/images/close_hover.svg')]"
onClick={() => handleDelete(index)}
></div>
: <>
<div className={clsx(`rb:relative rb:grid rb:grid-cols-${chatList.length} rb:overflow-hidden rb:w-full rb:flex-1 rb:min-h-0`)}>
{chatList.map((chat, index) => (
<div key={index} className={clsx('rb:flex rb:flex-col', {
"rb:border-r rb:border-[#DFE4ED]": index !== chatList.length - 1 && chatList.length > 1,
})}>
{chat.label &&
<div className={clsx(
"rb:grid rb:bg-[#F0F3F8] rb:text-center rb:flex-[0_0_auto]",
{
'rb:rounded-tr-xl': index === chatList.length - 1,
'rb:rounded-tl-xl': index === 0,
}
)}>
<div className='rb:relative rb:p-[10px_12px] rb:overflow-hidden'>
<div className="rb:text-ellipsis rb:overflow-hidden rb:whitespace-nowrap rb:w-[calc(100%-24px)]">{chat.label}</div>
<div
className="rb:w-4 rb:h-4 rb:cursor-pointer rb:absolute rb:top-3 rb:right-3 rb:bg-cover rb:bg-[url('@/assets/images/close.svg')] rb:hover:bg-[url('@/assets/images/close_hover.svg')]"
onClick={() => handleDelete(index)}
></div>
</div>
</div>
</div>
}
<ChatContent
classNames={{
'rb:mx-[16px] rb:mt-6': true,
'rb:h-[calc(100vh-282px)]': isCluster,
'rb:h-[calc(100vh-380px)]': !isCluster,
}}
contentClassNames={{
'rb:max-w-[400px]!': chatList.length === 1,
'rb:max-w-[260px]!': chatList.length === 2,
'rb:max-w-[150px]!': chatList.length === 3,
'rb:max-w-[108px]!': chatList.length === 4,
}}
empty={<Empty url={ChatIcon} title={t('application.chatEmpty')} isNeedSubTitle={false} size={[240, 200]} className="rb:h-full" />}
data={chat.list || []}
streamLoading={compareLoading}
labelPosition="top"
labelFormat={(item) => item.role === 'user' ? t('application.you') : chat.label}
errorDesc={t('application.ReplyException')}
/>
</div>
))}
</div>
<div className="rb:relative rb:flex rb:items-center rb:gap-2.5 rb:m-4 rb:mb-1">
<ChatInput
message={message}
className="rb:relative!"
loading={loading}
fileChange={updateFileList}
fileList={fileList}
onSend={isCluster ? handleClusterSend : handleSend}
onChange={handleMessageChange}
>
<Flex justify="space-between" className="rb:flex-1">
<Flex gap={8} align="center">
<Dropdown
menu={{
items: [
{ key: 'define', label: t('memoryConversation.addRemoteFile') },
{
key: 'upload', label: (
<UploadFiles
fileType={['jpg', 'jpeg', 'png', 'gif', 'bmp', 'webp', 'svg']}
onChange={fileChange}
/>
)
},
],
onClick: handleShowUpload
}
<ChatContent
classNames={{
'rb:mx-[16px] rb:mt-6': true,
'rb:h-[calc(100vh-282px)]': isCluster,
'rb:h-[calc(100vh-380px)]': !isCluster,
}}
>
<div
className="rb:size-6 rb:cursor-pointer rb:bg-cover rb:bg-[url('@/assets/images/conversation/link.svg')] rb:hover:bg-[url('@/assets/images/conversation/link_hover.svg')]"
></div>
</Dropdown>
contentClassNames={{
'rb:max-w-[400px]!': chatList.length === 1,
'rb:max-w-[260px]!': chatList.length === 2,
'rb:max-w-[150px]!': chatList.length === 3,
'rb:max-w-[108px]!': chatList.length === 4,
}}
empty={<Empty url={ChatIcon} title={t('application.chatEmpty')} isNeedSubTitle={false} size={[240, 200]} className="rb:h-full" />}
data={chat.list || []}
streamLoading={compareLoading}
labelPosition="top"
labelFormat={(item) => item.role === 'user' ? t('application.you') : chat.label}
errorDesc={t('application.ReplyException')}
/>
</div>
))}
</div>
<div className="rb:relative rb:flex rb:items-center rb:gap-2.5 rb:m-4 rb:mb-1">
<ChatInput
message={message}
className="rb:relative!"
loading={loading}
fileChange={updateFileList}
fileList={fileList}
onSend={isCluster ? handleClusterSend : handleSend}
onChange={handleMessageChange}
>
<Flex justify="space-between" className="rb:flex-1">
<Flex gap={8} align="center">
<Dropdown
menu={{
items: [
{ key: 'define', label: t('memoryConversation.addRemoteFile') },
{
key: 'upload', label: (
<UploadFiles
onChange={fileChange}
/>
)
},
],
onClick: handleShowUpload
}}
>
<div
className="rb:size-6 rb:cursor-pointer rb:bg-cover rb:bg-[url('@/assets/images/conversation/link.svg')] rb:hover:bg-[url('@/assets/images/conversation/link_hover.svg')]"
></div>
</Dropdown>
</Flex>
<Flex align="center">
<AudioRecorder onRecordingComplete={handleRecordingComplete} />
<Divider type="vertical" className="rb:ml-1.5! rb:mr-3!" />
</Flex>
</Flex>
{/* <Flex align="center">
<AudioRecorder onRecordingComplete={handleRecordingComplete} />
<Divider type="vertical" className="rb:ml-1.5! rb:mr-3!" />
</Flex> */}
</Flex>
</ChatInput>
</div>
</>
</ChatInput>
</div>
</>
}
<UploadFileListModal

View File

@@ -2,7 +2,7 @@
* @Author: ZhaoYing
* @Date: 2026-02-06 21:09:42
* @Last Modified by: ZhaoYing
* @Last Modified time: 2026-02-11 11:32:48
* @Last Modified time: 2026-03-05 15:09:22
*/
/**
* File Upload Component
@@ -25,6 +25,7 @@ import { Upload, Progress, App } from 'antd';
import type { UploadProps, UploadFile } from 'antd';
import type { UploadProps as RcUploadProps } from 'antd/es/upload/interface';
import { useTranslation } from 'react-i18next';
import { request } from '@/utils/request'
import { fileUploadUrlWithoutApiPrefix } from '@/api/fileStorage'
@@ -56,27 +57,36 @@ interface UploadFilesProps extends Omit<UploadProps, 'onChange'> {
/** Custom file removal callback */
onRemove?: (file: UploadFile) => boolean | void | Promise<boolean | void>;
}
const transform_file_type = {
'text/plain': 'document/text',
'text/markdown': 'document/markdown',
'text/x-markdown': 'document/x-markdown',
'application/pdf': 'document/pdf',
'application/msword': 'document/doc',
'application/vnd.openxmlformats-officedocument.wordprocessingml.document': 'document/docx',
'application/vnd.ms-powerpoint': 'document/ppt',
'application/vnd.openxmlformats-officedocument.presentationml.presentation': 'document/pptx',
}
// Mapping of file extensions to MIME types
const ALL_FILE_TYPE: {
[key: string]: string;
} = {
// txt: 'text/plain',
txt: 'text/plain',
md: 'text/markdown',
xmd: 'text/x-markdown',
pdf: 'application/pdf',
doc: 'application/msword',
docx: 'application/vnd.openxmlformats-officedocument.wordprocessingml.document',
xls: 'application/vnd.ms-excel',
xlsx: 'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet',
csv: 'text/csv',
ppt: 'application/vnd.ms-powerpoint',
pptx: 'application/vnd.openxmlformats-officedocument.presentationml.presentation',
// md: 'text/markdown',
// htm: 'text/html',
// html: 'text/html',
// json: 'application/json',
jpg: 'image/jpeg',
jpeg: 'image/jpeg',
png: 'image/png',
@@ -84,6 +94,23 @@ const ALL_FILE_TYPE: {
bmp: 'image/bmp',
webp: 'image/webp',
svg: 'image/svg+xml',
mp4: 'video/mp4',
mov: 'video/quicktime',
avi: 'video/x-msvideo',
mkv: 'video/x-matroska',
webm: 'video/webm',
flv: 'video/x-flv',
wmv: 'video/x-ms-wmv',
mp3: 'audio/mpeg',
wav: 'audio/wav',
ogg: 'audio/ogg',
aac: 'audio/aac',
flac: 'audio/flac',
m4a: 'audio/mp4',
wma: 'audio/x-ms-wma',
xm4a: 'audio/x-m4a',
}
export interface UploadFilesRef {
/** Current file list */
@@ -178,6 +205,10 @@ const UploadFiles = forwardRef<UploadFilesRef, UploadFilesProps>(({
* Handles upload state changes
*/
const handleChange: UploadProps['onChange'] = ({ fileList: newFileList }) => {
newFileList.map(file => {
const type = (file.type && transform_file_type[file.type as keyof typeof transform_file_type]) || file.type || 'document'
file.type = type
})
setFileList(newFileList);
if (onChange) {
onChange(maxCount === 1 ? newFileList[newFileList.length - 1] : newFileList);

View File

@@ -2,7 +2,7 @@
* @Author: ZhaoYing
* @Date: 2026-02-06 21:09:47
* @Last Modified by: ZhaoYing
* @Last Modified time: 2026-02-09 10:17:54
* @Last Modified time: 2026-03-04 17:47:09
*/
/**
* Upload File List Modal Component
@@ -104,7 +104,9 @@ const UploadFileListModal = forwardRef<UploadFileListModalRef, UploadFileListMod
<Select
placeholder={t('memoryConversation.fileType')}
options={[
{ label: t('memoryConversation.image'), value: 'image' }
{ label: t('memoryConversation.image'), value: 'image' },
{ label: t('memoryConversation.audio'), value: 'audio' },
{ label: t('memoryConversation.video'), value: 'video' },
]}
className="rb:w-30"
/>

View File

@@ -14,7 +14,7 @@ import { type FC, useState, useEffect, useRef } from 'react'
import { useParams, useLocation } from 'react-router-dom'
import { useTranslation } from 'react-i18next'
import InfiniteScroll from 'react-infinite-scroll-component';
import { Flex, Skeleton, Form, Dropdown, type MenuProps, App } from 'antd'
import { Flex, Skeleton, Form, Dropdown, type MenuProps, App, Divider } from 'antd'
import { SettingOutlined } from '@ant-design/icons'
import clsx from 'clsx'
import dayjs from 'dayjs'
@@ -35,7 +35,7 @@ import OnlineCheckedIcon from '@/assets/images/conversation/onlineChecked.svg'
import MemoryFunctionCheckedIcon from '@/assets/images/conversation/memoryFunctionChecked.svg'
import { type SSEMessage } from '@/utils/stream'
import UploadFiles from './components/FileUpload'
// import AudioRecorder from '@/components/AudioRecorder'
import AudioRecorder from '@/components/AudioRecorder'
import { shareFileUploadUrlWithoutApiPrefix } from '@/api/fileStorage'
import UploadFileListModal from './components/UploadFileListModal'
import type { VariableConfigModalRef } from '@/views/Workflow/types'
@@ -305,17 +305,27 @@ const Conversation: FC = () => {
}),
variables: params
}, handleStreamMessage, shareToken)
.catch(() => {
setLoading(false)
setStreamLoading(false)
})
.finally(() => {
setLoading(false)
setStreamLoading(false)
})
}
const fileChange = (file?: any) => {
form.setFieldValue('files', [...(queryValues.files || []), file])
}
// const handleRecordingComplete = async (file: any) => {
// console.log('file', file)
// }
const handleRecordingComplete = async (file: any) => {
form.setFieldValue('files', [...(queryValues.files || []), {
uid: file.file_id,
response: { data: file },
thumbUrl: file.url,
type: file.type
}])
}
const handleShowUpload: MenuProps['onClick'] = ({ key }) => {
switch(key) {
@@ -329,6 +339,7 @@ const Conversation: FC = () => {
form.setFieldValue('files', [...(queryValues.files || []), ...fileList])
}
const updateFileList = (fileList?: any[]) => {
console.log('fileList', fileList)
form.setFieldValue('files', [...(fileList || [])])
}
@@ -383,7 +394,7 @@ const Conversation: FC = () => {
<div className='rb:w-190 rb:h-screen rb:mx-auto rb:pt-10'>
<Chat
empty={<Empty url={ChatEmpty} className="rb:h-full" size={[320,180]} title={t('memoryConversation.chatEmpty')} subTitle={t('memoryConversation.emptyDesc')} />}
contentClassName="rb:h-[calc(100%-180px)]"
contentClassName={!queryValues?.files?.length ? "rb:h-[calc(100%-144px)]" : "rb:h-[calc(100%-208px)]"}
data={chatList}
streamLoading={streamLoading}
loading={loading}
@@ -405,13 +416,12 @@ const Conversation: FC = () => {
key: 'upload', label: (
<UploadFiles
action={shareFileUploadUrlWithoutApiPrefix}
fileType={['jpg', 'jpeg', 'png', 'gif', 'bmp', 'webp', 'svg']}
onChange={fileChange}
requestConfig={{
headers: {
'Content-Type': 'multipart/form-data',
Authorization: `Bearer ${shareToken || ''}`,
} }}
}}}
/>
)
},
@@ -455,10 +465,19 @@ const Conversation: FC = () => {
</Form.Item>
)}
</Flex>
{/* <Flex align="center">
<AudioRecorder onRecordingComplete={handleRecordingComplete} />
<Flex align="center">
<AudioRecorder
action={shareFileUploadUrlWithoutApiPrefix}
requestConfig={{
headers: {
'Content-Type': 'multipart/form-data',
Authorization: `Bearer ${shareToken || ''}`,
}
}}
onRecordingComplete={handleRecordingComplete}
/>
<Divider type="vertical" className="rb:ml-1.5! rb:mr-3!" />
</Flex> */}
</Flex>
</Flex>
</Form>
</Chat>

View File

@@ -15,6 +15,7 @@ import {
} from '@/api/knowledgeBase'
import RbModal from '@/components/RbModal'
import SliderInput from '@/components/SliderInput'
import { stringRegExp } from '@/utils/validator'
const { TextArea } = Input;
const { confirm } = Modal
@@ -519,12 +520,16 @@ const CreateModal = forwardRef<CreateModalRef, CreateModalRefProps>(({
<Form.Item
name="name"
label={t('knowledgeBase.createForm.name')}
rules={[{ required: true, message: t('knowledgeBase.createForm.nameRequired') }]}
rules={[
{ required: true, message: t('knowledgeBase.createForm.nameRequired') },
{ max: 50 },
{ pattern: stringRegExp, message: t('common.nameInvalid') },
]}
>
<Input placeholder={t('knowledgeBase.createForm.name')} />
</Form.Item>
)}
<Form.Item name="description" label={t('knowledgeBase.createForm.description')}>
<Form.Item name="description" label={t('knowledgeBase.createForm.description')} rules={[{ max: 500 }]}>
<TextArea rows={2} placeholder={t('knowledgeBase.createForm.description')} />
</Form.Item>

View File

@@ -1,8 +1,8 @@
/*
* @Author: ZhaoYing
* @Date: 2026-02-03 17:33:15
* @Last Modified by: ZhaoYing
* @Last Modified time: 2026-02-03 17:33:15
* @Last Modified by: ZhaoYing
* @Last Modified time: 2026-03-05 16:28:58
*/
/**
* Memory Management Page
@@ -110,9 +110,15 @@ const MemoryManagement: React.FC = () => {
<List.Item key={item.config_id}>
<RbCard
title={item.config_name}
className="rb:relative"
>
{item.is_system_default &&
<div className="rb:absolute rb:-right-px rb:-top-px rb:bg-[#FF5D34] rb:rounded-[0px_7px_0px_8px] rb:text-[12px] rb:text-white rb:font-regular rb:leading-4 rb:py-0.5 rb:px-1">
{t('common.default')}
</div>
}
<Tooltip title={item.config_desc}>
<div className="rb:text-[#5B6167] rb:text-[12px] rb:leading-4.25 rb:font-regular rb:-mt-1 rb:wrap-break-word rb:line-clamp-1 rb:h-[17px]">{item.config_desc}</div>
<div className="rb:text-[#5B6167] rb:text-[12px] rb:leading-4.25 rb:font-regular rb:-mt-1 rb:wrap-break-word rb:line-clamp-1 rb:h-4.25">{item.config_desc}</div>
</Tooltip>
<RbAlert className="rb:mt-3 ">
<div className={clsx("rb:flex rb:gap-5 rb:font-regular rb:text-[14px]")}>

Some files were not shown because too many files have changed in this diff Show More