diff --git a/README.md b/README.md index 2f53a996..95d8d737 100644 --- a/README.md +++ b/README.md @@ -226,8 +226,8 @@ REDIS_PORT=6379 REDIS_DB=1 # Celery (Using Redis as broker) -BROKER_URL=redis://127.0.0.1:6379/0 -RESULT_BACKEND=redis://127.0.0.1:6379/0 +REDIS_DB_CELERY_BROKER=1 +REDIS_DB_CELERY_BACKEND=2 # JWT Secret Key (Formation method: openssl rand -hex 32) SECRET_KEY=your-secret-key-here diff --git a/README_CN.md b/README_CN.md index aed69b03..1472acac 100644 --- a/README_CN.md +++ b/README_CN.md @@ -201,8 +201,8 @@ REDIS_PORT=6379 REDIS_DB=1 # Celery (使用Redis作为broker) -BROKER_URL=redis://127.0.0.1:6379/0 -RESULT_BACKEND=redis://127.0.0.1:6379/0 +REDIS_DB_CELERY_BROKER=1 +REDIS_DB_CELERY_BACKEND=2 # JWT密钥 (生成方式: openssl rand -hex 32) SECRET_KEY=your-secret-key-here diff --git a/api/app/cache/__init__.py b/api/app/cache/__init__.py index 46d1c959..ca7aa91a 100644 --- a/api/app/cache/__init__.py +++ b/api/app/cache/__init__.py @@ -3,10 +3,8 @@ Cache 缓存模块 提供各种缓存功能的统一入口 """ -from .memory import EmotionMemoryCache, ImplicitMemoryCache, InterestMemoryCache +from .memory import InterestMemoryCache __all__ = [ - "EmotionMemoryCache", - "ImplicitMemoryCache", "InterestMemoryCache", ] diff --git a/api/app/cache/memory/__init__.py b/api/app/cache/memory/__init__.py index 0e21df0f..9a7fd225 100644 --- a/api/app/cache/memory/__init__.py +++ b/api/app/cache/memory/__init__.py @@ -3,12 +3,8 @@ Memory 缓存模块 提供记忆系统相关的缓存功能 """ -from .emotion_memory import EmotionMemoryCache -from .implicit_memory import ImplicitMemoryCache from .interest_memory import InterestMemoryCache __all__ = [ - "EmotionMemoryCache", - "ImplicitMemoryCache", "InterestMemoryCache", ] diff --git a/api/app/cache/memory/emotion_memory.py b/api/app/cache/memory/emotion_memory.py deleted file mode 100644 index 45ea90de..00000000 --- a/api/app/cache/memory/emotion_memory.py +++ /dev/null @@ -1,134 +0,0 @@ -""" -Emotion Suggestions Cache - -情绪个性化建议缓存模块 -用于缓存用户的情绪个性化建议数据 -""" -import json -import logging -from typing import Optional, Dict, Any -from datetime import datetime - -from app.aioRedis import aio_redis - -logger = logging.getLogger(__name__) - - -class EmotionMemoryCache: - """情绪建议缓存类""" - - # Key 前缀 - PREFIX = "cache:memory:emotion_memory" - - @classmethod - def _get_key(cls, *parts: str) -> str: - """生成 Redis key - - Args: - *parts: key 的各个部分 - - Returns: - 完整的 Redis key - """ - return ":".join([cls.PREFIX] + list(parts)) - - @classmethod - async def set_emotion_suggestions( - cls, - user_id: str, - suggestions_data: Dict[str, Any], - expire: int = 86400 - ) -> bool: - """设置用户情绪建议缓存 - - Args: - user_id: 用户ID(end_user_id) - suggestions_data: 建议数据字典,包含: - - health_summary: 健康状态摘要 - - suggestions: 建议列表 - - generated_at: 生成时间(可选) - expire: 过期时间(秒),默认24小时(86400秒) - - Returns: - 是否设置成功 - """ - try: - key = cls._get_key("suggestions", user_id) - - # 添加生成时间戳 - if "generated_at" not in suggestions_data: - suggestions_data["generated_at"] = datetime.now().isoformat() - - # 添加缓存标记 - suggestions_data["cached"] = True - - value = json.dumps(suggestions_data, ensure_ascii=False) - await aio_redis.set(key, value, ex=expire) - logger.info(f"设置情绪建议缓存成功: {key}, 过期时间: {expire}秒") - return True - except Exception as e: - logger.error(f"设置情绪建议缓存失败: {e}", exc_info=True) - return False - - @classmethod - async def get_emotion_suggestions(cls, user_id: str) -> Optional[Dict[str, Any]]: - """获取用户情绪建议缓存 - - Args: - user_id: 用户ID(end_user_id) - - Returns: - 建议数据字典,如果不存在或已过期返回 None - """ - try: - key = cls._get_key("suggestions", user_id) - value = await aio_redis.get(key) - - if value: - data = json.loads(value) - logger.info(f"成功获取情绪建议缓存: {key}") - return data - - logger.info(f"情绪建议缓存不存在或已过期: {key}") - return None - except Exception as e: - logger.error(f"获取情绪建议缓存失败: {e}", exc_info=True) - return None - - @classmethod - async def delete_emotion_suggestions(cls, user_id: str) -> bool: - """删除用户情绪建议缓存 - - Args: - user_id: 用户ID(end_user_id) - - Returns: - 是否删除成功 - """ - try: - key = cls._get_key("suggestions", user_id) - result = await aio_redis.delete(key) - logger.info(f"删除情绪建议缓存: {key}, 结果: {result}") - return result > 0 - except Exception as e: - logger.error(f"删除情绪建议缓存失败: {e}", exc_info=True) - return False - - @classmethod - async def get_suggestions_ttl(cls, user_id: str) -> int: - """获取情绪建议缓存的剩余过期时间 - - Args: - user_id: 用户ID(end_user_id) - - Returns: - 剩余秒数,-1表示永不过期,-2表示key不存在 - """ - try: - key = cls._get_key("suggestions", user_id) - ttl = await aio_redis.ttl(key) - logger.debug(f"情绪建议缓存TTL: {key} = {ttl}秒") - return ttl - except Exception as e: - logger.error(f"获取情绪建议缓存TTL失败: {e}") - return -2 diff --git a/api/app/cache/memory/implicit_memory.py b/api/app/cache/memory/implicit_memory.py deleted file mode 100644 index 21f08e9a..00000000 --- a/api/app/cache/memory/implicit_memory.py +++ /dev/null @@ -1,136 +0,0 @@ -""" -Implicit Memory Profile Cache - -隐式记忆用户画像缓存模块 -用于缓存用户的完整画像数据(偏好标签、四维画像、兴趣领域、行为习惯) -""" -import json -import logging -from typing import Optional, Dict, Any -from datetime import datetime - -from app.aioRedis import aio_redis - -logger = logging.getLogger(__name__) - - -class ImplicitMemoryCache: - """隐式记忆用户画像缓存类""" - - # Key 前缀 - PREFIX = "cache:memory:implicit_memory" - - @classmethod - def _get_key(cls, *parts: str) -> str: - """生成 Redis key - - Args: - *parts: key 的各个部分 - - Returns: - 完整的 Redis key - """ - return ":".join([cls.PREFIX] + list(parts)) - - @classmethod - async def set_user_profile( - cls, - user_id: str, - profile_data: Dict[str, Any], - expire: int = 86400 - ) -> bool: - """设置用户完整画像缓存 - - Args: - user_id: 用户ID(end_user_id) - profile_data: 画像数据字典,包含: - - preferences: 偏好标签列表 - - portrait: 四维画像对象 - - interest_areas: 兴趣领域分布对象 - - habits: 行为习惯列表 - - generated_at: 生成时间(可选) - expire: 过期时间(秒),默认24小时(86400秒) - - Returns: - 是否设置成功 - """ - try: - key = cls._get_key("profile", user_id) - - # 添加生成时间戳 - if "generated_at" not in profile_data: - profile_data["generated_at"] = datetime.now().isoformat() - - # 添加缓存标记 - profile_data["cached"] = True - - value = json.dumps(profile_data, ensure_ascii=False) - await aio_redis.set(key, value, ex=expire) - logger.info(f"设置用户画像缓存成功: {key}, 过期时间: {expire}秒") - return True - except Exception as e: - logger.error(f"设置用户画像缓存失败: {e}", exc_info=True) - return False - - @classmethod - async def get_user_profile(cls, user_id: str) -> Optional[Dict[str, Any]]: - """获取用户完整画像缓存 - - Args: - user_id: 用户ID(end_user_id) - - Returns: - 画像数据字典,如果不存在或已过期返回 None - """ - try: - key = cls._get_key("profile", user_id) - value = await aio_redis.get(key) - - if value: - data = json.loads(value) - logger.info(f"成功获取用户画像缓存: {key}") - return data - - logger.info(f"用户画像缓存不存在或已过期: {key}") - return None - except Exception as e: - logger.error(f"获取用户画像缓存失败: {e}", exc_info=True) - return None - - @classmethod - async def delete_user_profile(cls, user_id: str) -> bool: - """删除用户完整画像缓存 - - Args: - user_id: 用户ID(end_user_id) - - Returns: - 是否删除成功 - """ - try: - key = cls._get_key("profile", user_id) - result = await aio_redis.delete(key) - logger.info(f"删除用户画像缓存: {key}, 结果: {result}") - return result > 0 - except Exception as e: - logger.error(f"删除用户画像缓存失败: {e}", exc_info=True) - return False - - @classmethod - async def get_profile_ttl(cls, user_id: str) -> int: - """获取用户画像缓存的剩余过期时间 - - Args: - user_id: 用户ID(end_user_id) - - Returns: - 剩余秒数,-1表示永不过期,-2表示key不存在 - """ - try: - key = cls._get_key("profile", user_id) - ttl = await aio_redis.ttl(key) - logger.debug(f"用户画像缓存TTL: {key} = {ttl}秒") - return ttl - except Exception as e: - logger.error(f"获取用户画像缓存TTL失败: {e}") - return -2 diff --git a/api/app/celery_app.py b/api/app/celery_app.py index c087e1d7..0319e079 100644 --- a/api/app/celery_app.py +++ b/api/app/celery_app.py @@ -1,26 +1,54 @@ import os import platform from datetime import timedelta -from celery.schedules import crontab from urllib.parse import quote from celery import Celery +from celery.schedules import crontab from app.core.config import settings +from app.core.logging_config import get_logger + +logger = get_logger(__name__) # macOS fork() safety - must be set before any Celery initialization if platform.system() == 'Darwin': os.environ.setdefault('OBJC_DISABLE_INITIALIZE_FORK_SAFETY', 'YES') # 创建 Celery 应用实例 -# broker: 任务队列(使用 Redis DB 0) -# backend: 结果存储(使用 Redis DB 10) +# broker: 任务队列(使用 Redis DB,由 CELERY_BROKER_DB 指定) +# backend: 结果存储(使用 Redis DB,由 CELERY_BACKEND_DB 指定) +# NOTE: 不要在 .env 中设置 BROKER_URL / RESULT_BACKEND / CELERY_BROKER / CELERY_BACKEND, +# 这些名称会被 Celery CLI 的 Click 框架劫持,详见 docs/celery-env-bug-report.md + +# Build canonical broker/backend URLs and force them into os.environ so that +# Celery's Settings.broker_url property (which checks CELERY_BROKER_URL first) +# cannot be overridden by stray env vars. +# See: https://github.com/celery/celery/issues/4284 +_broker_url = f"redis://:{quote(settings.REDIS_PASSWORD)}@{settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.REDIS_DB_CELERY_BROKER}" +_backend_url = f"redis://:{quote(settings.REDIS_PASSWORD)}@{settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.REDIS_DB_CELERY_BACKEND}" +os.environ["CELERY_BROKER_URL"] = _broker_url +os.environ["CELERY_RESULT_BACKEND"] = _backend_url +# Neutralize legacy Celery env vars that can be hijacked by Celery's CLI/Click +# integration and accidentally override our canonical URLs. +os.environ.pop("BROKER_URL", None) +os.environ.pop("RESULT_BACKEND", None) +os.environ.pop("CELERY_BROKER", None) +os.environ.pop("CELERY_BACKEND", None) + celery_app = Celery( "redbear_tasks", - broker=f"redis://:{quote(settings.REDIS_PASSWORD)}@{settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.CELERY_BROKER}", - backend=f"redis://:{quote(settings.REDIS_PASSWORD)}@{settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.CELERY_BACKEND}", + broker=_broker_url, + backend=_backend_url, ) +logger.info( + "Celery app initialized", + extra={ + "broker": _broker_url.replace(quote(settings.REDIS_PASSWORD), "***"), + "backend": _backend_url.replace(quote(settings.REDIS_PASSWORD), "***"), + }, +) # Default queue for unrouted tasks celery_app.conf.task_default_queue = 'memory_tasks' @@ -44,8 +72,8 @@ celery_app.conf.update( task_ignore_result=False, # 超时设置 - task_time_limit=1800, # 30分钟硬超时 - task_soft_time_limit=1500, # 25分钟软超时 + task_time_limit=3600, # 60分钟硬超时 + task_soft_time_limit=3000, # 50分钟软超时 # Worker 设置 (per-worker settings are in docker-compose command line) worker_prefetch_multiplier=1, # Don't hoard tasks, fairer distribution @@ -84,6 +112,7 @@ celery_app.conf.update( 'app.tasks.regenerate_memory_cache': {'queue': 'periodic_tasks'}, 'app.tasks.run_forgetting_cycle_task': {'queue': 'periodic_tasks'}, 'app.tasks.write_all_workspaces_memory_task': {'queue': 'periodic_tasks'}, + 'app.tasks.update_implicit_emotions_storage': {'queue': 'periodic_tasks'}, }, ) @@ -95,6 +124,10 @@ memory_increment_schedule = crontab(hour=settings.MEMORY_INCREMENT_HOUR, minute= memory_cache_regeneration_schedule = timedelta(hours=settings.MEMORY_CACHE_REGENERATION_HOURS) workspace_reflection_schedule = timedelta(seconds=settings.WORKSPACE_REFLECTION_INTERVAL_SECONDS) forgetting_cycle_schedule = timedelta(hours=settings.FORGETTING_CYCLE_INTERVAL_HOURS) +implicit_emotions_update_schedule = crontab( + hour=settings.IMPLICIT_EMOTIONS_UPDATE_HOUR, + minute=settings.IMPLICIT_EMOTIONS_UPDATE_MINUTE, +) #构建定时任务配置 beat_schedule_config = { @@ -120,6 +153,11 @@ beat_schedule_config = { "schedule": memory_increment_schedule, "args": (), }, + "update-implicit-emotions-storage": { + "task": "app.tasks.update_implicit_emotions_storage", + "schedule": implicit_emotions_update_schedule, + "args": (), + }, } celery_app.conf.beat_schedule = beat_schedule_config diff --git a/api/app/controllers/app_controller.py b/api/app/controllers/app_controller.py index e2849ad6..cdf94345 100644 --- a/api/app/controllers/app_controller.py +++ b/api/app/controllers/app_controller.py @@ -396,10 +396,10 @@ async def draft_run( from app.models import AgentConfig, ModelConfig from sqlalchemy import select from app.core.exceptions import BusinessException - from app.services.draft_run_service import DraftRunService + from app.services.draft_run_service import AgentRunService service = AppService(db) - draft_service = DraftRunService(db) + draft_service = AgentRunService(db) # 1. 验证应用 app = service._get_app_or_404(app_id) @@ -484,8 +484,8 @@ async def draft_run( } ) - from app.services.draft_run_service import DraftRunService - draft_service = DraftRunService(db) + from app.services.draft_run_service import AgentRunService + draft_service = AgentRunService(db) result = await draft_service.run( agent_config=agent_cfg, model_config=model_config, @@ -789,8 +789,8 @@ async def draft_run_compare( # 流式返回 if payload.stream: async def event_generator(): - from app.services.draft_run_service import DraftRunService - draft_service = DraftRunService(db) + from app.services.draft_run_service import AgentRunService + draft_service = AgentRunService(db) async for event in draft_service.run_compare_stream( agent_config=agent_cfg, models=model_configs, @@ -820,8 +820,8 @@ async def draft_run_compare( ) # 非流式返回 - from app.services.draft_run_service import DraftRunService - draft_service = DraftRunService(db) + from app.services.draft_run_service import AgentRunService + draft_service = AgentRunService(db) result = await draft_service.run_compare( agent_config=agent_cfg, models=model_configs, @@ -835,7 +835,8 @@ async def draft_run_compare( web_search=True, memory=True, parallel=payload.parallel, - timeout=payload.timeout or 60 + timeout=payload.timeout or 60, + files=payload.files ) logger.info( diff --git a/api/app/controllers/chunk_controller.py b/api/app/controllers/chunk_controller.py index 620d8a1a..988aa706 100644 --- a/api/app/controllers/chunk_controller.py +++ b/api/app/controllers/chunk_controller.py @@ -441,14 +441,14 @@ async def retrieve_chunks( # 1 participle search, 2 semantic search, 3 hybrid search match retrieve_data.retrieve_type: case chunk_schema.RetrieveType.PARTICIPLE: - rs = vector_service.search_by_full_text(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.similarity_threshold) + rs = vector_service.search_by_full_text(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.similarity_threshold, file_names_filter=retrieve_data.file_names_filter) return success(data=rs, msg="retrieval successful") case chunk_schema.RetrieveType.SEMANTIC: - rs = vector_service.search_by_vector(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.vector_similarity_weight) + rs = vector_service.search_by_vector(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.vector_similarity_weight, file_names_filter=retrieve_data.file_names_filter) return success(data=rs, msg="retrieval successful") case _: - rs1 = vector_service.search_by_vector(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.vector_similarity_weight) - rs2 = vector_service.search_by_full_text(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.similarity_threshold) + rs1 = vector_service.search_by_vector(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.vector_similarity_weight, file_names_filter=retrieve_data.file_names_filter) + rs2 = vector_service.search_by_full_text(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.similarity_threshold, file_names_filter=retrieve_data.file_names_filter) # Efficient deduplication seen_ids = set() unique_rs = [] diff --git a/api/app/controllers/emotion_controller.py b/api/app/controllers/emotion_controller.py index eb2436d2..8cfc5014 100644 --- a/api/app/controllers/emotion_controller.py +++ b/api/app/controllers/emotion_controller.py @@ -208,14 +208,64 @@ async def get_emotion_health( +# @router.post("/check-data", response_model=ApiResponse) +# async def check_emotion_data_exists( +# request: EmotionSuggestionsRequest, +# db: Session = Depends(get_db), +# current_user: User = Depends(get_current_user), +# ): +# """检查用户情绪建议数据是否存在 + +# Args: +# request: 包含 end_user_id +# db: 数据库会话 +# current_user: 当前用户 + +# Returns: +# 数据存在状态 +# """ +# try: +# api_logger.info( +# f"检查用户情绪建议数据是否存在: {request.end_user_id}", +# extra={"end_user_id": request.end_user_id} +# ) + +# # 从数据库获取建议 +# data = await emotion_service.get_cached_suggestions( +# end_user_id=request.end_user_id, +# db=db +# ) + +# if data is None: +# api_logger.info(f"用户 {request.end_user_id} 的情绪建议数据不存在") +# return fail( +# BizCode.NOT_FOUND, +# "情绪建议数据不存在,请点击右上角刷新进行初始化", +# {"exists": False} +# ) + +# api_logger.info(f"用户 {request.end_user_id} 的情绪建议数据存在") +# return success(data={"exists": True}, msg="情绪建议数据已存在") + +# except Exception as e: +# api_logger.error( +# f"检查情绪建议数据失败: {str(e)}", +# extra={"end_user_id": request.end_user_id}, +# exc_info=True +# ) +# raise HTTPException( +# status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, +# detail=f"检查情绪建议数据失败: {str(e)}" +# ) + + @router.post("/suggestions", response_model=ApiResponse) async def get_emotion_suggestions( request: EmotionSuggestionsRequest, - language_type: str = Header(default=None, alias="X-Language-Type"), db: Session = Depends(get_db), current_user: User = Depends(get_current_user), ): - """获取个性化情绪建议(从缓存读取) + """获取个性化情绪建议(从数据库读取) Args: request: 包含 end_user_id 和可选的 config_id @@ -223,77 +273,42 @@ async def get_emotion_suggestions( current_user: 当前用户 Returns: - 缓存的个性化情绪建议响应 + 存储的个性化情绪建议响应 """ try: - # 使用集中化的语言校验 - language = get_language_from_header(language_type) - api_logger.info( - f"用户 {current_user.username} 请求获取个性化情绪建议(缓存)", + f"用户 {current_user.username} 请求获取个性化情绪建议", extra={ "end_user_id": request.end_user_id, "config_id": request.config_id } ) - # 从缓存获取建议 + # 从数据库获取建议 data = await emotion_service.get_cached_suggestions( end_user_id=request.end_user_id, db=db ) if data is None: - # 缓存不存在或已过期,自动触发生成 api_logger.info( - f"用户 {request.end_user_id} 的建议缓存不存在或已过期,自动生成新建议", + f"用户 {request.end_user_id} 的建议数据不存在", extra={"end_user_id": request.end_user_id} ) - try: - data = await emotion_service.generate_emotion_suggestions( - end_user_id=request.end_user_id, - db=db, - language=language - ) - # 保存到缓存 - await emotion_service.save_suggestions_cache( - end_user_id=request.end_user_id, - suggestions_data=data, - db=db, - expires_hours=24 - ) - except (ValueError, KeyError) as gen_e: - # 预期内的业务异常:配置缺失、数据格式问题等 - api_logger.warning( - f"自动生成建议失败(业务异常): {str(gen_e)}", - extra={"end_user_id": request.end_user_id} - ) - return fail( - BizCode.NOT_FOUND, - f"自动生成建议失败: {str(gen_e)}", - "" - ) - except Exception as gen_e: - # 非预期异常:记录完整 traceback 便于排查 - api_logger.error( - f"自动生成建议时发生未预期异常: {str(gen_e)}", - extra={"end_user_id": request.end_user_id}, - exc_info=True - ) - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"生成建议时发生内部错误: {str(gen_e)}" - ) + return success( + data={"exists": False}, + msg="情绪建议数据不存在,请点击右上角刷新进行初始化" + ) api_logger.info( - "个性化建议获取成功(缓存)", + "个性化建议获取成功", extra={ "end_user_id": request.end_user_id, "suggestions_count": len(data.get("suggestions", [])) } ) - return success(data=data, msg="个性化建议获取成功(缓存)") + return success(data=data, msg="个性化建议获取成功") except Exception as e: api_logger.error( @@ -314,7 +329,7 @@ async def generate_emotion_suggestions( db: Session = Depends(get_db), current_user: User = Depends(get_current_user), ): - """生成个性化情绪建议(调用LLM并缓存) + """生成个性化情绪建议(调用LLM并保存到数据库) Args: request: 包含 end_user_id @@ -342,12 +357,11 @@ async def generate_emotion_suggestions( language=language ) - # 保存到缓存 + # 保存到数据库 await emotion_service.save_suggestions_cache( end_user_id=request.end_user_id, suggestions_data=data, - db=db, - expires_hours=24 + db=db ) api_logger.info( @@ -369,4 +383,4 @@ async def generate_emotion_suggestions( raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"生成个性化建议失败: {str(e)}" - ) + ) \ No newline at end of file diff --git a/api/app/controllers/implicit_memory_controller.py b/api/app/controllers/implicit_memory_controller.py index 96e437d6..76a87c5f 100644 --- a/api/app/controllers/implicit_memory_controller.py +++ b/api/app/controllers/implicit_memory_controller.py @@ -122,6 +122,48 @@ def validate_confidence_threshold(threshold: float) -> None: raise ValueError("confidence_threshold must be between 0.0 and 1.0") +@router.get("/check-data/{end_user_id}", response_model=ApiResponse) +@cur_workspace_access_guard() +async def check_user_data_exists( + end_user_id: str, + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user) +) -> ApiResponse: + """ + 检查用户画像数据是否存在 + + Args: + end_user_id: 目标用户ID + + Returns: + 数据存在状态 + """ + api_logger.info(f"检查用户画像数据是否存在: {end_user_id}") + + try: + # Validate inputs + validate_user_id(end_user_id) + + # Create service with user-specific config + service = ImplicitMemoryService(db=db, end_user_id=end_user_id) + + # Get cached profile + cached_profile = await service.get_cached_profile(end_user_id=end_user_id, db=db) + + if cached_profile is None: + api_logger.info(f"用户 {end_user_id} 的画像数据不存在") + return success( + data={"exists": False}, + msg="画像数据不存在,请点击右上角刷新进行初始化" + ) + + api_logger.info(f"用户 {end_user_id} 的画像数据存在") + return success(data={"exists": True}, msg="画像数据已存在") + + except Exception as e: + return handle_implicit_memory_error(e, "检查画像数据", end_user_id) + + @router.get("/preferences/{end_user_id}", response_model=ApiResponse) @cur_workspace_access_guard() async def get_preference_tags( @@ -159,12 +201,8 @@ async def get_preference_tags( cached_profile = await service.get_cached_profile(end_user_id=end_user_id, db=db) if cached_profile is None: - api_logger.info(f"用户 {end_user_id} 的画像缓存不存在或已过期") - return fail( - BizCode.NOT_FOUND, - "画像缓存不存在或已过期,请右上角刷新生成新画像", - "" - ) + api_logger.info(f"用户 {end_user_id} 的画像数据不存在") + return fail(BizCode.NOT_FOUND, "", "") # Extract preferences from cache preferences = cached_profile.get("preferences", []) @@ -230,12 +268,8 @@ async def get_dimension_portrait( cached_profile = await service.get_cached_profile(end_user_id=end_user_id, db=db) if cached_profile is None: - api_logger.info(f"用户 {end_user_id} 的画像缓存不存在或已过期") - return fail( - BizCode.NOT_FOUND, - "画像缓存不存在或已过期,请右上角刷新生成新画像", - "" - ) + api_logger.info(f"用户 {end_user_id} 的画像数据不存在") + return fail(BizCode.NOT_FOUND, "", "") # Extract portrait from cache portrait = cached_profile.get("portrait", {}) @@ -278,12 +312,8 @@ async def get_interest_area_distribution( cached_profile = await service.get_cached_profile(end_user_id=end_user_id, db=db) if cached_profile is None: - api_logger.info(f"用户 {end_user_id} 的画像缓存不存在或已过期") - return fail( - BizCode.NOT_FOUND, - "画像缓存不存在或已过期,请右上角刷新生成新画像", - "" - ) + api_logger.info(f"用户 {end_user_id} 的画像数据不存在") + return fail(BizCode.NOT_FOUND, "", "") # Extract interest areas from cache interest_areas = cached_profile.get("interest_areas", {}) @@ -330,12 +360,8 @@ async def get_behavior_habits( cached_profile = await service.get_cached_profile(end_user_id=end_user_id, db=db) if cached_profile is None: - api_logger.info(f"用户 {end_user_id} 的画像缓存不存在或已过期") - return fail( - BizCode.NOT_FOUND, - "画像缓存不存在或已过期,请右上角刷新生成新画像", - "" - ) + api_logger.info(f"用户 {end_user_id} 的画像数据不存在") + return fail(BizCode.NOT_FOUND, "", "") # Extract habits from cache habits = cached_profile.get("habits", []) diff --git a/api/app/controllers/mcp_market_config_controller.py b/api/app/controllers/mcp_market_config_controller.py index 98012568..7f73663e 100644 --- a/api/app/controllers/mcp_market_config_controller.py +++ b/api/app/controllers/mcp_market_config_controller.py @@ -90,7 +90,7 @@ async def get_mcp_servers( cookies=cookies) raise_for_http_status(r) except requests.exceptions.RequestException as e: - api_logger.error(f"mFailed to get MCP servers: {str(e)}") + api_logger.error(f"Failed to get MCP servers: {str(e)}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to get MCP servers: {str(e)}" @@ -118,6 +118,65 @@ async def get_mcp_servers( return success(data=result, msg="Query of mcp servers list successful") +@router.get("/operational_mcp_servers", response_model=ApiResponse) +async def get_operational_mcp_servers( + mcp_market_config_id: uuid.UUID, + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user) +): + """ + Query the operational mcp servers list in pages + - Support keyword search for name,author,owner + - Return paging metadata + operational mcp server list + """ + api_logger.info( + f"Query operational mcp server list: tenant_id={current_user.tenant_id}, username: {current_user.username}") + + # 1. Query mcp market config information from the database + api_logger.debug(f"Query mcp market config: {mcp_market_config_id}") + db_mcp_market_config = mcp_market_config_service.get_mcp_market_config_by_id(db, + mcp_market_config_id=mcp_market_config_id, + current_user=current_user) + if not db_mcp_market_config: + api_logger.warning( + f"The mcp market config does not exist or access is denied: mcp_market_config_id={mcp_market_config_id}") + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="The mcp market config does not exist or access is denied" + ) + + # 2. Execute paged query + api = MCPApi() + token = db_mcp_market_config.token + api.login(token) + + url = f'{api.mcp_base_url}/operational' + headers = api.builder_headers(api.headers) + + try: + cookies = api.get_cookies(access_token=token, cookies_required=True) + r = api.session.get(url, headers=headers, cookies=cookies) + raise_for_http_status(r) + except requests.exceptions.RequestException as e: + api_logger.error(f"Failed to get operational MCP servers: {str(e)}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to get operational MCP servers: {str(e)}" + ) + + data = api._handle_response(r) + total = data.get('total_count', 0) + mcp_server_list = data.get('mcp_server_list', []) + # items = [{ + # 'name': item.get('name', ''), + # 'id': item.get('id', ''), + # 'description': item.get('description', '') + # } for item in mcp_server_list] + + # 3. Return structured response + return success(data=mcp_server_list, msg="Query of operational mcp servers list successful") + + @router.get("/mcp_server", response_model=ApiResponse) async def get_mcp_server( mcp_market_config_id: uuid.UUID, diff --git a/api/app/controllers/memory_agent_controller.py b/api/app/controllers/memory_agent_controller.py index ccf93d68..e3d2bf92 100644 --- a/api/app/controllers/memory_agent_controller.py +++ b/api/app/controllers/memory_agent_controller.py @@ -1,28 +1,29 @@ from typing import List, Optional +from dotenv import load_dotenv +from fastapi import APIRouter, Depends, File, Form, Query, UploadFile, Header +from sqlalchemy.orm import Session +from starlette.responses import StreamingResponse + from app.cache.memory.interest_memory import InterestMemoryCache from app.celery_app import celery_app from app.core.error_codes import BizCode from app.core.language_utils import get_language_from_header from app.core.logging_config import get_api_logger +from app.core.memory.agent.utils.redis_tool import store +from app.core.memory.agent.utils.session_tools import SessionService from app.core.rag.llm.cv_model import QWenCV from app.core.response_utils import fail, success from app.db import get_db from app.dependencies import cur_workspace_access_guard, get_current_user from app.models import ModelApiKey from app.models.user_model import User -from app.core.memory.agent.utils.session_tools import SessionService -from app.core.memory.agent.utils.redis_tool import store -from app.repositories import knowledge_repository, WorkspaceRepository +from app.repositories import knowledge_repository from app.schemas.memory_agent_schema import UserInput, Write_UserInput from app.schemas.response_schema import ApiResponse from app.services import task_service, workspace_service from app.services.memory_agent_service import MemoryAgentService from app.services.model_service import ModelConfigService -from dotenv import load_dotenv -from fastapi import APIRouter, Depends, File, Form, Query, UploadFile,Header -from sqlalchemy.orm import Session -from starlette.responses import StreamingResponse load_dotenv() api_logger = get_api_logger() @@ -37,7 +38,7 @@ router = APIRouter( @router.get("/health/status", response_model=ApiResponse) async def get_health_status( - current_user: User = Depends(get_current_user) + current_user: User = Depends(get_current_user) ): """ Get latest health status written by Celery periodic task @@ -55,8 +56,9 @@ async def get_health_status( @router.get("/download_log") async def download_log( - log_type: str = Query("file", regex="^(file|transmission)$", description="日志类型: file=完整文件, transmission=实时流式传输"), - current_user: User = Depends(get_current_user) + log_type: str = Query("file", regex="^(file|transmission)$", + description="日志类型: file=完整文件, transmission=实时流式传输"), + current_user: User = Depends(get_current_user) ): """ Download or stream agent service log file @@ -75,16 +77,16 @@ async def download_log( - transmission mode: StreamingResponse with SSE """ api_logger.info(f"Log download requested with log_type={log_type}") - + # Validate log_type parameter (FastAPI Query regex already validates, but explicit check for clarity) if log_type not in ["file", "transmission"]: api_logger.warning(f"Invalid log_type parameter: {log_type}") return fail( - BizCode.BAD_REQUEST, - "无效的log_type参数", + BizCode.BAD_REQUEST, + "无效的log_type参数", "log_type必须是'file'或'transmission'" ) - + # Route to appropriate mode if log_type == "file": # File mode: Return complete log file content @@ -119,10 +121,10 @@ async def download_log( @router.post("/writer_service", response_model=ApiResponse) @cur_workspace_access_guard() async def write_server( - user_input: Write_UserInput, - language_type: str = Header(default=None, alias="X-Language-Type"), - db: Session = Depends(get_db), - current_user: User = Depends(get_current_user) + user_input: Write_UserInput, + language_type: str = Header(default=None, alias="X-Language-Type"), + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user) ): """ Write service endpoint - processes write operations synchronously @@ -136,11 +138,11 @@ async def write_server( """ # 使用集中化的语言校验 language = get_language_from_header(language_type) - + config_id = user_input.config_id workspace_id = current_user.current_workspace_id api_logger.info(f"Write service: workspace_id={workspace_id}, config_id={config_id}, language_type={language}") - + # 获取 storage_type,如果为 None 则使用默认值 storage_type = workspace_service.get_workspace_storage_type( db=db, @@ -149,7 +151,7 @@ async def write_server( ) if storage_type is None: storage_type = 'neo4j' user_rag_memory_id = '' - + # 如果 storage_type 是 rag,必须确保有有效的 user_rag_memory_id if storage_type == 'rag': if workspace_id: @@ -161,13 +163,15 @@ async def write_server( if knowledge: user_rag_memory_id = str(knowledge.id) else: - api_logger.warning(f"未找到名为 'USER_RAG_MERORY' 的知识库,workspace_id: {workspace_id},将使用 neo4j 存储") + api_logger.warning( + f"未找到名为 'USER_RAG_MERORY' 的知识库,workspace_id: {workspace_id},将使用 neo4j 存储") storage_type = 'neo4j' else: api_logger.warning("workspace_id 为空,无法使用 rag 存储,将使用 neo4j 存储") storage_type = 'neo4j' - - api_logger.info(f"Write service requested for group {user_input.end_user_id}, storage_type: {storage_type}, user_rag_memory_id: {user_rag_memory_id}") + + api_logger.info( + f"Write service requested for group {user_input.end_user_id}, storage_type: {storage_type}, user_rag_memory_id: {user_rag_memory_id}") try: messages_list = memory_agent_service.get_messages_list(user_input) result = await memory_agent_service.write_memory( @@ -175,7 +179,7 @@ async def write_server( messages_list, config_id, db, - storage_type, + storage_type, user_rag_memory_id, language ) @@ -195,10 +199,10 @@ async def write_server( @router.post("/writer_service_async", response_model=ApiResponse) @cur_workspace_access_guard() async def write_server_async( - user_input: Write_UserInput, - language_type: str = Header(default=None, alias="X-Language-Type"), - db: Session = Depends(get_db), - current_user: User = Depends(get_current_user) + user_input: Write_UserInput, + language_type: str = Header(default=None, alias="X-Language-Type"), + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user) ): """ Async write service endpoint - enqueues write processing to Celery @@ -213,10 +217,11 @@ async def write_server_async( """ # 使用集中化的语言校验 language = get_language_from_header(language_type) - + config_id = user_input.config_id workspace_id = current_user.current_workspace_id - api_logger.info(f"Async write service: workspace_id={workspace_id}, config_id={config_id}, language_type={language}") + api_logger.info( + f"Async write service: workspace_id={workspace_id}, config_id={config_id}, language_type={language}") # 获取 storage_type,如果为 None 则使用默认值 storage_type = workspace_service.get_workspace_storage_type( @@ -244,7 +249,7 @@ async def write_server_async( args=[user_input.end_user_id, messages_list, config_id, storage_type, user_rag_memory_id, language] ) api_logger.info(f"Write task queued: {task.id}") - + return success(data={"task_id": task.id}, msg="写入任务已提交") except Exception as e: api_logger.error(f"Async write operation failed: {str(e)}") @@ -254,9 +259,9 @@ async def write_server_async( @router.post("/read_service", response_model=ApiResponse) @cur_workspace_access_guard() async def read_server( - user_input: UserInput, - db: Session = Depends(get_db), - current_user: User = Depends(get_current_user) + user_input: UserInput, + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user) ): """ Read service endpoint - processes read operations synchronously @@ -291,8 +296,9 @@ async def read_server( ) if knowledge: user_rag_memory_id = str(knowledge.id) - - api_logger.info(f"Read service: group={user_input.end_user_id}, storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}, workspace_id={workspace_id}") + + api_logger.info( + f"Read service: group={user_input.end_user_id}, storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}, workspace_id={workspace_id}") try: result = await memory_agent_service.read_memory( user_input.end_user_id, @@ -306,7 +312,8 @@ async def read_server( ) if str(user_input.search_switch) == "2": retrieve_info = result['answer'] - history = await SessionService(store).get_history(user_input.end_user_id, user_input.end_user_id, user_input.end_user_id) + history = await SessionService(store).get_history(user_input.end_user_id, user_input.end_user_id, + user_input.end_user_id) query = user_input.message # 调用 memory_agent_service 的方法生成最终答案 @@ -319,7 +326,7 @@ async def read_server( db=db ) if "信息不足,无法回答" in result['answer']: - result['answer']=retrieve_info + result['answer'] = retrieve_info return success(data=result, msg="回复对话消息成功") except BaseException as e: # Handle ExceptionGroup from TaskGroup (Python 3.11+) or BaseExceptionGroup @@ -335,9 +342,10 @@ async def read_server( @router.post("/file", response_model=ApiResponse) async def file_update( files: List[UploadFile] = File(..., description="要上传的文件"), - model_id:str = Form(..., description="模型ID"), + model_id: str = Form(..., description="模型ID"), metadata: Optional[str] = Form(None, description="文件元数据 (JSON格式)"), - current_user: User = Depends(get_current_user) + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db), ): """ 文件上传接口 - 支持图片识别 @@ -350,9 +358,6 @@ async def file_update( Returns: 文件处理结果 """ - - db_gen = get_db() # get_db 通常是一个生成器 - db = next(db_gen) api_logger.info(f"File upload requested, file count: {len(files)}") config = ModelConfigService.get_model_by_id(db=db, model_id=model_id) apiConfig: ModelApiKey = config.api_keys[0] @@ -361,7 +366,7 @@ async def file_update( for file in files: api_logger.debug(f"Processing file: {file.filename}, content_type: {file.content_type}") content = await file.read() - + if file.content_type and file.content_type.startswith("image/"): vision_model = QWenCV( key=apiConfig.api_key, @@ -375,12 +380,12 @@ async def file_update( else: api_logger.warning(f"Unsupported file type: {file.content_type}") file_content.append(f"[不支持的文件类型: {file.content_type}]") - + result_text = ';'.join(file_content) api_logger.info(f"File processing completed, result length: {len(result_text)}") - + return success(data=result_text, msg="转换文本成功") - + except Exception as e: api_logger.error(f"File processing failed: {str(e)}", exc_info=True) return fail(BizCode.INTERNAL_ERROR, "转换文本失败", str(e)) @@ -430,8 +435,8 @@ async def read_server_async( @router.get("/read_result/", response_model=ApiResponse) async def get_read_task_result( - task_id: str, - current_user: User = Depends(get_current_user) + task_id: str, + current_user: User = Depends(get_current_user) ): """ Get the status and result of an async read task @@ -452,7 +457,7 @@ async def get_read_task_result( try: result = task_service.get_task_memory_read_result(task_id) status = result.get("status") - + if status == "SUCCESS": # 任务成功完成 task_result = result.get("result", {}) @@ -470,7 +475,7 @@ async def get_read_task_result( else: # 旧格式:直接返回结果 return success(data=task_result, msg="查询任务已完成") - + elif status == "FAILURE": # 任务失败 error_info = result.get("result", "Unknown error") @@ -479,7 +484,7 @@ async def get_read_task_result( else: error_msg = str(error_info) return fail(BizCode.INTERNAL_ERROR, "查询任务失败", error_msg) - + elif status in ["PENDING", "STARTED"]: # 任务进行中 return success( @@ -499,7 +504,7 @@ async def get_read_task_result( }, msg=f"任务状态: {status}" ) - + except Exception as e: api_logger.error(f"Read task status check failed: {str(e)}", exc_info=True) return fail(BizCode.INTERNAL_ERROR, "任务状态查询失败", str(e)) @@ -507,8 +512,8 @@ async def get_read_task_result( @router.get("/write_result/", response_model=ApiResponse) async def get_write_task_result( - task_id: str, - current_user: User = Depends(get_current_user) + task_id: str, + current_user: User = Depends(get_current_user) ): """ Get the status and result of an async write task @@ -529,7 +534,7 @@ async def get_write_task_result( try: result = task_service.get_task_memory_write_result(task_id) status = result.get("status") - + if status == "SUCCESS": # 任务成功完成 task_result = result.get("result", {}) @@ -547,7 +552,7 @@ async def get_write_task_result( else: # 旧格式:直接返回结果 return success(data=task_result, msg="写入任务已完成") - + elif status == "FAILURE": # 任务失败 error_info = result.get("result", "Unknown error") @@ -556,7 +561,7 @@ async def get_write_task_result( else: error_msg = str(error_info) return fail(BizCode.INTERNAL_ERROR, "写入任务失败", error_msg) - + elif status in ["PENDING", "STARTED"]: # 任务进行中 return success( @@ -576,7 +581,7 @@ async def get_write_task_result( }, msg=f"任务状态: {status}" ) - + except Exception as e: api_logger.error(f"Write task status check failed: {str(e)}", exc_info=True) return fail(BizCode.INTERNAL_ERROR, "任务状态查询失败", str(e)) @@ -584,9 +589,9 @@ async def get_write_task_result( @router.post("/status_type", response_model=ApiResponse) async def status_type( - user_input: Write_UserInput, - db: Session = Depends(get_db), - current_user: User = Depends(get_current_user) + user_input: Write_UserInput, + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user) ): """ Determine the type of user message (read or write) @@ -629,9 +634,10 @@ async def status_type( @router.get("/stats/types", response_model=ApiResponse) async def get_knowledge_type_stats_api( - end_user_id: Optional[str] = Query(None, description="用户ID(可选)"), - only_active: bool = Query(True, description="仅统计有效记录(status=1)"), - current_user: User = Depends(get_current_user) + end_user_id: Optional[str] = Query(None, description="用户ID(可选)"), + only_active: bool = Query(True, description="仅统计有效记录(status=1)"), + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db), ): """ 统计当前空间下各知识库类型的数量,包含 General | Web | Third-party | Folder。 @@ -640,14 +646,9 @@ async def get_knowledge_type_stats_api( - 知识库类型根据当前用户的 current_workspace_id 过滤 - 如果用户没有当前工作空间,对应的统计返回 0 """ - api_logger.info(f"Knowledge type stats requested for workspace_id: {current_user.current_workspace_id}, end_user_id: {end_user_id}") + api_logger.info( + f"Knowledge type stats requested for workspace_id: {current_user.current_workspace_id}, end_user_id: {end_user_id}") try: - from app.db import get_db - - # 获取数据库会话 - db_gen = get_db() - db = next(db_gen) - # 调用service层函数 result = await memory_agent_service.get_knowledge_type_stats( end_user_id=end_user_id, @@ -655,7 +656,7 @@ async def get_knowledge_type_stats_api( current_workspace_id=current_user.current_workspace_id, db=db ) - + return success(data=result, msg="获取知识库类型统计成功") except Exception as e: api_logger.error(f"Knowledge type stats failed: {str(e)}") @@ -664,11 +665,11 @@ async def get_knowledge_type_stats_api( @router.get("/analytics/interest_distribution/by_user", response_model=ApiResponse) async def get_interest_distribution_by_user_api( - end_user_id: str = Query(..., description="用户ID(必填)"), - limit: int = Query(5, le=5, description="返回兴趣标签数量限制,最多5个"), - language_type: str = Header(default=None, alias="X-Language-Type"), - current_user: User = Depends(get_current_user), - db: Session = Depends(get_db), + end_user_id: str = Query(..., description="用户ID(必填)"), + limit: int = Query(5, le=5, description="返回兴趣标签数量限制,最多5个"), + language_type: str = Header(default=None, alias="X-Language-Type"), + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db), ): """ 获取指定用户的兴趣分布标签 @@ -716,9 +717,9 @@ async def get_interest_distribution_by_user_api( @router.get("/analytics/user_profile", response_model=ApiResponse) async def get_user_profile_api( - end_user_id: Optional[str] = Query(None, description="用户ID(可选)"), - db: Session = Depends(get_db), - current_user: User = Depends(get_current_user) + end_user_id: Optional[str] = Query(None, description="用户ID(可选)"), + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user) ): """ 获取用户详情,包含: @@ -756,17 +757,17 @@ async def get_user_profile_api( # ): # """ # Get parsed API documentation (Public endpoint - no authentication required) - + # Args: # file_path: Optional path to API docs file. If None, uses default path. - + # Returns: # Parsed API documentation including title, meta info, and sections # """ # api_logger.info(f"API docs requested, file_path: {file_path or 'default'}") # try: # result = await memory_agent_service.get_api_docs(file_path) - + # if result.get("success"): # return success(msg=result["msg"], data=result["data"]) # else: @@ -782,9 +783,9 @@ async def get_user_profile_api( @router.get("/end_user/{end_user_id}/connected_config", response_model=ApiResponse) async def get_end_user_connected_config( - end_user_id: str, - db: Session = Depends(get_db), - current_user: User = Depends(get_current_user) + end_user_id: str, + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user) ): """ 获取终端用户关联的记忆配置 @@ -803,9 +804,9 @@ async def get_end_user_connected_config( from app.services.memory_agent_service import ( get_end_user_connected_config as get_config, ) - + api_logger.info(f"Getting connected config for end_user: {end_user_id}") - + try: result = get_config(end_user_id, db) return success(data=result, msg="获取终端用户关联配置成功") @@ -814,4 +815,4 @@ async def get_end_user_connected_config( return fail(BizCode.NOT_FOUND, str(e)) except Exception as e: api_logger.error(f"Failed to get end user connected config: {str(e)}", exc_info=True) - return fail(BizCode.INTERNAL_ERROR, "获取终端用户关联配置失败", str(e)) \ No newline at end of file + return fail(BizCode.INTERNAL_ERROR, "获取终端用户关联配置失败", str(e)) diff --git a/api/app/controllers/memory_dashboard_controller.py b/api/app/controllers/memory_dashboard_controller.py index 475d184e..1b5b45fb 100644 --- a/api/app/controllers/memory_dashboard_controller.py +++ b/api/app/controllers/memory_dashboard_controller.py @@ -606,8 +606,8 @@ async def dashboard_data( # 获取RAG相关数据 try: - # total_memory: 使用 total_chunk(总chunk数) - total_chunk = memory_dashboard_service.get_rag_total_chunk(db, current_user) + # total_memory: 只统计用户知识库(permission_id='Memory')的chunk数 + total_chunk = memory_dashboard_service.get_rag_user_kb_total_chunk(db, current_user) rag_data["total_memory"] = total_chunk # total_app: 统计当前空间下的所有app数量 diff --git a/api/app/controllers/memory_storage_controller.py b/api/app/controllers/memory_storage_controller.py index 826724c9..ee45fb83 100644 --- a/api/app/controllers/memory_storage_controller.py +++ b/api/app/controllers/memory_storage_controller.py @@ -2,7 +2,7 @@ from typing import Optional from uuid import UUID from fastapi import APIRouter, Depends, Query -from fastapi.responses import StreamingResponse +from fastapi.responses import StreamingResponse, JSONResponse from sqlalchemy.orm import Session from app.core.error_codes import BizCode @@ -85,6 +85,7 @@ def create_config( payload: ConfigParamsCreate, current_user: User = Depends(get_current_user), db: Session = Depends(get_db), + x_language_type: Optional[str] = Header(None, alias="X-Language-Type"), ) -> dict: workspace_id = current_user.current_workspace_id # 检查用户是否已选择工作空间 @@ -99,7 +100,29 @@ def create_config( svc = DataConfigService(db) result = svc.create(payload) return success(data=result, msg="创建成功") + except ValueError as e: + err_str = str(e) + if err_str.startswith("DUPLICATE_CONFIG_NAME:"): + config_name = err_str.split(":", 1)[1] + api_logger.warning(f"重复的配置名称 '{config_name}' 在工作空间 {workspace_id}") + lang = get_language_from_header(x_language_type) + if lang == "en": + msg = fail(BizCode.BAD_REQUEST, "Config name already exists", f"A config named \"{config_name}\" already exists in the current workspace. Please use a different name.") + else: + msg = fail(BizCode.BAD_REQUEST, "配置名称已存在", f"当前工作空间下已存在名为「{config_name}」的记忆配置,请使用其他名称") + return JSONResponse(status_code=400, content=msg) + api_logger.error(f"Create config failed: {err_str}") + return fail(BizCode.INTERNAL_ERROR, "创建配置失败", err_str) except Exception as e: + from sqlalchemy.exc import IntegrityError + if isinstance(e, IntegrityError) and "uq_workspace_config_name" in str(getattr(e, 'orig', '')): + api_logger.warning(f"重复的配置名称 '{payload.config_name}' 在工作空间 {workspace_id}") + lang = get_language_from_header(x_language_type) + if lang == "en": + msg = fail(BizCode.BAD_REQUEST, "Config name already exists", f"A config named \"{payload.config_name}\" already exists in the current workspace. Please use a different name.") + else: + msg = fail(BizCode.BAD_REQUEST, "配置名称已存在", f"当前工作空间下已存在名为「{payload.config_name}」的记忆配置,请使用其他名称") + return JSONResponse(status_code=400, content=msg) api_logger.error(f"Create config failed: {str(e)}") return fail(BizCode.INTERNAL_ERROR, "创建配置失败", str(e)) diff --git a/api/app/controllers/model_controller.py b/api/app/controllers/model_controller.py index bb1ba526..0de3d4fe 100644 --- a/api/app/controllers/model_controller.py +++ b/api/app/controllers/model_controller.py @@ -469,7 +469,9 @@ async def create_model_api_key_by_provider( config=api_key_data.config, is_active=api_key_data.is_active, priority=api_key_data.priority, - model_config_ids=model_config_ids + model_config_ids=model_config_ids, + capability=api_key_data.capability, + is_omni=api_key_data.is_omni ) created_keys, failed_models = await ModelApiKeyService.create_api_key_by_provider(db=db, data=create_data) diff --git a/api/app/controllers/ontology_controller.py b/api/app/controllers/ontology_controller.py index e4a87141..3d2a1bdb 100644 --- a/api/app/controllers/ontology_controller.py +++ b/api/app/controllers/ontology_controller.py @@ -25,7 +25,7 @@ from typing import Dict, Optional, List from urllib.parse import quote from fastapi import APIRouter, Depends, HTTPException, File, UploadFile, Form, Header -from fastapi.responses import StreamingResponse +from fastapi.responses import StreamingResponse, JSONResponse from sqlalchemy.orm import Session from app.core.config import settings @@ -124,15 +124,23 @@ def _get_ontology_service( ) # 通过 Repository 获取可用的 API Key(负载均衡逻辑由 Repository 处理) - from app.repositories.model_repository import ModelApiKeyRepository - api_keys = ModelApiKeyRepository.get_by_model_config(db, model_config.id) - if not api_keys: + # from app.repositories.model_repository import ModelApiKeyRepository + from app.services.model_service import ModelApiKeyService + api_key_config = ModelApiKeyService.get_available_api_key(db, model_config.id) + if not api_key_config: logger.error(f"Model {llm_id} has no active API key") raise HTTPException( status_code=400, detail="指定的LLM模型没有可用的API密钥" ) - api_key_config = api_keys[0] + # api_keys = ModelApiKeyRepository.get_by_model_config(db, model_config.id) + # if not api_keys: + # logger.error(f"Model {llm_id} has no active API key") + # raise HTTPException( + # status_code=400, + # detail="指定的LLM模型没有可用的API密钥" + # ) + # api_key_config = api_keys[0] is_composite = getattr(model_config, 'is_composite', False) logger.info( @@ -154,6 +162,7 @@ def _get_ontology_service( provider=actual_provider, api_key=api_key_config.api_key, base_url=api_key_config.api_base, + is_omni=api_key_config.is_omni, max_retries=3, timeout=60.0 ) @@ -280,7 +289,8 @@ async def extract_ontology( async def create_scene( request: SceneCreateRequest, db: Session = Depends(get_db), - current_user: User = Depends(get_current_user) + current_user: User = Depends(get_current_user), + x_language_type: Optional[str] = Header(None, alias="X-Language-Type") ): """创建本体场景 @@ -351,8 +361,18 @@ async def create_scene( return fail(BizCode.BAD_REQUEST, "请求参数无效", str(e)) except RuntimeError as e: - api_logger.error(f"Runtime error in scene creation: {str(e)}", exc_info=True) - return fail(BizCode.INTERNAL_ERROR, "场景创建失败", str(e)) + err_str = str(e) + if "UniqueViolation" in err_str or "uq_workspace_scene_name" in err_str: + api_logger.warning(f"Duplicate scene name '{request.scene_name}' in workspace {current_user.current_workspace_id}") + from app.core.language_utils import get_language_from_header + lang = get_language_from_header(x_language_type) + if lang == "en": + msg = fail(BizCode.BAD_REQUEST, "Scene name already exists", f"A scene named \"{request.scene_name}\" already exists in the current workspace. Please use a different name.") + else: + msg = fail(BizCode.BAD_REQUEST, "场景名称已存在", f"当前工作空间下已存在名为「{request.scene_name}」的场景,请使用其他名称") + return JSONResponse(status_code=400, content=msg) + api_logger.error(f"Runtime error in scene creation: {err_str}", exc_info=True) + return fail(BizCode.INTERNAL_ERROR, "场景创建失败", err_str) except Exception as e: api_logger.error(f"Unexpected error in scene creation: {str(e)}", exc_info=True) @@ -514,10 +534,9 @@ async def delete_scene( f"尝试删除系统默认场景: user_id={current_user.id}, " f"scene_id={scene_id}, scene_name={scene.scene_name}" ) - return fail( - BizCode.BAD_REQUEST, - "系统默认场景不可删除", - "该场景为系统预设场景,不允许删除" + raise HTTPException( + status_code=400, + detail="SYSTEM_DEFAULT_SCENE_CANNOT_DELETE" ) # 创建OntologyService实例 @@ -543,6 +562,9 @@ async def delete_scene( return success(data={"deleted": success_flag}, msg="场景删除成功") + except HTTPException: + raise + except ValueError as e: api_logger.warning(f"Validation error in scene deletion: {str(e)}") return fail(BizCode.BAD_REQUEST, "请求参数无效", str(e)) @@ -650,7 +672,8 @@ async def get_scenes( async def create_class( request: ClassCreateRequest, db: Session = Depends(get_db), - current_user: User = Depends(get_current_user) + current_user: User = Depends(get_current_user), + x_language_type: Optional[str] = Header(None, alias="X-Language-Type") ): """创建本体类型 @@ -665,7 +688,7 @@ async def create_class( ApiResponse: 包含创建的类型信息 """ from app.controllers.ontology_secondary_routes import create_class_handler - return await create_class_handler(request, db, current_user) + return await create_class_handler(request, db, current_user, x_language_type) @router.put("/class/{class_id}", response_model=ApiResponse) diff --git a/api/app/controllers/ontology_secondary_routes.py b/api/app/controllers/ontology_secondary_routes.py index 607a0739..8720065b 100644 --- a/api/app/controllers/ontology_secondary_routes.py +++ b/api/app/controllers/ontology_secondary_routes.py @@ -7,7 +7,7 @@ from uuid import UUID from typing import Optional -from fastapi import Depends +from fastapi import Depends, Header from sqlalchemy.orm import Session from app.core.error_codes import BizCode @@ -58,7 +58,7 @@ async def scenes_handler( workspace_id: Optional[str] = None, scene_name: Optional[str] = None, page: Optional[int] = None, - page_size: Optional[int] = None, + pagesize: Optional[int] = None, db: Session = Depends(get_db), current_user: User = Depends(get_current_user) ): @@ -71,14 +71,14 @@ async def scenes_handler( workspace_id: 工作空间ID(可选,默认当前用户工作空间) scene_name: 场景名称关键词(可选,支持模糊匹配) page: 页码(可选,从1开始,仅在全量查询时有效) - page_size: 每页数量(可选,仅在全量查询时有效) + pagesize: 每页数量(可选,仅在全量查询时有效) db: 数据库会话 current_user: 当前用户 """ operation = "search" if scene_name else "list" api_logger.info( f"Scene {operation} requested by user {current_user.id}, " - f"workspace_id={workspace_id}, keyword={scene_name}, page={page}, page_size={page_size}" + f"workspace_id={workspace_id}, keyword={scene_name}, page={page}, pagesize={pagesize}" ) try: @@ -105,13 +105,13 @@ async def scenes_handler( api_logger.warning(f"Invalid page number: {page}") return fail(BizCode.BAD_REQUEST, "请求参数无效", "页码必须大于0") - if page_size is not None and page_size < 1: - api_logger.warning(f"Invalid page_size: {page_size}") + if pagesize is not None and pagesize < 1: + api_logger.warning(f"Invalid pagesize: {pagesize}") return fail(BizCode.BAD_REQUEST, "请求参数无效", "每页数量必须大于0") - # 如果只提供了page或page_size中的一个,返回错误 - if (page is not None and page_size is None) or (page is None and page_size is not None): - api_logger.warning(f"Incomplete pagination params: page={page}, page_size={page_size}") + # 如果只提供了page或pagesize中的一个,返回错误 + if (page is not None and pagesize is None) or (page is None and pagesize is not None): + api_logger.warning(f"Incomplete pagination params: page={page}, pagesize={pagesize}") return fail(BizCode.BAD_REQUEST, "请求参数无效", "分页参数page和pagesize必须同时提供") # 模糊搜索场景(支持分页) @@ -119,17 +119,15 @@ async def scenes_handler( total = len(scenes) # 如果提供了分页参数,进行分页处理 - if page is not None and page_size is not None: - start_idx = (page - 1) * page_size - end_idx = start_idx + page_size + if page is not None and pagesize is not None: + start_idx = (page - 1) * pagesize + end_idx = start_idx + pagesize scenes = scenes[start_idx:end_idx] # 构建响应 items = [] for scene in scenes: - # 获取前3个class_name作为entity_type entity_type = [cls.class_name for cls in scene.classes[:3]] if scene.classes else None - # 动态计算 type_num type_num = len(scene.classes) if scene.classes else 0 items.append(SceneResponse( @@ -141,17 +139,16 @@ async def scenes_handler( workspace_id=scene.workspace_id, created_at=scene.created_at, updated_at=scene.updated_at, - classes_count=type_num + classes_count=type_num, + is_system_default=scene.is_system_default )) # 构建响应(包含分页信息) - if page is not None and page_size is not None: - # 计算是否有下一页 - hasnext = (page * page_size) < total - + if page is not None and pagesize is not None: + hasnext = (page * pagesize) < total pagination_info = PaginationInfo( page=page, - pagesize=page_size, + pagesize=pagesize, total=total, hasnext=hasnext ) @@ -165,28 +162,25 @@ async def scenes_handler( ) else: # 获取所有场景(支持分页) - # 验证分页参数 if page is not None and page < 1: api_logger.warning(f"Invalid page number: {page}") return fail(BizCode.BAD_REQUEST, "请求参数无效", "页码必须大于0") - if page_size is not None and page_size < 1: - api_logger.warning(f"Invalid page_size: {page_size}") + if pagesize is not None and pagesize < 1: + api_logger.warning(f"Invalid pagesize: {pagesize}") return fail(BizCode.BAD_REQUEST, "请求参数无效", "每页数量必须大于0") - # 如果只提供了page或page_size中的一个,返回错误 - if (page is not None and page_size is None) or (page is None and page_size is not None): - api_logger.warning(f"Incomplete pagination params: page={page}, page_size={page_size}") + # 如果只提供了page或pagesize中的一个,返回错误 + if (page is not None and pagesize is None) or (page is None and pagesize is not None): + api_logger.warning(f"Incomplete pagination params: page={page}, pagesize={pagesize}") return fail(BizCode.BAD_REQUEST, "请求参数无效", "分页参数page和pagesize必须同时提供") - scenes, total = service.list_scenes(ws_uuid, page, page_size) + scenes, total = service.list_scenes(ws_uuid, page, pagesize) # 构建响应 items = [] for scene in scenes: - # 获取前3个class_name作为entity_type entity_type = [cls.class_name for cls in scene.classes[:3]] if scene.classes else None - # 动态计算 type_num type_num = len(scene.classes) if scene.classes else 0 items.append(SceneResponse( @@ -198,17 +192,16 @@ async def scenes_handler( workspace_id=scene.workspace_id, created_at=scene.created_at, updated_at=scene.updated_at, - classes_count=type_num + classes_count=type_num, + is_system_default=scene.is_system_default )) # 构建响应(包含分页信息) - if page is not None and page_size is not None: - # 计算是否有下一页 - hasnext = (page * page_size) < total - + if page is not None and pagesize is not None: + hasnext = (page * pagesize) < total pagination_info = PaginationInfo( page=page, - pagesize=page_size, + pagesize=pagesize, total=total, hasnext=hasnext ) @@ -238,7 +231,8 @@ async def scenes_handler( async def create_class_handler( request: ClassCreateRequest, db: Session = Depends(get_db), - current_user: User = Depends(get_current_user) + current_user: User = Depends(get_current_user), + x_language_type: Optional[str] = None ): """创建本体类型(统一使用列表形式,支持单个或批量)""" @@ -271,8 +265,11 @@ async def create_class_handler( ] if count == 1: - # 单个创建 + # 单个创建 - 先检查重名 class_data = classes_data[0] + existing = OntologyClassRepository(db).get_by_name(class_data["class_name"], request.scene_id) + if existing: + raise ValueError(f"DUPLICATE_CLASS_NAME:{class_data['class_name']}") ontology_class = service.create_class( scene_id=request.scene_id, class_name=class_data["class_name"], @@ -330,12 +327,36 @@ async def create_class_handler( return success(data=response.model_dump(mode='json'), msg="批量创建完成") except ValueError as e: - api_logger.warning(f"Validation error in class creation: {str(e)}") - return fail(BizCode.BAD_REQUEST, "请求参数无效", str(e)) - + err_str = str(e) + if err_str.startswith("DUPLICATE_CLASS_NAME:"): + class_name = err_str.split(":", 1)[1] + api_logger.warning(f"Duplicate class name '{class_name}' in scene {request.scene_id}") + from app.core.language_utils import get_language_from_header + from fastapi.responses import JSONResponse + lang = get_language_from_header(x_language_type) + if lang == "en": + msg = fail(BizCode.BAD_REQUEST, "Class name already exists", f"A class named \"{class_name}\" already exists in this scene. Please use a different name.") + else: + msg = fail(BizCode.BAD_REQUEST, "类型名称已存在", f"当前场景下已存在名为「{class_name}」的类型,请使用其他名称") + return JSONResponse(status_code=400, content=msg) + api_logger.warning(f"Validation error in class creation: {err_str}") + return fail(BizCode.BAD_REQUEST, "请求参数无效", err_str) + except RuntimeError as e: - api_logger.error(f"Runtime error in class creation: {str(e)}", exc_info=True) - return fail(BizCode.INTERNAL_ERROR, "类型创建失败", str(e)) + err_str = str(e) + if "UniqueViolation" in err_str or "uq_scene_class_name" in err_str: + api_logger.warning(f"Duplicate class name in scene {request.scene_id}") + from app.core.language_utils import get_language_from_header + from fastapi.responses import JSONResponse + lang = get_language_from_header(x_language_type) + class_name = request.classes[0].class_name if request.classes else "" + if lang == "en": + msg = fail(BizCode.BAD_REQUEST, "Class name already exists", f"A class named \"{class_name}\" already exists in this scene. Please use a different name.") + else: + msg = fail(BizCode.BAD_REQUEST, "类型名称已存在", f"当前场景下已存在名为「{class_name}」的类型,请使用其他名称") + return JSONResponse(status_code=400, content=msg) + api_logger.error(f"Runtime error in class creation: {err_str}", exc_info=True) + return fail(BizCode.INTERNAL_ERROR, "类型创建失败", err_str) except Exception as e: api_logger.error(f"Unexpected error in class creation: {str(e)}", exc_info=True) @@ -615,6 +636,7 @@ async def classes_handler( scene_id=scene_uuid, scene_name=scene.scene_name, scene_description=scene.scene_description, + is_system_default=scene.is_system_default, items=items ) diff --git a/api/app/controllers/service/app_api_controller.py b/api/app/controllers/service/app_api_controller.py index 61a919b1..64143f57 100644 --- a/api/app/controllers/service/app_api_controller.py +++ b/api/app/controllers/service/app_api_controller.py @@ -249,6 +249,7 @@ async def chat( app_id=app.id, workspace_id=workspace_id, release_id=app.current_release.id, + public=True ): event_type = event.get("event", "message") event_data = event.get("data", {}) diff --git a/api/app/controllers/service/memory_api_controller.py b/api/app/controllers/service/memory_api_controller.py index accd749e..34489e8a 100644 --- a/api/app/controllers/service/memory_api_controller.py +++ b/api/app/controllers/service/memory_api_controller.py @@ -39,7 +39,7 @@ async def write_memory_api_service( Stores memory content for the specified end user using the Memory API Service. """ - logger.info(f"Memory write request - end_user_id: {payload.end_user_id}, tenant_id: {api_key_auth.tenant_id}") + logger.info(f"Memory write request - end_user_id: {payload.end_user_id}, workspace_id: {api_key_auth.workspace_id}") memory_api_service = MemoryAPIService(db) diff --git a/api/app/core/agent/langchain_agent.py b/api/app/core/agent/langchain_agent.py index fae20ea2..88b6371c 100644 --- a/api/app/core/agent/langchain_agent.py +++ b/api/app/core/agent/langchain_agent.py @@ -11,35 +11,37 @@ LangChain Agent 封装 import time from typing import Any, AsyncGenerator, Dict, List, Optional, Sequence -from app.core.memory.agent.langgraph_graph.write_graph import write_long_term +from app.core.memory.agent.langgraph_graph.write_graph import write_long_term from app.db import get_db from app.core.logging_config import get_business_logger from app.core.models import RedBearLLM, RedBearModelConfig -from app.models.models_model import ModelType +from app.models.models_model import ModelType, ModelProvider from app.services.memory_agent_service import ( get_end_user_connected_config, ) from langchain.agents import create_agent from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage from langchain_core.tools import BaseTool + logger = get_business_logger() class LangChainAgent: def __init__( - self, - model_name: str, - api_key: str, - provider: str = "openai", - api_base: Optional[str] = None, - temperature: float = 0.7, - max_tokens: int = 2000, - system_prompt: Optional[str] = None, - tools: Optional[Sequence[BaseTool]] = None, - streaming: bool = False, - max_iterations: Optional[int] = None, # 最大迭代次数(None 表示自动计算) - max_tool_consecutive_calls: int = 3 # 单个工具最大连续调用次数 + self, + model_name: str, + api_key: str, + provider: str = "openai", + api_base: Optional[str] = None, + is_omni: bool = False, + temperature: float = 0.7, + max_tokens: int = 2000, + system_prompt: Optional[str] = None, + tools: Optional[Sequence[BaseTool]] = None, + streaming: bool = False, + max_iterations: Optional[int] = None, # 最大迭代次数(None 表示自动计算) + max_tool_consecutive_calls: int = 3 # 单个工具最大连续调用次数 ): """初始化 LangChain Agent @@ -60,12 +62,13 @@ class LangChainAgent: self.provider = provider self.tools = tools or [] self.streaming = streaming + self.is_omni = is_omni self.max_tool_consecutive_calls = max_tool_consecutive_calls - + # 工具调用计数器:记录每个工具的连续调用次数 self.tool_call_counter: Dict[str, int] = {} self.last_tool_called: Optional[str] = None - + # 根据工具数量动态调整最大迭代次数 # 基础值 + 每个工具额外的调用机会 if max_iterations is None: @@ -73,9 +76,9 @@ class LangChainAgent: self.max_iterations = 5 + len(self.tools) * 2 else: self.max_iterations = max_iterations - + self.system_prompt = system_prompt or "你是一个专业的AI助手" - + logger.debug( f"Agent 迭代次数配置: max_iterations={self.max_iterations}, " f"tool_count={len(self.tools)}, " @@ -89,6 +92,7 @@ class LangChainAgent: provider=provider, api_key=api_key, base_url=api_base, + is_omni=is_omni, extra_params={ "temperature": temperature, "max_tokens": max_tokens, @@ -143,21 +147,22 @@ class LangChainAgent: """ from langchain_core.tools import StructuredTool from functools import wraps - + wrapped_tools = [] - + for original_tool in tools: tool_name = original_tool.name original_func = original_tool.func if hasattr(original_tool, 'func') else None - + if not original_func: # 如果无法获取原始函数,直接使用原工具 wrapped_tools.append(original_tool) continue - + # 创建包装函数 def make_wrapped_func(tool_name, original_func): """创建包装函数的工厂函数,避免闭包问题""" + @wraps(original_func) def wrapped_func(*args, **kwargs): """包装后的工具函数,跟踪连续调用次数""" @@ -168,13 +173,13 @@ class LangChainAgent: # 切换到新工具,重置计数器 self.tool_call_counter[tool_name] = 1 self.last_tool_called = tool_name - + current_count = self.tool_call_counter[tool_name] - + logger.debug( f"工具调用: {tool_name}, 连续调用次数: {current_count}/{self.max_tool_consecutive_calls}" ) - + # 检查是否超过最大连续调用次数 if current_count > self.max_tool_consecutive_calls: logger.warning( @@ -185,12 +190,12 @@ class LangChainAgent: f"工具 '{tool_name}' 已连续调用 {self.max_tool_consecutive_calls} 次," f"未找到有效结果。请尝试其他方法或直接回答用户的问题。" ) - + # 调用原始工具函数 return original_func(*args, **kwargs) - + return wrapped_func - + # 使用 StructuredTool 创建新工具 wrapped_tool = StructuredTool( name=original_tool.name, @@ -198,17 +203,17 @@ class LangChainAgent: func=make_wrapped_func(tool_name, original_func), args_schema=original_tool.args_schema if hasattr(original_tool, 'args_schema') else None ) - + wrapped_tools.append(wrapped_tool) - + return wrapped_tools def _prepare_messages( - self, - message: str, - history: Optional[List[Dict[str, str]]] = None, - context: Optional[str] = None, - files: Optional[List[Dict[str, Any]]] = None + self, + message: str, + history: Optional[List[Dict[str, str]]] = None, + context: Optional[str] = None, + files: Optional[List[Dict[str, Any]]] = None ) -> List[BaseMessage]: """准备消息列表 @@ -248,7 +253,7 @@ class LangChainAgent: messages.append(HumanMessage(content=user_content)) return messages - + def _build_multimodal_content(self, text: str, files: List[Dict[str, Any]]) -> List[Dict[str, Any]]: """ 构建多模态消息内容 @@ -261,23 +266,26 @@ class LangChainAgent: List[Dict]: 消息内容列表 """ # 根据 provider 使用不同的文本格式 - if self.provider.lower() in ["bedrock", "anthropic"]: - # Anthropic/Bedrock: {"type": "text", "text": "..."} - content_parts = [{"type": "text", "text": text}] - else: - # 通义千问等: {"text": "..."} - content_parts = [{"text": text}] - + # if (self.provider.lower() in [ModelProvider.BEDROCK, ModelProvider.OPENAI, ModelProvider.XINFERENCE, + # ModelProvider.GPUSTACK] or ( + # self.provider.lower() == ModelProvider.DASHSCOPE and self.is_omni)): + # # Anthropic/Bedrock/Xinference/Gpustack/Openai: {"type": "text", "text": "..."} + # content_parts = [{"type": "text", "text": text}] + # else: + # # 通义千问等: {"text": "..."} + # content_parts = [{"type": "text", "text": text}] + content_parts = [{"type": "text", "text": text}] + # 添加文件内容 # MultimodalService 已经根据 provider 返回了正确格式,直接使用 content_parts.extend(files) - + logger.debug( f"构建多模态消息: provider={self.provider}, " f"parts={len(content_parts)}, " f"files={len(files)}" ) - + return content_parts async def chat( @@ -302,7 +310,7 @@ class LangChainAgent: Returns: Dict: 包含 content 和元数据的字典 """ - message_chat= message + message_chat = message start_time = time.time() actual_config_id = config_id # If config_id is None, try to get from end_user's connected config @@ -322,8 +330,8 @@ class LangChainAgent: except Exception as e: logger.warning(f"Failed to get db session: {e}") actual_end_user_id = end_user_id if end_user_id is not None else "unknown" - logger.info(f'写入类型{storage_type,str(end_user_id), message, str(user_rag_memory_id)}') - print(f'写入类型{storage_type,str(end_user_id), message, str(user_rag_memory_id)}') + logger.info(f'写入类型{storage_type, str(end_user_id), message, str(user_rag_memory_id)}') + print(f'写入类型{storage_type, str(end_user_id), message, str(user_rag_memory_id)}') try: # 准备消息列表(支持多模态) messages = self._prepare_messages(message, history, context, files) @@ -367,14 +375,14 @@ class LangChainAgent: # 获取最后的 AI 消息 output_messages = result.get("messages", []) content = "" - + logger.debug(f"输出消息数量: {len(output_messages)}") total_tokens = 0 for msg in reversed(output_messages): if isinstance(msg, AIMessage): logger.debug(f"找到 AI 消息,content 类型: {type(msg.content)}") logger.debug(f"AI 消息内容: {msg.content}") - + # 处理多模态响应:content 可能是字符串或列表 if isinstance(msg.content, str): content = msg.content @@ -407,12 +415,13 @@ class LangChainAgent: response_meta = msg.response_metadata if hasattr(msg, 'response_metadata') else None total_tokens = response_meta.get("token_usage", {}).get("total_tokens", 0) if response_meta else 0 break - + logger.info(f"最终提取的内容长度: {len(content)}") elapsed_time = time.time() - start_time if memory_flag: - await write_long_term(storage_type, end_user_id, message_chat, content, user_rag_memory_id, actual_config_id) + await write_long_term(storage_type, end_user_id, message_chat, content, user_rag_memory_id, + actual_config_id) response = { "content": content, "model": self.model_name, @@ -439,16 +448,16 @@ class LangChainAgent: raise async def chat_stream( - self, - message: str, - history: Optional[List[Dict[str, str]]] = None, - context: Optional[str] = None, - end_user_id:Optional[str] = None, - config_id: Optional[str] = None, - storage_type:Optional[str] = None, - user_rag_memory_id:Optional[str] = None, - memory_flag: Optional[bool] = True, - files: Optional[List[Dict[str, Any]]] = None # 新增:多模态文件 + self, + message: str, + history: Optional[List[Dict[str, str]]] = None, + context: Optional[str] = None, + end_user_id: Optional[str] = None, + config_id: Optional[str] = None, + storage_type: Optional[str] = None, + user_rag_memory_id: Optional[str] = None, + memory_flag: Optional[bool] = True, + files: Optional[List[Dict[str, Any]]] = None # 新增:多模态文件 ) -> AsyncGenerator[str, None]: """执行流式对话 @@ -482,7 +491,6 @@ class LangChainAgent: except Exception as e: logger.warning(f"Failed to get db session: {e}") - # 注意:不在这里写入用户消息,等 AI 回复后一起写入 try: # 准备消息列表(支持多模态) @@ -500,13 +508,13 @@ class LangChainAgent: full_content = '' try: async for event in self.agent.astream_events( - {"messages": messages}, - version="v2", - config={"recursion_limit": self.max_iterations} + {"messages": messages}, + version="v2", + config={"recursion_limit": self.max_iterations} ): chunk_count += 1 kind = event.get("event") - + # 处理所有可能的流式事件 if kind == "on_chat_model_stream": # LLM 流式输出 @@ -540,7 +548,7 @@ class LangChainAgent: full_content += item yield item yielded_content = True - + elif kind == "on_llm_stream": # 另一种 LLM 流式事件 chunk = event.get("data", {}).get("chunk") @@ -577,13 +585,13 @@ class LangChainAgent: full_content += chunk yield chunk yielded_content = True - + # 记录工具调用(可选) elif kind == "on_tool_start": logger.debug(f"工具调用开始: {event.get('name')}") elif kind == "on_tool_end": logger.debug(f"工具调用结束: {event.get('name')}") - + logger.debug(f"Agent 流式完成,共 {chunk_count} 个事件") # 统计token消耗 output_messages = event.get("data", {}).get("output", {}).get("messages", []) @@ -595,7 +603,8 @@ class LangChainAgent: yield total_tokens break if memory_flag: - await write_long_term(storage_type, end_user_id, message_chat, full_content, user_rag_memory_id, actual_config_id) + await write_long_term(storage_type, end_user_id, message_chat, full_content, user_rag_memory_id, + actual_config_id) except Exception as e: logger.error(f"Agent astream_events 失败: {str(e)}", exc_info=True) raise @@ -609,5 +618,3 @@ class LangChainAgent: logger.info("=" * 80) logger.info("chat_stream 方法执行结束") logger.info("=" * 80) - - diff --git a/api/app/core/config.py b/api/app/core/config.py index 62ff5c37..ba17da93 100644 --- a/api/app/core/config.py +++ b/api/app/core/config.py @@ -190,8 +190,10 @@ class Settings: LOG_FILE_MAX_SIZE_MB: int = int(os.getenv("LOG_FILE_MAX_SIZE_MB", "10")) # 10MB # Celery configuration (internal) - CELERY_BROKER: int = int(os.getenv("CELERY_BROKER", "1")) - CELERY_BACKEND: int = int(os.getenv("CELERY_BACKEND", "2")) + # NOTE: 变量名不以 CELERY_ 开头,避免被 Celery CLI 的前缀匹配机制劫持 + # 详见 docs/celery-env-bug-report.md + REDIS_DB_CELERY_BROKER: int = int(os.getenv("REDIS_DB_CELERY_BROKER", "1")) + REDIS_DB_CELERY_BACKEND: int = int(os.getenv("REDIS_DB_CELERY_BACKEND", "2")) # SMTP Email Configuration SMTP_SERVER: str = os.getenv("SMTP_SERVER", "smtp.gmail.com") @@ -219,8 +221,12 @@ class Settings: FORGETTING_CYCLE_INTERVAL_HOURS: int = TypeAdapter( Annotated[int, Field(ge=1, description="forgetting cycle interval in hours, must be >= 1")] ).validate_python(int(os.getenv("FORGETTING_CYCLE_INTERVAL_HOURS", "24"))) - + + IMPLICIT_EMOTIONS_UPDATE_HOUR: int = int(os.getenv("IMPLICIT_EMOTIONS_UPDATE_HOUR", "2")) + # implicit_emotions_update: 每天几分执行(分钟,0-59) + IMPLICIT_EMOTIONS_UPDATE_MINUTE: int = int(os.getenv("IMPLICIT_EMOTIONS_UPDATE_MINUTE", "0")) # Memory Module Configuration (internal) + MEMORY_OUTPUT_DIR: str = os.getenv("MEMORY_OUTPUT_DIR", "logs/memory-output") MEMORY_CONFIG_DIR: str = os.getenv("MEMORY_CONFIG_DIR", "app/core/memory") diff --git a/api/app/core/memory/agent/langgraph_graph/nodes/problem_nodes.py b/api/app/core/memory/agent/langgraph_graph/nodes/problem_nodes.py index ac1fb9a6..c8cc0460 100644 --- a/api/app/core/memory/agent/langgraph_graph/nodes/problem_nodes.py +++ b/api/app/core/memory/agent/langgraph_graph/nodes/problem_nodes.py @@ -1,10 +1,10 @@ -import os import json +import os import time -from app.core.logging_config import get_agent_logger -from app.db import get_db +from app.core.logging_config import get_agent_logger from app.core.memory.agent.models.problem_models import ProblemExtensionResponse +from app.core.memory.agent.services.optimized_llm_service import LLMServiceMixin from app.core.memory.agent.utils.llm_tools import ( PROJECT_ROOT_, ReadState, @@ -12,10 +12,9 @@ from app.core.memory.agent.utils.llm_tools import ( from app.core.memory.agent.utils.redis_tool import store from app.core.memory.agent.utils.session_tools import SessionService from app.core.memory.agent.utils.template_tools import TemplateService -from app.core.memory.agent.services.optimized_llm_service import LLMServiceMixin +from app.db import get_db_context template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt') -db_session = next(get_db()) logger = get_agent_logger(__name__) @@ -53,13 +52,14 @@ async def Split_The_Problem(state: ReadState) -> ReadState: try: # 使用优化的LLM服务 - structured = await problem_service.call_llm_structured( - state=state, - db_session=db_session, - system_prompt=system_prompt, - response_model=ProblemExtensionResponse, - fallback_value=[] - ) + with get_db_context() as db_session: + structured = await problem_service.call_llm_structured( + state=state, + db_session=db_session, + system_prompt=system_prompt, + response_model=ProblemExtensionResponse, + fallback_value=[] + ) # 添加更详细的日志记录 logger.info(f"Split_The_Problem: 开始处理问题分解,内容长度: {len(content)}") @@ -171,13 +171,14 @@ async def Problem_Extension(state: ReadState) -> ReadState: try: # 使用优化的LLM服务 - response_content = await problem_service.call_llm_structured( - state=state, - db_session=db_session, - system_prompt=system_prompt, - response_model=ProblemExtensionResponse, - fallback_value=[] - ) + with get_db_context() as db_session: + response_content = await problem_service.call_llm_structured( + state=state, + db_session=db_session, + system_prompt=system_prompt, + response_model=ProblemExtensionResponse, + fallback_value=[] + ) logger.info(f"Problem_Extension: 开始处理问题扩展,问题数量: {len(databasets)}") diff --git a/api/app/core/memory/agent/langgraph_graph/nodes/retrieve_nodes.py b/api/app/core/memory/agent/langgraph_graph/nodes/retrieve_nodes.py index 1880357c..06539ad1 100644 --- a/api/app/core/memory/agent/langgraph_graph/nodes/retrieve_nodes.py +++ b/api/app/core/memory/agent/langgraph_graph/nodes/retrieve_nodes.py @@ -6,31 +6,26 @@ import os # ===== 第三方库 ===== from langchain.agents import create_agent from langchain_openai import ChatOpenAI + from app.core.logging_config import get_agent_logger -from app.db import get_db, get_db_context - -from app.schemas import model_schema -from app.services.memory_config_service import MemoryConfigService -from app.services.model_service import ModelConfigService - -from app.core.memory.agent.services.search_service import SearchService -from app.core.memory.agent.utils.llm_tools import ( - COUNTState, - ReadState, - deduplicate_entries, - merge_to_key_value_pairs, -) from app.core.memory.agent.langgraph_graph.tools.tool import ( create_hybrid_retrieval_tool_sync, create_time_retrieval_tool, extract_tool_message_content, ) - +from app.core.memory.agent.services.search_service import SearchService +from app.core.memory.agent.utils.llm_tools import ( + ReadState, + deduplicate_entries, + merge_to_key_value_pairs, +) from app.core.rag.nlp.search import knowledge_retrieval +from app.db import get_db_context +from app.schemas import model_schema +from app.services.memory_config_service import MemoryConfigService +from app.services.model_service import ModelConfigService logger = get_agent_logger(__name__) -db = next(get_db()) - async def rag_config(state): @@ -50,10 +45,12 @@ async def rag_config(state): "reranker_top_k": 10 } return kb_config -async def rag_knowledge(state,question): + + +async def rag_knowledge(state, question): kb_config = await rag_config(state) end_user_id = state.get('end_user_id', '') - user_rag_memory_id=state.get("user_rag_memory_id",'') + user_rag_memory_id = state.get("user_rag_memory_id", '') retrieve_chunks_result = knowledge_retrieval(question, kb_config, [str(end_user_id)]) try: retrieval_knowledge = [i.page_content for i in retrieve_chunks_result] @@ -61,13 +58,13 @@ async def rag_knowledge(state,question): cleaned_query = question raw_results = clean_content logger.info(f" Using RAG storage with memory_id={user_rag_memory_id}") - except Exception : - retrieval_knowledge=[] + except Exception: + retrieval_knowledge = [] clean_content = '' raw_results = '' cleaned_query = question logger.info(f"No content retrieved from knowledge base: {user_rag_memory_id}") - return retrieval_knowledge,clean_content,cleaned_query,raw_results + return retrieval_knowledge, clean_content, cleaned_query, raw_results async def llm_infomation(state: ReadState) -> ReadState: @@ -113,7 +110,7 @@ async def clean_databases(data) -> str: # 收集所有内容 content_list = [] - + # 处理重排序结果 reranked = results.get('reranked_results', {}) if reranked: @@ -141,7 +138,6 @@ async def clean_databases(data) -> str: elif isinstance(item, str): text_parts.append(item) - return '\n'.join(text_parts).strip() except Exception as e: @@ -150,23 +146,23 @@ async def clean_databases(data) -> str: async def retrieve_nodes(state: ReadState) -> ReadState: - ''' 模型信息 ''' - problem_extension=state.get('problem_extension', '')['context'] - storage_type=state.get('storage_type', '') - user_rag_memory_id=state.get('user_rag_memory_id', '') - end_user_id=state.get('end_user_id', '') + problem_extension = state.get('problem_extension', '')['context'] + storage_type = state.get('storage_type', '') + user_rag_memory_id = state.get('user_rag_memory_id', '') + end_user_id = state.get('end_user_id', '') memory_config = state.get('memory_config', None) - original=state.get('data', '') - problem_list=[] - for key,values in problem_extension.items(): + original = state.get('data', '') + problem_list = [] + for key, values in problem_extension.items(): for data in values: problem_list.append(data) logger.info(f"Retrieve: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}") + # 创建异步任务处理单个问题 async def process_question_nodes(idx, question): try: @@ -244,7 +240,7 @@ async def retrieve_nodes(state: ReadState) -> ReadState: send_verify = [] for i, j in zip(keys, val, strict=False): - if j!=['']: + if j != ['']: send_verify.append({ "Query_small": i, "Answer_Small": j @@ -257,15 +253,13 @@ async def retrieve_nodes(state: ReadState) -> ReadState: } logger.info(f"Collected {len(intermediate_outputs)} intermediate outputs from search results") - return {'retrieve':dup_databases} - - + return {'retrieve': dup_databases} async def retrieve(state: ReadState) -> ReadState: # 从state中获取end_user_id import time - start=time.time() + start = time.time() problem_extension = state.get('problem_extension', '')['context'] storage_type = state.get('storage_type', '') user_rag_memory_id = state.get('user_rag_memory_id', '') @@ -283,6 +277,7 @@ async def retrieve(state: ReadState) -> ReadState: with get_db_context() as db: # 使用同步数据库上下文管理器 config_service = MemoryConfigService(db) return await llm_infomation(state) + llm_config = await get_llm_info() api_key_obj = llm_config.api_keys[0] api_key = api_key_obj.api_key @@ -296,11 +291,11 @@ async def retrieve(state: ReadState) -> ReadState: ) time_retrieval_tool = create_time_retrieval_tool(end_user_id) - search_params = { "end_user_id": end_user_id, "return_raw_results": True } - hybrid_retrieval=create_hybrid_retrieval_tool_sync(memory_config, **search_params) + search_params = {"end_user_id": end_user_id, "return_raw_results": True} + hybrid_retrieval = create_hybrid_retrieval_tool_sync(memory_config, **search_params) agent = create_agent( llm, - tools=[time_retrieval_tool,hybrid_retrieval], + tools=[time_retrieval_tool, hybrid_retrieval], system_prompt=f"我是检索专家,可以根据适合的工具进行检索。当前使用的end_user_id是: {end_user_id}" ) @@ -314,7 +309,8 @@ async def retrieve(state: ReadState) -> ReadState: async with SEMAPHORE: # 限制并发 try: if storage_type == "rag" and user_rag_memory_id: - retrieval_knowledge, clean_content, cleaned_query, raw_results = await rag_knowledge(state, question) + retrieval_knowledge, clean_content, cleaned_query, raw_results = await rag_knowledge(state, + question) else: cleaned_query = question # 使用 asyncio 在线程池中运行同步的 agent.invoke @@ -413,5 +409,3 @@ async def retrieve(state: ReadState) -> ReadState: # json.dump(dup_databases, f, indent=4) logger.info(f"Collected {len(intermediate_outputs)} intermediate outputs from search results") return {'retrieve': dup_databases} - - diff --git a/api/app/core/memory/agent/langgraph_graph/nodes/summary_nodes.py b/api/app/core/memory/agent/langgraph_graph/nodes/summary_nodes.py index cf832add..87606bf8 100644 --- a/api/app/core/memory/agent/langgraph_graph/nodes/summary_nodes.py +++ b/api/app/core/memory/agent/langgraph_graph/nodes/summary_nodes.py @@ -1,5 +1,3 @@ - - import os import time @@ -18,22 +16,24 @@ from app.core.memory.agent.utils.redis_tool import store from app.core.memory.agent.utils.session_tools import SessionService from app.core.memory.agent.utils.template_tools import TemplateService from app.core.rag.nlp.search import knowledge_retrieval - -from app.db import get_db +from app.db import get_db_context template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt') logger = get_agent_logger(__name__) -db_session = next(get_db()) + class SummaryNodeService(LLMServiceMixin): """总结节点服务类""" - + def __init__(self): super().__init__() self.template_service = TemplateService(template_root) + # 创建全局服务实例 summary_service = SummaryNodeService() + + async def rag_config(state): user_rag_memory_id = state.get('user_rag_memory_id', '') kb_config = { @@ -51,10 +51,12 @@ async def rag_config(state): "reranker_top_k": 10 } return kb_config -async def rag_knowledge(state,question): + + +async def rag_knowledge(state, question): kb_config = await rag_config(state) end_user_id = state.get('end_user_id', '') - user_rag_memory_id=state.get("user_rag_memory_id",'') + user_rag_memory_id = state.get("user_rag_memory_id", '') retrieve_chunks_result = knowledge_retrieval(question, kb_config, [str(end_user_id)]) try: retrieval_knowledge = [i.page_content for i in retrieve_chunks_result] @@ -62,25 +64,28 @@ async def rag_knowledge(state,question): cleaned_query = question raw_results = clean_content logger.info(f" Using RAG storage with memory_id={user_rag_memory_id}") - except Exception : - retrieval_knowledge=[] + except Exception: + retrieval_knowledge = [] clean_content = '' raw_results = '' cleaned_query = question logger.info(f"No content retrieved from knowledge base: {user_rag_memory_id}") - return retrieval_knowledge,clean_content,cleaned_query,raw_results + return retrieval_knowledge, clean_content, cleaned_query, raw_results + async def summary_history(state: ReadState) -> ReadState: end_user_id = state.get("end_user_id", '') history = await SessionService(store).get_history(end_user_id, end_user_id, end_user_id) return history -async def summary_llm(state: ReadState, history, retrieve_info, template_name, operation_name, response_model,search_mode) -> str: + +async def summary_llm(state: ReadState, history, retrieve_info, template_name, operation_name, response_model, + search_mode) -> str: """ 增强的summary_llm函数,包含更好的错误处理和数据验证 """ data = state.get("data", '') - + # 构建系统提示词 if str(search_mode) == "0": system_prompt = await summary_service.template_service.render_template( @@ -99,18 +104,19 @@ async def summary_llm(state: ReadState, history, retrieve_info, template_name, o ) try: # 使用优化的LLM服务进行结构化输出 - structured = await summary_service.call_llm_structured( - state=state, - db_session=db_session, - system_prompt=system_prompt, - response_model=response_model, - fallback_value=None - ) + with get_db_context() as db_session: + structured = await summary_service.call_llm_structured( + state=state, + db_session=db_session, + system_prompt=system_prompt, + response_model=response_model, + fallback_value=None + ) # 验证结构化响应 if structured is None: logger.warning("LLM返回None,使用默认回答") return "信息不足,无法回答" - + # 根据操作类型提取答案 if operation_name == "summary": aimessages = getattr(structured, 'query_answer', None) or "信息不足,无法回答" @@ -121,16 +127,16 @@ async def summary_llm(state: ReadState, history, retrieve_info, template_name, o else: logger.warning("结构化响应缺少data字段") aimessages = "信息不足,无法回答" - + # 验证答案不为空 if not aimessages or aimessages.strip() == "": aimessages = "信息不足,无法回答" - + return aimessages - + except Exception as e: logger.error(f"结构化输出失败: {e}", exc_info=True) - + # 尝试非结构化输出作为fallback try: logger.info("尝试非结构化输出作为fallback") @@ -140,7 +146,7 @@ async def summary_llm(state: ReadState, history, retrieve_info, template_name, o system_prompt=system_prompt, fallback_message="信息不足,无法回答" ) - + if response and response.strip(): # 简单清理响应 cleaned_response = response.strip() @@ -148,16 +154,17 @@ async def summary_llm(state: ReadState, history, retrieve_info, template_name, o if cleaned_response.startswith('```'): lines = cleaned_response.split('\n') cleaned_response = '\n'.join(lines[1:-1]) - + return cleaned_response else: return "信息不足,无法回答" - + except Exception as fallback_error: logger.error(f"Fallback也失败: {fallback_error}") return "信息不足,无法回答" -async def summary_redis_save(state: ReadState,aimessages) -> ReadState: + +async def summary_redis_save(state: ReadState, aimessages) -> ReadState: data = state.get("data", '') end_user_id = state.get("end_user_id", '') await SessionService(store).save_session( @@ -169,10 +176,12 @@ async def summary_redis_save(state: ReadState,aimessages) -> ReadState: ) await SessionService(store).cleanup_duplicates() logger.info(f"sessionid: {aimessages} 写入成功") -async def summary_prompt(state: ReadState,aimessages,raw_results) -> ReadState: - storage_type=state.get("storage_type",'') - user_rag_memory_id=state.get("user_rag_memory_id",'') - data=state.get("data", '') + + +async def summary_prompt(state: ReadState, aimessages, raw_results) -> ReadState: + storage_type = state.get("storage_type", '') + user_rag_memory_id = state.get("user_rag_memory_id", '') + data = state.get("data", '') input_summary = { "status": "success", "summary_result": aimessages, @@ -189,14 +198,14 @@ async def summary_prompt(state: ReadState,aimessages,raw_results) -> ReadState: "user_rag_memory_id": user_rag_memory_id } } - retrieve={ + retrieve = { "status": "success", "summary_result": aimessages, "storage_type": storage_type, "user_rag_memory_id": user_rag_memory_id, "_intermediate": { "type": "retrieval_summary", - "title":"快速检索", + "title": "快速检索", "summary": aimessages, "query": data, "storage_type": storage_type, @@ -204,17 +213,18 @@ async def summary_prompt(state: ReadState,aimessages,raw_results) -> ReadState: } } - return input_summary,retrieve + return input_summary, retrieve + async def Input_Summary(state: ReadState) -> ReadState: - start=time.time() - storage_type=state.get("storage_type",'') + start = time.time() + storage_type = state.get("storage_type", '') memory_config = state.get('memory_config', None) - user_rag_memory_id=state.get("user_rag_memory_id",'') - data=state.get("data", '') - end_user_id=state.get("end_user_id", '') + user_rag_memory_id = state.get("user_rag_memory_id", '') + data = state.get("data", '') + end_user_id = state.get("end_user_id", '') logger.info(f"Input_Summary: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}") - history = await summary_history( state) + history = await summary_history(state) search_params = { "end_user_id": end_user_id, "question": data, @@ -223,12 +233,13 @@ async def Input_Summary(state: ReadState) -> ReadState: } try: - if storage_type!="rag": - retrieve_info, question, raw_results = await SearchService().execute_hybrid_search(**search_params, memory_config=memory_config) + if storage_type != "rag": + retrieve_info, question, raw_results = await SearchService().execute_hybrid_search(**search_params, + memory_config=memory_config) else: retrieval_knowledge, retrieve_info, question, raw_results = await rag_knowledge(state, data) except Exception as e: - logger.error( f"Input_Summary: hybrid_search failed, using empty results: {e}", exc_info=True ) + logger.error(f"Input_Summary: hybrid_search failed, using empty results: {e}", exc_info=True) retrieve_info, question, raw_results = "", data, [] try: # aimessages=await summary_llm(state,history,retrieve_info,'Retrieve_Summary_prompt.jinja2', @@ -237,8 +248,8 @@ async def Input_Summary(state: ReadState) -> ReadState: summary_result = await summary_prompt(state, retrieve_info, retrieve_info) summary = summary_result[0] except Exception as e: - logger.error( f"Input_Summary failed: {e}", exc_info=True ) - summary= { + logger.error(f"Input_Summary failed: {e}", exc_info=True) + summary = { "status": "fail", "summary_result": "信息不足,无法回答", "storage_type": storage_type, @@ -251,30 +262,31 @@ async def Input_Summary(state: ReadState) -> ReadState: except Exception: duration = 0.0 log_time('检索', duration) - return {"summary":summary} + return {"summary": summary} -async def Retrieve_Summary(state: ReadState)-> ReadState: - retrieve=state.get("retrieve", '') - history = await summary_history( state) + +async def Retrieve_Summary(state: ReadState) -> ReadState: + retrieve = state.get("retrieve", '') + history = await summary_history(state) import json - with open("检索.json","w",encoding='utf-8') as f: + with open("检索.json", "w", encoding='utf-8') as f: f.write(json.dumps(retrieve, indent=4, ensure_ascii=False)) - retrieve=retrieve.get("Expansion_issue", []) - start=time.time() - retrieve_info_str=[] + retrieve = retrieve.get("Expansion_issue", []) + start = time.time() + retrieve_info_str = [] for data in retrieve: - if data=='': - retrieve_info_str='' + if data == '': + retrieve_info_str = '' else: for key, value in data.items(): - if key=='Answer_Small': + if key == 'Answer_Small': for i in value: retrieve_info_str.append(i) - retrieve_info_str=list(set(retrieve_info_str)) - retrieve_info_str='\n'.join(retrieve_info_str) + retrieve_info_str = list(set(retrieve_info_str)) + retrieve_info_str = '\n'.join(retrieve_info_str) - aimessages=await summary_llm(state,history,retrieve_info_str, - 'direct_summary_prompt.jinja2','retrieve_summary',RetrieveSummaryResponse,"1") + aimessages = await summary_llm(state, history, retrieve_info_str, + 'direct_summary_prompt.jinja2', 'retrieve_summary', RetrieveSummaryResponse, "1") if '信息不足,无法回答' not in str(aimessages) or str(aimessages) != "": await summary_redis_save(state, aimessages) if aimessages == '': @@ -286,33 +298,33 @@ async def Retrieve_Summary(state: ReadState)-> ReadState: except Exception: duration = 0.0 log_time('Retrieval summary', duration) - + # 修复协程调用 - 先await,然后访问返回值 summary_result = await summary_prompt(state, aimessages, retrieve_info_str) summary = summary_result[1] - return {"summary":summary} + return {"summary": summary} -async def Summary(state: ReadState)-> ReadState: - start=time.time() +async def Summary(state: ReadState) -> ReadState: + start = time.time() query = state.get("data", '') - verify=state.get("verify", '') - verify_expansion_issue=verify.get("verified_data", '') - retrieve_info_str='' + verify = state.get("verify", '') + verify_expansion_issue = verify.get("verified_data", '') + retrieve_info_str = '' for data in verify_expansion_issue: for key, value in data.items(): - if key=='answer_small': + if key == 'answer_small': for i in value: - retrieve_info_str+=i+'\n' - history=await summary_history(state) + retrieve_info_str += i + '\n' + history = await summary_history(state) data = { "query": query, "history": history, "retrieve_info": retrieve_info_str } - aimessages=await summary_llm(state,history,data, - 'summary_prompt.jinja2','summary',SummaryResponse,0) + aimessages = await summary_llm(state, history, data, + 'summary_prompt.jinja2', 'summary', SummaryResponse, 0) if '信息不足,无法回答' not in str(aimessages) or str(aimessages) != "": await summary_redis_save(state, aimessages) @@ -327,10 +339,12 @@ async def Summary(state: ReadState)-> ReadState: # 修复协程调用 - 先await,然后访问返回值 summary_result = await summary_prompt(state, aimessages, retrieve_info_str) summary = summary_result[1] - return {"summary":summary} -async def Summary_fails(state: ReadState)-> ReadState: - storage_type=state.get("storage_type", '') - user_rag_memory_id=state.get("user_rag_memory_id", '') + return {"summary": summary} + + +async def Summary_fails(state: ReadState) -> ReadState: + storage_type = state.get("storage_type", '') + user_rag_memory_id = state.get("user_rag_memory_id", '') history = await summary_history(state) query = state.get("data", '') verify = state.get("verify", '') @@ -346,12 +360,12 @@ async def Summary_fails(state: ReadState)-> ReadState: "history": history, "retrieve_info": retrieve_info_str } - aimessages = await summary_llm(state, history, data, - 'fail_summary_prompt.jinja2', 'summary', SummaryResponse, 0) - result= { + aimessages = await summary_llm(state, history, data, + 'fail_summary_prompt.jinja2', 'summary', SummaryResponse, 0) + result = { "status": "success", "summary_result": aimessages, "storage_type": storage_type, "user_rag_memory_id": user_rag_memory_id } - return {"summary":result} \ No newline at end of file + return {"summary": result} diff --git a/api/app/core/memory/agent/langgraph_graph/nodes/verification_nodes.py b/api/app/core/memory/agent/langgraph_graph/nodes/verification_nodes.py index b809faf2..3f7b491e 100644 --- a/api/app/core/memory/agent/langgraph_graph/nodes/verification_nodes.py +++ b/api/app/core/memory/agent/langgraph_graph/nodes/verification_nodes.py @@ -1,8 +1,9 @@ +import asyncio import os -from app.core.logging_config import get_agent_logger -from app.db import get_db +from app.core.logging_config import get_agent_logger from app.core.memory.agent.models.verification_models import VerificationResult +from app.core.memory.agent.services.optimized_llm_service import LLMServiceMixin from app.core.memory.agent.utils.llm_tools import ( PROJECT_ROOT_, ReadState, @@ -10,28 +11,30 @@ from app.core.memory.agent.utils.llm_tools import ( from app.core.memory.agent.utils.redis_tool import store from app.core.memory.agent.utils.session_tools import SessionService from app.core.memory.agent.utils.template_tools import TemplateService -from app.core.memory.agent.services.optimized_llm_service import LLMServiceMixin +from app.db import get_db_context template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt') -db_session = next(get_db()) logger = get_agent_logger(__name__) + class VerificationNodeService(LLMServiceMixin): """验证节点服务类""" - + def __init__(self): super().__init__() self.template_service = TemplateService(template_root) + # 创建全局服务实例 verification_service = VerificationNodeService() + async def Verify_prompt(state: ReadState, messages_deal: VerificationResult): """处理验证结果并生成输出格式""" storage_type = state.get('storage_type', '') user_rag_memory_id = state.get('user_rag_memory_id', '') data = state.get('data', '') - + # 将 VerificationItem 对象转换为字典列表 verified_data = [] if messages_deal.expansion_issue: @@ -40,7 +43,7 @@ async def Verify_prompt(state: ReadState, messages_deal: VerificationResult): verified_data.append(item.model_dump()) elif isinstance(item, dict): verified_data.append(item) - + Verify_result = { "status": messages_deal.split_result, "verified_data": verified_data, @@ -58,34 +61,37 @@ async def Verify_prompt(state: ReadState, messages_deal: VerificationResult): } } return Verify_result + + async def Verify(state: ReadState): logger.info("=== Verify 节点开始执行 ===") try: content = state.get('data', '') end_user_id = state.get('end_user_id', '') memory_config = state.get('memory_config', None) - + logger.info(f"Verify: content={content[:50] if content else 'empty'}..., end_user_id={end_user_id}") history = await SessionService(store).get_history(end_user_id, end_user_id, end_user_id) logger.info(f"Verify: 获取历史记录完成,history length={len(history)}") retrieve = state.get("retrieve", {}) - logger.info(f"Verify: retrieve data type={type(retrieve)}, keys={retrieve.keys() if isinstance(retrieve, dict) else 'N/A'}") - + logger.info( + f"Verify: retrieve data type={type(retrieve)}, keys={retrieve.keys() if isinstance(retrieve, dict) else 'N/A'}") + retrieve_expansion = retrieve.get("Expansion_issue", []) if isinstance(retrieve, dict) else [] logger.info(f"Verify: Expansion_issue length={len(retrieve_expansion)}") - + messages = { "Query": content, "Expansion_issue": retrieve_expansion } logger.info("Verify: 开始渲染模板") - + # 生成 JSON schema 以指导 LLM 输出正确格式 json_schema = VerificationResult.model_json_schema() - + system_prompt = await verification_service.template_service.render_template( template_name='split_verify_prompt.jinja2', operation_name='split_verify_prompt', @@ -94,29 +100,30 @@ async def Verify(state: ReadState): json_schema=json_schema ) logger.info(f"Verify: 模板渲染完成,prompt length={len(system_prompt)}") - + # 使用优化的LLM服务,添加超时保护 logger.info("Verify: 开始调用 LLM") try: # 添加 asyncio.wait_for 超时包裹,防止无限等待 # 超时时间设置为 150 秒(比 LLM 配置的 120 秒稍长) - import asyncio - structured = await asyncio.wait_for( - verification_service.call_llm_structured( - state=state, - db_session=db_session, - system_prompt=system_prompt, - response_model=VerificationResult, - fallback_value={ - "query": content, - "history": history if isinstance(history, list) else [], - "expansion_issue": [], - "split_result": "failed", - "reason": "验证失败或超时" - } - ), - timeout=150.0 # 150秒超时 - ) + + with get_db_context() as db_session: + structured = await asyncio.wait_for( + verification_service.call_llm_structured( + state=state, + db_session=db_session, + system_prompt=system_prompt, + response_model=VerificationResult, + fallback_value={ + "query": content, + "history": history if isinstance(history, list) else [], + "expansion_issue": [], + "split_result": "failed", + "reason": "验证失败或超时" + } + ), + timeout=150.0 # 150秒超时 + ) logger.info(f"Verify: LLM 调用完成,result={structured}") except asyncio.TimeoutError: logger.error("Verify: LLM 调用超时(150秒),使用 fallback 值") @@ -127,11 +134,11 @@ async def Verify(state: ReadState): split_result="failed", reason="LLM调用超时" ) - + result = await Verify_prompt(state, structured) logger.info("=== Verify 节点执行完成 ===") return {"verify": result} - + except Exception as e: logger.error(f"Verify 节点执行失败: {e}", exc_info=True) # 返回失败的验证结果 @@ -152,4 +159,4 @@ async def Verify(state: ReadState): "user_rag_memory_id": state.get('user_rag_memory_id', '') } } - } \ No newline at end of file + } diff --git a/api/app/core/memory/agent/langgraph_graph/read_graph.py b/api/app/core/memory/agent/langgraph_graph/read_graph.py index 3476d0ec..cba1b230 100644 --- a/api/app/core/memory/agent/langgraph_graph/read_graph.py +++ b/api/app/core/memory/agent/langgraph_graph/read_graph.py @@ -5,7 +5,6 @@ from langchain_core.messages import HumanMessage from langgraph.constants import START, END from langgraph.graph import StateGraph - from app.db import get_db from app.services.memory_config_service import MemoryConfigService @@ -32,7 +31,6 @@ from app.core.memory.agent.langgraph_graph.routing.routers import ( ) - @asynccontextmanager async def make_read_graph(): """创建并返回 LangGraph 工作流""" @@ -49,7 +47,7 @@ async def make_read_graph(): workflow.add_node("Retrieve_Summary", Retrieve_Summary) workflow.add_node("Summary", Summary) workflow.add_node("Summary_fails", Summary_fails) - + # 添加边 workflow.add_edge(START, "content_input") workflow.add_conditional_edges("content_input", Split_continue) @@ -62,20 +60,20 @@ async def make_read_graph(): workflow.add_edge("Summary_fails", END) workflow.add_edge("Summary", END) - '''-----''' # workflow.add_edge("Retrieve", END) - + # 编译工作流 graph = workflow.compile() yield graph - + except Exception as e: print(f"创建工作流失败: {e}") raise finally: print("工作流创建完成") + async def main(): """主函数 - 运行工作流""" message = "昨天有什么好看的电影" @@ -92,17 +90,19 @@ async def main(): service_name="MemoryAgentService" ) import time - start=time.time() + start = time.time() try: async with make_read_graph() as graph: config = {"configurable": {"thread_id": end_user_id}} # 初始状态 - 包含所有必要字段 - initial_state = {"messages": [HumanMessage(content=message)] ,"search_switch":search_switch,"end_user_id":end_user_id - ,"storage_type":storage_type,"user_rag_memory_id":user_rag_memory_id,"memory_config":memory_config} + initial_state = {"messages": [HumanMessage(content=message)], "search_switch": search_switch, + "end_user_id": end_user_id + , "storage_type": storage_type, "user_rag_memory_id": user_rag_memory_id, + "memory_config": memory_config} # 获取节点更新信息 _intermediate_outputs = [] summary = '' - + async for update_event in graph.astream( initial_state, stream_mode="updates", @@ -110,7 +110,7 @@ async def main(): ): for node_name, node_data in update_event.items(): print(f"处理节点: {node_name}") - + # 处理不同Summary节点的返回结构 if 'Summary' in node_name: if 'InputSummary' in node_data and 'summary_result' in node_data['InputSummary']: @@ -125,23 +125,22 @@ async def main(): spit_data = node_data.get('spit_data', {}).get('_intermediate', None) if spit_data and spit_data != [] and spit_data != {}: _intermediate_outputs.append(spit_data) - + # Problem_Extension 节点 problem_extension = node_data.get('problem_extension', {}).get('_intermediate', None) if problem_extension and problem_extension != [] and problem_extension != {}: _intermediate_outputs.append(problem_extension) - + # Retrieve 节点 retrieve_node = node_data.get('retrieve', {}).get('_intermediate_outputs', None) if retrieve_node and retrieve_node != [] and retrieve_node != {}: _intermediate_outputs.extend(retrieve_node) - + # Verify 节点 verify_n = node_data.get('verify', {}).get('_intermediate', None) if verify_n and verify_n != [] and verify_n != {}: _intermediate_outputs.append(verify_n) - # Summary 节点 summary_n = node_data.get('summary', {}).get('_intermediate', None) if summary_n and summary_n != [] and summary_n != {}: @@ -161,17 +160,20 @@ async def main(): # print(f"=== 最终摘要 ===") print(summary) - + except Exception as e: import traceback traceback.print_exc() + finally: + db_session.close() - end=time.time() - print(100*'y') - print(f"总耗时: {end-start}s") - print(100*'y') + end = time.time() + print(100 * 'y') + print(f"总耗时: {end - start}s") + print(100 * 'y') if __name__ == "__main__": import asyncio + asyncio.run(main()) diff --git a/api/app/core/memory/agent/utils/llm_client_pool.py b/api/app/core/memory/agent/utils/llm_client_pool.py deleted file mode 100644 index fddd54f6..00000000 --- a/api/app/core/memory/agent/utils/llm_client_pool.py +++ /dev/null @@ -1,56 +0,0 @@ - -import asyncio -from typing import Dict, Optional -from app.core.memory.utils.llm.llm_utils import get_llm_client_fast -from app.db import get_db -from app.core.logging_config import get_agent_logger - -logger = get_agent_logger(__name__) - -class LLMClientPool: - """LLM客户端连接池""" - - def __init__(self, max_size: int = 5): - self.max_size = max_size - self.pools: Dict[str, asyncio.Queue] = {} - self.active_clients: Dict[str, int] = {} - - async def get_client(self, llm_model_id: str): - """获取LLM客户端""" - if llm_model_id not in self.pools: - self.pools[llm_model_id] = asyncio.Queue(maxsize=self.max_size) - self.active_clients[llm_model_id] = 0 - - pool = self.pools[llm_model_id] - - try: - # 尝试从池中获取客户端 - client = pool.get_nowait() - logger.debug(f"从池中获取LLM客户端: {llm_model_id}") - return client - except asyncio.QueueEmpty: - # 池为空,创建新客户端 - if self.active_clients[llm_model_id] < self.max_size: - db_session = next(get_db()) - client = get_llm_client_fast(llm_model_id, db_session) - self.active_clients[llm_model_id] += 1 - logger.debug(f"创建新LLM客户端: {llm_model_id}") - return client - else: - # 等待可用客户端 - logger.debug(f"等待LLM客户端可用: {llm_model_id}") - return await pool.get() - - async def return_client(self, llm_model_id: str, client): - """归还LLM客户端到池中""" - if llm_model_id in self.pools: - try: - self.pools[llm_model_id].put_nowait(client) - logger.debug(f"归还LLM客户端到池: {llm_model_id}") - except asyncio.QueueFull: - # 池已满,丢弃客户端 - self.active_clients[llm_model_id] -= 1 - logger.debug(f"池已满,丢弃LLM客户端: {llm_model_id}") - -# 全局客户端池 -llm_client_pool = LLMClientPool() diff --git a/api/app/core/models/base.py b/api/app/core/models/base.py index f5f49af0..dba6717d 100644 --- a/api/app/core/models/base.py +++ b/api/app/core/models/base.py @@ -21,31 +21,55 @@ from pydantic import BaseModel, Field T = TypeVar("T") + class RedBearModelConfig(BaseModel): """模型配置基类""" model_name: str provider: str api_key: str base_url: Optional[str] = None + is_omni: bool = False # 是否为 Omni 模型 # 请求超时时间(秒)- 默认120秒以支持复杂的LLM调用,可通过环境变量 LLM_TIMEOUT 配置 timeout: float = Field(default_factory=lambda: float(os.getenv("LLM_TIMEOUT", "120.0"))) # 最大重试次数 - 默认2次以避免过长等待,可通过环境变量 LLM_MAX_RETRIES 配置 max_retries: int = Field(default_factory=lambda: int(os.getenv("LLM_MAX_RETRIES", "2"))) - concurrency: int = 5 # 并发限流 + concurrency: int = 5 # 并发限流 extra_params: Dict[str, Any] = {} + class RedBearModelFactory: """模型工厂类""" - + @classmethod def get_model_params(cls, config: RedBearModelConfig) -> Dict[str, Any]: """根据提供商获取模型参数""" provider = config.provider.lower() - + # 打印供应商信息用于调试 from app.core.logging_config import get_business_logger logger = get_business_logger() - logger.debug(f"获取模型参数 - Provider: {provider}, Model: {config.model_name}") + logger.debug(f"获取模型参数 - Provider: {provider}, Model: {config.model_name}, is_omni: {config.is_omni}") + + # dashscope 的 omni 模型使用 OpenAI 兼容模式 + if provider == ModelProvider.DASHSCOPE and config.is_omni: + import httpx + if not config.base_url: + config.base_url = "https://dashscope.aliyuncs.com/compatible-mode/v1" + timeout_config = httpx.Timeout( + timeout=config.timeout, + connect=60.0, + read=config.timeout, + write=60.0, + pool=10.0, + ) + return { + "model": config.model_name, + "base_url": config.base_url, + "api_key": config.api_key, + "timeout": timeout_config, + "max_retries": config.max_retries, + **config.extra_params + } if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK, ModelProvider.OLLAMA]: # 使用 httpx.Timeout 对象来设置详细的超时配置 @@ -65,7 +89,7 @@ class RedBearModelFactory: "timeout": timeout_config, "max_retries": config.max_retries, **config.extra_params - } + } elif provider == ModelProvider.DASHSCOPE: # DashScope (通义千问) 使用自己的参数格式 # 注意: DashScopeEmbeddings 不支持 timeout 和 base_url 参数 @@ -82,7 +106,7 @@ class RedBearModelFactory: # region 从 base_url 或 extra_params 获取 from botocore.config import Config as BotoConfig from app.core.models.bedrock_model_mapper import normalize_bedrock_model_id - + max_pool_connections = int(os.getenv("BEDROCK_MAX_POOL_CONNECTIONS", "50")) max_retries = int(os.getenv("BEDROCK_MAX_RETRIES", "2")) # Configure with increased connection pool @@ -90,16 +114,16 @@ class RedBearModelFactory: max_pool_connections=max_pool_connections, retries={'max_attempts': max_retries, 'mode': 'adaptive'} ) - + # 标准化模型 ID(自动转换简化名称为完整 Bedrock Model ID) model_id = normalize_bedrock_model_id(config.model_name) - + params = { "model_id": model_id, "config": boto_config, **config.extra_params } - + # 解析 API key (格式: access_key_id:secret_access_key) if config.api_key and ":" in config.api_key: access_key_id, secret_access_key = config.api_key.split(":", 1) @@ -107,45 +131,52 @@ class RedBearModelFactory: params["aws_secret_access_key"] = secret_access_key elif config.api_key: params["aws_access_key_id"] = config.api_key - + # 设置 region if config.base_url: params["region_name"] = config.base_url elif "region_name" not in params: params["region_name"] = "us-east-1" # 默认区域 - + return params else: raise BusinessException(f"不支持的提供商: {provider}", code=BizCode.PROVIDER_NOT_SUPPORTED) - + @classmethod def get_rerank_model_params(cls, config: RedBearModelConfig) -> Dict[str, Any]: """根据提供商获取模型参数""" provider = config.provider.lower() if provider in [ModelProvider.XINFERENCE, ModelProvider.GPUSTACK]: - return { + return { "model": config.model_name, # "base_url": config.base_url, "jina_api_key": config.api_key, **config.extra_params - } + } else: raise BusinessException(f"不支持的提供商: {provider}", code=BizCode.PROVIDER_NOT_SUPPORTED) -def get_provider_llm_class(config:RedBearModelConfig, type: ModelType=ModelType.LLM) -> type[BaseLLM]: + +def get_provider_llm_class(config: RedBearModelConfig, type: ModelType = ModelType.LLM) -> type[BaseLLM]: """根据模型提供商获取对应的模型类""" provider = config.provider.lower() - if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK] : + + # dashscope 的 omni 模型使用 OpenAI 兼容模式 + if provider == ModelProvider.DASHSCOPE and config.is_omni: + from langchain_openai import ChatOpenAI + return ChatOpenAI + + if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK] : if type == ModelType.LLM: from langchain_openai import OpenAI - return OpenAI + return OpenAI elif type == ModelType.CHAT: from langchain_openai import ChatOpenAI return ChatOpenAI elif provider == ModelProvider.DASHSCOPE: from langchain_community.chat_models import ChatTongyi return ChatTongyi - elif provider == ModelProvider.OLLAMA: + elif provider == ModelProvider.OLLAMA: from langchain_ollama import OllamaLLM return OllamaLLM elif provider == ModelProvider.BEDROCK: @@ -155,15 +186,16 @@ def get_provider_llm_class(config:RedBearModelConfig, type: ModelType=ModelType. else: raise BusinessException(f"不支持的模型提供商: {provider}", code=BizCode.PROVIDER_NOT_SUPPORTED) + def get_provider_embedding_class(provider: str) -> type[Embeddings]: """根据模型提供商获取对应的模型类""" provider = provider.lower() - if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK] : + if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK]: from langchain_openai import OpenAIEmbeddings - return OpenAIEmbeddings + return OpenAIEmbeddings elif provider == ModelProvider.DASHSCOPE: from langchain_community.embeddings import DashScopeEmbeddings - return DashScopeEmbeddings + return DashScopeEmbeddings elif provider == ModelProvider.OLLAMA: from langchain_ollama import OllamaEmbeddings return OllamaEmbeddings @@ -173,14 +205,15 @@ def get_provider_embedding_class(provider: str) -> type[Embeddings]: else: raise BusinessException(f"不支持的模型提供商: {provider}", code=BizCode.PROVIDER_NOT_SUPPORTED) + def get_provider_rerank_class(provider: str): """根据模型提供商获取对应的模型类""" - provider = provider.lower() - if provider in [ModelProvider.XINFERENCE, ModelProvider.GPUSTACK] : + provider = provider.lower() + if provider in [ModelProvider.XINFERENCE, ModelProvider.GPUSTACK]: from langchain_community.document_compressors import JinaRerank - return JinaRerank - # elif provider == ModelProvider.OLLAMA: + return JinaRerank + # elif provider == ModelProvider.OLLAMA: # from langchain_ollama import OllamaEmbeddings # return OllamaEmbeddings else: - raise BusinessException(f"不支持的模型提供商: {provider}", code=BizCode.PROVIDER_NOT_SUPPORTED) \ No newline at end of file + raise BusinessException(f"不支持的模型提供商: {provider}", code=BizCode.PROVIDER_NOT_SUPPORTED) diff --git a/api/app/core/models/scripts/bedrock_models.yaml b/api/app/core/models/scripts/bedrock_models.yaml index e5b91d1c..2c0ab757 100644 --- a/api/app/core/models/scripts/bedrock_models.yaml +++ b/api/app/core/models/scripts/bedrock_models.yaml @@ -6,6 +6,8 @@ models: description: AI21 Labs大语言模型,completion生成模式,256000上下文窗口 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 大语言模型 logo: bedrock @@ -15,6 +17,9 @@ models: description: Amazon Nova大语言模型,支持智能体思考、工具调用、流式工具调用、视觉能力,300000上下文窗口,对话模式 is_deprecated: false is_official: true + capability: + - vision + is_omni: false tags: - 大语言模型 - agent-thought @@ -28,6 +33,9 @@ models: description: Anthropic Claude大语言模型,支持智能体思考、视觉能力、工具调用、流式工具调用、文档处理,200000上下文窗口,对话模式 is_deprecated: false is_official: true + capability: + - vision + is_omni: false tags: - 大语言模型 - agent-thought @@ -42,6 +50,8 @@ models: description: Cohere大语言模型,支持智能体思考、工具调用、流式工具调用,128000上下文窗口,对话模式 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - agent-thought @@ -54,6 +64,9 @@ models: description: DeepSeek大语言模型,支持智能体思考、视觉能力、工具调用、流式工具调用,32768上下文窗口,对话模式 is_deprecated: false is_official: true + capability: + - vision + is_omni: false tags: - 大语言模型 - agent-thought @@ -67,6 +80,8 @@ models: description: Meta Llama大语言模型,支持智能体思考、工具调用,128000上下文窗口,对话模式 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - agent-thought @@ -78,6 +93,8 @@ models: description: Mistral AI大语言模型,支持智能体思考、工具调用,32000上下文窗口,对话模式 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - agent-thought @@ -89,6 +106,8 @@ models: description: OpenAI大语言模型,支持智能体思考、工具调用、流式工具调用,32768上下文窗口,对话模式 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - agent-thought @@ -101,6 +120,8 @@ models: description: Qwen大语言模型,支持智能体思考、工具调用、流式工具调用,32768上下文窗口,对话模式 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - agent-thought @@ -113,6 +134,8 @@ models: description: amazon.rerank-v1:0重排序模型,5120上下文窗口 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 重排序模型 logo: bedrock @@ -122,6 +145,8 @@ models: description: cohere.rerank-v3-5:0重排序模型,5120上下文窗口 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 重排序模型 logo: bedrock @@ -131,6 +156,9 @@ models: description: amazon.nova-2-multimodal-embeddings-v1:0文本嵌入模型,支持视觉能力,8192上下文窗口 is_deprecated: false is_official: true + capability: + - vision + is_omni: false tags: - 文本嵌入模型 - vision @@ -141,6 +169,8 @@ models: description: amazon.titan-embed-text-v1文本嵌入模型,8192上下文窗口 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 文本嵌入模型 logo: bedrock @@ -150,6 +180,8 @@ models: description: amazon.titan-embed-text-v2:0文本嵌入模型,8192上下文窗口 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 文本嵌入模型 logo: bedrock @@ -159,6 +191,8 @@ models: description: Cohere Embed 3 English文本嵌入模型,512上下文窗口 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 文本嵌入模型 logo: bedrock @@ -168,6 +202,8 @@ models: description: Cohere Embed 3 Multilingual文本嵌入模型,512上下文窗口 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 文本嵌入模型 - logo: bedrock + logo: bedrock \ No newline at end of file diff --git a/api/app/core/models/scripts/dashscope_models.yaml b/api/app/core/models/scripts/dashscope_models.yaml index af1c3619..89a16966 100644 --- a/api/app/core/models/scripts/dashscope_models.yaml +++ b/api/app/core/models/scripts/dashscope_models.yaml @@ -6,6 +6,8 @@ models: description: DeepSeek-R1-Distill-Qwen-14B大语言模型,支持智能体思考,32000上下文窗口,对话模式 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - agent-thought @@ -16,6 +18,8 @@ models: description: DeepSeek-R1-Distill-Qwen-32B大语言模型,支持智能体思考,32000上下文窗口,对话模式 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - agent-thought @@ -26,6 +30,8 @@ models: description: DeepSeek-R1大语言模型,支持智能体思考,131072超大上下文窗口,对话模式 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - agent-thought @@ -36,6 +42,8 @@ models: description: DeepSeek-V3.1大语言模型,支持智能体思考,131072超大上下文窗口,对话模式,支持丰富生成参数调节 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - agent-thought @@ -46,6 +54,8 @@ models: description: DeepSeek-V3.2-exp实验版大语言模型,支持智能体思考,131072超大上下文窗口,对话模式,支持丰富生成参数调节 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - agent-thought @@ -56,6 +66,8 @@ models: description: DeepSeek-V3.2大语言模型,支持智能体思考,131072超大上下文窗口,对话模式,支持丰富生成参数调节 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - agent-thought @@ -66,6 +78,8 @@ models: description: DeepSeek-V3大语言模型,支持智能体思考,64000上下文窗口,对话模式,支持文本与JSON格式输出 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - agent-thought @@ -76,6 +90,8 @@ models: description: farui-plus大语言模型,支持多工具调用、智能体思考、流式工具调用,12288上下文窗口,对话模式 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - multi-tool-call @@ -88,6 +104,8 @@ models: description: GLM-4.7大语言模型,支持多工具调用、智能体思考、流式工具调用,202752超大上下文窗口,对话模式 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - multi-tool-call @@ -100,6 +118,9 @@ models: description: qvq-max-latest大语言模型,支持视觉、智能体思考、流式工具调用,131072上下文窗口,对话模式 is_deprecated: false is_official: true + capability: + - vision + is_omni: false tags: - 大语言模型 - vision @@ -112,6 +133,9 @@ models: description: qvq-max大语言模型,支持视觉、智能体思考、流式工具调用,131072上下文窗口,对话模式 is_deprecated: false is_official: true + capability: + - vision + is_omni: false tags: - 大语言模型 - vision @@ -124,6 +148,8 @@ models: description: qwen-coder-turbo-0919代码专用大语言模型,支持智能体思考,131072上下文窗口,对话模式,已废弃 is_deprecated: true is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - 代码模型 @@ -135,6 +161,8 @@ models: description: qwen-max-latest大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式,支持联网搜索 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - multi-tool-call @@ -147,6 +175,8 @@ models: description: qwen-max-longcontext长上下文大语言模型,支持多工具调用、智能体思考、流式工具调用,32000上下文窗口,对话模式,已废弃 is_deprecated: true is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - multi-tool-call @@ -159,6 +189,8 @@ models: description: qwen-max大语言模型,支持多工具调用、智能体思考、流式工具调用,32768上下文窗口,对话模式,支持联网搜索 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - multi-tool-call @@ -171,6 +203,8 @@ models: description: qwen-mt-plus多语言翻译大语言模型,支持智能体思考,16384上下文窗口,对话模式,支持多语种互译与领域翻译适配 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - 翻译模型 @@ -182,6 +216,8 @@ models: description: qwen-mt-turbo轻量化多语言翻译大语言模型,支持智能体思考,16384上下文窗口,对话模式,支持多语种互译与领域翻译适配 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - 翻译模型 @@ -193,6 +229,8 @@ models: description: qwen-plus-0112大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式,支持联网搜索,已废弃 is_deprecated: true is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - multi-tool-call @@ -205,6 +243,8 @@ models: description: qwen-plus-0125大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式,支持联网搜索,已废弃 is_deprecated: true is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - multi-tool-call @@ -217,6 +257,8 @@ models: description: qwen-plus-0723大语言模型,支持多工具调用、智能体思考、流式工具调用,32000上下文窗口,对话模式,支持联网搜索,已废弃 is_deprecated: true is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - multi-tool-call @@ -229,6 +271,8 @@ models: description: qwen-plus-0806大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式,支持联网搜索,已废弃 is_deprecated: true is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - multi-tool-call @@ -241,6 +285,8 @@ models: description: qwen-plus-0919大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式,支持联网搜索,已废弃 is_deprecated: true is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - multi-tool-call @@ -253,6 +299,8 @@ models: description: qwen-plus-1125大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式,支持联网搜索,已废弃 is_deprecated: true is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - multi-tool-call @@ -265,6 +313,8 @@ models: description: qwen-plus-1127大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式,支持联网搜索,已废弃 is_deprecated: true is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - multi-tool-call @@ -277,6 +327,8 @@ models: description: qwen-plus-1220大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式,已废弃 is_deprecated: true is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - multi-tool-call @@ -289,6 +341,10 @@ models: description: qwen-vl-max多模态大模型,支持视觉理解、智能体思考、视频理解,131072上下文窗口,对话模式,未废弃 is_deprecated: false is_official: true + capability: + - vision + - video + is_omni: false tags: - 大语言模型 - 多模态模型 @@ -302,6 +358,10 @@ models: description: qwen-vl-plus-0809多模态大模型,支持视觉理解、智能体思考、视频理解,32768上下文窗口,对话模式,已废弃 is_deprecated: true is_official: true + capability: + - vision + - video + is_omni: false tags: - 大语言模型 - 多模态模型 @@ -315,6 +375,10 @@ models: description: qwen-vl-plus-2025-01-02多模态大模型,支持视觉理解、智能体思考、视频理解,32768上下文窗口,对话模式,未废弃 is_deprecated: false is_official: true + capability: + - vision + - video + is_omni: false tags: - 大语言模型 - 多模态模型 @@ -328,6 +392,10 @@ models: description: qwen-vl-plus-2025-01-25多模态大模型,支持视觉理解、智能体思考、视频理解,131072上下文窗口,对话模式,未废弃 is_deprecated: false is_official: true + capability: + - vision + - video + is_omni: false tags: - 大语言模型 - 多模态模型 @@ -341,6 +409,10 @@ models: description: qwen-vl-plus-latest多模态大模型,支持视觉理解、智能体思考、视频理解,131072上下文窗口,对话模式,未废弃 is_deprecated: false is_official: true + capability: + - vision + - video + is_omni: false tags: - 大语言模型 - 多模态模型 @@ -354,6 +426,10 @@ models: description: qwen-vl-plus多模态大模型,支持视觉理解、智能体思考、视频理解,131072上下文窗口,对话模式,未废弃 is_deprecated: false is_official: true + capability: + - vision + - video + is_omni: false tags: - 大语言模型 - 多模态模型 @@ -367,6 +443,8 @@ models: description: qwen2.5-0.5b-instruct大语言模型,支持多工具调用、智能体思考、流式工具调用,32768上下文窗口,对话模式,未废弃 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - multi-tool-call @@ -379,6 +457,8 @@ models: description: qwen3-14b大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - multi-tool-call @@ -391,6 +471,8 @@ models: description: qwen3-235b-a22b-instruct-2507大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - multi-tool-call @@ -403,6 +485,8 @@ models: description: qwen3-235b-a22b-thinking-2507大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - multi-tool-call @@ -415,6 +499,8 @@ models: description: qwen3-235b-a22b大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - multi-tool-call @@ -427,6 +513,8 @@ models: description: qwen3-30b-a3b-instruct-2507大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - multi-tool-call @@ -439,6 +527,8 @@ models: description: qwen3-30b-a3b大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - multi-tool-call @@ -451,6 +541,8 @@ models: description: qwen3-32b大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - multi-tool-call @@ -463,6 +555,8 @@ models: description: qwen3-4b大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - multi-tool-call @@ -475,6 +569,8 @@ models: description: qwen3-8b大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - multi-tool-call @@ -487,6 +583,8 @@ models: description: qwen3-coder-30b-a3b-instruct大语言模型,支持智能体思考,262144上下文窗口,对话模式 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - 代码模型 @@ -498,6 +596,8 @@ models: description: qwen3-coder-480b-a35b-instruct大语言模型,支持智能体思考,262144上下文窗口,对话模式 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - 代码模型 @@ -509,6 +609,8 @@ models: description: qwen3-coder-plus-2025-09-23大语言模型,支持智能体思考,1000000上下文窗口,对话模式 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - 代码模型 @@ -520,6 +622,8 @@ models: description: qwen3-coder-plus大语言模型,支持智能体思考,1000000上下文窗口,对话模式 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - 代码模型 @@ -531,6 +635,8 @@ models: description: qwen3-max-2025-09-23大语言模型,支持多工具调用、智能体思考、流式工具调用,262144上下文窗口,对话模式,支持联网搜索 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - multi-tool-call @@ -544,6 +650,8 @@ models: description: qwen3-max-2026-01-23大语言模型,支持多工具调用、智能体思考、流式工具调用,262144上下文窗口,对话模式,支持联网搜索 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - multi-tool-call @@ -557,6 +665,8 @@ models: description: qwen3-max-preview大语言模型,支持多工具调用、智能体思考、流式工具调用,262144上下文窗口,对话模式 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - multi-tool-call @@ -569,6 +679,8 @@ models: description: qwen3-max大语言模型,支持多工具调用、智能体思考、流式工具调用,262144上下文窗口,对话模式,支持联网搜索 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - multi-tool-call @@ -582,6 +694,8 @@ models: description: qwen3-next-80b-a3b-instruct大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - multi-tool-call @@ -594,6 +708,8 @@ models: description: qwen3-next-80b-a3b-thinking大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - multi-tool-call @@ -606,6 +722,11 @@ models: description: qwen3-omni-flash-2025-12-01多模态大语言模型,支持视觉、智能体思考、视频、音频能力,65536上下文窗口,对话模式 is_deprecated: false is_official: true + capability: + - vision + - video + - audio + is_omni: true tags: - 大语言模型 - 多模态模型 @@ -620,6 +741,10 @@ models: description: qwen3-vl-235b-a22b-instruct多模态大语言模型,支持多工具调用、智能体思考、流式工具调用、视觉、视频能力,131072上下文窗口,对话模式 is_deprecated: false is_official: true + capability: + - vision + - video + is_omni: false tags: - 大语言模型 - 多模态模型 @@ -635,6 +760,10 @@ models: description: qwen3-vl-235b-a22b-thinking多模态大语言模型,支持多工具调用、智能体思考、流式工具调用、视觉、视频能力,131072上下文窗口,对话模式 is_deprecated: false is_official: true + capability: + - vision + - video + is_omni: false tags: - 大语言模型 - 多模态模型 @@ -650,6 +779,10 @@ models: description: qwen3-vl-30b-a3b-instruct多模态大语言模型,支持多工具调用、智能体思考、流式工具调用、视觉、视频能力,131072上下文窗口,对话模式 is_deprecated: false is_official: true + capability: + - vision + - video + is_omni: false tags: - 大语言模型 - 多模态模型 @@ -665,6 +798,10 @@ models: description: qwen3-vl-30b-a3b-thinking多模态大语言模型,支持多工具调用、智能体思考、流式工具调用、视觉、视频能力,131072上下文窗口,对话模式 is_deprecated: false is_official: true + capability: + - vision + - video + is_omni: false tags: - 大语言模型 - 多模态模型 @@ -680,6 +817,10 @@ models: description: qwen3-vl-flash多模态大语言模型,支持多工具调用、智能体思考、流式工具调用、视觉、视频能力,131072上下文窗口,对话模式 is_deprecated: false is_official: true + capability: + - vision + - video + is_omni: false tags: - 大语言模型 - 多模态模型 @@ -695,6 +836,10 @@ models: description: qwen3-vl-plus-2025-09-23多模态大语言模型,支持视觉、智能体思考、视频能力,262144上下文窗口,对话模式 is_deprecated: false is_official: true + capability: + - vision + - video + is_omni: false tags: - 大语言模型 - 多模态模型 @@ -708,6 +853,10 @@ models: description: qwen3-vl-plus多模态大语言模型,支持视觉、智能体思考、视频能力,262144上下文窗口,对话模式 is_deprecated: false is_official: true + capability: + - vision + - video + is_omni: false tags: - 大语言模型 - 多模态模型 @@ -721,6 +870,8 @@ models: description: qwq-32b大语言模型,支持智能体思考、流式工具调用,131072上下文窗口,对话模式 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - agent-thought @@ -732,6 +883,8 @@ models: description: qwq-plus-0305大语言模型,支持智能体思考、流式工具调用,131072上下文窗口,对话模式 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - agent-thought @@ -743,6 +896,8 @@ models: description: qwq-plus大语言模型,支持智能体思考、流式工具调用,131072上下文窗口,对话模式 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - agent-thought @@ -754,6 +909,8 @@ models: description: gte-rerank-v2重排序模型,4000上下文窗口 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 重排序模型 logo: dashscope @@ -763,6 +920,8 @@ models: description: gte-rerank重排序模型,4000上下文窗口 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 重排序模型 logo: dashscope @@ -772,6 +931,9 @@ models: description: multimodal-embedding-v1多模态嵌入模型,支持视觉能力,8192上下文窗口,最大分块数10 is_deprecated: false is_official: true + capability: + - vision + is_omni: false tags: - 嵌入模型 - 多模态模型 @@ -783,6 +945,8 @@ models: description: text-embedding-v1文本嵌入模型,2048上下文窗口,最大分块数25 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 嵌入模型 - 文本嵌入 @@ -793,6 +957,8 @@ models: description: text-embedding-v2文本嵌入模型,2048上下文窗口,最大分块数25 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 嵌入模型 - 文本嵌入 @@ -803,6 +969,8 @@ models: description: text-embedding-v3文本嵌入模型,8192上下文窗口,最大分块数10 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 嵌入模型 - 文本嵌入 @@ -813,7 +981,9 @@ models: description: text-embedding-v4文本嵌入模型,8192上下文窗口,最大分块数10 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 嵌入模型 - 文本嵌入 - logo: dashscope + logo: dashscope \ No newline at end of file diff --git a/api/app/core/models/scripts/loader.py b/api/app/core/models/scripts/loader.py index a14d3268..e4462efa 100644 --- a/api/app/core/models/scripts/loader.py +++ b/api/app/core/models/scripts/loader.py @@ -6,7 +6,7 @@ from typing import Callable import yaml from sqlalchemy.orm import Session -from app.models.models_model import ModelBase, ModelProvider +from app.models.models_model import ModelBase, ModelProvider, ModelConfig def _load_yaml_config(provider: ModelProvider) -> list[dict]: @@ -55,6 +55,15 @@ def load_models(db: Session, providers: list[str] = None, silent: bool = False) print(f"\n正在加载 {provider.value} 的 {len(models)} 个模型...") for model_data in models: + config_sync_fields = { + "logo": None, + "capability": None, + "is_omni": None, + "name": None, + "provider": None, + "type": None, + "description": None + } try: # 检查模型是否已存在 existing = db.query(ModelBase).filter( @@ -66,6 +75,40 @@ def load_models(db: Session, providers: list[str] = None, silent: bool = False) # 更新现有模型配置 for key, value in model_data.items(): setattr(existing, key, value) + + # 更新绑定了该 model_id 的 ModelConfig 和 ModelApiKey + sync_fields = [k for k in config_sync_fields.keys() if k in model_data] + if sync_fields: + # 批量更新 ModelConfig + update_kwargs = {k: model_data[k] for k in sync_fields} + db.query(ModelConfig).filter(ModelConfig.model_id == existing.id).update( + update_kwargs, + synchronize_session=False + ) + + # 更新 ModelApiKey 的 capability 和 is_omni + if 'capability' in model_data or 'is_omni' in model_data: + from app.models.models_model import ModelApiKey, model_config_api_key_association + api_key_update = {} + if 'capability' in model_data: + api_key_update['capability'] = model_data['capability'] + if 'is_omni' in model_data: + api_key_update['is_omni'] = model_data['is_omni'] + + if api_key_update: + # 查找所有关联的 API Key + api_key_ids = db.query(model_config_api_key_association.c.api_key_id).join( + ModelConfig, + ModelConfig.id == model_config_api_key_association.c.model_config_id + ).filter(ModelConfig.model_id == existing.id).distinct().all() + + if api_key_ids: + api_key_ids = [aid[0] for aid in api_key_ids] + db.query(ModelApiKey).filter(ModelApiKey.id.in_(api_key_ids)).update( + api_key_update, + synchronize_session=False + ) + db.commit() if not silent: print(f"更新成功: {model_data['name']}") diff --git a/api/app/core/models/scripts/openai_models.yaml b/api/app/core/models/scripts/openai_models.yaml index 68c63ee2..7f6d3a51 100644 --- a/api/app/core/models/scripts/openai_models.yaml +++ b/api/app/core/models/scripts/openai_models.yaml @@ -6,12 +6,19 @@ models: description: chatgpt-4o-latest大语言模型,支持多工具调用、智能体思考、流式工具调用、视觉能力,128000上下文窗口,对话模式 is_deprecated: false is_official: true + capability: + - vision + - audio + - video + is_omni: true tags: - 大语言模型 - multi-tool-call - agent-thought - stream-tool-call - vision + - audio + - video logo: openai - name: gpt-3.5-turbo-0125 type: llm @@ -19,6 +26,8 @@ models: description: gpt-3.5-turbo-0125大语言模型,支持多工具调用、智能体思考、流式工具调用,16385上下文窗口,对话模式 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - multi-tool-call @@ -31,6 +40,8 @@ models: description: gpt-3.5-turbo-1106大语言模型,支持多工具调用、智能体思考、流式工具调用,16385上下文窗口,对话模式 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - multi-tool-call @@ -43,6 +54,8 @@ models: description: gpt-3.5-turbo-16k大语言模型,支持多工具调用、智能体思考、流式工具调用,16385上下文窗口,对话模式 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - multi-tool-call @@ -55,6 +68,8 @@ models: description: gpt-3.5-turbo-instruct大语言模型,4096上下文窗口,文本补全模式 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 大语言模型 logo: openai @@ -64,6 +79,8 @@ models: description: gpt-3.5-turbo大语言模型,支持多工具调用、智能体思考、流式工具调用,16385上下文窗口,对话模式 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - multi-tool-call @@ -76,6 +93,8 @@ models: description: gpt-4-0125-preview大语言模型,支持多工具调用、智能体思考、流式工具调用,128000上下文窗口,对话模式 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - multi-tool-call @@ -88,6 +107,8 @@ models: description: gpt-4-1106-preview大语言模型,支持多工具调用、智能体思考、流式工具调用,128000上下文窗口,对话模式 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - multi-tool-call @@ -100,6 +121,9 @@ models: description: gpt-4-turbo-2024-04-09大语言模型,支持多工具调用、智能体思考、流式工具调用、视觉能力,128000上下文窗口,对话模式 is_deprecated: false is_official: true + capability: + - vision + is_omni: false tags: - 大语言模型 - multi-tool-call @@ -113,6 +137,8 @@ models: description: gpt-4-turbo-preview大语言模型,支持多工具调用、智能体思考、流式工具调用,128000上下文窗口,对话模式 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - multi-tool-call @@ -125,6 +151,9 @@ models: description: gpt-4-turbo大语言模型,支持多工具调用、智能体思考、流式工具调用、视觉能力,128000上下文窗口,对话模式 is_deprecated: false is_official: true + capability: + - vision + is_omni: false tags: - 大语言模型 - multi-tool-call @@ -138,6 +167,8 @@ models: description: o1-preview大语言模型,支持智能体思考,128000上下文窗口,对话模式,已废弃 is_deprecated: true is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - agent-thought @@ -148,6 +179,9 @@ models: description: o1大语言模型,支持多工具调用、智能体思考、流式工具调用、视觉能力、结构化输出,200000上下文窗口,对话模式 is_deprecated: false is_official: true + capability: + - vision + is_omni: false tags: - 大语言模型 - multi-tool-call @@ -162,6 +196,9 @@ models: description: o3-2025-04-16大语言模型,支持智能体思考、工具调用、视觉能力、流式工具调用、结构化输出,200000上下文窗口,对话模式 is_deprecated: false is_official: true + capability: + - vision + is_omni: false tags: - 大语言模型 - agent-thought @@ -176,6 +213,8 @@ models: description: o3-mini-2025-01-31大语言模型,支持智能体思考、工具调用、流式工具调用、结构化输出,200000上下文窗口,对话模式 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - agent-thought @@ -189,6 +228,8 @@ models: description: o3-mini大语言模型,支持智能体思考、工具调用、流式工具调用、结构化输出,200000上下文窗口,对话模式 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - agent-thought @@ -202,6 +243,9 @@ models: description: o3-pro-2025-06-10大语言模型,支持智能体思考、工具调用、视觉能力、结构化输出,200000上下文窗口,对话模式 is_deprecated: false is_official: true + capability: + - vision + is_omni: false tags: - 大语言模型 - agent-thought @@ -215,6 +259,9 @@ models: description: o3-pro大语言模型,支持智能体思考、工具调用、视觉能力、结构化输出,200000上下文窗口,对话模式 is_deprecated: false is_official: true + capability: + - vision + is_omni: false tags: - 大语言模型 - agent-thought @@ -228,6 +275,9 @@ models: description: o3大语言模型,支持智能体思考、视觉能力、工具调用、流式工具调用、结构化输出,200000上下文窗口,对话模式 is_deprecated: false is_official: true + capability: + - vision + is_omni: false tags: - 大语言模型 - agent-thought @@ -242,6 +292,9 @@ models: description: o4-mini-2025-04-16大语言模型,支持智能体思考、工具调用、视觉能力、流式工具调用、结构化输出,200000上下文窗口,对话模式 is_deprecated: false is_official: true + capability: + - vision + is_omni: false tags: - 大语言模型 - agent-thought @@ -256,6 +309,9 @@ models: description: o4-mini大语言模型,支持智能体思考、工具调用、视觉能力、流式工具调用、结构化输出,200000上下文窗口,对话模式 is_deprecated: false is_official: true + capability: + - vision + is_omni: false tags: - 大语言模型 - agent-thought @@ -270,6 +326,8 @@ models: description: text-embedding-3-large文本向量模型,8191上下文窗口,最大分块数32 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 文本向量模型 logo: openai @@ -279,6 +337,8 @@ models: description: text-embedding-3-small文本向量模型,8191上下文窗口,最大分块数32 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 文本向量模型 logo: openai @@ -288,6 +348,8 @@ models: description: text-embedding-ada-002文本向量模型,8097上下文窗口,最大分块数32 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 文本向量模型 - logo: openai + logo: openai \ No newline at end of file diff --git a/api/app/core/workflow/adapters/dify/converter.py b/api/app/core/workflow/adapters/dify/converter.py index 2014b4c3..06c988d3 100644 --- a/api/app/core/workflow/adapters/dify/converter.py +++ b/api/app/core/workflow/adapters/dify/converter.py @@ -98,7 +98,7 @@ class DifyConverter(BaseConverter): if not var_selector: return "" selector = var_selector.split('.') - if len(selector) not in [2, 3]: + if len(selector) not in [2, 3] and var_selector != "context": raise Exception(f"invalid variable selector: {var_selector}") if len(selector) == 3: selector = selector[1:] @@ -332,7 +332,9 @@ class DifyConverter(BaseConverter): messages.append( MessageConfig( role="user", - content=self.trans_variable_format(node_data["memory"]["query_prompt_template"]) + content=self.trans_variable_format( + node_data["memory"].get("query_prompt_template", "{{#sys.query#}}") + ) ) ) vision = node_data["vision"]["enabled"] diff --git a/api/app/core/workflow/adapters/dify/dify_adapter.py b/api/app/core/workflow/adapters/dify/dify_adapter.py index dcd14c7f..6336b1f9 100644 --- a/api/app/core/workflow/adapters/dify/dify_adapter.py +++ b/api/app/core/workflow/adapters/dify/dify_adapter.py @@ -80,7 +80,7 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter): return True def validate_config(self) -> bool: - require_fields = frozenset({'app', 'dependencies', 'kind', 'version', 'workflow'}) + require_fields = frozenset({'app', 'kind', 'version', 'workflow'}) if not all(field in self.config for field in require_fields): return False diff --git a/api/app/core/workflow/engine/variable_pool.py b/api/app/core/workflow/engine/variable_pool.py index d08f47e5..bc88df19 100644 --- a/api/app/core/workflow/engine/variable_pool.py +++ b/api/app/core/workflow/engine/variable_pool.py @@ -303,38 +303,52 @@ class VariablePool: """ return self._get_variable_struct(selector) is not None - def get_all_system_vars(self) -> dict[str, Any]: + def get_all_system_vars(self, literal=False) -> dict[str, Any]: """获取所有系统变量 Returns: 系统变量字典 """ sys_namespace = self.variables.get("sys", {}) + if literal: + return {k: v.instance.to_literal() for k, v in sys_namespace.items()} return {k: v.instance.get_value() for k, v in sys_namespace.items()} - def get_all_conversation_vars(self) -> dict[str, Any]: + def get_all_conversation_vars(self, literal=False) -> dict[str, Any]: """获取所有会话变量 Returns: 会话变量字典 """ conv_namespace = self.variables.get("conv", {}) + if literal: + return {k: v.instance.to_literal() for k, v in conv_namespace.items()} return {k: v.instance.get_value() for k, v in conv_namespace.items()} - def get_all_node_outputs(self) -> dict[str, Any]: + def get_all_node_outputs(self, literal=False) -> dict[str, Any]: """获取所有节点输出(运行时变量) Returns: 节点输出字典,键为节点 ID """ - runtime_vars = { - namespace: { - k: v.instance.get_value() - for k, v in vars_dict.items() + if literal: + runtime_vars = { + namespace: { + k: v.instance.to_literal() + for k, v in vars_dict.items() + } + for namespace, vars_dict in self.variables.items() + if namespace not in ("sys", "conv") + } + else: + runtime_vars = { + namespace: { + k: v.instance.get_value() + for k, v in vars_dict.items() + } + for namespace, vars_dict in self.variables.items() + if namespace not in ("sys", "conv") } - for namespace, vars_dict in self.variables.items() - if namespace not in ("sys", "conv") - } return runtime_vars def get_node_output(self, node_id: str, defalut: Any = None, strict: bool = True) -> dict[str, Any] | None: diff --git a/api/app/core/workflow/nodes/agent/node.py b/api/app/core/workflow/nodes/agent/node.py index 98d8bb75..8959e27c 100644 --- a/api/app/core/workflow/nodes/agent/node.py +++ b/api/app/core/workflow/nodes/agent/node.py @@ -14,9 +14,9 @@ from app.core.workflow.engine.state_manager import WorkflowState from app.core.workflow.engine.variable_pool import VariablePool from app.core.workflow.nodes.base_node import BaseNode from app.core.workflow.variable.base_variable import VariableType -from app.db import get_db +from app.db import get_db_context from app.models import AppRelease -from app.services.draft_run_service import DraftRunService +from app.services.draft_run_service import AgentRunService logger = logging.getLogger(__name__) @@ -39,7 +39,7 @@ class AgentNode(BaseNode): def _output_types(self) -> dict[str, VariableType]: return {"output": VariableType.STRING} - def _prepare_agent(self, variable_pool: VariablePool) -> tuple[DraftRunService, AppRelease, str]: + def _prepare_agent(self, variable_pool: VariablePool) -> tuple[AppRelease, str]: """准备 Agent(公共逻辑) Args: @@ -57,17 +57,17 @@ class AgentNode(BaseNode): if not agent_id: raise ValueError(f"节点 {self.node_id} 缺少 agent_id 配置") - db = next(get_db()) - release = db.query(AppRelease).filter( - AppRelease.id == agent_id - ).first() + with get_db_context() as db: + release = db.query(AppRelease).filter( + AppRelease.id == agent_id + ).first() if not release: raise ValueError(f"Agent 不存在: {agent_id}") - draft_service = DraftRunService(db) + - return draft_service, release, message + return release, message async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]: """非流式执行 @@ -79,19 +79,21 @@ class AgentNode(BaseNode): Returns: 状态更新字典 """ - draft_service, release, message = self._prepare_agent(variable_pool) + release, message = self._prepare_agent(variable_pool) logger.info(f"节点 {self.node_id} 开始执行 Agent 调用(非流式)") - - # 执行 Agent(非流式) - result = await draft_service.run( - agent_config=release.config, - model_config=None, - message=message, - workspace_id=variable_pool.get_value("sys.workspace_id"), - user_id=state.get("user_id"), - variables=variable_pool.get_all_conversation_vars() - ) + with get_db_context() as db: + draft_service = AgentRunService(db) + + # 执行 Agent(非流式) + result = await draft_service.run( + agent_config=release.config, + model_config=None, + message=message, + workspace_id=variable_pool.get_value("sys.workspace_id"), + user_id=state.get("user_id"), + variables=variable_pool.get_all_conversation_vars() + ) response = result.get("response", "") @@ -118,34 +120,35 @@ class AgentNode(BaseNode): Yields: 流式事件字典 """ - draft_service, release, message = self._prepare_agent(variable_pool) + release, message = self._prepare_agent(variable_pool) logger.info(f"节点 {self.node_id} 开始执行 Agent 调用(流式)") # 累积完整响应 full_response = "" - + with get_db_context() as db: + draft_service = AgentRunService(db) # 执行 Agent(流式) - async for chunk in draft_service.run_stream( - agent_config=release.config, - model_config=None, - message=message, - workspace_id=variable_pool.get_value("sys.workspace_id"), - user_id=state.get("user_id"), - variables=variable_pool.get_all_conversation_vars() - ): - # 提取内容 - content = chunk.get("content", "") - full_response += content - - # 流式返回每个 chunk - yield { - "type": "chunk", - "node_id": self.node_id, - "content": content, - "full_content": full_response, - "meta_data": chunk.get("meta_data", {}) - } + async for chunk in draft_service.run_stream( + agent_config=release.config, + model_config=None, + message=message, + workspace_id=variable_pool.get_value("sys.workspace_id"), + user_id=state.get("user_id"), + variables=variable_pool.get_all_conversation_vars() + ): + # 提取内容 + content = chunk.get("content", "") + full_response += content + + # 流式返回每个 chunk + yield { + "type": "chunk", + "node_id": self.node_id, + "content": content, + "full_content": full_response, + "meta_data": chunk.get("meta_data", {}) + } logger.info(f"节点 {self.node_id} Agent 调用完成,输出长度: {len(full_response)}") diff --git a/api/app/core/workflow/nodes/base_node.py b/api/app/core/workflow/nodes/base_node.py index 3e30c00e..3f30718c 100644 --- a/api/app/core/workflow/nodes/base_node.py +++ b/api/app/core/workflow/nodes/base_node.py @@ -1,5 +1,6 @@ import asyncio import logging +import uuid from abc import ABC, abstractmethod from functools import cached_property from typing import Any, AsyncGenerator @@ -10,8 +11,10 @@ from app.core.config import settings from app.core.workflow.engine.state_manager import WorkflowState from app.core.workflow.engine.variable_pool import VariablePool from app.core.workflow.nodes.enums import BRANCH_NODES -from app.core.workflow.variable.base_variable import VariableType -from app.services.multimodal_service import PROVIDER_STRATEGIES +from app.core.workflow.variable.base_variable import VariableType, FileObject +from app.db import get_db_read +from app.schemas import FileInput +from app.services.multimodal_service import MultimodalService logger = logging.getLogger(__name__) @@ -548,9 +551,9 @@ class BaseNode(ABC): return render_template( template=template, - conv_vars=variable_pool.get_all_conversation_vars(), - node_outputs=variable_pool.get_all_node_outputs(), - system_vars=variable_pool.get_all_system_vars(), + conv_vars=variable_pool.get_all_conversation_vars(literal=True), + node_outputs=variable_pool.get_all_node_outputs(literal=True), + system_vars=variable_pool.get_all_system_vars(literal=True), strict=strict ) @@ -614,16 +617,32 @@ class BaseNode(ABC): return variable_pool.has(selector) @staticmethod - async def process_message(provider, content, enable_file=False) -> dict | str | None: + async def process_message(provider: str, content: str | FileObject, enable_file=False) -> dict | str | None: if isinstance(content, str): if enable_file: return {"text": content} return content - elif isinstance(content, dict): - trans_tool = PROVIDER_STRATEGIES[provider]() - result = await trans_tool.format_image(content["url"]) - return result - raise TypeError('Unexpect input value type') + + elif isinstance(content, FileObject): + if content.content_cache.get(provider): + return content.content_cache[provider] + with get_db_read() as db: + multimodel_service = MultimodalService(db, provider) + message = await multimodel_service.process_files( + [FileInput.model_construct( + type=content.type, + url=content.url, + transfer_method=content.transfer_method, + file_type=content.origin_file_type, + upload_file_id=content.file_id + )] + ) + + if message: + content.content_cache[provider] = message[0] + return message[0] + return None + raise TypeError(f'Unexpect input value type - {type(content)}') @staticmethod def process_model_output(content) -> str: diff --git a/api/app/core/workflow/nodes/cycle_graph/iteration.py b/api/app/core/workflow/nodes/cycle_graph/iteration.py index e4026f2d..cf7ac976 100644 --- a/api/app/core/workflow/nodes/cycle_graph/iteration.py +++ b/api/app/core/workflow/nodes/cycle_graph/iteration.py @@ -91,8 +91,8 @@ class IterationRuntime: return loopstate def merge_conv_vars(self): - self.variable_pool.get_all_conversation_vars().update( - self.child_variable_pool.get_all_conversation_vars() + self.variable_pool.variables["conv"].update( + self.child_variable_pool.variables["conv"] ) async def run_task(self, item, idx): diff --git a/api/app/core/workflow/nodes/cycle_graph/loop.py b/api/app/core/workflow/nodes/cycle_graph/loop.py index cebadfdc..d3ada1ec 100644 --- a/api/app/core/workflow/nodes/cycle_graph/loop.py +++ b/api/app/core/workflow/nodes/cycle_graph/loop.py @@ -156,7 +156,7 @@ class LoopRuntime: def merge_conv_vars(self, loopstate): self.variable_pool.variables["conv"].update( - self.child_variable_pool.variables.get("conv", {}) + self.child_variable_pool.variables["conv"] ) loop_vars = self.child_variable_pool.get_node_output(self.node_id, defalut={}, strict=False) loopstate["node_outputs"][self.node_id] = loop_vars diff --git a/api/app/core/workflow/nodes/llm/node.py b/api/app/core/workflow/nodes/llm/node.py index fdd5df58..c109d59b 100644 --- a/api/app/core/workflow/nodes/llm/node.py +++ b/api/app/core/workflow/nodes/llm/node.py @@ -172,9 +172,9 @@ class LLMNode(BaseNode): if self.typed_config.vision_input and self.typed_config.vision: file_content = [] - files = variable_pool.get_value(self.typed_config.vision_input) - for file in files: - content = await self.process_message(provider, file, self.typed_config.vision) + files = variable_pool.get_instance(self.typed_config.vision_input) + for file in files.value: + content = await self.process_message(provider, file.value, self.typed_config.vision) if content: file_content.append(content) if messages and messages[-1]["role"] == 'user': diff --git a/api/app/core/workflow/variable/base_variable.py b/api/app/core/workflow/variable/base_variable.py index 19cbdc74..dd821ea7 100644 --- a/api/app/core/workflow/variable/base_variable.py +++ b/api/app/core/workflow/variable/base_variable.py @@ -2,7 +2,7 @@ from enum import StrEnum from abc import abstractmethod, ABC from typing import Any -from pydantic import BaseModel +from pydantic import BaseModel, Field from app.schemas import FileType @@ -45,7 +45,7 @@ class VariableType(StrEnum): return cls.NUMBER elif isinstance(var, bool): return cls.BOOLEAN - elif isinstance(var, FileObject) or (isinstance(var, dict) and var.get('__file')): + elif isinstance(var, FileObject) or (isinstance(var, dict) and var.get('is_file')): return cls.FILE elif isinstance(var, dict): return cls.OBJECT @@ -109,7 +109,13 @@ def DEFAULT_VALUE(var_type: VariableType) -> Any: class FileObject(BaseModel): type: FileType url: str - __file: bool + transfer_method: str + origin_file_type: str + file_id: str | None + + content_cache: dict = Field(default_factory=dict) + + is_file: bool class BaseVariable(ABC): diff --git a/api/app/core/workflow/variable/variable_objects.py b/api/app/core/workflow/variable/variable_objects.py index 49541afc..63437fd9 100644 --- a/api/app/core/workflow/variable/variable_objects.py +++ b/api/app/core/workflow/variable/variable_objects.py @@ -63,13 +63,16 @@ class FileVariable(BaseVariable): def valid_value(self, value) -> FileObject: if isinstance(value, dict): - if not value.get("__file"): + if not value.get("is_file"): raise TypeError(f"Value must be a FileObject - {type(value)}:{value}") return FileObject( **{ "type": str(value.get('type')), + "transfer_method": value.get("transfer_method"), "url": value.get('url'), - "__file": True + "file_id": value.get("file_id"), + "origin_file_type": value.get("origin_file_type"), + "is_file": True } ) if isinstance(value, FileObject): diff --git a/api/app/models/__init__.py b/api/app/models/__init__.py index b1b723e9..c6098a6d 100644 --- a/api/app/models/__init__.py +++ b/api/app/models/__init__.py @@ -35,6 +35,7 @@ from .ontology_scene import OntologyScene from .ontology_class import OntologyClass from .ontology_scene import OntologyScene from .ontology_class import OntologyClass +from .implicit_emotions_storage_model import ImplicitEmotionsStorage __all__ = [ "Tenants", @@ -90,5 +91,6 @@ __all__ = [ "MemoryPerceptualModel", "ModelBase", "LoadBalanceStrategy", - "Skill" + "Skill", + "ImplicitEmotionsStorage" ] diff --git a/api/app/models/implicit_emotions_storage_model.py b/api/app/models/implicit_emotions_storage_model.py new file mode 100644 index 00000000..cf654950 --- /dev/null +++ b/api/app/models/implicit_emotions_storage_model.py @@ -0,0 +1,45 @@ +""" +Implicit Emotions Storage Model + +数据库模型:存储用户的隐性记忆画像和情绪建议数据 +替代原有的Redis缓存方式 +""" +import uuid +from datetime import datetime +from sqlalchemy import Column, String, Text, DateTime, Index +from sqlalchemy.dialects.postgresql import UUID, JSONB +from app.db import Base + + +class ImplicitEmotionsStorage(Base): + """隐性记忆和情绪存储表""" + + __tablename__ = "implicit_emotions_storage" + + # 主键 + id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, comment="主键ID") + + # 用户标识(unique=True会自动创建唯一索引) + end_user_id = Column(String(255), nullable=False, unique=True, comment="终端用户ID") + + # 隐性记忆画像数据(JSON格式) + implicit_profile = Column(JSONB, nullable=True, comment="隐性记忆用户画像数据") + + # 情绪建议数据(JSON格式) + emotion_suggestions = Column(JSONB, nullable=True, comment="情绪个性化建议数据") + + # 时间戳 + created_at = Column(DateTime, nullable=False, default=datetime.utcnow, comment="创建时间") + updated_at = Column(DateTime, nullable=False, default=datetime.utcnow, onupdate=datetime.utcnow, comment="更新时间") + + # 数据生成时间(用于业务逻辑) + implicit_generated_at = Column(DateTime, nullable=True, comment="隐性记忆画像生成时间") + emotion_generated_at = Column(DateTime, nullable=True, comment="情绪建议生成时间") + + # 索引(只为updated_at创建索引,end_user_id的unique约束已自动创建索引) + __table_args__ = ( + Index('idx_updated_at', 'updated_at'), + ) + + def __repr__(self): + return f"" diff --git a/api/app/models/models_model.py b/api/app/models/models_model.py index 3e378f17..23fafcef 100644 --- a/api/app/models/models_model.py +++ b/api/app/models/models_model.py @@ -2,7 +2,7 @@ import datetime import uuid from enum import StrEnum -from sqlalchemy import Column, String, Boolean, DateTime, Text, ForeignKey, Enum as SQLEnum, UniqueConstraint, Integer, ARRAY, Table +from sqlalchemy import Column, String, Boolean, DateTime, Text, ForeignKey, Enum as SQLEnum, UniqueConstraint, Integer, ARRAY, Table, text from sqlalchemy.dialects.postgresql import UUID, JSON from sqlalchemy.orm import relationship from sqlalchemy.sql import func @@ -78,6 +78,9 @@ class ModelConfig(BaseModel): description = Column(String, comment="模型描述") # 模型配置参数 + capability = Column(ARRAY(String), default=list, nullable=False, server_default=text("'{}'::varchar[]"), + comment="模型能力列表(如['vision', 'audio', 'video'])") + is_omni = Column(Boolean, default=False, nullable=False, server_default="false", comment="是否为Omni模型(使用特殊API调用)") config = Column(JSON, comment="模型配置参数") # - temperature : 控制生成文本的随机性。值越高,输出越随机、越有创造性;值越低,输出越确定、越保守。 # - top_p : 一种替代 temperature 的采样方法,控制模型从概率最高的词中选择的范围。 @@ -118,6 +121,11 @@ class ModelApiKey(BaseModel): api_key = Column(String, nullable=False, comment="API密钥") api_base = Column(String, comment="API基础URL") + # 模型能力参数 + capability = Column(ARRAY(String), default=list, nullable=False, server_default=text("'{}'::varchar[]"), + comment="模型能力列表(如['vision', 'audio', 'video'])") + is_omni = Column(Boolean, default=False, nullable=False, server_default="false", comment="是否为Omni模型(使用特殊API调用)") + # 配置参数 config = Column(JSON, comment="API Key特定配置") @@ -155,6 +163,9 @@ class ModelBase(Base): tags = Column(ARRAY(String), default=list, nullable=False, comment="模型标签(如['聊天', '创作'])") add_count = Column(Integer, default=0, nullable=False, comment="模型被用户添加的次数") created_at = Column(DateTime, default=datetime.datetime.now, comment="创建时间", server_default=func.now()) + capability = Column(ARRAY(String), default=list, nullable=False, server_default=text("'{}'::varchar[]"), + comment="模型能力列表(如['vision', 'audio', 'video'])") + is_omni = Column(Boolean, default=False, nullable=False, server_default="false", comment="是否为Omni模型(使用特殊API调用)") # 关联关系 configs = relationship("ModelConfig", back_populates="model_base", cascade="all, delete-orphan") diff --git a/api/app/repositories/implicit_emotions_storage_repository.py b/api/app/repositories/implicit_emotions_storage_repository.py new file mode 100644 index 00000000..97405ab6 --- /dev/null +++ b/api/app/repositories/implicit_emotions_storage_repository.py @@ -0,0 +1,169 @@ +""" +Implicit Emotions Storage Repository + +数据访问层:处理隐性记忆和情绪数据的数据库操作 +事务由调用方控制,仓储层只使用 flush/refresh +""" +import logging +from datetime import datetime, date, timezone, timedelta +from typing import Optional, Generator +from sqlalchemy.orm import Session +from sqlalchemy import select, not_, exists + +from app.models.implicit_emotions_storage_model import ImplicitEmotionsStorage +from app.models.end_user_model import EndUser + +logger = logging.getLogger(__name__) + + +class ImplicitEmotionsStorageRepository: + """隐性记忆和情绪存储仓储类""" + + def __init__(self, db: Session): + self.db = db + + def get_by_end_user_id(self, end_user_id: str) -> Optional[ImplicitEmotionsStorage]: + """根据终端用户ID获取存储记录""" + try: + stmt = select(ImplicitEmotionsStorage).where( + ImplicitEmotionsStorage.end_user_id == end_user_id + ) + return self.db.execute(stmt).scalar_one_or_none() + except Exception as e: + logger.error(f"获取用户存储记录失败: end_user_id={end_user_id}, error={e}") + return None + + def create(self, end_user_id: str) -> ImplicitEmotionsStorage: + """创建新的存储记录(事务由调用方提交)""" + storage = ImplicitEmotionsStorage( + end_user_id=end_user_id, + created_at=datetime.utcnow(), + updated_at=datetime.utcnow() + ) + self.db.add(storage) + self.db.flush() + self.db.refresh(storage) + logger.info(f"创建用户存储记录成功: end_user_id={end_user_id}") + return storage + + def update_implicit_profile( + self, + end_user_id: str, + profile_data: dict + ) -> ImplicitEmotionsStorage: + """更新隐性记忆画像数据(事务由调用方提交)""" + storage = self.get_by_end_user_id(end_user_id) + if storage is None: + storage = self.create(end_user_id) + + storage.implicit_profile = profile_data + storage.implicit_generated_at = datetime.utcnow() + storage.updated_at = datetime.utcnow() + + self.db.flush() + self.db.refresh(storage) + logger.info(f"更新隐性记忆画像成功: end_user_id={end_user_id}") + return storage + + def update_emotion_suggestions( + self, + end_user_id: str, + suggestions_data: dict + ) -> ImplicitEmotionsStorage: + """更新情绪建议数据(事务由调用方提交)""" + storage = self.get_by_end_user_id(end_user_id) + if storage is None: + storage = self.create(end_user_id) + + storage.emotion_suggestions = suggestions_data + storage.emotion_generated_at = datetime.utcnow() + storage.updated_at = datetime.utcnow() + + self.db.flush() + self.db.refresh(storage) + logger.info(f"更新情绪建议成功: end_user_id={end_user_id}") + return storage + + def get_all_user_ids(self, batch_size: int = 100) -> Generator[str, None, None]: + """分批次获取所有已存储数据的用户ID(避免大数据量内存溢出) + + Args: + batch_size: 每批次加载的数量,默认100 + + Yields: + 用户ID字符串 + """ + offset = 0 + while True: + try: + stmt = ( + select(ImplicitEmotionsStorage.end_user_id) + .order_by(ImplicitEmotionsStorage.end_user_id) + .limit(batch_size) + .offset(offset) + ) + batch = self.db.execute(stmt).scalars().all() + if not batch: + break + yield from batch + offset += batch_size + except Exception as e: + logger.error(f"分批获取用户ID失败: offset={offset}, error={e}") + break + + def get_new_user_ids_today(self, batch_size: int = 100) -> Generator[str, None, None]: + """分批次获取当天新增的、尚未初始化隐性记忆和情绪建议数据的用户ID + + 查询逻辑:end_users 表中 created_at 为今天,且在 implicit_emotions_storage 中没有对应记录。 + 没有对应记录意味着隐性记忆画像和情绪建议均未初始化,需要对这批用户执行首次初始化。 + end_users.id(UUID)转为字符串后与 implicit_emotions_storage.end_user_id(String)对比。 + + Args: + batch_size: 每批次加载的数量,默认100 + + Yields: + 用户ID字符串 + """ + from sqlalchemy import cast, String as SAString + CST = timezone(timedelta(hours=8)) + now_cst = datetime.now(CST) + today_start = now_cst.replace(hour=0, minute=0, second=0, microsecond=0).astimezone(timezone.utc).replace(tzinfo=None) + tomorrow_start = today_start + timedelta(days=1) + offset = 0 + while True: + try: + stmt = ( + select(EndUser.id) + .where( + EndUser.created_at >= today_start, + EndUser.created_at < tomorrow_start, + not_( + exists( + select(ImplicitEmotionsStorage.end_user_id).where( + ImplicitEmotionsStorage.end_user_id == cast(EndUser.id, SAString) + ) + ) + ) + ) + .order_by(EndUser.id) + .limit(batch_size) + .offset(offset) + ) + batch = self.db.execute(stmt).scalars().all() + if not batch: + break + yield from (str(uid) for uid in batch) + offset += batch_size + except Exception as e: + logger.error(f"分批获取当天新增用户ID失败: offset={offset}, error={e}") + break + + def delete_by_end_user_id(self, end_user_id: str) -> bool: + """删除用户的存储记录(事务由调用方提交)""" + storage = self.get_by_end_user_id(end_user_id) + if storage: + self.db.delete(storage) + self.db.flush() + logger.info(f"删除用户存储记录成功: end_user_id={end_user_id}") + return True + return False diff --git a/api/app/repositories/knowledge_repository.py b/api/app/repositories/knowledge_repository.py index 681d1c10..e3832214 100644 --- a/api/app/repositories/knowledge_repository.py +++ b/api/app/repositories/knowledge_repository.py @@ -211,3 +211,46 @@ def get_total_kb_count_by_workspace(db: Session, workspace_id: uuid.UUID) -> int except Exception as e: db_logger.error(f"Failed to query total knowledge base count: workspace_id={workspace_id} - {str(e)}") raise + + +def get_user_kb_chunk_num_by_workspace(db: Session, workspace_id: uuid.UUID) -> int: + """ + 根据workspace_id查询knowledges表中permission_id='Memory'(用户知识库)的chunk_num总和 + """ + db_logger.debug(f"Query user KB chunk_num by workspace_id: workspace_id={workspace_id}") + + try: + from sqlalchemy import func + result = db.query(func.sum(Knowledge.chunk_num)).filter( + Knowledge.workspace_id == workspace_id, + Knowledge.status == 1, + Knowledge.permission_id == "Memory" + ).scalar() + + total = result if result is not None else 0 + db_logger.info(f"User KB chunk_num query successful: workspace_id={workspace_id}, total={total}") + return total + except Exception as e: + db_logger.error(f"Failed to query user KB chunk_num: workspace_id={workspace_id} - {str(e)}") + raise + + +def get_non_user_kb_count_by_workspace(db: Session, workspace_id: uuid.UUID) -> int: + """ + 根据workspace_id查询knowledges表中排除用户知识库(permission_id!='Memory')的数量 + """ + db_logger.debug(f"Query non-user KB count by workspace_id: workspace_id={workspace_id}") + + try: + count = db.query(Knowledge).filter( + Knowledge.workspace_id == workspace_id, + Knowledge.status == 1, + Knowledge.permission_id != "Memory" + ).count() + + db_logger.info(f"Non-user KB count query successful: workspace_id={workspace_id}, count={count}") + return count + except Exception as e: + db_logger.error(f"Failed to query non-user KB count: workspace_id={workspace_id} - {str(e)}") + raise + diff --git a/api/app/repositories/ontology_scene_repository.py b/api/app/repositories/ontology_scene_repository.py index 141b5d1c..0b357e41 100644 --- a/api/app/repositories/ontology_scene_repository.py +++ b/api/app/repositories/ontology_scene_repository.py @@ -374,7 +374,7 @@ class OntologySceneRepository: count = self.db.query(OntologyScene).filter( OntologyScene.scene_id == scene_id, - OntologyScene.workspace_id == workspace_id + (OntologyScene.workspace_id == workspace_id) | (OntologyScene.is_system_default == True) ).count() is_owner = count > 0 diff --git a/api/app/schemas/api_key_schema.py b/api/app/schemas/api_key_schema.py index d19cf061..c7ca1e55 100644 --- a/api/app/schemas/api_key_schema.py +++ b/api/app/schemas/api_key_schema.py @@ -15,7 +15,7 @@ class ApiKeyCreate(BaseModel): type: ApiKeyType = Field(..., description="API Key 类型") scopes: List[str] = Field(default_factory=list, description="权限范围列表") resource_id: Optional[uuid.UUID] = Field(None, description="关联资源ID") - rate_limit: Optional[int] = Field(10, ge=1, le=1000, description="QPS限制(请求/秒)") + rate_limit: Optional[int] = Field(100, ge=1, le=1000, description="QPS限制(请求/秒)") daily_request_limit: Optional[int] = Field(10000, description="日请求限制", ge=1) quota_limit: Optional[int] = Field(None, description="配额限制(总请求数)", ge=1) expires_at: Optional[datetime.datetime] = Field(None, description="过期时间") @@ -155,8 +155,7 @@ class ApiKey(BaseModel): return datetime.datetime.now() > self.expires_at @field_serializer('expires_at', 'last_used_at', 'created_at', 'updated_at') - @classmethod - def serialize_datetime(cls, v: Optional[datetime.datetime]) -> Optional[int]: + def serialize_datetime(self, v: Optional[datetime.datetime]) -> Optional[int]: """将datetime转换为时间戳""" return datetime_to_timestamp(v) @@ -171,8 +170,7 @@ class ApiKeyStats(BaseModel): avg_response_time: Optional[float] = Field(None, description="平均响应时间(毫秒)") @field_serializer('last_used_at') - @classmethod - def serialize_datetime(cls, v: Optional[datetime.datetime]) -> Optional[int]: + def serialize_datetime(self, v: Optional[datetime.datetime]) -> Optional[int]: """将datetime转换为时间戳""" return datetime_to_timestamp(v) @@ -219,7 +217,6 @@ class ApiKeyLog(BaseModel): created_at: datetime.datetime @field_serializer('created_at') - @classmethod - def serialize_datetime(cls, v: datetime.datetime) -> int: + def serialize_datetime(self, v: datetime.datetime) -> int: """将datetime转换为时间戳""" return datetime_to_timestamp(v) diff --git a/api/app/schemas/app_schema.py b/api/app/schemas/app_schema.py index 07875e13..f073a200 100644 --- a/api/app/schemas/app_schema.py +++ b/api/app/schemas/app_schema.py @@ -21,8 +21,14 @@ class FileType(StrEnum): def trans(cls, value: str) -> 'FileType': if value.startswith("image"): return cls.IMAGE - # TODO: other file type support - raise RuntimeError("Unsupport file type") + elif value.startswith("document"): + return cls.DOCUMENT + elif value.startswith("audio"): + return cls.AUDIO + elif value.startswith("video"): + return cls.VIDEO + else: + raise RuntimeError("Unsupport file type") class TransferMethod(str, Enum): @@ -37,6 +43,12 @@ class FileInput(BaseModel): transfer_method: TransferMethod = Field(..., description="传输方式: local_file/remote_url") upload_file_id: Optional[uuid.UUID] = Field(None, description="已上传文件ID(local_file时必填)") url: Optional[str] = Field(None, description="远程URL(remote_url时必填)") + file_type: Optional[str] = Field(None, description="具体文件格式(如image/jpg、audio/wav、document/docx、video/mp4)") + + def __init__(self, **data): + if "type" in data: + data['file_type'] = data['type'] + super().__init__(**data) @field_validator("type", mode="before") @classmethod diff --git a/api/app/schemas/chunk_schema.py b/api/app/schemas/chunk_schema.py index cef9b9cb..ce8f70f2 100644 --- a/api/app/schemas/chunk_schema.py +++ b/api/app/schemas/chunk_schema.py @@ -46,6 +46,7 @@ class ChunkUpdate(BaseModel): class ChunkRetrieve(BaseModel): query: str kb_ids: list[uuid.UUID] + file_names_filter: list[str] | None = Field(None) similarity_threshold: float | None = Field(None) vector_similarity_weight: float | None = Field(None) top_k: int | None = Field(None) diff --git a/api/app/schemas/model_schema.py b/api/app/schemas/model_schema.py index 0c0bbeed..ea4183a5 100644 --- a/api/app/schemas/model_schema.py +++ b/api/app/schemas/model_schema.py @@ -21,6 +21,8 @@ class ModelConfigBase(BaseModel): is_active: bool = Field(True, description="是否激活") is_public: bool = Field(False, description="是否公开") load_balance_strategy: Optional[str] = Field(LoadBalanceStrategy.NONE.value, description="负载均衡策略") + capability: List[str] = Field(default_factory=list, description="模型能力列表") + is_omni: bool = Field(False, description="是否为Omni模型") class ApiKeyCreateNested(BaseModel): @@ -30,6 +32,8 @@ class ApiKeyCreateNested(BaseModel): provider: Optional[str] = Field(None, description="API Key提供商") api_key: str = Field(..., description="API密钥", max_length=500) api_base: Optional[str] = Field(None, description="API基础URL", max_length=500) + capability: Optional[List[str]] = Field(None, description="模型能力列表") + is_omni: Optional[bool] = Field(None, description="是否为Omni模型") config: Optional[Dict[str, Any]] = Field({}, description="API Key特定配置") priority: str = Field("1", description="优先级", max_length=10) @@ -63,6 +67,8 @@ class ModelConfigUpdate(BaseModel): config: Optional[Dict[str, Any]] = Field(None, description="模型配置参数") is_active: Optional[bool] = Field(None, description="是否激活") is_public: Optional[bool] = Field(None, description="是否公开") + capability: Optional[List[str]] = Field(None, description="模型能力列表") + is_omni: Optional[bool] = Field(None, description="是否为Omni模型") class ModelConfig(ModelConfigBase): @@ -95,6 +101,8 @@ class ModelApiKeyCreateByProvider(BaseModel): api_key: str = Field(..., description="API密钥", max_length=500) api_base: Optional[str] = Field(None, description="API基础URL", max_length=500) description: Optional[str] = Field(None, description="备注") + capability: Optional[List[str]] = Field(None, description="模型能力列表") + is_omni: Optional[bool] = Field(None, description="是否为Omni模型") config: Optional[Dict[str, Any]] = Field({}, description="API Key特定配置") is_active: bool = Field(True, description="是否激活") priority: str = Field("1", description="优先级", max_length=10) @@ -108,6 +116,8 @@ class ModelApiKeyBase(BaseModel): provider: ModelProvider = Field(..., description="API Key提供商") api_key: str = Field(..., description="API密钥", max_length=500) api_base: Optional[str] = Field(None, description="API基础URL", max_length=500) + capability: Optional[List[str]] = Field(None, description="模型能力列表") + is_omni: Optional[bool] = Field(None, description="是否为Omni模型") config: Optional[Dict[str, Any]] = Field({}, description="API Key特定配置") is_active: bool = Field(True, description="是否激活") priority: str = Field("1", description="优先级", max_length=10) @@ -124,6 +134,8 @@ class ModelApiKeyUpdate(BaseModel): provider: Optional[ModelProvider] = Field(None, description="API Key提供商") api_key: Optional[str] = Field(None, description="API密钥", max_length=500) api_base: Optional[str] = Field(None, description="API基础URL", max_length=500) + capability: Optional[List[str]] = Field(None, description="模型能力列表") + is_omni: Optional[bool] = Field(None, description="是否为Omni模型") config: Optional[Dict[str, Any]] = Field(None, description="API Key特定配置") is_active: Optional[bool] = Field(None, description="是否激活") priority: Optional[str] = Field(None, description="优先级", max_length=10) @@ -270,6 +282,8 @@ class ModelBaseCreate(BaseModel): description: Optional[str] = Field(None, description="模型描述") is_official: bool = Field(True, description="是否供应商官方模型") tags: List[str] = Field(default_factory=list, description="模型标签") + capability: List[str] = Field(default_factory=list, description="模型能力列表(如['vision', 'audio', 'video'])") + is_omni: bool = Field(False, description="是否为Omni模型") class ModelBaseUpdate(BaseModel): @@ -282,6 +296,8 @@ class ModelBaseUpdate(BaseModel): is_deprecated: Optional[bool] = Field(None, description="是否弃用") is_official: Optional[bool] = Field(None, description="是否供应商官方模型") tags: Optional[List[str]] = Field(None, description="模型标签") + capability: Optional[List[str]] = Field(None, description="模型能力列表") + is_omni: Optional[bool] = Field(None, description="是否为Omni模型") class ModelBase(BaseModel): @@ -298,6 +314,8 @@ class ModelBase(BaseModel): is_official: bool tags: List[str] add_count: int + capability: List[str] = [] + is_omni: bool = False class ModelBaseQuery(BaseModel): diff --git a/api/app/schemas/multi_agent_schema.py b/api/app/schemas/multi_agent_schema.py index 8fba2929..3573e87c 100644 --- a/api/app/schemas/multi_agent_schema.py +++ b/api/app/schemas/multi_agent_schema.py @@ -64,14 +64,14 @@ class ExecutionConfig(BaseModel): class MultiAgentConfigCreate(BaseModel): """创建多 Agent 配置""" master_agent_id: uuid.UUID = Field(..., description="主 Agent ID") - master_agent_name: Optional[str] = Field(None, max_length=100, description="主 Agent 名称") + master_agent_name: Optional[str] = Field(default=None, max_length=100, description="主 Agent 名称") orchestration_mode: str = Field( default="collaboration", pattern="^(collaboration|supervisor)$", description="协作模式:collaboration(协作)| supervisor(监督)" ) sub_agents: List[SubAgentConfig] = Field(..., description="子 Agent 列表") - routing_rules: Optional[List[RoutingRule]] = Field(None, description="路由规则") + routing_rules: Optional[List[RoutingRule]] = Field(default=None, description="路由规则") execution_config: ExecutionConfig = Field(default_factory=ExecutionConfig, description="执行配置") aggregation_strategy: str = Field( default="merge", @@ -83,7 +83,7 @@ class MultiAgentConfigCreate(BaseModel): class MultiAgentConfigUpdate(BaseModel): """更新多 Agent 配置""" master_agent_id: Optional[uuid.UUID] = None - master_agent_name: Optional[str] = Field(None, max_length=100, description="主 Agent 名称") + master_agent_name: Optional[str] = Field(default=None, max_length=100, description="主 Agent 名称") default_model_config_id: Optional[uuid.UUID] = Field(None, description="默认模型配置ID") model_parameters: Optional[ModelParameters] = Field( None, diff --git a/api/app/schemas/ontology_schemas.py b/api/app/schemas/ontology_schemas.py index 88ecd712..905e65fe 100644 --- a/api/app/schemas/ontology_schemas.py +++ b/api/app/schemas/ontology_schemas.py @@ -241,6 +241,7 @@ class SceneResponse(BaseModel): created_at: datetime.datetime = Field(..., description="创建时间(毫秒时间戳)") updated_at: datetime.datetime = Field(..., description="更新时间(毫秒时间戳)") classes_count: int = Field(0, description="类型数量") + is_system_default: bool = Field(False, description="是否为系统默认场景") @field_serializer("created_at", when_used="json") def _serialize_created_at(self, dt: datetime.datetime): @@ -462,6 +463,7 @@ class ClassListResponse(BaseModel): scene_id: UUID = Field(..., description="所属场景ID") scene_name: str = Field(..., description="场景名称") scene_description: Optional[str] = Field(None, description="场景描述") + is_system_default: bool = Field(False, description="是否为系统默认场景") items: List[ClassResponse] = Field(..., description="类型列表") diff --git a/api/app/services/agent_tools.py b/api/app/services/agent_tools.py index 3ca7bddd..a4768b51 100644 --- a/api/app/services/agent_tools.py +++ b/api/app/services/agent_tools.py @@ -263,8 +263,8 @@ def create_agent_invocation_tool( try: # 9. 调用 Agent - from app.services.draft_run_service import DraftRunService - draft_service = DraftRunService(db) + from app.services.draft_run_service import AgentRunService + draft_service = AgentRunService(db) result = await draft_service.run( agent_config=agent_config, diff --git a/api/app/services/app_chat_service.py b/api/app/services/app_chat_service.py index 9723121d..5430d2f9 100644 --- a/api/app/services/app_chat_service.py +++ b/api/app/services/app_chat_service.py @@ -10,25 +10,24 @@ from sqlalchemy.orm import Session from app.core.agent.agent_middleware import AgentMiddleware from app.core.agent.langchain_agent import LangChainAgent -from app.core.error_codes import BizCode from app.core.exceptions import BusinessException from app.core.logging_config import get_business_logger -from app.db import get_db, get_db_context -from app.models import MultiAgentConfig, AgentConfig, WorkflowConfig -from app.schemas import DraftRunRequest -from app.schemas.app_schema import FileInput -from app.services.tool_service import ToolService -from app.repositories.tool_repository import ToolRepository from app.db import get_db from app.models import MultiAgentConfig, AgentConfig +from app.models import WorkflowConfig +from app.repositories.tool_repository import ToolRepository +from app.schemas import DraftRunRequest +from app.schemas.app_schema import FileInput from app.schemas.prompt_schema import render_prompt_message, PromptMessageRole from app.services.conversation_service import ConversationService -from app.services.draft_run_service import create_knowledge_retrieval_tool, create_long_term_memory_tool +from app.services.draft_run_service import create_knowledge_retrieval_tool, create_long_term_memory_tool, \ + AgentRunService from app.services.draft_run_service import create_web_search_tool from app.services.model_service import ModelApiKeyService from app.services.multi_agent_orchestrator import MultiAgentOrchestrator -from app.services.workflow_service import WorkflowService from app.services.multimodal_service import MultimodalService +from app.services.tool_service import ToolService +from app.services.workflow_service import WorkflowService logger = get_business_logger() @@ -39,6 +38,8 @@ class AppChatService: def __init__(self, db: Session): self.db = db self.conversation_service = ConversationService(db) + self.agent_service = AgentRunService(db) + self.workflow_service = WorkflowService(db) async def agnet_chat( self, @@ -55,12 +56,10 @@ class AppChatService: files: Optional[List[FileInput]] = None # 新增:多模态文件 ) -> Dict[str, Any]: """聊天(非流式)""" - start_time = time.time() config_id = None - if variables is None: - variables = {} + variables = self.agent_service.prepare_variables(variables, config.variables) # 获取模型配置ID model_config_id = config.default_model_config_id @@ -79,74 +78,20 @@ class AppChatService: tools = [] # 获取工具服务 - tool_service = ToolService(self.db) tenant_id = ToolRepository.get_tenant_id_by_workspace_id(self.db, str(workspace_id)) - # 从配置中获取启用的工具 - if hasattr(config, 'tools') and config.tools and isinstance(config.tools, list): - for tool_config in config.tools: - if tool_config.get("enabled", False): - # 根据工具名称查找工具实例 - tool_instance = tool_service._get_tool_instance(tool_config.get("tool_id", ""), tenant_id) - if tool_instance: - if tool_instance.name == "baidu_search_tool" and not web_search: - continue - # 转换为LangChain工具 - langchain_tool = tool_instance.to_langchain_tool(tool_config.get("operation", None)) - tools.append(langchain_tool) - elif hasattr(config, 'tools') and config.tools and isinstance(config.tools, dict): - web_tools = config.tools - web_search_choice = web_tools.get("web_search", {}) - web_search_enable = web_search_choice.get("enabled", False) - if web_search: - if web_search_enable: - search_tool = create_web_search_tool({}) - tools.append(search_tool) - - logger.debug( - "已添加网络搜索工具", - extra={ - "tool_count": len(tools) - } - ) - - # 加载技能关联的工具 - if hasattr(config, 'skills') and config.skills: - skills = config.skills - skill_enable = skills.get("enabled", False) - if skill_enable: - middleware = AgentMiddleware(skills=skills) - skill_tools, skill_configs, tool_to_skill_map = middleware.load_skill_tools(self.db, tenant_id) - tools.extend(skill_tools) - logger.debug(f"已加载 {len(skill_tools)} 个技能工具") - - # 应用动态过滤 - if skill_configs: - tools, activated_skill_ids = middleware.filter_tools(tools, message, skill_configs, - tool_to_skill_map) - logger.debug(f"过滤后剩余 {len(tools)} 个工具") - active_prompts = AgentMiddleware.get_active_prompts( - activated_skill_ids, skill_configs - ) - system_prompt = f"{system_prompt}\n\n{active_prompts}" - - # 添加知识库检索工具 - knowledge_retrieval = config.knowledge_retrieval - if knowledge_retrieval: - knowledge_bases = knowledge_retrieval.get("knowledge_bases", []) - kb_ids = [kb.get("kb_id") for kb in knowledge_bases if kb.get("kb_id")] - if kb_ids: - kb_tool = create_knowledge_retrieval_tool(knowledge_retrieval, kb_ids, user_id) - tools.append(kb_tool) - - # 添加长期记忆工具 + tools.extend(self.agent_service.load_tools_config(config.tools, web_search, tenant_id)) + skill_tools, skill_prompts = self.agent_service.load_skill_config(config.skills, message, tenant_id) + tools.extend(skill_tools) + if skill_prompts: + system_prompt = f"{system_prompt}\n\n{skill_prompts}" + tools.extend(self.agent_service.load_knowledge_retrieval_config(config.knowledge_retrieval, user_id)) memory_flag = False - if memory == True: - memory_config = config.memory - if memory_config.get("enabled") and user_id: - memory_flag = True - memory_tool = create_long_term_memory_tool(memory_config, user_id) - tools.append(memory_tool) + if memory: + memory_tools, memory_flag = self.agent_service.load_memory_config( + config.memory, user_id, storage_type, user_rag_memory_id + ) + tools.extend(memory_tools) # 获取模型参数 model_parameters = config.model_parameters @@ -157,6 +102,7 @@ class AppChatService: api_key=api_key_obj.api_key, provider=api_key_obj.provider, api_base=api_key_obj.api_base, + is_omni=api_key_obj.is_omni, temperature=model_parameters.get("temperature", 0.7), max_tokens=model_parameters.get("max_tokens", 2000), system_prompt=system_prompt, @@ -180,7 +126,7 @@ class AppChatService: # 处理多模态文件 processed_files = None if files: - multimodal_service = MultimodalService(self.db) + multimodal_service = MultimodalService(self.db, api_key_obj.provider, is_omni=api_key_obj.is_omni) processed_files = await multimodal_service.process_files(files) logger.info(f"处理了 {len(processed_files)} 个文件") @@ -245,10 +191,9 @@ class AppChatService: try: start_time = time.time() config_id = None + yield f"event: start\ndata: {json.dumps({'conversation_id': str(conversation_id)}, ensure_ascii=False)}\n\n" - if variables is None: - variables = {} - + variables = self.agent_service.prepare_variables(variables, config.variables) # 获取模型配置ID model_config_id = config.default_model_config_id api_key_obj = ModelApiKeyService.get_available_api_key(self.db, model_config_id) @@ -266,73 +211,22 @@ class AppChatService: tools = [] # 获取工具服务 - tool_service = ToolService(self.db) tenant_id = ToolRepository.get_tenant_id_by_workspace_id(self.db, str(workspace_id)) - if hasattr(config, 'tools') and config.tools and isinstance(config.tools, list): - for tool_config in config.tools: - if tool_config.get("enabled", False): - # 根据工具名称查找工具实例 - tool_instance = tool_service._get_tool_instance(tool_config.get("tool_id", ""), tenant_id) - if tool_instance: - if tool_instance.name == "baidu_search_tool" and not web_search: - continue - # 转换为LangChain工具 - langchain_tool = tool_instance.to_langchain_tool(tool_config.get("operation", None)) - tools.append(langchain_tool) - elif hasattr(config, 'tools') and config.tools and isinstance(config.tools, dict): - web_tools = config.tools - web_search_choice = web_tools.get("web_search", {}) - web_search_enable = web_search_choice.get("enabled", False) - if web_search: - if web_search_enable: - search_tool = create_web_search_tool({}) - tools.append(search_tool) - - logger.debug( - "已添加网络搜索工具", - extra={ - "tool_count": len(tools) - } - ) - - # 加载技能关联的工具 - if hasattr(config, 'skills') and config.skills: - skills = config.skills - skill_enable = skills.get("enabled", False) - if skill_enable: - middleware = AgentMiddleware(skills=skills) - skill_tools, skill_configs, tool_to_skill_map = middleware.load_skill_tools(self.db, tenant_id) - tools.extend(skill_tools) - logger.debug(f"已加载 {len(skill_tools)} 个技能工具") - - # 应用动态过滤 - if skill_configs: - tools, activated_skill_ids = middleware.filter_tools(tools, message, skill_configs, - tool_to_skill_map) - logger.debug(f"过滤后剩余 {len(tools)} 个工具") - active_prompts = AgentMiddleware.get_active_prompts( - activated_skill_ids, skill_configs - ) - system_prompt = f"{system_prompt}\n\n{active_prompts}" - - # 添加知识库检索工具 - knowledge_retrieval = config.knowledge_retrieval - if knowledge_retrieval: - knowledge_bases = knowledge_retrieval.get("knowledge_bases", []) - kb_ids = [kb.get("kb_id") for kb in knowledge_bases if kb.get("kb_id")] - if kb_ids: - kb_tool = create_knowledge_retrieval_tool(knowledge_retrieval, kb_ids, user_id) - tools.append(kb_tool) + tools.extend(self.agent_service.load_tools_config(config.tools, web_search, tenant_id)) + skill_tools, skill_prompts = self.agent_service.load_skill_config(config.skills, message, tenant_id) + tools.extend(skill_tools) + if skill_prompts: + system_prompt = f"{system_prompt}\n\n{skill_prompts}" + tools.extend(self.agent_service.load_knowledge_retrieval_config(config.knowledge_retrieval, user_id)) # 添加长期记忆工具 memory_flag = False if memory: - memory_config = config.memory - if memory_config.get("enabled") and user_id: - memory_flag = True - memory_tool = create_long_term_memory_tool(memory_config, user_id) - tools.append(memory_tool) + memory_tools, memory_flag = self.agent_service.load_memory_config( + config.memory, user_id, storage_type, user_rag_memory_id + ) + tools.extend(memory_tools) # 获取模型参数 model_parameters = config.model_parameters @@ -343,6 +237,7 @@ class AppChatService: api_key=api_key_obj.api_key, provider=api_key_obj.provider, api_base=api_key_obj.api_base, + is_omni=api_key_obj.is_omni, temperature=model_parameters.get("temperature", 0.7), max_tokens=model_parameters.get("max_tokens", 2000), system_prompt=system_prompt, @@ -366,13 +261,10 @@ class AppChatService: # 处理多模态文件 processed_files = None if files: - multimodal_service = MultimodalService(self.db) + multimodal_service = MultimodalService(self.db, api_key_obj.provider, is_omni=api_key_obj.is_omni) processed_files = await multimodal_service.process_files(files) logger.info(f"处理了 {len(processed_files)} 个文件") - # 发送开始事件 - yield f"event: start\ndata: {json.dumps({'conversation_id': str(conversation_id)}, ensure_ascii=False)}\n\n" - # 流式调用 Agent(支持多模态) full_content = "" total_tokens = 0 @@ -416,7 +308,7 @@ class AppChatService: ModelApiKeyService.record_api_key_usage(self.db, api_key_obj.id) # 发送结束事件 - end_data = {"elapsed_time": elapsed_time, "message_length": len(full_content)} + end_data = {"elapsed_time": elapsed_time, "message_length": len(full_content), "error": None} yield f"event: end\ndata: {json.dumps(end_data, ensure_ascii=False)}\n\n" logger.info( @@ -435,7 +327,7 @@ class AppChatService: except Exception as e: logger.error(f"流式聊天失败: {str(e)}", exc_info=True) # 发送错误事件 - yield f"event: error\ndata: {json.dumps({'error': str(e)}, ensure_ascii=False)}\n\n" + yield f"event: end\ndata: {json.dumps({'error': str(e)}, ensure_ascii=False)}\n\n" async def multi_agent_chat( self, @@ -489,10 +381,10 @@ class AppChatService: "mode": result.get("mode"), "elapsed_time": result.get("elapsed_time"), "usage": result.get("usage", { - "prompt_tokens": 0, - "completion_tokens": 0, - "total_tokens": 0 - }) + "prompt_tokens": 0, + "completion_tokens": 0, + "total_tokens": 0 + }) } ) @@ -522,8 +414,6 @@ class AppChatService: """多 Agent 聊天(流式)""" start_time = time.time() - actual_config_id = None - config_id = actual_config_id if variables is None: variables = {} @@ -629,7 +519,6 @@ class AppChatService: user_rag_memory_id: Optional[str] = None, ) -> Dict[str, Any]: """聊天(非流式)""" - workflow_service = WorkflowService(self.db) payload = DraftRunRequest( message=message, variables=variables, @@ -637,7 +526,7 @@ class AppChatService: stream=True, user_id=user_id ) - return await workflow_service.run( + return await self.workflow_service.run( app_id=app_id, payload=payload, config=config, @@ -664,7 +553,6 @@ class AppChatService: ) -> AsyncGenerator[dict, None]: """聊天(流式)""" - workflow_service = WorkflowService(self.db) payload = DraftRunRequest( message=message, variables=variables, @@ -673,7 +561,7 @@ class AppChatService: user_id=user_id, files=files ) - async for event in workflow_service.run_stream( + async for event in self.workflow_service.run_stream( app_id=app_id, payload=payload, config=config, diff --git a/api/app/services/app_service.py b/api/app/services/app_service.py index 6e6e0ecb..a248f869 100644 --- a/api/app/services/app_service.py +++ b/api/app/services/app_service.py @@ -232,7 +232,7 @@ class AppService: # 检查主 Agent 的模型配置 multi_agent_config.default_model_config_id = master_agent_release.default_model_config_id - model_api_key = ModelApiKeyService.get_a_api_key(self.db, multi_agent_config.default_model_config_id) + model_api_key = ModelApiKeyService.get_available_api_key(self.db, multi_agent_config.default_model_config_id) if not model_api_key: raise ResourceNotFoundException("模型配置", str(multi_agent_config.default_model_config_id)) @@ -1791,372 +1791,6 @@ class AppService: return shares - # ==================== 试运行功能 ==================== - - async def draft_run( - self, - *, - app_id: uuid.UUID, - message: str, - conversation_id: Optional[str] = None, - user_id: Optional[str] = None, - variables: Optional[Dict[str, Any]] = None, - workspace_id: Optional[uuid.UUID] = None - ) -> Dict[str, Any]: - """试运行 Agent(使用当前草稿配置) - - Args: - app_id: 应用ID - message: 用户消息 - conversation_id: 会话ID(用于多轮对话) - user_id: 用户ID(用于会话管理) - variables: 自定义变量参数值 - workspace_id: 工作空间ID(用于权限验证) - - Returns: - Dict: 包含 AI 回复和元数据的字典 - - Raises: - ResourceNotFoundException: 当应用不存在时 - BusinessException: 当应用类型不支持或配置缺失时 - """ - from app.services.draft_run_service import DraftRunService - - logger.info("试运行 Agent", extra={"app_id": str(app_id), "user_message": message[:50]}) - - # 1. 验证应用 - app = self._get_app_or_404(app_id) - - if app.type != "agent": - raise BusinessException("只有 Agent 类型应用支持试运行", BizCode.APP_TYPE_NOT_SUPPORTED) - - # 只读操作,允许访问共享应用 - self._validate_app_accessible(app, workspace_id) - - # 2. 获取 Agent 配置 - stmt = select(AgentConfig).where(AgentConfig.app_id == app_id) - agent_cfg = self.db.scalars(stmt).first() - - if not agent_cfg: - raise BusinessException("Agent 配置不存在,无法试运行", BizCode.AGENT_CONFIG_MISSING) - - # 3. 获取模型配置 - model_config = None - if agent_cfg.default_model_config_id: - from app.models import ModelConfig - model_config = self.db.get(ModelConfig, agent_cfg.default_model_config_id) - - if not model_config: - raise BusinessException("模型配置不存在,无法试运行", BizCode.AGENT_CONFIG_MISSING) - - # 4. 调用试运行服务 - logger.debug( - "准备调用试运行服务", - extra={ - "app_id": str(app_id), - "model": model_config.name, - "has_conversation_id": bool(conversation_id), - "has_variables": bool(variables) - } - ) - - draft_service = DraftRunService(self.db) - result = await draft_service.run( - agent_config=agent_cfg, - model_config=model_config, - message=message, - workspace_id=workspace_id, - conversation_id=conversation_id, - user_id=user_id, - variables=variables - ) - - logger.debug( - "试运行服务返回结果", - extra={ - "result_type": str(type(result)), - "result_keys": list(result.keys()) if isinstance(result, dict) else "not_dict", - "has_message": "message" in result if isinstance(result, dict) else False, - "has_conversation_id": "conversation_id" in result if isinstance(result, dict) else False - } - ) - - logger.info( - "试运行完成", - extra={ - "app_id": str(app_id), - "elapsed_time": result.get("elapsed_time"), - "model": model_config.name - } - ) - - return result - - async def draft_run_stream( - self, - *, - app_id: uuid.UUID, - message: str, - conversation_id: Optional[str] = None, - user_id: Optional[str] = None, - variables: Optional[Dict[str, Any]] = None, - workspace_id: Optional[uuid.UUID] = None - ): - """试运行 Agent(流式返回) - - Args: - app_id: 应用ID - message: 用户消息 - conversation_id: 会话ID(用于多轮对话) - user_id: 用户ID(用于会话管理) - variables: 自定义变量参数值 - workspace_id: 工作空间ID(用于权限验证) - - Yields: - str: SSE 格式的事件数据 - - Raises: - ResourceNotFoundException: 当应用不存在时 - BusinessException: 当应用类型不支持或配置缺失时 - """ - from app.services.draft_run_service import DraftRunService - - logger.info("流式试运行 Agent", extra={"app_id": str(app_id), "user_message": message[:50]}) - - # 1. 验证应用 - app = self._get_app_or_404(app_id) - - if app.type != "agent": - raise BusinessException("只有 Agent 类型应用支持试运行", BizCode.APP_TYPE_NOT_SUPPORTED) - - # 只读操作,允许访问共享应用 - self._validate_app_accessible(app, workspace_id) - - # 2. 获取 Agent 配置 - stmt = select(AgentConfig).where(AgentConfig.app_id == app_id) - agent_cfg = self.db.scalars(stmt).first() - - if not agent_cfg: - raise BusinessException("Agent 配置不存在,无法试运行", BizCode.AGENT_CONFIG_MISSING) - - # 3. 获取模型配置 - model_config = None - if agent_cfg.default_model_config_id: - from app.models import ModelConfig - model_config = self.db.get(ModelConfig, agent_cfg.default_model_config_id) - - if not model_config: - raise BusinessException("模型配置不存在,无法试运行", BizCode.AGENT_CONFIG_MISSING) - - # 4. 调用流式试运行服务 - draft_service = DraftRunService(self.db) - async for event in draft_service.run_stream( - agent_config=agent_cfg, - model_config=model_config, - message=message, - workspace_id=workspace_id, - conversation_id=conversation_id, - user_id=user_id, - variables=variables - ): - yield event - - # ==================== 多模型对比试运行 ==================== - - async def draft_run_compare( - self, - *, - app_id: uuid.UUID, - message: str, - models: List[app_schema.ModelCompareItem], - conversation_id: Optional[str] = None, - user_id: Optional[str] = None, - variables: Optional[Dict[str, Any]] = None, - workspace_id: Optional[uuid.UUID] = None, - parallel: bool = True, - timeout: int = 60 - ) -> Dict[str, Any]: - """多模型对比试运行 - - Args: - app_id: 应用ID - message: 用户消息 - models: 要对比的模型列表 - conversation_id: 会话ID - user_id: 用户ID - variables: 变量参数 - workspace_id: 工作空间ID - parallel: 是否并行执行 - timeout: 超时时间(秒) - - Returns: - Dict: 对比结果 - """ - from app.models import ModelConfig - from app.services.draft_run_service import DraftRunService - - logger.info( - "多模型对比试运行", - extra={ - "app_id": str(app_id), - "model_count": len(models), - "parallel": parallel - } - ) - - # 1. 验证应用 - app = self._get_app_or_404(app_id) - if app.type != "agent": - raise BusinessException("只有 Agent 类型应用支持试运行", BizCode.APP_TYPE_NOT_SUPPORTED) - - # 只读操作,允许访问共享应用 - self._validate_app_accessible(app, workspace_id) - - # 2. 获取 Agent 配置 - stmt = select(AgentConfig).where(AgentConfig.app_id == app_id) - agent_cfg = self.db.scalars(stmt).first() - if not agent_cfg: - raise BusinessException("Agent 配置不存在", BizCode.AGENT_CONFIG_MISSING) - - # 3. 准备所有模型配置 - model_configs = [] - for model_item in models: - model_config = self.db.get(ModelConfig, model_item.model_config_id) - if not model_config: - raise ResourceNotFoundException("模型配置", str(model_item.model_config_id)) - - # 合并参数:agent配置参数 + 请求覆盖参数 - merged_parameters = { - **(agent_cfg.model_parameters or {}), - **(model_item.model_parameters or {}) - } - - model_configs.append({ - "model_config": model_config, - "parameters": merged_parameters, - "label": model_item.label or model_config.name, - "model_config_id": model_item.model_config_id - }) - - # 4. 调用 DraftRunService 的对比方法 - draft_service = DraftRunService(self.db) - result = await draft_service.run_compare( - agent_config=agent_cfg, - models=model_configs, - message=message, - workspace_id=workspace_id, - conversation_id=conversation_id, - user_id=user_id, - variables=variables, - parallel=parallel, - timeout=timeout - ) - - logger.info( - "多模型对比完成", - extra={ - "app_id": str(app_id), - "successful": result["successful_count"], - "failed": result["failed_count"] - } - ) - - return result - - async def draft_run_compare_stream( - self, - *, - app_id: uuid.UUID, - message: str, - models: List[app_schema.ModelCompareItem], - conversation_id: Optional[str] = None, - user_id: Optional[str] = None, - variables: Optional[Dict[str, Any]] = None, - workspace_id: Optional[uuid.UUID] = None, - parallel: bool = True, - timeout: int = 60 - ): - """多模型对比试运行(流式返回) - - Args: - app_id: 应用ID - message: 用户消息 - models: 要对比的模型列表 - conversation_id: 会话ID - user_id: 用户ID - variables: 变量参数 - workspace_id: 工作空间ID - timeout: 超时时间(秒) - - Yields: - str: SSE 格式的事件数据 - """ - from app.models import ModelConfig - from app.services.draft_run_service import DraftRunService - - logger.info( - "多模型对比流式试运行", - extra={ - "app_id": str(app_id), - "model_count": len(models) - } - ) - - # 1. 验证应用 - app = self._get_app_or_404(app_id) - if app.type != "agent": - raise BusinessException("只有 Agent 类型应用支持试运行", BizCode.APP_TYPE_NOT_SUPPORTED) - - # 只读操作,允许访问共享应用 - self._validate_app_accessible(app, workspace_id) - - # 2. 获取 Agent 配置 - stmt = select(AgentConfig).where(AgentConfig.app_id == app_id) - agent_cfg = self.db.scalars(stmt).first() - if not agent_cfg: - raise BusinessException("Agent 配置不存在", BizCode.AGENT_CONFIG_MISSING) - - # 3. 准备所有模型配置 - model_configs = [] - for model_item in models: - model_config = self.db.get(ModelConfig, model_item.model_config_id) - if not model_config: - raise ResourceNotFoundException("模型配置", str(model_item.model_config_id)) - - # 合并参数:agent配置参数 + 请求覆盖参数 - merged_parameters = { - **(agent_cfg.model_parameters or {}), - **(model_item.model_parameters or {}) - } - - model_configs.append({ - "model_config": model_config, - "parameters": merged_parameters, - "label": model_item.label or model_config.name, - "model_config_id": model_item.model_config_id - }) - - # 4. 调用 DraftRunService 的流式对比方法 - draft_service = DraftRunService(self.db) - async for event in draft_service.run_compare_stream( - agent_config=agent_cfg, - models=model_configs, - message=message, - workspace_id=workspace_id, - conversation_id=conversation_id, - user_id=user_id, - variables=variables, - parallel=parallel, - timeout=timeout - ): - yield event - - logger.info( - "多模型对比流式完成", - extra={"app_id": str(app_id)} - ) - - # ==================== 向后兼容的函数接口 ==================== # 保留函数接口以兼容现有代码,但内部使用服务类 @@ -2278,53 +1912,6 @@ def get_apps_by_ids( return service.get_apps_by_ids(app_ids, workspace_id) -# ==================== 向后兼容的函数接口 ==================== - -async def draft_run( - db: Session, - *, - app_id: uuid.UUID, - message: str, - conversation_id: Optional[str] = None, - user_id: Optional[str] = None, - variables: Optional[Dict[str, Any]] = None, - workspace_id: Optional[uuid.UUID] = None -) -> Dict[str, Any]: - """试运行 Agent(向后兼容接口)""" - service = AppService(db) - return await service.draft_run( - app_id=app_id, - message=message, - conversation_id=conversation_id, - user_id=user_id, - variables=variables, - workspace_id=workspace_id - ) - - -async def draft_run_stream( - db: Session, - *, - app_id: uuid.UUID, - message: str, - conversation_id: Optional[str] = None, - user_id: Optional[str] = None, - variables: Optional[Dict[str, Any]] = None, - workspace_id: Optional[uuid.UUID] = None -): - """试运行 Agent 流式返回(向后兼容接口)""" - service = AppService(db) - async for event in service.draft_run_stream( - app_id=app_id, - message=message, - conversation_id=conversation_id, - user_id=user_id, - variables=variables, - workspace_id=workspace_id - ): - yield event - - # ==================== 依赖注入函数 ==================== def get_app_service( diff --git a/api/app/services/audio_transcription_service.py b/api/app/services/audio_transcription_service.py new file mode 100644 index 00000000..11d13f38 --- /dev/null +++ b/api/app/services/audio_transcription_service.py @@ -0,0 +1,101 @@ +""" +音频转文本服务 + +支持的服务商: +- DashScope (阿里云通义千问) +- OpenAI Whisper +""" +import httpx + +from app.core.logging_config import get_business_logger + +logger = get_business_logger() + + +class AudioTranscriptionService: + """音频转文本服务""" + + @staticmethod + async def transcribe_dashscope(audio_url: str, api_key: str) -> str: + """ + 使用阿里云通义千问语音识别服务转换音频为文本 + + Args: + audio_url: 音频文件 URL + api_key: DashScope API Key + + Returns: + str: 转录的文本 + """ + try: + async with httpx.AsyncClient(timeout=60.0) as client: + response = await client.post( + "https://dashscope.aliyuncs.com/api/v1/services/audio/asr/transcription", + headers={ + "Authorization": f"Bearer {api_key}", + "Content-Type": "application/json", + "X-DashScope-Async": "enable", + }, + json={ + "model": "paraformer-v2", + "input": { + "file_urls": [audio_url] + }, + "parameters": { + "language_hints": ["zh", "en", "ja", "yue", "ko", "de", "fr", "ru"] + } + } + ) + response.raise_for_status() + result = response.json() + + if result.get("output", {}).get("results"): + text = result["output"]["results"][0].get("transcription_text", "") + logger.info(f"音频转文本成功: {len(text)} 字符") + return text + + return "[音频转文本失败]" + + except Exception as e: + logger.error(f"DashScope 音频转文本失败: {e}") + return f"[音频转文本失败: {str(e)}]" + + @staticmethod + async def transcribe_openai(audio_url: str, api_key: str) -> str: + """ + 使用 OpenAI Whisper 转换音频为文本 + + Args: + audio_url: 音频文件 URL + api_key: OpenAI API Key + + Returns: + str: 转录的文本 + """ + try: + # 下载音频文件 + async with httpx.AsyncClient(timeout=60.0) as client: + audio_response = await client.get(audio_url) + audio_response.raise_for_status() + audio_data = audio_response.content + + # 调用 Whisper API + files = {"file": ("audio.mp3", audio_data, "audio/mpeg")} + data = {"model": "whisper-1"} + + response = await client.post( + "https://api.openai.com/v1/audio/transcriptions", + headers={"Authorization": f"Bearer {api_key}"}, + files=files, + data=data + ) + response.raise_for_status() + result = response.json() + + text = result.get("text", "") + logger.info(f"音频转文本成功: {len(text)} 字符") + return text + + except Exception as e: + logger.error(f"OpenAI Whisper 音频转文本失败: {e}") + return f"[音频转文本失败: {str(e)}]" diff --git a/api/app/services/collaborative_orchestrator.py b/api/app/services/collaborative_orchestrator.py index 00a731de..68181cd1 100644 --- a/api/app/services/collaborative_orchestrator.py +++ b/api/app/services/collaborative_orchestrator.py @@ -445,6 +445,7 @@ class CollaborativeOrchestrator: "provider": api_key_config.provider, "api_key": api_key_config.api_key, "api_base": api_key_config.api_base, + "is_omni": api_key_config.is_omni, "model_parameters": config_data.get("model_parameters", {}), "api_key_id": api_key_config.id } @@ -511,6 +512,7 @@ class CollaborativeOrchestrator: provider=agent_config["provider"], api_key=agent_config["api_key"], base_url=agent_config.get("api_base"), + is_omni=agent_config.get("is_omni", False), extra_params=extra_params ) diff --git a/api/app/services/draft_run_service.py b/api/app/services/draft_run_service.py index 8977710b..5026bf27 100644 --- a/api/app/services/draft_run_service.py +++ b/api/app/services/draft_run_service.py @@ -17,15 +17,18 @@ from sqlalchemy.orm import Session from app.celery_app import celery_app from app.core.agent.agent_middleware import AgentMiddleware +from app.core.agent.langchain_agent import LangChainAgent from app.core.error_codes import BizCode from app.core.exceptions import BusinessException from app.core.logging_config import get_business_logger from app.core.rag.nlp.search import knowledge_retrieval +from app.db import get_db_context from app.models import AgentConfig, ModelConfig from app.repositories.tool_repository import ToolRepository from app.schemas.app_schema import FileInput from app.schemas.prompt_schema import PromptMessageRole, render_prompt_message from app.services import task_service +from app.services.conversation_service import ConversationService from app.services.langchain_tool_server import Search from app.services.memory_agent_service import MemoryAgentService from app.services.model_parameter_merger import ModelParameterMerger @@ -52,8 +55,12 @@ class LongTermMemoryInput(BaseModel): description="经过优化重写的查询问题。请将用户的原始问题重写为更合适的检索形式,包含关键词,上下文和具体描述,注意错词检查并且改写") -def create_long_term_memory_tool(memory_config: Dict[str, Any], end_user_id: str, storage_type: Optional[str] = None, - user_rag_memory_id: Optional[str] = None): +def create_long_term_memory_tool( + memory_config: Dict[str, Any], + end_user_id: str, + storage_type: Optional[str] = None, + user_rag_memory_id: Optional[str] = None +): """创建记忆工具, @@ -61,6 +68,7 @@ def create_long_term_memory_tool(memory_config: Dict[str, Any], end_user_id: str memory_config: 记忆配置 end_user_id: 用户ID storage_type: 存储类型(可选) + user_rag_memory_id: 用户RAG记忆ID(可选) Returns: 长期记忆工具 @@ -96,9 +104,7 @@ def create_long_term_memory_tool(memory_config: Dict[str, Any], end_user_id: str """ logger.info(f" 长期记忆工具被调用!question={question}, user={end_user_id}") try: - from app.db import get_db - db = next(get_db()) - try: + with get_db_context() as db: memory_content = asyncio.run( MemoryAgentService().read_memory( end_user_id=end_user_id, @@ -120,9 +126,6 @@ def create_long_term_memory_tool(memory_config: Dict[str, Any], end_user_id: str logger.info(f"读取任务状态:{status}") if memory_content: memory_content = memory_content['answer'] - - finally: - db.close() logger.info(f'用户ID:Agent:{end_user_id}') logger.debug("调用长期记忆 API", extra={"question": question, "end_user_id": end_user_id}) @@ -188,7 +191,9 @@ def create_knowledge_retrieval_tool(kb_config, kb_ids, user_id): """从知识库中检索相关信息。当用户的问题需要参考知识库、文档或历史记录时,使用此工具进行检索。 Args: - query: 需要检索的问题或关键词 + kb_config: 知识库配置 + kb_ids: 知识库ID列表 + user_id: 用户ID Returns: 检索到的相关知识内容 @@ -232,17 +237,141 @@ def create_knowledge_retrieval_tool(kb_config, kb_ids, user_id): return knowledge_retrieval_tool -class DraftRunService: - """试运行服务类""" +class AgentRunService: + """Agent运行服务类""" def __init__(self, db: Session): - """初始化试运行服务 + """Agent运行服务 Args: db: 数据库会话 """ self.db = db + @staticmethod + def prepare_variables( + input_vars: dict | None, + variables_config: dict + ) -> dict: + input_vars = input_vars or {} + for variable in variables_config: + if variable.get("required") and variable.get("name") not in input_vars: + raise ValueError(f"The required parameter '{variable.get('name')}' was not provided") + return input_vars + + def load_tools_config(self, tools_config, web_search, tenant_id) -> list: + """加载工具配置""" + if not tools_config: + return [] + tools = [] + tool_service = ToolService(self.db) + + if tools_config and isinstance(tools_config, list): + for tool_config in tools_config: + if tool_config.get("enabled", False): + # 根据工具名称查找工具实例 + tool_instance = tool_service.get_tool_instance(tool_config.get("tool_id", ""), tenant_id) + if tool_instance: + if tool_instance.name == "baidu_search_tool" and not web_search: + continue + # 转换为LangChain工具 + langchain_tool = tool_instance.to_langchain_tool(tool_config.get("operation", None)) + tools.append(langchain_tool) + elif tools_config and isinstance(tools_config, dict): + web_search_choice = tools_config.get("web_search", {}) + web_search_enable = web_search_choice.get("enabled", False) + if web_search and web_search_enable: + search_tool = create_web_search_tool({}) + tools.append(search_tool) + + logger.debug( + "已添加网络搜索工具", + extra={ + "tool_count": len(tools) + } + ) + return tools + + def load_skill_config( + self, + skills_config: dict | None, + message: str, tenant_id + ) -> tuple[list, str]: + if not skills_config: + return [], "" + + tools = [] + skill_prompts = "" + skill_enable = skills_config.get("enabled", False) + if skill_enable: + middleware = AgentMiddleware(skills=skills_config) + skill_tools, skill_configs, tool_to_skill_map = middleware.load_skill_tools(self.db, tenant_id) + tools.extend(skill_tools) + logger.debug(f"已加载 {len(skill_tools)} 个技能工具") + + if skill_configs: + tools, activated_skill_ids = middleware.filter_tools(tools, message, skill_configs, + tool_to_skill_map) + logger.debug(f"过滤后剩余 {len(tools)} 个工具") + skill_prompts = AgentMiddleware.get_active_prompts( + activated_skill_ids, skill_configs + ) + + return tools, skill_prompts + + def load_knowledge_retrieval_config( + self, + knowledge_retrieval_config: dict | None, + user_id + ) -> list: + if not knowledge_retrieval_config: + return [] + + tools = [] + knowledge_bases = knowledge_retrieval_config.get("knowledge_bases", []) + kb_ids = bool(knowledge_bases and knowledge_bases[0].get("kb_id")) + if kb_ids: + # 创建知识库检索工具 + kb_tool = create_knowledge_retrieval_tool(knowledge_retrieval_config, kb_ids, user_id) + tools.append(kb_tool) + + logger.debug( + "已添加知识库检索工具", + extra={ + "kb_ids": kb_ids, + "tool_count": len(tools) + } + ) + return tools + + def load_memory_config( + self, + memory_config: dict | None, + user_id, + storage_type, + user_rag_memory_id + ) -> tuple[list, bool]: + """加载长期记忆配置""" + if not memory_config: + return [], False + + tools = [] + if memory_config.get("enabled"): + if user_id: + # 创建长期记忆工具 + memory_tool = create_long_term_memory_tool(memory_config, user_id, storage_type, + user_rag_memory_id) + tools.append(memory_tool) + + logger.debug( + "已添加长期记忆工具", + extra={ + "user_id": user_id, + "tool_count": len(tools) + } + ) + return tools, bool(memory_config.get("enabled")) + async def run( self, *, @@ -270,19 +399,21 @@ class DraftRunService: conversation_id: 会话ID(用于多轮对话) user_id: 用户ID variables: 自定义变量参数值 + storage_type: 存储类型(可选) + user_rag_memory_id: 用户RAG记忆ID(可选) + web_search: 是否启用网络搜索(默认True) + memory: 是否启用长期记忆(默认True) + sub_agent: 是否为子代理调用(默认False) + files: 多模态文件列表(可选) Returns: Dict: 包含 AI 回复和元数据的字典 """ - memory_flag = False - - print('===========', storage_type) - - print(user_id) - if variables == None: variables = {} - from app.core.agent.langchain_agent import LangChainAgent - start_time = time.time() + tools_config: dict | list | None = agent_config.tools + skills_config: dict | None = agent_config.skills + knowledge_retrieval_config: dict | None = agent_config.knowledge_retrieval + memory_config: dict | None = agent_config.memory try: # 1. 获取 API Key 配置 @@ -302,112 +433,40 @@ class DraftRunService: agent_config=agent_config ) - items_params = variables + if sub_agent: + variables = self.prepare_variables(variables, agent_config.variables) + else: + # FIXME: subagent input valid + variables = variables or {} + system_prompt = render_prompt_message( - agent_config.system_prompt, # 修正拼写错误 + agent_config.system_prompt, PromptMessageRole.USER, - items_params + variables ) # 3. 处理系统提示词(支持变量替换) system_prompt = system_prompt.get_text_content() or "你是一个专业的AI助手" - print('系统提示词:', system_prompt) # 4. 准备工具列表 tools = [] - tool_service = ToolService(self.db) tenant_id = ToolRepository.get_tenant_id_by_workspace_id(self.db, str(workspace_id)) # 从配置中获取启用的工具 - if hasattr(agent_config, 'tools') and agent_config.tools and isinstance(agent_config.tools, list): - if hasattr(agent_config, 'tools') and agent_config.tools: - for tool_config in agent_config.tools: - print("+" * 50) - print(f"agent_config:{agent_config}") - print(f"tool_config:{tool_config}") - if tool_config.get("enabled", False): - # 根据工具名称查找工具实例 - tool_instance = tool_service._get_tool_instance(tool_config.get("tool_id", ""), tenant_id) - if tool_instance: - if tool_instance.name == "baidu_search_tool" and not web_search: - continue - # 转换为LangChain工具 - langchain_tool = tool_instance.to_langchain_tool(tool_config.get("operation", None)) - tools.append(langchain_tool) - elif hasattr(agent_config, 'tools') and agent_config.tools and isinstance(agent_config.tools, dict): - web_tools = agent_config.tools - web_search_choice = web_tools.get("web_search", {}) - web_search_enable = web_search_choice.get("enabled", False) - if web_search: - if web_search_enable: - search_tool = create_web_search_tool({}) - tools.append(search_tool) - - logger.debug( - "已添加网络搜索工具", - extra={ - "tool_count": len(tools) - } - ) - - # 加载技能关联的工具 - if hasattr(agent_config, 'skills') and agent_config.skills: - skills = agent_config.skills - skill_enable = skills.get("enabled", False) - if skill_enable: - middleware = AgentMiddleware(skills=skills) - skill_tools, skill_configs, tool_to_skill_map = middleware.load_skill_tools(self.db, tenant_id) - tools.extend(skill_tools) - logger.debug(f"已加载 {len(skill_tools)} 个技能工具") - - # 应用动态过滤 - if skill_configs: - tools, activated_skill_ids = middleware.filter_tools(tools, message, skill_configs, - tool_to_skill_map) - logger.debug(f"过滤后剩余 {len(tools)} 个工具") - active_prompts = AgentMiddleware.get_active_prompts( - activated_skill_ids, skill_configs - ) - system_prompt = f"{system_prompt}\n\n{active_prompts}" - - # 添加知识库检索工具 - if agent_config.knowledge_retrieval: - kb_config = agent_config.knowledge_retrieval - knowledge_bases = kb_config.get("knowledge_bases", []) - kb_ids = bool(knowledge_bases and knowledge_bases[0].get("kb_id")) - if kb_ids: - # 创建知识库检索工具 - kb_tool = create_knowledge_retrieval_tool(kb_config, kb_ids, user_id) - tools.append(kb_tool) - - logger.debug( - "已添加知识库检索工具", - extra={ - "kb_ids": kb_ids, - "tool_count": len(tools) - } - ) - + tools.extend(self.load_tools_config(tools_config, web_search, tenant_id)) + skill_tools, skill_prompts = self.load_skill_config(skills_config, message, tenant_id) + tools.extend(skill_tools) + if skill_prompts: + system_prompt = f"{system_prompt}\n\n{skill_prompts}" + tools.extend(self.load_knowledge_retrieval_config(knowledge_retrieval_config, user_id)) # 添加长期记忆工具 + memory_flag = False if memory: - if agent_config.memory and agent_config.memory.get("enabled"): - memory_flag = True - - memory_config = agent_config.memory - if user_id: - # 创建长期记忆工具 - memory_tool = create_long_term_memory_tool(memory_config, user_id, storage_type, - user_rag_memory_id) - tools.append(memory_tool) - - logger.debug( - "已添加长期记忆工具", - extra={ - "user_id": user_id, - "tool_count": len(tools) - } - ) + memory_tools, memory_flag = self.load_memory_config( + memory_config, user_id, storage_type, user_rag_memory_id + ) + tools.extend(memory_tools) # 4. 创建 LangChain Agent agent = LangChainAgent( @@ -415,6 +474,7 @@ class DraftRunService: api_key=api_key_config["api_key"], provider=api_key_config.get("provider", "openai"), api_base=api_key_config.get("api_base"), + is_omni=api_key_config.get("is_omni", False), temperature=effective_params.get("temperature", 0.7), max_tokens=effective_params.get("max_tokens", 2000), system_prompt=system_prompt, @@ -431,7 +491,7 @@ class DraftRunService: # 6. 加载历史消息 history = [] - if agent_config.memory and agent_config.memory.get("enabled"): + if memory_config and memory_config.get("enabled"): history = await self._load_conversation_history( conversation_id=conversation_id, max_history=agent_config.memory.get("max_history", 10) @@ -442,7 +502,7 @@ class DraftRunService: if files: # 获取 provider 信息 provider = api_key_config.get("provider", "openai") - multimodal_service = MultimodalService(self.db, provider=provider) + multimodal_service = MultimodalService(self.db, provider=provider, is_omni=api_key_config.get("is_omni", False)) processed_files = await multimodal_service.process_files(files) logger.info(f"处理了 {len(processed_files)} 个文件,provider={provider}") @@ -481,7 +541,7 @@ class DraftRunService: ModelApiKeyService.record_api_key_usage(self.db, api_key_config.get("api_key_id")) # 9. 保存会话消息 - if not sub_agent and agent_config.memory and agent_config.memory.get("enabled"): + if not sub_agent and memory_config and memory_config.get("enabled"): await self._save_conversation_message( conversation_id=conversation_id, user_message=message, @@ -556,16 +616,21 @@ class DraftRunService: Yields: str: SSE 格式的事件数据 """ - memory_flag = False - if variables == None: variables = {} - - from app.core.agent.langchain_agent import LangChainAgent + tools_config: dict | list | None = agent_config.tools + skills_config: dict | None = agent_config.skills + knowledge_retrieval_config: dict | None = agent_config.knowledge_retrieval + memory_config: dict | None = agent_config.memory start_time = time.time() try: # 1. 获取 API Key 配置 api_key_config = await self._get_api_key(model_config.id) + if not sub_agent: + variables = self.prepare_variables(variables, agent_config.variables) + else: + # FIXME: subagent input valid + variables = variables or {} # 2. 合并模型参数 effective_params = ModelParameterMerger.get_effective_parameters( @@ -587,95 +652,22 @@ class DraftRunService: # 4. 准备工具列表 tools = [] - tool_service = ToolService(self.db) tenant_id = ToolRepository.get_tenant_id_by_workspace_id(self.db, str(workspace_id)) # 从配置中获取启用的工具 - if hasattr(agent_config, 'tools') and agent_config.tools and isinstance(agent_config.tools, list): - for tool_config in agent_config.tools: - # print("+"*50) - # print(f"agent_config:{agent_config}") - # print(f"tool_config:{tool_config}") - if tool_config.get("enabled", False): - # 根据工具名称查找工具实例 - tool_instance = tool_service._get_tool_instance(tool_config.get("tool_id", ""), tenant_id) - if tool_instance: - if tool_instance.name == "baidu_search_tool" and not web_search: - continue - # 转换为LangChain工具 - langchain_tool = tool_instance.to_langchain_tool(tool_config.get("operation", None)) - tools.append(langchain_tool) - elif hasattr(agent_config, 'tools') and agent_config.tools and isinstance(agent_config.tools, dict): - web_tools = agent_config.tools - web_search_choice = web_tools.get("web_search", {}) - web_search_enable = web_search_choice.get("enabled", False) - if web_search: - if web_search_enable: - search_tool = create_web_search_tool({}) - tools.append(search_tool) + tools.extend(self.load_tools_config(tools_config, web_search, tenant_id)) + skill_tools, skill_prompts = self.load_skill_config(skills_config, message, tenant_id) + tools.extend(skill_tools) + if skill_prompts: + system_prompt = f"{system_prompt}\n\n{skill_prompts}" + tools.extend(self.load_knowledge_retrieval_config(knowledge_retrieval_config, user_id)) - logger.debug( - "已添加网络搜索工具", - extra={ - "tool_count": len(tools) - } - ) - - # 加载技能关联的工具 - if hasattr(agent_config, 'skills') and agent_config.skills: - skills = agent_config.skills - skill_enable = skills.get("enabled", False) - if skill_enable: - middleware = AgentMiddleware(skills=skills) - skill_tools, skill_configs, tool_to_skill_map = middleware.load_skill_tools(self.db, tenant_id) - tools.extend(skill_tools) - logger.debug(f"已加载 {len(skill_tools)} 个技能工具") - - # 应用动态过滤 - if skill_configs: - tools, activated_skill_ids = middleware.filter_tools(tools, message, skill_configs, - tool_to_skill_map) - logger.debug(f"过滤后剩余 {len(tools)} 个工具") - active_prompts = AgentMiddleware.get_active_prompts( - activated_skill_ids, skill_configs - ) - system_prompt = f"{system_prompt}\n\n{active_prompts}" - - # 添加知识库检索工具 - if agent_config.knowledge_retrieval: - kb_config = agent_config.knowledge_retrieval - knowledge_bases = kb_config.get("knowledge_bases", []) - kb_ids = bool(knowledge_bases and knowledge_bases[0].get("kb_id")) - if kb_ids: - # 创建知识库检索工具 - kb_tool = create_knowledge_retrieval_tool(kb_config, kb_ids, user_id) - tools.append(kb_tool) - - logger.debug( - "已添加知识库检索工具", - extra={ - "kb_ids": kb_ids, - "tool_count": len(tools) - } - ) # 添加长期记忆工具 + memory_flag = False if memory: - if agent_config.memory and agent_config.memory.get("enabled"): - memory_flag = True - memory_config = agent_config.memory - if user_id: - # 创建长期记忆工具 - memory_tool = create_long_term_memory_tool(memory_config, user_id, storage_type, - user_rag_memory_id) - tools.append(memory_tool) - - logger.debug( - "已添加长期记忆工具", - extra={ - "user_id": user_id, - "tool_count": len(tools) - } - ) + memory_tools, memory_flag = self.load_memory_config(memory_config, user_id, storage_type, + user_rag_memory_id) + tools.extend(memory_tools) # 4. 创建 LangChain Agent agent = LangChainAgent( @@ -683,6 +675,7 @@ class DraftRunService: api_key=api_key_config["api_key"], provider=api_key_config.get("provider", "openai"), api_base=api_key_config.get("api_base"), + is_omni=api_key_config.get("is_omni", False), temperature=effective_params.get("temperature", 0.7), max_tokens=effective_params.get("max_tokens", 2000), system_prompt=system_prompt, @@ -700,10 +693,10 @@ class DraftRunService: # 6. 加载历史消息 history = [] - if agent_config.memory and agent_config.memory.get("enabled"): + if memory_config and memory_config.get("enabled"): history = await self._load_conversation_history( conversation_id=conversation_id, - max_history=agent_config.memory.get("max_history", 10) + max_history=memory_config.get("max_history", 10) ) # 6. 处理多模态文件 @@ -711,7 +704,7 @@ class DraftRunService: if files: # 获取 provider 信息 provider = api_key_config.get("provider", "openai") - multimodal_service = MultimodalService(self.db, provider=provider) + multimodal_service = MultimodalService(self.db, provider=provider, is_omni=api_key_config.get("is_omni", False)) processed_files = await multimodal_service.process_files(files) logger.info(f"处理了 {len(processed_files)} 个文件,provider={provider}") @@ -761,7 +754,7 @@ class DraftRunService: }) # 10. 保存会话消息 - if not sub_agent and agent_config.memory and agent_config.memory.get("enabled"): + if not sub_agent and memory_config and memory_config.get("enabled"): await self._save_conversation_message( conversation_id=conversation_id, user_message=message, @@ -809,7 +802,7 @@ class DraftRunService: """ return f"event: {event}\ndata: {json.dumps(data, ensure_ascii=False)}\n\n" - async def _get_api_key(self, model_config_id: uuid.UUID) -> Dict[str, str]: + async def _get_api_key(self, model_config_id: uuid.UUID) -> Dict: """获取模型的 API Key Args: @@ -846,7 +839,8 @@ class DraftRunService: "provider": api_key.provider, "api_key": api_key.api_key, "api_base": api_key.api_base, - "api_key_id": api_key.id + "api_key_id": api_key.id, + "is_omni": api_key.is_omni } async def _ensure_conversation( @@ -966,7 +960,6 @@ class DraftRunService: List[Dict]: 历史消息列表 """ try: - from app.services.conversation_service import ConversationService conversation_service = ConversationService(self.db) history = conversation_service.get_conversation_history( @@ -1486,6 +1479,15 @@ class DraftRunService: "conversation_id": returned_conversation_id, "content": chunk })) + + if event_type == "error" and event_data: + await event_queue.put(self._format_sse_event("model_error", { + "model_index": idx, + "model_config_id": model_config_id, + "label": model_label, + "conversation_id": returned_conversation_id, + "error": event_data.get("error", "未知错误") + })) except Exception as e: logger.warning(f"解析流式事件失败: {e}") finally: @@ -1670,41 +1672,3 @@ class DraftRunService: "total_time": sum(r.get("elapsed_time", 0) for r in results) } ) - - -async def draft_run( - db: Session, - *, - agent_config: AgentConfig, - model_config: ModelConfig, - message: str, - user_id: Optional[str] = None, - kb_ids: Optional[List[str]] = None, - similarity_threshold: float = 0.7, - top_k: int = 3 -) -> Dict[str, Any]: - """试运行 Agent(便捷函数) - - Args: - db: 数据库会话 - agent_config: Agent 配置 - model_config: 模型配置 - message: 用户消息 - user_id: 用户ID - kb_ids: 知识库ID列表 - similarity_threshold: 相似度阈值 - top_k: 检索返回的文档数量 - - Returns: - Dict: 包含 AI 回复和元数据的字典 - """ - service = DraftRunService(db) - return await service.run( - agent_config=agent_config, - model_config=model_config, - message=message, - user_id=user_id, - kb_ids=kb_ids, - similarity_threshold=similarity_threshold, - top_k=top_k - ) diff --git a/api/app/services/emotion_analytics_service.py b/api/app/services/emotion_analytics_service.py index 89e3cab9..c226348e 100644 --- a/api/app/services/emotion_analytics_service.py +++ b/api/app/services/emotion_analytics_service.py @@ -843,32 +843,33 @@ class EmotionAnalyticsService: end_user_id: str, db: Session, ) -> Optional[Dict[str, Any]]: - """从 Redis 缓存获取个性化情绪建议 + """从数据库获取个性化情绪建议 Args: end_user_id: 宿主ID(用户组ID) - db: 数据库会话(保留参数以保持接口兼容性) + db: 数据库会话 Returns: - Dict: 缓存的建议数据,如果不存在或已过期返回 None + Dict: 存储的建议数据,如果不存在返回 None """ try: - from app.cache.memory.emotion_memory import EmotionMemoryCache + from app.repositories.implicit_emotions_storage_repository import ImplicitEmotionsStorageRepository - logger.info(f"尝试从 Redis 缓存获取情绪建议: user={end_user_id}") + logger.info(f"尝试从数据库获取情绪建议: user={end_user_id}") - # 从 Redis 获取缓存 - cached_data = await EmotionMemoryCache.get_emotion_suggestions(end_user_id) + # 从数据库获取存储记录 + repo = ImplicitEmotionsStorageRepository(db) + storage = repo.get_by_end_user_id(end_user_id) - if cached_data is None: - logger.info(f"用户 {end_user_id} 的建议缓存不存在或已过期") + if storage is None or storage.emotion_suggestions is None: + logger.info(f"用户 {end_user_id} 的建议数据不存在") return None - logger.info(f"成功从 Redis 缓存获取建议: user={end_user_id}") - return cached_data + logger.info(f"成功从数据库获取建议: user={end_user_id}") + return storage.emotion_suggestions except Exception as e: - logger.error(f"从 Redis 缓存获取建议失败: {str(e)}", exc_info=True) + logger.error(f"从数据库获取建议失败: {str(e)}", exc_info=True) return None async def save_suggestions_cache( @@ -876,36 +877,27 @@ class EmotionAnalyticsService: end_user_id: str, suggestions_data: Dict[str, Any], db: Session, - expires_hours: int = 24 + expires_hours: int = 24 # 参数保留以保持接口兼容性 ) -> None: - """保存建议到 Redis 缓存 + """保存建议到数据库 Args: end_user_id: 宿主ID(用户组ID) suggestions_data: 建议数据 - db: 数据库会话(保留参数以保持接口兼容性) - expires_hours: 过期时间(小时),默认24小时 + db: 数据库会话 + expires_hours: 保留参数(兼容性) """ try: - from app.cache.memory.emotion_memory import EmotionMemoryCache + from app.repositories.implicit_emotions_storage_repository import ImplicitEmotionsStorageRepository - logger.info(f"保存建议到 Redis 缓存: user={end_user_id}, expires={expires_hours}小时") + logger.info(f"保存建议到数据库: user={end_user_id}") - # 计算过期时间(秒) - expire_seconds = expires_hours * 3600 + repo = ImplicitEmotionsStorageRepository(db) + repo.update_emotion_suggestions(end_user_id, suggestions_data) + db.commit() - # 保存到 Redis - success = await EmotionMemoryCache.set_emotion_suggestions( - user_id=end_user_id, - suggestions_data=suggestions_data, - expire=expire_seconds - ) - - if success: - logger.info(f"建议缓存保存成功: user={end_user_id}") - else: - logger.warning(f"建议缓存保存失败: user={end_user_id}") + logger.info(f"建议保存成功: user={end_user_id}") except Exception as e: - logger.error(f"保存建议缓存失败: {str(e)}", exc_info=True) - # 不抛出异常,缓存失败不应影响主流程 \ No newline at end of file + db.rollback() + logger.error(f"保存建议失败: {str(e)}", exc_info=True) \ No newline at end of file diff --git a/api/app/services/handoffs_service.py b/api/app/services/handoffs_service.py index e490eea4..8418fe31 100644 --- a/api/app/services/handoffs_service.py +++ b/api/app/services/handoffs_service.py @@ -544,6 +544,7 @@ def convert_multi_agent_config_to_handoffs( provider=model_api_key.provider, api_key=model_api_key.api_key, base_url=model_api_key.api_base, + is_omni=model_api_key.is_omni, extra_params={ "temperature": 0.7, "max_tokens": 2000, diff --git a/api/app/services/implicit_memory_service.py b/api/app/services/implicit_memory_service.py index 34ebe880..4bd11deb 100644 --- a/api/app/services/implicit_memory_service.py +++ b/api/app/services/implicit_memory_service.py @@ -422,32 +422,33 @@ class ImplicitMemoryService: end_user_id: str, db: Session ) -> Optional[dict]: - """从 Redis 缓存获取完整用户画像 + """从数据库获取完整用户画像 Args: end_user_id: 终端用户ID - db: 数据库会话(保留参数以保持接口兼容性) + db: 数据库会话 Returns: - Dict: 缓存的画像数据,如果不存在或已过期返回 None + Dict: 存储的画像数据,如果不存在返回 None """ try: - from app.cache.memory.implicit_memory import ImplicitMemoryCache + from app.repositories.implicit_emotions_storage_repository import ImplicitEmotionsStorageRepository - logger.info(f"尝试从 Redis 缓存获取用户画像: user={end_user_id}") + logger.info(f"尝试从数据库获取用户画像: user={end_user_id}") - # 从 Redis 获取缓存 - cached_data = await ImplicitMemoryCache.get_user_profile(end_user_id) + # 从数据库获取存储记录 + repo = ImplicitEmotionsStorageRepository(db) + storage = repo.get_by_end_user_id(end_user_id) - if cached_data is None: - logger.info(f"用户 {end_user_id} 的画像缓存不存在或已过期") + if storage is None or storage.implicit_profile is None: + logger.info(f"用户 {end_user_id} 的画像数据不存在") return None - logger.info(f"成功从 Redis 缓存获取用户画像: user={end_user_id}") - return cached_data + logger.info(f"成功从数据库获取用户画像: user={end_user_id}") + return storage.implicit_profile except Exception as e: - logger.error(f"从 Redis 缓存获取用户画像失败: {str(e)}", exc_info=True) + logger.error(f"从数据库获取用户画像失败: {str(e)}", exc_info=True) return None async def save_profile_cache( @@ -455,36 +456,27 @@ class ImplicitMemoryService: end_user_id: str, profile_data: dict, db: Session, - expires_hours: int = 168 # 默认7天 + expires_hours: int = 168 # 参数保留以保持接口兼容性 ) -> None: - """保存用户画像到 Redis 缓存 + """保存用户画像到数据库 Args: end_user_id: 终端用户ID profile_data: 画像数据 - db: 数据库会话(保留参数以保持接口兼容性) - expires_hours: 过期时间(小时),默认168小时(7天) + db: 数据库会话 + expires_hours: 保留参数(兼容性) """ try: - from app.cache.memory.implicit_memory import ImplicitMemoryCache + from app.repositories.implicit_emotions_storage_repository import ImplicitEmotionsStorageRepository - logger.info(f"保存用户画像到 Redis 缓存: user={end_user_id}, expires={expires_hours}小时") + logger.info(f"保存用户画像到数据库: user={end_user_id}") - # 计算过期时间(秒) - expire_seconds = expires_hours * 3600 + repo = ImplicitEmotionsStorageRepository(db) + repo.update_implicit_profile(end_user_id, profile_data) + db.commit() - # 保存到 Redis - success = await ImplicitMemoryCache.set_user_profile( - user_id=end_user_id, - profile_data=profile_data, - expire=expire_seconds - ) - - if success: - logger.info(f"用户画像缓存保存成功: user={end_user_id}") - else: - logger.warning(f"用户画像缓存保存失败: user={end_user_id}") + logger.info(f"用户画像保存成功: user={end_user_id}") except Exception as e: - logger.error(f"保存用户画像缓存失败: {str(e)}", exc_info=True) - # 不抛出异常,缓存失败不应影响主流程 + db.rollback() + logger.error(f"保存用户画像失败: {str(e)}", exc_info=True) diff --git a/api/app/services/langchain_tool_server.py b/api/app/services/langchain_tool_server.py index f44e4cdc..2c151956 100644 --- a/api/app/services/langchain_tool_server.py +++ b/api/app/services/langchain_tool_server.py @@ -9,6 +9,8 @@ load_dotenv() # 读取web_search环境变量 web_search_value = os.getenv('web_search') + + def Search(query): url = "https://qianfan.baidubce.com/v2/ai_search/chat/completions" api_key = web_search_value @@ -18,23 +20,24 @@ def Search(query): "role": "user", "content": query } - ], #搜索输入 - "edition":"standard", #搜索版本。默认为standard。可选值:standard:完整版本。lite:标准版本,对召回规模和精排条数简化后的版本,时延表现更好,效果略弱于完整版。 - "search_source": "baidu_search_v2", #使用的搜索引擎版本 - "resource_type_filter": [{"type": "web","top_k": 20}], #支持设置网页、视频、图片、阿拉丁搜索模态,网页top_k最大取值为50,视频top_k最大为10,图片top_k最大为30,阿拉丁top_k最大为5 + ], # 搜索输入 + "edition": "standard", # 搜索版本。默认为standard。可选值:standard:完整版本。lite:标准版本,对召回规模和精排条数简化后的版本,时延表现更好,效果略弱于完整版。 + "search_source": "baidu_search_v2", # 使用的搜索引擎版本 + "resource_type_filter": [{"type": "web", "top_k": 20}], + # 支持设置网页、视频、图片、阿拉丁搜索模态,网页top_k最大取值为50,视频top_k最大为10,图片top_k最大为30,阿拉丁top_k最大为5 "search_filter": { "range": { "page_time": { - "gte": "now-1w/d", #时间查询参数,大于或等于 - "lt": "now/d", #时间查询参数,小于 - "gt": "", #时间查询参数,大于 - "lte": "" #时间查询参数,小于或等于 + "gte": "now-1w/d", # 时间查询参数,大于或等于 + "lt": "now/d", # 时间查询参数,小于 + "gt": "", # 时间查询参数,大于 + "lte": "" # 时间查询参数,小于或等于 } } }, - "block_websites":["tieba.baidu.com"], #需要屏蔽的站点列表 - "search_recency_filter":"week", #根据网页发布时间进行筛选,可填值为:week,month,semiyear,year - "enable_full_content":True #是否输出网页完整原文 + "block_websites": ["tieba.baidu.com"], # 需要屏蔽的站点列表 + "search_recency_filter": "week", # 根据网页发布时间进行筛选,可填值为:week,month,semiyear,year + "enable_full_content": True # 是否输出网页完整原文 }, ensure_ascii=False) headers = { 'Content-Type': 'application/json', @@ -42,10 +45,10 @@ def Search(query): } response = requests.request("POST", url, headers=headers, data=payload.encode("utf-8")).json() - content=[] + content = [] for i in response['references']: - title=i['title'] - snippet=i['snippet'] - content.append(title+';'+snippet) - content='。'.join(content) - return content \ No newline at end of file + title = i['title'] + snippet = i['snippet'] + content.append(title + ';' + snippet) + content = '。'.join(content) + return content diff --git a/api/app/services/llm_router.py b/api/app/services/llm_router.py index e56ad5aa..02895d6b 100644 --- a/api/app/services/llm_router.py +++ b/api/app/services/llm_router.py @@ -414,6 +414,7 @@ class LLMRouter: provider=api_key_config.provider, api_key=api_key_config.api_key, base_url=api_key_config.api_base, + is_omni=api_key_config.is_omni, temperature=0.3, max_tokens=500 ) diff --git a/api/app/services/master_agent_router.py b/api/app/services/master_agent_router.py index 3cf3ecc3..b0f43b51 100644 --- a/api/app/services/master_agent_router.py +++ b/api/app/services/master_agent_router.py @@ -392,6 +392,7 @@ class MasterAgentRouter: provider=api_key_config.provider, api_key=api_key_config.api_key, base_url=api_key_config.api_base, + is_omni=api_key_config.is_omni, extra_params = extra_params ) diff --git a/api/app/services/memory_agent_service.py b/api/app/services/memory_agent_service.py index 16aee283..f272c541 100644 --- a/api/app/services/memory_agent_service.py +++ b/api/app/services/memory_agent_service.py @@ -13,7 +13,6 @@ TODO: Refactor get_end_user_connected_config """ import json import os -import re import time import uuid from typing import Any, AsyncGenerator, Dict, List, Optional @@ -35,12 +34,10 @@ from app.core.memory.agent.utils.messages_tools import ( reorder_output_results, ) from app.core.memory.agent.utils.type_classifier import status_typle -from app.core.memory.agent.utils.write_tools import write # 新增:直接导入 write 函数 -from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags, get_interest_distribution +from app.core.memory.analytics.hot_memory_tags import get_interest_distribution from app.core.memory.utils.llm.llm_utils import MemoryClientFactory from app.db import get_db_context from app.models.knowledge_model import Knowledge, KnowledgeType -from app.repositories.memory_short_repository import ShortTermMemoryRepository from app.repositories.neo4j.neo4j_connector import Neo4jConnector from app.schemas.memory_agent_schema import Write_UserInput from app.schemas.memory_config_schema import ConfigurationError @@ -69,7 +66,8 @@ class MemoryAgentService: logger.info(f"Write operation successful for group {end_user_id} with config_id {config_id}") # 记录成功的操作 if audit_logger: - audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id, success=True, + audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id, + success=True, duration=duration, details={"message_length": len(message)}) return context else: @@ -88,8 +86,6 @@ class MemoryAgentService: raise ValueError(f"写入失败: {messages}") - - def extract_tool_call_info(self, event: Dict) -> bool: """Extract tool call information from event""" last_message = event["messages"][-1] @@ -271,7 +267,8 @@ class MemoryAgentService: logger.info("Log streaming completed, cleaning up resources") # LogStreamer uses context manager for file handling, so cleanup is automatic - async def write_memory(self, end_user_id: str, messages: list[dict], config_id: Optional[uuid.UUID]|int, db: Session, storage_type: str, user_rag_memory_id: str, language: str = "zh") -> str: + async def write_memory(self, end_user_id: str, messages: list[dict], config_id: Optional[uuid.UUID] | int, + db: Session, storage_type: str, user_rag_memory_id: str, language: str = "zh") -> str: """ Process write operation with config_id @@ -300,7 +297,8 @@ class MemoryAgentService: config_id = connected_config.get("memory_config_id") logger.info(f"Resolved config from end_user: config_id={config_id}, workspace_id={workspace_id}") if config_id is None and workspace_id is None: - raise ValueError(f"No memory configuration found for end_user {end_user_id}. Please ensure the user has a connected memory configuration.") + raise ValueError( + f"No memory configuration found for end_user {end_user_id}. Please ensure the user has a connected memory configuration.") except Exception as e: if "No memory configuration found" in str(e): raise # Re-raise our specific error @@ -331,7 +329,8 @@ class MemoryAgentService: # Log failed operation if audit_logger: duration = time.time() - start_time - audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id, success=False, duration=duration, error=error_msg) + audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id, + success=False, duration=duration, error=error_msg) raise ValueError(error_msg) @@ -351,9 +350,9 @@ class MemoryAgentService: langchain_messages.append(HumanMessage(content=msg['content'])) elif msg['role'] == 'assistant': langchain_messages.append(AIMessage(content=msg['content'])) - print(100*'-') + print(100 * '-') print(langchain_messages) - print(100*'-') + print(100 * '-') # 初始状态 - 包含所有必要字段 initial_state = { "messages": langchain_messages, @@ -375,29 +374,28 @@ class MemoryAgentService: contents = massages.get('write_result') # Convert messages back to string for logging message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages]) - return self.writer_messages_deal(massagesstatus, start_time, end_user_id, config_id, message_text, contents) + return self.writer_messages_deal(massagesstatus, start_time, end_user_id, config_id, message_text, + contents) except Exception as e: # Ensure proper error handling and logging error_msg = f"Write operation failed: {str(e)}" logger.error(error_msg) if audit_logger: duration = time.time() - start_time - audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id, success=False, duration=duration, error=error_msg) + audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id, + success=False, duration=duration, error=error_msg) raise ValueError(error_msg) - - - async def read_memory( - self, - end_user_id: str, - message: str, - history: List[Dict], - search_switch: str, - config_id: Optional[uuid.UUID]|int, - db: Session, - storage_type: str, - user_rag_memory_id: str) -> Dict: + self, + end_user_id: str, + message: str, + history: List[Dict], + search_switch: str, + config_id: Optional[uuid.UUID] | int, + db: Session, + storage_type: str, + user_rag_memory_id: str) -> Dict: """ Process read operation with config_id @@ -425,7 +423,7 @@ class MemoryAgentService: import time start_time = time.time() - ori_message= message + ori_message = message # Resolve config_id and workspace_id # Always get workspace_id from end_user for fallback, even if config_id is provided @@ -437,7 +435,8 @@ class MemoryAgentService: config_id = connected_config.get("memory_config_id") logger.info(f"Resolved config from end_user: config_id={config_id}, workspace_id={workspace_id}") if config_id is None and workspace_id is None: - raise ValueError(f"No memory configuration found for end_user {end_user_id}. Please ensure the user has a connected memory configuration.") + raise ValueError( + f"No memory configuration found for end_user {end_user_id}. Please ensure the user has a connected memory configuration.") except Exception as e: if "No memory configuration found" in str(e): raise # Re-raise our specific error @@ -454,7 +453,6 @@ class MemoryAgentService: except ImportError: audit_logger = None - config_load_start = time.time() try: # Use a separate database session to avoid transaction failures @@ -562,34 +560,35 @@ class MemoryAgentService: from app.repositories.memory_short_repository import ( ShortTermMemoryRepository, ) - + retrieved_content = [] repo = ShortTermMemoryRepository(db) - + if str(search_switch) != "2": for intermediate in _intermediate_outputs: logger.debug(f"处理中间结果: {intermediate}") intermediate_type = intermediate.get('type', '') - + if intermediate_type == "search_result": query = intermediate.get('query', '') raw_results = intermediate.get('raw_results', {}) try: reranked_results = raw_results.get('reranked_results', []) - statements = [statement['statement'] for statement in reranked_results.get('statements', [])] + statements = [statement['statement'] for statement in + reranked_results.get('statements', [])] except Exception: statements = [] - + # 去重 statements = list(set(statements)) - + if query and statements: retrieved_content.append({query: statements}) - + # 如果 retrieved_content 为空,设置为空字符串 if retrieved_content == []: retrieved_content = '' - + # 只有当回答不是"信息不足"且不是快速检索时才保存 if '信息不足,无法回答。' != str(summary) and str(search_switch).strip() != "2": # 使用 upsert 方法 @@ -602,15 +601,17 @@ class MemoryAgentService: ) logger.info(f"成功保存短期记忆: end_user_id={end_user_id}, search_switch={search_switch}") else: - logger.debug(f"跳过保存短期记忆: summary={summary[:50] if summary else 'None'}, search_switch={search_switch}") - + logger.debug( + f"跳过保存短期记忆: summary={summary[:50] if summary else 'None'}, search_switch={search_switch}") + except Exception as save_error: # 保存失败不应该影响主流程,只记录错误 logger.error(f"保存短期记忆失败: {str(save_error)}", exc_info=True) # Log successful operation total_time = time.time() - start_time - logger.info(f"[PERF] read_memory completed successfully in {total_time:.4f}s (config: {config_load_time:.4f}s, graph: {graph_exec_time:.4f}s)") + logger.info( + f"[PERF] read_memory completed successfully in {total_time:.4f}s (config: {config_load_time:.4f}s, graph: {graph_exec_time:.4f}s)") if audit_logger: duration = time.time() - start_time audit_logger.log_operation( @@ -641,7 +642,6 @@ class MemoryAgentService: ) raise ValueError(error_msg) - def get_messages_list(self, user_input: Write_UserInput) -> list[dict]: """ Get standardized message list from user input. @@ -657,41 +657,43 @@ class MemoryAgentService: """ from app.core.logging_config import get_api_logger logger = get_api_logger() - + if len(user_input.messages) == 0: logger.error("Validation failed: Message list cannot be empty") raise ValueError("Message list cannot be empty") - + for idx, msg in enumerate(user_input.messages): if not isinstance(msg, dict): logger.error(f"Validation failed: Message {idx} is not a dict: {type(msg)}") - raise ValueError(f"Message format error: Message must be a dictionary. Error message index: {idx}, type: {type(msg)}") - + raise ValueError( + f"Message format error: Message must be a dictionary. Error message index: {idx}, type: {type(msg)}") + if 'role' not in msg: logger.error(f"Validation failed: Message {idx} missing 'role' field: {msg}") raise ValueError(f"Message format error: Message must contain 'role' field. Error message index: {idx}") - + if 'content' not in msg: logger.error(f"Validation failed: Message {idx} missing 'content' field: {msg}") - raise ValueError(f"Message format error: Message must contain 'content' field. Error message index: {idx}") - + raise ValueError( + f"Message format error: Message must contain 'content' field. Error message index: {idx}") + if msg['role'] not in ['user', 'assistant']: logger.error(f"Validation failed: Message {idx} invalid role: {msg['role']}") raise ValueError(f"Role must be 'user' or 'assistant', got: {msg['role']}. Message index: {idx}") - + if not msg['content'] or not msg['content'].strip(): logger.error(f"Validation failed: Message {idx} content is empty") raise ValueError(f"Message content cannot be empty. Message index: {idx}, role: {msg['role']}") - + logger.info(f"Validation successful: Structured message list, count: {len(user_input.messages)}") return user_input.messages async def classify_message_type( - self, - message: str, - config_id: UUID, - db: Session, - workspace_id: Optional[UUID] = None + self, + message: str, + config_id: UUID, + db: Session, + workspace_id: Optional[UUID] = None ) -> Dict: """ Determine the type of user message (read or write) @@ -719,14 +721,15 @@ class MemoryAgentService: status = await status_typle(message, memory_config.llm_model_id) logger.debug(f"Message type: {status}") return status + async def generate_summary_from_retrieve( - self, - end_user_id: str, - retrieve_info: str, - history: List[Dict], - query: str, - config_id: str, - db: Session + self, + end_user_id: str, + retrieve_info: str, + history: List[Dict], + query: str, + config_id: str, + db: Session ) -> str: """ 基于检索信息、历史对话和查询生成最终答案 @@ -761,9 +764,9 @@ class MemoryAgentService: if config_id is None: raise ValueError(f"Unable to determine memory configuration for end_user {end_user_id}: {e}") # If config_id was provided, continue without workspace_id fallback - + logger.info(f"Generating summary from retrieve info for query: {query[:50]}...") - + try: # 加载配置 config_service = MemoryConfigService(db) @@ -772,7 +775,7 @@ class MemoryAgentService: workspace_id=workspace_id, service_name="MemoryAgentService" ) - + # 导入必要的模块 from app.core.memory.agent.langgraph_graph.nodes.summary_nodes import ( summary_llm, @@ -780,13 +783,13 @@ class MemoryAgentService: from app.core.memory.agent.models.summary_models import ( RetrieveSummaryResponse, ) - + # 构建状态对象 state = { "data": query, "memory_config": memory_config } - + # 直接调用 summary_llm 函数 answer = await summary_llm( state=state, @@ -797,21 +800,20 @@ class MemoryAgentService: response_model=RetrieveSummaryResponse, search_mode="1" ) - + logger.info(f"Successfully generated summary: {answer[:100] if answer else 'None'}...") return answer if answer else "信息不足,无法回答。" - + except Exception as e: logger.error(f"生成摘要失败: {str(e)}", exc_info=True) return "信息不足,无法回答。" - async def get_knowledge_type_stats( - self, - end_user_id: Optional[str] = None, - only_active: bool = True, - current_workspace_id: Optional[uuid.UUID] = None, - db: Session = None + self, + db: Session, + end_user_id: Optional[str] = None, + only_active: bool = True, + current_workspace_id: Optional[uuid.UUID] = None ) -> Dict[str, Any]: """ 统计知识库类型分布,包含: @@ -837,11 +839,6 @@ class MemoryAgentService: # 1. 统计 PostgreSQL 中的知识库类型 try: - if db is None: - from app.db import get_db - db_gen = get_db() - db = next(db_gen) - # 初始化所有标准类型为 0 for kb_type in KnowledgeType: result[kb_type.value] = 0 @@ -881,21 +878,19 @@ class MemoryAgentService: # 3. 计算知识库类型总和(不包括 memory) result["total"] = ( - result.get("General", 0) + - result.get("Web", 0) + - result.get("Third-party", 0) + - result.get("Folder", 0) + result.get("General", 0) + + result.get("Web", 0) + + result.get("Third-party", 0) + + result.get("Folder", 0) ) return result - - async def get_interest_distribution_by_user( - self, - end_user_id: Optional[str] = None, - limit: int = 5, - language: str = "zh" + self, + end_user_id: Optional[str] = None, + limit: int = 5, + language: str = "zh" ) -> List[Dict[str, Any]]: """ 获取指定用户的兴趣分布标签。 @@ -921,13 +916,12 @@ class MemoryAgentService: logger.error(f"兴趣分布标签查询失败: {e}") raise Exception(f"兴趣分布标签查询失败: {e}") - async def get_user_profile( - self, - end_user_id: Optional[str] = None, - current_user_id: Optional[str] = None, - llm_id: Optional[str] = None, - db: Session = None + self, + end_user_id: Optional[str] = None, + current_user_id: Optional[str] = None, + llm_id: Optional[str] = None, + db: Session = None ) -> Dict[str, Any]: """ 获取用户详情,包含: @@ -1017,7 +1011,8 @@ class MemoryAgentService: # 定义标签提取的结构 class UserTags(BaseModel): - tags: list[str] = Field(..., description="3个描述用户特征的标签,如:产品设计师、旅行爱好者、摄影发烧友") + tags: list[str] = Field(..., + description="3个描述用户特征的标签,如:产品设计师、旅行爱好者、摄影发烧友") messages = [ { @@ -1160,7 +1155,6 @@ def get_end_user_connected_config(end_user_id: str, db: Session) -> Dict[str, An ValueError: 当终端用户不存在或应用未发布时 """ import json as json_module - import uuid from sqlalchemy import select @@ -1192,14 +1186,14 @@ def get_end_user_connected_config(end_user_id: str, db: Session) -> Dict[str, An # 3. 兼容旧数据:如果 memory_config_id 为空,从 AppRelease.config 获取并回填 memory_config_id_to_use = end_user.memory_config_id - + # 如果已有 memory_config_id,直接使用 # 如果新创建enduser,enduser.memory_config_id 必定为none # 那么使用从release中获取memory_config_id为预期行为,并且回填到 # end_user.memory_config_id if not memory_config_id_to_use: logger.info(f"end_user.memory_config_id is None, migrating from AppRelease.config") - + # 获取最新发布版本 stmt = ( select(AppRelease) @@ -1208,10 +1202,10 @@ def get_end_user_connected_config(end_user_id: str, db: Session) -> Dict[str, An ) # TODO: change to current_release_id latest_release = db.scalars(stmt).first() - + if latest_release: config = latest_release.config or {} - + # 如果 config 是字符串,解析为字典 if isinstance(config, str): try: @@ -1219,22 +1213,22 @@ def get_end_user_connected_config(end_user_id: str, db: Session) -> Dict[str, An except json_module.JSONDecodeError: logger.warning(f"Failed to parse config JSON for release {latest_release.id}") config = {} - + # 使用 MemoryConfigService 的提取方法 memory_config_service = MemoryConfigService(db) legacy_config_id, is_legacy_int = memory_config_service.extract_memory_config_id( app_type=app.type, config=config ) - + if legacy_config_id: # 验证提取的 config_id 是否存在于数据库中 from app.models.memory_config_model import MemoryConfig as MemoryConfigModel existing_config = db.get(MemoryConfigModel, legacy_config_id) - + if existing_config: memory_config_id_to_use = legacy_config_id - + # 回填到 end_user 表(lazy update) end_user.memory_config_id = memory_config_id_to_use db.commit() @@ -1268,7 +1262,8 @@ def get_end_user_connected_config(end_user_id: str, db: Session) -> Dict[str, An "workspace_id": str(app.workspace_id) } - logger.info(f"Successfully retrieved connected config: memory_config_id={memory_config_id}, workspace_id={app.workspace_id}") + logger.info( + f"Successfully retrieved connected config: memory_config_id={memory_config_id}, workspace_id={app.workspace_id}") return result @@ -1312,7 +1307,7 @@ def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session) # 1. 批量查询所有 end_user 及其 app_id 和 memory_config_id end_users = db.query(EndUser).filter(EndUser.id.in_(end_user_ids)).all() - + # 创建映射 - 保留 EndUser 对象引用以便回填 end_user_map = {str(eu.id): eu for eu in end_users} user_data = {str(eu.id): {"app_id": eu.app_id, "memory_config_id": eu.memory_config_id} for eu in end_users} @@ -1336,15 +1331,15 @@ def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session) # 3. 对于没有 memory_config_id 的用户,尝试从 AppRelease.config 提取 users_needing_migration = [ - (end_user_id, data["app_id"]) - for end_user_id, data in user_data.items() + (end_user_id, data["app_id"]) + for end_user_id, data in user_data.items() if not data["memory_config_id"] ] - + if users_needing_migration: # 批量获取相关应用的最新发布版本 migration_app_ids = list(set(app_id for _, app_id in users_needing_migration)) - + # 查询每个应用的最新活跃发布版本 app_latest_releases = {} for app_id in migration_app_ids: @@ -1357,18 +1352,18 @@ def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session) latest_release = db.scalars(stmt).first() if latest_release: app_latest_releases[app_id] = latest_release - + # 为每个需要迁移的用户提取 memory_config_id config_service = MemoryConfigService(db) users_to_backfill = [] # [(end_user, memory_config_id), ...] - + for end_user_id, app_id in users_needing_migration: latest_release = app_latest_releases.get(app_id) if not latest_release: continue - + config = latest_release.config or {} - + # 如果 config 是字符串,解析为字典 if isinstance(config, str): try: @@ -1376,21 +1371,21 @@ def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session) except json_module.JSONDecodeError: logger.warning(f"Failed to parse config JSON for release {latest_release.id}") continue - + # 使用 MemoryConfigService 的提取方法 app = app_map.get(app_id) if not app: continue - + legacy_config_id, is_legacy_int = config_service.extract_memory_config_id( app_type=app.type, config=config ) - + if legacy_config_id: # 更新 user_data 中的 memory_config_id user_data[end_user_id]["memory_config_id"] = legacy_config_id - + # 记录需要回填的用户(稍后验证配置存在后再回填) end_user = end_user_map.get(end_user_id) if end_user: @@ -1399,7 +1394,7 @@ def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session) logger.info( f"Legacy int config detected for end_user {end_user_id}, will use workspace default" ) - + # 验证提取的 config_id 是否存在于数据库中 if users_to_backfill: config_ids_to_validate = list(set(cid for _, cid in users_to_backfill)) @@ -1407,17 +1402,17 @@ def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session) MemoryConfig.config_id.in_(config_ids_to_validate) ).all() valid_config_ids = {mc.config_id for mc in existing_configs} - + # 只回填存在的配置 valid_backfills = [ - (eu, cid) for eu, cid in users_to_backfill + (eu, cid) for eu, cid in users_to_backfill if cid in valid_config_ids ] invalid_backfills = [ - (eu, cid) for eu, cid in users_to_backfill + (eu, cid) for eu, cid in users_to_backfill if cid not in valid_config_ids ] - + if invalid_backfills: invalid_ids = [str(cid) for _, cid in invalid_backfills] logger.warning( @@ -1426,7 +1421,7 @@ def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session) # 清除 user_data 中无效的 config_id for eu, cid in invalid_backfills: user_data[str(eu.id)]["memory_config_id"] = None - + # 批量回填 end_user.memory_config_id if valid_backfills: for end_user, memory_config_id in valid_backfills: @@ -1437,7 +1432,7 @@ def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session) # 4. 收集需要查询的 memory_config_id 和需要回退的 workspace_id direct_config_ids = [] workspace_fallback_users = [] # [(end_user_id, workspace_id), ...] - + for end_user_id, data in user_data.items(): if data["memory_config_id"]: direct_config_ids.append(data["memory_config_id"]) @@ -1455,7 +1450,7 @@ def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session) # 6. 获取工作空间默认配置(需要逐个查询,因为 get_workspace_default_config 有复杂逻辑) workspace_default_configs = {} unique_workspace_ids = list(set(ws_id for _, ws_id in workspace_fallback_users)) - + if unique_workspace_ids: config_service = MemoryConfigService(db) for workspace_id in unique_workspace_ids: @@ -1466,11 +1461,11 @@ def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session) # 7. 构建最终结果 for end_user_id, data in user_data.items(): memory_config = None - + # 优先使用 end_user 直接分配的配置 if data["memory_config_id"]: memory_config = config_id_to_config.get(data["memory_config_id"]) - + # 回退到工作空间默认配置 if not memory_config: workspace_id = app_to_workspace.get(data["app_id"]) @@ -1486,4 +1481,4 @@ def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session) result[end_user_id] = {"memory_config_id": None, "memory_config_name": None} logger.info(f"Successfully retrieved {len(result)} connected configs") - return result \ No newline at end of file + return result diff --git a/api/app/services/memory_api_service.py b/api/app/services/memory_api_service.py index a8c39a5a..f86fbed8 100644 --- a/api/app/services/memory_api_service.py +++ b/api/app/services/memory_api_service.py @@ -140,9 +140,11 @@ class MemoryAPIService: try: # Delegate to MemoryAgentService + # Convert string message to list[dict] format expected by MemoryAgentService + messages = message if isinstance(message, list) else [{"role": "user", "content": message}] result = await MemoryAgentService().write_memory( end_user_id=end_user_id, - messages=message, + messages=messages, config_id=config_id, db=self.db, storage_type=storage_type, @@ -151,9 +153,18 @@ class MemoryAPIService: logger.info(f"Memory write successful for end_user: {end_user_id}") + # result may be a string "success" or a dict with a "status" key + # Preserve the full dict so callers don't silently lose extra fields + # (e.g. error codes, metadata) returned by MemoryAgentService. + if isinstance(result, dict): + return { + **result, + "status": result.get("status", "unknown"), + "end_user_id": end_user_id, + } return { - "status": "success" if result == "success" else result, - "end_user_id": end_user_id + "status": result if isinstance(result, str) else "success", + "end_user_id": end_user_id, } except ConfigurationError as e: diff --git a/api/app/services/memory_dashboard_service.py b/api/app/services/memory_dashboard_service.py index 8d6071cc..05aed57e 100644 --- a/api/app/services/memory_dashboard_service.py +++ b/api/app/services/memory_dashboard_service.py @@ -390,19 +390,59 @@ def get_rag_total_kb( current_user: User ) -> int: """ - 根据当前用户所在的workspace_id查询konwledges表所有不同id的数量 + 根据当前用户所在的workspace_id查询konwledges表中排除用户知识库(permission_id!='Memory')的数量 """ workspace_id = current_user.current_workspace_id - business_logger.info(f"获取RAG总知识库数: workspace_id={workspace_id}, 操作者: {current_user.username}") + business_logger.info(f"获取RAG总知识库数(排除用户知识库): workspace_id={workspace_id}, 操作者: {current_user.username}") try: - total_kb = knowledge_repository.get_total_kb_count_by_workspace(db, workspace_id) + total_kb = knowledge_repository.get_non_user_kb_count_by_workspace(db, workspace_id) business_logger.info(f"成功获取RAG总知识库数: {total_kb}") return total_kb except Exception as e: business_logger.error(f"获取RAG总知识库数失败: workspace_id={workspace_id} - {str(e)}") raise + +def get_rag_user_kb_total_chunk( + db: Session, + current_user: User +) -> int: + """ + 根据当前用户所在的workspace_id,从documents表统计所有用户知识库的chunk总数。 + 与 /end_users 接口保持同源:查询 file_name 匹配 end_user_id.txt 的文档 chunk_num 之和。 + """ + workspace_id = current_user.current_workspace_id + business_logger.info(f"获取用户知识库总chunk数(documents表): workspace_id={workspace_id}, 操作者: {current_user.username}") + + try: + from app.models.document_model import Document + from app.models.end_user_model import EndUser + from app.models.app_model import App + from sqlalchemy import func + + # 通过 App 关联取该 workspace 下所有 end_user_id + end_user_ids = [ + str(eid) for (eid,) in db.query(EndUser.id) + .join(App, EndUser.app_id == App.id) + .filter(App.workspace_id == workspace_id) + .all() + ] + if not end_user_ids: + return 0 + + file_names = [f"{uid}.txt" for uid in end_user_ids] + result = db.query(func.sum(Document.chunk_num)).filter( + Document.file_name.in_(file_names) + ).scalar() + + total_chunk = int(result or 0) + business_logger.info(f"成功获取用户知识库总chunk数: {total_chunk}") + return total_chunk + except Exception as e: + business_logger.error(f"获取用户知识库总chunk数失败: workspace_id={workspace_id} - {str(e)}") + raise + def get_current_user_total_chunk( end_user_id: str, db: Session, diff --git a/api/app/services/memory_konwledges_server.py b/api/app/services/memory_konwledges_server.py index 420f7ca1..b8961d33 100644 --- a/api/app/services/memory_konwledges_server.py +++ b/api/app/services/memory_konwledges_server.py @@ -1,45 +1,42 @@ # 修改 memory_konwledges_server.py 文件 -import asyncio import os -import re import uuid from pathlib import Path from typing import Optional -from pydantic import BaseModel, Field +from fastapi import HTTPException, status +from pydantic import BaseModel +from sqlalchemy.orm import Session +from app.celery_app import celery_app +from app.core.config import settings +from app.core.logging_config import get_api_logger from app.core.rag.models.chunk import DocumentChunk from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory from app.core.response_utils import success -from app.db import get_db -from app.schemas import file_schema, document_schema -from fastapi import APIRouter, Depends, HTTPException, status, File, UploadFile, Query +from app.db import get_db_context from app.models.document_model import Document -import uuid -from sqlalchemy.orm import Session -from fastapi import HTTPException, status - -from app.core.config import settings from app.models.user_model import User +from app.schemas import file_schema, document_schema from app.schemas.file_schema import CustomTextFileCreate from app.services import document_service, file_service, knowledge_service -from app.celery_app import celery_app -from app.core.logging_config import get_api_logger -from app.schemas.file_schema import CustomTextFileCreate -from app.db import get_db + # 创建一个简单的用户类用于测试 api_logger = get_api_logger() + class ChunkCreate(BaseModel): content: str + + class SimpleUser: def __init__(self, user_id: str): # 确保ID是UUID类型 self.id = user_id self.username = user_id -'''解析''' + async def parse_document_by_id(document_id: uuid.UUID, db: Session, current_user: User): """ 解析指定文档 @@ -120,7 +117,7 @@ async def parse_document_by_id(document_id: uuid.UUID, db: Session, current_user api_logger.error(f"文档解析失败: document_id={document_id} - {str(e)}") raise -'''获取块ID''' + async def get_document_chunks( kb_id: uuid.UUID, document_id: uuid.UUID, @@ -198,7 +195,7 @@ async def get_document_chunks( return success(data=result, msg="文档块列表查询成功") -'''查找文档ID''' + def find_document_id_by_kb_and_filename( db: Session, kb_id: str, @@ -231,7 +228,7 @@ def find_document_id_by_kb_and_filename( except Exception as e: return None -'''获取知识库ID''' + def find_documents_by_kb_id( db: Session, kb_id: str, @@ -268,18 +265,14 @@ def find_documents_by_kb_id( except Exception as e: return [] -''''上传文件''' + async def memory_konwledges_up( kb_id: str, parent_id: str, create_data: file_schema.CustomTextFileCreate, - db: Session = Depends(get_db), - current_user: SimpleUser = None, # 修改为SimpleUser + db: Session, + current_user: SimpleUser, ): - # 如果没有提供current_user,则创建一个默认的 - if current_user is None: - current_user = SimpleUser("5d27df0b-7eec-4fa6-9f8b-0f9b7e852f60") - content_bytes = create_data.content.encode('utf-8') file_size = len(content_bytes) print(f"file size: {file_size} byte") @@ -350,8 +343,6 @@ async def memory_konwledges_up( return success(data=document_schema.Document.model_validate(db_document), msg="custom text upload successful") -'''添加新块''' - async def create_document_chunk( kb_id: uuid.UUID, @@ -417,7 +408,7 @@ async def create_document_chunk( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"查询文档块失败: {error_msg}" ) - + sort_id = sort_id + 1 # 5. 创建文档块 @@ -450,6 +441,7 @@ async def create_document_chunk( return success(data=chunk, msg="文档块创建成功") + async def write_rag(end_user_id, message, user_rag_memory_id): """ 将消息写入 RAG 知识库 @@ -483,15 +475,12 @@ async def write_rag(end_user_id, message, user_rag_memory_id): detail=f"知识库ID格式无效: {user_rag_memory_id}" ) - db_gen = get_db() - db = next(db_gen) - - try: + with get_db_context() as db: create_data = CustomTextFileCreate(title=end_user_id, content=message) current_user = SimpleUser(user_rag_memory_id) # 检查文档是否已存在 document = find_document_id_by_kb_and_filename(db=db, kb_id=user_rag_memory_id, file_name=f"{end_user_id}.txt") - print('======',document) + print('======', document) api_logger.info(f"查找文档结果: document_id={document}") if document is not None: # 文档已存在,直接添加新块 @@ -528,6 +517,3 @@ async def write_rag(end_user_id, message, user_rag_memory_id): else: api_logger.error(f"创建文档后无法找到文档ID: end_user_id={end_user_id}") return result - finally: - # 确保数据库会话被关闭 - db.close() \ No newline at end of file diff --git a/api/app/services/memory_storage_service.py b/api/app/services/memory_storage_service.py index 1083f750..02fd1051 100644 --- a/api/app/services/memory_storage_service.py +++ b/api/app/services/memory_storage_service.py @@ -115,6 +115,17 @@ class DataConfigService: # 数据配置服务类(PostgreSQL) # --- Create --- def create(self, params: ConfigParamsCreate) -> Dict[str, Any]: # 创建配置参数(仅名称与描述) + # 业务层检查同一工作空间下是否已存在同名配置 + if params.workspace_id and params.config_name: + from app.models.memory_config_model import MemoryConfig + existing = ( + self.db.query(MemoryConfig) + .filter_by(workspace_id=params.workspace_id, config_name=params.config_name) + .first() + ) + if existing: + raise ValueError(f"DUPLICATE_CONFIG_NAME:{params.config_name}") + # 如果workspace_id存在且模型字段未全部指定,则自动获取 if params.workspace_id and not all([params.llm_id, params.embedding_id, params.rerank_id]): configs = self._get_workspace_configs(params.workspace_id) @@ -211,6 +222,7 @@ class DataConfigService: # 数据配置服务类(PostgreSQL) "apply_id": config.apply_id, "scene_id": str(config.scene_id) if config.scene_id else None, "scene_name": scene_name, # 新增:场景名称 + "is_system_default": config.is_default, # 是否为系统默认配置 "llm_id": config.llm_id, "embedding_id": config.embedding_id, "rerank_id": config.rerank_id, diff --git a/api/app/services/model_service.py b/api/app/services/model_service.py index aa8cfbac..cba25f32 100644 --- a/api/app/services/model_service.py +++ b/api/app/services/model_service.py @@ -90,7 +90,8 @@ class ModelConfigService: api_key: str, api_base: Optional[str] = None, model_type: str = "llm", - test_message: str = "Hello" + test_message: str = "Hello", + is_omni: bool = False ) -> Dict[str, Any]: """验证模型配置是否有效 @@ -102,6 +103,7 @@ class ModelConfigService: api_base: API基础URL model_type: 模型类型 (llm/chat/embedding/rerank) test_message: 测试消息 + is_omni: 是否为Omni模型 Returns: Dict: 验证结果 @@ -119,6 +121,7 @@ class ModelConfigService: provider=provider, api_key=api_key, base_url=api_base, + is_omni=is_omni, temperature=0.7, max_tokens=100 ) @@ -257,8 +260,9 @@ class ModelConfigService: provider=model_data.provider, api_key=api_key_data.api_key, api_base=api_key_data.api_base, - model_type=model_data.type, # 传递模型类型 - test_message="Hello" + model_type=model_data.type, + test_message="Hello", + is_omni=model_data.is_omni ) if not validation_result["valid"]: raise BusinessException( @@ -279,6 +283,9 @@ class ModelConfigService: for api_key_data in api_key_datas: api_key_data.model_name = model_data.name api_key_data.provider = model_data.provider + # 同步capability和is_omni + api_key_data.capability = model_data.capability + api_key_data.is_omni = model_data.is_omni api_key_create_schema = ModelApiKeyCreate( model_config_ids=[model.id], **api_key_data.model_dump() @@ -473,6 +480,9 @@ class ModelApiKeyService: model_config = ModelConfigRepository.get_by_id(db, model_config_id) if not model_config: continue + + data.is_omni = model_config.is_omni + data.capability = model_config.capability # 从ModelBase获取model_name model_name = model_config.model_base.name if model_config.model_base else model_config.name @@ -497,6 +507,8 @@ class ModelApiKeyService: existing_key.config = data.config existing_key.priority = data.priority existing_key.model_name = model_name + existing_key.capability = data.capability + existing_key.is_omni = data.is_omni # 检查是否已关联该模型配置 if model_config not in existing_key.model_configs: @@ -513,7 +525,8 @@ class ModelApiKeyService: api_key=data.api_key, api_base=data.api_base, model_type=model_config.type, - test_message="Hello" + test_message="Hello", + is_omni=data.is_omni ) if not validation_result["valid"]: # 记录验证失败的模型,但不抛出异常 @@ -528,6 +541,8 @@ class ModelApiKeyService: provider=data.provider, api_key=data.api_key, api_base=data.api_base, + capability=data.capability, + is_omni=data.is_omni, config=data.config, is_active=data.is_active, priority=data.priority @@ -550,6 +565,10 @@ class ModelApiKeyService: model_config = ModelConfigRepository.get_by_id(db, model_config_id) if not model_config: raise BusinessException("模型配置不存在", BizCode.MODEL_NOT_FOUND) + if api_key_data.is_omni is None: + api_key_data.is_omni = model_config.is_omni + if api_key_data.capability is None: + api_key_data.capability = model_config.capability # 检查API Key是否已存在(包括软删除),需要考虑tenant_id existing_key = db.query(ModelApiKey).join( @@ -572,6 +591,8 @@ class ModelApiKeyService: existing_key.config = api_key_data.config existing_key.priority = api_key_data.priority existing_key.model_name = api_key_data.model_name + existing_key.capability = api_key_data.capability + existing_key.is_omni = api_key_data.is_omni # 检查是否已关联该模型配置 if model_config not in existing_key.model_configs: @@ -589,7 +610,8 @@ class ModelApiKeyService: api_key=api_key_data.api_key, api_base=api_key_data.api_base, model_type=model_config.type, - test_message="Hello" + test_message="Hello", + is_omni=api_key_data.is_omni ) if not validation_result["valid"]: raise BusinessException( @@ -620,7 +642,8 @@ class ModelApiKeyService: api_key=api_key_data.api_key or existing_api_key.api_key, api_base=api_key_data.api_base or existing_api_key.api_base, model_type=model_config.type, - test_message="Hello" + test_message="Hello", + is_omni=model_config.is_omni ) if not validation_result["valid"]: raise BusinessException( @@ -755,6 +778,8 @@ class ModelBaseService: "type": model_base.type, "logo": model_base.logo, "description": model_base.description, + "capability": model_base.capability, + "is_omni": model_base.is_omni, "is_composite": False } model_config = ModelConfigRepository.create(db, model_config_data) diff --git a/api/app/services/multi_agent_orchestrator.py b/api/app/services/multi_agent_orchestrator.py index d1aa46d1..f42ee95a 100644 --- a/api/app/services/multi_agent_orchestrator.py +++ b/api/app/services/multi_agent_orchestrator.py @@ -123,11 +123,14 @@ class MultiAgentOrchestrator: user_id: 用户 ID variables: 变量参数 use_llm_routing: 是否使用 LLM 路由 + web_search: 是否启用网络搜索 + memory: 是否启用记忆功能 + storage_type: 存储类型 + user_rag_memory_id: 用户 RAG 记忆 ID Yields: SSE 格式的事件流 """ - import json start_time = time.time() @@ -200,7 +203,8 @@ class MultiAgentOrchestrator: except Exception as e: logger.error( "多 Agent 任务执行失败(流式)", - extra={"error": str(e), "mode": self._normalized_mode} + extra={"error": str(e), "mode": self._normalized_mode}, + exc_info=True ) # 发送错误事件 yield self._format_sse_event("error", { @@ -1267,7 +1271,7 @@ class MultiAgentOrchestrator: Yields: SSE 格式的事件流 """ - from app.services.draft_run_service import DraftRunService + from app.services.draft_run_service import AgentRunService # 获取模型配置 model_config = self.db.get(ModelConfig, agent_config.default_model_config_id) @@ -1278,7 +1282,7 @@ class MultiAgentOrchestrator: ) # 流式执行 Agent - draft_service = DraftRunService(self.db) + draft_service = AgentRunService(self.db) async for event in draft_service.run_stream( agent_config=agent_config, model_config=model_config, @@ -1320,7 +1324,7 @@ class MultiAgentOrchestrator: Returns: 执行结果 """ - from app.services.draft_run_service import DraftRunService + from app.services.draft_run_service import AgentRunService # 获取模型配置 model_config = self.db.get(ModelConfig, agent_config.default_model_config_id) @@ -1331,7 +1335,7 @@ class MultiAgentOrchestrator: ) # 执行 Agent - draft_service = DraftRunService(self.db) + draft_service = AgentRunService(self.db) result = await draft_service.run( agent_config=agent_config, model_config=model_config, @@ -1633,6 +1637,7 @@ class MultiAgentOrchestrator: self.memory = config_data.get("memory") self.variables = config_data.get("variables", []) self.tools = config_data.get("tools", {}) + self.skills = config_data.get("skills", {}) self.default_model_config_id = release.default_model_config_id return AgentConfigProxy(release, app, config_data) @@ -2593,6 +2598,7 @@ class MultiAgentOrchestrator: provider=api_key_config.provider, api_key=api_key_config.api_key, base_url=api_key_config.api_base, + is_omni=api_key_config.is_omni, temperature=0.7, # 整合任务使用中等温度 max_tokens=2000 ) @@ -2758,6 +2764,7 @@ class MultiAgentOrchestrator: provider=api_key_config.provider, api_key=api_key_config.api_key, base_url=api_key_config.api_base, + is_omni=api_key_config.is_omni, temperature=0.7, max_tokens=2000, extra_params={"streaming": True} # 启用流式输出 diff --git a/api/app/services/multi_agent_service.py b/api/app/services/multi_agent_service.py index c52814ed..751099d5 100644 --- a/api/app/services/multi_agent_service.py +++ b/api/app/services/multi_agent_service.py @@ -267,7 +267,7 @@ class MultiAgentService: # 2. 验证模型配置(如果提供了) if data.default_model_config_id: - model_api_key = ModelApiKeyService.get_a_api_key(self.db, data.default_model_config_id) + model_api_key = ModelApiKeyService.get_available_api_key(self.db, data.default_model_config_id) if not model_api_key: raise ResourceNotFoundException("模型配置", str(data.default_model_config_id)) diff --git a/api/app/services/multimodal_service.py b/api/app/services/multimodal_service.py index bfb23a56..9b06c287 100644 --- a/api/app/services/multimodal_service.py +++ b/api/app/services/multimodal_service.py @@ -9,47 +9,100 @@ - OpenAI: 支持 URL 和 base64 格式 """ import uuid -from typing import List, Dict, Any, Optional, Protocol +import httpx +import base64 +from typing import List, Dict, Any, Optional +from abc import ABC, abstractmethod from sqlalchemy.orm import Session +from docx import Document +import io +import PyPDF2 from app.core.logging_config import get_business_logger from app.core.exceptions import BusinessException from app.core.error_codes import BizCode from app.schemas.app_schema import FileInput, FileType, TransferMethod -from app.models.generic_file_model import GenericFile +from app.models.file_metadata_model import FileMetadata +from app.core.config import settings +from app.services.audio_transcription_service import AudioTranscriptionService logger = get_business_logger() -class ImageFormatStrategy(Protocol): - """图片格式策略接口""" +class MultimodalFormatStrategy(ABC): + """多模态格式策略基类""" + + @abstractmethod + async def format_image(self, url: str) -> Dict[str, Any]: + """格式化图片""" + pass + + @abstractmethod + async def format_document(self, file_name: str, text: str) -> Dict[str, Any]: + """格式化文档""" + pass + + @abstractmethod + async def format_audio(self, file_type: str, url: str) -> Dict[str, Any]: + """格式化音频""" + pass + + @abstractmethod + async def format_video(self, url: str) -> Dict[str, Any]: + """格式化视频""" + pass + + +class DashScopeFormatStrategy(MultimodalFormatStrategy): + """通义千问策略""" async def format_image(self, url: str) -> Dict[str, Any]: - """将图片 URL 转换为特定 provider 的格式""" - ... - - -class DashScopeImageStrategy: - """通义千问图片格式策略""" - - async def format_image(self, url: str) -> Dict[str, Any]: - """通义千问格式: {"type": "image", "image": "url"}""" + """通义千问图片格式:{"type": "image", "image": "url"}""" return { "type": "image", "image": url } + async def format_document(self, file_name: str, text: str) -> Dict[str, Any]: + """通义千问文档格式""" + return { + "type": "text", + "text": f"\n{text}\n" + } -class BedrockImageStrategy: - """Bedrock/Anthropic 图片格式策略""" + async def format_audio(self, file_type: str, url: str, transcription: Optional[str] = None) -> Dict[str, Any]: + """ + 通义千问音频格式 + - 原生支持: qwen-audio 系列 + - 其他模型: 需要转录为文本 + """ + if transcription: + return { + "type": "text", + "text": f"" + } + # 通义千问音频格式:{"type": "audio", "audio": "url"} + return { + "type": "audio", + "audio": url + } + + async def format_video(self, url: str) -> Dict[str, Any]: + """通义千问视频格式(qwen-vl 系列原生支持)""" + return { + "type": "video", + "video": url + } + + +class BedrockFormatStrategy(MultimodalFormatStrategy): + """Bedrock/Anthropic 策略""" async def format_image(self, url: str) -> Dict[str, Any]: """ Bedrock/Anthropic 格式: base64 编码 {"type": "image", "source": {"type": "base64", "media_type": "...", "data": "..."}} """ - import httpx - import base64 from mimetypes import guess_type logger.info(f"下载并编码图片: {url}") @@ -84,9 +137,46 @@ class BedrockImageStrategy: } } + async def format_document(self, file_name: str, text: str) -> Dict[str, Any]: + """Bedrock/Anthropic 文档格式(需要 base64 编码)""" + # Bedrock 文档需要 base64 编码 + text_bytes = text.encode('utf-8') + base64_text = base64.b64encode(text_bytes).decode('utf-8') -class OpenAIImageStrategy: - """OpenAI 图片格式策略""" + return { + "type": "document", + "source": { + "type": "base64", + "media_type": "text/plain", + "data": base64_text + } + } + + async def format_audio(self, file_type: str, url: str, transcription: Optional[str] = None) -> Dict[str, Any]: + """ + Bedrock/Anthropic 音频格式 + 不支持原生音频,必须转录为文本 + """ + if transcription: + return { + "type": "text", + "text": f"[音频转录]\n{transcription}" + } + return { + "type": "text", + "text": "[音频文件:Bedrock 不支持原生音频,请启用音频转文本功能]" + } + + async def format_video(self, url: str) -> Dict[str, Any]: + """Bedrock/Anthropic 视频格式""" + return { + "type": "text", + "text": f"" + } + + +class OpenAIFormatStrategy(MultimodalFormatStrategy): + """OpenAI 策略""" async def format_image(self, url: str) -> Dict[str, Any]: """OpenAI 格式: {"type": "image_url", "image_url": {"url": "..."}}""" @@ -97,29 +187,97 @@ class OpenAIImageStrategy: } } + async def format_document(self, file_name: str, text: str) -> Dict[str, Any]: + """OpenAI 文档格式""" + return { + "type": "text", + "text": f"\n{text}\n" + } + + async def format_audio(self, file_type: str, url: str, transcription: Optional[str] = None) -> Dict[str, Any]: + """ + OpenAI 音频格式 + - gpt-4o-audio 系列支持原生音频(需要 base64 编码) + - 其他模型使用转录文本 + """ + if transcription: + return { + "type": "text", + "text": f"" + } + + # OpenAI 音频需要 base64 编码 + try: + async with httpx.AsyncClient(timeout=30.0) as client: + response = await client.get(url) + response.raise_for_status() + audio_data = response.content + base64_audio = base64.b64encode(audio_data).decode('utf-8') + # 1. 优先从 file_type (MIME) 取扩展名 + file_ext = file_type.split('/')[-1] if file_type and '/' in file_type else None + # 2. 从响应头 content-type 取 + if not file_ext: + ct = response.headers.get("content-type", "") + file_ext = ct.split('/')[-1].split(';')[0].strip() if '/' in ct else None + # 3. 从 URL 路径取扩展名 + if not file_ext: + file_ext = url.split('?')[0].rsplit('.', 1)[-1].lower() or None + # 4. 默认 wav + # supported_ext = {"wav", "mp3", "mp4", "ogg", "flac", "webm", "m4a", "wave", "x-m4a"} + file_ext = "wav" if not file_ext else file_ext + + return { + "type": "input_audio", + "input_audio": { + "data": f"data:;base64,{base64_audio}", + "format": file_ext + } + } + except Exception as e: + logger.error(f"下载音频失败: {e}") + return { + "type": "text", + "text": f"[音频处理失败: {str(e)}]" + } + + async def format_video(self, url: str) -> Dict[str, Any]: + """OpenAI 视频格式""" + return { + "type": "video_url", + "video_url": { + "url": url + } + } + # Provider 到策略的映射 PROVIDER_STRATEGIES = { - "dashscope": DashScopeImageStrategy, - "bedrock": BedrockImageStrategy, - "anthropic": BedrockImageStrategy, - "openai": OpenAIImageStrategy, + "dashscope": DashScopeFormatStrategy, + "bedrock": BedrockFormatStrategy, + "anthropic": BedrockFormatStrategy, + "openai": OpenAIFormatStrategy, } class MultimodalService: """多模态文件处理服务""" - def __init__(self, db: Session, provider: str = "dashscope"): + def __init__(self, db: Session, provider: str = "dashscope", api_key: Optional[str] = None, enable_audio_transcription: bool = False, is_omni: bool = False): """ 初始化多模态服务 Args: db: 数据库会话 - provider: 模型提供商(dashscope, bedrock, anthropic 等) + provider: 模型提供商(dashscope, bedrock, anthropic, openai 等) + api_key: API 密钥(用于音频转文本) + enable_audio_transcription: 是否启用音频转文本 + is_omni: 是否为 Omni 模型(dashscope 的 omni 模型需要使用 OpenAI 兼容格式) """ self.db = db self.provider = provider.lower() + self.api_key = api_key + self.enable_audio_transcription = enable_audio_transcription + self.is_omni = is_omni async def process_files( self, @@ -137,20 +295,32 @@ class MultimodalService: if not files: return [] + # 获取对应的策略 + # dashscope 的 omni 模型使用 OpenAI 兼容格式 + if self.provider == "dashscope" and self.is_omni: + strategy_class = OpenAIFormatStrategy + else: + strategy_class = PROVIDER_STRATEGIES.get(self.provider) + if not strategy_class: + logger.warning(f"未找到 provider '{self.provider}' 的策略,使用默认策略") + strategy_class = DashScopeFormatStrategy + + strategy = strategy_class() + result = [] for idx, file in enumerate(files): try: if file.type == FileType.IMAGE: - content = await self._process_image(file) + content = await self._process_image(file, strategy) result.append(content) elif file.type == FileType.DOCUMENT: - content = await self._process_document(file) + content = await self._process_document(file, strategy) result.append(content) elif file.type == FileType.AUDIO: - content = await self._process_audio(file) + content = await self._process_audio(file, strategy) result.append(content) elif file.type == FileType.VIDEO: - content = await self._process_video(file) + content = await self._process_video(file, strategy) result.append(content) else: logger.warning(f"不支持的文件类型: {file.type}") @@ -172,55 +342,29 @@ class MultimodalService: logger.info(f"成功处理 {len(result)}/{len(files)} 个文件,provider={self.provider}") return result - async def _process_image(self, file: FileInput) -> Dict[str, Any]: + async def _process_image(self, file: FileInput, strategy) -> Dict[str, Any]: """ 处理图片文件 Args: file: 图片文件输入 + strategy: 格式化策略 Returns: - Dict: 根据 provider 返回不同格式 - - Anthropic/Bedrock: {"type": "image", "source": {"type": "base64", "media_type": "...", "data": "..."}} - - 通义千问: {"type": "image", "image": "url"} + Dict: 根据 provider 返回不同格式的图片内容 """ - url = await self.get_file_url(file) - - logger.debug(f"处理图片: {url}, provider={self.provider}") - - # 根据 provider 返回不同格式 - if self.provider in ["bedrock", "anthropic"]: - # Anthropic/Bedrock 只支持 base64 格式,需要下载并转换 - try: - logger.info(f"开始下载并编码图片: {url}") - base64_data, media_type = await self._download_and_encode_image(url) - result = { - "type": "image", - "source": { - "type": "base64", - "media_type": media_type, - "data": base64_data[:100] + "..." # 只记录前100个字符 - } - } - logger.info(f"图片编码完成: media_type={media_type}, data_length={len(base64_data)}") - # 返回完整数据 - result["source"]["data"] = base64_data - return result - except Exception as e: - logger.error(f"下载并编码图片失败: {e}", exc_info=True) - # 返回错误提示 - return { - "type": "text", - "text": f"[图片加载失败: {str(e)}]" - } - else: - # 通义千问等其他格式支持 URL + try: + url = await self.get_file_url(file) + return await strategy.format_image(url) + except Exception as e: + logger.error(f"处理图片失败: {e}", exc_info=True) return { - "type": "image", - "image": url + "type": "text", + "text": f"[图片处理失败: {str(e)}]" } - async def _download_and_encode_image(self, url: str) -> tuple[str, str]: + @staticmethod + async def _download_and_encode_image(url: str) -> tuple[str, str]: """ 下载图片并转换为 base64 @@ -230,8 +374,6 @@ class MultimodalService: Returns: tuple: (base64_data, media_type) """ - import httpx - import base64 from mimetypes import guess_type # 下载图片 @@ -258,15 +400,16 @@ class MultimodalService: return base64_data, media_type - async def _process_document(self, file: FileInput) -> Dict[str, Any]: + async def _process_document(self, file: FileInput, strategy) -> Dict[str, Any]: """ 处理文档文件(PDF、Word 等) Args: file: 文档文件输入 + strategy: 格式化策略 Returns: - Dict: text 格式的内容(包含提取的文本) + Dict: 根据 provider 返回不同格式的文档内容 """ if file.transfer_method == TransferMethod.REMOTE_URL: # 远程文档暂不支持提取 @@ -277,48 +420,68 @@ class MultimodalService: else: # 本地文件,提取文本内容 text = await self._extract_document_text(file.upload_file_id) - generic_file = self.db.query(GenericFile).filter( - GenericFile.id == file.upload_file_id + file_metadata = self.db.query(FileMetadata).filter( + FileMetadata.id == file.upload_file_id ).first() - file_name = generic_file.file_name if generic_file else "unknown" + file_name = file_metadata.file_name if file_metadata else "unknown" - return { - "type": "text", - "text": f"\n{text}\n" - } + # 使用策略格式化文档 + return await strategy.format_document(file_name, text) - async def _process_audio(self, file: FileInput) -> Dict[str, Any]: + async def _process_audio(self, file: FileInput, strategy) -> Dict[str, Any]: """ 处理音频文件 Args: file: 音频文件输入 + strategy: 格式化策略 Returns: - Dict: 音频内容(暂时返回占位符) + Dict: 根据 provider 返回不同格式的音频内容 """ - # TODO: 实现音频转文字功能 - return { - "type": "text", - "text": "[音频文件,暂不支持处理]" - } + try: + url = await self.get_file_url(file) - async def _process_video(self, file: FileInput) -> Dict[str, Any]: + # 如果启用音频转文本且有 API Key + transcription = None + if self.enable_audio_transcription and self.api_key: + logger.info(f"开始音频转文本: {url}") + if self.provider == "dashscope": + transcription = await AudioTranscriptionService.transcribe_dashscope(url, self.api_key) + elif self.provider == "openai": + transcription = await AudioTranscriptionService.transcribe_openai(url, self.api_key) + else: + logger.warning(f"Provider {self.provider} 不支持音频转文本") + + return await strategy.format_audio(file.file_type, url, transcription) + except Exception as e: + logger.error(f"处理音频失败: {e}", exc_info=True) + return { + "type": "text", + "text": f"[音频处理失败: {str(e)}]" + } + + async def _process_video(self, file: FileInput, strategy) -> Dict[str, Any]: """ 处理视频文件 Args: file: 视频文件输入 + strategy: 格式化策略 Returns: - Dict: 视频内容(暂时返回占位符) + Dict: 根据 provider 返回不同格式的视频内容 """ - # TODO: 实现视频处理功能 - return { - "type": "text", - "text": "[视频文件,暂不支持处理]" - } + try: + url = await self.get_file_url(file) + return await strategy.format_video(url) + except Exception as e: + logger.error(f"处理视频失败: {e}", exc_info=True) + return { + "type": "text", + "text": f"[视频处理失败: {str(e)}]" + } async def get_file_url(self, file: FileInput) -> str: """ @@ -336,26 +499,22 @@ class MultimodalService: if file.transfer_method == TransferMethod.REMOTE_URL: return file.url else: - # 本地文件,通过 file_storage 系统获取永久访问 URL - from app.models.file_metadata_model import FileMetadata - from app.core.config import settings - file_id = file.upload_file_id print("="*50) print("file_id",file_id) - + # 查询 FileMetadata file_metadata = self.db.query(FileMetadata).filter( FileMetadata.id == file_id, FileMetadata.status == "completed" ).first() - + if not file_metadata: raise BusinessException( f"文件不存在或已删除: {file_id}", BizCode.NOT_FOUND ) - + # 返回永久URL server_url = settings.FILE_LOCAL_SERVER_URL return f"{server_url}/storage/permanent/{file_id}" @@ -370,58 +529,79 @@ class MultimodalService: Returns: str: 提取的文本内容 """ - generic_file = self.db.query(GenericFile).filter( - GenericFile.id == file_id, - GenericFile.status == "active" + file_metadata = self.db.query(FileMetadata).filter( + FileMetadata.id == file_id, + FileMetadata.status == "completed" ).first() - if not generic_file: + if not file_metadata: raise BusinessException( f"文件不存在或已删除: {file_id}", BizCode.NOT_FOUND ) - # TODO: 根据文件类型提取文本 - # - PDF: 使用 PyPDF2 或 pdfplumber - # - Word: 使用 python-docx - # - TXT/MD: 直接读取 - - file_ext = generic_file.file_ext.lower() + file_ext = file_metadata.file_ext.lower() + server_url = settings.FILE_LOCAL_SERVER_URL + file_url = f"{server_url}/storage/permanent/{file_id}" if file_ext in ['.txt', '.md', '.markdown']: - return await self._read_text_file(generic_file.storage_path) + return await self._read_text_file(file_url) elif file_ext == '.pdf': - return await self._extract_pdf_text(generic_file.storage_path) + return await self._extract_pdf_text(file_url) elif file_ext in ['.doc', '.docx']: - return await self._extract_word_text(generic_file.storage_path) + return await self._extract_word_text(file_url) else: return f"[不支持的文档格式: {file_ext}]" - async def _read_text_file(self, storage_path: str) -> str: + @staticmethod + async def _read_text_file(file_url: str) -> str: """读取纯文本文件""" try: - with open(storage_path, 'r', encoding='utf-8') as f: - return f.read() + # 下载文件 + async with httpx.AsyncClient(timeout=30.0) as client: + response = await client.get(file_url) + response.raise_for_status() + return response.text except Exception as e: logger.error(f"读取文本文件失败: {e}") return f"[文件读取失败: {str(e)}]" - async def _extract_pdf_text(self, storage_path: str) -> str: + @staticmethod + async def _extract_pdf_text(file_url: str) -> str: """提取 PDF 文本""" try: - # TODO: 实现 PDF 文本提取 - # import PyPDF2 或 pdfplumber - return "[PDF 文本提取功能待实现]" + # 下载 PDF 文件 + async with httpx.AsyncClient(timeout=30.0) as client: + response = await client.get(file_url) + response.raise_for_status() + pdf_data = response.content + + # 使用 BytesIO 读取 PDF + text_parts = [] + pdf_file = io.BytesIO(pdf_data) + pdf_reader = PyPDF2.PdfReader(pdf_file) + for page in pdf_reader.pages: + text_parts.append(page.extract_text()) + return '\n'.join(text_parts) except Exception as e: logger.error(f"提取 PDF 文本失败: {e}") return f"[PDF 提取失败: {str(e)}]" - async def _extract_word_text(self, storage_path: str) -> str: + @staticmethod + async def _extract_word_text(file_url: str) -> str: """提取 Word 文档文本""" try: - # TODO: 实现 Word 文本提取 - # import docx - return "[Word 文本提取功能待实现]" + # 下载 Word 文件 + async with httpx.AsyncClient(timeout=30.0) as client: + response = await client.get(file_url) + response.raise_for_status() + word_data = response.content + + # 使用 BytesIO 读取 Word 文档 + word_file = io.BytesIO(word_data) + doc = Document(word_file) + text_parts = [paragraph.text for paragraph in doc.paragraphs] + return '\n'.join(text_parts) except Exception as e: logger.error(f"提取 Word 文本失败: {e}") return f"[Word 提取失败: {str(e)}]" diff --git a/api/app/services/prompt_optimizer_service.py b/api/app/services/prompt_optimizer_service.py index 99edcc0e..184220a8 100644 --- a/api/app/services/prompt_optimizer_service.py +++ b/api/app/services/prompt_optimizer_service.py @@ -184,7 +184,8 @@ class PromptOptimizerService: model_name=api_config.model_name, provider=api_config.provider, api_key=api_config.api_key, - base_url=api_config.api_base + base_url=api_config.api_base, + is_omni=api_config.is_omni ), type=ModelType(model_config.type)) try: prompt_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'prompt') diff --git a/api/app/services/shared_chat_service.py b/api/app/services/shared_chat_service.py index 89d3f3d6..0d659832 100644 --- a/api/app/services/shared_chat_service.py +++ b/api/app/services/shared_chat_service.py @@ -247,6 +247,7 @@ class SharedChatService: api_key=api_key_obj.api_key, provider=api_key_obj.provider, api_base=api_key_obj.api_base, + is_omni=api_key_obj.is_omni, temperature=model_parameters.get("temperature", 0.7), max_tokens=model_parameters.get("max_tokens", 2000), system_prompt=system_prompt, @@ -454,6 +455,7 @@ class SharedChatService: api_key=api_key_obj.api_key, provider=api_key_obj.provider, api_base=api_key_obj.api_base, + is_omni=api_key_obj.is_omni, temperature=model_parameters.get("temperature", 0.7), max_tokens=model_parameters.get("max_tokens", 2000), system_prompt=system_prompt, diff --git a/api/app/services/skill_service.py b/api/app/services/skill_service.py index 5eb80795..0b7de6cf 100644 --- a/api/app/services/skill_service.py +++ b/api/app/services/skill_service.py @@ -121,7 +121,7 @@ class SkillService: if skill and skill.is_active: # 加载技能关联的工具 for tool_config in skill.tools: - tool = tool_service._get_tool_instance(tool_config.get("tool_id", ""), tenant_id) + tool = tool_service.get_tool_instance(tool_config.get("tool_id", ""), tenant_id) if tool: langchain_tool = tool.to_langchain_tool(tool_config.get("operation", None)) tools.append(langchain_tool) diff --git a/api/app/services/tool_service.py b/api/app/services/tool_service.py index 2bb96e53..f6e2ccce 100644 --- a/api/app/services/tool_service.py +++ b/api/app/services/tool_service.py @@ -8,6 +8,8 @@ from datetime import datetime from sqlalchemy.orm import Session +from app.core.error_codes import BizCode +from app.core.exceptions import BusinessException from app.core.tools.mcp import MCPToolManager, SimpleMCPClient from app.repositories.tool_repository import ( ToolRepository, BuiltinToolRepository, CustomToolRepository, @@ -79,6 +81,18 @@ class ToolService: config = self.tool_repo.find_by_id_and_tenant(self.db, uuid.UUID(tool_id), tenant_id) return self._config_to_info(config) if config else None + def _check_name_duplicate(self, name: str, tool_type: ToolType, tenant_id: uuid.UUID, exclude_id: Optional[uuid.UUID] = None): + """检查工具名称是否重复""" + query = self.db.query(ToolConfig).filter( + ToolConfig.name == name, + ToolConfig.tool_type == tool_type.value, + ToolConfig.tenant_id == tenant_id + ) + if exclude_id: + query = query.filter(ToolConfig.id != exclude_id) + if query.first(): + raise BusinessException(f"工具名称 '{name}' 已存在", BizCode.DUPLICATE_NAME) + def create_tool( self, name: str, @@ -92,6 +106,7 @@ class ToolService: """创建工具""" if tool_type == ToolType.BUILTIN: raise ValueError("内置工具不允许创建") + self._check_name_duplicate(name, tool_type, tenant_id) try: # 创建基础配置 @@ -141,6 +156,7 @@ class ToolService: raise ValueError("内置工具不允许修改名称、描述和图标") try: if name: + self._check_name_duplicate(name, config_obj.tool_type, tenant_id, exclude_id=config_obj.id) config_obj.name = name if description: config_obj.description = description @@ -209,7 +225,7 @@ class ToolService: try: # 获取工具实例 - tool = self._get_tool_instance(tool_id, tenant_id) + tool = self.get_tool_instance(tool_id, tenant_id) if not tool: return ToolResult.error_result( error=f"工具不存在: {tool_id}", @@ -335,7 +351,7 @@ class ToolService: return [] # 获取工具实例 - tool_instance = self._get_tool_instance(str(config.id), config.tenant_id) + tool_instance = self.get_tool_instance(str(config.id), config.tenant_id) if not tool_instance: return [] @@ -792,7 +808,7 @@ class ToolService: """获取工具配置""" return self.tool_repo.find_by_id_and_tenant(self.db, uuid.UUID(tool_id), tenant_id) - def _get_tool_instance(self, tool_id: str, tenant_id: uuid.UUID) -> Optional[BaseTool]: + def get_tool_instance(self, tool_id: str, tenant_id: uuid.UUID) -> Optional[BaseTool]: """获取工具实例""" if tool_id in self._tool_cache: return self._tool_cache[tool_id] @@ -1416,7 +1432,7 @@ class ToolService: """测试内置工具连接""" try: # 获取工具实例 - tool_instance = self._get_tool_instance(str(config.id), config.tenant_id) + tool_instance = self.get_tool_instance(str(config.id), config.tenant_id) if not tool_instance: return {"success": False, "message": "无法创建工具实例"} diff --git a/api/app/services/user_memory_service.py b/api/app/services/user_memory_service.py index db5051d2..8bacc112 100644 --- a/api/app/services/user_memory_service.py +++ b/api/app/services/user_memory_service.py @@ -21,8 +21,7 @@ from app.repositories.end_user_repository import EndUserRepository from app.repositories.neo4j.cypher_queries import Graph_Node_query from app.repositories.neo4j.neo4j_connector import Neo4jConnector from app.schemas.memory_episodic_schema import EmotionSubject, EmotionType, type_mapping -from app.services.implicit_memory_service import ImplicitMemoryService -from app.services.memory_base_service import MemoryBaseService, MemoryTransService +from app.services.memory_base_service import MemoryBaseService from app.services.memory_config_service import MemoryConfigService from app.services.memory_perceptual_service import MemoryPerceptualService from app.services.memory_short_service import ShortService @@ -1167,7 +1166,6 @@ async def analytics_user_summary(end_user_id: Optional[str] = None, language: st from app.core.language_utils import validate_language from app.core.memory.utils.prompt.prompt_utils import render_user_summary_prompt - from app.db import get_db from app.repositories.end_user_repository import EndUserRepository # 验证语言参数 @@ -1178,8 +1176,7 @@ async def analytics_user_summary(end_user_id: Optional[str] = None, language: st if end_user_id: try: # 获取数据库会话并查询用户信息 - db = next(get_db()) - try: + with get_db_context() as db: repo = EndUserRepository(db) end_user = repo.get_by_id(uuid.UUID(end_user_id)) if end_user and end_user.other_name: @@ -1187,8 +1184,7 @@ async def analytics_user_summary(end_user_id: Optional[str] = None, language: st logger.info(f"使用 other_name 作为用户显示名称: {user_display_name}") else: logger.info(f"用户 {end_user_id} 的 other_name 为空,使用默认称呼: {user_display_name}") - finally: - db.close() + except Exception as e: logger.warning(f"获取用户 other_name 失败,使用默认称呼: {str(e)}") diff --git a/api/app/services/workflow_service.py b/api/app/services/workflow_service.py index 02819efb..d13e3454 100644 --- a/api/app/services/workflow_service.py +++ b/api/app/services/workflow_service.py @@ -16,6 +16,7 @@ from app.core.workflow.adapters.registry import PlatformAdapterRegistry from app.core.workflow.executor import execute_workflow, execute_workflow_stream from app.core.workflow.nodes.enums import NodeType from app.core.workflow.validator import validate_workflow_config +from app.core.workflow.variable.base_variable import FileObject from app.db import get_db from app.models import App from app.models.workflow_model import WorkflowConfig, WorkflowExecution @@ -453,11 +454,14 @@ class WorkflowService: files_struct = [] for file in files: files_struct.append( - { - "type": file.type, - "url": await self.multimodal_service.get_file_url(file), - "__file": True - } + FileObject( + type=file.type, + url=await self.multimodal_service.get_file_url(file), + transfer_method=file.transfer_method, + file_id=str(file.upload_file_id), + origin_file_type=file.file_type, + is_file=True + ).model_dump() ) return files_struct diff --git a/api/app/services/workspace_service.py b/api/app/services/workspace_service.py index 2f8cdc70..e93c0c5c 100644 --- a/api/app/services/workspace_service.py +++ b/api/app/services/workspace_service.py @@ -107,6 +107,7 @@ def get_user_workspaces(db: Session, user: User) -> List[Workspace]: for workspace in workspaces: if workspace.storage_type == 'neo4j': _ensure_default_memory_config(db, workspace) + _ensure_default_ontology_scenes(db, workspace) business_logger.info(f"用户 {user.username} 的工作空间数量: {len(workspaces)}") return workspaces @@ -1104,6 +1105,52 @@ def _fill_workspace_configs_model_defaults( ) +def _ensure_default_ontology_scenes(db: Session, workspace: Workspace) -> None: + """Ensure a workspace has default ontology scenes, creating them if missing. + + Checks whether any is_system_default scene exists for the workspace. + If not, runs the DefaultOntologyInitializer to create them. + + Args: + db: Database session + workspace: The workspace to check + """ + from app.models.ontology_scene import OntologyScene + + # 幂等检查:是否已存在系统默认场景 + existing = db.query(OntologyScene).filter( + OntologyScene.workspace_id == workspace.id, + OntologyScene.is_system_default.is_(True) + ).first() + + if existing: + return + + business_logger.info( + f"Workspace {workspace.id} missing default ontology scenes, creating them" + ) + + try: + initializer = DefaultOntologyInitializer(db) + success, error_msg = initializer.initialize_default_scenes( + workspace.id, language="zh" + ) + if success: + db.commit() + business_logger.info( + f"为工作空间 {workspace.id} 补建默认本体场景成功" + ) + else: + business_logger.warning( + f"为工作空间 {workspace.id} 补建默认本体场景失败: {error_msg}" + ) + except Exception as e: + db.rollback() + business_logger.error( + f"为工作空间 {workspace.id} 补建默认本体场景异常: {str(e)}" + ) + + def _create_default_memory_config( db: Session, workspace_id: uuid.UUID, diff --git a/api/app/tasks.py b/api/app/tasks.py index 299d188b..a6ebbb8e 100644 --- a/api/app/tasks.py +++ b/api/app/tasks.py @@ -257,7 +257,7 @@ def parse_document(file_path: str, document_id: uuid.UUID): progress_msg += f"{datetime.now().strftime('%H:%M:%S')} GraphRAG task result for task {task}:\n{result}\n" return result - try: + def sync_task(): trio.run( lambda: _run( row=task, @@ -272,6 +272,10 @@ def parse_document(file_path: str, document_id: uuid.UUID): with_community=with_community, ) ) + try: + with ThreadPoolExecutor(max_workers=1) as executor: + future = executor.submit(sync_task) + future.result() # Blocks until the task completes except Exception as e: progress_msg += f"{datetime.now().strftime('%H:%M:%S')} GraphRAG task failed for task {task}:\n{str(e)}\n" progress_msg += f"{datetime.now().strftime('%H:%M:%S')} Knowledge Graph done ({time.time() - start_time}s)" @@ -2108,4 +2112,307 @@ def run_forgetting_cycle_task(self, config_id: Optional[uuid.UUID] = None) -> Di # "config_id": config_id, # "elapsed_time": elapsed_time, # "task_id": self.request.id -# } \ No newline at end of file +# } + + +# ============================================================================= +# 隐性记忆和情绪数据更新定时任务 +# ============================================================================= + +@celery_app.task( + name="app.tasks.update_implicit_emotions_storage", + bind=True, + ignore_result=True, + max_retries=0, + acks_late=False, + time_limit=7200, # 2小时硬超时 + soft_time_limit=6900, # 1小时55分钟软超时 +) +def update_implicit_emotions_storage(self) -> Dict[str, Any]: + """定时任务:更新所有用户的隐性记忆画像和情绪建议数据 + + 遍历数据库中所有已存在数据的用户,为每个用户重新生成隐性记忆画像和情绪建议。 + 实现错误隔离,单个用户失败不影响其他用户的处理。 + + Returns: + 包含任务执行结果的字典,包括: + - status: 任务状态 (SUCCESS/FAILURE) + - message: 执行消息 + - total_users: 总用户数 + - successful_implicit: 成功更新隐性记忆的用户数 + - successful_emotion: 成功更新情绪建议的用户数 + - failed: 失败的用户数 + - user_results: 每个用户的详细结果 + - elapsed_time: 执行耗时(秒) + - task_id: 任务ID + """ + start_time = time.time() + + async def _run() -> Dict[str, Any]: + from app.core.logging_config import get_logger + from app.repositories.implicit_emotions_storage_repository import ImplicitEmotionsStorageRepository + from app.models.implicit_emotions_storage_model import ImplicitEmotionsStorage + from sqlalchemy import select, func + from app.services.implicit_memory_service import ImplicitMemoryService + from app.services.emotion_analytics_service import EmotionAnalyticsService + + logger = get_logger(__name__) + logger.info("开始执行隐性记忆和情绪数据更新定时任务") + + total_users = 0 + successful_implicit = 0 + successful_emotion = 0 + failed = 0 + user_results = [] + + with get_db_context() as db: + try: + # 获取所有已存储数据的用户ID(分批次处理) + repo = ImplicitEmotionsStorageRepository(db) + + # 先统计总数用于日志 + from sqlalchemy import func + total_users = db.execute( + select(func.count()).select_from(ImplicitEmotionsStorage) + ).scalar() or 0 + logger.info(f"找到 {total_users} 个需要更新的用户") + + # 遍历每个用户并更新数据(分批次,避免一次性加载所有ID) + for end_user_id in repo.get_all_user_ids(batch_size=100): + logger.info(f"开始处理用户: {end_user_id}") + user_start_time = time.time() + + implicit_success = False + emotion_success = False + errors = [] + + try: + # 更新隐性记忆画像 + try: + implicit_service = ImplicitMemoryService(db=db, end_user_id=end_user_id) + profile_data = await implicit_service.generate_complete_profile(user_id=end_user_id) + await implicit_service.save_profile_cache( + end_user_id=end_user_id, + profile_data=profile_data, + db=db + ) + implicit_success = True + logger.info(f"成功更新用户 {end_user_id} 的隐性记忆画像") + except Exception as e: + error_msg = f"隐性记忆更新失败: {str(e)}" + errors.append(error_msg) + logger.error(f"用户 {end_user_id} {error_msg}") + + # 更新情绪建议 + try: + emotion_service = EmotionAnalyticsService() + suggestions_data = await emotion_service.generate_emotion_suggestions( + end_user_id=end_user_id, + db=db, + language="zh" + ) + await emotion_service.save_suggestions_cache( + end_user_id=end_user_id, + suggestions_data=suggestions_data, + db=db + ) + emotion_success = True + logger.info(f"成功更新用户 {end_user_id} 的情绪建议") + except Exception as e: + error_msg = f"情绪建议更新失败: {str(e)}" + errors.append(error_msg) + logger.error(f"用户 {end_user_id} {error_msg}") + + # 统计结果 + if implicit_success: + successful_implicit += 1 + if emotion_success: + successful_emotion += 1 + if not implicit_success and not emotion_success: + failed += 1 + + user_elapsed = time.time() - user_start_time + + # 记录用户处理结果 + user_result = { + "end_user_id": end_user_id, + "implicit_success": implicit_success, + "emotion_success": emotion_success, + "errors": errors, + "elapsed_time": user_elapsed + } + user_results.append(user_result) + + logger.info( + f"用户 {end_user_id} 处理完成: " + f"隐性记忆={'成功' if implicit_success else '失败'}, " + f"情绪建议={'成功' if emotion_success else '失败'}, " + f"耗时={user_elapsed:.2f}秒" + ) + + except Exception as e: + # 单个用户失败不影响其他用户(错误隔离) + failed += 1 + user_elapsed = time.time() - user_start_time + error_info = { + "end_user_id": end_user_id, + "implicit_success": False, + "emotion_success": False, + "errors": [str(e)], + "elapsed_time": user_elapsed + } + user_results.append(error_info) + logger.error(f"处理用户 {end_user_id} 时出错: {str(e)}") + + # ---- 处理增量用户(当天新增、尚未初始化的用户)---- + new_users_initialized = 0 + new_users_failed = 0 + logger.info("开始处理当天新增的增量用户初始化") + + for end_user_id in repo.get_new_user_ids_today(batch_size=100): + logger.info(f"开始初始化新用户: {end_user_id}") + user_start_time = time.time() + implicit_success = False + emotion_success = False + errors = [] + + try: + try: + implicit_service = ImplicitMemoryService(db=db, end_user_id=end_user_id) + profile_data = await implicit_service.generate_complete_profile(user_id=end_user_id) + await implicit_service.save_profile_cache( + end_user_id=end_user_id, + profile_data=profile_data, + db=db + ) + implicit_success = True + logger.info(f"成功初始化新用户 {end_user_id} 的隐性记忆画像") + except Exception as e: + error_msg = f"隐性记忆初始化失败: {str(e)}" + errors.append(error_msg) + logger.error(f"新用户 {end_user_id} {error_msg}") + + try: + emotion_service = EmotionAnalyticsService() + suggestions_data = await emotion_service.generate_emotion_suggestions( + end_user_id=end_user_id, + db=db, + language="zh" + ) + await emotion_service.save_suggestions_cache( + end_user_id=end_user_id, + suggestions_data=suggestions_data, + db=db + ) + emotion_success = True + logger.info(f"成功初始化新用户 {end_user_id} 的情绪建议") + except Exception as e: + error_msg = f"情绪建议初始化失败: {str(e)}" + errors.append(error_msg) + logger.error(f"新用户 {end_user_id} {error_msg}") + + if implicit_success or emotion_success: + new_users_initialized += 1 + else: + new_users_failed += 1 + + user_elapsed = time.time() - user_start_time + user_results.append({ + "end_user_id": end_user_id, + "type": "init", + "implicit_success": implicit_success, + "emotion_success": emotion_success, + "errors": errors, + "elapsed_time": user_elapsed + }) + + except Exception as e: + new_users_failed += 1 + user_elapsed = time.time() - user_start_time + user_results.append({ + "end_user_id": end_user_id, + "type": "init", + "implicit_success": False, + "emotion_success": False, + "errors": [str(e)], + "elapsed_time": user_elapsed + }) + logger.error(f"初始化新用户 {end_user_id} 时出错: {str(e)}") + + logger.info( + f"增量用户初始化完成: 成功={new_users_initialized}, 失败={new_users_failed}" + ) + # ---- 增量用户处理结束 ---- + + # 记录总体统计信息 + logger.info( + f"隐性记忆和情绪数据更新定时任务完成: " + f"存量用户总数={total_users}, " + f"隐性记忆成功={successful_implicit}, " + f"情绪建议成功={successful_emotion}, " + f"存量失败={failed}, " + f"增量初始化成功={new_users_initialized}, " + f"增量初始化失败={new_users_failed}" + ) + + return { + "status": "SUCCESS", + "message": ( + f"存量用户 {total_users} 个,隐性记忆 {successful_implicit} 个成功,情绪建议 {successful_emotion} 个成功;" + f"增量新用户初始化 {new_users_initialized} 个成功,{new_users_failed} 个失败" + ), + "total_users": total_users, + "successful_implicit": successful_implicit, + "successful_emotion": successful_emotion, + "failed": failed, + "new_users_initialized": new_users_initialized, + "new_users_failed": new_users_failed, + "user_results": user_results[:50] # 只保留前50个用户的详细结果 + } + + except Exception as e: + logger.error(f"隐性记忆和情绪数据更新定时任务执行失败: {str(e)}") + return { + "status": "FAILURE", + "error": str(e), + "total_users": total_users, + "successful_implicit": successful_implicit, + "successful_emotion": successful_emotion, + "failed": failed, + "new_users_initialized": 0, + "new_users_failed": 0, + "user_results": user_results[:50] + } + + try: + # 使用 nest_asyncio 来避免事件循环冲突 + try: + import nest_asyncio + nest_asyncio.apply() + except ImportError: + pass + + # 尝试获取现有事件循环,如果不存在则创建新的 + try: + loop = asyncio.get_event_loop() + if loop.is_closed(): + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + except RuntimeError: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + result = loop.run_until_complete(_run()) + elapsed_time = time.time() - start_time + result["elapsed_time"] = elapsed_time + result["task_id"] = self.request.id + + return result + except Exception as e: + elapsed_time = time.time() - start_time + return { + "status": "FAILURE", + "error": str(e), + "elapsed_time": elapsed_time, + "task_id": self.request.id + } diff --git a/api/env.example b/api/env.example index 1dc4536c..bd7f3dae 100644 --- a/api/env.example +++ b/api/env.example @@ -29,10 +29,10 @@ REDIS_DB= REDIS_PASSWORD=password #celery -BROKER_URL= -RESULT_BACKEND= -CELERY_BROKER= -CELERY_BACKEND= +# NOTE: 不要使用 BROKER_URL / RESULT_BACKEND / CELERY_BROKER / CELERY_BACKEND, +# 这些名称会被 Celery CLI 劫持,详见 docs/celery-env-bug-report.md +REDIS_DB_CELERY_BROKER= +REDIS_DB_CELERY_BACKEND= # Memory Cache Regeneration Configuration # Interval in hours for regenerating memory insight and user summary cache diff --git a/api/migrations/versions/6a4641cf192b_202603051440.py b/api/migrations/versions/6a4641cf192b_202603051440.py new file mode 100644 index 00000000..0322c9e2 --- /dev/null +++ b/api/migrations/versions/6a4641cf192b_202603051440.py @@ -0,0 +1,43 @@ +"""202603051440 + +Revision ID: 6a4641cf192b +Revises: b4af97639217 +Create Date: 2026-03-05 14:41:03.371557 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision: str = '6a4641cf192b' +down_revision: Union[str, None] = 'b4af97639217' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('implicit_emotions_storage', + sa.Column('id', sa.UUID(), nullable=False, comment='主键ID'), + sa.Column('end_user_id', sa.String(length=255), nullable=False, comment='终端用户ID'), + sa.Column('implicit_profile', postgresql.JSONB(astext_type=sa.Text()), nullable=True, comment='隐性记忆用户画像数据'), + sa.Column('emotion_suggestions', postgresql.JSONB(astext_type=sa.Text()), nullable=True, comment='情绪个性化建议数据'), + sa.Column('created_at', sa.DateTime(), nullable=False, comment='创建时间'), + sa.Column('updated_at', sa.DateTime(), nullable=False, comment='更新时间'), + sa.Column('implicit_generated_at', sa.DateTime(), nullable=True, comment='隐性记忆画像生成时间'), + sa.Column('emotion_generated_at', sa.DateTime(), nullable=True, comment='情绪建议生成时间'), + sa.PrimaryKeyConstraint('id'), + sa.UniqueConstraint('end_user_id') + ) + op.create_index('idx_updated_at', 'implicit_emotions_storage', ['updated_at'], unique=False) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_index('idx_updated_at', table_name='implicit_emotions_storage') + op.drop_table('implicit_emotions_storage') + # ### end Alembic commands ### diff --git a/api/migrations/versions/b4af97639217_202603051033.py b/api/migrations/versions/b4af97639217_202603051033.py new file mode 100644 index 00000000..ddeae41c --- /dev/null +++ b/api/migrations/versions/b4af97639217_202603051033.py @@ -0,0 +1,63 @@ +"""202603051033 + +Revision ID: b4af97639217 +Revises: 4bf27c66ae63 +Create Date: 2026-03-05 10:36:06.282227 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = 'b4af97639217' +down_revision: Union[str, None] = '4bf27c66ae63' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + # Add columns as nullable first to avoid table locks + op.add_column('model_api_keys', sa.Column('capability', sa.ARRAY(sa.String()), nullable=True, comment="模型能力列表(如['vision', 'audio', 'video'])")) + op.add_column('model_api_keys', sa.Column('is_omni', sa.Boolean(), nullable=True, comment='是否为Omni模型(使用特殊API调用)')) + + op.add_column('model_bases', sa.Column('capability', sa.ARRAY(sa.String()), nullable=True, comment="模型能力列表(如['vision', 'audio', 'video'])")) + op.add_column('model_bases', sa.Column('is_omni', sa.Boolean(), nullable=True, comment='是否为Omni模型(使用特殊API调用)')) + + op.add_column('model_configs', sa.Column('capability', sa.ARRAY(sa.String()), nullable=True, comment="模型能力列表(如['vision', 'audio', 'video'])")) + op.add_column('model_configs', sa.Column('is_omni', sa.Boolean(), nullable=True, comment='是否为Omni模型(使用特殊API调用)')) + + # Update existing rows with default values + op.execute("UPDATE model_api_keys SET capability = '{}' WHERE capability IS NULL") + op.execute("UPDATE model_api_keys SET is_omni = false WHERE is_omni IS NULL") + + op.execute("UPDATE model_bases SET capability = '{}' WHERE capability IS NULL") + op.execute("UPDATE model_bases SET is_omni = false WHERE is_omni IS NULL") + + op.execute("UPDATE model_configs SET capability = '{}' WHERE capability IS NULL") + op.execute("UPDATE model_configs SET is_omni = false WHERE is_omni IS NULL") + + # Now make columns NOT NULL + op.alter_column('model_api_keys', 'capability', nullable=False) + op.alter_column('model_api_keys', 'is_omni', nullable=False) + + op.alter_column('model_bases', 'capability', nullable=False) + op.alter_column('model_bases', 'is_omni', nullable=False) + + op.alter_column('model_configs', 'capability', nullable=False) + op.alter_column('model_configs', 'is_omni', nullable=False) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column('model_configs', 'is_omni') + op.drop_column('model_configs', 'capability') + op.drop_column('model_bases', 'is_omni') + op.drop_column('model_bases', 'capability') + op.drop_column('model_api_keys', 'is_omni') + op.drop_column('model_api_keys', 'capability') + # ### end Alembic commands ### diff --git a/web/src/api/memory.ts b/web/src/api/memory.ts index ef7aa460..2c840c9a 100644 --- a/web/src/api/memory.ts +++ b/web/src/api/memory.ts @@ -163,9 +163,14 @@ export const getImplicitInterestAreas = (end_user_id: string) => { export const getImplicitHabits = (end_user_id: string) => { return request.get(`/memory/implicit-memory/habits/${end_user_id}`) } +// Implicit Memory - Generate user portrait export const generateProfile = (end_user_id: string) => { return request.post(`/memory/implicit-memory/generate_profile`, { end_user_id }) } +// Implicit Memory - Check if data exists +export const implicitCheckData = (end_user_id: string) => { + return request.get(`/memory/implicit-memory/check-data/${end_user_id}`) +} // Short-term memory export const getShortTerm = (end_user_id: string) => { return request.get(`/memory/short/short_term`, { end_user_id }) diff --git a/web/src/components/AudioRecorder/index.tsx b/web/src/components/AudioRecorder/index.tsx index f6a030b4..d31746f6 100644 --- a/web/src/components/AudioRecorder/index.tsx +++ b/web/src/components/AudioRecorder/index.tsx @@ -1,16 +1,21 @@ import { type FC, useRef, useState } from 'react' import RecordRTC from 'recordrtc' -import { fileUpload } from '@/api/fileStorage' +import { fileUploadUrlWithoutApiPrefix } from '@/api/fileStorage' +import { request } from '@/utils/request' interface AudioRecorderProps { - onRecordingComplete?: (file: { file_id: string; file_key: string; }, blob: Blob) => void - className?: string + onRecordingComplete?: (file: { file_id: string; file_key: string; url: string; type?: string; }, blob?: Blob) => void + className?: string; + action?: string; + requestConfig?: Record; } const AudioRecorder: FC = ({ onRecordingComplete, className = '', + action = fileUploadUrlWithoutApiPrefix, + requestConfig = {} }) => { const [isRecording, setIsRecording] = useState(false) const recorderRef = useRef(null) @@ -33,11 +38,17 @@ const AudioRecorder: FC = ({ if (recorderRef.current) { recorderRef.current.stopRecording(() => { const blob = recorderRef.current!.getBlob() + const url = recorderRef.current!.toURL() const formData = new FormData() formData.append('file', blob, `recording_${Date.now()}.webm`) - fileUpload(formData) + request + .uploadFile(action, formData, requestConfig) .then(res => { - onRecordingComplete?.(res as { file_id: string; file_key: string; }, blob) + onRecordingComplete?.({ + ...(res as { file_id: string; file_key: string }), + type: blob.type, + url + }, blob) recorderRef.current?.destroy() recorderRef.current = null }) diff --git a/web/src/components/Chat/ChatInput.tsx b/web/src/components/Chat/ChatInput.tsx index c155bb22..49fb65d2 100644 --- a/web/src/components/Chat/ChatInput.tsx +++ b/web/src/components/Chat/ChatInput.tsx @@ -2,10 +2,11 @@ * @Author: ZhaoYing * @Date: 2025-12-10 16:46:14 * @Last Modified by: ZhaoYing - * @Last Modified time: 2026-02-10 12:13:52 + * @Last Modified time: 2026-03-04 18:42:49 */ import { type FC, useEffect, useMemo } from 'react' import { Flex, Input, Form } from 'antd' + import SendIcon from '@/assets/images/conversation/send.svg' import SendDisabledIcon from '@/assets/images/conversation/sendDisabled.svg' import LoadingIcon from '@/assets/images/conversation/loading.svg' @@ -80,9 +81,31 @@ const ChatInput: FC = ({ ) } + if (file.type.includes('video')) { + return ( +
+
+ ) + } + if (file.type.includes('audio')) { + return ( +
+
+ ) + } return (
- {(file.type.includes('word') || file.type.includes('wordprocessingml.document')) &&
} {(file.type.includes('pdf')) &&
item.msg).join(';') } else { msg = msg || i18n.t('common.unknownError'); diff --git a/web/src/utils/stream.ts b/web/src/utils/stream.ts index b637e76a..846af9f7 100644 --- a/web/src/utils/stream.ts +++ b/web/src/utils/stream.ts @@ -1,8 +1,8 @@ /* * @Author: ZhaoYing * @Date: 2026-02-02 16:35:43 - * @Last Modified by: ZhaoYing - * @Last Modified time: 2026-02-02 16:35:43 + * @Last Modified by: ZhaoYing + * @Last Modified time: 2026-03-04 18:19:24 */ /** * Server-Sent Events (SSE) Stream Utility Module @@ -176,17 +176,17 @@ export const handleSSE = async (url: string, data: any, onMessage?: (data: SSEMe case 500: case 502: const errorData = await response.json(); - errorData.error || i18n.t('common.serviceUpgrading'); - message.warning(errorData.error || i18n.t('common.serviceUpgrading')); - return; + let errorInfo = errorData.error || i18n.t('common.serviceUpgrading') + message.warning(errorInfo); + throw errorInfo; case 400: const error = await response.json(); message.warning(error.error); - throw error || 'Bad Request'; + throw error.error || 'Bad Request'; case 504: const errorJson = await response.json(); message.warning(errorJson.error || i18n.t('common.serverError')); - return; + throw errorData.error; case 401: if (url?.includes('/public')) { return message.warning(i18n.t('common.publicApiCannotRefreshToken')); diff --git a/web/src/views/ApplicationConfig/components/Chat.tsx b/web/src/views/ApplicationConfig/components/Chat.tsx index 8cb6812c..17af7613 100644 --- a/web/src/views/ApplicationConfig/components/Chat.tsx +++ b/web/src/views/ApplicationConfig/components/Chat.tsx @@ -2,7 +2,7 @@ * @Author: ZhaoYing * @Date: 2026-02-03 16:27:39 * @Last Modified by: ZhaoYing - * @Last Modified time: 2026-03-03 14:21:54 + * @Last Modified time: 2026-03-05 17:03:46 */ /** * Chat debugging component for application testing @@ -13,7 +13,7 @@ import { type FC, useEffect, useState, useRef } from 'react'; import { useTranslation } from 'react-i18next'; import clsx from 'clsx' -import { Flex, Dropdown, type MenuProps, App } from 'antd' +import { Flex, Dropdown, type MenuProps, App, Divider } from 'antd' import ChatIcon from '@/assets/images/application/chat.png' import DebuggingEmpty from '@/assets/images/application/debuggingEmpty.png' @@ -25,7 +25,7 @@ import type { ChatItem } from '@/components/Chat/types' import { type SSEMessage } from '@/utils/stream' import ChatInput from '@/components/Chat/ChatInput' import UploadFiles from '@/views/Conversation/components/FileUpload' -// import AudioRecorder from '@/components/AudioRecorder' +import AudioRecorder from '@/components/AudioRecorder' import UploadFileListModal from '@/views/Conversation/components/UploadFileListModal' import type { UploadFileListModalRef } from '@/views/Conversation/types' import type { Variable } from './VariableList/types' @@ -88,7 +88,7 @@ const Chat: FC = ({ chatList, data, updateChatList, handleSave, sourc content: '', created_at: Date.now(), }; - + if (isCluster) { updateChatList(prev => prev.map(item => ({ ...item, @@ -134,7 +134,7 @@ const Chat: FC = ({ chatList, data, updateChatList, handleSave, sourc }) } /** Update assistant message when error occurs */ - const updateErrorAssistantMessage = (message_length: number, model_config_id?: string) => { + const updateErrorAssistantMessage = (message_length: number, model_config_id?: string) => { if (message_length > 0 || !model_config_id) return updateChatList(prev => { @@ -171,6 +171,29 @@ const Chat: FC = ({ chatList, data, updateChatList, handleSave, sourc .then(() => { const message = msg if (!message?.trim()) return + // Validate required variables before sending + let isCanSend = true + const params: Record = {} + if (chatVariables && chatVariables.length > 0) { + const needRequired: string[] = [] + chatVariables.forEach(vo => { + params[vo.name] = vo.value + + if (vo.required && (params[vo.name] === null || params[vo.name] === undefined || params[vo.name] === '')) { + isCanSend = false + needRequired.push(vo.name) + } + }) + + if (needRequired.length) { + messageApi.error(`${needRequired.join(',')} ${t('workflow.variableRequired')}`) + } + } + if (!isCanSend) { + setLoading(false) + setCompareLoading(false) + return + } addUserMessage(message, fileList) setMessage(message) @@ -198,27 +221,6 @@ const Chat: FC = ({ chatList, data, updateChatList, handleSave, sourc }; setTimeout(() => { - // Validate required variables before sending - let isCanSend = true - const params: Record = {} - if (chatVariables && chatVariables.length > 0) { - const needRequired: string[] = [] - chatVariables.forEach(vo => { - params[vo.name] = vo.value - - if (vo.required && (params[vo.name] === null || params[vo.name] === undefined || params[vo.name] === '')) { - isCanSend = false - needRequired.push(vo.name) - } - }) - - if (needRequired.length) { - messageApi.error(`${needRequired.join(',')} ${t('workflow.variableRequired')}`) - } - } - if (!isCanSend) { - return - } runCompare(data.app_id, { message, files: fileList.map(file => { @@ -243,7 +245,15 @@ const Chat: FC = ({ chatList, data, updateChatList, handleSave, sourc "stream": true, "timeout": 60, }, handleStreamMessage) - .finally(() => setLoading(false)); + .catch(() => { + setLoading(false) + setCompareLoading(false) + updateClusterErrorAssistantMessage(0) + }) + .finally(() => { + setLoading(false) + setCompareLoading(false) + }) }, 0) }) .catch(() => { @@ -288,7 +298,7 @@ const Chat: FC = ({ chatList, data, updateChatList, handleSave, sourc }) } /** Update cluster message when error occurs */ - const updateClusterErrorAssistantMessage = (message_length: number) => { + const updateClusterErrorAssistantMessage = (message_length: number) => { if (message_length > 0) return updateChatList(prev => { @@ -331,7 +341,7 @@ const Chat: FC = ({ chatList, data, updateChatList, handleSave, sourc data.map(item => { const { conversation_id, content, message_length } = item.data as { conversation_id: string, content: string, message_length: number }; - switch(item.event) { + switch (item.event) { case 'start': if (conversation_id && conversationId !== conversation_id) { setConversationId(conversation_id); @@ -354,27 +364,35 @@ const Chat: FC = ({ chatList, data, updateChatList, handleSave, sourc }; setTimeout(() => { - draftRun( - data.app_id, - { - message, - conversation_id: conversationId, - stream: true, - files: fileList.map(file => { - if (file.url) { - return file - } else { - return { - type: file.type, - transfer_method: 'local_file', - upload_file_id: file.response.data.file_id - } + draftRun( + data.app_id, + { + message, + conversation_id: conversationId, + stream: true, + files: fileList.map(file => { + if (file.url) { + return file + } else { + return { + type: file.type, + transfer_method: 'local_file', + upload_file_id: file.response.data.file_id } - }), - }, - handleStreamMessage - ) - .finally(() => setLoading(false)) + } + }), + }, + handleStreamMessage + ) + .catch(() => { + setLoading(false) + setCompareLoading(false) + updateClusterErrorAssistantMessage(0) + }) + .finally(() => { + setLoading(false) + setCompareLoading(false) + }) }, 0) }) .catch(() => { @@ -393,12 +411,17 @@ const Chat: FC = ({ chatList, data, updateChatList, handleSave, sourc const fileChange = (file?: any) => { setFileList([...fileList, file]) } - // const handleRecordingComplete = async (file: any) => { - // console.log('file', file) - // } + const handleRecordingComplete = async (file: any) => { + setFileList([...fileList, { + uid: file.file_id, + response: { data: file }, + thumbUrl: file.url, + type: file.type + }]) + } const handleShowUpload: MenuProps['onClick'] = ({ key }) => { - switch(key) { + switch (key) { case 'define': uploadFileListModalRef.current?.handleOpen() break @@ -415,99 +438,98 @@ const Chat: FC = ({ chatList, data, updateChatList, handleSave, sourc return (
{chatList.length === 0 - ? - : <> -
- {chatList.map((chat, index) => ( -
1, - })}> - {chat.label && -
-
-
{chat.label}
-
handleDelete(index)} - >
+ : <> +
+ {chatList.map((chat, index) => ( +
1, + })}> + {chat.label && +
+
+
{chat.label}
+
handleDelete(index)} + >
+
-
- } - } - data={chat.list || []} - streamLoading={compareLoading} - labelPosition="top" - labelFormat={(item) => item.role === 'user' ? t('application.you') : chat.label} - errorDesc={t('application.ReplyException')} - /> -
- ))} -
-
- - - - - ) - }, - ], - onClick: handleShowUpload + } + -
-
+ contentClassNames={{ + 'rb:max-w-[400px]!': chatList.length === 1, + 'rb:max-w-[260px]!': chatList.length === 2, + 'rb:max-w-[150px]!': chatList.length === 3, + 'rb:max-w-[108px]!': chatList.length === 4, + }} + empty={} + data={chat.list || []} + streamLoading={compareLoading} + labelPosition="top" + labelFormat={(item) => item.role === 'user' ? t('application.you') : chat.label} + errorDesc={t('application.ReplyException')} + /> +
+ ))} +
+
+ + + + + ) + }, + ], + onClick: handleShowUpload + }} + > +
+
+
+ + + +
- {/* - - - */} - -
-
- + +
+ } { /** Custom file removal callback */ onRemove?: (file: UploadFile) => boolean | void | Promise; } + +const transform_file_type = { + 'text/plain': 'document/text', + 'text/markdown': 'document/markdown', + 'text/x-markdown': 'document/x-markdown', + + 'application/pdf': 'document/pdf', + + 'application/msword': 'document/doc', + 'application/vnd.openxmlformats-officedocument.wordprocessingml.document': 'document/docx', + + 'application/vnd.ms-powerpoint': 'document/ppt', + 'application/vnd.openxmlformats-officedocument.presentationml.presentation': 'document/pptx', +} // Mapping of file extensions to MIME types const ALL_FILE_TYPE: { [key: string]: string; } = { - // txt: 'text/plain', + txt: 'text/plain', + md: 'text/markdown', + xmd: 'text/x-markdown', + pdf: 'application/pdf', doc: 'application/msword', docx: 'application/vnd.openxmlformats-officedocument.wordprocessingml.document', - - xls: 'application/vnd.ms-excel', - xlsx: 'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet', - csv: 'text/csv', ppt: 'application/vnd.ms-powerpoint', pptx: 'application/vnd.openxmlformats-officedocument.presentationml.presentation', - - // md: 'text/markdown', - // htm: 'text/html', - // html: 'text/html', - // json: 'application/json', + jpg: 'image/jpeg', jpeg: 'image/jpeg', png: 'image/png', @@ -84,6 +94,23 @@ const ALL_FILE_TYPE: { bmp: 'image/bmp', webp: 'image/webp', svg: 'image/svg+xml', + + mp4: 'video/mp4', + mov: 'video/quicktime', + avi: 'video/x-msvideo', + mkv: 'video/x-matroska', + webm: 'video/webm', + flv: 'video/x-flv', + wmv: 'video/x-ms-wmv', + + mp3: 'audio/mpeg', + wav: 'audio/wav', + ogg: 'audio/ogg', + aac: 'audio/aac', + flac: 'audio/flac', + m4a: 'audio/mp4', + wma: 'audio/x-ms-wma', + xm4a: 'audio/x-m4a', } export interface UploadFilesRef { /** Current file list */ @@ -178,6 +205,10 @@ const UploadFiles = forwardRef(({ * Handles upload state changes */ const handleChange: UploadProps['onChange'] = ({ fileList: newFileList }) => { + newFileList.map(file => { + const type = (file.type && transform_file_type[file.type as keyof typeof transform_file_type]) || file.type || 'document' + file.type = type + }) setFileList(newFileList); if (onChange) { onChange(maxCount === 1 ? newFileList[newFileList.length - 1] : newFileList); diff --git a/web/src/views/Conversation/components/UploadFileListModal.tsx b/web/src/views/Conversation/components/UploadFileListModal.tsx index c5110701..a43b9dd4 100644 --- a/web/src/views/Conversation/components/UploadFileListModal.tsx +++ b/web/src/views/Conversation/components/UploadFileListModal.tsx @@ -2,7 +2,7 @@ * @Author: ZhaoYing * @Date: 2026-02-06 21:09:47 * @Last Modified by: ZhaoYing - * @Last Modified time: 2026-02-09 10:17:54 + * @Last Modified time: 2026-03-04 17:47:09 */ /** * Upload File List Modal Component @@ -104,7 +104,9 @@ const UploadFileListModal = forwardRef diff --git a/web/src/views/Conversation/index.tsx b/web/src/views/Conversation/index.tsx index 2ad2a5a4..8a67b3ae 100644 --- a/web/src/views/Conversation/index.tsx +++ b/web/src/views/Conversation/index.tsx @@ -14,7 +14,7 @@ import { type FC, useState, useEffect, useRef } from 'react' import { useParams, useLocation } from 'react-router-dom' import { useTranslation } from 'react-i18next' import InfiniteScroll from 'react-infinite-scroll-component'; -import { Flex, Skeleton, Form, Dropdown, type MenuProps, App } from 'antd' +import { Flex, Skeleton, Form, Dropdown, type MenuProps, App, Divider } from 'antd' import { SettingOutlined } from '@ant-design/icons' import clsx from 'clsx' import dayjs from 'dayjs' @@ -35,7 +35,7 @@ import OnlineCheckedIcon from '@/assets/images/conversation/onlineChecked.svg' import MemoryFunctionCheckedIcon from '@/assets/images/conversation/memoryFunctionChecked.svg' import { type SSEMessage } from '@/utils/stream' import UploadFiles from './components/FileUpload' -// import AudioRecorder from '@/components/AudioRecorder' +import AudioRecorder from '@/components/AudioRecorder' import { shareFileUploadUrlWithoutApiPrefix } from '@/api/fileStorage' import UploadFileListModal from './components/UploadFileListModal' import type { VariableConfigModalRef } from '@/views/Workflow/types' @@ -305,17 +305,27 @@ const Conversation: FC = () => { }), variables: params }, handleStreamMessage, shareToken) + .catch(() => { + setLoading(false) + setStreamLoading(false) + }) .finally(() => { setLoading(false) + setStreamLoading(false) }) } const fileChange = (file?: any) => { form.setFieldValue('files', [...(queryValues.files || []), file]) } - // const handleRecordingComplete = async (file: any) => { - // console.log('file', file) - // } + const handleRecordingComplete = async (file: any) => { + form.setFieldValue('files', [...(queryValues.files || []), { + uid: file.file_id, + response: { data: file }, + thumbUrl: file.url, + type: file.type + }]) + } const handleShowUpload: MenuProps['onClick'] = ({ key }) => { switch(key) { @@ -329,6 +339,7 @@ const Conversation: FC = () => { form.setFieldValue('files', [...(queryValues.files || []), ...fileList]) } const updateFileList = (fileList?: any[]) => { + console.log('fileList', fileList) form.setFieldValue('files', [...(fileList || [])]) } @@ -383,7 +394,7 @@ const Conversation: FC = () => {
} - contentClassName="rb:h-[calc(100%-180px)]" + contentClassName={!queryValues?.files?.length ? "rb:h-[calc(100%-144px)]" : "rb:h-[calc(100%-208px)]"} data={chatList} streamLoading={streamLoading} loading={loading} @@ -405,13 +416,12 @@ const Conversation: FC = () => { key: 'upload', label: ( ) }, @@ -455,10 +465,19 @@ const Conversation: FC = () => { )} - {/* - + + - */} + diff --git a/web/src/views/KnowledgeBase/components/CreateModal.tsx b/web/src/views/KnowledgeBase/components/CreateModal.tsx index 76640058..d9727d18 100644 --- a/web/src/views/KnowledgeBase/components/CreateModal.tsx +++ b/web/src/views/KnowledgeBase/components/CreateModal.tsx @@ -15,6 +15,7 @@ import { } from '@/api/knowledgeBase' import RbModal from '@/components/RbModal' import SliderInput from '@/components/SliderInput' +import { stringRegExp } from '@/utils/validator' const { TextArea } = Input; const { confirm } = Modal @@ -519,12 +520,16 @@ const CreateModal = forwardRef(({ )} - +