Merge branch 'develop' into feature/tool_yjp
This commit is contained in:
@@ -226,8 +226,8 @@ REDIS_PORT=6379
|
||||
REDIS_DB=1
|
||||
|
||||
# Celery (Using Redis as broker)
|
||||
BROKER_URL=redis://127.0.0.1:6379/0
|
||||
RESULT_BACKEND=redis://127.0.0.1:6379/0
|
||||
REDIS_DB_CELERY_BROKER=1
|
||||
REDIS_DB_CELERY_BACKEND=2
|
||||
|
||||
# JWT Secret Key (Formation method: openssl rand -hex 32)
|
||||
SECRET_KEY=your-secret-key-here
|
||||
|
||||
@@ -201,8 +201,8 @@ REDIS_PORT=6379
|
||||
REDIS_DB=1
|
||||
|
||||
# Celery (使用Redis作为broker)
|
||||
BROKER_URL=redis://127.0.0.1:6379/0
|
||||
RESULT_BACKEND=redis://127.0.0.1:6379/0
|
||||
REDIS_DB_CELERY_BROKER=1
|
||||
REDIS_DB_CELERY_BACKEND=2
|
||||
|
||||
# JWT密钥 (生成方式: openssl rand -hex 32)
|
||||
SECRET_KEY=your-secret-key-here
|
||||
|
||||
4
api/app/cache/__init__.py
vendored
4
api/app/cache/__init__.py
vendored
@@ -3,10 +3,8 @@ Cache 缓存模块
|
||||
|
||||
提供各种缓存功能的统一入口
|
||||
"""
|
||||
from .memory import EmotionMemoryCache, ImplicitMemoryCache, InterestMemoryCache
|
||||
from .memory import InterestMemoryCache
|
||||
|
||||
__all__ = [
|
||||
"EmotionMemoryCache",
|
||||
"ImplicitMemoryCache",
|
||||
"InterestMemoryCache",
|
||||
]
|
||||
|
||||
4
api/app/cache/memory/__init__.py
vendored
4
api/app/cache/memory/__init__.py
vendored
@@ -3,12 +3,8 @@ Memory 缓存模块
|
||||
|
||||
提供记忆系统相关的缓存功能
|
||||
"""
|
||||
from .emotion_memory import EmotionMemoryCache
|
||||
from .implicit_memory import ImplicitMemoryCache
|
||||
from .interest_memory import InterestMemoryCache
|
||||
|
||||
__all__ = [
|
||||
"EmotionMemoryCache",
|
||||
"ImplicitMemoryCache",
|
||||
"InterestMemoryCache",
|
||||
]
|
||||
|
||||
134
api/app/cache/memory/emotion_memory.py
vendored
134
api/app/cache/memory/emotion_memory.py
vendored
@@ -1,134 +0,0 @@
|
||||
"""
|
||||
Emotion Suggestions Cache
|
||||
|
||||
情绪个性化建议缓存模块
|
||||
用于缓存用户的情绪个性化建议数据
|
||||
"""
|
||||
import json
|
||||
import logging
|
||||
from typing import Optional, Dict, Any
|
||||
from datetime import datetime
|
||||
|
||||
from app.aioRedis import aio_redis
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EmotionMemoryCache:
|
||||
"""情绪建议缓存类"""
|
||||
|
||||
# Key 前缀
|
||||
PREFIX = "cache:memory:emotion_memory"
|
||||
|
||||
@classmethod
|
||||
def _get_key(cls, *parts: str) -> str:
|
||||
"""生成 Redis key
|
||||
|
||||
Args:
|
||||
*parts: key 的各个部分
|
||||
|
||||
Returns:
|
||||
完整的 Redis key
|
||||
"""
|
||||
return ":".join([cls.PREFIX] + list(parts))
|
||||
|
||||
@classmethod
|
||||
async def set_emotion_suggestions(
|
||||
cls,
|
||||
user_id: str,
|
||||
suggestions_data: Dict[str, Any],
|
||||
expire: int = 86400
|
||||
) -> bool:
|
||||
"""设置用户情绪建议缓存
|
||||
|
||||
Args:
|
||||
user_id: 用户ID(end_user_id)
|
||||
suggestions_data: 建议数据字典,包含:
|
||||
- health_summary: 健康状态摘要
|
||||
- suggestions: 建议列表
|
||||
- generated_at: 生成时间(可选)
|
||||
expire: 过期时间(秒),默认24小时(86400秒)
|
||||
|
||||
Returns:
|
||||
是否设置成功
|
||||
"""
|
||||
try:
|
||||
key = cls._get_key("suggestions", user_id)
|
||||
|
||||
# 添加生成时间戳
|
||||
if "generated_at" not in suggestions_data:
|
||||
suggestions_data["generated_at"] = datetime.now().isoformat()
|
||||
|
||||
# 添加缓存标记
|
||||
suggestions_data["cached"] = True
|
||||
|
||||
value = json.dumps(suggestions_data, ensure_ascii=False)
|
||||
await aio_redis.set(key, value, ex=expire)
|
||||
logger.info(f"设置情绪建议缓存成功: {key}, 过期时间: {expire}秒")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"设置情绪建议缓存失败: {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
async def get_emotion_suggestions(cls, user_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""获取用户情绪建议缓存
|
||||
|
||||
Args:
|
||||
user_id: 用户ID(end_user_id)
|
||||
|
||||
Returns:
|
||||
建议数据字典,如果不存在或已过期返回 None
|
||||
"""
|
||||
try:
|
||||
key = cls._get_key("suggestions", user_id)
|
||||
value = await aio_redis.get(key)
|
||||
|
||||
if value:
|
||||
data = json.loads(value)
|
||||
logger.info(f"成功获取情绪建议缓存: {key}")
|
||||
return data
|
||||
|
||||
logger.info(f"情绪建议缓存不存在或已过期: {key}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"获取情绪建议缓存失败: {e}", exc_info=True)
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
async def delete_emotion_suggestions(cls, user_id: str) -> bool:
|
||||
"""删除用户情绪建议缓存
|
||||
|
||||
Args:
|
||||
user_id: 用户ID(end_user_id)
|
||||
|
||||
Returns:
|
||||
是否删除成功
|
||||
"""
|
||||
try:
|
||||
key = cls._get_key("suggestions", user_id)
|
||||
result = await aio_redis.delete(key)
|
||||
logger.info(f"删除情绪建议缓存: {key}, 结果: {result}")
|
||||
return result > 0
|
||||
except Exception as e:
|
||||
logger.error(f"删除情绪建议缓存失败: {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
async def get_suggestions_ttl(cls, user_id: str) -> int:
|
||||
"""获取情绪建议缓存的剩余过期时间
|
||||
|
||||
Args:
|
||||
user_id: 用户ID(end_user_id)
|
||||
|
||||
Returns:
|
||||
剩余秒数,-1表示永不过期,-2表示key不存在
|
||||
"""
|
||||
try:
|
||||
key = cls._get_key("suggestions", user_id)
|
||||
ttl = await aio_redis.ttl(key)
|
||||
logger.debug(f"情绪建议缓存TTL: {key} = {ttl}秒")
|
||||
return ttl
|
||||
except Exception as e:
|
||||
logger.error(f"获取情绪建议缓存TTL失败: {e}")
|
||||
return -2
|
||||
136
api/app/cache/memory/implicit_memory.py
vendored
136
api/app/cache/memory/implicit_memory.py
vendored
@@ -1,136 +0,0 @@
|
||||
"""
|
||||
Implicit Memory Profile Cache
|
||||
|
||||
隐式记忆用户画像缓存模块
|
||||
用于缓存用户的完整画像数据(偏好标签、四维画像、兴趣领域、行为习惯)
|
||||
"""
|
||||
import json
|
||||
import logging
|
||||
from typing import Optional, Dict, Any
|
||||
from datetime import datetime
|
||||
|
||||
from app.aioRedis import aio_redis
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ImplicitMemoryCache:
|
||||
"""隐式记忆用户画像缓存类"""
|
||||
|
||||
# Key 前缀
|
||||
PREFIX = "cache:memory:implicit_memory"
|
||||
|
||||
@classmethod
|
||||
def _get_key(cls, *parts: str) -> str:
|
||||
"""生成 Redis key
|
||||
|
||||
Args:
|
||||
*parts: key 的各个部分
|
||||
|
||||
Returns:
|
||||
完整的 Redis key
|
||||
"""
|
||||
return ":".join([cls.PREFIX] + list(parts))
|
||||
|
||||
@classmethod
|
||||
async def set_user_profile(
|
||||
cls,
|
||||
user_id: str,
|
||||
profile_data: Dict[str, Any],
|
||||
expire: int = 86400
|
||||
) -> bool:
|
||||
"""设置用户完整画像缓存
|
||||
|
||||
Args:
|
||||
user_id: 用户ID(end_user_id)
|
||||
profile_data: 画像数据字典,包含:
|
||||
- preferences: 偏好标签列表
|
||||
- portrait: 四维画像对象
|
||||
- interest_areas: 兴趣领域分布对象
|
||||
- habits: 行为习惯列表
|
||||
- generated_at: 生成时间(可选)
|
||||
expire: 过期时间(秒),默认24小时(86400秒)
|
||||
|
||||
Returns:
|
||||
是否设置成功
|
||||
"""
|
||||
try:
|
||||
key = cls._get_key("profile", user_id)
|
||||
|
||||
# 添加生成时间戳
|
||||
if "generated_at" not in profile_data:
|
||||
profile_data["generated_at"] = datetime.now().isoformat()
|
||||
|
||||
# 添加缓存标记
|
||||
profile_data["cached"] = True
|
||||
|
||||
value = json.dumps(profile_data, ensure_ascii=False)
|
||||
await aio_redis.set(key, value, ex=expire)
|
||||
logger.info(f"设置用户画像缓存成功: {key}, 过期时间: {expire}秒")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"设置用户画像缓存失败: {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
async def get_user_profile(cls, user_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""获取用户完整画像缓存
|
||||
|
||||
Args:
|
||||
user_id: 用户ID(end_user_id)
|
||||
|
||||
Returns:
|
||||
画像数据字典,如果不存在或已过期返回 None
|
||||
"""
|
||||
try:
|
||||
key = cls._get_key("profile", user_id)
|
||||
value = await aio_redis.get(key)
|
||||
|
||||
if value:
|
||||
data = json.loads(value)
|
||||
logger.info(f"成功获取用户画像缓存: {key}")
|
||||
return data
|
||||
|
||||
logger.info(f"用户画像缓存不存在或已过期: {key}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"获取用户画像缓存失败: {e}", exc_info=True)
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
async def delete_user_profile(cls, user_id: str) -> bool:
|
||||
"""删除用户完整画像缓存
|
||||
|
||||
Args:
|
||||
user_id: 用户ID(end_user_id)
|
||||
|
||||
Returns:
|
||||
是否删除成功
|
||||
"""
|
||||
try:
|
||||
key = cls._get_key("profile", user_id)
|
||||
result = await aio_redis.delete(key)
|
||||
logger.info(f"删除用户画像缓存: {key}, 结果: {result}")
|
||||
return result > 0
|
||||
except Exception as e:
|
||||
logger.error(f"删除用户画像缓存失败: {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
async def get_profile_ttl(cls, user_id: str) -> int:
|
||||
"""获取用户画像缓存的剩余过期时间
|
||||
|
||||
Args:
|
||||
user_id: 用户ID(end_user_id)
|
||||
|
||||
Returns:
|
||||
剩余秒数,-1表示永不过期,-2表示key不存在
|
||||
"""
|
||||
try:
|
||||
key = cls._get_key("profile", user_id)
|
||||
ttl = await aio_redis.ttl(key)
|
||||
logger.debug(f"用户画像缓存TTL: {key} = {ttl}秒")
|
||||
return ttl
|
||||
except Exception as e:
|
||||
logger.error(f"获取用户画像缓存TTL失败: {e}")
|
||||
return -2
|
||||
@@ -1,26 +1,54 @@
|
||||
import os
|
||||
import platform
|
||||
from datetime import timedelta
|
||||
from celery.schedules import crontab
|
||||
from urllib.parse import quote
|
||||
|
||||
from celery import Celery
|
||||
from celery.schedules import crontab
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.logging_config import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# macOS fork() safety - must be set before any Celery initialization
|
||||
if platform.system() == 'Darwin':
|
||||
os.environ.setdefault('OBJC_DISABLE_INITIALIZE_FORK_SAFETY', 'YES')
|
||||
|
||||
# 创建 Celery 应用实例
|
||||
# broker: 任务队列(使用 Redis DB 0)
|
||||
# backend: 结果存储(使用 Redis DB 10)
|
||||
# broker: 任务队列(使用 Redis DB,由 CELERY_BROKER_DB 指定)
|
||||
# backend: 结果存储(使用 Redis DB,由 CELERY_BACKEND_DB 指定)
|
||||
# NOTE: 不要在 .env 中设置 BROKER_URL / RESULT_BACKEND / CELERY_BROKER / CELERY_BACKEND,
|
||||
# 这些名称会被 Celery CLI 的 Click 框架劫持,详见 docs/celery-env-bug-report.md
|
||||
|
||||
# Build canonical broker/backend URLs and force them into os.environ so that
|
||||
# Celery's Settings.broker_url property (which checks CELERY_BROKER_URL first)
|
||||
# cannot be overridden by stray env vars.
|
||||
# See: https://github.com/celery/celery/issues/4284
|
||||
_broker_url = f"redis://:{quote(settings.REDIS_PASSWORD)}@{settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.REDIS_DB_CELERY_BROKER}"
|
||||
_backend_url = f"redis://:{quote(settings.REDIS_PASSWORD)}@{settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.REDIS_DB_CELERY_BACKEND}"
|
||||
os.environ["CELERY_BROKER_URL"] = _broker_url
|
||||
os.environ["CELERY_RESULT_BACKEND"] = _backend_url
|
||||
# Neutralize legacy Celery env vars that can be hijacked by Celery's CLI/Click
|
||||
# integration and accidentally override our canonical URLs.
|
||||
os.environ.pop("BROKER_URL", None)
|
||||
os.environ.pop("RESULT_BACKEND", None)
|
||||
os.environ.pop("CELERY_BROKER", None)
|
||||
os.environ.pop("CELERY_BACKEND", None)
|
||||
|
||||
celery_app = Celery(
|
||||
"redbear_tasks",
|
||||
broker=f"redis://:{quote(settings.REDIS_PASSWORD)}@{settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.CELERY_BROKER}",
|
||||
backend=f"redis://:{quote(settings.REDIS_PASSWORD)}@{settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.CELERY_BACKEND}",
|
||||
broker=_broker_url,
|
||||
backend=_backend_url,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Celery app initialized",
|
||||
extra={
|
||||
"broker": _broker_url.replace(quote(settings.REDIS_PASSWORD), "***"),
|
||||
"backend": _backend_url.replace(quote(settings.REDIS_PASSWORD), "***"),
|
||||
},
|
||||
)
|
||||
# Default queue for unrouted tasks
|
||||
celery_app.conf.task_default_queue = 'memory_tasks'
|
||||
|
||||
@@ -44,8 +72,8 @@ celery_app.conf.update(
|
||||
task_ignore_result=False,
|
||||
|
||||
# 超时设置
|
||||
task_time_limit=1800, # 30分钟硬超时
|
||||
task_soft_time_limit=1500, # 25分钟软超时
|
||||
task_time_limit=3600, # 60分钟硬超时
|
||||
task_soft_time_limit=3000, # 50分钟软超时
|
||||
|
||||
# Worker 设置 (per-worker settings are in docker-compose command line)
|
||||
worker_prefetch_multiplier=1, # Don't hoard tasks, fairer distribution
|
||||
@@ -84,6 +112,7 @@ celery_app.conf.update(
|
||||
'app.tasks.regenerate_memory_cache': {'queue': 'periodic_tasks'},
|
||||
'app.tasks.run_forgetting_cycle_task': {'queue': 'periodic_tasks'},
|
||||
'app.tasks.write_all_workspaces_memory_task': {'queue': 'periodic_tasks'},
|
||||
'app.tasks.update_implicit_emotions_storage': {'queue': 'periodic_tasks'},
|
||||
},
|
||||
)
|
||||
|
||||
@@ -95,6 +124,10 @@ memory_increment_schedule = crontab(hour=settings.MEMORY_INCREMENT_HOUR, minute=
|
||||
memory_cache_regeneration_schedule = timedelta(hours=settings.MEMORY_CACHE_REGENERATION_HOURS)
|
||||
workspace_reflection_schedule = timedelta(seconds=settings.WORKSPACE_REFLECTION_INTERVAL_SECONDS)
|
||||
forgetting_cycle_schedule = timedelta(hours=settings.FORGETTING_CYCLE_INTERVAL_HOURS)
|
||||
implicit_emotions_update_schedule = crontab(
|
||||
hour=settings.IMPLICIT_EMOTIONS_UPDATE_HOUR,
|
||||
minute=settings.IMPLICIT_EMOTIONS_UPDATE_MINUTE,
|
||||
)
|
||||
|
||||
#构建定时任务配置
|
||||
beat_schedule_config = {
|
||||
@@ -120,6 +153,11 @@ beat_schedule_config = {
|
||||
"schedule": memory_increment_schedule,
|
||||
"args": (),
|
||||
},
|
||||
"update-implicit-emotions-storage": {
|
||||
"task": "app.tasks.update_implicit_emotions_storage",
|
||||
"schedule": implicit_emotions_update_schedule,
|
||||
"args": (),
|
||||
},
|
||||
}
|
||||
|
||||
celery_app.conf.beat_schedule = beat_schedule_config
|
||||
|
||||
@@ -396,10 +396,10 @@ async def draft_run(
|
||||
from app.models import AgentConfig, ModelConfig
|
||||
from sqlalchemy import select
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.services.draft_run_service import DraftRunService
|
||||
from app.services.draft_run_service import AgentRunService
|
||||
|
||||
service = AppService(db)
|
||||
draft_service = DraftRunService(db)
|
||||
draft_service = AgentRunService(db)
|
||||
|
||||
# 1. 验证应用
|
||||
app = service._get_app_or_404(app_id)
|
||||
@@ -484,8 +484,8 @@ async def draft_run(
|
||||
}
|
||||
)
|
||||
|
||||
from app.services.draft_run_service import DraftRunService
|
||||
draft_service = DraftRunService(db)
|
||||
from app.services.draft_run_service import AgentRunService
|
||||
draft_service = AgentRunService(db)
|
||||
result = await draft_service.run(
|
||||
agent_config=agent_cfg,
|
||||
model_config=model_config,
|
||||
@@ -789,8 +789,8 @@ async def draft_run_compare(
|
||||
# 流式返回
|
||||
if payload.stream:
|
||||
async def event_generator():
|
||||
from app.services.draft_run_service import DraftRunService
|
||||
draft_service = DraftRunService(db)
|
||||
from app.services.draft_run_service import AgentRunService
|
||||
draft_service = AgentRunService(db)
|
||||
async for event in draft_service.run_compare_stream(
|
||||
agent_config=agent_cfg,
|
||||
models=model_configs,
|
||||
@@ -820,8 +820,8 @@ async def draft_run_compare(
|
||||
)
|
||||
|
||||
# 非流式返回
|
||||
from app.services.draft_run_service import DraftRunService
|
||||
draft_service = DraftRunService(db)
|
||||
from app.services.draft_run_service import AgentRunService
|
||||
draft_service = AgentRunService(db)
|
||||
result = await draft_service.run_compare(
|
||||
agent_config=agent_cfg,
|
||||
models=model_configs,
|
||||
@@ -835,7 +835,8 @@ async def draft_run_compare(
|
||||
web_search=True,
|
||||
memory=True,
|
||||
parallel=payload.parallel,
|
||||
timeout=payload.timeout or 60
|
||||
timeout=payload.timeout or 60,
|
||||
files=payload.files
|
||||
)
|
||||
|
||||
logger.info(
|
||||
|
||||
@@ -441,14 +441,14 @@ async def retrieve_chunks(
|
||||
# 1 participle search, 2 semantic search, 3 hybrid search
|
||||
match retrieve_data.retrieve_type:
|
||||
case chunk_schema.RetrieveType.PARTICIPLE:
|
||||
rs = vector_service.search_by_full_text(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.similarity_threshold)
|
||||
rs = vector_service.search_by_full_text(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.similarity_threshold, file_names_filter=retrieve_data.file_names_filter)
|
||||
return success(data=rs, msg="retrieval successful")
|
||||
case chunk_schema.RetrieveType.SEMANTIC:
|
||||
rs = vector_service.search_by_vector(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.vector_similarity_weight)
|
||||
rs = vector_service.search_by_vector(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.vector_similarity_weight, file_names_filter=retrieve_data.file_names_filter)
|
||||
return success(data=rs, msg="retrieval successful")
|
||||
case _:
|
||||
rs1 = vector_service.search_by_vector(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.vector_similarity_weight)
|
||||
rs2 = vector_service.search_by_full_text(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.similarity_threshold)
|
||||
rs1 = vector_service.search_by_vector(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.vector_similarity_weight, file_names_filter=retrieve_data.file_names_filter)
|
||||
rs2 = vector_service.search_by_full_text(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.similarity_threshold, file_names_filter=retrieve_data.file_names_filter)
|
||||
# Efficient deduplication
|
||||
seen_ids = set()
|
||||
unique_rs = []
|
||||
|
||||
@@ -208,14 +208,64 @@ async def get_emotion_health(
|
||||
|
||||
|
||||
|
||||
# @router.post("/check-data", response_model=ApiResponse)
|
||||
# async def check_emotion_data_exists(
|
||||
# request: EmotionSuggestionsRequest,
|
||||
# db: Session = Depends(get_db),
|
||||
# current_user: User = Depends(get_current_user),
|
||||
# ):
|
||||
# """检查用户情绪建议数据是否存在
|
||||
|
||||
# Args:
|
||||
# request: 包含 end_user_id
|
||||
# db: 数据库会话
|
||||
# current_user: 当前用户
|
||||
|
||||
# Returns:
|
||||
# 数据存在状态
|
||||
# """
|
||||
# try:
|
||||
# api_logger.info(
|
||||
# f"检查用户情绪建议数据是否存在: {request.end_user_id}",
|
||||
# extra={"end_user_id": request.end_user_id}
|
||||
# )
|
||||
|
||||
# # 从数据库获取建议
|
||||
# data = await emotion_service.get_cached_suggestions(
|
||||
# end_user_id=request.end_user_id,
|
||||
# db=db
|
||||
# )
|
||||
|
||||
# if data is None:
|
||||
# api_logger.info(f"用户 {request.end_user_id} 的情绪建议数据不存在")
|
||||
# return fail(
|
||||
# BizCode.NOT_FOUND,
|
||||
# "情绪建议数据不存在,请点击右上角刷新进行初始化",
|
||||
# {"exists": False}
|
||||
# )
|
||||
|
||||
# api_logger.info(f"用户 {request.end_user_id} 的情绪建议数据存在")
|
||||
# return success(data={"exists": True}, msg="情绪建议数据已存在")
|
||||
|
||||
# except Exception as e:
|
||||
# api_logger.error(
|
||||
# f"检查情绪建议数据失败: {str(e)}",
|
||||
# extra={"end_user_id": request.end_user_id},
|
||||
# exc_info=True
|
||||
# )
|
||||
# raise HTTPException(
|
||||
# status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
# detail=f"检查情绪建议数据失败: {str(e)}"
|
||||
# )
|
||||
|
||||
|
||||
@router.post("/suggestions", response_model=ApiResponse)
|
||||
async def get_emotion_suggestions(
|
||||
request: EmotionSuggestionsRequest,
|
||||
language_type: str = Header(default=None, alias="X-Language-Type"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""获取个性化情绪建议(从缓存读取)
|
||||
"""获取个性化情绪建议(从数据库读取)
|
||||
|
||||
Args:
|
||||
request: 包含 end_user_id 和可选的 config_id
|
||||
@@ -223,77 +273,42 @@ async def get_emotion_suggestions(
|
||||
current_user: 当前用户
|
||||
|
||||
Returns:
|
||||
缓存的个性化情绪建议响应
|
||||
存储的个性化情绪建议响应
|
||||
"""
|
||||
try:
|
||||
# 使用集中化的语言校验
|
||||
language = get_language_from_header(language_type)
|
||||
|
||||
api_logger.info(
|
||||
f"用户 {current_user.username} 请求获取个性化情绪建议(缓存)",
|
||||
f"用户 {current_user.username} 请求获取个性化情绪建议",
|
||||
extra={
|
||||
"end_user_id": request.end_user_id,
|
||||
"config_id": request.config_id
|
||||
}
|
||||
)
|
||||
|
||||
# 从缓存获取建议
|
||||
# 从数据库获取建议
|
||||
data = await emotion_service.get_cached_suggestions(
|
||||
end_user_id=request.end_user_id,
|
||||
db=db
|
||||
)
|
||||
|
||||
if data is None:
|
||||
# 缓存不存在或已过期,自动触发生成
|
||||
api_logger.info(
|
||||
f"用户 {request.end_user_id} 的建议缓存不存在或已过期,自动生成新建议",
|
||||
f"用户 {request.end_user_id} 的建议数据不存在",
|
||||
extra={"end_user_id": request.end_user_id}
|
||||
)
|
||||
try:
|
||||
data = await emotion_service.generate_emotion_suggestions(
|
||||
end_user_id=request.end_user_id,
|
||||
db=db,
|
||||
language=language
|
||||
)
|
||||
# 保存到缓存
|
||||
await emotion_service.save_suggestions_cache(
|
||||
end_user_id=request.end_user_id,
|
||||
suggestions_data=data,
|
||||
db=db,
|
||||
expires_hours=24
|
||||
)
|
||||
except (ValueError, KeyError) as gen_e:
|
||||
# 预期内的业务异常:配置缺失、数据格式问题等
|
||||
api_logger.warning(
|
||||
f"自动生成建议失败(业务异常): {str(gen_e)}",
|
||||
extra={"end_user_id": request.end_user_id}
|
||||
)
|
||||
return fail(
|
||||
BizCode.NOT_FOUND,
|
||||
f"自动生成建议失败: {str(gen_e)}",
|
||||
""
|
||||
)
|
||||
except Exception as gen_e:
|
||||
# 非预期异常:记录完整 traceback 便于排查
|
||||
api_logger.error(
|
||||
f"自动生成建议时发生未预期异常: {str(gen_e)}",
|
||||
extra={"end_user_id": request.end_user_id},
|
||||
exc_info=True
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"生成建议时发生内部错误: {str(gen_e)}"
|
||||
)
|
||||
return success(
|
||||
data={"exists": False},
|
||||
msg="情绪建议数据不存在,请点击右上角刷新进行初始化"
|
||||
)
|
||||
|
||||
api_logger.info(
|
||||
"个性化建议获取成功(缓存)",
|
||||
"个性化建议获取成功",
|
||||
extra={
|
||||
"end_user_id": request.end_user_id,
|
||||
"suggestions_count": len(data.get("suggestions", []))
|
||||
}
|
||||
)
|
||||
|
||||
return success(data=data, msg="个性化建议获取成功(缓存)")
|
||||
return success(data=data, msg="个性化建议获取成功")
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(
|
||||
@@ -314,7 +329,7 @@ async def generate_emotion_suggestions(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""生成个性化情绪建议(调用LLM并缓存)
|
||||
"""生成个性化情绪建议(调用LLM并保存到数据库)
|
||||
|
||||
Args:
|
||||
request: 包含 end_user_id
|
||||
@@ -342,12 +357,11 @@ async def generate_emotion_suggestions(
|
||||
language=language
|
||||
)
|
||||
|
||||
# 保存到缓存
|
||||
# 保存到数据库
|
||||
await emotion_service.save_suggestions_cache(
|
||||
end_user_id=request.end_user_id,
|
||||
suggestions_data=data,
|
||||
db=db,
|
||||
expires_hours=24
|
||||
db=db
|
||||
)
|
||||
|
||||
api_logger.info(
|
||||
@@ -369,4 +383,4 @@ async def generate_emotion_suggestions(
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"生成个性化建议失败: {str(e)}"
|
||||
)
|
||||
)
|
||||
@@ -122,6 +122,48 @@ def validate_confidence_threshold(threshold: float) -> None:
|
||||
raise ValueError("confidence_threshold must be between 0.0 and 1.0")
|
||||
|
||||
|
||||
@router.get("/check-data/{end_user_id}", response_model=ApiResponse)
|
||||
@cur_workspace_access_guard()
|
||||
async def check_user_data_exists(
|
||||
end_user_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
) -> ApiResponse:
|
||||
"""
|
||||
检查用户画像数据是否存在
|
||||
|
||||
Args:
|
||||
end_user_id: 目标用户ID
|
||||
|
||||
Returns:
|
||||
数据存在状态
|
||||
"""
|
||||
api_logger.info(f"检查用户画像数据是否存在: {end_user_id}")
|
||||
|
||||
try:
|
||||
# Validate inputs
|
||||
validate_user_id(end_user_id)
|
||||
|
||||
# Create service with user-specific config
|
||||
service = ImplicitMemoryService(db=db, end_user_id=end_user_id)
|
||||
|
||||
# Get cached profile
|
||||
cached_profile = await service.get_cached_profile(end_user_id=end_user_id, db=db)
|
||||
|
||||
if cached_profile is None:
|
||||
api_logger.info(f"用户 {end_user_id} 的画像数据不存在")
|
||||
return success(
|
||||
data={"exists": False},
|
||||
msg="画像数据不存在,请点击右上角刷新进行初始化"
|
||||
)
|
||||
|
||||
api_logger.info(f"用户 {end_user_id} 的画像数据存在")
|
||||
return success(data={"exists": True}, msg="画像数据已存在")
|
||||
|
||||
except Exception as e:
|
||||
return handle_implicit_memory_error(e, "检查画像数据", end_user_id)
|
||||
|
||||
|
||||
@router.get("/preferences/{end_user_id}", response_model=ApiResponse)
|
||||
@cur_workspace_access_guard()
|
||||
async def get_preference_tags(
|
||||
@@ -159,12 +201,8 @@ async def get_preference_tags(
|
||||
cached_profile = await service.get_cached_profile(end_user_id=end_user_id, db=db)
|
||||
|
||||
if cached_profile is None:
|
||||
api_logger.info(f"用户 {end_user_id} 的画像缓存不存在或已过期")
|
||||
return fail(
|
||||
BizCode.NOT_FOUND,
|
||||
"画像缓存不存在或已过期,请右上角刷新生成新画像",
|
||||
""
|
||||
)
|
||||
api_logger.info(f"用户 {end_user_id} 的画像数据不存在")
|
||||
return fail(BizCode.NOT_FOUND, "", "")
|
||||
|
||||
# Extract preferences from cache
|
||||
preferences = cached_profile.get("preferences", [])
|
||||
@@ -230,12 +268,8 @@ async def get_dimension_portrait(
|
||||
cached_profile = await service.get_cached_profile(end_user_id=end_user_id, db=db)
|
||||
|
||||
if cached_profile is None:
|
||||
api_logger.info(f"用户 {end_user_id} 的画像缓存不存在或已过期")
|
||||
return fail(
|
||||
BizCode.NOT_FOUND,
|
||||
"画像缓存不存在或已过期,请右上角刷新生成新画像",
|
||||
""
|
||||
)
|
||||
api_logger.info(f"用户 {end_user_id} 的画像数据不存在")
|
||||
return fail(BizCode.NOT_FOUND, "", "")
|
||||
|
||||
# Extract portrait from cache
|
||||
portrait = cached_profile.get("portrait", {})
|
||||
@@ -278,12 +312,8 @@ async def get_interest_area_distribution(
|
||||
cached_profile = await service.get_cached_profile(end_user_id=end_user_id, db=db)
|
||||
|
||||
if cached_profile is None:
|
||||
api_logger.info(f"用户 {end_user_id} 的画像缓存不存在或已过期")
|
||||
return fail(
|
||||
BizCode.NOT_FOUND,
|
||||
"画像缓存不存在或已过期,请右上角刷新生成新画像",
|
||||
""
|
||||
)
|
||||
api_logger.info(f"用户 {end_user_id} 的画像数据不存在")
|
||||
return fail(BizCode.NOT_FOUND, "", "")
|
||||
|
||||
# Extract interest areas from cache
|
||||
interest_areas = cached_profile.get("interest_areas", {})
|
||||
@@ -330,12 +360,8 @@ async def get_behavior_habits(
|
||||
cached_profile = await service.get_cached_profile(end_user_id=end_user_id, db=db)
|
||||
|
||||
if cached_profile is None:
|
||||
api_logger.info(f"用户 {end_user_id} 的画像缓存不存在或已过期")
|
||||
return fail(
|
||||
BizCode.NOT_FOUND,
|
||||
"画像缓存不存在或已过期,请右上角刷新生成新画像",
|
||||
""
|
||||
)
|
||||
api_logger.info(f"用户 {end_user_id} 的画像数据不存在")
|
||||
return fail(BizCode.NOT_FOUND, "", "")
|
||||
|
||||
# Extract habits from cache
|
||||
habits = cached_profile.get("habits", [])
|
||||
|
||||
@@ -90,7 +90,7 @@ async def get_mcp_servers(
|
||||
cookies=cookies)
|
||||
raise_for_http_status(r)
|
||||
except requests.exceptions.RequestException as e:
|
||||
api_logger.error(f"mFailed to get MCP servers: {str(e)}")
|
||||
api_logger.error(f"Failed to get MCP servers: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to get MCP servers: {str(e)}"
|
||||
@@ -118,6 +118,65 @@ async def get_mcp_servers(
|
||||
return success(data=result, msg="Query of mcp servers list successful")
|
||||
|
||||
|
||||
@router.get("/operational_mcp_servers", response_model=ApiResponse)
|
||||
async def get_operational_mcp_servers(
|
||||
mcp_market_config_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Query the operational mcp servers list in pages
|
||||
- Support keyword search for name,author,owner
|
||||
- Return paging metadata + operational mcp server list
|
||||
"""
|
||||
api_logger.info(
|
||||
f"Query operational mcp server list: tenant_id={current_user.tenant_id}, username: {current_user.username}")
|
||||
|
||||
# 1. Query mcp market config information from the database
|
||||
api_logger.debug(f"Query mcp market config: {mcp_market_config_id}")
|
||||
db_mcp_market_config = mcp_market_config_service.get_mcp_market_config_by_id(db,
|
||||
mcp_market_config_id=mcp_market_config_id,
|
||||
current_user=current_user)
|
||||
if not db_mcp_market_config:
|
||||
api_logger.warning(
|
||||
f"The mcp market config does not exist or access is denied: mcp_market_config_id={mcp_market_config_id}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="The mcp market config does not exist or access is denied"
|
||||
)
|
||||
|
||||
# 2. Execute paged query
|
||||
api = MCPApi()
|
||||
token = db_mcp_market_config.token
|
||||
api.login(token)
|
||||
|
||||
url = f'{api.mcp_base_url}/operational'
|
||||
headers = api.builder_headers(api.headers)
|
||||
|
||||
try:
|
||||
cookies = api.get_cookies(access_token=token, cookies_required=True)
|
||||
r = api.session.get(url, headers=headers, cookies=cookies)
|
||||
raise_for_http_status(r)
|
||||
except requests.exceptions.RequestException as e:
|
||||
api_logger.error(f"Failed to get operational MCP servers: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to get operational MCP servers: {str(e)}"
|
||||
)
|
||||
|
||||
data = api._handle_response(r)
|
||||
total = data.get('total_count', 0)
|
||||
mcp_server_list = data.get('mcp_server_list', [])
|
||||
# items = [{
|
||||
# 'name': item.get('name', ''),
|
||||
# 'id': item.get('id', ''),
|
||||
# 'description': item.get('description', '')
|
||||
# } for item in mcp_server_list]
|
||||
|
||||
# 3. Return structured response
|
||||
return success(data=mcp_server_list, msg="Query of operational mcp servers list successful")
|
||||
|
||||
|
||||
@router.get("/mcp_server", response_model=ApiResponse)
|
||||
async def get_mcp_server(
|
||||
mcp_market_config_id: uuid.UUID,
|
||||
|
||||
@@ -1,28 +1,29 @@
|
||||
from typing import List, Optional
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from fastapi import APIRouter, Depends, File, Form, Query, UploadFile, Header
|
||||
from sqlalchemy.orm import Session
|
||||
from starlette.responses import StreamingResponse
|
||||
|
||||
from app.cache.memory.interest_memory import InterestMemoryCache
|
||||
from app.celery_app import celery_app
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.language_utils import get_language_from_header
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.core.memory.agent.utils.redis_tool import store
|
||||
from app.core.memory.agent.utils.session_tools import SessionService
|
||||
from app.core.rag.llm.cv_model import QWenCV
|
||||
from app.core.response_utils import fail, success
|
||||
from app.db import get_db
|
||||
from app.dependencies import cur_workspace_access_guard, get_current_user
|
||||
from app.models import ModelApiKey
|
||||
from app.models.user_model import User
|
||||
from app.core.memory.agent.utils.session_tools import SessionService
|
||||
from app.core.memory.agent.utils.redis_tool import store
|
||||
from app.repositories import knowledge_repository, WorkspaceRepository
|
||||
from app.repositories import knowledge_repository
|
||||
from app.schemas.memory_agent_schema import UserInput, Write_UserInput
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.services import task_service, workspace_service
|
||||
from app.services.memory_agent_service import MemoryAgentService
|
||||
from app.services.model_service import ModelConfigService
|
||||
from dotenv import load_dotenv
|
||||
from fastapi import APIRouter, Depends, File, Form, Query, UploadFile,Header
|
||||
from sqlalchemy.orm import Session
|
||||
from starlette.responses import StreamingResponse
|
||||
|
||||
load_dotenv()
|
||||
api_logger = get_api_logger()
|
||||
@@ -37,7 +38,7 @@ router = APIRouter(
|
||||
|
||||
@router.get("/health/status", response_model=ApiResponse)
|
||||
async def get_health_status(
|
||||
current_user: User = Depends(get_current_user)
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Get latest health status written by Celery periodic task
|
||||
@@ -55,8 +56,9 @@ async def get_health_status(
|
||||
|
||||
@router.get("/download_log")
|
||||
async def download_log(
|
||||
log_type: str = Query("file", regex="^(file|transmission)$", description="日志类型: file=完整文件, transmission=实时流式传输"),
|
||||
current_user: User = Depends(get_current_user)
|
||||
log_type: str = Query("file", regex="^(file|transmission)$",
|
||||
description="日志类型: file=完整文件, transmission=实时流式传输"),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Download or stream agent service log file
|
||||
@@ -75,16 +77,16 @@ async def download_log(
|
||||
- transmission mode: StreamingResponse with SSE
|
||||
"""
|
||||
api_logger.info(f"Log download requested with log_type={log_type}")
|
||||
|
||||
|
||||
# Validate log_type parameter (FastAPI Query regex already validates, but explicit check for clarity)
|
||||
if log_type not in ["file", "transmission"]:
|
||||
api_logger.warning(f"Invalid log_type parameter: {log_type}")
|
||||
return fail(
|
||||
BizCode.BAD_REQUEST,
|
||||
"无效的log_type参数",
|
||||
BizCode.BAD_REQUEST,
|
||||
"无效的log_type参数",
|
||||
"log_type必须是'file'或'transmission'"
|
||||
)
|
||||
|
||||
|
||||
# Route to appropriate mode
|
||||
if log_type == "file":
|
||||
# File mode: Return complete log file content
|
||||
@@ -119,10 +121,10 @@ async def download_log(
|
||||
@router.post("/writer_service", response_model=ApiResponse)
|
||||
@cur_workspace_access_guard()
|
||||
async def write_server(
|
||||
user_input: Write_UserInput,
|
||||
language_type: str = Header(default=None, alias="X-Language-Type"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
user_input: Write_UserInput,
|
||||
language_type: str = Header(default=None, alias="X-Language-Type"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Write service endpoint - processes write operations synchronously
|
||||
@@ -136,11 +138,11 @@ async def write_server(
|
||||
"""
|
||||
# 使用集中化的语言校验
|
||||
language = get_language_from_header(language_type)
|
||||
|
||||
|
||||
config_id = user_input.config_id
|
||||
workspace_id = current_user.current_workspace_id
|
||||
api_logger.info(f"Write service: workspace_id={workspace_id}, config_id={config_id}, language_type={language}")
|
||||
|
||||
|
||||
# 获取 storage_type,如果为 None 则使用默认值
|
||||
storage_type = workspace_service.get_workspace_storage_type(
|
||||
db=db,
|
||||
@@ -149,7 +151,7 @@ async def write_server(
|
||||
)
|
||||
if storage_type is None: storage_type = 'neo4j'
|
||||
user_rag_memory_id = ''
|
||||
|
||||
|
||||
# 如果 storage_type 是 rag,必须确保有有效的 user_rag_memory_id
|
||||
if storage_type == 'rag':
|
||||
if workspace_id:
|
||||
@@ -161,13 +163,15 @@ async def write_server(
|
||||
if knowledge:
|
||||
user_rag_memory_id = str(knowledge.id)
|
||||
else:
|
||||
api_logger.warning(f"未找到名为 'USER_RAG_MERORY' 的知识库,workspace_id: {workspace_id},将使用 neo4j 存储")
|
||||
api_logger.warning(
|
||||
f"未找到名为 'USER_RAG_MERORY' 的知识库,workspace_id: {workspace_id},将使用 neo4j 存储")
|
||||
storage_type = 'neo4j'
|
||||
else:
|
||||
api_logger.warning("workspace_id 为空,无法使用 rag 存储,将使用 neo4j 存储")
|
||||
storage_type = 'neo4j'
|
||||
|
||||
api_logger.info(f"Write service requested for group {user_input.end_user_id}, storage_type: {storage_type}, user_rag_memory_id: {user_rag_memory_id}")
|
||||
|
||||
api_logger.info(
|
||||
f"Write service requested for group {user_input.end_user_id}, storage_type: {storage_type}, user_rag_memory_id: {user_rag_memory_id}")
|
||||
try:
|
||||
messages_list = memory_agent_service.get_messages_list(user_input)
|
||||
result = await memory_agent_service.write_memory(
|
||||
@@ -175,7 +179,7 @@ async def write_server(
|
||||
messages_list,
|
||||
config_id,
|
||||
db,
|
||||
storage_type,
|
||||
storage_type,
|
||||
user_rag_memory_id,
|
||||
language
|
||||
)
|
||||
@@ -195,10 +199,10 @@ async def write_server(
|
||||
@router.post("/writer_service_async", response_model=ApiResponse)
|
||||
@cur_workspace_access_guard()
|
||||
async def write_server_async(
|
||||
user_input: Write_UserInput,
|
||||
language_type: str = Header(default=None, alias="X-Language-Type"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
user_input: Write_UserInput,
|
||||
language_type: str = Header(default=None, alias="X-Language-Type"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Async write service endpoint - enqueues write processing to Celery
|
||||
@@ -213,10 +217,11 @@ async def write_server_async(
|
||||
"""
|
||||
# 使用集中化的语言校验
|
||||
language = get_language_from_header(language_type)
|
||||
|
||||
|
||||
config_id = user_input.config_id
|
||||
workspace_id = current_user.current_workspace_id
|
||||
api_logger.info(f"Async write service: workspace_id={workspace_id}, config_id={config_id}, language_type={language}")
|
||||
api_logger.info(
|
||||
f"Async write service: workspace_id={workspace_id}, config_id={config_id}, language_type={language}")
|
||||
|
||||
# 获取 storage_type,如果为 None 则使用默认值
|
||||
storage_type = workspace_service.get_workspace_storage_type(
|
||||
@@ -244,7 +249,7 @@ async def write_server_async(
|
||||
args=[user_input.end_user_id, messages_list, config_id, storage_type, user_rag_memory_id, language]
|
||||
)
|
||||
api_logger.info(f"Write task queued: {task.id}")
|
||||
|
||||
|
||||
return success(data={"task_id": task.id}, msg="写入任务已提交")
|
||||
except Exception as e:
|
||||
api_logger.error(f"Async write operation failed: {str(e)}")
|
||||
@@ -254,9 +259,9 @@ async def write_server_async(
|
||||
@router.post("/read_service", response_model=ApiResponse)
|
||||
@cur_workspace_access_guard()
|
||||
async def read_server(
|
||||
user_input: UserInput,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
user_input: UserInput,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Read service endpoint - processes read operations synchronously
|
||||
@@ -291,8 +296,9 @@ async def read_server(
|
||||
)
|
||||
if knowledge:
|
||||
user_rag_memory_id = str(knowledge.id)
|
||||
|
||||
api_logger.info(f"Read service: group={user_input.end_user_id}, storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}, workspace_id={workspace_id}")
|
||||
|
||||
api_logger.info(
|
||||
f"Read service: group={user_input.end_user_id}, storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}, workspace_id={workspace_id}")
|
||||
try:
|
||||
result = await memory_agent_service.read_memory(
|
||||
user_input.end_user_id,
|
||||
@@ -306,7 +312,8 @@ async def read_server(
|
||||
)
|
||||
if str(user_input.search_switch) == "2":
|
||||
retrieve_info = result['answer']
|
||||
history = await SessionService(store).get_history(user_input.end_user_id, user_input.end_user_id, user_input.end_user_id)
|
||||
history = await SessionService(store).get_history(user_input.end_user_id, user_input.end_user_id,
|
||||
user_input.end_user_id)
|
||||
query = user_input.message
|
||||
|
||||
# 调用 memory_agent_service 的方法生成最终答案
|
||||
@@ -319,7 +326,7 @@ async def read_server(
|
||||
db=db
|
||||
)
|
||||
if "信息不足,无法回答" in result['answer']:
|
||||
result['answer']=retrieve_info
|
||||
result['answer'] = retrieve_info
|
||||
return success(data=result, msg="回复对话消息成功")
|
||||
except BaseException as e:
|
||||
# Handle ExceptionGroup from TaskGroup (Python 3.11+) or BaseExceptionGroup
|
||||
@@ -335,9 +342,10 @@ async def read_server(
|
||||
@router.post("/file", response_model=ApiResponse)
|
||||
async def file_update(
|
||||
files: List[UploadFile] = File(..., description="要上传的文件"),
|
||||
model_id:str = Form(..., description="模型ID"),
|
||||
model_id: str = Form(..., description="模型ID"),
|
||||
metadata: Optional[str] = Form(None, description="文件元数据 (JSON格式)"),
|
||||
current_user: User = Depends(get_current_user)
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
文件上传接口 - 支持图片识别
|
||||
@@ -350,9 +358,6 @@ async def file_update(
|
||||
Returns:
|
||||
文件处理结果
|
||||
"""
|
||||
|
||||
db_gen = get_db() # get_db 通常是一个生成器
|
||||
db = next(db_gen)
|
||||
api_logger.info(f"File upload requested, file count: {len(files)}")
|
||||
config = ModelConfigService.get_model_by_id(db=db, model_id=model_id)
|
||||
apiConfig: ModelApiKey = config.api_keys[0]
|
||||
@@ -361,7 +366,7 @@ async def file_update(
|
||||
for file in files:
|
||||
api_logger.debug(f"Processing file: {file.filename}, content_type: {file.content_type}")
|
||||
content = await file.read()
|
||||
|
||||
|
||||
if file.content_type and file.content_type.startswith("image/"):
|
||||
vision_model = QWenCV(
|
||||
key=apiConfig.api_key,
|
||||
@@ -375,12 +380,12 @@ async def file_update(
|
||||
else:
|
||||
api_logger.warning(f"Unsupported file type: {file.content_type}")
|
||||
file_content.append(f"[不支持的文件类型: {file.content_type}]")
|
||||
|
||||
|
||||
result_text = ';'.join(file_content)
|
||||
api_logger.info(f"File processing completed, result length: {len(result_text)}")
|
||||
|
||||
|
||||
return success(data=result_text, msg="转换文本成功")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"File processing failed: {str(e)}", exc_info=True)
|
||||
return fail(BizCode.INTERNAL_ERROR, "转换文本失败", str(e))
|
||||
@@ -430,8 +435,8 @@ async def read_server_async(
|
||||
|
||||
@router.get("/read_result/", response_model=ApiResponse)
|
||||
async def get_read_task_result(
|
||||
task_id: str,
|
||||
current_user: User = Depends(get_current_user)
|
||||
task_id: str,
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Get the status and result of an async read task
|
||||
@@ -452,7 +457,7 @@ async def get_read_task_result(
|
||||
try:
|
||||
result = task_service.get_task_memory_read_result(task_id)
|
||||
status = result.get("status")
|
||||
|
||||
|
||||
if status == "SUCCESS":
|
||||
# 任务成功完成
|
||||
task_result = result.get("result", {})
|
||||
@@ -470,7 +475,7 @@ async def get_read_task_result(
|
||||
else:
|
||||
# 旧格式:直接返回结果
|
||||
return success(data=task_result, msg="查询任务已完成")
|
||||
|
||||
|
||||
elif status == "FAILURE":
|
||||
# 任务失败
|
||||
error_info = result.get("result", "Unknown error")
|
||||
@@ -479,7 +484,7 @@ async def get_read_task_result(
|
||||
else:
|
||||
error_msg = str(error_info)
|
||||
return fail(BizCode.INTERNAL_ERROR, "查询任务失败", error_msg)
|
||||
|
||||
|
||||
elif status in ["PENDING", "STARTED"]:
|
||||
# 任务进行中
|
||||
return success(
|
||||
@@ -499,7 +504,7 @@ async def get_read_task_result(
|
||||
},
|
||||
msg=f"任务状态: {status}"
|
||||
)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"Read task status check failed: {str(e)}", exc_info=True)
|
||||
return fail(BizCode.INTERNAL_ERROR, "任务状态查询失败", str(e))
|
||||
@@ -507,8 +512,8 @@ async def get_read_task_result(
|
||||
|
||||
@router.get("/write_result/", response_model=ApiResponse)
|
||||
async def get_write_task_result(
|
||||
task_id: str,
|
||||
current_user: User = Depends(get_current_user)
|
||||
task_id: str,
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Get the status and result of an async write task
|
||||
@@ -529,7 +534,7 @@ async def get_write_task_result(
|
||||
try:
|
||||
result = task_service.get_task_memory_write_result(task_id)
|
||||
status = result.get("status")
|
||||
|
||||
|
||||
if status == "SUCCESS":
|
||||
# 任务成功完成
|
||||
task_result = result.get("result", {})
|
||||
@@ -547,7 +552,7 @@ async def get_write_task_result(
|
||||
else:
|
||||
# 旧格式:直接返回结果
|
||||
return success(data=task_result, msg="写入任务已完成")
|
||||
|
||||
|
||||
elif status == "FAILURE":
|
||||
# 任务失败
|
||||
error_info = result.get("result", "Unknown error")
|
||||
@@ -556,7 +561,7 @@ async def get_write_task_result(
|
||||
else:
|
||||
error_msg = str(error_info)
|
||||
return fail(BizCode.INTERNAL_ERROR, "写入任务失败", error_msg)
|
||||
|
||||
|
||||
elif status in ["PENDING", "STARTED"]:
|
||||
# 任务进行中
|
||||
return success(
|
||||
@@ -576,7 +581,7 @@ async def get_write_task_result(
|
||||
},
|
||||
msg=f"任务状态: {status}"
|
||||
)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"Write task status check failed: {str(e)}", exc_info=True)
|
||||
return fail(BizCode.INTERNAL_ERROR, "任务状态查询失败", str(e))
|
||||
@@ -584,9 +589,9 @@ async def get_write_task_result(
|
||||
|
||||
@router.post("/status_type", response_model=ApiResponse)
|
||||
async def status_type(
|
||||
user_input: Write_UserInput,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
user_input: Write_UserInput,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Determine the type of user message (read or write)
|
||||
@@ -629,9 +634,10 @@ async def status_type(
|
||||
|
||||
@router.get("/stats/types", response_model=ApiResponse)
|
||||
async def get_knowledge_type_stats_api(
|
||||
end_user_id: Optional[str] = Query(None, description="用户ID(可选)"),
|
||||
only_active: bool = Query(True, description="仅统计有效记录(status=1)"),
|
||||
current_user: User = Depends(get_current_user)
|
||||
end_user_id: Optional[str] = Query(None, description="用户ID(可选)"),
|
||||
only_active: bool = Query(True, description="仅统计有效记录(status=1)"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
统计当前空间下各知识库类型的数量,包含 General | Web | Third-party | Folder。
|
||||
@@ -640,14 +646,9 @@ async def get_knowledge_type_stats_api(
|
||||
- 知识库类型根据当前用户的 current_workspace_id 过滤
|
||||
- 如果用户没有当前工作空间,对应的统计返回 0
|
||||
"""
|
||||
api_logger.info(f"Knowledge type stats requested for workspace_id: {current_user.current_workspace_id}, end_user_id: {end_user_id}")
|
||||
api_logger.info(
|
||||
f"Knowledge type stats requested for workspace_id: {current_user.current_workspace_id}, end_user_id: {end_user_id}")
|
||||
try:
|
||||
from app.db import get_db
|
||||
|
||||
# 获取数据库会话
|
||||
db_gen = get_db()
|
||||
db = next(db_gen)
|
||||
|
||||
# 调用service层函数
|
||||
result = await memory_agent_service.get_knowledge_type_stats(
|
||||
end_user_id=end_user_id,
|
||||
@@ -655,7 +656,7 @@ async def get_knowledge_type_stats_api(
|
||||
current_workspace_id=current_user.current_workspace_id,
|
||||
db=db
|
||||
)
|
||||
|
||||
|
||||
return success(data=result, msg="获取知识库类型统计成功")
|
||||
except Exception as e:
|
||||
api_logger.error(f"Knowledge type stats failed: {str(e)}")
|
||||
@@ -664,11 +665,11 @@ async def get_knowledge_type_stats_api(
|
||||
|
||||
@router.get("/analytics/interest_distribution/by_user", response_model=ApiResponse)
|
||||
async def get_interest_distribution_by_user_api(
|
||||
end_user_id: str = Query(..., description="用户ID(必填)"),
|
||||
limit: int = Query(5, le=5, description="返回兴趣标签数量限制,最多5个"),
|
||||
language_type: str = Header(default=None, alias="X-Language-Type"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
end_user_id: str = Query(..., description="用户ID(必填)"),
|
||||
limit: int = Query(5, le=5, description="返回兴趣标签数量限制,最多5个"),
|
||||
language_type: str = Header(default=None, alias="X-Language-Type"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
获取指定用户的兴趣分布标签
|
||||
@@ -716,9 +717,9 @@ async def get_interest_distribution_by_user_api(
|
||||
|
||||
@router.get("/analytics/user_profile", response_model=ApiResponse)
|
||||
async def get_user_profile_api(
|
||||
end_user_id: Optional[str] = Query(None, description="用户ID(可选)"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
end_user_id: Optional[str] = Query(None, description="用户ID(可选)"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
获取用户详情,包含:
|
||||
@@ -756,17 +757,17 @@ async def get_user_profile_api(
|
||||
# ):
|
||||
# """
|
||||
# Get parsed API documentation (Public endpoint - no authentication required)
|
||||
|
||||
|
||||
# Args:
|
||||
# file_path: Optional path to API docs file. If None, uses default path.
|
||||
|
||||
|
||||
# Returns:
|
||||
# Parsed API documentation including title, meta info, and sections
|
||||
# """
|
||||
# api_logger.info(f"API docs requested, file_path: {file_path or 'default'}")
|
||||
# try:
|
||||
# result = await memory_agent_service.get_api_docs(file_path)
|
||||
|
||||
|
||||
# if result.get("success"):
|
||||
# return success(msg=result["msg"], data=result["data"])
|
||||
# else:
|
||||
@@ -782,9 +783,9 @@ async def get_user_profile_api(
|
||||
|
||||
@router.get("/end_user/{end_user_id}/connected_config", response_model=ApiResponse)
|
||||
async def get_end_user_connected_config(
|
||||
end_user_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
end_user_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
获取终端用户关联的记忆配置
|
||||
@@ -803,9 +804,9 @@ async def get_end_user_connected_config(
|
||||
from app.services.memory_agent_service import (
|
||||
get_end_user_connected_config as get_config,
|
||||
)
|
||||
|
||||
|
||||
api_logger.info(f"Getting connected config for end_user: {end_user_id}")
|
||||
|
||||
|
||||
try:
|
||||
result = get_config(end_user_id, db)
|
||||
return success(data=result, msg="获取终端用户关联配置成功")
|
||||
@@ -814,4 +815,4 @@ async def get_end_user_connected_config(
|
||||
return fail(BizCode.NOT_FOUND, str(e))
|
||||
except Exception as e:
|
||||
api_logger.error(f"Failed to get end user connected config: {str(e)}", exc_info=True)
|
||||
return fail(BizCode.INTERNAL_ERROR, "获取终端用户关联配置失败", str(e))
|
||||
return fail(BizCode.INTERNAL_ERROR, "获取终端用户关联配置失败", str(e))
|
||||
|
||||
@@ -606,8 +606,8 @@ async def dashboard_data(
|
||||
|
||||
# 获取RAG相关数据
|
||||
try:
|
||||
# total_memory: 使用 total_chunk(总chunk数)
|
||||
total_chunk = memory_dashboard_service.get_rag_total_chunk(db, current_user)
|
||||
# total_memory: 只统计用户知识库(permission_id='Memory')的chunk数
|
||||
total_chunk = memory_dashboard_service.get_rag_user_kb_total_chunk(db, current_user)
|
||||
rag_data["total_memory"] = total_chunk
|
||||
|
||||
# total_app: 统计当前空间下的所有app数量
|
||||
|
||||
@@ -2,7 +2,7 @@ from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from fastapi.responses import StreamingResponse
|
||||
from fastapi.responses import StreamingResponse, JSONResponse
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.error_codes import BizCode
|
||||
@@ -85,6 +85,7 @@ def create_config(
|
||||
payload: ConfigParamsCreate,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
x_language_type: Optional[str] = Header(None, alias="X-Language-Type"),
|
||||
) -> dict:
|
||||
workspace_id = current_user.current_workspace_id
|
||||
# 检查用户是否已选择工作空间
|
||||
@@ -99,7 +100,29 @@ def create_config(
|
||||
svc = DataConfigService(db)
|
||||
result = svc.create(payload)
|
||||
return success(data=result, msg="创建成功")
|
||||
except ValueError as e:
|
||||
err_str = str(e)
|
||||
if err_str.startswith("DUPLICATE_CONFIG_NAME:"):
|
||||
config_name = err_str.split(":", 1)[1]
|
||||
api_logger.warning(f"重复的配置名称 '{config_name}' 在工作空间 {workspace_id}")
|
||||
lang = get_language_from_header(x_language_type)
|
||||
if lang == "en":
|
||||
msg = fail(BizCode.BAD_REQUEST, "Config name already exists", f"A config named \"{config_name}\" already exists in the current workspace. Please use a different name.")
|
||||
else:
|
||||
msg = fail(BizCode.BAD_REQUEST, "配置名称已存在", f"当前工作空间下已存在名为「{config_name}」的记忆配置,请使用其他名称")
|
||||
return JSONResponse(status_code=400, content=msg)
|
||||
api_logger.error(f"Create config failed: {err_str}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "创建配置失败", err_str)
|
||||
except Exception as e:
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
if isinstance(e, IntegrityError) and "uq_workspace_config_name" in str(getattr(e, 'orig', '')):
|
||||
api_logger.warning(f"重复的配置名称 '{payload.config_name}' 在工作空间 {workspace_id}")
|
||||
lang = get_language_from_header(x_language_type)
|
||||
if lang == "en":
|
||||
msg = fail(BizCode.BAD_REQUEST, "Config name already exists", f"A config named \"{payload.config_name}\" already exists in the current workspace. Please use a different name.")
|
||||
else:
|
||||
msg = fail(BizCode.BAD_REQUEST, "配置名称已存在", f"当前工作空间下已存在名为「{payload.config_name}」的记忆配置,请使用其他名称")
|
||||
return JSONResponse(status_code=400, content=msg)
|
||||
api_logger.error(f"Create config failed: {str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "创建配置失败", str(e))
|
||||
|
||||
|
||||
@@ -469,7 +469,9 @@ async def create_model_api_key_by_provider(
|
||||
config=api_key_data.config,
|
||||
is_active=api_key_data.is_active,
|
||||
priority=api_key_data.priority,
|
||||
model_config_ids=model_config_ids
|
||||
model_config_ids=model_config_ids,
|
||||
capability=api_key_data.capability,
|
||||
is_omni=api_key_data.is_omni
|
||||
)
|
||||
created_keys, failed_models = await ModelApiKeyService.create_api_key_by_provider(db=db, data=create_data)
|
||||
|
||||
|
||||
@@ -25,7 +25,7 @@ from typing import Dict, Optional, List
|
||||
from urllib.parse import quote
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, File, UploadFile, Form, Header
|
||||
from fastapi.responses import StreamingResponse
|
||||
from fastapi.responses import StreamingResponse, JSONResponse
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.config import settings
|
||||
@@ -124,15 +124,23 @@ def _get_ontology_service(
|
||||
)
|
||||
|
||||
# 通过 Repository 获取可用的 API Key(负载均衡逻辑由 Repository 处理)
|
||||
from app.repositories.model_repository import ModelApiKeyRepository
|
||||
api_keys = ModelApiKeyRepository.get_by_model_config(db, model_config.id)
|
||||
if not api_keys:
|
||||
# from app.repositories.model_repository import ModelApiKeyRepository
|
||||
from app.services.model_service import ModelApiKeyService
|
||||
api_key_config = ModelApiKeyService.get_available_api_key(db, model_config.id)
|
||||
if not api_key_config:
|
||||
logger.error(f"Model {llm_id} has no active API key")
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="指定的LLM模型没有可用的API密钥"
|
||||
)
|
||||
api_key_config = api_keys[0]
|
||||
# api_keys = ModelApiKeyRepository.get_by_model_config(db, model_config.id)
|
||||
# if not api_keys:
|
||||
# logger.error(f"Model {llm_id} has no active API key")
|
||||
# raise HTTPException(
|
||||
# status_code=400,
|
||||
# detail="指定的LLM模型没有可用的API密钥"
|
||||
# )
|
||||
# api_key_config = api_keys[0]
|
||||
|
||||
is_composite = getattr(model_config, 'is_composite', False)
|
||||
logger.info(
|
||||
@@ -154,6 +162,7 @@ def _get_ontology_service(
|
||||
provider=actual_provider,
|
||||
api_key=api_key_config.api_key,
|
||||
base_url=api_key_config.api_base,
|
||||
is_omni=api_key_config.is_omni,
|
||||
max_retries=3,
|
||||
timeout=60.0
|
||||
)
|
||||
@@ -280,7 +289,8 @@ async def extract_ontology(
|
||||
async def create_scene(
|
||||
request: SceneCreateRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
current_user: User = Depends(get_current_user),
|
||||
x_language_type: Optional[str] = Header(None, alias="X-Language-Type")
|
||||
):
|
||||
"""创建本体场景
|
||||
|
||||
@@ -351,8 +361,18 @@ async def create_scene(
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", str(e))
|
||||
|
||||
except RuntimeError as e:
|
||||
api_logger.error(f"Runtime error in scene creation: {str(e)}", exc_info=True)
|
||||
return fail(BizCode.INTERNAL_ERROR, "场景创建失败", str(e))
|
||||
err_str = str(e)
|
||||
if "UniqueViolation" in err_str or "uq_workspace_scene_name" in err_str:
|
||||
api_logger.warning(f"Duplicate scene name '{request.scene_name}' in workspace {current_user.current_workspace_id}")
|
||||
from app.core.language_utils import get_language_from_header
|
||||
lang = get_language_from_header(x_language_type)
|
||||
if lang == "en":
|
||||
msg = fail(BizCode.BAD_REQUEST, "Scene name already exists", f"A scene named \"{request.scene_name}\" already exists in the current workspace. Please use a different name.")
|
||||
else:
|
||||
msg = fail(BizCode.BAD_REQUEST, "场景名称已存在", f"当前工作空间下已存在名为「{request.scene_name}」的场景,请使用其他名称")
|
||||
return JSONResponse(status_code=400, content=msg)
|
||||
api_logger.error(f"Runtime error in scene creation: {err_str}", exc_info=True)
|
||||
return fail(BizCode.INTERNAL_ERROR, "场景创建失败", err_str)
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"Unexpected error in scene creation: {str(e)}", exc_info=True)
|
||||
@@ -514,10 +534,9 @@ async def delete_scene(
|
||||
f"尝试删除系统默认场景: user_id={current_user.id}, "
|
||||
f"scene_id={scene_id}, scene_name={scene.scene_name}"
|
||||
)
|
||||
return fail(
|
||||
BizCode.BAD_REQUEST,
|
||||
"系统默认场景不可删除",
|
||||
"该场景为系统预设场景,不允许删除"
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="SYSTEM_DEFAULT_SCENE_CANNOT_DELETE"
|
||||
)
|
||||
|
||||
# 创建OntologyService实例
|
||||
@@ -543,6 +562,9 @@ async def delete_scene(
|
||||
|
||||
return success(data={"deleted": success_flag}, msg="场景删除成功")
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
|
||||
except ValueError as e:
|
||||
api_logger.warning(f"Validation error in scene deletion: {str(e)}")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", str(e))
|
||||
@@ -650,7 +672,8 @@ async def get_scenes(
|
||||
async def create_class(
|
||||
request: ClassCreateRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
current_user: User = Depends(get_current_user),
|
||||
x_language_type: Optional[str] = Header(None, alias="X-Language-Type")
|
||||
):
|
||||
"""创建本体类型
|
||||
|
||||
@@ -665,7 +688,7 @@ async def create_class(
|
||||
ApiResponse: 包含创建的类型信息
|
||||
"""
|
||||
from app.controllers.ontology_secondary_routes import create_class_handler
|
||||
return await create_class_handler(request, db, current_user)
|
||||
return await create_class_handler(request, db, current_user, x_language_type)
|
||||
|
||||
|
||||
@router.put("/class/{class_id}", response_model=ApiResponse)
|
||||
|
||||
@@ -7,7 +7,7 @@
|
||||
from uuid import UUID
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import Depends
|
||||
from fastapi import Depends, Header
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.error_codes import BizCode
|
||||
@@ -58,7 +58,7 @@ async def scenes_handler(
|
||||
workspace_id: Optional[str] = None,
|
||||
scene_name: Optional[str] = None,
|
||||
page: Optional[int] = None,
|
||||
page_size: Optional[int] = None,
|
||||
pagesize: Optional[int] = None,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
@@ -71,14 +71,14 @@ async def scenes_handler(
|
||||
workspace_id: 工作空间ID(可选,默认当前用户工作空间)
|
||||
scene_name: 场景名称关键词(可选,支持模糊匹配)
|
||||
page: 页码(可选,从1开始,仅在全量查询时有效)
|
||||
page_size: 每页数量(可选,仅在全量查询时有效)
|
||||
pagesize: 每页数量(可选,仅在全量查询时有效)
|
||||
db: 数据库会话
|
||||
current_user: 当前用户
|
||||
"""
|
||||
operation = "search" if scene_name else "list"
|
||||
api_logger.info(
|
||||
f"Scene {operation} requested by user {current_user.id}, "
|
||||
f"workspace_id={workspace_id}, keyword={scene_name}, page={page}, page_size={page_size}"
|
||||
f"workspace_id={workspace_id}, keyword={scene_name}, page={page}, pagesize={pagesize}"
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -105,13 +105,13 @@ async def scenes_handler(
|
||||
api_logger.warning(f"Invalid page number: {page}")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "页码必须大于0")
|
||||
|
||||
if page_size is not None and page_size < 1:
|
||||
api_logger.warning(f"Invalid page_size: {page_size}")
|
||||
if pagesize is not None and pagesize < 1:
|
||||
api_logger.warning(f"Invalid pagesize: {pagesize}")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "每页数量必须大于0")
|
||||
|
||||
# 如果只提供了page或page_size中的一个,返回错误
|
||||
if (page is not None and page_size is None) or (page is None and page_size is not None):
|
||||
api_logger.warning(f"Incomplete pagination params: page={page}, page_size={page_size}")
|
||||
# 如果只提供了page或pagesize中的一个,返回错误
|
||||
if (page is not None and pagesize is None) or (page is None and pagesize is not None):
|
||||
api_logger.warning(f"Incomplete pagination params: page={page}, pagesize={pagesize}")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "分页参数page和pagesize必须同时提供")
|
||||
|
||||
# 模糊搜索场景(支持分页)
|
||||
@@ -119,17 +119,15 @@ async def scenes_handler(
|
||||
total = len(scenes)
|
||||
|
||||
# 如果提供了分页参数,进行分页处理
|
||||
if page is not None and page_size is not None:
|
||||
start_idx = (page - 1) * page_size
|
||||
end_idx = start_idx + page_size
|
||||
if page is not None and pagesize is not None:
|
||||
start_idx = (page - 1) * pagesize
|
||||
end_idx = start_idx + pagesize
|
||||
scenes = scenes[start_idx:end_idx]
|
||||
|
||||
# 构建响应
|
||||
items = []
|
||||
for scene in scenes:
|
||||
# 获取前3个class_name作为entity_type
|
||||
entity_type = [cls.class_name for cls in scene.classes[:3]] if scene.classes else None
|
||||
# 动态计算 type_num
|
||||
type_num = len(scene.classes) if scene.classes else 0
|
||||
|
||||
items.append(SceneResponse(
|
||||
@@ -141,17 +139,16 @@ async def scenes_handler(
|
||||
workspace_id=scene.workspace_id,
|
||||
created_at=scene.created_at,
|
||||
updated_at=scene.updated_at,
|
||||
classes_count=type_num
|
||||
classes_count=type_num,
|
||||
is_system_default=scene.is_system_default
|
||||
))
|
||||
|
||||
# 构建响应(包含分页信息)
|
||||
if page is not None and page_size is not None:
|
||||
# 计算是否有下一页
|
||||
hasnext = (page * page_size) < total
|
||||
|
||||
if page is not None and pagesize is not None:
|
||||
hasnext = (page * pagesize) < total
|
||||
pagination_info = PaginationInfo(
|
||||
page=page,
|
||||
pagesize=page_size,
|
||||
pagesize=pagesize,
|
||||
total=total,
|
||||
hasnext=hasnext
|
||||
)
|
||||
@@ -165,28 +162,25 @@ async def scenes_handler(
|
||||
)
|
||||
else:
|
||||
# 获取所有场景(支持分页)
|
||||
# 验证分页参数
|
||||
if page is not None and page < 1:
|
||||
api_logger.warning(f"Invalid page number: {page}")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "页码必须大于0")
|
||||
|
||||
if page_size is not None and page_size < 1:
|
||||
api_logger.warning(f"Invalid page_size: {page_size}")
|
||||
if pagesize is not None and pagesize < 1:
|
||||
api_logger.warning(f"Invalid pagesize: {pagesize}")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "每页数量必须大于0")
|
||||
|
||||
# 如果只提供了page或page_size中的一个,返回错误
|
||||
if (page is not None and page_size is None) or (page is None and page_size is not None):
|
||||
api_logger.warning(f"Incomplete pagination params: page={page}, page_size={page_size}")
|
||||
# 如果只提供了page或pagesize中的一个,返回错误
|
||||
if (page is not None and pagesize is None) or (page is None and pagesize is not None):
|
||||
api_logger.warning(f"Incomplete pagination params: page={page}, pagesize={pagesize}")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "分页参数page和pagesize必须同时提供")
|
||||
|
||||
scenes, total = service.list_scenes(ws_uuid, page, page_size)
|
||||
scenes, total = service.list_scenes(ws_uuid, page, pagesize)
|
||||
|
||||
# 构建响应
|
||||
items = []
|
||||
for scene in scenes:
|
||||
# 获取前3个class_name作为entity_type
|
||||
entity_type = [cls.class_name for cls in scene.classes[:3]] if scene.classes else None
|
||||
# 动态计算 type_num
|
||||
type_num = len(scene.classes) if scene.classes else 0
|
||||
|
||||
items.append(SceneResponse(
|
||||
@@ -198,17 +192,16 @@ async def scenes_handler(
|
||||
workspace_id=scene.workspace_id,
|
||||
created_at=scene.created_at,
|
||||
updated_at=scene.updated_at,
|
||||
classes_count=type_num
|
||||
classes_count=type_num,
|
||||
is_system_default=scene.is_system_default
|
||||
))
|
||||
|
||||
# 构建响应(包含分页信息)
|
||||
if page is not None and page_size is not None:
|
||||
# 计算是否有下一页
|
||||
hasnext = (page * page_size) < total
|
||||
|
||||
if page is not None and pagesize is not None:
|
||||
hasnext = (page * pagesize) < total
|
||||
pagination_info = PaginationInfo(
|
||||
page=page,
|
||||
pagesize=page_size,
|
||||
pagesize=pagesize,
|
||||
total=total,
|
||||
hasnext=hasnext
|
||||
)
|
||||
@@ -238,7 +231,8 @@ async def scenes_handler(
|
||||
async def create_class_handler(
|
||||
request: ClassCreateRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
current_user: User = Depends(get_current_user),
|
||||
x_language_type: Optional[str] = None
|
||||
):
|
||||
"""创建本体类型(统一使用列表形式,支持单个或批量)"""
|
||||
|
||||
@@ -271,8 +265,11 @@ async def create_class_handler(
|
||||
]
|
||||
|
||||
if count == 1:
|
||||
# 单个创建
|
||||
# 单个创建 - 先检查重名
|
||||
class_data = classes_data[0]
|
||||
existing = OntologyClassRepository(db).get_by_name(class_data["class_name"], request.scene_id)
|
||||
if existing:
|
||||
raise ValueError(f"DUPLICATE_CLASS_NAME:{class_data['class_name']}")
|
||||
ontology_class = service.create_class(
|
||||
scene_id=request.scene_id,
|
||||
class_name=class_data["class_name"],
|
||||
@@ -330,12 +327,36 @@ async def create_class_handler(
|
||||
return success(data=response.model_dump(mode='json'), msg="批量创建完成")
|
||||
|
||||
except ValueError as e:
|
||||
api_logger.warning(f"Validation error in class creation: {str(e)}")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", str(e))
|
||||
|
||||
err_str = str(e)
|
||||
if err_str.startswith("DUPLICATE_CLASS_NAME:"):
|
||||
class_name = err_str.split(":", 1)[1]
|
||||
api_logger.warning(f"Duplicate class name '{class_name}' in scene {request.scene_id}")
|
||||
from app.core.language_utils import get_language_from_header
|
||||
from fastapi.responses import JSONResponse
|
||||
lang = get_language_from_header(x_language_type)
|
||||
if lang == "en":
|
||||
msg = fail(BizCode.BAD_REQUEST, "Class name already exists", f"A class named \"{class_name}\" already exists in this scene. Please use a different name.")
|
||||
else:
|
||||
msg = fail(BizCode.BAD_REQUEST, "类型名称已存在", f"当前场景下已存在名为「{class_name}」的类型,请使用其他名称")
|
||||
return JSONResponse(status_code=400, content=msg)
|
||||
api_logger.warning(f"Validation error in class creation: {err_str}")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", err_str)
|
||||
|
||||
except RuntimeError as e:
|
||||
api_logger.error(f"Runtime error in class creation: {str(e)}", exc_info=True)
|
||||
return fail(BizCode.INTERNAL_ERROR, "类型创建失败", str(e))
|
||||
err_str = str(e)
|
||||
if "UniqueViolation" in err_str or "uq_scene_class_name" in err_str:
|
||||
api_logger.warning(f"Duplicate class name in scene {request.scene_id}")
|
||||
from app.core.language_utils import get_language_from_header
|
||||
from fastapi.responses import JSONResponse
|
||||
lang = get_language_from_header(x_language_type)
|
||||
class_name = request.classes[0].class_name if request.classes else ""
|
||||
if lang == "en":
|
||||
msg = fail(BizCode.BAD_REQUEST, "Class name already exists", f"A class named \"{class_name}\" already exists in this scene. Please use a different name.")
|
||||
else:
|
||||
msg = fail(BizCode.BAD_REQUEST, "类型名称已存在", f"当前场景下已存在名为「{class_name}」的类型,请使用其他名称")
|
||||
return JSONResponse(status_code=400, content=msg)
|
||||
api_logger.error(f"Runtime error in class creation: {err_str}", exc_info=True)
|
||||
return fail(BizCode.INTERNAL_ERROR, "类型创建失败", err_str)
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"Unexpected error in class creation: {str(e)}", exc_info=True)
|
||||
@@ -615,6 +636,7 @@ async def classes_handler(
|
||||
scene_id=scene_uuid,
|
||||
scene_name=scene.scene_name,
|
||||
scene_description=scene.scene_description,
|
||||
is_system_default=scene.is_system_default,
|
||||
items=items
|
||||
)
|
||||
|
||||
|
||||
@@ -249,6 +249,7 @@ async def chat(
|
||||
app_id=app.id,
|
||||
workspace_id=workspace_id,
|
||||
release_id=app.current_release.id,
|
||||
public=True
|
||||
):
|
||||
event_type = event.get("event", "message")
|
||||
event_data = event.get("data", {})
|
||||
|
||||
@@ -39,7 +39,7 @@ async def write_memory_api_service(
|
||||
|
||||
Stores memory content for the specified end user using the Memory API Service.
|
||||
"""
|
||||
logger.info(f"Memory write request - end_user_id: {payload.end_user_id}, tenant_id: {api_key_auth.tenant_id}")
|
||||
logger.info(f"Memory write request - end_user_id: {payload.end_user_id}, workspace_id: {api_key_auth.workspace_id}")
|
||||
|
||||
memory_api_service = MemoryAPIService(db)
|
||||
|
||||
|
||||
@@ -11,35 +11,37 @@ LangChain Agent 封装
|
||||
import time
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional, Sequence
|
||||
|
||||
from app.core.memory.agent.langgraph_graph.write_graph import write_long_term
|
||||
from app.core.memory.agent.langgraph_graph.write_graph import write_long_term
|
||||
from app.db import get_db
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.core.models import RedBearLLM, RedBearModelConfig
|
||||
from app.models.models_model import ModelType
|
||||
from app.models.models_model import ModelType, ModelProvider
|
||||
from app.services.memory_agent_service import (
|
||||
get_end_user_connected_config,
|
||||
)
|
||||
from langchain.agents import create_agent
|
||||
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
|
||||
from langchain_core.tools import BaseTool
|
||||
|
||||
logger = get_business_logger()
|
||||
|
||||
|
||||
class LangChainAgent:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str,
|
||||
api_key: str,
|
||||
provider: str = "openai",
|
||||
api_base: Optional[str] = None,
|
||||
temperature: float = 0.7,
|
||||
max_tokens: int = 2000,
|
||||
system_prompt: Optional[str] = None,
|
||||
tools: Optional[Sequence[BaseTool]] = None,
|
||||
streaming: bool = False,
|
||||
max_iterations: Optional[int] = None, # 最大迭代次数(None 表示自动计算)
|
||||
max_tool_consecutive_calls: int = 3 # 单个工具最大连续调用次数
|
||||
self,
|
||||
model_name: str,
|
||||
api_key: str,
|
||||
provider: str = "openai",
|
||||
api_base: Optional[str] = None,
|
||||
is_omni: bool = False,
|
||||
temperature: float = 0.7,
|
||||
max_tokens: int = 2000,
|
||||
system_prompt: Optional[str] = None,
|
||||
tools: Optional[Sequence[BaseTool]] = None,
|
||||
streaming: bool = False,
|
||||
max_iterations: Optional[int] = None, # 最大迭代次数(None 表示自动计算)
|
||||
max_tool_consecutive_calls: int = 3 # 单个工具最大连续调用次数
|
||||
):
|
||||
"""初始化 LangChain Agent
|
||||
|
||||
@@ -60,12 +62,13 @@ class LangChainAgent:
|
||||
self.provider = provider
|
||||
self.tools = tools or []
|
||||
self.streaming = streaming
|
||||
self.is_omni = is_omni
|
||||
self.max_tool_consecutive_calls = max_tool_consecutive_calls
|
||||
|
||||
|
||||
# 工具调用计数器:记录每个工具的连续调用次数
|
||||
self.tool_call_counter: Dict[str, int] = {}
|
||||
self.last_tool_called: Optional[str] = None
|
||||
|
||||
|
||||
# 根据工具数量动态调整最大迭代次数
|
||||
# 基础值 + 每个工具额外的调用机会
|
||||
if max_iterations is None:
|
||||
@@ -73,9 +76,9 @@ class LangChainAgent:
|
||||
self.max_iterations = 5 + len(self.tools) * 2
|
||||
else:
|
||||
self.max_iterations = max_iterations
|
||||
|
||||
|
||||
self.system_prompt = system_prompt or "你是一个专业的AI助手"
|
||||
|
||||
|
||||
logger.debug(
|
||||
f"Agent 迭代次数配置: max_iterations={self.max_iterations}, "
|
||||
f"tool_count={len(self.tools)}, "
|
||||
@@ -89,6 +92,7 @@ class LangChainAgent:
|
||||
provider=provider,
|
||||
api_key=api_key,
|
||||
base_url=api_base,
|
||||
is_omni=is_omni,
|
||||
extra_params={
|
||||
"temperature": temperature,
|
||||
"max_tokens": max_tokens,
|
||||
@@ -143,21 +147,22 @@ class LangChainAgent:
|
||||
"""
|
||||
from langchain_core.tools import StructuredTool
|
||||
from functools import wraps
|
||||
|
||||
|
||||
wrapped_tools = []
|
||||
|
||||
|
||||
for original_tool in tools:
|
||||
tool_name = original_tool.name
|
||||
original_func = original_tool.func if hasattr(original_tool, 'func') else None
|
||||
|
||||
|
||||
if not original_func:
|
||||
# 如果无法获取原始函数,直接使用原工具
|
||||
wrapped_tools.append(original_tool)
|
||||
continue
|
||||
|
||||
|
||||
# 创建包装函数
|
||||
def make_wrapped_func(tool_name, original_func):
|
||||
"""创建包装函数的工厂函数,避免闭包问题"""
|
||||
|
||||
@wraps(original_func)
|
||||
def wrapped_func(*args, **kwargs):
|
||||
"""包装后的工具函数,跟踪连续调用次数"""
|
||||
@@ -168,13 +173,13 @@ class LangChainAgent:
|
||||
# 切换到新工具,重置计数器
|
||||
self.tool_call_counter[tool_name] = 1
|
||||
self.last_tool_called = tool_name
|
||||
|
||||
|
||||
current_count = self.tool_call_counter[tool_name]
|
||||
|
||||
|
||||
logger.debug(
|
||||
f"工具调用: {tool_name}, 连续调用次数: {current_count}/{self.max_tool_consecutive_calls}"
|
||||
)
|
||||
|
||||
|
||||
# 检查是否超过最大连续调用次数
|
||||
if current_count > self.max_tool_consecutive_calls:
|
||||
logger.warning(
|
||||
@@ -185,12 +190,12 @@ class LangChainAgent:
|
||||
f"工具 '{tool_name}' 已连续调用 {self.max_tool_consecutive_calls} 次,"
|
||||
f"未找到有效结果。请尝试其他方法或直接回答用户的问题。"
|
||||
)
|
||||
|
||||
|
||||
# 调用原始工具函数
|
||||
return original_func(*args, **kwargs)
|
||||
|
||||
|
||||
return wrapped_func
|
||||
|
||||
|
||||
# 使用 StructuredTool 创建新工具
|
||||
wrapped_tool = StructuredTool(
|
||||
name=original_tool.name,
|
||||
@@ -198,17 +203,17 @@ class LangChainAgent:
|
||||
func=make_wrapped_func(tool_name, original_func),
|
||||
args_schema=original_tool.args_schema if hasattr(original_tool, 'args_schema') else None
|
||||
)
|
||||
|
||||
|
||||
wrapped_tools.append(wrapped_tool)
|
||||
|
||||
|
||||
return wrapped_tools
|
||||
|
||||
def _prepare_messages(
|
||||
self,
|
||||
message: str,
|
||||
history: Optional[List[Dict[str, str]]] = None,
|
||||
context: Optional[str] = None,
|
||||
files: Optional[List[Dict[str, Any]]] = None
|
||||
self,
|
||||
message: str,
|
||||
history: Optional[List[Dict[str, str]]] = None,
|
||||
context: Optional[str] = None,
|
||||
files: Optional[List[Dict[str, Any]]] = None
|
||||
) -> List[BaseMessage]:
|
||||
"""准备消息列表
|
||||
|
||||
@@ -248,7 +253,7 @@ class LangChainAgent:
|
||||
messages.append(HumanMessage(content=user_content))
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
def _build_multimodal_content(self, text: str, files: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
构建多模态消息内容
|
||||
@@ -261,23 +266,26 @@ class LangChainAgent:
|
||||
List[Dict]: 消息内容列表
|
||||
"""
|
||||
# 根据 provider 使用不同的文本格式
|
||||
if self.provider.lower() in ["bedrock", "anthropic"]:
|
||||
# Anthropic/Bedrock: {"type": "text", "text": "..."}
|
||||
content_parts = [{"type": "text", "text": text}]
|
||||
else:
|
||||
# 通义千问等: {"text": "..."}
|
||||
content_parts = [{"text": text}]
|
||||
|
||||
# if (self.provider.lower() in [ModelProvider.BEDROCK, ModelProvider.OPENAI, ModelProvider.XINFERENCE,
|
||||
# ModelProvider.GPUSTACK] or (
|
||||
# self.provider.lower() == ModelProvider.DASHSCOPE and self.is_omni)):
|
||||
# # Anthropic/Bedrock/Xinference/Gpustack/Openai: {"type": "text", "text": "..."}
|
||||
# content_parts = [{"type": "text", "text": text}]
|
||||
# else:
|
||||
# # 通义千问等: {"text": "..."}
|
||||
# content_parts = [{"type": "text", "text": text}]
|
||||
content_parts = [{"type": "text", "text": text}]
|
||||
|
||||
# 添加文件内容
|
||||
# MultimodalService 已经根据 provider 返回了正确格式,直接使用
|
||||
content_parts.extend(files)
|
||||
|
||||
|
||||
logger.debug(
|
||||
f"构建多模态消息: provider={self.provider}, "
|
||||
f"parts={len(content_parts)}, "
|
||||
f"files={len(files)}"
|
||||
)
|
||||
|
||||
|
||||
return content_parts
|
||||
|
||||
async def chat(
|
||||
@@ -302,7 +310,7 @@ class LangChainAgent:
|
||||
Returns:
|
||||
Dict: 包含 content 和元数据的字典
|
||||
"""
|
||||
message_chat= message
|
||||
message_chat = message
|
||||
start_time = time.time()
|
||||
actual_config_id = config_id
|
||||
# If config_id is None, try to get from end_user's connected config
|
||||
@@ -322,8 +330,8 @@ class LangChainAgent:
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get db session: {e}")
|
||||
actual_end_user_id = end_user_id if end_user_id is not None else "unknown"
|
||||
logger.info(f'写入类型{storage_type,str(end_user_id), message, str(user_rag_memory_id)}')
|
||||
print(f'写入类型{storage_type,str(end_user_id), message, str(user_rag_memory_id)}')
|
||||
logger.info(f'写入类型{storage_type, str(end_user_id), message, str(user_rag_memory_id)}')
|
||||
print(f'写入类型{storage_type, str(end_user_id), message, str(user_rag_memory_id)}')
|
||||
try:
|
||||
# 准备消息列表(支持多模态)
|
||||
messages = self._prepare_messages(message, history, context, files)
|
||||
@@ -367,14 +375,14 @@ class LangChainAgent:
|
||||
# 获取最后的 AI 消息
|
||||
output_messages = result.get("messages", [])
|
||||
content = ""
|
||||
|
||||
|
||||
logger.debug(f"输出消息数量: {len(output_messages)}")
|
||||
total_tokens = 0
|
||||
for msg in reversed(output_messages):
|
||||
if isinstance(msg, AIMessage):
|
||||
logger.debug(f"找到 AI 消息,content 类型: {type(msg.content)}")
|
||||
logger.debug(f"AI 消息内容: {msg.content}")
|
||||
|
||||
|
||||
# 处理多模态响应:content 可能是字符串或列表
|
||||
if isinstance(msg.content, str):
|
||||
content = msg.content
|
||||
@@ -407,12 +415,13 @@ class LangChainAgent:
|
||||
response_meta = msg.response_metadata if hasattr(msg, 'response_metadata') else None
|
||||
total_tokens = response_meta.get("token_usage", {}).get("total_tokens", 0) if response_meta else 0
|
||||
break
|
||||
|
||||
|
||||
logger.info(f"最终提取的内容长度: {len(content)}")
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
if memory_flag:
|
||||
await write_long_term(storage_type, end_user_id, message_chat, content, user_rag_memory_id, actual_config_id)
|
||||
await write_long_term(storage_type, end_user_id, message_chat, content, user_rag_memory_id,
|
||||
actual_config_id)
|
||||
response = {
|
||||
"content": content,
|
||||
"model": self.model_name,
|
||||
@@ -439,16 +448,16 @@ class LangChainAgent:
|
||||
raise
|
||||
|
||||
async def chat_stream(
|
||||
self,
|
||||
message: str,
|
||||
history: Optional[List[Dict[str, str]]] = None,
|
||||
context: Optional[str] = None,
|
||||
end_user_id:Optional[str] = None,
|
||||
config_id: Optional[str] = None,
|
||||
storage_type:Optional[str] = None,
|
||||
user_rag_memory_id:Optional[str] = None,
|
||||
memory_flag: Optional[bool] = True,
|
||||
files: Optional[List[Dict[str, Any]]] = None # 新增:多模态文件
|
||||
self,
|
||||
message: str,
|
||||
history: Optional[List[Dict[str, str]]] = None,
|
||||
context: Optional[str] = None,
|
||||
end_user_id: Optional[str] = None,
|
||||
config_id: Optional[str] = None,
|
||||
storage_type: Optional[str] = None,
|
||||
user_rag_memory_id: Optional[str] = None,
|
||||
memory_flag: Optional[bool] = True,
|
||||
files: Optional[List[Dict[str, Any]]] = None # 新增:多模态文件
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""执行流式对话
|
||||
|
||||
@@ -482,7 +491,6 @@ class LangChainAgent:
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get db session: {e}")
|
||||
|
||||
|
||||
# 注意:不在这里写入用户消息,等 AI 回复后一起写入
|
||||
try:
|
||||
# 准备消息列表(支持多模态)
|
||||
@@ -500,13 +508,13 @@ class LangChainAgent:
|
||||
full_content = ''
|
||||
try:
|
||||
async for event in self.agent.astream_events(
|
||||
{"messages": messages},
|
||||
version="v2",
|
||||
config={"recursion_limit": self.max_iterations}
|
||||
{"messages": messages},
|
||||
version="v2",
|
||||
config={"recursion_limit": self.max_iterations}
|
||||
):
|
||||
chunk_count += 1
|
||||
kind = event.get("event")
|
||||
|
||||
|
||||
# 处理所有可能的流式事件
|
||||
if kind == "on_chat_model_stream":
|
||||
# LLM 流式输出
|
||||
@@ -540,7 +548,7 @@ class LangChainAgent:
|
||||
full_content += item
|
||||
yield item
|
||||
yielded_content = True
|
||||
|
||||
|
||||
elif kind == "on_llm_stream":
|
||||
# 另一种 LLM 流式事件
|
||||
chunk = event.get("data", {}).get("chunk")
|
||||
@@ -577,13 +585,13 @@ class LangChainAgent:
|
||||
full_content += chunk
|
||||
yield chunk
|
||||
yielded_content = True
|
||||
|
||||
|
||||
# 记录工具调用(可选)
|
||||
elif kind == "on_tool_start":
|
||||
logger.debug(f"工具调用开始: {event.get('name')}")
|
||||
elif kind == "on_tool_end":
|
||||
logger.debug(f"工具调用结束: {event.get('name')}")
|
||||
|
||||
|
||||
logger.debug(f"Agent 流式完成,共 {chunk_count} 个事件")
|
||||
# 统计token消耗
|
||||
output_messages = event.get("data", {}).get("output", {}).get("messages", [])
|
||||
@@ -595,7 +603,8 @@ class LangChainAgent:
|
||||
yield total_tokens
|
||||
break
|
||||
if memory_flag:
|
||||
await write_long_term(storage_type, end_user_id, message_chat, full_content, user_rag_memory_id, actual_config_id)
|
||||
await write_long_term(storage_type, end_user_id, message_chat, full_content, user_rag_memory_id,
|
||||
actual_config_id)
|
||||
except Exception as e:
|
||||
logger.error(f"Agent astream_events 失败: {str(e)}", exc_info=True)
|
||||
raise
|
||||
@@ -609,5 +618,3 @@ class LangChainAgent:
|
||||
logger.info("=" * 80)
|
||||
logger.info("chat_stream 方法执行结束")
|
||||
logger.info("=" * 80)
|
||||
|
||||
|
||||
|
||||
@@ -190,8 +190,10 @@ class Settings:
|
||||
LOG_FILE_MAX_SIZE_MB: int = int(os.getenv("LOG_FILE_MAX_SIZE_MB", "10")) # 10MB
|
||||
|
||||
# Celery configuration (internal)
|
||||
CELERY_BROKER: int = int(os.getenv("CELERY_BROKER", "1"))
|
||||
CELERY_BACKEND: int = int(os.getenv("CELERY_BACKEND", "2"))
|
||||
# NOTE: 变量名不以 CELERY_ 开头,避免被 Celery CLI 的前缀匹配机制劫持
|
||||
# 详见 docs/celery-env-bug-report.md
|
||||
REDIS_DB_CELERY_BROKER: int = int(os.getenv("REDIS_DB_CELERY_BROKER", "1"))
|
||||
REDIS_DB_CELERY_BACKEND: int = int(os.getenv("REDIS_DB_CELERY_BACKEND", "2"))
|
||||
|
||||
# SMTP Email Configuration
|
||||
SMTP_SERVER: str = os.getenv("SMTP_SERVER", "smtp.gmail.com")
|
||||
@@ -219,8 +221,12 @@ class Settings:
|
||||
FORGETTING_CYCLE_INTERVAL_HOURS: int = TypeAdapter(
|
||||
Annotated[int, Field(ge=1, description="forgetting cycle interval in hours, must be >= 1")]
|
||||
).validate_python(int(os.getenv("FORGETTING_CYCLE_INTERVAL_HOURS", "24")))
|
||||
|
||||
|
||||
IMPLICIT_EMOTIONS_UPDATE_HOUR: int = int(os.getenv("IMPLICIT_EMOTIONS_UPDATE_HOUR", "2"))
|
||||
# implicit_emotions_update: 每天几分执行(分钟,0-59)
|
||||
IMPLICIT_EMOTIONS_UPDATE_MINUTE: int = int(os.getenv("IMPLICIT_EMOTIONS_UPDATE_MINUTE", "0"))
|
||||
# Memory Module Configuration (internal)
|
||||
|
||||
MEMORY_OUTPUT_DIR: str = os.getenv("MEMORY_OUTPUT_DIR", "logs/memory-output")
|
||||
MEMORY_CONFIG_DIR: str = os.getenv("MEMORY_CONFIG_DIR", "app/core/memory")
|
||||
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
import os
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from app.core.logging_config import get_agent_logger
|
||||
from app.db import get_db
|
||||
|
||||
from app.core.logging_config import get_agent_logger
|
||||
from app.core.memory.agent.models.problem_models import ProblemExtensionResponse
|
||||
from app.core.memory.agent.services.optimized_llm_service import LLMServiceMixin
|
||||
from app.core.memory.agent.utils.llm_tools import (
|
||||
PROJECT_ROOT_,
|
||||
ReadState,
|
||||
@@ -12,10 +12,9 @@ from app.core.memory.agent.utils.llm_tools import (
|
||||
from app.core.memory.agent.utils.redis_tool import store
|
||||
from app.core.memory.agent.utils.session_tools import SessionService
|
||||
from app.core.memory.agent.utils.template_tools import TemplateService
|
||||
from app.core.memory.agent.services.optimized_llm_service import LLMServiceMixin
|
||||
from app.db import get_db_context
|
||||
|
||||
template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt')
|
||||
db_session = next(get_db())
|
||||
logger = get_agent_logger(__name__)
|
||||
|
||||
|
||||
@@ -53,13 +52,14 @@ async def Split_The_Problem(state: ReadState) -> ReadState:
|
||||
|
||||
try:
|
||||
# 使用优化的LLM服务
|
||||
structured = await problem_service.call_llm_structured(
|
||||
state=state,
|
||||
db_session=db_session,
|
||||
system_prompt=system_prompt,
|
||||
response_model=ProblemExtensionResponse,
|
||||
fallback_value=[]
|
||||
)
|
||||
with get_db_context() as db_session:
|
||||
structured = await problem_service.call_llm_structured(
|
||||
state=state,
|
||||
db_session=db_session,
|
||||
system_prompt=system_prompt,
|
||||
response_model=ProblemExtensionResponse,
|
||||
fallback_value=[]
|
||||
)
|
||||
|
||||
# 添加更详细的日志记录
|
||||
logger.info(f"Split_The_Problem: 开始处理问题分解,内容长度: {len(content)}")
|
||||
@@ -171,13 +171,14 @@ async def Problem_Extension(state: ReadState) -> ReadState:
|
||||
|
||||
try:
|
||||
# 使用优化的LLM服务
|
||||
response_content = await problem_service.call_llm_structured(
|
||||
state=state,
|
||||
db_session=db_session,
|
||||
system_prompt=system_prompt,
|
||||
response_model=ProblemExtensionResponse,
|
||||
fallback_value=[]
|
||||
)
|
||||
with get_db_context() as db_session:
|
||||
response_content = await problem_service.call_llm_structured(
|
||||
state=state,
|
||||
db_session=db_session,
|
||||
system_prompt=system_prompt,
|
||||
response_model=ProblemExtensionResponse,
|
||||
fallback_value=[]
|
||||
)
|
||||
|
||||
logger.info(f"Problem_Extension: 开始处理问题扩展,问题数量: {len(databasets)}")
|
||||
|
||||
|
||||
@@ -6,31 +6,26 @@ import os
|
||||
# ===== 第三方库 =====
|
||||
from langchain.agents import create_agent
|
||||
from langchain_openai import ChatOpenAI
|
||||
|
||||
from app.core.logging_config import get_agent_logger
|
||||
from app.db import get_db, get_db_context
|
||||
|
||||
from app.schemas import model_schema
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
from app.services.model_service import ModelConfigService
|
||||
|
||||
from app.core.memory.agent.services.search_service import SearchService
|
||||
from app.core.memory.agent.utils.llm_tools import (
|
||||
COUNTState,
|
||||
ReadState,
|
||||
deduplicate_entries,
|
||||
merge_to_key_value_pairs,
|
||||
)
|
||||
from app.core.memory.agent.langgraph_graph.tools.tool import (
|
||||
create_hybrid_retrieval_tool_sync,
|
||||
create_time_retrieval_tool,
|
||||
extract_tool_message_content,
|
||||
)
|
||||
|
||||
from app.core.memory.agent.services.search_service import SearchService
|
||||
from app.core.memory.agent.utils.llm_tools import (
|
||||
ReadState,
|
||||
deduplicate_entries,
|
||||
merge_to_key_value_pairs,
|
||||
)
|
||||
from app.core.rag.nlp.search import knowledge_retrieval
|
||||
from app.db import get_db_context
|
||||
from app.schemas import model_schema
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
from app.services.model_service import ModelConfigService
|
||||
|
||||
logger = get_agent_logger(__name__)
|
||||
db = next(get_db())
|
||||
|
||||
|
||||
|
||||
async def rag_config(state):
|
||||
@@ -50,10 +45,12 @@ async def rag_config(state):
|
||||
"reranker_top_k": 10
|
||||
}
|
||||
return kb_config
|
||||
async def rag_knowledge(state,question):
|
||||
|
||||
|
||||
async def rag_knowledge(state, question):
|
||||
kb_config = await rag_config(state)
|
||||
end_user_id = state.get('end_user_id', '')
|
||||
user_rag_memory_id=state.get("user_rag_memory_id",'')
|
||||
user_rag_memory_id = state.get("user_rag_memory_id", '')
|
||||
retrieve_chunks_result = knowledge_retrieval(question, kb_config, [str(end_user_id)])
|
||||
try:
|
||||
retrieval_knowledge = [i.page_content for i in retrieve_chunks_result]
|
||||
@@ -61,13 +58,13 @@ async def rag_knowledge(state,question):
|
||||
cleaned_query = question
|
||||
raw_results = clean_content
|
||||
logger.info(f" Using RAG storage with memory_id={user_rag_memory_id}")
|
||||
except Exception :
|
||||
retrieval_knowledge=[]
|
||||
except Exception:
|
||||
retrieval_knowledge = []
|
||||
clean_content = ''
|
||||
raw_results = ''
|
||||
cleaned_query = question
|
||||
logger.info(f"No content retrieved from knowledge base: {user_rag_memory_id}")
|
||||
return retrieval_knowledge,clean_content,cleaned_query,raw_results
|
||||
return retrieval_knowledge, clean_content, cleaned_query, raw_results
|
||||
|
||||
|
||||
async def llm_infomation(state: ReadState) -> ReadState:
|
||||
@@ -113,7 +110,7 @@ async def clean_databases(data) -> str:
|
||||
|
||||
# 收集所有内容
|
||||
content_list = []
|
||||
|
||||
|
||||
# 处理重排序结果
|
||||
reranked = results.get('reranked_results', {})
|
||||
if reranked:
|
||||
@@ -141,7 +138,6 @@ async def clean_databases(data) -> str:
|
||||
elif isinstance(item, str):
|
||||
text_parts.append(item)
|
||||
|
||||
|
||||
return '\n'.join(text_parts).strip()
|
||||
|
||||
except Exception as e:
|
||||
@@ -150,23 +146,23 @@ async def clean_databases(data) -> str:
|
||||
|
||||
|
||||
async def retrieve_nodes(state: ReadState) -> ReadState:
|
||||
|
||||
'''
|
||||
|
||||
模型信息
|
||||
'''
|
||||
|
||||
problem_extension=state.get('problem_extension', '')['context']
|
||||
storage_type=state.get('storage_type', '')
|
||||
user_rag_memory_id=state.get('user_rag_memory_id', '')
|
||||
end_user_id=state.get('end_user_id', '')
|
||||
problem_extension = state.get('problem_extension', '')['context']
|
||||
storage_type = state.get('storage_type', '')
|
||||
user_rag_memory_id = state.get('user_rag_memory_id', '')
|
||||
end_user_id = state.get('end_user_id', '')
|
||||
memory_config = state.get('memory_config', None)
|
||||
original=state.get('data', '')
|
||||
problem_list=[]
|
||||
for key,values in problem_extension.items():
|
||||
original = state.get('data', '')
|
||||
problem_list = []
|
||||
for key, values in problem_extension.items():
|
||||
for data in values:
|
||||
problem_list.append(data)
|
||||
logger.info(f"Retrieve: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}")
|
||||
|
||||
# 创建异步任务处理单个问题
|
||||
async def process_question_nodes(idx, question):
|
||||
try:
|
||||
@@ -244,7 +240,7 @@ async def retrieve_nodes(state: ReadState) -> ReadState:
|
||||
|
||||
send_verify = []
|
||||
for i, j in zip(keys, val, strict=False):
|
||||
if j!=['']:
|
||||
if j != ['']:
|
||||
send_verify.append({
|
||||
"Query_small": i,
|
||||
"Answer_Small": j
|
||||
@@ -257,15 +253,13 @@ async def retrieve_nodes(state: ReadState) -> ReadState:
|
||||
}
|
||||
|
||||
logger.info(f"Collected {len(intermediate_outputs)} intermediate outputs from search results")
|
||||
return {'retrieve':dup_databases}
|
||||
|
||||
|
||||
return {'retrieve': dup_databases}
|
||||
|
||||
|
||||
async def retrieve(state: ReadState) -> ReadState:
|
||||
# 从state中获取end_user_id
|
||||
import time
|
||||
start=time.time()
|
||||
start = time.time()
|
||||
problem_extension = state.get('problem_extension', '')['context']
|
||||
storage_type = state.get('storage_type', '')
|
||||
user_rag_memory_id = state.get('user_rag_memory_id', '')
|
||||
@@ -283,6 +277,7 @@ async def retrieve(state: ReadState) -> ReadState:
|
||||
with get_db_context() as db: # 使用同步数据库上下文管理器
|
||||
config_service = MemoryConfigService(db)
|
||||
return await llm_infomation(state)
|
||||
|
||||
llm_config = await get_llm_info()
|
||||
api_key_obj = llm_config.api_keys[0]
|
||||
api_key = api_key_obj.api_key
|
||||
@@ -296,11 +291,11 @@ async def retrieve(state: ReadState) -> ReadState:
|
||||
)
|
||||
|
||||
time_retrieval_tool = create_time_retrieval_tool(end_user_id)
|
||||
search_params = { "end_user_id": end_user_id, "return_raw_results": True }
|
||||
hybrid_retrieval=create_hybrid_retrieval_tool_sync(memory_config, **search_params)
|
||||
search_params = {"end_user_id": end_user_id, "return_raw_results": True}
|
||||
hybrid_retrieval = create_hybrid_retrieval_tool_sync(memory_config, **search_params)
|
||||
agent = create_agent(
|
||||
llm,
|
||||
tools=[time_retrieval_tool,hybrid_retrieval],
|
||||
tools=[time_retrieval_tool, hybrid_retrieval],
|
||||
system_prompt=f"我是检索专家,可以根据适合的工具进行检索。当前使用的end_user_id是: {end_user_id}"
|
||||
)
|
||||
|
||||
@@ -314,7 +309,8 @@ async def retrieve(state: ReadState) -> ReadState:
|
||||
async with SEMAPHORE: # 限制并发
|
||||
try:
|
||||
if storage_type == "rag" and user_rag_memory_id:
|
||||
retrieval_knowledge, clean_content, cleaned_query, raw_results = await rag_knowledge(state, question)
|
||||
retrieval_knowledge, clean_content, cleaned_query, raw_results = await rag_knowledge(state,
|
||||
question)
|
||||
else:
|
||||
cleaned_query = question
|
||||
# 使用 asyncio 在线程池中运行同步的 agent.invoke
|
||||
@@ -413,5 +409,3 @@ async def retrieve(state: ReadState) -> ReadState:
|
||||
# json.dump(dup_databases, f, indent=4)
|
||||
logger.info(f"Collected {len(intermediate_outputs)} intermediate outputs from search results")
|
||||
return {'retrieve': dup_databases}
|
||||
|
||||
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
|
||||
|
||||
import os
|
||||
import time
|
||||
|
||||
@@ -18,22 +16,24 @@ from app.core.memory.agent.utils.redis_tool import store
|
||||
from app.core.memory.agent.utils.session_tools import SessionService
|
||||
from app.core.memory.agent.utils.template_tools import TemplateService
|
||||
from app.core.rag.nlp.search import knowledge_retrieval
|
||||
|
||||
from app.db import get_db
|
||||
from app.db import get_db_context
|
||||
|
||||
template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt')
|
||||
logger = get_agent_logger(__name__)
|
||||
db_session = next(get_db())
|
||||
|
||||
|
||||
class SummaryNodeService(LLMServiceMixin):
|
||||
"""总结节点服务类"""
|
||||
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.template_service = TemplateService(template_root)
|
||||
|
||||
|
||||
# 创建全局服务实例
|
||||
summary_service = SummaryNodeService()
|
||||
|
||||
|
||||
async def rag_config(state):
|
||||
user_rag_memory_id = state.get('user_rag_memory_id', '')
|
||||
kb_config = {
|
||||
@@ -51,10 +51,12 @@ async def rag_config(state):
|
||||
"reranker_top_k": 10
|
||||
}
|
||||
return kb_config
|
||||
async def rag_knowledge(state,question):
|
||||
|
||||
|
||||
async def rag_knowledge(state, question):
|
||||
kb_config = await rag_config(state)
|
||||
end_user_id = state.get('end_user_id', '')
|
||||
user_rag_memory_id=state.get("user_rag_memory_id",'')
|
||||
user_rag_memory_id = state.get("user_rag_memory_id", '')
|
||||
retrieve_chunks_result = knowledge_retrieval(question, kb_config, [str(end_user_id)])
|
||||
try:
|
||||
retrieval_knowledge = [i.page_content for i in retrieve_chunks_result]
|
||||
@@ -62,25 +64,28 @@ async def rag_knowledge(state,question):
|
||||
cleaned_query = question
|
||||
raw_results = clean_content
|
||||
logger.info(f" Using RAG storage with memory_id={user_rag_memory_id}")
|
||||
except Exception :
|
||||
retrieval_knowledge=[]
|
||||
except Exception:
|
||||
retrieval_knowledge = []
|
||||
clean_content = ''
|
||||
raw_results = ''
|
||||
cleaned_query = question
|
||||
logger.info(f"No content retrieved from knowledge base: {user_rag_memory_id}")
|
||||
return retrieval_knowledge,clean_content,cleaned_query,raw_results
|
||||
return retrieval_knowledge, clean_content, cleaned_query, raw_results
|
||||
|
||||
|
||||
async def summary_history(state: ReadState) -> ReadState:
|
||||
end_user_id = state.get("end_user_id", '')
|
||||
history = await SessionService(store).get_history(end_user_id, end_user_id, end_user_id)
|
||||
return history
|
||||
|
||||
async def summary_llm(state: ReadState, history, retrieve_info, template_name, operation_name, response_model,search_mode) -> str:
|
||||
|
||||
async def summary_llm(state: ReadState, history, retrieve_info, template_name, operation_name, response_model,
|
||||
search_mode) -> str:
|
||||
"""
|
||||
增强的summary_llm函数,包含更好的错误处理和数据验证
|
||||
"""
|
||||
data = state.get("data", '')
|
||||
|
||||
|
||||
# 构建系统提示词
|
||||
if str(search_mode) == "0":
|
||||
system_prompt = await summary_service.template_service.render_template(
|
||||
@@ -99,18 +104,19 @@ async def summary_llm(state: ReadState, history, retrieve_info, template_name, o
|
||||
)
|
||||
try:
|
||||
# 使用优化的LLM服务进行结构化输出
|
||||
structured = await summary_service.call_llm_structured(
|
||||
state=state,
|
||||
db_session=db_session,
|
||||
system_prompt=system_prompt,
|
||||
response_model=response_model,
|
||||
fallback_value=None
|
||||
)
|
||||
with get_db_context() as db_session:
|
||||
structured = await summary_service.call_llm_structured(
|
||||
state=state,
|
||||
db_session=db_session,
|
||||
system_prompt=system_prompt,
|
||||
response_model=response_model,
|
||||
fallback_value=None
|
||||
)
|
||||
# 验证结构化响应
|
||||
if structured is None:
|
||||
logger.warning("LLM返回None,使用默认回答")
|
||||
return "信息不足,无法回答"
|
||||
|
||||
|
||||
# 根据操作类型提取答案
|
||||
if operation_name == "summary":
|
||||
aimessages = getattr(structured, 'query_answer', None) or "信息不足,无法回答"
|
||||
@@ -121,16 +127,16 @@ async def summary_llm(state: ReadState, history, retrieve_info, template_name, o
|
||||
else:
|
||||
logger.warning("结构化响应缺少data字段")
|
||||
aimessages = "信息不足,无法回答"
|
||||
|
||||
|
||||
# 验证答案不为空
|
||||
if not aimessages or aimessages.strip() == "":
|
||||
aimessages = "信息不足,无法回答"
|
||||
|
||||
|
||||
return aimessages
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"结构化输出失败: {e}", exc_info=True)
|
||||
|
||||
|
||||
# 尝试非结构化输出作为fallback
|
||||
try:
|
||||
logger.info("尝试非结构化输出作为fallback")
|
||||
@@ -140,7 +146,7 @@ async def summary_llm(state: ReadState, history, retrieve_info, template_name, o
|
||||
system_prompt=system_prompt,
|
||||
fallback_message="信息不足,无法回答"
|
||||
)
|
||||
|
||||
|
||||
if response and response.strip():
|
||||
# 简单清理响应
|
||||
cleaned_response = response.strip()
|
||||
@@ -148,16 +154,17 @@ async def summary_llm(state: ReadState, history, retrieve_info, template_name, o
|
||||
if cleaned_response.startswith('```'):
|
||||
lines = cleaned_response.split('\n')
|
||||
cleaned_response = '\n'.join(lines[1:-1])
|
||||
|
||||
|
||||
return cleaned_response
|
||||
else:
|
||||
return "信息不足,无法回答"
|
||||
|
||||
|
||||
except Exception as fallback_error:
|
||||
logger.error(f"Fallback也失败: {fallback_error}")
|
||||
return "信息不足,无法回答"
|
||||
|
||||
async def summary_redis_save(state: ReadState,aimessages) -> ReadState:
|
||||
|
||||
async def summary_redis_save(state: ReadState, aimessages) -> ReadState:
|
||||
data = state.get("data", '')
|
||||
end_user_id = state.get("end_user_id", '')
|
||||
await SessionService(store).save_session(
|
||||
@@ -169,10 +176,12 @@ async def summary_redis_save(state: ReadState,aimessages) -> ReadState:
|
||||
)
|
||||
await SessionService(store).cleanup_duplicates()
|
||||
logger.info(f"sessionid: {aimessages} 写入成功")
|
||||
async def summary_prompt(state: ReadState,aimessages,raw_results) -> ReadState:
|
||||
storage_type=state.get("storage_type",'')
|
||||
user_rag_memory_id=state.get("user_rag_memory_id",'')
|
||||
data=state.get("data", '')
|
||||
|
||||
|
||||
async def summary_prompt(state: ReadState, aimessages, raw_results) -> ReadState:
|
||||
storage_type = state.get("storage_type", '')
|
||||
user_rag_memory_id = state.get("user_rag_memory_id", '')
|
||||
data = state.get("data", '')
|
||||
input_summary = {
|
||||
"status": "success",
|
||||
"summary_result": aimessages,
|
||||
@@ -189,14 +198,14 @@ async def summary_prompt(state: ReadState,aimessages,raw_results) -> ReadState:
|
||||
"user_rag_memory_id": user_rag_memory_id
|
||||
}
|
||||
}
|
||||
retrieve={
|
||||
retrieve = {
|
||||
"status": "success",
|
||||
"summary_result": aimessages,
|
||||
"storage_type": storage_type,
|
||||
"user_rag_memory_id": user_rag_memory_id,
|
||||
"_intermediate": {
|
||||
"type": "retrieval_summary",
|
||||
"title":"快速检索",
|
||||
"title": "快速检索",
|
||||
"summary": aimessages,
|
||||
"query": data,
|
||||
"storage_type": storage_type,
|
||||
@@ -204,17 +213,18 @@ async def summary_prompt(state: ReadState,aimessages,raw_results) -> ReadState:
|
||||
}
|
||||
}
|
||||
|
||||
return input_summary,retrieve
|
||||
return input_summary, retrieve
|
||||
|
||||
|
||||
async def Input_Summary(state: ReadState) -> ReadState:
|
||||
start=time.time()
|
||||
storage_type=state.get("storage_type",'')
|
||||
start = time.time()
|
||||
storage_type = state.get("storage_type", '')
|
||||
memory_config = state.get('memory_config', None)
|
||||
user_rag_memory_id=state.get("user_rag_memory_id",'')
|
||||
data=state.get("data", '')
|
||||
end_user_id=state.get("end_user_id", '')
|
||||
user_rag_memory_id = state.get("user_rag_memory_id", '')
|
||||
data = state.get("data", '')
|
||||
end_user_id = state.get("end_user_id", '')
|
||||
logger.info(f"Input_Summary: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}")
|
||||
history = await summary_history( state)
|
||||
history = await summary_history(state)
|
||||
search_params = {
|
||||
"end_user_id": end_user_id,
|
||||
"question": data,
|
||||
@@ -223,12 +233,13 @@ async def Input_Summary(state: ReadState) -> ReadState:
|
||||
}
|
||||
|
||||
try:
|
||||
if storage_type!="rag":
|
||||
retrieve_info, question, raw_results = await SearchService().execute_hybrid_search(**search_params, memory_config=memory_config)
|
||||
if storage_type != "rag":
|
||||
retrieve_info, question, raw_results = await SearchService().execute_hybrid_search(**search_params,
|
||||
memory_config=memory_config)
|
||||
else:
|
||||
retrieval_knowledge, retrieve_info, question, raw_results = await rag_knowledge(state, data)
|
||||
except Exception as e:
|
||||
logger.error( f"Input_Summary: hybrid_search failed, using empty results: {e}", exc_info=True )
|
||||
logger.error(f"Input_Summary: hybrid_search failed, using empty results: {e}", exc_info=True)
|
||||
retrieve_info, question, raw_results = "", data, []
|
||||
try:
|
||||
# aimessages=await summary_llm(state,history,retrieve_info,'Retrieve_Summary_prompt.jinja2',
|
||||
@@ -237,8 +248,8 @@ async def Input_Summary(state: ReadState) -> ReadState:
|
||||
summary_result = await summary_prompt(state, retrieve_info, retrieve_info)
|
||||
summary = summary_result[0]
|
||||
except Exception as e:
|
||||
logger.error( f"Input_Summary failed: {e}", exc_info=True )
|
||||
summary= {
|
||||
logger.error(f"Input_Summary failed: {e}", exc_info=True)
|
||||
summary = {
|
||||
"status": "fail",
|
||||
"summary_result": "信息不足,无法回答",
|
||||
"storage_type": storage_type,
|
||||
@@ -251,30 +262,31 @@ async def Input_Summary(state: ReadState) -> ReadState:
|
||||
except Exception:
|
||||
duration = 0.0
|
||||
log_time('检索', duration)
|
||||
return {"summary":summary}
|
||||
return {"summary": summary}
|
||||
|
||||
async def Retrieve_Summary(state: ReadState)-> ReadState:
|
||||
retrieve=state.get("retrieve", '')
|
||||
history = await summary_history( state)
|
||||
|
||||
async def Retrieve_Summary(state: ReadState) -> ReadState:
|
||||
retrieve = state.get("retrieve", '')
|
||||
history = await summary_history(state)
|
||||
import json
|
||||
with open("检索.json","w",encoding='utf-8') as f:
|
||||
with open("检索.json", "w", encoding='utf-8') as f:
|
||||
f.write(json.dumps(retrieve, indent=4, ensure_ascii=False))
|
||||
retrieve=retrieve.get("Expansion_issue", [])
|
||||
start=time.time()
|
||||
retrieve_info_str=[]
|
||||
retrieve = retrieve.get("Expansion_issue", [])
|
||||
start = time.time()
|
||||
retrieve_info_str = []
|
||||
for data in retrieve:
|
||||
if data=='':
|
||||
retrieve_info_str=''
|
||||
if data == '':
|
||||
retrieve_info_str = ''
|
||||
else:
|
||||
for key, value in data.items():
|
||||
if key=='Answer_Small':
|
||||
if key == 'Answer_Small':
|
||||
for i in value:
|
||||
retrieve_info_str.append(i)
|
||||
retrieve_info_str=list(set(retrieve_info_str))
|
||||
retrieve_info_str='\n'.join(retrieve_info_str)
|
||||
retrieve_info_str = list(set(retrieve_info_str))
|
||||
retrieve_info_str = '\n'.join(retrieve_info_str)
|
||||
|
||||
aimessages=await summary_llm(state,history,retrieve_info_str,
|
||||
'direct_summary_prompt.jinja2','retrieve_summary',RetrieveSummaryResponse,"1")
|
||||
aimessages = await summary_llm(state, history, retrieve_info_str,
|
||||
'direct_summary_prompt.jinja2', 'retrieve_summary', RetrieveSummaryResponse, "1")
|
||||
if '信息不足,无法回答' not in str(aimessages) or str(aimessages) != "":
|
||||
await summary_redis_save(state, aimessages)
|
||||
if aimessages == '':
|
||||
@@ -286,33 +298,33 @@ async def Retrieve_Summary(state: ReadState)-> ReadState:
|
||||
except Exception:
|
||||
duration = 0.0
|
||||
log_time('Retrieval summary', duration)
|
||||
|
||||
|
||||
# 修复协程调用 - 先await,然后访问返回值
|
||||
summary_result = await summary_prompt(state, aimessages, retrieve_info_str)
|
||||
summary = summary_result[1]
|
||||
return {"summary":summary}
|
||||
return {"summary": summary}
|
||||
|
||||
|
||||
async def Summary(state: ReadState)-> ReadState:
|
||||
start=time.time()
|
||||
async def Summary(state: ReadState) -> ReadState:
|
||||
start = time.time()
|
||||
query = state.get("data", '')
|
||||
verify=state.get("verify", '')
|
||||
verify_expansion_issue=verify.get("verified_data", '')
|
||||
retrieve_info_str=''
|
||||
verify = state.get("verify", '')
|
||||
verify_expansion_issue = verify.get("verified_data", '')
|
||||
retrieve_info_str = ''
|
||||
for data in verify_expansion_issue:
|
||||
for key, value in data.items():
|
||||
if key=='answer_small':
|
||||
if key == 'answer_small':
|
||||
for i in value:
|
||||
retrieve_info_str+=i+'\n'
|
||||
history=await summary_history(state)
|
||||
retrieve_info_str += i + '\n'
|
||||
history = await summary_history(state)
|
||||
|
||||
data = {
|
||||
"query": query,
|
||||
"history": history,
|
||||
"retrieve_info": retrieve_info_str
|
||||
}
|
||||
aimessages=await summary_llm(state,history,data,
|
||||
'summary_prompt.jinja2','summary',SummaryResponse,0)
|
||||
aimessages = await summary_llm(state, history, data,
|
||||
'summary_prompt.jinja2', 'summary', SummaryResponse, 0)
|
||||
|
||||
if '信息不足,无法回答' not in str(aimessages) or str(aimessages) != "":
|
||||
await summary_redis_save(state, aimessages)
|
||||
@@ -327,10 +339,12 @@ async def Summary(state: ReadState)-> ReadState:
|
||||
# 修复协程调用 - 先await,然后访问返回值
|
||||
summary_result = await summary_prompt(state, aimessages, retrieve_info_str)
|
||||
summary = summary_result[1]
|
||||
return {"summary":summary}
|
||||
async def Summary_fails(state: ReadState)-> ReadState:
|
||||
storage_type=state.get("storage_type", '')
|
||||
user_rag_memory_id=state.get("user_rag_memory_id", '')
|
||||
return {"summary": summary}
|
||||
|
||||
|
||||
async def Summary_fails(state: ReadState) -> ReadState:
|
||||
storage_type = state.get("storage_type", '')
|
||||
user_rag_memory_id = state.get("user_rag_memory_id", '')
|
||||
history = await summary_history(state)
|
||||
query = state.get("data", '')
|
||||
verify = state.get("verify", '')
|
||||
@@ -346,12 +360,12 @@ async def Summary_fails(state: ReadState)-> ReadState:
|
||||
"history": history,
|
||||
"retrieve_info": retrieve_info_str
|
||||
}
|
||||
aimessages = await summary_llm(state, history, data,
|
||||
'fail_summary_prompt.jinja2', 'summary', SummaryResponse, 0)
|
||||
result= {
|
||||
aimessages = await summary_llm(state, history, data,
|
||||
'fail_summary_prompt.jinja2', 'summary', SummaryResponse, 0)
|
||||
result = {
|
||||
"status": "success",
|
||||
"summary_result": aimessages,
|
||||
"storage_type": storage_type,
|
||||
"user_rag_memory_id": user_rag_memory_id
|
||||
}
|
||||
return {"summary":result}
|
||||
return {"summary": result}
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
import asyncio
|
||||
import os
|
||||
from app.core.logging_config import get_agent_logger
|
||||
from app.db import get_db
|
||||
|
||||
from app.core.logging_config import get_agent_logger
|
||||
from app.core.memory.agent.models.verification_models import VerificationResult
|
||||
from app.core.memory.agent.services.optimized_llm_service import LLMServiceMixin
|
||||
from app.core.memory.agent.utils.llm_tools import (
|
||||
PROJECT_ROOT_,
|
||||
ReadState,
|
||||
@@ -10,28 +11,30 @@ from app.core.memory.agent.utils.llm_tools import (
|
||||
from app.core.memory.agent.utils.redis_tool import store
|
||||
from app.core.memory.agent.utils.session_tools import SessionService
|
||||
from app.core.memory.agent.utils.template_tools import TemplateService
|
||||
from app.core.memory.agent.services.optimized_llm_service import LLMServiceMixin
|
||||
from app.db import get_db_context
|
||||
|
||||
template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt')
|
||||
db_session = next(get_db())
|
||||
logger = get_agent_logger(__name__)
|
||||
|
||||
|
||||
class VerificationNodeService(LLMServiceMixin):
|
||||
"""验证节点服务类"""
|
||||
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.template_service = TemplateService(template_root)
|
||||
|
||||
|
||||
# 创建全局服务实例
|
||||
verification_service = VerificationNodeService()
|
||||
|
||||
|
||||
async def Verify_prompt(state: ReadState, messages_deal: VerificationResult):
|
||||
"""处理验证结果并生成输出格式"""
|
||||
storage_type = state.get('storage_type', '')
|
||||
user_rag_memory_id = state.get('user_rag_memory_id', '')
|
||||
data = state.get('data', '')
|
||||
|
||||
|
||||
# 将 VerificationItem 对象转换为字典列表
|
||||
verified_data = []
|
||||
if messages_deal.expansion_issue:
|
||||
@@ -40,7 +43,7 @@ async def Verify_prompt(state: ReadState, messages_deal: VerificationResult):
|
||||
verified_data.append(item.model_dump())
|
||||
elif isinstance(item, dict):
|
||||
verified_data.append(item)
|
||||
|
||||
|
||||
Verify_result = {
|
||||
"status": messages_deal.split_result,
|
||||
"verified_data": verified_data,
|
||||
@@ -58,34 +61,37 @@ async def Verify_prompt(state: ReadState, messages_deal: VerificationResult):
|
||||
}
|
||||
}
|
||||
return Verify_result
|
||||
|
||||
|
||||
async def Verify(state: ReadState):
|
||||
logger.info("=== Verify 节点开始执行 ===")
|
||||
try:
|
||||
content = state.get('data', '')
|
||||
end_user_id = state.get('end_user_id', '')
|
||||
memory_config = state.get('memory_config', None)
|
||||
|
||||
|
||||
logger.info(f"Verify: content={content[:50] if content else 'empty'}..., end_user_id={end_user_id}")
|
||||
|
||||
history = await SessionService(store).get_history(end_user_id, end_user_id, end_user_id)
|
||||
logger.info(f"Verify: 获取历史记录完成,history length={len(history)}")
|
||||
|
||||
retrieve = state.get("retrieve", {})
|
||||
logger.info(f"Verify: retrieve data type={type(retrieve)}, keys={retrieve.keys() if isinstance(retrieve, dict) else 'N/A'}")
|
||||
|
||||
logger.info(
|
||||
f"Verify: retrieve data type={type(retrieve)}, keys={retrieve.keys() if isinstance(retrieve, dict) else 'N/A'}")
|
||||
|
||||
retrieve_expansion = retrieve.get("Expansion_issue", []) if isinstance(retrieve, dict) else []
|
||||
logger.info(f"Verify: Expansion_issue length={len(retrieve_expansion)}")
|
||||
|
||||
|
||||
messages = {
|
||||
"Query": content,
|
||||
"Expansion_issue": retrieve_expansion
|
||||
}
|
||||
|
||||
logger.info("Verify: 开始渲染模板")
|
||||
|
||||
|
||||
# 生成 JSON schema 以指导 LLM 输出正确格式
|
||||
json_schema = VerificationResult.model_json_schema()
|
||||
|
||||
|
||||
system_prompt = await verification_service.template_service.render_template(
|
||||
template_name='split_verify_prompt.jinja2',
|
||||
operation_name='split_verify_prompt',
|
||||
@@ -94,29 +100,30 @@ async def Verify(state: ReadState):
|
||||
json_schema=json_schema
|
||||
)
|
||||
logger.info(f"Verify: 模板渲染完成,prompt length={len(system_prompt)}")
|
||||
|
||||
|
||||
# 使用优化的LLM服务,添加超时保护
|
||||
logger.info("Verify: 开始调用 LLM")
|
||||
try:
|
||||
# 添加 asyncio.wait_for 超时包裹,防止无限等待
|
||||
# 超时时间设置为 150 秒(比 LLM 配置的 120 秒稍长)
|
||||
import asyncio
|
||||
structured = await asyncio.wait_for(
|
||||
verification_service.call_llm_structured(
|
||||
state=state,
|
||||
db_session=db_session,
|
||||
system_prompt=system_prompt,
|
||||
response_model=VerificationResult,
|
||||
fallback_value={
|
||||
"query": content,
|
||||
"history": history if isinstance(history, list) else [],
|
||||
"expansion_issue": [],
|
||||
"split_result": "failed",
|
||||
"reason": "验证失败或超时"
|
||||
}
|
||||
),
|
||||
timeout=150.0 # 150秒超时
|
||||
)
|
||||
|
||||
with get_db_context() as db_session:
|
||||
structured = await asyncio.wait_for(
|
||||
verification_service.call_llm_structured(
|
||||
state=state,
|
||||
db_session=db_session,
|
||||
system_prompt=system_prompt,
|
||||
response_model=VerificationResult,
|
||||
fallback_value={
|
||||
"query": content,
|
||||
"history": history if isinstance(history, list) else [],
|
||||
"expansion_issue": [],
|
||||
"split_result": "failed",
|
||||
"reason": "验证失败或超时"
|
||||
}
|
||||
),
|
||||
timeout=150.0 # 150秒超时
|
||||
)
|
||||
logger.info(f"Verify: LLM 调用完成,result={structured}")
|
||||
except asyncio.TimeoutError:
|
||||
logger.error("Verify: LLM 调用超时(150秒),使用 fallback 值")
|
||||
@@ -127,11 +134,11 @@ async def Verify(state: ReadState):
|
||||
split_result="failed",
|
||||
reason="LLM调用超时"
|
||||
)
|
||||
|
||||
|
||||
result = await Verify_prompt(state, structured)
|
||||
logger.info("=== Verify 节点执行完成 ===")
|
||||
return {"verify": result}
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Verify 节点执行失败: {e}", exc_info=True)
|
||||
# 返回失败的验证结果
|
||||
@@ -152,4 +159,4 @@ async def Verify(state: ReadState):
|
||||
"user_rag_memory_id": state.get('user_rag_memory_id', '')
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,7 +5,6 @@ from langchain_core.messages import HumanMessage
|
||||
from langgraph.constants import START, END
|
||||
from langgraph.graph import StateGraph
|
||||
|
||||
|
||||
from app.db import get_db
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
|
||||
@@ -32,7 +31,6 @@ from app.core.memory.agent.langgraph_graph.routing.routers import (
|
||||
)
|
||||
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def make_read_graph():
|
||||
"""创建并返回 LangGraph 工作流"""
|
||||
@@ -49,7 +47,7 @@ async def make_read_graph():
|
||||
workflow.add_node("Retrieve_Summary", Retrieve_Summary)
|
||||
workflow.add_node("Summary", Summary)
|
||||
workflow.add_node("Summary_fails", Summary_fails)
|
||||
|
||||
|
||||
# 添加边
|
||||
workflow.add_edge(START, "content_input")
|
||||
workflow.add_conditional_edges("content_input", Split_continue)
|
||||
@@ -62,20 +60,20 @@ async def make_read_graph():
|
||||
workflow.add_edge("Summary_fails", END)
|
||||
workflow.add_edge("Summary", END)
|
||||
|
||||
|
||||
'''-----'''
|
||||
# workflow.add_edge("Retrieve", END)
|
||||
|
||||
|
||||
# 编译工作流
|
||||
graph = workflow.compile()
|
||||
yield graph
|
||||
|
||||
|
||||
except Exception as e:
|
||||
print(f"创建工作流失败: {e}")
|
||||
raise
|
||||
finally:
|
||||
print("工作流创建完成")
|
||||
|
||||
|
||||
async def main():
|
||||
"""主函数 - 运行工作流"""
|
||||
message = "昨天有什么好看的电影"
|
||||
@@ -92,17 +90,19 @@ async def main():
|
||||
service_name="MemoryAgentService"
|
||||
)
|
||||
import time
|
||||
start=time.time()
|
||||
start = time.time()
|
||||
try:
|
||||
async with make_read_graph() as graph:
|
||||
config = {"configurable": {"thread_id": end_user_id}}
|
||||
# 初始状态 - 包含所有必要字段
|
||||
initial_state = {"messages": [HumanMessage(content=message)] ,"search_switch":search_switch,"end_user_id":end_user_id
|
||||
,"storage_type":storage_type,"user_rag_memory_id":user_rag_memory_id,"memory_config":memory_config}
|
||||
initial_state = {"messages": [HumanMessage(content=message)], "search_switch": search_switch,
|
||||
"end_user_id": end_user_id
|
||||
, "storage_type": storage_type, "user_rag_memory_id": user_rag_memory_id,
|
||||
"memory_config": memory_config}
|
||||
# 获取节点更新信息
|
||||
_intermediate_outputs = []
|
||||
summary = ''
|
||||
|
||||
|
||||
async for update_event in graph.astream(
|
||||
initial_state,
|
||||
stream_mode="updates",
|
||||
@@ -110,7 +110,7 @@ async def main():
|
||||
):
|
||||
for node_name, node_data in update_event.items():
|
||||
print(f"处理节点: {node_name}")
|
||||
|
||||
|
||||
# 处理不同Summary节点的返回结构
|
||||
if 'Summary' in node_name:
|
||||
if 'InputSummary' in node_data and 'summary_result' in node_data['InputSummary']:
|
||||
@@ -125,23 +125,22 @@ async def main():
|
||||
spit_data = node_data.get('spit_data', {}).get('_intermediate', None)
|
||||
if spit_data and spit_data != [] and spit_data != {}:
|
||||
_intermediate_outputs.append(spit_data)
|
||||
|
||||
|
||||
# Problem_Extension 节点
|
||||
problem_extension = node_data.get('problem_extension', {}).get('_intermediate', None)
|
||||
if problem_extension and problem_extension != [] and problem_extension != {}:
|
||||
_intermediate_outputs.append(problem_extension)
|
||||
|
||||
|
||||
# Retrieve 节点
|
||||
retrieve_node = node_data.get('retrieve', {}).get('_intermediate_outputs', None)
|
||||
if retrieve_node and retrieve_node != [] and retrieve_node != {}:
|
||||
_intermediate_outputs.extend(retrieve_node)
|
||||
|
||||
|
||||
# Verify 节点
|
||||
verify_n = node_data.get('verify', {}).get('_intermediate', None)
|
||||
if verify_n and verify_n != [] and verify_n != {}:
|
||||
_intermediate_outputs.append(verify_n)
|
||||
|
||||
|
||||
# Summary 节点
|
||||
summary_n = node_data.get('summary', {}).get('_intermediate', None)
|
||||
if summary_n and summary_n != [] and summary_n != {}:
|
||||
@@ -161,17 +160,20 @@ async def main():
|
||||
#
|
||||
print(f"=== 最终摘要 ===")
|
||||
print(summary)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
finally:
|
||||
db_session.close()
|
||||
|
||||
end=time.time()
|
||||
print(100*'y')
|
||||
print(f"总耗时: {end-start}s")
|
||||
print(100*'y')
|
||||
end = time.time()
|
||||
print(100 * 'y')
|
||||
print(f"总耗时: {end - start}s")
|
||||
print(100 * 'y')
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
|
||||
asyncio.run(main())
|
||||
|
||||
@@ -1,56 +0,0 @@
|
||||
|
||||
import asyncio
|
||||
from typing import Dict, Optional
|
||||
from app.core.memory.utils.llm.llm_utils import get_llm_client_fast
|
||||
from app.db import get_db
|
||||
from app.core.logging_config import get_agent_logger
|
||||
|
||||
logger = get_agent_logger(__name__)
|
||||
|
||||
class LLMClientPool:
|
||||
"""LLM客户端连接池"""
|
||||
|
||||
def __init__(self, max_size: int = 5):
|
||||
self.max_size = max_size
|
||||
self.pools: Dict[str, asyncio.Queue] = {}
|
||||
self.active_clients: Dict[str, int] = {}
|
||||
|
||||
async def get_client(self, llm_model_id: str):
|
||||
"""获取LLM客户端"""
|
||||
if llm_model_id not in self.pools:
|
||||
self.pools[llm_model_id] = asyncio.Queue(maxsize=self.max_size)
|
||||
self.active_clients[llm_model_id] = 0
|
||||
|
||||
pool = self.pools[llm_model_id]
|
||||
|
||||
try:
|
||||
# 尝试从池中获取客户端
|
||||
client = pool.get_nowait()
|
||||
logger.debug(f"从池中获取LLM客户端: {llm_model_id}")
|
||||
return client
|
||||
except asyncio.QueueEmpty:
|
||||
# 池为空,创建新客户端
|
||||
if self.active_clients[llm_model_id] < self.max_size:
|
||||
db_session = next(get_db())
|
||||
client = get_llm_client_fast(llm_model_id, db_session)
|
||||
self.active_clients[llm_model_id] += 1
|
||||
logger.debug(f"创建新LLM客户端: {llm_model_id}")
|
||||
return client
|
||||
else:
|
||||
# 等待可用客户端
|
||||
logger.debug(f"等待LLM客户端可用: {llm_model_id}")
|
||||
return await pool.get()
|
||||
|
||||
async def return_client(self, llm_model_id: str, client):
|
||||
"""归还LLM客户端到池中"""
|
||||
if llm_model_id in self.pools:
|
||||
try:
|
||||
self.pools[llm_model_id].put_nowait(client)
|
||||
logger.debug(f"归还LLM客户端到池: {llm_model_id}")
|
||||
except asyncio.QueueFull:
|
||||
# 池已满,丢弃客户端
|
||||
self.active_clients[llm_model_id] -= 1
|
||||
logger.debug(f"池已满,丢弃LLM客户端: {llm_model_id}")
|
||||
|
||||
# 全局客户端池
|
||||
llm_client_pool = LLMClientPool()
|
||||
@@ -21,31 +21,55 @@ from pydantic import BaseModel, Field
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class RedBearModelConfig(BaseModel):
|
||||
"""模型配置基类"""
|
||||
model_name: str
|
||||
provider: str
|
||||
api_key: str
|
||||
base_url: Optional[str] = None
|
||||
is_omni: bool = False # 是否为 Omni 模型
|
||||
# 请求超时时间(秒)- 默认120秒以支持复杂的LLM调用,可通过环境变量 LLM_TIMEOUT 配置
|
||||
timeout: float = Field(default_factory=lambda: float(os.getenv("LLM_TIMEOUT", "120.0")))
|
||||
# 最大重试次数 - 默认2次以避免过长等待,可通过环境变量 LLM_MAX_RETRIES 配置
|
||||
max_retries: int = Field(default_factory=lambda: int(os.getenv("LLM_MAX_RETRIES", "2")))
|
||||
concurrency: int = 5 # 并发限流
|
||||
concurrency: int = 5 # 并发限流
|
||||
extra_params: Dict[str, Any] = {}
|
||||
|
||||
|
||||
class RedBearModelFactory:
|
||||
"""模型工厂类"""
|
||||
|
||||
|
||||
@classmethod
|
||||
def get_model_params(cls, config: RedBearModelConfig) -> Dict[str, Any]:
|
||||
"""根据提供商获取模型参数"""
|
||||
provider = config.provider.lower()
|
||||
|
||||
|
||||
# 打印供应商信息用于调试
|
||||
from app.core.logging_config import get_business_logger
|
||||
logger = get_business_logger()
|
||||
logger.debug(f"获取模型参数 - Provider: {provider}, Model: {config.model_name}")
|
||||
logger.debug(f"获取模型参数 - Provider: {provider}, Model: {config.model_name}, is_omni: {config.is_omni}")
|
||||
|
||||
# dashscope 的 omni 模型使用 OpenAI 兼容模式
|
||||
if provider == ModelProvider.DASHSCOPE and config.is_omni:
|
||||
import httpx
|
||||
if not config.base_url:
|
||||
config.base_url = "https://dashscope.aliyuncs.com/compatible-mode/v1"
|
||||
timeout_config = httpx.Timeout(
|
||||
timeout=config.timeout,
|
||||
connect=60.0,
|
||||
read=config.timeout,
|
||||
write=60.0,
|
||||
pool=10.0,
|
||||
)
|
||||
return {
|
||||
"model": config.model_name,
|
||||
"base_url": config.base_url,
|
||||
"api_key": config.api_key,
|
||||
"timeout": timeout_config,
|
||||
"max_retries": config.max_retries,
|
||||
**config.extra_params
|
||||
}
|
||||
|
||||
if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK, ModelProvider.OLLAMA]:
|
||||
# 使用 httpx.Timeout 对象来设置详细的超时配置
|
||||
@@ -65,7 +89,7 @@ class RedBearModelFactory:
|
||||
"timeout": timeout_config,
|
||||
"max_retries": config.max_retries,
|
||||
**config.extra_params
|
||||
}
|
||||
}
|
||||
elif provider == ModelProvider.DASHSCOPE:
|
||||
# DashScope (通义千问) 使用自己的参数格式
|
||||
# 注意: DashScopeEmbeddings 不支持 timeout 和 base_url 参数
|
||||
@@ -82,7 +106,7 @@ class RedBearModelFactory:
|
||||
# region 从 base_url 或 extra_params 获取
|
||||
from botocore.config import Config as BotoConfig
|
||||
from app.core.models.bedrock_model_mapper import normalize_bedrock_model_id
|
||||
|
||||
|
||||
max_pool_connections = int(os.getenv("BEDROCK_MAX_POOL_CONNECTIONS", "50"))
|
||||
max_retries = int(os.getenv("BEDROCK_MAX_RETRIES", "2"))
|
||||
# Configure with increased connection pool
|
||||
@@ -90,16 +114,16 @@ class RedBearModelFactory:
|
||||
max_pool_connections=max_pool_connections,
|
||||
retries={'max_attempts': max_retries, 'mode': 'adaptive'}
|
||||
)
|
||||
|
||||
|
||||
# 标准化模型 ID(自动转换简化名称为完整 Bedrock Model ID)
|
||||
model_id = normalize_bedrock_model_id(config.model_name)
|
||||
|
||||
|
||||
params = {
|
||||
"model_id": model_id,
|
||||
"config": boto_config,
|
||||
**config.extra_params
|
||||
}
|
||||
|
||||
|
||||
# 解析 API key (格式: access_key_id:secret_access_key)
|
||||
if config.api_key and ":" in config.api_key:
|
||||
access_key_id, secret_access_key = config.api_key.split(":", 1)
|
||||
@@ -107,45 +131,52 @@ class RedBearModelFactory:
|
||||
params["aws_secret_access_key"] = secret_access_key
|
||||
elif config.api_key:
|
||||
params["aws_access_key_id"] = config.api_key
|
||||
|
||||
|
||||
# 设置 region
|
||||
if config.base_url:
|
||||
params["region_name"] = config.base_url
|
||||
elif "region_name" not in params:
|
||||
params["region_name"] = "us-east-1" # 默认区域
|
||||
|
||||
|
||||
return params
|
||||
else:
|
||||
raise BusinessException(f"不支持的提供商: {provider}", code=BizCode.PROVIDER_NOT_SUPPORTED)
|
||||
|
||||
|
||||
@classmethod
|
||||
def get_rerank_model_params(cls, config: RedBearModelConfig) -> Dict[str, Any]:
|
||||
"""根据提供商获取模型参数"""
|
||||
provider = config.provider.lower()
|
||||
if provider in [ModelProvider.XINFERENCE, ModelProvider.GPUSTACK]:
|
||||
return {
|
||||
return {
|
||||
"model": config.model_name,
|
||||
# "base_url": config.base_url,
|
||||
"jina_api_key": config.api_key,
|
||||
**config.extra_params
|
||||
}
|
||||
}
|
||||
else:
|
||||
raise BusinessException(f"不支持的提供商: {provider}", code=BizCode.PROVIDER_NOT_SUPPORTED)
|
||||
|
||||
def get_provider_llm_class(config:RedBearModelConfig, type: ModelType=ModelType.LLM) -> type[BaseLLM]:
|
||||
|
||||
def get_provider_llm_class(config: RedBearModelConfig, type: ModelType = ModelType.LLM) -> type[BaseLLM]:
|
||||
"""根据模型提供商获取对应的模型类"""
|
||||
provider = config.provider.lower()
|
||||
if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK] :
|
||||
|
||||
# dashscope 的 omni 模型使用 OpenAI 兼容模式
|
||||
if provider == ModelProvider.DASHSCOPE and config.is_omni:
|
||||
from langchain_openai import ChatOpenAI
|
||||
return ChatOpenAI
|
||||
|
||||
if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK] :
|
||||
if type == ModelType.LLM:
|
||||
from langchain_openai import OpenAI
|
||||
return OpenAI
|
||||
return OpenAI
|
||||
elif type == ModelType.CHAT:
|
||||
from langchain_openai import ChatOpenAI
|
||||
return ChatOpenAI
|
||||
elif provider == ModelProvider.DASHSCOPE:
|
||||
from langchain_community.chat_models import ChatTongyi
|
||||
return ChatTongyi
|
||||
elif provider == ModelProvider.OLLAMA:
|
||||
elif provider == ModelProvider.OLLAMA:
|
||||
from langchain_ollama import OllamaLLM
|
||||
return OllamaLLM
|
||||
elif provider == ModelProvider.BEDROCK:
|
||||
@@ -155,15 +186,16 @@ def get_provider_llm_class(config:RedBearModelConfig, type: ModelType=ModelType.
|
||||
else:
|
||||
raise BusinessException(f"不支持的模型提供商: {provider}", code=BizCode.PROVIDER_NOT_SUPPORTED)
|
||||
|
||||
|
||||
def get_provider_embedding_class(provider: str) -> type[Embeddings]:
|
||||
"""根据模型提供商获取对应的模型类"""
|
||||
provider = provider.lower()
|
||||
if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK] :
|
||||
if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK]:
|
||||
from langchain_openai import OpenAIEmbeddings
|
||||
return OpenAIEmbeddings
|
||||
return OpenAIEmbeddings
|
||||
elif provider == ModelProvider.DASHSCOPE:
|
||||
from langchain_community.embeddings import DashScopeEmbeddings
|
||||
return DashScopeEmbeddings
|
||||
return DashScopeEmbeddings
|
||||
elif provider == ModelProvider.OLLAMA:
|
||||
from langchain_ollama import OllamaEmbeddings
|
||||
return OllamaEmbeddings
|
||||
@@ -173,14 +205,15 @@ def get_provider_embedding_class(provider: str) -> type[Embeddings]:
|
||||
else:
|
||||
raise BusinessException(f"不支持的模型提供商: {provider}", code=BizCode.PROVIDER_NOT_SUPPORTED)
|
||||
|
||||
|
||||
def get_provider_rerank_class(provider: str):
|
||||
"""根据模型提供商获取对应的模型类"""
|
||||
provider = provider.lower()
|
||||
if provider in [ModelProvider.XINFERENCE, ModelProvider.GPUSTACK] :
|
||||
provider = provider.lower()
|
||||
if provider in [ModelProvider.XINFERENCE, ModelProvider.GPUSTACK]:
|
||||
from langchain_community.document_compressors import JinaRerank
|
||||
return JinaRerank
|
||||
# elif provider == ModelProvider.OLLAMA:
|
||||
return JinaRerank
|
||||
# elif provider == ModelProvider.OLLAMA:
|
||||
# from langchain_ollama import OllamaEmbeddings
|
||||
# return OllamaEmbeddings
|
||||
else:
|
||||
raise BusinessException(f"不支持的模型提供商: {provider}", code=BizCode.PROVIDER_NOT_SUPPORTED)
|
||||
raise BusinessException(f"不支持的模型提供商: {provider}", code=BizCode.PROVIDER_NOT_SUPPORTED)
|
||||
|
||||
@@ -6,6 +6,8 @@ models:
|
||||
description: AI21 Labs大语言模型,completion生成模式,256000上下文窗口
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
logo: bedrock
|
||||
@@ -15,6 +17,9 @@ models:
|
||||
description: Amazon Nova大语言模型,支持智能体思考、工具调用、流式工具调用、视觉能力,300000上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- agent-thought
|
||||
@@ -28,6 +33,9 @@ models:
|
||||
description: Anthropic Claude大语言模型,支持智能体思考、视觉能力、工具调用、流式工具调用、文档处理,200000上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- agent-thought
|
||||
@@ -42,6 +50,8 @@ models:
|
||||
description: Cohere大语言模型,支持智能体思考、工具调用、流式工具调用,128000上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- agent-thought
|
||||
@@ -54,6 +64,9 @@ models:
|
||||
description: DeepSeek大语言模型,支持智能体思考、视觉能力、工具调用、流式工具调用,32768上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- agent-thought
|
||||
@@ -67,6 +80,8 @@ models:
|
||||
description: Meta Llama大语言模型,支持智能体思考、工具调用,128000上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- agent-thought
|
||||
@@ -78,6 +93,8 @@ models:
|
||||
description: Mistral AI大语言模型,支持智能体思考、工具调用,32000上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- agent-thought
|
||||
@@ -89,6 +106,8 @@ models:
|
||||
description: OpenAI大语言模型,支持智能体思考、工具调用、流式工具调用,32768上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- agent-thought
|
||||
@@ -101,6 +120,8 @@ models:
|
||||
description: Qwen大语言模型,支持智能体思考、工具调用、流式工具调用,32768上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- agent-thought
|
||||
@@ -113,6 +134,8 @@ models:
|
||||
description: amazon.rerank-v1:0重排序模型,5120上下文窗口
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 重排序模型
|
||||
logo: bedrock
|
||||
@@ -122,6 +145,8 @@ models:
|
||||
description: cohere.rerank-v3-5:0重排序模型,5120上下文窗口
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 重排序模型
|
||||
logo: bedrock
|
||||
@@ -131,6 +156,9 @@ models:
|
||||
description: amazon.nova-2-multimodal-embeddings-v1:0文本嵌入模型,支持视觉能力,8192上下文窗口
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
is_omni: false
|
||||
tags:
|
||||
- 文本嵌入模型
|
||||
- vision
|
||||
@@ -141,6 +169,8 @@ models:
|
||||
description: amazon.titan-embed-text-v1文本嵌入模型,8192上下文窗口
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 文本嵌入模型
|
||||
logo: bedrock
|
||||
@@ -150,6 +180,8 @@ models:
|
||||
description: amazon.titan-embed-text-v2:0文本嵌入模型,8192上下文窗口
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 文本嵌入模型
|
||||
logo: bedrock
|
||||
@@ -159,6 +191,8 @@ models:
|
||||
description: Cohere Embed 3 English文本嵌入模型,512上下文窗口
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 文本嵌入模型
|
||||
logo: bedrock
|
||||
@@ -168,6 +202,8 @@ models:
|
||||
description: Cohere Embed 3 Multilingual文本嵌入模型,512上下文窗口
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 文本嵌入模型
|
||||
logo: bedrock
|
||||
logo: bedrock
|
||||
@@ -6,6 +6,8 @@ models:
|
||||
description: DeepSeek-R1-Distill-Qwen-14B大语言模型,支持智能体思考,32000上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- agent-thought
|
||||
@@ -16,6 +18,8 @@ models:
|
||||
description: DeepSeek-R1-Distill-Qwen-32B大语言模型,支持智能体思考,32000上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- agent-thought
|
||||
@@ -26,6 +30,8 @@ models:
|
||||
description: DeepSeek-R1大语言模型,支持智能体思考,131072超大上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- agent-thought
|
||||
@@ -36,6 +42,8 @@ models:
|
||||
description: DeepSeek-V3.1大语言模型,支持智能体思考,131072超大上下文窗口,对话模式,支持丰富生成参数调节
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- agent-thought
|
||||
@@ -46,6 +54,8 @@ models:
|
||||
description: DeepSeek-V3.2-exp实验版大语言模型,支持智能体思考,131072超大上下文窗口,对话模式,支持丰富生成参数调节
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- agent-thought
|
||||
@@ -56,6 +66,8 @@ models:
|
||||
description: DeepSeek-V3.2大语言模型,支持智能体思考,131072超大上下文窗口,对话模式,支持丰富生成参数调节
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- agent-thought
|
||||
@@ -66,6 +78,8 @@ models:
|
||||
description: DeepSeek-V3大语言模型,支持智能体思考,64000上下文窗口,对话模式,支持文本与JSON格式输出
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- agent-thought
|
||||
@@ -76,6 +90,8 @@ models:
|
||||
description: farui-plus大语言模型,支持多工具调用、智能体思考、流式工具调用,12288上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- multi-tool-call
|
||||
@@ -88,6 +104,8 @@ models:
|
||||
description: GLM-4.7大语言模型,支持多工具调用、智能体思考、流式工具调用,202752超大上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- multi-tool-call
|
||||
@@ -100,6 +118,9 @@ models:
|
||||
description: qvq-max-latest大语言模型,支持视觉、智能体思考、流式工具调用,131072上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- vision
|
||||
@@ -112,6 +133,9 @@ models:
|
||||
description: qvq-max大语言模型,支持视觉、智能体思考、流式工具调用,131072上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- vision
|
||||
@@ -124,6 +148,8 @@ models:
|
||||
description: qwen-coder-turbo-0919代码专用大语言模型,支持智能体思考,131072上下文窗口,对话模式,已废弃
|
||||
is_deprecated: true
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- 代码模型
|
||||
@@ -135,6 +161,8 @@ models:
|
||||
description: qwen-max-latest大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式,支持联网搜索
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- multi-tool-call
|
||||
@@ -147,6 +175,8 @@ models:
|
||||
description: qwen-max-longcontext长上下文大语言模型,支持多工具调用、智能体思考、流式工具调用,32000上下文窗口,对话模式,已废弃
|
||||
is_deprecated: true
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- multi-tool-call
|
||||
@@ -159,6 +189,8 @@ models:
|
||||
description: qwen-max大语言模型,支持多工具调用、智能体思考、流式工具调用,32768上下文窗口,对话模式,支持联网搜索
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- multi-tool-call
|
||||
@@ -171,6 +203,8 @@ models:
|
||||
description: qwen-mt-plus多语言翻译大语言模型,支持智能体思考,16384上下文窗口,对话模式,支持多语种互译与领域翻译适配
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- 翻译模型
|
||||
@@ -182,6 +216,8 @@ models:
|
||||
description: qwen-mt-turbo轻量化多语言翻译大语言模型,支持智能体思考,16384上下文窗口,对话模式,支持多语种互译与领域翻译适配
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- 翻译模型
|
||||
@@ -193,6 +229,8 @@ models:
|
||||
description: qwen-plus-0112大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式,支持联网搜索,已废弃
|
||||
is_deprecated: true
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- multi-tool-call
|
||||
@@ -205,6 +243,8 @@ models:
|
||||
description: qwen-plus-0125大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式,支持联网搜索,已废弃
|
||||
is_deprecated: true
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- multi-tool-call
|
||||
@@ -217,6 +257,8 @@ models:
|
||||
description: qwen-plus-0723大语言模型,支持多工具调用、智能体思考、流式工具调用,32000上下文窗口,对话模式,支持联网搜索,已废弃
|
||||
is_deprecated: true
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- multi-tool-call
|
||||
@@ -229,6 +271,8 @@ models:
|
||||
description: qwen-plus-0806大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式,支持联网搜索,已废弃
|
||||
is_deprecated: true
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- multi-tool-call
|
||||
@@ -241,6 +285,8 @@ models:
|
||||
description: qwen-plus-0919大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式,支持联网搜索,已废弃
|
||||
is_deprecated: true
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- multi-tool-call
|
||||
@@ -253,6 +299,8 @@ models:
|
||||
description: qwen-plus-1125大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式,支持联网搜索,已废弃
|
||||
is_deprecated: true
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- multi-tool-call
|
||||
@@ -265,6 +313,8 @@ models:
|
||||
description: qwen-plus-1127大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式,支持联网搜索,已废弃
|
||||
is_deprecated: true
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- multi-tool-call
|
||||
@@ -277,6 +327,8 @@ models:
|
||||
description: qwen-plus-1220大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式,已废弃
|
||||
is_deprecated: true
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- multi-tool-call
|
||||
@@ -289,6 +341,10 @@ models:
|
||||
description: qwen-vl-max多模态大模型,支持视觉理解、智能体思考、视频理解,131072上下文窗口,对话模式,未废弃
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
- video
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- 多模态模型
|
||||
@@ -302,6 +358,10 @@ models:
|
||||
description: qwen-vl-plus-0809多模态大模型,支持视觉理解、智能体思考、视频理解,32768上下文窗口,对话模式,已废弃
|
||||
is_deprecated: true
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
- video
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- 多模态模型
|
||||
@@ -315,6 +375,10 @@ models:
|
||||
description: qwen-vl-plus-2025-01-02多模态大模型,支持视觉理解、智能体思考、视频理解,32768上下文窗口,对话模式,未废弃
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
- video
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- 多模态模型
|
||||
@@ -328,6 +392,10 @@ models:
|
||||
description: qwen-vl-plus-2025-01-25多模态大模型,支持视觉理解、智能体思考、视频理解,131072上下文窗口,对话模式,未废弃
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
- video
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- 多模态模型
|
||||
@@ -341,6 +409,10 @@ models:
|
||||
description: qwen-vl-plus-latest多模态大模型,支持视觉理解、智能体思考、视频理解,131072上下文窗口,对话模式,未废弃
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
- video
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- 多模态模型
|
||||
@@ -354,6 +426,10 @@ models:
|
||||
description: qwen-vl-plus多模态大模型,支持视觉理解、智能体思考、视频理解,131072上下文窗口,对话模式,未废弃
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
- video
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- 多模态模型
|
||||
@@ -367,6 +443,8 @@ models:
|
||||
description: qwen2.5-0.5b-instruct大语言模型,支持多工具调用、智能体思考、流式工具调用,32768上下文窗口,对话模式,未废弃
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- multi-tool-call
|
||||
@@ -379,6 +457,8 @@ models:
|
||||
description: qwen3-14b大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- multi-tool-call
|
||||
@@ -391,6 +471,8 @@ models:
|
||||
description: qwen3-235b-a22b-instruct-2507大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- multi-tool-call
|
||||
@@ -403,6 +485,8 @@ models:
|
||||
description: qwen3-235b-a22b-thinking-2507大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- multi-tool-call
|
||||
@@ -415,6 +499,8 @@ models:
|
||||
description: qwen3-235b-a22b大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- multi-tool-call
|
||||
@@ -427,6 +513,8 @@ models:
|
||||
description: qwen3-30b-a3b-instruct-2507大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- multi-tool-call
|
||||
@@ -439,6 +527,8 @@ models:
|
||||
description: qwen3-30b-a3b大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- multi-tool-call
|
||||
@@ -451,6 +541,8 @@ models:
|
||||
description: qwen3-32b大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- multi-tool-call
|
||||
@@ -463,6 +555,8 @@ models:
|
||||
description: qwen3-4b大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- multi-tool-call
|
||||
@@ -475,6 +569,8 @@ models:
|
||||
description: qwen3-8b大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- multi-tool-call
|
||||
@@ -487,6 +583,8 @@ models:
|
||||
description: qwen3-coder-30b-a3b-instruct大语言模型,支持智能体思考,262144上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- 代码模型
|
||||
@@ -498,6 +596,8 @@ models:
|
||||
description: qwen3-coder-480b-a35b-instruct大语言模型,支持智能体思考,262144上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- 代码模型
|
||||
@@ -509,6 +609,8 @@ models:
|
||||
description: qwen3-coder-plus-2025-09-23大语言模型,支持智能体思考,1000000上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- 代码模型
|
||||
@@ -520,6 +622,8 @@ models:
|
||||
description: qwen3-coder-plus大语言模型,支持智能体思考,1000000上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- 代码模型
|
||||
@@ -531,6 +635,8 @@ models:
|
||||
description: qwen3-max-2025-09-23大语言模型,支持多工具调用、智能体思考、流式工具调用,262144上下文窗口,对话模式,支持联网搜索
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- multi-tool-call
|
||||
@@ -544,6 +650,8 @@ models:
|
||||
description: qwen3-max-2026-01-23大语言模型,支持多工具调用、智能体思考、流式工具调用,262144上下文窗口,对话模式,支持联网搜索
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- multi-tool-call
|
||||
@@ -557,6 +665,8 @@ models:
|
||||
description: qwen3-max-preview大语言模型,支持多工具调用、智能体思考、流式工具调用,262144上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- multi-tool-call
|
||||
@@ -569,6 +679,8 @@ models:
|
||||
description: qwen3-max大语言模型,支持多工具调用、智能体思考、流式工具调用,262144上下文窗口,对话模式,支持联网搜索
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- multi-tool-call
|
||||
@@ -582,6 +694,8 @@ models:
|
||||
description: qwen3-next-80b-a3b-instruct大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- multi-tool-call
|
||||
@@ -594,6 +708,8 @@ models:
|
||||
description: qwen3-next-80b-a3b-thinking大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- multi-tool-call
|
||||
@@ -606,6 +722,11 @@ models:
|
||||
description: qwen3-omni-flash-2025-12-01多模态大语言模型,支持视觉、智能体思考、视频、音频能力,65536上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
- video
|
||||
- audio
|
||||
is_omni: true
|
||||
tags:
|
||||
- 大语言模型
|
||||
- 多模态模型
|
||||
@@ -620,6 +741,10 @@ models:
|
||||
description: qwen3-vl-235b-a22b-instruct多模态大语言模型,支持多工具调用、智能体思考、流式工具调用、视觉、视频能力,131072上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
- video
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- 多模态模型
|
||||
@@ -635,6 +760,10 @@ models:
|
||||
description: qwen3-vl-235b-a22b-thinking多模态大语言模型,支持多工具调用、智能体思考、流式工具调用、视觉、视频能力,131072上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
- video
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- 多模态模型
|
||||
@@ -650,6 +779,10 @@ models:
|
||||
description: qwen3-vl-30b-a3b-instruct多模态大语言模型,支持多工具调用、智能体思考、流式工具调用、视觉、视频能力,131072上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
- video
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- 多模态模型
|
||||
@@ -665,6 +798,10 @@ models:
|
||||
description: qwen3-vl-30b-a3b-thinking多模态大语言模型,支持多工具调用、智能体思考、流式工具调用、视觉、视频能力,131072上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
- video
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- 多模态模型
|
||||
@@ -680,6 +817,10 @@ models:
|
||||
description: qwen3-vl-flash多模态大语言模型,支持多工具调用、智能体思考、流式工具调用、视觉、视频能力,131072上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
- video
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- 多模态模型
|
||||
@@ -695,6 +836,10 @@ models:
|
||||
description: qwen3-vl-plus-2025-09-23多模态大语言模型,支持视觉、智能体思考、视频能力,262144上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
- video
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- 多模态模型
|
||||
@@ -708,6 +853,10 @@ models:
|
||||
description: qwen3-vl-plus多模态大语言模型,支持视觉、智能体思考、视频能力,262144上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
- video
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- 多模态模型
|
||||
@@ -721,6 +870,8 @@ models:
|
||||
description: qwq-32b大语言模型,支持智能体思考、流式工具调用,131072上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- agent-thought
|
||||
@@ -732,6 +883,8 @@ models:
|
||||
description: qwq-plus-0305大语言模型,支持智能体思考、流式工具调用,131072上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- agent-thought
|
||||
@@ -743,6 +896,8 @@ models:
|
||||
description: qwq-plus大语言模型,支持智能体思考、流式工具调用,131072上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- agent-thought
|
||||
@@ -754,6 +909,8 @@ models:
|
||||
description: gte-rerank-v2重排序模型,4000上下文窗口
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 重排序模型
|
||||
logo: dashscope
|
||||
@@ -763,6 +920,8 @@ models:
|
||||
description: gte-rerank重排序模型,4000上下文窗口
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 重排序模型
|
||||
logo: dashscope
|
||||
@@ -772,6 +931,9 @@ models:
|
||||
description: multimodal-embedding-v1多模态嵌入模型,支持视觉能力,8192上下文窗口,最大分块数10
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
is_omni: false
|
||||
tags:
|
||||
- 嵌入模型
|
||||
- 多模态模型
|
||||
@@ -783,6 +945,8 @@ models:
|
||||
description: text-embedding-v1文本嵌入模型,2048上下文窗口,最大分块数25
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 嵌入模型
|
||||
- 文本嵌入
|
||||
@@ -793,6 +957,8 @@ models:
|
||||
description: text-embedding-v2文本嵌入模型,2048上下文窗口,最大分块数25
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 嵌入模型
|
||||
- 文本嵌入
|
||||
@@ -803,6 +969,8 @@ models:
|
||||
description: text-embedding-v3文本嵌入模型,8192上下文窗口,最大分块数10
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 嵌入模型
|
||||
- 文本嵌入
|
||||
@@ -813,7 +981,9 @@ models:
|
||||
description: text-embedding-v4文本嵌入模型,8192上下文窗口,最大分块数10
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 嵌入模型
|
||||
- 文本嵌入
|
||||
logo: dashscope
|
||||
logo: dashscope
|
||||
@@ -6,7 +6,7 @@ from typing import Callable
|
||||
import yaml
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.models.models_model import ModelBase, ModelProvider
|
||||
from app.models.models_model import ModelBase, ModelProvider, ModelConfig
|
||||
|
||||
|
||||
def _load_yaml_config(provider: ModelProvider) -> list[dict]:
|
||||
@@ -55,6 +55,15 @@ def load_models(db: Session, providers: list[str] = None, silent: bool = False)
|
||||
print(f"\n正在加载 {provider.value} 的 {len(models)} 个模型...")
|
||||
|
||||
for model_data in models:
|
||||
config_sync_fields = {
|
||||
"logo": None,
|
||||
"capability": None,
|
||||
"is_omni": None,
|
||||
"name": None,
|
||||
"provider": None,
|
||||
"type": None,
|
||||
"description": None
|
||||
}
|
||||
try:
|
||||
# 检查模型是否已存在
|
||||
existing = db.query(ModelBase).filter(
|
||||
@@ -66,6 +75,40 @@ def load_models(db: Session, providers: list[str] = None, silent: bool = False)
|
||||
# 更新现有模型配置
|
||||
for key, value in model_data.items():
|
||||
setattr(existing, key, value)
|
||||
|
||||
# 更新绑定了该 model_id 的 ModelConfig 和 ModelApiKey
|
||||
sync_fields = [k for k in config_sync_fields.keys() if k in model_data]
|
||||
if sync_fields:
|
||||
# 批量更新 ModelConfig
|
||||
update_kwargs = {k: model_data[k] for k in sync_fields}
|
||||
db.query(ModelConfig).filter(ModelConfig.model_id == existing.id).update(
|
||||
update_kwargs,
|
||||
synchronize_session=False
|
||||
)
|
||||
|
||||
# 更新 ModelApiKey 的 capability 和 is_omni
|
||||
if 'capability' in model_data or 'is_omni' in model_data:
|
||||
from app.models.models_model import ModelApiKey, model_config_api_key_association
|
||||
api_key_update = {}
|
||||
if 'capability' in model_data:
|
||||
api_key_update['capability'] = model_data['capability']
|
||||
if 'is_omni' in model_data:
|
||||
api_key_update['is_omni'] = model_data['is_omni']
|
||||
|
||||
if api_key_update:
|
||||
# 查找所有关联的 API Key
|
||||
api_key_ids = db.query(model_config_api_key_association.c.api_key_id).join(
|
||||
ModelConfig,
|
||||
ModelConfig.id == model_config_api_key_association.c.model_config_id
|
||||
).filter(ModelConfig.model_id == existing.id).distinct().all()
|
||||
|
||||
if api_key_ids:
|
||||
api_key_ids = [aid[0] for aid in api_key_ids]
|
||||
db.query(ModelApiKey).filter(ModelApiKey.id.in_(api_key_ids)).update(
|
||||
api_key_update,
|
||||
synchronize_session=False
|
||||
)
|
||||
|
||||
db.commit()
|
||||
if not silent:
|
||||
print(f"更新成功: {model_data['name']}")
|
||||
|
||||
@@ -6,12 +6,19 @@ models:
|
||||
description: chatgpt-4o-latest大语言模型,支持多工具调用、智能体思考、流式工具调用、视觉能力,128000上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
- audio
|
||||
- video
|
||||
is_omni: true
|
||||
tags:
|
||||
- 大语言模型
|
||||
- multi-tool-call
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
- vision
|
||||
- audio
|
||||
- video
|
||||
logo: openai
|
||||
- name: gpt-3.5-turbo-0125
|
||||
type: llm
|
||||
@@ -19,6 +26,8 @@ models:
|
||||
description: gpt-3.5-turbo-0125大语言模型,支持多工具调用、智能体思考、流式工具调用,16385上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- multi-tool-call
|
||||
@@ -31,6 +40,8 @@ models:
|
||||
description: gpt-3.5-turbo-1106大语言模型,支持多工具调用、智能体思考、流式工具调用,16385上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- multi-tool-call
|
||||
@@ -43,6 +54,8 @@ models:
|
||||
description: gpt-3.5-turbo-16k大语言模型,支持多工具调用、智能体思考、流式工具调用,16385上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- multi-tool-call
|
||||
@@ -55,6 +68,8 @@ models:
|
||||
description: gpt-3.5-turbo-instruct大语言模型,4096上下文窗口,文本补全模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
logo: openai
|
||||
@@ -64,6 +79,8 @@ models:
|
||||
description: gpt-3.5-turbo大语言模型,支持多工具调用、智能体思考、流式工具调用,16385上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- multi-tool-call
|
||||
@@ -76,6 +93,8 @@ models:
|
||||
description: gpt-4-0125-preview大语言模型,支持多工具调用、智能体思考、流式工具调用,128000上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- multi-tool-call
|
||||
@@ -88,6 +107,8 @@ models:
|
||||
description: gpt-4-1106-preview大语言模型,支持多工具调用、智能体思考、流式工具调用,128000上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- multi-tool-call
|
||||
@@ -100,6 +121,9 @@ models:
|
||||
description: gpt-4-turbo-2024-04-09大语言模型,支持多工具调用、智能体思考、流式工具调用、视觉能力,128000上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- multi-tool-call
|
||||
@@ -113,6 +137,8 @@ models:
|
||||
description: gpt-4-turbo-preview大语言模型,支持多工具调用、智能体思考、流式工具调用,128000上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- multi-tool-call
|
||||
@@ -125,6 +151,9 @@ models:
|
||||
description: gpt-4-turbo大语言模型,支持多工具调用、智能体思考、流式工具调用、视觉能力,128000上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- multi-tool-call
|
||||
@@ -138,6 +167,8 @@ models:
|
||||
description: o1-preview大语言模型,支持智能体思考,128000上下文窗口,对话模式,已废弃
|
||||
is_deprecated: true
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- agent-thought
|
||||
@@ -148,6 +179,9 @@ models:
|
||||
description: o1大语言模型,支持多工具调用、智能体思考、流式工具调用、视觉能力、结构化输出,200000上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- multi-tool-call
|
||||
@@ -162,6 +196,9 @@ models:
|
||||
description: o3-2025-04-16大语言模型,支持智能体思考、工具调用、视觉能力、流式工具调用、结构化输出,200000上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- agent-thought
|
||||
@@ -176,6 +213,8 @@ models:
|
||||
description: o3-mini-2025-01-31大语言模型,支持智能体思考、工具调用、流式工具调用、结构化输出,200000上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- agent-thought
|
||||
@@ -189,6 +228,8 @@ models:
|
||||
description: o3-mini大语言模型,支持智能体思考、工具调用、流式工具调用、结构化输出,200000上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- agent-thought
|
||||
@@ -202,6 +243,9 @@ models:
|
||||
description: o3-pro-2025-06-10大语言模型,支持智能体思考、工具调用、视觉能力、结构化输出,200000上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- agent-thought
|
||||
@@ -215,6 +259,9 @@ models:
|
||||
description: o3-pro大语言模型,支持智能体思考、工具调用、视觉能力、结构化输出,200000上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- agent-thought
|
||||
@@ -228,6 +275,9 @@ models:
|
||||
description: o3大语言模型,支持智能体思考、视觉能力、工具调用、流式工具调用、结构化输出,200000上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- agent-thought
|
||||
@@ -242,6 +292,9 @@ models:
|
||||
description: o4-mini-2025-04-16大语言模型,支持智能体思考、工具调用、视觉能力、流式工具调用、结构化输出,200000上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- agent-thought
|
||||
@@ -256,6 +309,9 @@ models:
|
||||
description: o4-mini大语言模型,支持智能体思考、工具调用、视觉能力、流式工具调用、结构化输出,200000上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- agent-thought
|
||||
@@ -270,6 +326,8 @@ models:
|
||||
description: text-embedding-3-large文本向量模型,8191上下文窗口,最大分块数32
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 文本向量模型
|
||||
logo: openai
|
||||
@@ -279,6 +337,8 @@ models:
|
||||
description: text-embedding-3-small文本向量模型,8191上下文窗口,最大分块数32
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 文本向量模型
|
||||
logo: openai
|
||||
@@ -288,6 +348,8 @@ models:
|
||||
description: text-embedding-ada-002文本向量模型,8097上下文窗口,最大分块数32
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 文本向量模型
|
||||
logo: openai
|
||||
logo: openai
|
||||
@@ -98,7 +98,7 @@ class DifyConverter(BaseConverter):
|
||||
if not var_selector:
|
||||
return ""
|
||||
selector = var_selector.split('.')
|
||||
if len(selector) not in [2, 3]:
|
||||
if len(selector) not in [2, 3] and var_selector != "context":
|
||||
raise Exception(f"invalid variable selector: {var_selector}")
|
||||
if len(selector) == 3:
|
||||
selector = selector[1:]
|
||||
@@ -332,7 +332,9 @@ class DifyConverter(BaseConverter):
|
||||
messages.append(
|
||||
MessageConfig(
|
||||
role="user",
|
||||
content=self.trans_variable_format(node_data["memory"]["query_prompt_template"])
|
||||
content=self.trans_variable_format(
|
||||
node_data["memory"].get("query_prompt_template", "{{#sys.query#}}")
|
||||
)
|
||||
)
|
||||
)
|
||||
vision = node_data["vision"]["enabled"]
|
||||
|
||||
@@ -80,7 +80,7 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
|
||||
return True
|
||||
|
||||
def validate_config(self) -> bool:
|
||||
require_fields = frozenset({'app', 'dependencies', 'kind', 'version', 'workflow'})
|
||||
require_fields = frozenset({'app', 'kind', 'version', 'workflow'})
|
||||
if not all(field in self.config for field in require_fields):
|
||||
return False
|
||||
|
||||
|
||||
@@ -303,38 +303,52 @@ class VariablePool:
|
||||
"""
|
||||
return self._get_variable_struct(selector) is not None
|
||||
|
||||
def get_all_system_vars(self) -> dict[str, Any]:
|
||||
def get_all_system_vars(self, literal=False) -> dict[str, Any]:
|
||||
"""获取所有系统变量
|
||||
|
||||
Returns:
|
||||
系统变量字典
|
||||
"""
|
||||
sys_namespace = self.variables.get("sys", {})
|
||||
if literal:
|
||||
return {k: v.instance.to_literal() for k, v in sys_namespace.items()}
|
||||
return {k: v.instance.get_value() for k, v in sys_namespace.items()}
|
||||
|
||||
def get_all_conversation_vars(self) -> dict[str, Any]:
|
||||
def get_all_conversation_vars(self, literal=False) -> dict[str, Any]:
|
||||
"""获取所有会话变量
|
||||
|
||||
Returns:
|
||||
会话变量字典
|
||||
"""
|
||||
conv_namespace = self.variables.get("conv", {})
|
||||
if literal:
|
||||
return {k: v.instance.to_literal() for k, v in conv_namespace.items()}
|
||||
return {k: v.instance.get_value() for k, v in conv_namespace.items()}
|
||||
|
||||
def get_all_node_outputs(self) -> dict[str, Any]:
|
||||
def get_all_node_outputs(self, literal=False) -> dict[str, Any]:
|
||||
"""获取所有节点输出(运行时变量)
|
||||
|
||||
Returns:
|
||||
节点输出字典,键为节点 ID
|
||||
"""
|
||||
runtime_vars = {
|
||||
namespace: {
|
||||
k: v.instance.get_value()
|
||||
for k, v in vars_dict.items()
|
||||
if literal:
|
||||
runtime_vars = {
|
||||
namespace: {
|
||||
k: v.instance.to_literal()
|
||||
for k, v in vars_dict.items()
|
||||
}
|
||||
for namespace, vars_dict in self.variables.items()
|
||||
if namespace not in ("sys", "conv")
|
||||
}
|
||||
else:
|
||||
runtime_vars = {
|
||||
namespace: {
|
||||
k: v.instance.get_value()
|
||||
for k, v in vars_dict.items()
|
||||
}
|
||||
for namespace, vars_dict in self.variables.items()
|
||||
if namespace not in ("sys", "conv")
|
||||
}
|
||||
for namespace, vars_dict in self.variables.items()
|
||||
if namespace not in ("sys", "conv")
|
||||
}
|
||||
return runtime_vars
|
||||
|
||||
def get_node_output(self, node_id: str, defalut: Any = None, strict: bool = True) -> dict[str, Any] | None:
|
||||
|
||||
@@ -14,9 +14,9 @@ from app.core.workflow.engine.state_manager import WorkflowState
|
||||
from app.core.workflow.engine.variable_pool import VariablePool
|
||||
from app.core.workflow.nodes.base_node import BaseNode
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
from app.db import get_db
|
||||
from app.db import get_db_context
|
||||
from app.models import AppRelease
|
||||
from app.services.draft_run_service import DraftRunService
|
||||
from app.services.draft_run_service import AgentRunService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -39,7 +39,7 @@ class AgentNode(BaseNode):
|
||||
def _output_types(self) -> dict[str, VariableType]:
|
||||
return {"output": VariableType.STRING}
|
||||
|
||||
def _prepare_agent(self, variable_pool: VariablePool) -> tuple[DraftRunService, AppRelease, str]:
|
||||
def _prepare_agent(self, variable_pool: VariablePool) -> tuple[AppRelease, str]:
|
||||
"""准备 Agent(公共逻辑)
|
||||
|
||||
Args:
|
||||
@@ -57,17 +57,17 @@ class AgentNode(BaseNode):
|
||||
if not agent_id:
|
||||
raise ValueError(f"节点 {self.node_id} 缺少 agent_id 配置")
|
||||
|
||||
db = next(get_db())
|
||||
release = db.query(AppRelease).filter(
|
||||
AppRelease.id == agent_id
|
||||
).first()
|
||||
with get_db_context() as db:
|
||||
release = db.query(AppRelease).filter(
|
||||
AppRelease.id == agent_id
|
||||
).first()
|
||||
|
||||
if not release:
|
||||
raise ValueError(f"Agent 不存在: {agent_id}")
|
||||
|
||||
draft_service = DraftRunService(db)
|
||||
|
||||
|
||||
return draft_service, release, message
|
||||
return release, message
|
||||
|
||||
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]:
|
||||
"""非流式执行
|
||||
@@ -79,19 +79,21 @@ class AgentNode(BaseNode):
|
||||
Returns:
|
||||
状态更新字典
|
||||
"""
|
||||
draft_service, release, message = self._prepare_agent(variable_pool)
|
||||
release, message = self._prepare_agent(variable_pool)
|
||||
|
||||
logger.info(f"节点 {self.node_id} 开始执行 Agent 调用(非流式)")
|
||||
|
||||
# 执行 Agent(非流式)
|
||||
result = await draft_service.run(
|
||||
agent_config=release.config,
|
||||
model_config=None,
|
||||
message=message,
|
||||
workspace_id=variable_pool.get_value("sys.workspace_id"),
|
||||
user_id=state.get("user_id"),
|
||||
variables=variable_pool.get_all_conversation_vars()
|
||||
)
|
||||
with get_db_context() as db:
|
||||
draft_service = AgentRunService(db)
|
||||
|
||||
# 执行 Agent(非流式)
|
||||
result = await draft_service.run(
|
||||
agent_config=release.config,
|
||||
model_config=None,
|
||||
message=message,
|
||||
workspace_id=variable_pool.get_value("sys.workspace_id"),
|
||||
user_id=state.get("user_id"),
|
||||
variables=variable_pool.get_all_conversation_vars()
|
||||
)
|
||||
|
||||
response = result.get("response", "")
|
||||
|
||||
@@ -118,34 +120,35 @@ class AgentNode(BaseNode):
|
||||
Yields:
|
||||
流式事件字典
|
||||
"""
|
||||
draft_service, release, message = self._prepare_agent(variable_pool)
|
||||
release, message = self._prepare_agent(variable_pool)
|
||||
|
||||
logger.info(f"节点 {self.node_id} 开始执行 Agent 调用(流式)")
|
||||
|
||||
# 累积完整响应
|
||||
full_response = ""
|
||||
|
||||
with get_db_context() as db:
|
||||
draft_service = AgentRunService(db)
|
||||
# 执行 Agent(流式)
|
||||
async for chunk in draft_service.run_stream(
|
||||
agent_config=release.config,
|
||||
model_config=None,
|
||||
message=message,
|
||||
workspace_id=variable_pool.get_value("sys.workspace_id"),
|
||||
user_id=state.get("user_id"),
|
||||
variables=variable_pool.get_all_conversation_vars()
|
||||
):
|
||||
# 提取内容
|
||||
content = chunk.get("content", "")
|
||||
full_response += content
|
||||
|
||||
# 流式返回每个 chunk
|
||||
yield {
|
||||
"type": "chunk",
|
||||
"node_id": self.node_id,
|
||||
"content": content,
|
||||
"full_content": full_response,
|
||||
"meta_data": chunk.get("meta_data", {})
|
||||
}
|
||||
async for chunk in draft_service.run_stream(
|
||||
agent_config=release.config,
|
||||
model_config=None,
|
||||
message=message,
|
||||
workspace_id=variable_pool.get_value("sys.workspace_id"),
|
||||
user_id=state.get("user_id"),
|
||||
variables=variable_pool.get_all_conversation_vars()
|
||||
):
|
||||
# 提取内容
|
||||
content = chunk.get("content", "")
|
||||
full_response += content
|
||||
|
||||
# 流式返回每个 chunk
|
||||
yield {
|
||||
"type": "chunk",
|
||||
"node_id": self.node_id,
|
||||
"content": content,
|
||||
"full_content": full_response,
|
||||
"meta_data": chunk.get("meta_data", {})
|
||||
}
|
||||
|
||||
logger.info(f"节点 {self.node_id} Agent 调用完成,输出长度: {len(full_response)}")
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
from abc import ABC, abstractmethod
|
||||
from functools import cached_property
|
||||
from typing import Any, AsyncGenerator
|
||||
@@ -10,8 +11,10 @@ from app.core.config import settings
|
||||
from app.core.workflow.engine.state_manager import WorkflowState
|
||||
from app.core.workflow.engine.variable_pool import VariablePool
|
||||
from app.core.workflow.nodes.enums import BRANCH_NODES
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
from app.services.multimodal_service import PROVIDER_STRATEGIES
|
||||
from app.core.workflow.variable.base_variable import VariableType, FileObject
|
||||
from app.db import get_db_read
|
||||
from app.schemas import FileInput
|
||||
from app.services.multimodal_service import MultimodalService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -548,9 +551,9 @@ class BaseNode(ABC):
|
||||
|
||||
return render_template(
|
||||
template=template,
|
||||
conv_vars=variable_pool.get_all_conversation_vars(),
|
||||
node_outputs=variable_pool.get_all_node_outputs(),
|
||||
system_vars=variable_pool.get_all_system_vars(),
|
||||
conv_vars=variable_pool.get_all_conversation_vars(literal=True),
|
||||
node_outputs=variable_pool.get_all_node_outputs(literal=True),
|
||||
system_vars=variable_pool.get_all_system_vars(literal=True),
|
||||
strict=strict
|
||||
)
|
||||
|
||||
@@ -614,16 +617,32 @@ class BaseNode(ABC):
|
||||
return variable_pool.has(selector)
|
||||
|
||||
@staticmethod
|
||||
async def process_message(provider, content, enable_file=False) -> dict | str | None:
|
||||
async def process_message(provider: str, content: str | FileObject, enable_file=False) -> dict | str | None:
|
||||
if isinstance(content, str):
|
||||
if enable_file:
|
||||
return {"text": content}
|
||||
return content
|
||||
elif isinstance(content, dict):
|
||||
trans_tool = PROVIDER_STRATEGIES[provider]()
|
||||
result = await trans_tool.format_image(content["url"])
|
||||
return result
|
||||
raise TypeError('Unexpect input value type')
|
||||
|
||||
elif isinstance(content, FileObject):
|
||||
if content.content_cache.get(provider):
|
||||
return content.content_cache[provider]
|
||||
with get_db_read() as db:
|
||||
multimodel_service = MultimodalService(db, provider)
|
||||
message = await multimodel_service.process_files(
|
||||
[FileInput.model_construct(
|
||||
type=content.type,
|
||||
url=content.url,
|
||||
transfer_method=content.transfer_method,
|
||||
file_type=content.origin_file_type,
|
||||
upload_file_id=content.file_id
|
||||
)]
|
||||
)
|
||||
|
||||
if message:
|
||||
content.content_cache[provider] = message[0]
|
||||
return message[0]
|
||||
return None
|
||||
raise TypeError(f'Unexpect input value type - {type(content)}')
|
||||
|
||||
@staticmethod
|
||||
def process_model_output(content) -> str:
|
||||
|
||||
@@ -91,8 +91,8 @@ class IterationRuntime:
|
||||
return loopstate
|
||||
|
||||
def merge_conv_vars(self):
|
||||
self.variable_pool.get_all_conversation_vars().update(
|
||||
self.child_variable_pool.get_all_conversation_vars()
|
||||
self.variable_pool.variables["conv"].update(
|
||||
self.child_variable_pool.variables["conv"]
|
||||
)
|
||||
|
||||
async def run_task(self, item, idx):
|
||||
|
||||
@@ -156,7 +156,7 @@ class LoopRuntime:
|
||||
|
||||
def merge_conv_vars(self, loopstate):
|
||||
self.variable_pool.variables["conv"].update(
|
||||
self.child_variable_pool.variables.get("conv", {})
|
||||
self.child_variable_pool.variables["conv"]
|
||||
)
|
||||
loop_vars = self.child_variable_pool.get_node_output(self.node_id, defalut={}, strict=False)
|
||||
loopstate["node_outputs"][self.node_id] = loop_vars
|
||||
|
||||
@@ -172,9 +172,9 @@ class LLMNode(BaseNode):
|
||||
|
||||
if self.typed_config.vision_input and self.typed_config.vision:
|
||||
file_content = []
|
||||
files = variable_pool.get_value(self.typed_config.vision_input)
|
||||
for file in files:
|
||||
content = await self.process_message(provider, file, self.typed_config.vision)
|
||||
files = variable_pool.get_instance(self.typed_config.vision_input)
|
||||
for file in files.value:
|
||||
content = await self.process_message(provider, file.value, self.typed_config.vision)
|
||||
if content:
|
||||
file_content.append(content)
|
||||
if messages and messages[-1]["role"] == 'user':
|
||||
|
||||
@@ -2,7 +2,7 @@ from enum import StrEnum
|
||||
from abc import abstractmethod, ABC
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.schemas import FileType
|
||||
|
||||
@@ -45,7 +45,7 @@ class VariableType(StrEnum):
|
||||
return cls.NUMBER
|
||||
elif isinstance(var, bool):
|
||||
return cls.BOOLEAN
|
||||
elif isinstance(var, FileObject) or (isinstance(var, dict) and var.get('__file')):
|
||||
elif isinstance(var, FileObject) or (isinstance(var, dict) and var.get('is_file')):
|
||||
return cls.FILE
|
||||
elif isinstance(var, dict):
|
||||
return cls.OBJECT
|
||||
@@ -109,7 +109,13 @@ def DEFAULT_VALUE(var_type: VariableType) -> Any:
|
||||
class FileObject(BaseModel):
|
||||
type: FileType
|
||||
url: str
|
||||
__file: bool
|
||||
transfer_method: str
|
||||
origin_file_type: str
|
||||
file_id: str | None
|
||||
|
||||
content_cache: dict = Field(default_factory=dict)
|
||||
|
||||
is_file: bool
|
||||
|
||||
|
||||
class BaseVariable(ABC):
|
||||
|
||||
@@ -63,13 +63,16 @@ class FileVariable(BaseVariable):
|
||||
def valid_value(self, value) -> FileObject:
|
||||
|
||||
if isinstance(value, dict):
|
||||
if not value.get("__file"):
|
||||
if not value.get("is_file"):
|
||||
raise TypeError(f"Value must be a FileObject - {type(value)}:{value}")
|
||||
return FileObject(
|
||||
**{
|
||||
"type": str(value.get('type')),
|
||||
"transfer_method": value.get("transfer_method"),
|
||||
"url": value.get('url'),
|
||||
"__file": True
|
||||
"file_id": value.get("file_id"),
|
||||
"origin_file_type": value.get("origin_file_type"),
|
||||
"is_file": True
|
||||
}
|
||||
)
|
||||
if isinstance(value, FileObject):
|
||||
|
||||
@@ -35,6 +35,7 @@ from .ontology_scene import OntologyScene
|
||||
from .ontology_class import OntologyClass
|
||||
from .ontology_scene import OntologyScene
|
||||
from .ontology_class import OntologyClass
|
||||
from .implicit_emotions_storage_model import ImplicitEmotionsStorage
|
||||
|
||||
__all__ = [
|
||||
"Tenants",
|
||||
@@ -90,5 +91,6 @@ __all__ = [
|
||||
"MemoryPerceptualModel",
|
||||
"ModelBase",
|
||||
"LoadBalanceStrategy",
|
||||
"Skill"
|
||||
"Skill",
|
||||
"ImplicitEmotionsStorage"
|
||||
]
|
||||
|
||||
45
api/app/models/implicit_emotions_storage_model.py
Normal file
45
api/app/models/implicit_emotions_storage_model.py
Normal file
@@ -0,0 +1,45 @@
|
||||
"""
|
||||
Implicit Emotions Storage Model
|
||||
|
||||
数据库模型:存储用户的隐性记忆画像和情绪建议数据
|
||||
替代原有的Redis缓存方式
|
||||
"""
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from sqlalchemy import Column, String, Text, DateTime, Index
|
||||
from sqlalchemy.dialects.postgresql import UUID, JSONB
|
||||
from app.db import Base
|
||||
|
||||
|
||||
class ImplicitEmotionsStorage(Base):
|
||||
"""隐性记忆和情绪存储表"""
|
||||
|
||||
__tablename__ = "implicit_emotions_storage"
|
||||
|
||||
# 主键
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, comment="主键ID")
|
||||
|
||||
# 用户标识(unique=True会自动创建唯一索引)
|
||||
end_user_id = Column(String(255), nullable=False, unique=True, comment="终端用户ID")
|
||||
|
||||
# 隐性记忆画像数据(JSON格式)
|
||||
implicit_profile = Column(JSONB, nullable=True, comment="隐性记忆用户画像数据")
|
||||
|
||||
# 情绪建议数据(JSON格式)
|
||||
emotion_suggestions = Column(JSONB, nullable=True, comment="情绪个性化建议数据")
|
||||
|
||||
# 时间戳
|
||||
created_at = Column(DateTime, nullable=False, default=datetime.utcnow, comment="创建时间")
|
||||
updated_at = Column(DateTime, nullable=False, default=datetime.utcnow, onupdate=datetime.utcnow, comment="更新时间")
|
||||
|
||||
# 数据生成时间(用于业务逻辑)
|
||||
implicit_generated_at = Column(DateTime, nullable=True, comment="隐性记忆画像生成时间")
|
||||
emotion_generated_at = Column(DateTime, nullable=True, comment="情绪建议生成时间")
|
||||
|
||||
# 索引(只为updated_at创建索引,end_user_id的unique约束已自动创建索引)
|
||||
__table_args__ = (
|
||||
Index('idx_updated_at', 'updated_at'),
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
return f"<ImplicitEmotionsStorage(id={self.id}, end_user_id={self.end_user_id})>"
|
||||
@@ -2,7 +2,7 @@ import datetime
|
||||
import uuid
|
||||
from enum import StrEnum
|
||||
|
||||
from sqlalchemy import Column, String, Boolean, DateTime, Text, ForeignKey, Enum as SQLEnum, UniqueConstraint, Integer, ARRAY, Table
|
||||
from sqlalchemy import Column, String, Boolean, DateTime, Text, ForeignKey, Enum as SQLEnum, UniqueConstraint, Integer, ARRAY, Table, text
|
||||
from sqlalchemy.dialects.postgresql import UUID, JSON
|
||||
from sqlalchemy.orm import relationship
|
||||
from sqlalchemy.sql import func
|
||||
@@ -78,6 +78,9 @@ class ModelConfig(BaseModel):
|
||||
description = Column(String, comment="模型描述")
|
||||
|
||||
# 模型配置参数
|
||||
capability = Column(ARRAY(String), default=list, nullable=False, server_default=text("'{}'::varchar[]"),
|
||||
comment="模型能力列表(如['vision', 'audio', 'video'])")
|
||||
is_omni = Column(Boolean, default=False, nullable=False, server_default="false", comment="是否为Omni模型(使用特殊API调用)")
|
||||
config = Column(JSON, comment="模型配置参数")
|
||||
# - temperature : 控制生成文本的随机性。值越高,输出越随机、越有创造性;值越低,输出越确定、越保守。
|
||||
# - top_p : 一种替代 temperature 的采样方法,控制模型从概率最高的词中选择的范围。
|
||||
@@ -118,6 +121,11 @@ class ModelApiKey(BaseModel):
|
||||
api_key = Column(String, nullable=False, comment="API密钥")
|
||||
api_base = Column(String, comment="API基础URL")
|
||||
|
||||
# 模型能力参数
|
||||
capability = Column(ARRAY(String), default=list, nullable=False, server_default=text("'{}'::varchar[]"),
|
||||
comment="模型能力列表(如['vision', 'audio', 'video'])")
|
||||
is_omni = Column(Boolean, default=False, nullable=False, server_default="false", comment="是否为Omni模型(使用特殊API调用)")
|
||||
|
||||
# 配置参数
|
||||
config = Column(JSON, comment="API Key特定配置")
|
||||
|
||||
@@ -155,6 +163,9 @@ class ModelBase(Base):
|
||||
tags = Column(ARRAY(String), default=list, nullable=False, comment="模型标签(如['聊天', '创作'])")
|
||||
add_count = Column(Integer, default=0, nullable=False, comment="模型被用户添加的次数")
|
||||
created_at = Column(DateTime, default=datetime.datetime.now, comment="创建时间", server_default=func.now())
|
||||
capability = Column(ARRAY(String), default=list, nullable=False, server_default=text("'{}'::varchar[]"),
|
||||
comment="模型能力列表(如['vision', 'audio', 'video'])")
|
||||
is_omni = Column(Boolean, default=False, nullable=False, server_default="false", comment="是否为Omni模型(使用特殊API调用)")
|
||||
|
||||
# 关联关系
|
||||
configs = relationship("ModelConfig", back_populates="model_base", cascade="all, delete-orphan")
|
||||
|
||||
169
api/app/repositories/implicit_emotions_storage_repository.py
Normal file
169
api/app/repositories/implicit_emotions_storage_repository.py
Normal file
@@ -0,0 +1,169 @@
|
||||
"""
|
||||
Implicit Emotions Storage Repository
|
||||
|
||||
数据访问层:处理隐性记忆和情绪数据的数据库操作
|
||||
事务由调用方控制,仓储层只使用 flush/refresh
|
||||
"""
|
||||
import logging
|
||||
from datetime import datetime, date, timezone, timedelta
|
||||
from typing import Optional, Generator
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import select, not_, exists
|
||||
|
||||
from app.models.implicit_emotions_storage_model import ImplicitEmotionsStorage
|
||||
from app.models.end_user_model import EndUser
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ImplicitEmotionsStorageRepository:
|
||||
"""隐性记忆和情绪存储仓储类"""
|
||||
|
||||
def __init__(self, db: Session):
|
||||
self.db = db
|
||||
|
||||
def get_by_end_user_id(self, end_user_id: str) -> Optional[ImplicitEmotionsStorage]:
|
||||
"""根据终端用户ID获取存储记录"""
|
||||
try:
|
||||
stmt = select(ImplicitEmotionsStorage).where(
|
||||
ImplicitEmotionsStorage.end_user_id == end_user_id
|
||||
)
|
||||
return self.db.execute(stmt).scalar_one_or_none()
|
||||
except Exception as e:
|
||||
logger.error(f"获取用户存储记录失败: end_user_id={end_user_id}, error={e}")
|
||||
return None
|
||||
|
||||
def create(self, end_user_id: str) -> ImplicitEmotionsStorage:
|
||||
"""创建新的存储记录(事务由调用方提交)"""
|
||||
storage = ImplicitEmotionsStorage(
|
||||
end_user_id=end_user_id,
|
||||
created_at=datetime.utcnow(),
|
||||
updated_at=datetime.utcnow()
|
||||
)
|
||||
self.db.add(storage)
|
||||
self.db.flush()
|
||||
self.db.refresh(storage)
|
||||
logger.info(f"创建用户存储记录成功: end_user_id={end_user_id}")
|
||||
return storage
|
||||
|
||||
def update_implicit_profile(
|
||||
self,
|
||||
end_user_id: str,
|
||||
profile_data: dict
|
||||
) -> ImplicitEmotionsStorage:
|
||||
"""更新隐性记忆画像数据(事务由调用方提交)"""
|
||||
storage = self.get_by_end_user_id(end_user_id)
|
||||
if storage is None:
|
||||
storage = self.create(end_user_id)
|
||||
|
||||
storage.implicit_profile = profile_data
|
||||
storage.implicit_generated_at = datetime.utcnow()
|
||||
storage.updated_at = datetime.utcnow()
|
||||
|
||||
self.db.flush()
|
||||
self.db.refresh(storage)
|
||||
logger.info(f"更新隐性记忆画像成功: end_user_id={end_user_id}")
|
||||
return storage
|
||||
|
||||
def update_emotion_suggestions(
|
||||
self,
|
||||
end_user_id: str,
|
||||
suggestions_data: dict
|
||||
) -> ImplicitEmotionsStorage:
|
||||
"""更新情绪建议数据(事务由调用方提交)"""
|
||||
storage = self.get_by_end_user_id(end_user_id)
|
||||
if storage is None:
|
||||
storage = self.create(end_user_id)
|
||||
|
||||
storage.emotion_suggestions = suggestions_data
|
||||
storage.emotion_generated_at = datetime.utcnow()
|
||||
storage.updated_at = datetime.utcnow()
|
||||
|
||||
self.db.flush()
|
||||
self.db.refresh(storage)
|
||||
logger.info(f"更新情绪建议成功: end_user_id={end_user_id}")
|
||||
return storage
|
||||
|
||||
def get_all_user_ids(self, batch_size: int = 100) -> Generator[str, None, None]:
|
||||
"""分批次获取所有已存储数据的用户ID(避免大数据量内存溢出)
|
||||
|
||||
Args:
|
||||
batch_size: 每批次加载的数量,默认100
|
||||
|
||||
Yields:
|
||||
用户ID字符串
|
||||
"""
|
||||
offset = 0
|
||||
while True:
|
||||
try:
|
||||
stmt = (
|
||||
select(ImplicitEmotionsStorage.end_user_id)
|
||||
.order_by(ImplicitEmotionsStorage.end_user_id)
|
||||
.limit(batch_size)
|
||||
.offset(offset)
|
||||
)
|
||||
batch = self.db.execute(stmt).scalars().all()
|
||||
if not batch:
|
||||
break
|
||||
yield from batch
|
||||
offset += batch_size
|
||||
except Exception as e:
|
||||
logger.error(f"分批获取用户ID失败: offset={offset}, error={e}")
|
||||
break
|
||||
|
||||
def get_new_user_ids_today(self, batch_size: int = 100) -> Generator[str, None, None]:
|
||||
"""分批次获取当天新增的、尚未初始化隐性记忆和情绪建议数据的用户ID
|
||||
|
||||
查询逻辑:end_users 表中 created_at 为今天,且在 implicit_emotions_storage 中没有对应记录。
|
||||
没有对应记录意味着隐性记忆画像和情绪建议均未初始化,需要对这批用户执行首次初始化。
|
||||
end_users.id(UUID)转为字符串后与 implicit_emotions_storage.end_user_id(String)对比。
|
||||
|
||||
Args:
|
||||
batch_size: 每批次加载的数量,默认100
|
||||
|
||||
Yields:
|
||||
用户ID字符串
|
||||
"""
|
||||
from sqlalchemy import cast, String as SAString
|
||||
CST = timezone(timedelta(hours=8))
|
||||
now_cst = datetime.now(CST)
|
||||
today_start = now_cst.replace(hour=0, minute=0, second=0, microsecond=0).astimezone(timezone.utc).replace(tzinfo=None)
|
||||
tomorrow_start = today_start + timedelta(days=1)
|
||||
offset = 0
|
||||
while True:
|
||||
try:
|
||||
stmt = (
|
||||
select(EndUser.id)
|
||||
.where(
|
||||
EndUser.created_at >= today_start,
|
||||
EndUser.created_at < tomorrow_start,
|
||||
not_(
|
||||
exists(
|
||||
select(ImplicitEmotionsStorage.end_user_id).where(
|
||||
ImplicitEmotionsStorage.end_user_id == cast(EndUser.id, SAString)
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
.order_by(EndUser.id)
|
||||
.limit(batch_size)
|
||||
.offset(offset)
|
||||
)
|
||||
batch = self.db.execute(stmt).scalars().all()
|
||||
if not batch:
|
||||
break
|
||||
yield from (str(uid) for uid in batch)
|
||||
offset += batch_size
|
||||
except Exception as e:
|
||||
logger.error(f"分批获取当天新增用户ID失败: offset={offset}, error={e}")
|
||||
break
|
||||
|
||||
def delete_by_end_user_id(self, end_user_id: str) -> bool:
|
||||
"""删除用户的存储记录(事务由调用方提交)"""
|
||||
storage = self.get_by_end_user_id(end_user_id)
|
||||
if storage:
|
||||
self.db.delete(storage)
|
||||
self.db.flush()
|
||||
logger.info(f"删除用户存储记录成功: end_user_id={end_user_id}")
|
||||
return True
|
||||
return False
|
||||
@@ -211,3 +211,46 @@ def get_total_kb_count_by_workspace(db: Session, workspace_id: uuid.UUID) -> int
|
||||
except Exception as e:
|
||||
db_logger.error(f"Failed to query total knowledge base count: workspace_id={workspace_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
def get_user_kb_chunk_num_by_workspace(db: Session, workspace_id: uuid.UUID) -> int:
|
||||
"""
|
||||
根据workspace_id查询knowledges表中permission_id='Memory'(用户知识库)的chunk_num总和
|
||||
"""
|
||||
db_logger.debug(f"Query user KB chunk_num by workspace_id: workspace_id={workspace_id}")
|
||||
|
||||
try:
|
||||
from sqlalchemy import func
|
||||
result = db.query(func.sum(Knowledge.chunk_num)).filter(
|
||||
Knowledge.workspace_id == workspace_id,
|
||||
Knowledge.status == 1,
|
||||
Knowledge.permission_id == "Memory"
|
||||
).scalar()
|
||||
|
||||
total = result if result is not None else 0
|
||||
db_logger.info(f"User KB chunk_num query successful: workspace_id={workspace_id}, total={total}")
|
||||
return total
|
||||
except Exception as e:
|
||||
db_logger.error(f"Failed to query user KB chunk_num: workspace_id={workspace_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
def get_non_user_kb_count_by_workspace(db: Session, workspace_id: uuid.UUID) -> int:
|
||||
"""
|
||||
根据workspace_id查询knowledges表中排除用户知识库(permission_id!='Memory')的数量
|
||||
"""
|
||||
db_logger.debug(f"Query non-user KB count by workspace_id: workspace_id={workspace_id}")
|
||||
|
||||
try:
|
||||
count = db.query(Knowledge).filter(
|
||||
Knowledge.workspace_id == workspace_id,
|
||||
Knowledge.status == 1,
|
||||
Knowledge.permission_id != "Memory"
|
||||
).count()
|
||||
|
||||
db_logger.info(f"Non-user KB count query successful: workspace_id={workspace_id}, count={count}")
|
||||
return count
|
||||
except Exception as e:
|
||||
db_logger.error(f"Failed to query non-user KB count: workspace_id={workspace_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
@@ -374,7 +374,7 @@ class OntologySceneRepository:
|
||||
|
||||
count = self.db.query(OntologyScene).filter(
|
||||
OntologyScene.scene_id == scene_id,
|
||||
OntologyScene.workspace_id == workspace_id
|
||||
(OntologyScene.workspace_id == workspace_id) | (OntologyScene.is_system_default == True)
|
||||
).count()
|
||||
|
||||
is_owner = count > 0
|
||||
|
||||
@@ -15,7 +15,7 @@ class ApiKeyCreate(BaseModel):
|
||||
type: ApiKeyType = Field(..., description="API Key 类型")
|
||||
scopes: List[str] = Field(default_factory=list, description="权限范围列表")
|
||||
resource_id: Optional[uuid.UUID] = Field(None, description="关联资源ID")
|
||||
rate_limit: Optional[int] = Field(10, ge=1, le=1000, description="QPS限制(请求/秒)")
|
||||
rate_limit: Optional[int] = Field(100, ge=1, le=1000, description="QPS限制(请求/秒)")
|
||||
daily_request_limit: Optional[int] = Field(10000, description="日请求限制", ge=1)
|
||||
quota_limit: Optional[int] = Field(None, description="配额限制(总请求数)", ge=1)
|
||||
expires_at: Optional[datetime.datetime] = Field(None, description="过期时间")
|
||||
@@ -155,8 +155,7 @@ class ApiKey(BaseModel):
|
||||
return datetime.datetime.now() > self.expires_at
|
||||
|
||||
@field_serializer('expires_at', 'last_used_at', 'created_at', 'updated_at')
|
||||
@classmethod
|
||||
def serialize_datetime(cls, v: Optional[datetime.datetime]) -> Optional[int]:
|
||||
def serialize_datetime(self, v: Optional[datetime.datetime]) -> Optional[int]:
|
||||
"""将datetime转换为时间戳"""
|
||||
return datetime_to_timestamp(v)
|
||||
|
||||
@@ -171,8 +170,7 @@ class ApiKeyStats(BaseModel):
|
||||
avg_response_time: Optional[float] = Field(None, description="平均响应时间(毫秒)")
|
||||
|
||||
@field_serializer('last_used_at')
|
||||
@classmethod
|
||||
def serialize_datetime(cls, v: Optional[datetime.datetime]) -> Optional[int]:
|
||||
def serialize_datetime(self, v: Optional[datetime.datetime]) -> Optional[int]:
|
||||
"""将datetime转换为时间戳"""
|
||||
return datetime_to_timestamp(v)
|
||||
|
||||
@@ -219,7 +217,6 @@ class ApiKeyLog(BaseModel):
|
||||
created_at: datetime.datetime
|
||||
|
||||
@field_serializer('created_at')
|
||||
@classmethod
|
||||
def serialize_datetime(cls, v: datetime.datetime) -> int:
|
||||
def serialize_datetime(self, v: datetime.datetime) -> int:
|
||||
"""将datetime转换为时间戳"""
|
||||
return datetime_to_timestamp(v)
|
||||
|
||||
@@ -21,8 +21,14 @@ class FileType(StrEnum):
|
||||
def trans(cls, value: str) -> 'FileType':
|
||||
if value.startswith("image"):
|
||||
return cls.IMAGE
|
||||
# TODO: other file type support
|
||||
raise RuntimeError("Unsupport file type")
|
||||
elif value.startswith("document"):
|
||||
return cls.DOCUMENT
|
||||
elif value.startswith("audio"):
|
||||
return cls.AUDIO
|
||||
elif value.startswith("video"):
|
||||
return cls.VIDEO
|
||||
else:
|
||||
raise RuntimeError("Unsupport file type")
|
||||
|
||||
|
||||
class TransferMethod(str, Enum):
|
||||
@@ -37,6 +43,12 @@ class FileInput(BaseModel):
|
||||
transfer_method: TransferMethod = Field(..., description="传输方式: local_file/remote_url")
|
||||
upload_file_id: Optional[uuid.UUID] = Field(None, description="已上传文件ID(local_file时必填)")
|
||||
url: Optional[str] = Field(None, description="远程URL(remote_url时必填)")
|
||||
file_type: Optional[str] = Field(None, description="具体文件格式(如image/jpg、audio/wav、document/docx、video/mp4)")
|
||||
|
||||
def __init__(self, **data):
|
||||
if "type" in data:
|
||||
data['file_type'] = data['type']
|
||||
super().__init__(**data)
|
||||
|
||||
@field_validator("type", mode="before")
|
||||
@classmethod
|
||||
|
||||
@@ -46,6 +46,7 @@ class ChunkUpdate(BaseModel):
|
||||
class ChunkRetrieve(BaseModel):
|
||||
query: str
|
||||
kb_ids: list[uuid.UUID]
|
||||
file_names_filter: list[str] | None = Field(None)
|
||||
similarity_threshold: float | None = Field(None)
|
||||
vector_similarity_weight: float | None = Field(None)
|
||||
top_k: int | None = Field(None)
|
||||
|
||||
@@ -21,6 +21,8 @@ class ModelConfigBase(BaseModel):
|
||||
is_active: bool = Field(True, description="是否激活")
|
||||
is_public: bool = Field(False, description="是否公开")
|
||||
load_balance_strategy: Optional[str] = Field(LoadBalanceStrategy.NONE.value, description="负载均衡策略")
|
||||
capability: List[str] = Field(default_factory=list, description="模型能力列表")
|
||||
is_omni: bool = Field(False, description="是否为Omni模型")
|
||||
|
||||
|
||||
class ApiKeyCreateNested(BaseModel):
|
||||
@@ -30,6 +32,8 @@ class ApiKeyCreateNested(BaseModel):
|
||||
provider: Optional[str] = Field(None, description="API Key提供商")
|
||||
api_key: str = Field(..., description="API密钥", max_length=500)
|
||||
api_base: Optional[str] = Field(None, description="API基础URL", max_length=500)
|
||||
capability: Optional[List[str]] = Field(None, description="模型能力列表")
|
||||
is_omni: Optional[bool] = Field(None, description="是否为Omni模型")
|
||||
config: Optional[Dict[str, Any]] = Field({}, description="API Key特定配置")
|
||||
priority: str = Field("1", description="优先级", max_length=10)
|
||||
|
||||
@@ -63,6 +67,8 @@ class ModelConfigUpdate(BaseModel):
|
||||
config: Optional[Dict[str, Any]] = Field(None, description="模型配置参数")
|
||||
is_active: Optional[bool] = Field(None, description="是否激活")
|
||||
is_public: Optional[bool] = Field(None, description="是否公开")
|
||||
capability: Optional[List[str]] = Field(None, description="模型能力列表")
|
||||
is_omni: Optional[bool] = Field(None, description="是否为Omni模型")
|
||||
|
||||
|
||||
class ModelConfig(ModelConfigBase):
|
||||
@@ -95,6 +101,8 @@ class ModelApiKeyCreateByProvider(BaseModel):
|
||||
api_key: str = Field(..., description="API密钥", max_length=500)
|
||||
api_base: Optional[str] = Field(None, description="API基础URL", max_length=500)
|
||||
description: Optional[str] = Field(None, description="备注")
|
||||
capability: Optional[List[str]] = Field(None, description="模型能力列表")
|
||||
is_omni: Optional[bool] = Field(None, description="是否为Omni模型")
|
||||
config: Optional[Dict[str, Any]] = Field({}, description="API Key特定配置")
|
||||
is_active: bool = Field(True, description="是否激活")
|
||||
priority: str = Field("1", description="优先级", max_length=10)
|
||||
@@ -108,6 +116,8 @@ class ModelApiKeyBase(BaseModel):
|
||||
provider: ModelProvider = Field(..., description="API Key提供商")
|
||||
api_key: str = Field(..., description="API密钥", max_length=500)
|
||||
api_base: Optional[str] = Field(None, description="API基础URL", max_length=500)
|
||||
capability: Optional[List[str]] = Field(None, description="模型能力列表")
|
||||
is_omni: Optional[bool] = Field(None, description="是否为Omni模型")
|
||||
config: Optional[Dict[str, Any]] = Field({}, description="API Key特定配置")
|
||||
is_active: bool = Field(True, description="是否激活")
|
||||
priority: str = Field("1", description="优先级", max_length=10)
|
||||
@@ -124,6 +134,8 @@ class ModelApiKeyUpdate(BaseModel):
|
||||
provider: Optional[ModelProvider] = Field(None, description="API Key提供商")
|
||||
api_key: Optional[str] = Field(None, description="API密钥", max_length=500)
|
||||
api_base: Optional[str] = Field(None, description="API基础URL", max_length=500)
|
||||
capability: Optional[List[str]] = Field(None, description="模型能力列表")
|
||||
is_omni: Optional[bool] = Field(None, description="是否为Omni模型")
|
||||
config: Optional[Dict[str, Any]] = Field(None, description="API Key特定配置")
|
||||
is_active: Optional[bool] = Field(None, description="是否激活")
|
||||
priority: Optional[str] = Field(None, description="优先级", max_length=10)
|
||||
@@ -270,6 +282,8 @@ class ModelBaseCreate(BaseModel):
|
||||
description: Optional[str] = Field(None, description="模型描述")
|
||||
is_official: bool = Field(True, description="是否供应商官方模型")
|
||||
tags: List[str] = Field(default_factory=list, description="模型标签")
|
||||
capability: List[str] = Field(default_factory=list, description="模型能力列表(如['vision', 'audio', 'video'])")
|
||||
is_omni: bool = Field(False, description="是否为Omni模型")
|
||||
|
||||
|
||||
class ModelBaseUpdate(BaseModel):
|
||||
@@ -282,6 +296,8 @@ class ModelBaseUpdate(BaseModel):
|
||||
is_deprecated: Optional[bool] = Field(None, description="是否弃用")
|
||||
is_official: Optional[bool] = Field(None, description="是否供应商官方模型")
|
||||
tags: Optional[List[str]] = Field(None, description="模型标签")
|
||||
capability: Optional[List[str]] = Field(None, description="模型能力列表")
|
||||
is_omni: Optional[bool] = Field(None, description="是否为Omni模型")
|
||||
|
||||
|
||||
class ModelBase(BaseModel):
|
||||
@@ -298,6 +314,8 @@ class ModelBase(BaseModel):
|
||||
is_official: bool
|
||||
tags: List[str]
|
||||
add_count: int
|
||||
capability: List[str] = []
|
||||
is_omni: bool = False
|
||||
|
||||
|
||||
class ModelBaseQuery(BaseModel):
|
||||
|
||||
@@ -64,14 +64,14 @@ class ExecutionConfig(BaseModel):
|
||||
class MultiAgentConfigCreate(BaseModel):
|
||||
"""创建多 Agent 配置"""
|
||||
master_agent_id: uuid.UUID = Field(..., description="主 Agent ID")
|
||||
master_agent_name: Optional[str] = Field(None, max_length=100, description="主 Agent 名称")
|
||||
master_agent_name: Optional[str] = Field(default=None, max_length=100, description="主 Agent 名称")
|
||||
orchestration_mode: str = Field(
|
||||
default="collaboration",
|
||||
pattern="^(collaboration|supervisor)$",
|
||||
description="协作模式:collaboration(协作)| supervisor(监督)"
|
||||
)
|
||||
sub_agents: List[SubAgentConfig] = Field(..., description="子 Agent 列表")
|
||||
routing_rules: Optional[List[RoutingRule]] = Field(None, description="路由规则")
|
||||
routing_rules: Optional[List[RoutingRule]] = Field(default=None, description="路由规则")
|
||||
execution_config: ExecutionConfig = Field(default_factory=ExecutionConfig, description="执行配置")
|
||||
aggregation_strategy: str = Field(
|
||||
default="merge",
|
||||
@@ -83,7 +83,7 @@ class MultiAgentConfigCreate(BaseModel):
|
||||
class MultiAgentConfigUpdate(BaseModel):
|
||||
"""更新多 Agent 配置"""
|
||||
master_agent_id: Optional[uuid.UUID] = None
|
||||
master_agent_name: Optional[str] = Field(None, max_length=100, description="主 Agent 名称")
|
||||
master_agent_name: Optional[str] = Field(default=None, max_length=100, description="主 Agent 名称")
|
||||
default_model_config_id: Optional[uuid.UUID] = Field(None, description="默认模型配置ID")
|
||||
model_parameters: Optional[ModelParameters] = Field(
|
||||
None,
|
||||
|
||||
@@ -241,6 +241,7 @@ class SceneResponse(BaseModel):
|
||||
created_at: datetime.datetime = Field(..., description="创建时间(毫秒时间戳)")
|
||||
updated_at: datetime.datetime = Field(..., description="更新时间(毫秒时间戳)")
|
||||
classes_count: int = Field(0, description="类型数量")
|
||||
is_system_default: bool = Field(False, description="是否为系统默认场景")
|
||||
|
||||
@field_serializer("created_at", when_used="json")
|
||||
def _serialize_created_at(self, dt: datetime.datetime):
|
||||
@@ -462,6 +463,7 @@ class ClassListResponse(BaseModel):
|
||||
scene_id: UUID = Field(..., description="所属场景ID")
|
||||
scene_name: str = Field(..., description="场景名称")
|
||||
scene_description: Optional[str] = Field(None, description="场景描述")
|
||||
is_system_default: bool = Field(False, description="是否为系统默认场景")
|
||||
items: List[ClassResponse] = Field(..., description="类型列表")
|
||||
|
||||
|
||||
|
||||
@@ -263,8 +263,8 @@ def create_agent_invocation_tool(
|
||||
|
||||
try:
|
||||
# 9. 调用 Agent
|
||||
from app.services.draft_run_service import DraftRunService
|
||||
draft_service = DraftRunService(db)
|
||||
from app.services.draft_run_service import AgentRunService
|
||||
draft_service = AgentRunService(db)
|
||||
|
||||
result = await draft_service.run(
|
||||
agent_config=agent_config,
|
||||
|
||||
@@ -10,25 +10,24 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.agent.agent_middleware import AgentMiddleware
|
||||
from app.core.agent.langchain_agent import LangChainAgent
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.db import get_db, get_db_context
|
||||
from app.models import MultiAgentConfig, AgentConfig, WorkflowConfig
|
||||
from app.schemas import DraftRunRequest
|
||||
from app.schemas.app_schema import FileInput
|
||||
from app.services.tool_service import ToolService
|
||||
from app.repositories.tool_repository import ToolRepository
|
||||
from app.db import get_db
|
||||
from app.models import MultiAgentConfig, AgentConfig
|
||||
from app.models import WorkflowConfig
|
||||
from app.repositories.tool_repository import ToolRepository
|
||||
from app.schemas import DraftRunRequest
|
||||
from app.schemas.app_schema import FileInput
|
||||
from app.schemas.prompt_schema import render_prompt_message, PromptMessageRole
|
||||
from app.services.conversation_service import ConversationService
|
||||
from app.services.draft_run_service import create_knowledge_retrieval_tool, create_long_term_memory_tool
|
||||
from app.services.draft_run_service import create_knowledge_retrieval_tool, create_long_term_memory_tool, \
|
||||
AgentRunService
|
||||
from app.services.draft_run_service import create_web_search_tool
|
||||
from app.services.model_service import ModelApiKeyService
|
||||
from app.services.multi_agent_orchestrator import MultiAgentOrchestrator
|
||||
from app.services.workflow_service import WorkflowService
|
||||
from app.services.multimodal_service import MultimodalService
|
||||
from app.services.tool_service import ToolService
|
||||
from app.services.workflow_service import WorkflowService
|
||||
|
||||
logger = get_business_logger()
|
||||
|
||||
@@ -39,6 +38,8 @@ class AppChatService:
|
||||
def __init__(self, db: Session):
|
||||
self.db = db
|
||||
self.conversation_service = ConversationService(db)
|
||||
self.agent_service = AgentRunService(db)
|
||||
self.workflow_service = WorkflowService(db)
|
||||
|
||||
async def agnet_chat(
|
||||
self,
|
||||
@@ -55,12 +56,10 @@ class AppChatService:
|
||||
files: Optional[List[FileInput]] = None # 新增:多模态文件
|
||||
) -> Dict[str, Any]:
|
||||
"""聊天(非流式)"""
|
||||
|
||||
start_time = time.time()
|
||||
config_id = None
|
||||
|
||||
if variables is None:
|
||||
variables = {}
|
||||
variables = self.agent_service.prepare_variables(variables, config.variables)
|
||||
|
||||
# 获取模型配置ID
|
||||
model_config_id = config.default_model_config_id
|
||||
@@ -79,74 +78,20 @@ class AppChatService:
|
||||
tools = []
|
||||
|
||||
# 获取工具服务
|
||||
tool_service = ToolService(self.db)
|
||||
tenant_id = ToolRepository.get_tenant_id_by_workspace_id(self.db, str(workspace_id))
|
||||
|
||||
# 从配置中获取启用的工具
|
||||
if hasattr(config, 'tools') and config.tools and isinstance(config.tools, list):
|
||||
for tool_config in config.tools:
|
||||
if tool_config.get("enabled", False):
|
||||
# 根据工具名称查找工具实例
|
||||
tool_instance = tool_service._get_tool_instance(tool_config.get("tool_id", ""), tenant_id)
|
||||
if tool_instance:
|
||||
if tool_instance.name == "baidu_search_tool" and not web_search:
|
||||
continue
|
||||
# 转换为LangChain工具
|
||||
langchain_tool = tool_instance.to_langchain_tool(tool_config.get("operation", None))
|
||||
tools.append(langchain_tool)
|
||||
elif hasattr(config, 'tools') and config.tools and isinstance(config.tools, dict):
|
||||
web_tools = config.tools
|
||||
web_search_choice = web_tools.get("web_search", {})
|
||||
web_search_enable = web_search_choice.get("enabled", False)
|
||||
if web_search:
|
||||
if web_search_enable:
|
||||
search_tool = create_web_search_tool({})
|
||||
tools.append(search_tool)
|
||||
|
||||
logger.debug(
|
||||
"已添加网络搜索工具",
|
||||
extra={
|
||||
"tool_count": len(tools)
|
||||
}
|
||||
)
|
||||
|
||||
# 加载技能关联的工具
|
||||
if hasattr(config, 'skills') and config.skills:
|
||||
skills = config.skills
|
||||
skill_enable = skills.get("enabled", False)
|
||||
if skill_enable:
|
||||
middleware = AgentMiddleware(skills=skills)
|
||||
skill_tools, skill_configs, tool_to_skill_map = middleware.load_skill_tools(self.db, tenant_id)
|
||||
tools.extend(skill_tools)
|
||||
logger.debug(f"已加载 {len(skill_tools)} 个技能工具")
|
||||
|
||||
# 应用动态过滤
|
||||
if skill_configs:
|
||||
tools, activated_skill_ids = middleware.filter_tools(tools, message, skill_configs,
|
||||
tool_to_skill_map)
|
||||
logger.debug(f"过滤后剩余 {len(tools)} 个工具")
|
||||
active_prompts = AgentMiddleware.get_active_prompts(
|
||||
activated_skill_ids, skill_configs
|
||||
)
|
||||
system_prompt = f"{system_prompt}\n\n{active_prompts}"
|
||||
|
||||
# 添加知识库检索工具
|
||||
knowledge_retrieval = config.knowledge_retrieval
|
||||
if knowledge_retrieval:
|
||||
knowledge_bases = knowledge_retrieval.get("knowledge_bases", [])
|
||||
kb_ids = [kb.get("kb_id") for kb in knowledge_bases if kb.get("kb_id")]
|
||||
if kb_ids:
|
||||
kb_tool = create_knowledge_retrieval_tool(knowledge_retrieval, kb_ids, user_id)
|
||||
tools.append(kb_tool)
|
||||
|
||||
# 添加长期记忆工具
|
||||
tools.extend(self.agent_service.load_tools_config(config.tools, web_search, tenant_id))
|
||||
skill_tools, skill_prompts = self.agent_service.load_skill_config(config.skills, message, tenant_id)
|
||||
tools.extend(skill_tools)
|
||||
if skill_prompts:
|
||||
system_prompt = f"{system_prompt}\n\n{skill_prompts}"
|
||||
tools.extend(self.agent_service.load_knowledge_retrieval_config(config.knowledge_retrieval, user_id))
|
||||
memory_flag = False
|
||||
if memory == True:
|
||||
memory_config = config.memory
|
||||
if memory_config.get("enabled") and user_id:
|
||||
memory_flag = True
|
||||
memory_tool = create_long_term_memory_tool(memory_config, user_id)
|
||||
tools.append(memory_tool)
|
||||
if memory:
|
||||
memory_tools, memory_flag = self.agent_service.load_memory_config(
|
||||
config.memory, user_id, storage_type, user_rag_memory_id
|
||||
)
|
||||
tools.extend(memory_tools)
|
||||
|
||||
# 获取模型参数
|
||||
model_parameters = config.model_parameters
|
||||
@@ -157,6 +102,7 @@ class AppChatService:
|
||||
api_key=api_key_obj.api_key,
|
||||
provider=api_key_obj.provider,
|
||||
api_base=api_key_obj.api_base,
|
||||
is_omni=api_key_obj.is_omni,
|
||||
temperature=model_parameters.get("temperature", 0.7),
|
||||
max_tokens=model_parameters.get("max_tokens", 2000),
|
||||
system_prompt=system_prompt,
|
||||
@@ -180,7 +126,7 @@ class AppChatService:
|
||||
# 处理多模态文件
|
||||
processed_files = None
|
||||
if files:
|
||||
multimodal_service = MultimodalService(self.db)
|
||||
multimodal_service = MultimodalService(self.db, api_key_obj.provider, is_omni=api_key_obj.is_omni)
|
||||
processed_files = await multimodal_service.process_files(files)
|
||||
logger.info(f"处理了 {len(processed_files)} 个文件")
|
||||
|
||||
@@ -245,10 +191,9 @@ class AppChatService:
|
||||
try:
|
||||
start_time = time.time()
|
||||
config_id = None
|
||||
yield f"event: start\ndata: {json.dumps({'conversation_id': str(conversation_id)}, ensure_ascii=False)}\n\n"
|
||||
|
||||
if variables is None:
|
||||
variables = {}
|
||||
|
||||
variables = self.agent_service.prepare_variables(variables, config.variables)
|
||||
# 获取模型配置ID
|
||||
model_config_id = config.default_model_config_id
|
||||
api_key_obj = ModelApiKeyService.get_available_api_key(self.db, model_config_id)
|
||||
@@ -266,73 +211,22 @@ class AppChatService:
|
||||
tools = []
|
||||
|
||||
# 获取工具服务
|
||||
tool_service = ToolService(self.db)
|
||||
tenant_id = ToolRepository.get_tenant_id_by_workspace_id(self.db, str(workspace_id))
|
||||
|
||||
if hasattr(config, 'tools') and config.tools and isinstance(config.tools, list):
|
||||
for tool_config in config.tools:
|
||||
if tool_config.get("enabled", False):
|
||||
# 根据工具名称查找工具实例
|
||||
tool_instance = tool_service._get_tool_instance(tool_config.get("tool_id", ""), tenant_id)
|
||||
if tool_instance:
|
||||
if tool_instance.name == "baidu_search_tool" and not web_search:
|
||||
continue
|
||||
# 转换为LangChain工具
|
||||
langchain_tool = tool_instance.to_langchain_tool(tool_config.get("operation", None))
|
||||
tools.append(langchain_tool)
|
||||
elif hasattr(config, 'tools') and config.tools and isinstance(config.tools, dict):
|
||||
web_tools = config.tools
|
||||
web_search_choice = web_tools.get("web_search", {})
|
||||
web_search_enable = web_search_choice.get("enabled", False)
|
||||
if web_search:
|
||||
if web_search_enable:
|
||||
search_tool = create_web_search_tool({})
|
||||
tools.append(search_tool)
|
||||
|
||||
logger.debug(
|
||||
"已添加网络搜索工具",
|
||||
extra={
|
||||
"tool_count": len(tools)
|
||||
}
|
||||
)
|
||||
|
||||
# 加载技能关联的工具
|
||||
if hasattr(config, 'skills') and config.skills:
|
||||
skills = config.skills
|
||||
skill_enable = skills.get("enabled", False)
|
||||
if skill_enable:
|
||||
middleware = AgentMiddleware(skills=skills)
|
||||
skill_tools, skill_configs, tool_to_skill_map = middleware.load_skill_tools(self.db, tenant_id)
|
||||
tools.extend(skill_tools)
|
||||
logger.debug(f"已加载 {len(skill_tools)} 个技能工具")
|
||||
|
||||
# 应用动态过滤
|
||||
if skill_configs:
|
||||
tools, activated_skill_ids = middleware.filter_tools(tools, message, skill_configs,
|
||||
tool_to_skill_map)
|
||||
logger.debug(f"过滤后剩余 {len(tools)} 个工具")
|
||||
active_prompts = AgentMiddleware.get_active_prompts(
|
||||
activated_skill_ids, skill_configs
|
||||
)
|
||||
system_prompt = f"{system_prompt}\n\n{active_prompts}"
|
||||
|
||||
# 添加知识库检索工具
|
||||
knowledge_retrieval = config.knowledge_retrieval
|
||||
if knowledge_retrieval:
|
||||
knowledge_bases = knowledge_retrieval.get("knowledge_bases", [])
|
||||
kb_ids = [kb.get("kb_id") for kb in knowledge_bases if kb.get("kb_id")]
|
||||
if kb_ids:
|
||||
kb_tool = create_knowledge_retrieval_tool(knowledge_retrieval, kb_ids, user_id)
|
||||
tools.append(kb_tool)
|
||||
tools.extend(self.agent_service.load_tools_config(config.tools, web_search, tenant_id))
|
||||
|
||||
skill_tools, skill_prompts = self.agent_service.load_skill_config(config.skills, message, tenant_id)
|
||||
tools.extend(skill_tools)
|
||||
if skill_prompts:
|
||||
system_prompt = f"{system_prompt}\n\n{skill_prompts}"
|
||||
tools.extend(self.agent_service.load_knowledge_retrieval_config(config.knowledge_retrieval, user_id))
|
||||
# 添加长期记忆工具
|
||||
memory_flag = False
|
||||
if memory:
|
||||
memory_config = config.memory
|
||||
if memory_config.get("enabled") and user_id:
|
||||
memory_flag = True
|
||||
memory_tool = create_long_term_memory_tool(memory_config, user_id)
|
||||
tools.append(memory_tool)
|
||||
memory_tools, memory_flag = self.agent_service.load_memory_config(
|
||||
config.memory, user_id, storage_type, user_rag_memory_id
|
||||
)
|
||||
tools.extend(memory_tools)
|
||||
|
||||
# 获取模型参数
|
||||
model_parameters = config.model_parameters
|
||||
@@ -343,6 +237,7 @@ class AppChatService:
|
||||
api_key=api_key_obj.api_key,
|
||||
provider=api_key_obj.provider,
|
||||
api_base=api_key_obj.api_base,
|
||||
is_omni=api_key_obj.is_omni,
|
||||
temperature=model_parameters.get("temperature", 0.7),
|
||||
max_tokens=model_parameters.get("max_tokens", 2000),
|
||||
system_prompt=system_prompt,
|
||||
@@ -366,13 +261,10 @@ class AppChatService:
|
||||
# 处理多模态文件
|
||||
processed_files = None
|
||||
if files:
|
||||
multimodal_service = MultimodalService(self.db)
|
||||
multimodal_service = MultimodalService(self.db, api_key_obj.provider, is_omni=api_key_obj.is_omni)
|
||||
processed_files = await multimodal_service.process_files(files)
|
||||
logger.info(f"处理了 {len(processed_files)} 个文件")
|
||||
|
||||
# 发送开始事件
|
||||
yield f"event: start\ndata: {json.dumps({'conversation_id': str(conversation_id)}, ensure_ascii=False)}\n\n"
|
||||
|
||||
# 流式调用 Agent(支持多模态)
|
||||
full_content = ""
|
||||
total_tokens = 0
|
||||
@@ -416,7 +308,7 @@ class AppChatService:
|
||||
ModelApiKeyService.record_api_key_usage(self.db, api_key_obj.id)
|
||||
|
||||
# 发送结束事件
|
||||
end_data = {"elapsed_time": elapsed_time, "message_length": len(full_content)}
|
||||
end_data = {"elapsed_time": elapsed_time, "message_length": len(full_content), "error": None}
|
||||
yield f"event: end\ndata: {json.dumps(end_data, ensure_ascii=False)}\n\n"
|
||||
|
||||
logger.info(
|
||||
@@ -435,7 +327,7 @@ class AppChatService:
|
||||
except Exception as e:
|
||||
logger.error(f"流式聊天失败: {str(e)}", exc_info=True)
|
||||
# 发送错误事件
|
||||
yield f"event: error\ndata: {json.dumps({'error': str(e)}, ensure_ascii=False)}\n\n"
|
||||
yield f"event: end\ndata: {json.dumps({'error': str(e)}, ensure_ascii=False)}\n\n"
|
||||
|
||||
async def multi_agent_chat(
|
||||
self,
|
||||
@@ -489,10 +381,10 @@ class AppChatService:
|
||||
"mode": result.get("mode"),
|
||||
"elapsed_time": result.get("elapsed_time"),
|
||||
"usage": result.get("usage", {
|
||||
"prompt_tokens": 0,
|
||||
"completion_tokens": 0,
|
||||
"total_tokens": 0
|
||||
})
|
||||
"prompt_tokens": 0,
|
||||
"completion_tokens": 0,
|
||||
"total_tokens": 0
|
||||
})
|
||||
}
|
||||
)
|
||||
|
||||
@@ -522,8 +414,6 @@ class AppChatService:
|
||||
"""多 Agent 聊天(流式)"""
|
||||
|
||||
start_time = time.time()
|
||||
actual_config_id = None
|
||||
config_id = actual_config_id
|
||||
|
||||
if variables is None:
|
||||
variables = {}
|
||||
@@ -629,7 +519,6 @@ class AppChatService:
|
||||
user_rag_memory_id: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""聊天(非流式)"""
|
||||
workflow_service = WorkflowService(self.db)
|
||||
payload = DraftRunRequest(
|
||||
message=message,
|
||||
variables=variables,
|
||||
@@ -637,7 +526,7 @@ class AppChatService:
|
||||
stream=True,
|
||||
user_id=user_id
|
||||
)
|
||||
return await workflow_service.run(
|
||||
return await self.workflow_service.run(
|
||||
app_id=app_id,
|
||||
payload=payload,
|
||||
config=config,
|
||||
@@ -664,7 +553,6 @@ class AppChatService:
|
||||
|
||||
) -> AsyncGenerator[dict, None]:
|
||||
"""聊天(流式)"""
|
||||
workflow_service = WorkflowService(self.db)
|
||||
payload = DraftRunRequest(
|
||||
message=message,
|
||||
variables=variables,
|
||||
@@ -673,7 +561,7 @@ class AppChatService:
|
||||
user_id=user_id,
|
||||
files=files
|
||||
)
|
||||
async for event in workflow_service.run_stream(
|
||||
async for event in self.workflow_service.run_stream(
|
||||
app_id=app_id,
|
||||
payload=payload,
|
||||
config=config,
|
||||
|
||||
@@ -232,7 +232,7 @@ class AppService:
|
||||
# 检查主 Agent 的模型配置
|
||||
multi_agent_config.default_model_config_id = master_agent_release.default_model_config_id
|
||||
|
||||
model_api_key = ModelApiKeyService.get_a_api_key(self.db, multi_agent_config.default_model_config_id)
|
||||
model_api_key = ModelApiKeyService.get_available_api_key(self.db, multi_agent_config.default_model_config_id)
|
||||
if not model_api_key:
|
||||
raise ResourceNotFoundException("模型配置", str(multi_agent_config.default_model_config_id))
|
||||
|
||||
@@ -1791,372 +1791,6 @@ class AppService:
|
||||
|
||||
return shares
|
||||
|
||||
# ==================== 试运行功能 ====================
|
||||
|
||||
async def draft_run(
|
||||
self,
|
||||
*,
|
||||
app_id: uuid.UUID,
|
||||
message: str,
|
||||
conversation_id: Optional[str] = None,
|
||||
user_id: Optional[str] = None,
|
||||
variables: Optional[Dict[str, Any]] = None,
|
||||
workspace_id: Optional[uuid.UUID] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""试运行 Agent(使用当前草稿配置)
|
||||
|
||||
Args:
|
||||
app_id: 应用ID
|
||||
message: 用户消息
|
||||
conversation_id: 会话ID(用于多轮对话)
|
||||
user_id: 用户ID(用于会话管理)
|
||||
variables: 自定义变量参数值
|
||||
workspace_id: 工作空间ID(用于权限验证)
|
||||
|
||||
Returns:
|
||||
Dict: 包含 AI 回复和元数据的字典
|
||||
|
||||
Raises:
|
||||
ResourceNotFoundException: 当应用不存在时
|
||||
BusinessException: 当应用类型不支持或配置缺失时
|
||||
"""
|
||||
from app.services.draft_run_service import DraftRunService
|
||||
|
||||
logger.info("试运行 Agent", extra={"app_id": str(app_id), "user_message": message[:50]})
|
||||
|
||||
# 1. 验证应用
|
||||
app = self._get_app_or_404(app_id)
|
||||
|
||||
if app.type != "agent":
|
||||
raise BusinessException("只有 Agent 类型应用支持试运行", BizCode.APP_TYPE_NOT_SUPPORTED)
|
||||
|
||||
# 只读操作,允许访问共享应用
|
||||
self._validate_app_accessible(app, workspace_id)
|
||||
|
||||
# 2. 获取 Agent 配置
|
||||
stmt = select(AgentConfig).where(AgentConfig.app_id == app_id)
|
||||
agent_cfg = self.db.scalars(stmt).first()
|
||||
|
||||
if not agent_cfg:
|
||||
raise BusinessException("Agent 配置不存在,无法试运行", BizCode.AGENT_CONFIG_MISSING)
|
||||
|
||||
# 3. 获取模型配置
|
||||
model_config = None
|
||||
if agent_cfg.default_model_config_id:
|
||||
from app.models import ModelConfig
|
||||
model_config = self.db.get(ModelConfig, agent_cfg.default_model_config_id)
|
||||
|
||||
if not model_config:
|
||||
raise BusinessException("模型配置不存在,无法试运行", BizCode.AGENT_CONFIG_MISSING)
|
||||
|
||||
# 4. 调用试运行服务
|
||||
logger.debug(
|
||||
"准备调用试运行服务",
|
||||
extra={
|
||||
"app_id": str(app_id),
|
||||
"model": model_config.name,
|
||||
"has_conversation_id": bool(conversation_id),
|
||||
"has_variables": bool(variables)
|
||||
}
|
||||
)
|
||||
|
||||
draft_service = DraftRunService(self.db)
|
||||
result = await draft_service.run(
|
||||
agent_config=agent_cfg,
|
||||
model_config=model_config,
|
||||
message=message,
|
||||
workspace_id=workspace_id,
|
||||
conversation_id=conversation_id,
|
||||
user_id=user_id,
|
||||
variables=variables
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
"试运行服务返回结果",
|
||||
extra={
|
||||
"result_type": str(type(result)),
|
||||
"result_keys": list(result.keys()) if isinstance(result, dict) else "not_dict",
|
||||
"has_message": "message" in result if isinstance(result, dict) else False,
|
||||
"has_conversation_id": "conversation_id" in result if isinstance(result, dict) else False
|
||||
}
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"试运行完成",
|
||||
extra={
|
||||
"app_id": str(app_id),
|
||||
"elapsed_time": result.get("elapsed_time"),
|
||||
"model": model_config.name
|
||||
}
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
async def draft_run_stream(
|
||||
self,
|
||||
*,
|
||||
app_id: uuid.UUID,
|
||||
message: str,
|
||||
conversation_id: Optional[str] = None,
|
||||
user_id: Optional[str] = None,
|
||||
variables: Optional[Dict[str, Any]] = None,
|
||||
workspace_id: Optional[uuid.UUID] = None
|
||||
):
|
||||
"""试运行 Agent(流式返回)
|
||||
|
||||
Args:
|
||||
app_id: 应用ID
|
||||
message: 用户消息
|
||||
conversation_id: 会话ID(用于多轮对话)
|
||||
user_id: 用户ID(用于会话管理)
|
||||
variables: 自定义变量参数值
|
||||
workspace_id: 工作空间ID(用于权限验证)
|
||||
|
||||
Yields:
|
||||
str: SSE 格式的事件数据
|
||||
|
||||
Raises:
|
||||
ResourceNotFoundException: 当应用不存在时
|
||||
BusinessException: 当应用类型不支持或配置缺失时
|
||||
"""
|
||||
from app.services.draft_run_service import DraftRunService
|
||||
|
||||
logger.info("流式试运行 Agent", extra={"app_id": str(app_id), "user_message": message[:50]})
|
||||
|
||||
# 1. 验证应用
|
||||
app = self._get_app_or_404(app_id)
|
||||
|
||||
if app.type != "agent":
|
||||
raise BusinessException("只有 Agent 类型应用支持试运行", BizCode.APP_TYPE_NOT_SUPPORTED)
|
||||
|
||||
# 只读操作,允许访问共享应用
|
||||
self._validate_app_accessible(app, workspace_id)
|
||||
|
||||
# 2. 获取 Agent 配置
|
||||
stmt = select(AgentConfig).where(AgentConfig.app_id == app_id)
|
||||
agent_cfg = self.db.scalars(stmt).first()
|
||||
|
||||
if not agent_cfg:
|
||||
raise BusinessException("Agent 配置不存在,无法试运行", BizCode.AGENT_CONFIG_MISSING)
|
||||
|
||||
# 3. 获取模型配置
|
||||
model_config = None
|
||||
if agent_cfg.default_model_config_id:
|
||||
from app.models import ModelConfig
|
||||
model_config = self.db.get(ModelConfig, agent_cfg.default_model_config_id)
|
||||
|
||||
if not model_config:
|
||||
raise BusinessException("模型配置不存在,无法试运行", BizCode.AGENT_CONFIG_MISSING)
|
||||
|
||||
# 4. 调用流式试运行服务
|
||||
draft_service = DraftRunService(self.db)
|
||||
async for event in draft_service.run_stream(
|
||||
agent_config=agent_cfg,
|
||||
model_config=model_config,
|
||||
message=message,
|
||||
workspace_id=workspace_id,
|
||||
conversation_id=conversation_id,
|
||||
user_id=user_id,
|
||||
variables=variables
|
||||
):
|
||||
yield event
|
||||
|
||||
# ==================== 多模型对比试运行 ====================
|
||||
|
||||
async def draft_run_compare(
|
||||
self,
|
||||
*,
|
||||
app_id: uuid.UUID,
|
||||
message: str,
|
||||
models: List[app_schema.ModelCompareItem],
|
||||
conversation_id: Optional[str] = None,
|
||||
user_id: Optional[str] = None,
|
||||
variables: Optional[Dict[str, Any]] = None,
|
||||
workspace_id: Optional[uuid.UUID] = None,
|
||||
parallel: bool = True,
|
||||
timeout: int = 60
|
||||
) -> Dict[str, Any]:
|
||||
"""多模型对比试运行
|
||||
|
||||
Args:
|
||||
app_id: 应用ID
|
||||
message: 用户消息
|
||||
models: 要对比的模型列表
|
||||
conversation_id: 会话ID
|
||||
user_id: 用户ID
|
||||
variables: 变量参数
|
||||
workspace_id: 工作空间ID
|
||||
parallel: 是否并行执行
|
||||
timeout: 超时时间(秒)
|
||||
|
||||
Returns:
|
||||
Dict: 对比结果
|
||||
"""
|
||||
from app.models import ModelConfig
|
||||
from app.services.draft_run_service import DraftRunService
|
||||
|
||||
logger.info(
|
||||
"多模型对比试运行",
|
||||
extra={
|
||||
"app_id": str(app_id),
|
||||
"model_count": len(models),
|
||||
"parallel": parallel
|
||||
}
|
||||
)
|
||||
|
||||
# 1. 验证应用
|
||||
app = self._get_app_or_404(app_id)
|
||||
if app.type != "agent":
|
||||
raise BusinessException("只有 Agent 类型应用支持试运行", BizCode.APP_TYPE_NOT_SUPPORTED)
|
||||
|
||||
# 只读操作,允许访问共享应用
|
||||
self._validate_app_accessible(app, workspace_id)
|
||||
|
||||
# 2. 获取 Agent 配置
|
||||
stmt = select(AgentConfig).where(AgentConfig.app_id == app_id)
|
||||
agent_cfg = self.db.scalars(stmt).first()
|
||||
if not agent_cfg:
|
||||
raise BusinessException("Agent 配置不存在", BizCode.AGENT_CONFIG_MISSING)
|
||||
|
||||
# 3. 准备所有模型配置
|
||||
model_configs = []
|
||||
for model_item in models:
|
||||
model_config = self.db.get(ModelConfig, model_item.model_config_id)
|
||||
if not model_config:
|
||||
raise ResourceNotFoundException("模型配置", str(model_item.model_config_id))
|
||||
|
||||
# 合并参数:agent配置参数 + 请求覆盖参数
|
||||
merged_parameters = {
|
||||
**(agent_cfg.model_parameters or {}),
|
||||
**(model_item.model_parameters or {})
|
||||
}
|
||||
|
||||
model_configs.append({
|
||||
"model_config": model_config,
|
||||
"parameters": merged_parameters,
|
||||
"label": model_item.label or model_config.name,
|
||||
"model_config_id": model_item.model_config_id
|
||||
})
|
||||
|
||||
# 4. 调用 DraftRunService 的对比方法
|
||||
draft_service = DraftRunService(self.db)
|
||||
result = await draft_service.run_compare(
|
||||
agent_config=agent_cfg,
|
||||
models=model_configs,
|
||||
message=message,
|
||||
workspace_id=workspace_id,
|
||||
conversation_id=conversation_id,
|
||||
user_id=user_id,
|
||||
variables=variables,
|
||||
parallel=parallel,
|
||||
timeout=timeout
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"多模型对比完成",
|
||||
extra={
|
||||
"app_id": str(app_id),
|
||||
"successful": result["successful_count"],
|
||||
"failed": result["failed_count"]
|
||||
}
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
async def draft_run_compare_stream(
|
||||
self,
|
||||
*,
|
||||
app_id: uuid.UUID,
|
||||
message: str,
|
||||
models: List[app_schema.ModelCompareItem],
|
||||
conversation_id: Optional[str] = None,
|
||||
user_id: Optional[str] = None,
|
||||
variables: Optional[Dict[str, Any]] = None,
|
||||
workspace_id: Optional[uuid.UUID] = None,
|
||||
parallel: bool = True,
|
||||
timeout: int = 60
|
||||
):
|
||||
"""多模型对比试运行(流式返回)
|
||||
|
||||
Args:
|
||||
app_id: 应用ID
|
||||
message: 用户消息
|
||||
models: 要对比的模型列表
|
||||
conversation_id: 会话ID
|
||||
user_id: 用户ID
|
||||
variables: 变量参数
|
||||
workspace_id: 工作空间ID
|
||||
timeout: 超时时间(秒)
|
||||
|
||||
Yields:
|
||||
str: SSE 格式的事件数据
|
||||
"""
|
||||
from app.models import ModelConfig
|
||||
from app.services.draft_run_service import DraftRunService
|
||||
|
||||
logger.info(
|
||||
"多模型对比流式试运行",
|
||||
extra={
|
||||
"app_id": str(app_id),
|
||||
"model_count": len(models)
|
||||
}
|
||||
)
|
||||
|
||||
# 1. 验证应用
|
||||
app = self._get_app_or_404(app_id)
|
||||
if app.type != "agent":
|
||||
raise BusinessException("只有 Agent 类型应用支持试运行", BizCode.APP_TYPE_NOT_SUPPORTED)
|
||||
|
||||
# 只读操作,允许访问共享应用
|
||||
self._validate_app_accessible(app, workspace_id)
|
||||
|
||||
# 2. 获取 Agent 配置
|
||||
stmt = select(AgentConfig).where(AgentConfig.app_id == app_id)
|
||||
agent_cfg = self.db.scalars(stmt).first()
|
||||
if not agent_cfg:
|
||||
raise BusinessException("Agent 配置不存在", BizCode.AGENT_CONFIG_MISSING)
|
||||
|
||||
# 3. 准备所有模型配置
|
||||
model_configs = []
|
||||
for model_item in models:
|
||||
model_config = self.db.get(ModelConfig, model_item.model_config_id)
|
||||
if not model_config:
|
||||
raise ResourceNotFoundException("模型配置", str(model_item.model_config_id))
|
||||
|
||||
# 合并参数:agent配置参数 + 请求覆盖参数
|
||||
merged_parameters = {
|
||||
**(agent_cfg.model_parameters or {}),
|
||||
**(model_item.model_parameters or {})
|
||||
}
|
||||
|
||||
model_configs.append({
|
||||
"model_config": model_config,
|
||||
"parameters": merged_parameters,
|
||||
"label": model_item.label or model_config.name,
|
||||
"model_config_id": model_item.model_config_id
|
||||
})
|
||||
|
||||
# 4. 调用 DraftRunService 的流式对比方法
|
||||
draft_service = DraftRunService(self.db)
|
||||
async for event in draft_service.run_compare_stream(
|
||||
agent_config=agent_cfg,
|
||||
models=model_configs,
|
||||
message=message,
|
||||
workspace_id=workspace_id,
|
||||
conversation_id=conversation_id,
|
||||
user_id=user_id,
|
||||
variables=variables,
|
||||
parallel=parallel,
|
||||
timeout=timeout
|
||||
):
|
||||
yield event
|
||||
|
||||
logger.info(
|
||||
"多模型对比流式完成",
|
||||
extra={"app_id": str(app_id)}
|
||||
)
|
||||
|
||||
|
||||
# ==================== 向后兼容的函数接口 ====================
|
||||
# 保留函数接口以兼容现有代码,但内部使用服务类
|
||||
|
||||
@@ -2278,53 +1912,6 @@ def get_apps_by_ids(
|
||||
return service.get_apps_by_ids(app_ids, workspace_id)
|
||||
|
||||
|
||||
# ==================== 向后兼容的函数接口 ====================
|
||||
|
||||
async def draft_run(
|
||||
db: Session,
|
||||
*,
|
||||
app_id: uuid.UUID,
|
||||
message: str,
|
||||
conversation_id: Optional[str] = None,
|
||||
user_id: Optional[str] = None,
|
||||
variables: Optional[Dict[str, Any]] = None,
|
||||
workspace_id: Optional[uuid.UUID] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""试运行 Agent(向后兼容接口)"""
|
||||
service = AppService(db)
|
||||
return await service.draft_run(
|
||||
app_id=app_id,
|
||||
message=message,
|
||||
conversation_id=conversation_id,
|
||||
user_id=user_id,
|
||||
variables=variables,
|
||||
workspace_id=workspace_id
|
||||
)
|
||||
|
||||
|
||||
async def draft_run_stream(
|
||||
db: Session,
|
||||
*,
|
||||
app_id: uuid.UUID,
|
||||
message: str,
|
||||
conversation_id: Optional[str] = None,
|
||||
user_id: Optional[str] = None,
|
||||
variables: Optional[Dict[str, Any]] = None,
|
||||
workspace_id: Optional[uuid.UUID] = None
|
||||
):
|
||||
"""试运行 Agent 流式返回(向后兼容接口)"""
|
||||
service = AppService(db)
|
||||
async for event in service.draft_run_stream(
|
||||
app_id=app_id,
|
||||
message=message,
|
||||
conversation_id=conversation_id,
|
||||
user_id=user_id,
|
||||
variables=variables,
|
||||
workspace_id=workspace_id
|
||||
):
|
||||
yield event
|
||||
|
||||
|
||||
# ==================== 依赖注入函数 ====================
|
||||
|
||||
def get_app_service(
|
||||
|
||||
101
api/app/services/audio_transcription_service.py
Normal file
101
api/app/services/audio_transcription_service.py
Normal file
@@ -0,0 +1,101 @@
|
||||
"""
|
||||
音频转文本服务
|
||||
|
||||
支持的服务商:
|
||||
- DashScope (阿里云通义千问)
|
||||
- OpenAI Whisper
|
||||
"""
|
||||
import httpx
|
||||
|
||||
from app.core.logging_config import get_business_logger
|
||||
|
||||
logger = get_business_logger()
|
||||
|
||||
|
||||
class AudioTranscriptionService:
|
||||
"""音频转文本服务"""
|
||||
|
||||
@staticmethod
|
||||
async def transcribe_dashscope(audio_url: str, api_key: str) -> str:
|
||||
"""
|
||||
使用阿里云通义千问语音识别服务转换音频为文本
|
||||
|
||||
Args:
|
||||
audio_url: 音频文件 URL
|
||||
api_key: DashScope API Key
|
||||
|
||||
Returns:
|
||||
str: 转录的文本
|
||||
"""
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=60.0) as client:
|
||||
response = await client.post(
|
||||
"https://dashscope.aliyuncs.com/api/v1/services/audio/asr/transcription",
|
||||
headers={
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"Content-Type": "application/json",
|
||||
"X-DashScope-Async": "enable",
|
||||
},
|
||||
json={
|
||||
"model": "paraformer-v2",
|
||||
"input": {
|
||||
"file_urls": [audio_url]
|
||||
},
|
||||
"parameters": {
|
||||
"language_hints": ["zh", "en", "ja", "yue", "ko", "de", "fr", "ru"]
|
||||
}
|
||||
}
|
||||
)
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
|
||||
if result.get("output", {}).get("results"):
|
||||
text = result["output"]["results"][0].get("transcription_text", "")
|
||||
logger.info(f"音频转文本成功: {len(text)} 字符")
|
||||
return text
|
||||
|
||||
return "[音频转文本失败]"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"DashScope 音频转文本失败: {e}")
|
||||
return f"[音频转文本失败: {str(e)}]"
|
||||
|
||||
@staticmethod
|
||||
async def transcribe_openai(audio_url: str, api_key: str) -> str:
|
||||
"""
|
||||
使用 OpenAI Whisper 转换音频为文本
|
||||
|
||||
Args:
|
||||
audio_url: 音频文件 URL
|
||||
api_key: OpenAI API Key
|
||||
|
||||
Returns:
|
||||
str: 转录的文本
|
||||
"""
|
||||
try:
|
||||
# 下载音频文件
|
||||
async with httpx.AsyncClient(timeout=60.0) as client:
|
||||
audio_response = await client.get(audio_url)
|
||||
audio_response.raise_for_status()
|
||||
audio_data = audio_response.content
|
||||
|
||||
# 调用 Whisper API
|
||||
files = {"file": ("audio.mp3", audio_data, "audio/mpeg")}
|
||||
data = {"model": "whisper-1"}
|
||||
|
||||
response = await client.post(
|
||||
"https://api.openai.com/v1/audio/transcriptions",
|
||||
headers={"Authorization": f"Bearer {api_key}"},
|
||||
files=files,
|
||||
data=data
|
||||
)
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
|
||||
text = result.get("text", "")
|
||||
logger.info(f"音频转文本成功: {len(text)} 字符")
|
||||
return text
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"OpenAI Whisper 音频转文本失败: {e}")
|
||||
return f"[音频转文本失败: {str(e)}]"
|
||||
@@ -445,6 +445,7 @@ class CollaborativeOrchestrator:
|
||||
"provider": api_key_config.provider,
|
||||
"api_key": api_key_config.api_key,
|
||||
"api_base": api_key_config.api_base,
|
||||
"is_omni": api_key_config.is_omni,
|
||||
"model_parameters": config_data.get("model_parameters", {}),
|
||||
"api_key_id": api_key_config.id
|
||||
}
|
||||
@@ -511,6 +512,7 @@ class CollaborativeOrchestrator:
|
||||
provider=agent_config["provider"],
|
||||
api_key=agent_config["api_key"],
|
||||
base_url=agent_config.get("api_base"),
|
||||
is_omni=agent_config.get("is_omni", False),
|
||||
extra_params=extra_params
|
||||
)
|
||||
|
||||
|
||||
@@ -17,15 +17,18 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from app.celery_app import celery_app
|
||||
from app.core.agent.agent_middleware import AgentMiddleware
|
||||
from app.core.agent.langchain_agent import LangChainAgent
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.core.rag.nlp.search import knowledge_retrieval
|
||||
from app.db import get_db_context
|
||||
from app.models import AgentConfig, ModelConfig
|
||||
from app.repositories.tool_repository import ToolRepository
|
||||
from app.schemas.app_schema import FileInput
|
||||
from app.schemas.prompt_schema import PromptMessageRole, render_prompt_message
|
||||
from app.services import task_service
|
||||
from app.services.conversation_service import ConversationService
|
||||
from app.services.langchain_tool_server import Search
|
||||
from app.services.memory_agent_service import MemoryAgentService
|
||||
from app.services.model_parameter_merger import ModelParameterMerger
|
||||
@@ -52,8 +55,12 @@ class LongTermMemoryInput(BaseModel):
|
||||
description="经过优化重写的查询问题。请将用户的原始问题重写为更合适的检索形式,包含关键词,上下文和具体描述,注意错词检查并且改写")
|
||||
|
||||
|
||||
def create_long_term_memory_tool(memory_config: Dict[str, Any], end_user_id: str, storage_type: Optional[str] = None,
|
||||
user_rag_memory_id: Optional[str] = None):
|
||||
def create_long_term_memory_tool(
|
||||
memory_config: Dict[str, Any],
|
||||
end_user_id: str,
|
||||
storage_type: Optional[str] = None,
|
||||
user_rag_memory_id: Optional[str] = None
|
||||
):
|
||||
"""创建记忆工具,
|
||||
|
||||
|
||||
@@ -61,6 +68,7 @@ def create_long_term_memory_tool(memory_config: Dict[str, Any], end_user_id: str
|
||||
memory_config: 记忆配置
|
||||
end_user_id: 用户ID
|
||||
storage_type: 存储类型(可选)
|
||||
user_rag_memory_id: 用户RAG记忆ID(可选)
|
||||
|
||||
Returns:
|
||||
长期记忆工具
|
||||
@@ -96,9 +104,7 @@ def create_long_term_memory_tool(memory_config: Dict[str, Any], end_user_id: str
|
||||
"""
|
||||
logger.info(f" 长期记忆工具被调用!question={question}, user={end_user_id}")
|
||||
try:
|
||||
from app.db import get_db
|
||||
db = next(get_db())
|
||||
try:
|
||||
with get_db_context() as db:
|
||||
memory_content = asyncio.run(
|
||||
MemoryAgentService().read_memory(
|
||||
end_user_id=end_user_id,
|
||||
@@ -120,9 +126,6 @@ def create_long_term_memory_tool(memory_config: Dict[str, Any], end_user_id: str
|
||||
logger.info(f"读取任务状态:{status}")
|
||||
if memory_content:
|
||||
memory_content = memory_content['answer']
|
||||
|
||||
finally:
|
||||
db.close()
|
||||
logger.info(f'用户ID:Agent:{end_user_id}')
|
||||
logger.debug("调用长期记忆 API", extra={"question": question, "end_user_id": end_user_id})
|
||||
|
||||
@@ -188,7 +191,9 @@ def create_knowledge_retrieval_tool(kb_config, kb_ids, user_id):
|
||||
"""从知识库中检索相关信息。当用户的问题需要参考知识库、文档或历史记录时,使用此工具进行检索。
|
||||
|
||||
Args:
|
||||
query: 需要检索的问题或关键词
|
||||
kb_config: 知识库配置
|
||||
kb_ids: 知识库ID列表
|
||||
user_id: 用户ID
|
||||
|
||||
Returns:
|
||||
检索到的相关知识内容
|
||||
@@ -232,17 +237,141 @@ def create_knowledge_retrieval_tool(kb_config, kb_ids, user_id):
|
||||
return knowledge_retrieval_tool
|
||||
|
||||
|
||||
class DraftRunService:
|
||||
"""试运行服务类"""
|
||||
class AgentRunService:
|
||||
"""Agent运行服务类"""
|
||||
|
||||
def __init__(self, db: Session):
|
||||
"""初始化试运行服务
|
||||
"""Agent运行服务
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
"""
|
||||
self.db = db
|
||||
|
||||
@staticmethod
|
||||
def prepare_variables(
|
||||
input_vars: dict | None,
|
||||
variables_config: dict
|
||||
) -> dict:
|
||||
input_vars = input_vars or {}
|
||||
for variable in variables_config:
|
||||
if variable.get("required") and variable.get("name") not in input_vars:
|
||||
raise ValueError(f"The required parameter '{variable.get('name')}' was not provided")
|
||||
return input_vars
|
||||
|
||||
def load_tools_config(self, tools_config, web_search, tenant_id) -> list:
|
||||
"""加载工具配置"""
|
||||
if not tools_config:
|
||||
return []
|
||||
tools = []
|
||||
tool_service = ToolService(self.db)
|
||||
|
||||
if tools_config and isinstance(tools_config, list):
|
||||
for tool_config in tools_config:
|
||||
if tool_config.get("enabled", False):
|
||||
# 根据工具名称查找工具实例
|
||||
tool_instance = tool_service.get_tool_instance(tool_config.get("tool_id", ""), tenant_id)
|
||||
if tool_instance:
|
||||
if tool_instance.name == "baidu_search_tool" and not web_search:
|
||||
continue
|
||||
# 转换为LangChain工具
|
||||
langchain_tool = tool_instance.to_langchain_tool(tool_config.get("operation", None))
|
||||
tools.append(langchain_tool)
|
||||
elif tools_config and isinstance(tools_config, dict):
|
||||
web_search_choice = tools_config.get("web_search", {})
|
||||
web_search_enable = web_search_choice.get("enabled", False)
|
||||
if web_search and web_search_enable:
|
||||
search_tool = create_web_search_tool({})
|
||||
tools.append(search_tool)
|
||||
|
||||
logger.debug(
|
||||
"已添加网络搜索工具",
|
||||
extra={
|
||||
"tool_count": len(tools)
|
||||
}
|
||||
)
|
||||
return tools
|
||||
|
||||
def load_skill_config(
|
||||
self,
|
||||
skills_config: dict | None,
|
||||
message: str, tenant_id
|
||||
) -> tuple[list, str]:
|
||||
if not skills_config:
|
||||
return [], ""
|
||||
|
||||
tools = []
|
||||
skill_prompts = ""
|
||||
skill_enable = skills_config.get("enabled", False)
|
||||
if skill_enable:
|
||||
middleware = AgentMiddleware(skills=skills_config)
|
||||
skill_tools, skill_configs, tool_to_skill_map = middleware.load_skill_tools(self.db, tenant_id)
|
||||
tools.extend(skill_tools)
|
||||
logger.debug(f"已加载 {len(skill_tools)} 个技能工具")
|
||||
|
||||
if skill_configs:
|
||||
tools, activated_skill_ids = middleware.filter_tools(tools, message, skill_configs,
|
||||
tool_to_skill_map)
|
||||
logger.debug(f"过滤后剩余 {len(tools)} 个工具")
|
||||
skill_prompts = AgentMiddleware.get_active_prompts(
|
||||
activated_skill_ids, skill_configs
|
||||
)
|
||||
|
||||
return tools, skill_prompts
|
||||
|
||||
def load_knowledge_retrieval_config(
|
||||
self,
|
||||
knowledge_retrieval_config: dict | None,
|
||||
user_id
|
||||
) -> list:
|
||||
if not knowledge_retrieval_config:
|
||||
return []
|
||||
|
||||
tools = []
|
||||
knowledge_bases = knowledge_retrieval_config.get("knowledge_bases", [])
|
||||
kb_ids = bool(knowledge_bases and knowledge_bases[0].get("kb_id"))
|
||||
if kb_ids:
|
||||
# 创建知识库检索工具
|
||||
kb_tool = create_knowledge_retrieval_tool(knowledge_retrieval_config, kb_ids, user_id)
|
||||
tools.append(kb_tool)
|
||||
|
||||
logger.debug(
|
||||
"已添加知识库检索工具",
|
||||
extra={
|
||||
"kb_ids": kb_ids,
|
||||
"tool_count": len(tools)
|
||||
}
|
||||
)
|
||||
return tools
|
||||
|
||||
def load_memory_config(
|
||||
self,
|
||||
memory_config: dict | None,
|
||||
user_id,
|
||||
storage_type,
|
||||
user_rag_memory_id
|
||||
) -> tuple[list, bool]:
|
||||
"""加载长期记忆配置"""
|
||||
if not memory_config:
|
||||
return [], False
|
||||
|
||||
tools = []
|
||||
if memory_config.get("enabled"):
|
||||
if user_id:
|
||||
# 创建长期记忆工具
|
||||
memory_tool = create_long_term_memory_tool(memory_config, user_id, storage_type,
|
||||
user_rag_memory_id)
|
||||
tools.append(memory_tool)
|
||||
|
||||
logger.debug(
|
||||
"已添加长期记忆工具",
|
||||
extra={
|
||||
"user_id": user_id,
|
||||
"tool_count": len(tools)
|
||||
}
|
||||
)
|
||||
return tools, bool(memory_config.get("enabled"))
|
||||
|
||||
async def run(
|
||||
self,
|
||||
*,
|
||||
@@ -270,19 +399,21 @@ class DraftRunService:
|
||||
conversation_id: 会话ID(用于多轮对话)
|
||||
user_id: 用户ID
|
||||
variables: 自定义变量参数值
|
||||
storage_type: 存储类型(可选)
|
||||
user_rag_memory_id: 用户RAG记忆ID(可选)
|
||||
web_search: 是否启用网络搜索(默认True)
|
||||
memory: 是否启用长期记忆(默认True)
|
||||
sub_agent: 是否为子代理调用(默认False)
|
||||
files: 多模态文件列表(可选)
|
||||
|
||||
Returns:
|
||||
Dict: 包含 AI 回复和元数据的字典
|
||||
"""
|
||||
memory_flag = False
|
||||
|
||||
print('===========', storage_type)
|
||||
|
||||
print(user_id)
|
||||
if variables == None: variables = {}
|
||||
from app.core.agent.langchain_agent import LangChainAgent
|
||||
|
||||
start_time = time.time()
|
||||
tools_config: dict | list | None = agent_config.tools
|
||||
skills_config: dict | None = agent_config.skills
|
||||
knowledge_retrieval_config: dict | None = agent_config.knowledge_retrieval
|
||||
memory_config: dict | None = agent_config.memory
|
||||
|
||||
try:
|
||||
# 1. 获取 API Key 配置
|
||||
@@ -302,112 +433,40 @@ class DraftRunService:
|
||||
agent_config=agent_config
|
||||
)
|
||||
|
||||
items_params = variables
|
||||
if sub_agent:
|
||||
variables = self.prepare_variables(variables, agent_config.variables)
|
||||
else:
|
||||
# FIXME: subagent input valid
|
||||
variables = variables or {}
|
||||
|
||||
system_prompt = render_prompt_message(
|
||||
agent_config.system_prompt, # 修正拼写错误
|
||||
agent_config.system_prompt,
|
||||
PromptMessageRole.USER,
|
||||
items_params
|
||||
variables
|
||||
)
|
||||
|
||||
# 3. 处理系统提示词(支持变量替换)
|
||||
system_prompt = system_prompt.get_text_content() or "你是一个专业的AI助手"
|
||||
print('系统提示词:', system_prompt)
|
||||
|
||||
# 4. 准备工具列表
|
||||
tools = []
|
||||
|
||||
tool_service = ToolService(self.db)
|
||||
tenant_id = ToolRepository.get_tenant_id_by_workspace_id(self.db, str(workspace_id))
|
||||
|
||||
# 从配置中获取启用的工具
|
||||
if hasattr(agent_config, 'tools') and agent_config.tools and isinstance(agent_config.tools, list):
|
||||
if hasattr(agent_config, 'tools') and agent_config.tools:
|
||||
for tool_config in agent_config.tools:
|
||||
print("+" * 50)
|
||||
print(f"agent_config:{agent_config}")
|
||||
print(f"tool_config:{tool_config}")
|
||||
if tool_config.get("enabled", False):
|
||||
# 根据工具名称查找工具实例
|
||||
tool_instance = tool_service._get_tool_instance(tool_config.get("tool_id", ""), tenant_id)
|
||||
if tool_instance:
|
||||
if tool_instance.name == "baidu_search_tool" and not web_search:
|
||||
continue
|
||||
# 转换为LangChain工具
|
||||
langchain_tool = tool_instance.to_langchain_tool(tool_config.get("operation", None))
|
||||
tools.append(langchain_tool)
|
||||
elif hasattr(agent_config, 'tools') and agent_config.tools and isinstance(agent_config.tools, dict):
|
||||
web_tools = agent_config.tools
|
||||
web_search_choice = web_tools.get("web_search", {})
|
||||
web_search_enable = web_search_choice.get("enabled", False)
|
||||
if web_search:
|
||||
if web_search_enable:
|
||||
search_tool = create_web_search_tool({})
|
||||
tools.append(search_tool)
|
||||
|
||||
logger.debug(
|
||||
"已添加网络搜索工具",
|
||||
extra={
|
||||
"tool_count": len(tools)
|
||||
}
|
||||
)
|
||||
|
||||
# 加载技能关联的工具
|
||||
if hasattr(agent_config, 'skills') and agent_config.skills:
|
||||
skills = agent_config.skills
|
||||
skill_enable = skills.get("enabled", False)
|
||||
if skill_enable:
|
||||
middleware = AgentMiddleware(skills=skills)
|
||||
skill_tools, skill_configs, tool_to_skill_map = middleware.load_skill_tools(self.db, tenant_id)
|
||||
tools.extend(skill_tools)
|
||||
logger.debug(f"已加载 {len(skill_tools)} 个技能工具")
|
||||
|
||||
# 应用动态过滤
|
||||
if skill_configs:
|
||||
tools, activated_skill_ids = middleware.filter_tools(tools, message, skill_configs,
|
||||
tool_to_skill_map)
|
||||
logger.debug(f"过滤后剩余 {len(tools)} 个工具")
|
||||
active_prompts = AgentMiddleware.get_active_prompts(
|
||||
activated_skill_ids, skill_configs
|
||||
)
|
||||
system_prompt = f"{system_prompt}\n\n{active_prompts}"
|
||||
|
||||
# 添加知识库检索工具
|
||||
if agent_config.knowledge_retrieval:
|
||||
kb_config = agent_config.knowledge_retrieval
|
||||
knowledge_bases = kb_config.get("knowledge_bases", [])
|
||||
kb_ids = bool(knowledge_bases and knowledge_bases[0].get("kb_id"))
|
||||
if kb_ids:
|
||||
# 创建知识库检索工具
|
||||
kb_tool = create_knowledge_retrieval_tool(kb_config, kb_ids, user_id)
|
||||
tools.append(kb_tool)
|
||||
|
||||
logger.debug(
|
||||
"已添加知识库检索工具",
|
||||
extra={
|
||||
"kb_ids": kb_ids,
|
||||
"tool_count": len(tools)
|
||||
}
|
||||
)
|
||||
|
||||
tools.extend(self.load_tools_config(tools_config, web_search, tenant_id))
|
||||
skill_tools, skill_prompts = self.load_skill_config(skills_config, message, tenant_id)
|
||||
tools.extend(skill_tools)
|
||||
if skill_prompts:
|
||||
system_prompt = f"{system_prompt}\n\n{skill_prompts}"
|
||||
tools.extend(self.load_knowledge_retrieval_config(knowledge_retrieval_config, user_id))
|
||||
# 添加长期记忆工具
|
||||
memory_flag = False
|
||||
if memory:
|
||||
if agent_config.memory and agent_config.memory.get("enabled"):
|
||||
memory_flag = True
|
||||
|
||||
memory_config = agent_config.memory
|
||||
if user_id:
|
||||
# 创建长期记忆工具
|
||||
memory_tool = create_long_term_memory_tool(memory_config, user_id, storage_type,
|
||||
user_rag_memory_id)
|
||||
tools.append(memory_tool)
|
||||
|
||||
logger.debug(
|
||||
"已添加长期记忆工具",
|
||||
extra={
|
||||
"user_id": user_id,
|
||||
"tool_count": len(tools)
|
||||
}
|
||||
)
|
||||
memory_tools, memory_flag = self.load_memory_config(
|
||||
memory_config, user_id, storage_type, user_rag_memory_id
|
||||
)
|
||||
tools.extend(memory_tools)
|
||||
|
||||
# 4. 创建 LangChain Agent
|
||||
agent = LangChainAgent(
|
||||
@@ -415,6 +474,7 @@ class DraftRunService:
|
||||
api_key=api_key_config["api_key"],
|
||||
provider=api_key_config.get("provider", "openai"),
|
||||
api_base=api_key_config.get("api_base"),
|
||||
is_omni=api_key_config.get("is_omni", False),
|
||||
temperature=effective_params.get("temperature", 0.7),
|
||||
max_tokens=effective_params.get("max_tokens", 2000),
|
||||
system_prompt=system_prompt,
|
||||
@@ -431,7 +491,7 @@ class DraftRunService:
|
||||
|
||||
# 6. 加载历史消息
|
||||
history = []
|
||||
if agent_config.memory and agent_config.memory.get("enabled"):
|
||||
if memory_config and memory_config.get("enabled"):
|
||||
history = await self._load_conversation_history(
|
||||
conversation_id=conversation_id,
|
||||
max_history=agent_config.memory.get("max_history", 10)
|
||||
@@ -442,7 +502,7 @@ class DraftRunService:
|
||||
if files:
|
||||
# 获取 provider 信息
|
||||
provider = api_key_config.get("provider", "openai")
|
||||
multimodal_service = MultimodalService(self.db, provider=provider)
|
||||
multimodal_service = MultimodalService(self.db, provider=provider, is_omni=api_key_config.get("is_omni", False))
|
||||
processed_files = await multimodal_service.process_files(files)
|
||||
logger.info(f"处理了 {len(processed_files)} 个文件,provider={provider}")
|
||||
|
||||
@@ -481,7 +541,7 @@ class DraftRunService:
|
||||
ModelApiKeyService.record_api_key_usage(self.db, api_key_config.get("api_key_id"))
|
||||
|
||||
# 9. 保存会话消息
|
||||
if not sub_agent and agent_config.memory and agent_config.memory.get("enabled"):
|
||||
if not sub_agent and memory_config and memory_config.get("enabled"):
|
||||
await self._save_conversation_message(
|
||||
conversation_id=conversation_id,
|
||||
user_message=message,
|
||||
@@ -556,16 +616,21 @@ class DraftRunService:
|
||||
Yields:
|
||||
str: SSE 格式的事件数据
|
||||
"""
|
||||
memory_flag = False
|
||||
if variables == None: variables = {}
|
||||
|
||||
from app.core.agent.langchain_agent import LangChainAgent
|
||||
tools_config: dict | list | None = agent_config.tools
|
||||
skills_config: dict | None = agent_config.skills
|
||||
knowledge_retrieval_config: dict | None = agent_config.knowledge_retrieval
|
||||
memory_config: dict | None = agent_config.memory
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
# 1. 获取 API Key 配置
|
||||
api_key_config = await self._get_api_key(model_config.id)
|
||||
if not sub_agent:
|
||||
variables = self.prepare_variables(variables, agent_config.variables)
|
||||
else:
|
||||
# FIXME: subagent input valid
|
||||
variables = variables or {}
|
||||
|
||||
# 2. 合并模型参数
|
||||
effective_params = ModelParameterMerger.get_effective_parameters(
|
||||
@@ -587,95 +652,22 @@ class DraftRunService:
|
||||
# 4. 准备工具列表
|
||||
tools = []
|
||||
|
||||
tool_service = ToolService(self.db)
|
||||
tenant_id = ToolRepository.get_tenant_id_by_workspace_id(self.db, str(workspace_id))
|
||||
|
||||
# 从配置中获取启用的工具
|
||||
if hasattr(agent_config, 'tools') and agent_config.tools and isinstance(agent_config.tools, list):
|
||||
for tool_config in agent_config.tools:
|
||||
# print("+"*50)
|
||||
# print(f"agent_config:{agent_config}")
|
||||
# print(f"tool_config:{tool_config}")
|
||||
if tool_config.get("enabled", False):
|
||||
# 根据工具名称查找工具实例
|
||||
tool_instance = tool_service._get_tool_instance(tool_config.get("tool_id", ""), tenant_id)
|
||||
if tool_instance:
|
||||
if tool_instance.name == "baidu_search_tool" and not web_search:
|
||||
continue
|
||||
# 转换为LangChain工具
|
||||
langchain_tool = tool_instance.to_langchain_tool(tool_config.get("operation", None))
|
||||
tools.append(langchain_tool)
|
||||
elif hasattr(agent_config, 'tools') and agent_config.tools and isinstance(agent_config.tools, dict):
|
||||
web_tools = agent_config.tools
|
||||
web_search_choice = web_tools.get("web_search", {})
|
||||
web_search_enable = web_search_choice.get("enabled", False)
|
||||
if web_search:
|
||||
if web_search_enable:
|
||||
search_tool = create_web_search_tool({})
|
||||
tools.append(search_tool)
|
||||
tools.extend(self.load_tools_config(tools_config, web_search, tenant_id))
|
||||
skill_tools, skill_prompts = self.load_skill_config(skills_config, message, tenant_id)
|
||||
tools.extend(skill_tools)
|
||||
if skill_prompts:
|
||||
system_prompt = f"{system_prompt}\n\n{skill_prompts}"
|
||||
tools.extend(self.load_knowledge_retrieval_config(knowledge_retrieval_config, user_id))
|
||||
|
||||
logger.debug(
|
||||
"已添加网络搜索工具",
|
||||
extra={
|
||||
"tool_count": len(tools)
|
||||
}
|
||||
)
|
||||
|
||||
# 加载技能关联的工具
|
||||
if hasattr(agent_config, 'skills') and agent_config.skills:
|
||||
skills = agent_config.skills
|
||||
skill_enable = skills.get("enabled", False)
|
||||
if skill_enable:
|
||||
middleware = AgentMiddleware(skills=skills)
|
||||
skill_tools, skill_configs, tool_to_skill_map = middleware.load_skill_tools(self.db, tenant_id)
|
||||
tools.extend(skill_tools)
|
||||
logger.debug(f"已加载 {len(skill_tools)} 个技能工具")
|
||||
|
||||
# 应用动态过滤
|
||||
if skill_configs:
|
||||
tools, activated_skill_ids = middleware.filter_tools(tools, message, skill_configs,
|
||||
tool_to_skill_map)
|
||||
logger.debug(f"过滤后剩余 {len(tools)} 个工具")
|
||||
active_prompts = AgentMiddleware.get_active_prompts(
|
||||
activated_skill_ids, skill_configs
|
||||
)
|
||||
system_prompt = f"{system_prompt}\n\n{active_prompts}"
|
||||
|
||||
# 添加知识库检索工具
|
||||
if agent_config.knowledge_retrieval:
|
||||
kb_config = agent_config.knowledge_retrieval
|
||||
knowledge_bases = kb_config.get("knowledge_bases", [])
|
||||
kb_ids = bool(knowledge_bases and knowledge_bases[0].get("kb_id"))
|
||||
if kb_ids:
|
||||
# 创建知识库检索工具
|
||||
kb_tool = create_knowledge_retrieval_tool(kb_config, kb_ids, user_id)
|
||||
tools.append(kb_tool)
|
||||
|
||||
logger.debug(
|
||||
"已添加知识库检索工具",
|
||||
extra={
|
||||
"kb_ids": kb_ids,
|
||||
"tool_count": len(tools)
|
||||
}
|
||||
)
|
||||
# 添加长期记忆工具
|
||||
memory_flag = False
|
||||
if memory:
|
||||
if agent_config.memory and agent_config.memory.get("enabled"):
|
||||
memory_flag = True
|
||||
memory_config = agent_config.memory
|
||||
if user_id:
|
||||
# 创建长期记忆工具
|
||||
memory_tool = create_long_term_memory_tool(memory_config, user_id, storage_type,
|
||||
user_rag_memory_id)
|
||||
tools.append(memory_tool)
|
||||
|
||||
logger.debug(
|
||||
"已添加长期记忆工具",
|
||||
extra={
|
||||
"user_id": user_id,
|
||||
"tool_count": len(tools)
|
||||
}
|
||||
)
|
||||
memory_tools, memory_flag = self.load_memory_config(memory_config, user_id, storage_type,
|
||||
user_rag_memory_id)
|
||||
tools.extend(memory_tools)
|
||||
|
||||
# 4. 创建 LangChain Agent
|
||||
agent = LangChainAgent(
|
||||
@@ -683,6 +675,7 @@ class DraftRunService:
|
||||
api_key=api_key_config["api_key"],
|
||||
provider=api_key_config.get("provider", "openai"),
|
||||
api_base=api_key_config.get("api_base"),
|
||||
is_omni=api_key_config.get("is_omni", False),
|
||||
temperature=effective_params.get("temperature", 0.7),
|
||||
max_tokens=effective_params.get("max_tokens", 2000),
|
||||
system_prompt=system_prompt,
|
||||
@@ -700,10 +693,10 @@ class DraftRunService:
|
||||
|
||||
# 6. 加载历史消息
|
||||
history = []
|
||||
if agent_config.memory and agent_config.memory.get("enabled"):
|
||||
if memory_config and memory_config.get("enabled"):
|
||||
history = await self._load_conversation_history(
|
||||
conversation_id=conversation_id,
|
||||
max_history=agent_config.memory.get("max_history", 10)
|
||||
max_history=memory_config.get("max_history", 10)
|
||||
)
|
||||
|
||||
# 6. 处理多模态文件
|
||||
@@ -711,7 +704,7 @@ class DraftRunService:
|
||||
if files:
|
||||
# 获取 provider 信息
|
||||
provider = api_key_config.get("provider", "openai")
|
||||
multimodal_service = MultimodalService(self.db, provider=provider)
|
||||
multimodal_service = MultimodalService(self.db, provider=provider, is_omni=api_key_config.get("is_omni", False))
|
||||
processed_files = await multimodal_service.process_files(files)
|
||||
logger.info(f"处理了 {len(processed_files)} 个文件,provider={provider}")
|
||||
|
||||
@@ -761,7 +754,7 @@ class DraftRunService:
|
||||
})
|
||||
|
||||
# 10. 保存会话消息
|
||||
if not sub_agent and agent_config.memory and agent_config.memory.get("enabled"):
|
||||
if not sub_agent and memory_config and memory_config.get("enabled"):
|
||||
await self._save_conversation_message(
|
||||
conversation_id=conversation_id,
|
||||
user_message=message,
|
||||
@@ -809,7 +802,7 @@ class DraftRunService:
|
||||
"""
|
||||
return f"event: {event}\ndata: {json.dumps(data, ensure_ascii=False)}\n\n"
|
||||
|
||||
async def _get_api_key(self, model_config_id: uuid.UUID) -> Dict[str, str]:
|
||||
async def _get_api_key(self, model_config_id: uuid.UUID) -> Dict:
|
||||
"""获取模型的 API Key
|
||||
|
||||
Args:
|
||||
@@ -846,7 +839,8 @@ class DraftRunService:
|
||||
"provider": api_key.provider,
|
||||
"api_key": api_key.api_key,
|
||||
"api_base": api_key.api_base,
|
||||
"api_key_id": api_key.id
|
||||
"api_key_id": api_key.id,
|
||||
"is_omni": api_key.is_omni
|
||||
}
|
||||
|
||||
async def _ensure_conversation(
|
||||
@@ -966,7 +960,6 @@ class DraftRunService:
|
||||
List[Dict]: 历史消息列表
|
||||
"""
|
||||
try:
|
||||
from app.services.conversation_service import ConversationService
|
||||
|
||||
conversation_service = ConversationService(self.db)
|
||||
history = conversation_service.get_conversation_history(
|
||||
@@ -1486,6 +1479,15 @@ class DraftRunService:
|
||||
"conversation_id": returned_conversation_id,
|
||||
"content": chunk
|
||||
}))
|
||||
|
||||
if event_type == "error" and event_data:
|
||||
await event_queue.put(self._format_sse_event("model_error", {
|
||||
"model_index": idx,
|
||||
"model_config_id": model_config_id,
|
||||
"label": model_label,
|
||||
"conversation_id": returned_conversation_id,
|
||||
"error": event_data.get("error", "未知错误")
|
||||
}))
|
||||
except Exception as e:
|
||||
logger.warning(f"解析流式事件失败: {e}")
|
||||
finally:
|
||||
@@ -1670,41 +1672,3 @@ class DraftRunService:
|
||||
"total_time": sum(r.get("elapsed_time", 0) for r in results)
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
async def draft_run(
|
||||
db: Session,
|
||||
*,
|
||||
agent_config: AgentConfig,
|
||||
model_config: ModelConfig,
|
||||
message: str,
|
||||
user_id: Optional[str] = None,
|
||||
kb_ids: Optional[List[str]] = None,
|
||||
similarity_threshold: float = 0.7,
|
||||
top_k: int = 3
|
||||
) -> Dict[str, Any]:
|
||||
"""试运行 Agent(便捷函数)
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
agent_config: Agent 配置
|
||||
model_config: 模型配置
|
||||
message: 用户消息
|
||||
user_id: 用户ID
|
||||
kb_ids: 知识库ID列表
|
||||
similarity_threshold: 相似度阈值
|
||||
top_k: 检索返回的文档数量
|
||||
|
||||
Returns:
|
||||
Dict: 包含 AI 回复和元数据的字典
|
||||
"""
|
||||
service = DraftRunService(db)
|
||||
return await service.run(
|
||||
agent_config=agent_config,
|
||||
model_config=model_config,
|
||||
message=message,
|
||||
user_id=user_id,
|
||||
kb_ids=kb_ids,
|
||||
similarity_threshold=similarity_threshold,
|
||||
top_k=top_k
|
||||
)
|
||||
|
||||
@@ -843,32 +843,33 @@ class EmotionAnalyticsService:
|
||||
end_user_id: str,
|
||||
db: Session,
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""从 Redis 缓存获取个性化情绪建议
|
||||
"""从数据库获取个性化情绪建议
|
||||
|
||||
Args:
|
||||
end_user_id: 宿主ID(用户组ID)
|
||||
db: 数据库会话(保留参数以保持接口兼容性)
|
||||
db: 数据库会话
|
||||
|
||||
Returns:
|
||||
Dict: 缓存的建议数据,如果不存在或已过期返回 None
|
||||
Dict: 存储的建议数据,如果不存在返回 None
|
||||
"""
|
||||
try:
|
||||
from app.cache.memory.emotion_memory import EmotionMemoryCache
|
||||
from app.repositories.implicit_emotions_storage_repository import ImplicitEmotionsStorageRepository
|
||||
|
||||
logger.info(f"尝试从 Redis 缓存获取情绪建议: user={end_user_id}")
|
||||
logger.info(f"尝试从数据库获取情绪建议: user={end_user_id}")
|
||||
|
||||
# 从 Redis 获取缓存
|
||||
cached_data = await EmotionMemoryCache.get_emotion_suggestions(end_user_id)
|
||||
# 从数据库获取存储记录
|
||||
repo = ImplicitEmotionsStorageRepository(db)
|
||||
storage = repo.get_by_end_user_id(end_user_id)
|
||||
|
||||
if cached_data is None:
|
||||
logger.info(f"用户 {end_user_id} 的建议缓存不存在或已过期")
|
||||
if storage is None or storage.emotion_suggestions is None:
|
||||
logger.info(f"用户 {end_user_id} 的建议数据不存在")
|
||||
return None
|
||||
|
||||
logger.info(f"成功从 Redis 缓存获取建议: user={end_user_id}")
|
||||
return cached_data
|
||||
logger.info(f"成功从数据库获取建议: user={end_user_id}")
|
||||
return storage.emotion_suggestions
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"从 Redis 缓存获取建议失败: {str(e)}", exc_info=True)
|
||||
logger.error(f"从数据库获取建议失败: {str(e)}", exc_info=True)
|
||||
return None
|
||||
|
||||
async def save_suggestions_cache(
|
||||
@@ -876,36 +877,27 @@ class EmotionAnalyticsService:
|
||||
end_user_id: str,
|
||||
suggestions_data: Dict[str, Any],
|
||||
db: Session,
|
||||
expires_hours: int = 24
|
||||
expires_hours: int = 24 # 参数保留以保持接口兼容性
|
||||
) -> None:
|
||||
"""保存建议到 Redis 缓存
|
||||
"""保存建议到数据库
|
||||
|
||||
Args:
|
||||
end_user_id: 宿主ID(用户组ID)
|
||||
suggestions_data: 建议数据
|
||||
db: 数据库会话(保留参数以保持接口兼容性)
|
||||
expires_hours: 过期时间(小时),默认24小时
|
||||
db: 数据库会话
|
||||
expires_hours: 保留参数(兼容性)
|
||||
"""
|
||||
try:
|
||||
from app.cache.memory.emotion_memory import EmotionMemoryCache
|
||||
from app.repositories.implicit_emotions_storage_repository import ImplicitEmotionsStorageRepository
|
||||
|
||||
logger.info(f"保存建议到 Redis 缓存: user={end_user_id}, expires={expires_hours}小时")
|
||||
logger.info(f"保存建议到数据库: user={end_user_id}")
|
||||
|
||||
# 计算过期时间(秒)
|
||||
expire_seconds = expires_hours * 3600
|
||||
repo = ImplicitEmotionsStorageRepository(db)
|
||||
repo.update_emotion_suggestions(end_user_id, suggestions_data)
|
||||
db.commit()
|
||||
|
||||
# 保存到 Redis
|
||||
success = await EmotionMemoryCache.set_emotion_suggestions(
|
||||
user_id=end_user_id,
|
||||
suggestions_data=suggestions_data,
|
||||
expire=expire_seconds
|
||||
)
|
||||
|
||||
if success:
|
||||
logger.info(f"建议缓存保存成功: user={end_user_id}")
|
||||
else:
|
||||
logger.warning(f"建议缓存保存失败: user={end_user_id}")
|
||||
logger.info(f"建议保存成功: user={end_user_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"保存建议缓存失败: {str(e)}", exc_info=True)
|
||||
# 不抛出异常,缓存失败不应影响主流程
|
||||
db.rollback()
|
||||
logger.error(f"保存建议失败: {str(e)}", exc_info=True)
|
||||
@@ -544,6 +544,7 @@ def convert_multi_agent_config_to_handoffs(
|
||||
provider=model_api_key.provider,
|
||||
api_key=model_api_key.api_key,
|
||||
base_url=model_api_key.api_base,
|
||||
is_omni=model_api_key.is_omni,
|
||||
extra_params={
|
||||
"temperature": 0.7,
|
||||
"max_tokens": 2000,
|
||||
|
||||
@@ -422,32 +422,33 @@ class ImplicitMemoryService:
|
||||
end_user_id: str,
|
||||
db: Session
|
||||
) -> Optional[dict]:
|
||||
"""从 Redis 缓存获取完整用户画像
|
||||
"""从数据库获取完整用户画像
|
||||
|
||||
Args:
|
||||
end_user_id: 终端用户ID
|
||||
db: 数据库会话(保留参数以保持接口兼容性)
|
||||
db: 数据库会话
|
||||
|
||||
Returns:
|
||||
Dict: 缓存的画像数据,如果不存在或已过期返回 None
|
||||
Dict: 存储的画像数据,如果不存在返回 None
|
||||
"""
|
||||
try:
|
||||
from app.cache.memory.implicit_memory import ImplicitMemoryCache
|
||||
from app.repositories.implicit_emotions_storage_repository import ImplicitEmotionsStorageRepository
|
||||
|
||||
logger.info(f"尝试从 Redis 缓存获取用户画像: user={end_user_id}")
|
||||
logger.info(f"尝试从数据库获取用户画像: user={end_user_id}")
|
||||
|
||||
# 从 Redis 获取缓存
|
||||
cached_data = await ImplicitMemoryCache.get_user_profile(end_user_id)
|
||||
# 从数据库获取存储记录
|
||||
repo = ImplicitEmotionsStorageRepository(db)
|
||||
storage = repo.get_by_end_user_id(end_user_id)
|
||||
|
||||
if cached_data is None:
|
||||
logger.info(f"用户 {end_user_id} 的画像缓存不存在或已过期")
|
||||
if storage is None or storage.implicit_profile is None:
|
||||
logger.info(f"用户 {end_user_id} 的画像数据不存在")
|
||||
return None
|
||||
|
||||
logger.info(f"成功从 Redis 缓存获取用户画像: user={end_user_id}")
|
||||
return cached_data
|
||||
logger.info(f"成功从数据库获取用户画像: user={end_user_id}")
|
||||
return storage.implicit_profile
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"从 Redis 缓存获取用户画像失败: {str(e)}", exc_info=True)
|
||||
logger.error(f"从数据库获取用户画像失败: {str(e)}", exc_info=True)
|
||||
return None
|
||||
|
||||
async def save_profile_cache(
|
||||
@@ -455,36 +456,27 @@ class ImplicitMemoryService:
|
||||
end_user_id: str,
|
||||
profile_data: dict,
|
||||
db: Session,
|
||||
expires_hours: int = 168 # 默认7天
|
||||
expires_hours: int = 168 # 参数保留以保持接口兼容性
|
||||
) -> None:
|
||||
"""保存用户画像到 Redis 缓存
|
||||
"""保存用户画像到数据库
|
||||
|
||||
Args:
|
||||
end_user_id: 终端用户ID
|
||||
profile_data: 画像数据
|
||||
db: 数据库会话(保留参数以保持接口兼容性)
|
||||
expires_hours: 过期时间(小时),默认168小时(7天)
|
||||
db: 数据库会话
|
||||
expires_hours: 保留参数(兼容性)
|
||||
"""
|
||||
try:
|
||||
from app.cache.memory.implicit_memory import ImplicitMemoryCache
|
||||
from app.repositories.implicit_emotions_storage_repository import ImplicitEmotionsStorageRepository
|
||||
|
||||
logger.info(f"保存用户画像到 Redis 缓存: user={end_user_id}, expires={expires_hours}小时")
|
||||
logger.info(f"保存用户画像到数据库: user={end_user_id}")
|
||||
|
||||
# 计算过期时间(秒)
|
||||
expire_seconds = expires_hours * 3600
|
||||
repo = ImplicitEmotionsStorageRepository(db)
|
||||
repo.update_implicit_profile(end_user_id, profile_data)
|
||||
db.commit()
|
||||
|
||||
# 保存到 Redis
|
||||
success = await ImplicitMemoryCache.set_user_profile(
|
||||
user_id=end_user_id,
|
||||
profile_data=profile_data,
|
||||
expire=expire_seconds
|
||||
)
|
||||
|
||||
if success:
|
||||
logger.info(f"用户画像缓存保存成功: user={end_user_id}")
|
||||
else:
|
||||
logger.warning(f"用户画像缓存保存失败: user={end_user_id}")
|
||||
logger.info(f"用户画像保存成功: user={end_user_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"保存用户画像缓存失败: {str(e)}", exc_info=True)
|
||||
# 不抛出异常,缓存失败不应影响主流程
|
||||
db.rollback()
|
||||
logger.error(f"保存用户画像失败: {str(e)}", exc_info=True)
|
||||
|
||||
@@ -9,6 +9,8 @@ load_dotenv()
|
||||
|
||||
# 读取web_search环境变量
|
||||
web_search_value = os.getenv('web_search')
|
||||
|
||||
|
||||
def Search(query):
|
||||
url = "https://qianfan.baidubce.com/v2/ai_search/chat/completions"
|
||||
api_key = web_search_value
|
||||
@@ -18,23 +20,24 @@ def Search(query):
|
||||
"role": "user",
|
||||
"content": query
|
||||
}
|
||||
], #搜索输入
|
||||
"edition":"standard", #搜索版本。默认为standard。可选值:standard:完整版本。lite:标准版本,对召回规模和精排条数简化后的版本,时延表现更好,效果略弱于完整版。
|
||||
"search_source": "baidu_search_v2", #使用的搜索引擎版本
|
||||
"resource_type_filter": [{"type": "web","top_k": 20}], #支持设置网页、视频、图片、阿拉丁搜索模态,网页top_k最大取值为50,视频top_k最大为10,图片top_k最大为30,阿拉丁top_k最大为5
|
||||
], # 搜索输入
|
||||
"edition": "standard", # 搜索版本。默认为standard。可选值:standard:完整版本。lite:标准版本,对召回规模和精排条数简化后的版本,时延表现更好,效果略弱于完整版。
|
||||
"search_source": "baidu_search_v2", # 使用的搜索引擎版本
|
||||
"resource_type_filter": [{"type": "web", "top_k": 20}],
|
||||
# 支持设置网页、视频、图片、阿拉丁搜索模态,网页top_k最大取值为50,视频top_k最大为10,图片top_k最大为30,阿拉丁top_k最大为5
|
||||
"search_filter": {
|
||||
"range": {
|
||||
"page_time": {
|
||||
"gte": "now-1w/d", #时间查询参数,大于或等于
|
||||
"lt": "now/d", #时间查询参数,小于
|
||||
"gt": "", #时间查询参数,大于
|
||||
"lte": "" #时间查询参数,小于或等于
|
||||
"gte": "now-1w/d", # 时间查询参数,大于或等于
|
||||
"lt": "now/d", # 时间查询参数,小于
|
||||
"gt": "", # 时间查询参数,大于
|
||||
"lte": "" # 时间查询参数,小于或等于
|
||||
}
|
||||
}
|
||||
},
|
||||
"block_websites":["tieba.baidu.com"], #需要屏蔽的站点列表
|
||||
"search_recency_filter":"week", #根据网页发布时间进行筛选,可填值为:week,month,semiyear,year
|
||||
"enable_full_content":True #是否输出网页完整原文
|
||||
"block_websites": ["tieba.baidu.com"], # 需要屏蔽的站点列表
|
||||
"search_recency_filter": "week", # 根据网页发布时间进行筛选,可填值为:week,month,semiyear,year
|
||||
"enable_full_content": True # 是否输出网页完整原文
|
||||
}, ensure_ascii=False)
|
||||
headers = {
|
||||
'Content-Type': 'application/json',
|
||||
@@ -42,10 +45,10 @@ def Search(query):
|
||||
}
|
||||
|
||||
response = requests.request("POST", url, headers=headers, data=payload.encode("utf-8")).json()
|
||||
content=[]
|
||||
content = []
|
||||
for i in response['references']:
|
||||
title=i['title']
|
||||
snippet=i['snippet']
|
||||
content.append(title+';'+snippet)
|
||||
content='。'.join(content)
|
||||
return content
|
||||
title = i['title']
|
||||
snippet = i['snippet']
|
||||
content.append(title + ';' + snippet)
|
||||
content = '。'.join(content)
|
||||
return content
|
||||
|
||||
@@ -414,6 +414,7 @@ class LLMRouter:
|
||||
provider=api_key_config.provider,
|
||||
api_key=api_key_config.api_key,
|
||||
base_url=api_key_config.api_base,
|
||||
is_omni=api_key_config.is_omni,
|
||||
temperature=0.3,
|
||||
max_tokens=500
|
||||
)
|
||||
|
||||
@@ -392,6 +392,7 @@ class MasterAgentRouter:
|
||||
provider=api_key_config.provider,
|
||||
api_key=api_key_config.api_key,
|
||||
base_url=api_key_config.api_base,
|
||||
is_omni=api_key_config.is_omni,
|
||||
extra_params = extra_params
|
||||
)
|
||||
|
||||
|
||||
@@ -13,7 +13,6 @@ TODO: Refactor get_end_user_connected_config
|
||||
"""
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional
|
||||
@@ -35,12 +34,10 @@ from app.core.memory.agent.utils.messages_tools import (
|
||||
reorder_output_results,
|
||||
)
|
||||
from app.core.memory.agent.utils.type_classifier import status_typle
|
||||
from app.core.memory.agent.utils.write_tools import write # 新增:直接导入 write 函数
|
||||
from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags, get_interest_distribution
|
||||
from app.core.memory.analytics.hot_memory_tags import get_interest_distribution
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.db import get_db_context
|
||||
from app.models.knowledge_model import Knowledge, KnowledgeType
|
||||
from app.repositories.memory_short_repository import ShortTermMemoryRepository
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.schemas.memory_agent_schema import Write_UserInput
|
||||
from app.schemas.memory_config_schema import ConfigurationError
|
||||
@@ -69,7 +66,8 @@ class MemoryAgentService:
|
||||
logger.info(f"Write operation successful for group {end_user_id} with config_id {config_id}")
|
||||
# 记录成功的操作
|
||||
if audit_logger:
|
||||
audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id, success=True,
|
||||
audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id,
|
||||
success=True,
|
||||
duration=duration, details={"message_length": len(message)})
|
||||
return context
|
||||
else:
|
||||
@@ -88,8 +86,6 @@ class MemoryAgentService:
|
||||
|
||||
raise ValueError(f"写入失败: {messages}")
|
||||
|
||||
|
||||
|
||||
def extract_tool_call_info(self, event: Dict) -> bool:
|
||||
"""Extract tool call information from event"""
|
||||
last_message = event["messages"][-1]
|
||||
@@ -271,7 +267,8 @@ class MemoryAgentService:
|
||||
logger.info("Log streaming completed, cleaning up resources")
|
||||
# LogStreamer uses context manager for file handling, so cleanup is automatic
|
||||
|
||||
async def write_memory(self, end_user_id: str, messages: list[dict], config_id: Optional[uuid.UUID]|int, db: Session, storage_type: str, user_rag_memory_id: str, language: str = "zh") -> str:
|
||||
async def write_memory(self, end_user_id: str, messages: list[dict], config_id: Optional[uuid.UUID] | int,
|
||||
db: Session, storage_type: str, user_rag_memory_id: str, language: str = "zh") -> str:
|
||||
"""
|
||||
Process write operation with config_id
|
||||
|
||||
@@ -300,7 +297,8 @@ class MemoryAgentService:
|
||||
config_id = connected_config.get("memory_config_id")
|
||||
logger.info(f"Resolved config from end_user: config_id={config_id}, workspace_id={workspace_id}")
|
||||
if config_id is None and workspace_id is None:
|
||||
raise ValueError(f"No memory configuration found for end_user {end_user_id}. Please ensure the user has a connected memory configuration.")
|
||||
raise ValueError(
|
||||
f"No memory configuration found for end_user {end_user_id}. Please ensure the user has a connected memory configuration.")
|
||||
except Exception as e:
|
||||
if "No memory configuration found" in str(e):
|
||||
raise # Re-raise our specific error
|
||||
@@ -331,7 +329,8 @@ class MemoryAgentService:
|
||||
# Log failed operation
|
||||
if audit_logger:
|
||||
duration = time.time() - start_time
|
||||
audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id, success=False, duration=duration, error=error_msg)
|
||||
audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id,
|
||||
success=False, duration=duration, error=error_msg)
|
||||
|
||||
raise ValueError(error_msg)
|
||||
|
||||
@@ -351,9 +350,9 @@ class MemoryAgentService:
|
||||
langchain_messages.append(HumanMessage(content=msg['content']))
|
||||
elif msg['role'] == 'assistant':
|
||||
langchain_messages.append(AIMessage(content=msg['content']))
|
||||
print(100*'-')
|
||||
print(100 * '-')
|
||||
print(langchain_messages)
|
||||
print(100*'-')
|
||||
print(100 * '-')
|
||||
# 初始状态 - 包含所有必要字段
|
||||
initial_state = {
|
||||
"messages": langchain_messages,
|
||||
@@ -375,29 +374,28 @@ class MemoryAgentService:
|
||||
contents = massages.get('write_result')
|
||||
# Convert messages back to string for logging
|
||||
message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages])
|
||||
return self.writer_messages_deal(massagesstatus, start_time, end_user_id, config_id, message_text, contents)
|
||||
return self.writer_messages_deal(massagesstatus, start_time, end_user_id, config_id, message_text,
|
||||
contents)
|
||||
except Exception as e:
|
||||
# Ensure proper error handling and logging
|
||||
error_msg = f"Write operation failed: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
if audit_logger:
|
||||
duration = time.time() - start_time
|
||||
audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id, success=False, duration=duration, error=error_msg)
|
||||
audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id,
|
||||
success=False, duration=duration, error=error_msg)
|
||||
raise ValueError(error_msg)
|
||||
|
||||
|
||||
|
||||
|
||||
async def read_memory(
|
||||
self,
|
||||
end_user_id: str,
|
||||
message: str,
|
||||
history: List[Dict],
|
||||
search_switch: str,
|
||||
config_id: Optional[uuid.UUID]|int,
|
||||
db: Session,
|
||||
storage_type: str,
|
||||
user_rag_memory_id: str) -> Dict:
|
||||
self,
|
||||
end_user_id: str,
|
||||
message: str,
|
||||
history: List[Dict],
|
||||
search_switch: str,
|
||||
config_id: Optional[uuid.UUID] | int,
|
||||
db: Session,
|
||||
storage_type: str,
|
||||
user_rag_memory_id: str) -> Dict:
|
||||
"""
|
||||
Process read operation with config_id
|
||||
|
||||
@@ -425,7 +423,7 @@ class MemoryAgentService:
|
||||
|
||||
import time
|
||||
start_time = time.time()
|
||||
ori_message= message
|
||||
ori_message = message
|
||||
|
||||
# Resolve config_id and workspace_id
|
||||
# Always get workspace_id from end_user for fallback, even if config_id is provided
|
||||
@@ -437,7 +435,8 @@ class MemoryAgentService:
|
||||
config_id = connected_config.get("memory_config_id")
|
||||
logger.info(f"Resolved config from end_user: config_id={config_id}, workspace_id={workspace_id}")
|
||||
if config_id is None and workspace_id is None:
|
||||
raise ValueError(f"No memory configuration found for end_user {end_user_id}. Please ensure the user has a connected memory configuration.")
|
||||
raise ValueError(
|
||||
f"No memory configuration found for end_user {end_user_id}. Please ensure the user has a connected memory configuration.")
|
||||
except Exception as e:
|
||||
if "No memory configuration found" in str(e):
|
||||
raise # Re-raise our specific error
|
||||
@@ -454,7 +453,6 @@ class MemoryAgentService:
|
||||
except ImportError:
|
||||
audit_logger = None
|
||||
|
||||
|
||||
config_load_start = time.time()
|
||||
try:
|
||||
# Use a separate database session to avoid transaction failures
|
||||
@@ -562,34 +560,35 @@ class MemoryAgentService:
|
||||
from app.repositories.memory_short_repository import (
|
||||
ShortTermMemoryRepository,
|
||||
)
|
||||
|
||||
|
||||
retrieved_content = []
|
||||
repo = ShortTermMemoryRepository(db)
|
||||
|
||||
|
||||
if str(search_switch) != "2":
|
||||
for intermediate in _intermediate_outputs:
|
||||
logger.debug(f"处理中间结果: {intermediate}")
|
||||
intermediate_type = intermediate.get('type', '')
|
||||
|
||||
|
||||
if intermediate_type == "search_result":
|
||||
query = intermediate.get('query', '')
|
||||
raw_results = intermediate.get('raw_results', {})
|
||||
try:
|
||||
reranked_results = raw_results.get('reranked_results', [])
|
||||
statements = [statement['statement'] for statement in reranked_results.get('statements', [])]
|
||||
statements = [statement['statement'] for statement in
|
||||
reranked_results.get('statements', [])]
|
||||
except Exception:
|
||||
statements = []
|
||||
|
||||
|
||||
# 去重
|
||||
statements = list(set(statements))
|
||||
|
||||
|
||||
if query and statements:
|
||||
retrieved_content.append({query: statements})
|
||||
|
||||
|
||||
# 如果 retrieved_content 为空,设置为空字符串
|
||||
if retrieved_content == []:
|
||||
retrieved_content = ''
|
||||
|
||||
|
||||
# 只有当回答不是"信息不足"且不是快速检索时才保存
|
||||
if '信息不足,无法回答。' != str(summary) and str(search_switch).strip() != "2":
|
||||
# 使用 upsert 方法
|
||||
@@ -602,15 +601,17 @@ class MemoryAgentService:
|
||||
)
|
||||
logger.info(f"成功保存短期记忆: end_user_id={end_user_id}, search_switch={search_switch}")
|
||||
else:
|
||||
logger.debug(f"跳过保存短期记忆: summary={summary[:50] if summary else 'None'}, search_switch={search_switch}")
|
||||
|
||||
logger.debug(
|
||||
f"跳过保存短期记忆: summary={summary[:50] if summary else 'None'}, search_switch={search_switch}")
|
||||
|
||||
except Exception as save_error:
|
||||
# 保存失败不应该影响主流程,只记录错误
|
||||
logger.error(f"保存短期记忆失败: {str(save_error)}", exc_info=True)
|
||||
|
||||
# Log successful operation
|
||||
total_time = time.time() - start_time
|
||||
logger.info(f"[PERF] read_memory completed successfully in {total_time:.4f}s (config: {config_load_time:.4f}s, graph: {graph_exec_time:.4f}s)")
|
||||
logger.info(
|
||||
f"[PERF] read_memory completed successfully in {total_time:.4f}s (config: {config_load_time:.4f}s, graph: {graph_exec_time:.4f}s)")
|
||||
if audit_logger:
|
||||
duration = time.time() - start_time
|
||||
audit_logger.log_operation(
|
||||
@@ -641,7 +642,6 @@ class MemoryAgentService:
|
||||
)
|
||||
raise ValueError(error_msg)
|
||||
|
||||
|
||||
def get_messages_list(self, user_input: Write_UserInput) -> list[dict]:
|
||||
"""
|
||||
Get standardized message list from user input.
|
||||
@@ -657,41 +657,43 @@ class MemoryAgentService:
|
||||
"""
|
||||
from app.core.logging_config import get_api_logger
|
||||
logger = get_api_logger()
|
||||
|
||||
|
||||
if len(user_input.messages) == 0:
|
||||
logger.error("Validation failed: Message list cannot be empty")
|
||||
raise ValueError("Message list cannot be empty")
|
||||
|
||||
|
||||
for idx, msg in enumerate(user_input.messages):
|
||||
if not isinstance(msg, dict):
|
||||
logger.error(f"Validation failed: Message {idx} is not a dict: {type(msg)}")
|
||||
raise ValueError(f"Message format error: Message must be a dictionary. Error message index: {idx}, type: {type(msg)}")
|
||||
|
||||
raise ValueError(
|
||||
f"Message format error: Message must be a dictionary. Error message index: {idx}, type: {type(msg)}")
|
||||
|
||||
if 'role' not in msg:
|
||||
logger.error(f"Validation failed: Message {idx} missing 'role' field: {msg}")
|
||||
raise ValueError(f"Message format error: Message must contain 'role' field. Error message index: {idx}")
|
||||
|
||||
|
||||
if 'content' not in msg:
|
||||
logger.error(f"Validation failed: Message {idx} missing 'content' field: {msg}")
|
||||
raise ValueError(f"Message format error: Message must contain 'content' field. Error message index: {idx}")
|
||||
|
||||
raise ValueError(
|
||||
f"Message format error: Message must contain 'content' field. Error message index: {idx}")
|
||||
|
||||
if msg['role'] not in ['user', 'assistant']:
|
||||
logger.error(f"Validation failed: Message {idx} invalid role: {msg['role']}")
|
||||
raise ValueError(f"Role must be 'user' or 'assistant', got: {msg['role']}. Message index: {idx}")
|
||||
|
||||
|
||||
if not msg['content'] or not msg['content'].strip():
|
||||
logger.error(f"Validation failed: Message {idx} content is empty")
|
||||
raise ValueError(f"Message content cannot be empty. Message index: {idx}, role: {msg['role']}")
|
||||
|
||||
|
||||
logger.info(f"Validation successful: Structured message list, count: {len(user_input.messages)}")
|
||||
return user_input.messages
|
||||
|
||||
async def classify_message_type(
|
||||
self,
|
||||
message: str,
|
||||
config_id: UUID,
|
||||
db: Session,
|
||||
workspace_id: Optional[UUID] = None
|
||||
self,
|
||||
message: str,
|
||||
config_id: UUID,
|
||||
db: Session,
|
||||
workspace_id: Optional[UUID] = None
|
||||
) -> Dict:
|
||||
"""
|
||||
Determine the type of user message (read or write)
|
||||
@@ -719,14 +721,15 @@ class MemoryAgentService:
|
||||
status = await status_typle(message, memory_config.llm_model_id)
|
||||
logger.debug(f"Message type: {status}")
|
||||
return status
|
||||
|
||||
async def generate_summary_from_retrieve(
|
||||
self,
|
||||
end_user_id: str,
|
||||
retrieve_info: str,
|
||||
history: List[Dict],
|
||||
query: str,
|
||||
config_id: str,
|
||||
db: Session
|
||||
self,
|
||||
end_user_id: str,
|
||||
retrieve_info: str,
|
||||
history: List[Dict],
|
||||
query: str,
|
||||
config_id: str,
|
||||
db: Session
|
||||
) -> str:
|
||||
"""
|
||||
基于检索信息、历史对话和查询生成最终答案
|
||||
@@ -761,9 +764,9 @@ class MemoryAgentService:
|
||||
if config_id is None:
|
||||
raise ValueError(f"Unable to determine memory configuration for end_user {end_user_id}: {e}")
|
||||
# If config_id was provided, continue without workspace_id fallback
|
||||
|
||||
|
||||
logger.info(f"Generating summary from retrieve info for query: {query[:50]}...")
|
||||
|
||||
|
||||
try:
|
||||
# 加载配置
|
||||
config_service = MemoryConfigService(db)
|
||||
@@ -772,7 +775,7 @@ class MemoryAgentService:
|
||||
workspace_id=workspace_id,
|
||||
service_name="MemoryAgentService"
|
||||
)
|
||||
|
||||
|
||||
# 导入必要的模块
|
||||
from app.core.memory.agent.langgraph_graph.nodes.summary_nodes import (
|
||||
summary_llm,
|
||||
@@ -780,13 +783,13 @@ class MemoryAgentService:
|
||||
from app.core.memory.agent.models.summary_models import (
|
||||
RetrieveSummaryResponse,
|
||||
)
|
||||
|
||||
|
||||
# 构建状态对象
|
||||
state = {
|
||||
"data": query,
|
||||
"memory_config": memory_config
|
||||
}
|
||||
|
||||
|
||||
# 直接调用 summary_llm 函数
|
||||
answer = await summary_llm(
|
||||
state=state,
|
||||
@@ -797,21 +800,20 @@ class MemoryAgentService:
|
||||
response_model=RetrieveSummaryResponse,
|
||||
search_mode="1"
|
||||
)
|
||||
|
||||
|
||||
logger.info(f"Successfully generated summary: {answer[:100] if answer else 'None'}...")
|
||||
return answer if answer else "信息不足,无法回答。"
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"生成摘要失败: {str(e)}", exc_info=True)
|
||||
return "信息不足,无法回答。"
|
||||
|
||||
|
||||
async def get_knowledge_type_stats(
|
||||
self,
|
||||
end_user_id: Optional[str] = None,
|
||||
only_active: bool = True,
|
||||
current_workspace_id: Optional[uuid.UUID] = None,
|
||||
db: Session = None
|
||||
self,
|
||||
db: Session,
|
||||
end_user_id: Optional[str] = None,
|
||||
only_active: bool = True,
|
||||
current_workspace_id: Optional[uuid.UUID] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
统计知识库类型分布,包含:
|
||||
@@ -837,11 +839,6 @@ class MemoryAgentService:
|
||||
|
||||
# 1. 统计 PostgreSQL 中的知识库类型
|
||||
try:
|
||||
if db is None:
|
||||
from app.db import get_db
|
||||
db_gen = get_db()
|
||||
db = next(db_gen)
|
||||
|
||||
# 初始化所有标准类型为 0
|
||||
for kb_type in KnowledgeType:
|
||||
result[kb_type.value] = 0
|
||||
@@ -881,21 +878,19 @@ class MemoryAgentService:
|
||||
|
||||
# 3. 计算知识库类型总和(不包括 memory)
|
||||
result["total"] = (
|
||||
result.get("General", 0) +
|
||||
result.get("Web", 0) +
|
||||
result.get("Third-party", 0) +
|
||||
result.get("Folder", 0)
|
||||
result.get("General", 0) +
|
||||
result.get("Web", 0) +
|
||||
result.get("Third-party", 0) +
|
||||
result.get("Folder", 0)
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
|
||||
async def get_interest_distribution_by_user(
|
||||
self,
|
||||
end_user_id: Optional[str] = None,
|
||||
limit: int = 5,
|
||||
language: str = "zh"
|
||||
self,
|
||||
end_user_id: Optional[str] = None,
|
||||
limit: int = 5,
|
||||
language: str = "zh"
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
获取指定用户的兴趣分布标签。
|
||||
@@ -921,13 +916,12 @@ class MemoryAgentService:
|
||||
logger.error(f"兴趣分布标签查询失败: {e}")
|
||||
raise Exception(f"兴趣分布标签查询失败: {e}")
|
||||
|
||||
|
||||
async def get_user_profile(
|
||||
self,
|
||||
end_user_id: Optional[str] = None,
|
||||
current_user_id: Optional[str] = None,
|
||||
llm_id: Optional[str] = None,
|
||||
db: Session = None
|
||||
self,
|
||||
end_user_id: Optional[str] = None,
|
||||
current_user_id: Optional[str] = None,
|
||||
llm_id: Optional[str] = None,
|
||||
db: Session = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
获取用户详情,包含:
|
||||
@@ -1017,7 +1011,8 @@ class MemoryAgentService:
|
||||
|
||||
# 定义标签提取的结构
|
||||
class UserTags(BaseModel):
|
||||
tags: list[str] = Field(..., description="3个描述用户特征的标签,如:产品设计师、旅行爱好者、摄影发烧友")
|
||||
tags: list[str] = Field(...,
|
||||
description="3个描述用户特征的标签,如:产品设计师、旅行爱好者、摄影发烧友")
|
||||
|
||||
messages = [
|
||||
{
|
||||
@@ -1160,7 +1155,6 @@ def get_end_user_connected_config(end_user_id: str, db: Session) -> Dict[str, An
|
||||
ValueError: 当终端用户不存在或应用未发布时
|
||||
"""
|
||||
import json as json_module
|
||||
import uuid
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
@@ -1192,14 +1186,14 @@ def get_end_user_connected_config(end_user_id: str, db: Session) -> Dict[str, An
|
||||
|
||||
# 3. 兼容旧数据:如果 memory_config_id 为空,从 AppRelease.config 获取并回填
|
||||
memory_config_id_to_use = end_user.memory_config_id
|
||||
|
||||
|
||||
# 如果已有 memory_config_id,直接使用
|
||||
# 如果新创建enduser,enduser.memory_config_id 必定为none
|
||||
# 那么使用从release中获取memory_config_id为预期行为,并且回填到
|
||||
# end_user.memory_config_id
|
||||
if not memory_config_id_to_use:
|
||||
logger.info(f"end_user.memory_config_id is None, migrating from AppRelease.config")
|
||||
|
||||
|
||||
# 获取最新发布版本
|
||||
stmt = (
|
||||
select(AppRelease)
|
||||
@@ -1208,10 +1202,10 @@ def get_end_user_connected_config(end_user_id: str, db: Session) -> Dict[str, An
|
||||
)
|
||||
# TODO: change to current_release_id
|
||||
latest_release = db.scalars(stmt).first()
|
||||
|
||||
|
||||
if latest_release:
|
||||
config = latest_release.config or {}
|
||||
|
||||
|
||||
# 如果 config 是字符串,解析为字典
|
||||
if isinstance(config, str):
|
||||
try:
|
||||
@@ -1219,22 +1213,22 @@ def get_end_user_connected_config(end_user_id: str, db: Session) -> Dict[str, An
|
||||
except json_module.JSONDecodeError:
|
||||
logger.warning(f"Failed to parse config JSON for release {latest_release.id}")
|
||||
config = {}
|
||||
|
||||
|
||||
# 使用 MemoryConfigService 的提取方法
|
||||
memory_config_service = MemoryConfigService(db)
|
||||
legacy_config_id, is_legacy_int = memory_config_service.extract_memory_config_id(
|
||||
app_type=app.type,
|
||||
config=config
|
||||
)
|
||||
|
||||
|
||||
if legacy_config_id:
|
||||
# 验证提取的 config_id 是否存在于数据库中
|
||||
from app.models.memory_config_model import MemoryConfig as MemoryConfigModel
|
||||
existing_config = db.get(MemoryConfigModel, legacy_config_id)
|
||||
|
||||
|
||||
if existing_config:
|
||||
memory_config_id_to_use = legacy_config_id
|
||||
|
||||
|
||||
# 回填到 end_user 表(lazy update)
|
||||
end_user.memory_config_id = memory_config_id_to_use
|
||||
db.commit()
|
||||
@@ -1268,7 +1262,8 @@ def get_end_user_connected_config(end_user_id: str, db: Session) -> Dict[str, An
|
||||
"workspace_id": str(app.workspace_id)
|
||||
}
|
||||
|
||||
logger.info(f"Successfully retrieved connected config: memory_config_id={memory_config_id}, workspace_id={app.workspace_id}")
|
||||
logger.info(
|
||||
f"Successfully retrieved connected config: memory_config_id={memory_config_id}, workspace_id={app.workspace_id}")
|
||||
return result
|
||||
|
||||
|
||||
@@ -1312,7 +1307,7 @@ def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session)
|
||||
|
||||
# 1. 批量查询所有 end_user 及其 app_id 和 memory_config_id
|
||||
end_users = db.query(EndUser).filter(EndUser.id.in_(end_user_ids)).all()
|
||||
|
||||
|
||||
# 创建映射 - 保留 EndUser 对象引用以便回填
|
||||
end_user_map = {str(eu.id): eu for eu in end_users}
|
||||
user_data = {str(eu.id): {"app_id": eu.app_id, "memory_config_id": eu.memory_config_id} for eu in end_users}
|
||||
@@ -1336,15 +1331,15 @@ def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session)
|
||||
|
||||
# 3. 对于没有 memory_config_id 的用户,尝试从 AppRelease.config 提取
|
||||
users_needing_migration = [
|
||||
(end_user_id, data["app_id"])
|
||||
for end_user_id, data in user_data.items()
|
||||
(end_user_id, data["app_id"])
|
||||
for end_user_id, data in user_data.items()
|
||||
if not data["memory_config_id"]
|
||||
]
|
||||
|
||||
|
||||
if users_needing_migration:
|
||||
# 批量获取相关应用的最新发布版本
|
||||
migration_app_ids = list(set(app_id for _, app_id in users_needing_migration))
|
||||
|
||||
|
||||
# 查询每个应用的最新活跃发布版本
|
||||
app_latest_releases = {}
|
||||
for app_id in migration_app_ids:
|
||||
@@ -1357,18 +1352,18 @@ def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session)
|
||||
latest_release = db.scalars(stmt).first()
|
||||
if latest_release:
|
||||
app_latest_releases[app_id] = latest_release
|
||||
|
||||
|
||||
# 为每个需要迁移的用户提取 memory_config_id
|
||||
config_service = MemoryConfigService(db)
|
||||
users_to_backfill = [] # [(end_user, memory_config_id), ...]
|
||||
|
||||
|
||||
for end_user_id, app_id in users_needing_migration:
|
||||
latest_release = app_latest_releases.get(app_id)
|
||||
if not latest_release:
|
||||
continue
|
||||
|
||||
|
||||
config = latest_release.config or {}
|
||||
|
||||
|
||||
# 如果 config 是字符串,解析为字典
|
||||
if isinstance(config, str):
|
||||
try:
|
||||
@@ -1376,21 +1371,21 @@ def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session)
|
||||
except json_module.JSONDecodeError:
|
||||
logger.warning(f"Failed to parse config JSON for release {latest_release.id}")
|
||||
continue
|
||||
|
||||
|
||||
# 使用 MemoryConfigService 的提取方法
|
||||
app = app_map.get(app_id)
|
||||
if not app:
|
||||
continue
|
||||
|
||||
|
||||
legacy_config_id, is_legacy_int = config_service.extract_memory_config_id(
|
||||
app_type=app.type,
|
||||
config=config
|
||||
)
|
||||
|
||||
|
||||
if legacy_config_id:
|
||||
# 更新 user_data 中的 memory_config_id
|
||||
user_data[end_user_id]["memory_config_id"] = legacy_config_id
|
||||
|
||||
|
||||
# 记录需要回填的用户(稍后验证配置存在后再回填)
|
||||
end_user = end_user_map.get(end_user_id)
|
||||
if end_user:
|
||||
@@ -1399,7 +1394,7 @@ def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session)
|
||||
logger.info(
|
||||
f"Legacy int config detected for end_user {end_user_id}, will use workspace default"
|
||||
)
|
||||
|
||||
|
||||
# 验证提取的 config_id 是否存在于数据库中
|
||||
if users_to_backfill:
|
||||
config_ids_to_validate = list(set(cid for _, cid in users_to_backfill))
|
||||
@@ -1407,17 +1402,17 @@ def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session)
|
||||
MemoryConfig.config_id.in_(config_ids_to_validate)
|
||||
).all()
|
||||
valid_config_ids = {mc.config_id for mc in existing_configs}
|
||||
|
||||
|
||||
# 只回填存在的配置
|
||||
valid_backfills = [
|
||||
(eu, cid) for eu, cid in users_to_backfill
|
||||
(eu, cid) for eu, cid in users_to_backfill
|
||||
if cid in valid_config_ids
|
||||
]
|
||||
invalid_backfills = [
|
||||
(eu, cid) for eu, cid in users_to_backfill
|
||||
(eu, cid) for eu, cid in users_to_backfill
|
||||
if cid not in valid_config_ids
|
||||
]
|
||||
|
||||
|
||||
if invalid_backfills:
|
||||
invalid_ids = [str(cid) for _, cid in invalid_backfills]
|
||||
logger.warning(
|
||||
@@ -1426,7 +1421,7 @@ def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session)
|
||||
# 清除 user_data 中无效的 config_id
|
||||
for eu, cid in invalid_backfills:
|
||||
user_data[str(eu.id)]["memory_config_id"] = None
|
||||
|
||||
|
||||
# 批量回填 end_user.memory_config_id
|
||||
if valid_backfills:
|
||||
for end_user, memory_config_id in valid_backfills:
|
||||
@@ -1437,7 +1432,7 @@ def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session)
|
||||
# 4. 收集需要查询的 memory_config_id 和需要回退的 workspace_id
|
||||
direct_config_ids = []
|
||||
workspace_fallback_users = [] # [(end_user_id, workspace_id), ...]
|
||||
|
||||
|
||||
for end_user_id, data in user_data.items():
|
||||
if data["memory_config_id"]:
|
||||
direct_config_ids.append(data["memory_config_id"])
|
||||
@@ -1455,7 +1450,7 @@ def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session)
|
||||
# 6. 获取工作空间默认配置(需要逐个查询,因为 get_workspace_default_config 有复杂逻辑)
|
||||
workspace_default_configs = {}
|
||||
unique_workspace_ids = list(set(ws_id for _, ws_id in workspace_fallback_users))
|
||||
|
||||
|
||||
if unique_workspace_ids:
|
||||
config_service = MemoryConfigService(db)
|
||||
for workspace_id in unique_workspace_ids:
|
||||
@@ -1466,11 +1461,11 @@ def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session)
|
||||
# 7. 构建最终结果
|
||||
for end_user_id, data in user_data.items():
|
||||
memory_config = None
|
||||
|
||||
|
||||
# 优先使用 end_user 直接分配的配置
|
||||
if data["memory_config_id"]:
|
||||
memory_config = config_id_to_config.get(data["memory_config_id"])
|
||||
|
||||
|
||||
# 回退到工作空间默认配置
|
||||
if not memory_config:
|
||||
workspace_id = app_to_workspace.get(data["app_id"])
|
||||
@@ -1486,4 +1481,4 @@ def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session)
|
||||
result[end_user_id] = {"memory_config_id": None, "memory_config_name": None}
|
||||
|
||||
logger.info(f"Successfully retrieved {len(result)} connected configs")
|
||||
return result
|
||||
return result
|
||||
|
||||
@@ -140,9 +140,11 @@ class MemoryAPIService:
|
||||
|
||||
try:
|
||||
# Delegate to MemoryAgentService
|
||||
# Convert string message to list[dict] format expected by MemoryAgentService
|
||||
messages = message if isinstance(message, list) else [{"role": "user", "content": message}]
|
||||
result = await MemoryAgentService().write_memory(
|
||||
end_user_id=end_user_id,
|
||||
messages=message,
|
||||
messages=messages,
|
||||
config_id=config_id,
|
||||
db=self.db,
|
||||
storage_type=storage_type,
|
||||
@@ -151,9 +153,18 @@ class MemoryAPIService:
|
||||
|
||||
logger.info(f"Memory write successful for end_user: {end_user_id}")
|
||||
|
||||
# result may be a string "success" or a dict with a "status" key
|
||||
# Preserve the full dict so callers don't silently lose extra fields
|
||||
# (e.g. error codes, metadata) returned by MemoryAgentService.
|
||||
if isinstance(result, dict):
|
||||
return {
|
||||
**result,
|
||||
"status": result.get("status", "unknown"),
|
||||
"end_user_id": end_user_id,
|
||||
}
|
||||
return {
|
||||
"status": "success" if result == "success" else result,
|
||||
"end_user_id": end_user_id
|
||||
"status": result if isinstance(result, str) else "success",
|
||||
"end_user_id": end_user_id,
|
||||
}
|
||||
|
||||
except ConfigurationError as e:
|
||||
|
||||
@@ -390,19 +390,59 @@ def get_rag_total_kb(
|
||||
current_user: User
|
||||
) -> int:
|
||||
"""
|
||||
根据当前用户所在的workspace_id查询konwledges表所有不同id的数量
|
||||
根据当前用户所在的workspace_id查询konwledges表中排除用户知识库(permission_id!='Memory')的数量
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
business_logger.info(f"获取RAG总知识库数: workspace_id={workspace_id}, 操作者: {current_user.username}")
|
||||
business_logger.info(f"获取RAG总知识库数(排除用户知识库): workspace_id={workspace_id}, 操作者: {current_user.username}")
|
||||
|
||||
try:
|
||||
total_kb = knowledge_repository.get_total_kb_count_by_workspace(db, workspace_id)
|
||||
total_kb = knowledge_repository.get_non_user_kb_count_by_workspace(db, workspace_id)
|
||||
business_logger.info(f"成功获取RAG总知识库数: {total_kb}")
|
||||
return total_kb
|
||||
except Exception as e:
|
||||
business_logger.error(f"获取RAG总知识库数失败: workspace_id={workspace_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
def get_rag_user_kb_total_chunk(
|
||||
db: Session,
|
||||
current_user: User
|
||||
) -> int:
|
||||
"""
|
||||
根据当前用户所在的workspace_id,从documents表统计所有用户知识库的chunk总数。
|
||||
与 /end_users 接口保持同源:查询 file_name 匹配 end_user_id.txt 的文档 chunk_num 之和。
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
business_logger.info(f"获取用户知识库总chunk数(documents表): workspace_id={workspace_id}, 操作者: {current_user.username}")
|
||||
|
||||
try:
|
||||
from app.models.document_model import Document
|
||||
from app.models.end_user_model import EndUser
|
||||
from app.models.app_model import App
|
||||
from sqlalchemy import func
|
||||
|
||||
# 通过 App 关联取该 workspace 下所有 end_user_id
|
||||
end_user_ids = [
|
||||
str(eid) for (eid,) in db.query(EndUser.id)
|
||||
.join(App, EndUser.app_id == App.id)
|
||||
.filter(App.workspace_id == workspace_id)
|
||||
.all()
|
||||
]
|
||||
if not end_user_ids:
|
||||
return 0
|
||||
|
||||
file_names = [f"{uid}.txt" for uid in end_user_ids]
|
||||
result = db.query(func.sum(Document.chunk_num)).filter(
|
||||
Document.file_name.in_(file_names)
|
||||
).scalar()
|
||||
|
||||
total_chunk = int(result or 0)
|
||||
business_logger.info(f"成功获取用户知识库总chunk数: {total_chunk}")
|
||||
return total_chunk
|
||||
except Exception as e:
|
||||
business_logger.error(f"获取用户知识库总chunk数失败: workspace_id={workspace_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
def get_current_user_total_chunk(
|
||||
end_user_id: str,
|
||||
db: Session,
|
||||
|
||||
@@ -1,45 +1,42 @@
|
||||
# 修改 memory_konwledges_server.py 文件
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import re
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from fastapi import HTTPException, status
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.celery_app import celery_app
|
||||
from app.core.config import settings
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.core.rag.models.chunk import DocumentChunk
|
||||
from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory
|
||||
from app.core.response_utils import success
|
||||
from app.db import get_db
|
||||
from app.schemas import file_schema, document_schema
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, File, UploadFile, Query
|
||||
from app.db import get_db_context
|
||||
from app.models.document_model import Document
|
||||
import uuid
|
||||
from sqlalchemy.orm import Session
|
||||
from fastapi import HTTPException, status
|
||||
|
||||
from app.core.config import settings
|
||||
from app.models.user_model import User
|
||||
from app.schemas import file_schema, document_schema
|
||||
from app.schemas.file_schema import CustomTextFileCreate
|
||||
from app.services import document_service, file_service, knowledge_service
|
||||
from app.celery_app import celery_app
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.schemas.file_schema import CustomTextFileCreate
|
||||
from app.db import get_db
|
||||
|
||||
# 创建一个简单的用户类用于测试
|
||||
api_logger = get_api_logger()
|
||||
|
||||
|
||||
class ChunkCreate(BaseModel):
|
||||
content: str
|
||||
|
||||
|
||||
class SimpleUser:
|
||||
def __init__(self, user_id: str):
|
||||
# 确保ID是UUID类型
|
||||
self.id = user_id
|
||||
self.username = user_id
|
||||
|
||||
'''解析'''
|
||||
|
||||
async def parse_document_by_id(document_id: uuid.UUID, db: Session, current_user: User):
|
||||
"""
|
||||
解析指定文档
|
||||
@@ -120,7 +117,7 @@ async def parse_document_by_id(document_id: uuid.UUID, db: Session, current_user
|
||||
api_logger.error(f"文档解析失败: document_id={document_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
'''获取块ID'''
|
||||
|
||||
async def get_document_chunks(
|
||||
kb_id: uuid.UUID,
|
||||
document_id: uuid.UUID,
|
||||
@@ -198,7 +195,7 @@ async def get_document_chunks(
|
||||
|
||||
return success(data=result, msg="文档块列表查询成功")
|
||||
|
||||
'''查找文档ID'''
|
||||
|
||||
def find_document_id_by_kb_and_filename(
|
||||
db: Session,
|
||||
kb_id: str,
|
||||
@@ -231,7 +228,7 @@ def find_document_id_by_kb_and_filename(
|
||||
except Exception as e:
|
||||
return None
|
||||
|
||||
'''获取知识库ID'''
|
||||
|
||||
def find_documents_by_kb_id(
|
||||
db: Session,
|
||||
kb_id: str,
|
||||
@@ -268,18 +265,14 @@ def find_documents_by_kb_id(
|
||||
except Exception as e:
|
||||
return []
|
||||
|
||||
''''上传文件'''
|
||||
|
||||
async def memory_konwledges_up(
|
||||
kb_id: str,
|
||||
parent_id: str,
|
||||
create_data: file_schema.CustomTextFileCreate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: SimpleUser = None, # 修改为SimpleUser
|
||||
db: Session,
|
||||
current_user: SimpleUser,
|
||||
):
|
||||
# 如果没有提供current_user,则创建一个默认的
|
||||
if current_user is None:
|
||||
current_user = SimpleUser("5d27df0b-7eec-4fa6-9f8b-0f9b7e852f60")
|
||||
|
||||
content_bytes = create_data.content.encode('utf-8')
|
||||
file_size = len(content_bytes)
|
||||
print(f"file size: {file_size} byte")
|
||||
@@ -350,8 +343,6 @@ async def memory_konwledges_up(
|
||||
|
||||
return success(data=document_schema.Document.model_validate(db_document), msg="custom text upload successful")
|
||||
|
||||
'''添加新块'''
|
||||
|
||||
|
||||
async def create_document_chunk(
|
||||
kb_id: uuid.UUID,
|
||||
@@ -417,7 +408,7 @@ async def create_document_chunk(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"查询文档块失败: {error_msg}"
|
||||
)
|
||||
|
||||
|
||||
sort_id = sort_id + 1
|
||||
|
||||
# 5. 创建文档块
|
||||
@@ -450,6 +441,7 @@ async def create_document_chunk(
|
||||
|
||||
return success(data=chunk, msg="文档块创建成功")
|
||||
|
||||
|
||||
async def write_rag(end_user_id, message, user_rag_memory_id):
|
||||
"""
|
||||
将消息写入 RAG 知识库
|
||||
@@ -483,15 +475,12 @@ async def write_rag(end_user_id, message, user_rag_memory_id):
|
||||
detail=f"知识库ID格式无效: {user_rag_memory_id}"
|
||||
)
|
||||
|
||||
db_gen = get_db()
|
||||
db = next(db_gen)
|
||||
|
||||
try:
|
||||
with get_db_context() as db:
|
||||
create_data = CustomTextFileCreate(title=end_user_id, content=message)
|
||||
current_user = SimpleUser(user_rag_memory_id)
|
||||
# 检查文档是否已存在
|
||||
document = find_document_id_by_kb_and_filename(db=db, kb_id=user_rag_memory_id, file_name=f"{end_user_id}.txt")
|
||||
print('======',document)
|
||||
print('======', document)
|
||||
api_logger.info(f"查找文档结果: document_id={document}")
|
||||
if document is not None:
|
||||
# 文档已存在,直接添加新块
|
||||
@@ -528,6 +517,3 @@ async def write_rag(end_user_id, message, user_rag_memory_id):
|
||||
else:
|
||||
api_logger.error(f"创建文档后无法找到文档ID: end_user_id={end_user_id}")
|
||||
return result
|
||||
finally:
|
||||
# 确保数据库会话被关闭
|
||||
db.close()
|
||||
@@ -115,6 +115,17 @@ class DataConfigService: # 数据配置服务类(PostgreSQL)
|
||||
|
||||
# --- Create ---
|
||||
def create(self, params: ConfigParamsCreate) -> Dict[str, Any]: # 创建配置参数(仅名称与描述)
|
||||
# 业务层检查同一工作空间下是否已存在同名配置
|
||||
if params.workspace_id and params.config_name:
|
||||
from app.models.memory_config_model import MemoryConfig
|
||||
existing = (
|
||||
self.db.query(MemoryConfig)
|
||||
.filter_by(workspace_id=params.workspace_id, config_name=params.config_name)
|
||||
.first()
|
||||
)
|
||||
if existing:
|
||||
raise ValueError(f"DUPLICATE_CONFIG_NAME:{params.config_name}")
|
||||
|
||||
# 如果workspace_id存在且模型字段未全部指定,则自动获取
|
||||
if params.workspace_id and not all([params.llm_id, params.embedding_id, params.rerank_id]):
|
||||
configs = self._get_workspace_configs(params.workspace_id)
|
||||
@@ -211,6 +222,7 @@ class DataConfigService: # 数据配置服务类(PostgreSQL)
|
||||
"apply_id": config.apply_id,
|
||||
"scene_id": str(config.scene_id) if config.scene_id else None,
|
||||
"scene_name": scene_name, # 新增:场景名称
|
||||
"is_system_default": config.is_default, # 是否为系统默认配置
|
||||
"llm_id": config.llm_id,
|
||||
"embedding_id": config.embedding_id,
|
||||
"rerank_id": config.rerank_id,
|
||||
|
||||
@@ -90,7 +90,8 @@ class ModelConfigService:
|
||||
api_key: str,
|
||||
api_base: Optional[str] = None,
|
||||
model_type: str = "llm",
|
||||
test_message: str = "Hello"
|
||||
test_message: str = "Hello",
|
||||
is_omni: bool = False
|
||||
) -> Dict[str, Any]:
|
||||
"""验证模型配置是否有效
|
||||
|
||||
@@ -102,6 +103,7 @@ class ModelConfigService:
|
||||
api_base: API基础URL
|
||||
model_type: 模型类型 (llm/chat/embedding/rerank)
|
||||
test_message: 测试消息
|
||||
is_omni: 是否为Omni模型
|
||||
|
||||
Returns:
|
||||
Dict: 验证结果
|
||||
@@ -119,6 +121,7 @@ class ModelConfigService:
|
||||
provider=provider,
|
||||
api_key=api_key,
|
||||
base_url=api_base,
|
||||
is_omni=is_omni,
|
||||
temperature=0.7,
|
||||
max_tokens=100
|
||||
)
|
||||
@@ -257,8 +260,9 @@ class ModelConfigService:
|
||||
provider=model_data.provider,
|
||||
api_key=api_key_data.api_key,
|
||||
api_base=api_key_data.api_base,
|
||||
model_type=model_data.type, # 传递模型类型
|
||||
test_message="Hello"
|
||||
model_type=model_data.type,
|
||||
test_message="Hello",
|
||||
is_omni=model_data.is_omni
|
||||
)
|
||||
if not validation_result["valid"]:
|
||||
raise BusinessException(
|
||||
@@ -279,6 +283,9 @@ class ModelConfigService:
|
||||
for api_key_data in api_key_datas:
|
||||
api_key_data.model_name = model_data.name
|
||||
api_key_data.provider = model_data.provider
|
||||
# 同步capability和is_omni
|
||||
api_key_data.capability = model_data.capability
|
||||
api_key_data.is_omni = model_data.is_omni
|
||||
api_key_create_schema = ModelApiKeyCreate(
|
||||
model_config_ids=[model.id],
|
||||
**api_key_data.model_dump()
|
||||
@@ -473,6 +480,9 @@ class ModelApiKeyService:
|
||||
model_config = ModelConfigRepository.get_by_id(db, model_config_id)
|
||||
if not model_config:
|
||||
continue
|
||||
|
||||
data.is_omni = model_config.is_omni
|
||||
data.capability = model_config.capability
|
||||
|
||||
# 从ModelBase获取model_name
|
||||
model_name = model_config.model_base.name if model_config.model_base else model_config.name
|
||||
@@ -497,6 +507,8 @@ class ModelApiKeyService:
|
||||
existing_key.config = data.config
|
||||
existing_key.priority = data.priority
|
||||
existing_key.model_name = model_name
|
||||
existing_key.capability = data.capability
|
||||
existing_key.is_omni = data.is_omni
|
||||
|
||||
# 检查是否已关联该模型配置
|
||||
if model_config not in existing_key.model_configs:
|
||||
@@ -513,7 +525,8 @@ class ModelApiKeyService:
|
||||
api_key=data.api_key,
|
||||
api_base=data.api_base,
|
||||
model_type=model_config.type,
|
||||
test_message="Hello"
|
||||
test_message="Hello",
|
||||
is_omni=data.is_omni
|
||||
)
|
||||
if not validation_result["valid"]:
|
||||
# 记录验证失败的模型,但不抛出异常
|
||||
@@ -528,6 +541,8 @@ class ModelApiKeyService:
|
||||
provider=data.provider,
|
||||
api_key=data.api_key,
|
||||
api_base=data.api_base,
|
||||
capability=data.capability,
|
||||
is_omni=data.is_omni,
|
||||
config=data.config,
|
||||
is_active=data.is_active,
|
||||
priority=data.priority
|
||||
@@ -550,6 +565,10 @@ class ModelApiKeyService:
|
||||
model_config = ModelConfigRepository.get_by_id(db, model_config_id)
|
||||
if not model_config:
|
||||
raise BusinessException("模型配置不存在", BizCode.MODEL_NOT_FOUND)
|
||||
if api_key_data.is_omni is None:
|
||||
api_key_data.is_omni = model_config.is_omni
|
||||
if api_key_data.capability is None:
|
||||
api_key_data.capability = model_config.capability
|
||||
|
||||
# 检查API Key是否已存在(包括软删除),需要考虑tenant_id
|
||||
existing_key = db.query(ModelApiKey).join(
|
||||
@@ -572,6 +591,8 @@ class ModelApiKeyService:
|
||||
existing_key.config = api_key_data.config
|
||||
existing_key.priority = api_key_data.priority
|
||||
existing_key.model_name = api_key_data.model_name
|
||||
existing_key.capability = api_key_data.capability
|
||||
existing_key.is_omni = api_key_data.is_omni
|
||||
|
||||
# 检查是否已关联该模型配置
|
||||
if model_config not in existing_key.model_configs:
|
||||
@@ -589,7 +610,8 @@ class ModelApiKeyService:
|
||||
api_key=api_key_data.api_key,
|
||||
api_base=api_key_data.api_base,
|
||||
model_type=model_config.type,
|
||||
test_message="Hello"
|
||||
test_message="Hello",
|
||||
is_omni=api_key_data.is_omni
|
||||
)
|
||||
if not validation_result["valid"]:
|
||||
raise BusinessException(
|
||||
@@ -620,7 +642,8 @@ class ModelApiKeyService:
|
||||
api_key=api_key_data.api_key or existing_api_key.api_key,
|
||||
api_base=api_key_data.api_base or existing_api_key.api_base,
|
||||
model_type=model_config.type,
|
||||
test_message="Hello"
|
||||
test_message="Hello",
|
||||
is_omni=model_config.is_omni
|
||||
)
|
||||
if not validation_result["valid"]:
|
||||
raise BusinessException(
|
||||
@@ -755,6 +778,8 @@ class ModelBaseService:
|
||||
"type": model_base.type,
|
||||
"logo": model_base.logo,
|
||||
"description": model_base.description,
|
||||
"capability": model_base.capability,
|
||||
"is_omni": model_base.is_omni,
|
||||
"is_composite": False
|
||||
}
|
||||
model_config = ModelConfigRepository.create(db, model_config_data)
|
||||
|
||||
@@ -123,11 +123,14 @@ class MultiAgentOrchestrator:
|
||||
user_id: 用户 ID
|
||||
variables: 变量参数
|
||||
use_llm_routing: 是否使用 LLM 路由
|
||||
web_search: 是否启用网络搜索
|
||||
memory: 是否启用记忆功能
|
||||
storage_type: 存储类型
|
||||
user_rag_memory_id: 用户 RAG 记忆 ID
|
||||
|
||||
Yields:
|
||||
SSE 格式的事件流
|
||||
"""
|
||||
import json
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
@@ -200,7 +203,8 @@ class MultiAgentOrchestrator:
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"多 Agent 任务执行失败(流式)",
|
||||
extra={"error": str(e), "mode": self._normalized_mode}
|
||||
extra={"error": str(e), "mode": self._normalized_mode},
|
||||
exc_info=True
|
||||
)
|
||||
# 发送错误事件
|
||||
yield self._format_sse_event("error", {
|
||||
@@ -1267,7 +1271,7 @@ class MultiAgentOrchestrator:
|
||||
Yields:
|
||||
SSE 格式的事件流
|
||||
"""
|
||||
from app.services.draft_run_service import DraftRunService
|
||||
from app.services.draft_run_service import AgentRunService
|
||||
|
||||
# 获取模型配置
|
||||
model_config = self.db.get(ModelConfig, agent_config.default_model_config_id)
|
||||
@@ -1278,7 +1282,7 @@ class MultiAgentOrchestrator:
|
||||
)
|
||||
|
||||
# 流式执行 Agent
|
||||
draft_service = DraftRunService(self.db)
|
||||
draft_service = AgentRunService(self.db)
|
||||
async for event in draft_service.run_stream(
|
||||
agent_config=agent_config,
|
||||
model_config=model_config,
|
||||
@@ -1320,7 +1324,7 @@ class MultiAgentOrchestrator:
|
||||
Returns:
|
||||
执行结果
|
||||
"""
|
||||
from app.services.draft_run_service import DraftRunService
|
||||
from app.services.draft_run_service import AgentRunService
|
||||
|
||||
# 获取模型配置
|
||||
model_config = self.db.get(ModelConfig, agent_config.default_model_config_id)
|
||||
@@ -1331,7 +1335,7 @@ class MultiAgentOrchestrator:
|
||||
)
|
||||
|
||||
# 执行 Agent
|
||||
draft_service = DraftRunService(self.db)
|
||||
draft_service = AgentRunService(self.db)
|
||||
result = await draft_service.run(
|
||||
agent_config=agent_config,
|
||||
model_config=model_config,
|
||||
@@ -1633,6 +1637,7 @@ class MultiAgentOrchestrator:
|
||||
self.memory = config_data.get("memory")
|
||||
self.variables = config_data.get("variables", [])
|
||||
self.tools = config_data.get("tools", {})
|
||||
self.skills = config_data.get("skills", {})
|
||||
self.default_model_config_id = release.default_model_config_id
|
||||
|
||||
return AgentConfigProxy(release, app, config_data)
|
||||
@@ -2593,6 +2598,7 @@ class MultiAgentOrchestrator:
|
||||
provider=api_key_config.provider,
|
||||
api_key=api_key_config.api_key,
|
||||
base_url=api_key_config.api_base,
|
||||
is_omni=api_key_config.is_omni,
|
||||
temperature=0.7, # 整合任务使用中等温度
|
||||
max_tokens=2000
|
||||
)
|
||||
@@ -2758,6 +2764,7 @@ class MultiAgentOrchestrator:
|
||||
provider=api_key_config.provider,
|
||||
api_key=api_key_config.api_key,
|
||||
base_url=api_key_config.api_base,
|
||||
is_omni=api_key_config.is_omni,
|
||||
temperature=0.7,
|
||||
max_tokens=2000,
|
||||
extra_params={"streaming": True} # 启用流式输出
|
||||
|
||||
@@ -267,7 +267,7 @@ class MultiAgentService:
|
||||
|
||||
# 2. 验证模型配置(如果提供了)
|
||||
if data.default_model_config_id:
|
||||
model_api_key = ModelApiKeyService.get_a_api_key(self.db, data.default_model_config_id)
|
||||
model_api_key = ModelApiKeyService.get_available_api_key(self.db, data.default_model_config_id)
|
||||
if not model_api_key:
|
||||
raise ResourceNotFoundException("模型配置", str(data.default_model_config_id))
|
||||
|
||||
|
||||
@@ -9,47 +9,100 @@
|
||||
- OpenAI: 支持 URL 和 base64 格式
|
||||
"""
|
||||
import uuid
|
||||
from typing import List, Dict, Any, Optional, Protocol
|
||||
import httpx
|
||||
import base64
|
||||
from typing import List, Dict, Any, Optional
|
||||
from abc import ABC, abstractmethod
|
||||
from sqlalchemy.orm import Session
|
||||
from docx import Document
|
||||
import io
|
||||
import PyPDF2
|
||||
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.error_codes import BizCode
|
||||
from app.schemas.app_schema import FileInput, FileType, TransferMethod
|
||||
from app.models.generic_file_model import GenericFile
|
||||
from app.models.file_metadata_model import FileMetadata
|
||||
from app.core.config import settings
|
||||
from app.services.audio_transcription_service import AudioTranscriptionService
|
||||
|
||||
logger = get_business_logger()
|
||||
|
||||
|
||||
class ImageFormatStrategy(Protocol):
|
||||
"""图片格式策略接口"""
|
||||
class MultimodalFormatStrategy(ABC):
|
||||
"""多模态格式策略基类"""
|
||||
|
||||
@abstractmethod
|
||||
async def format_image(self, url: str) -> Dict[str, Any]:
|
||||
"""格式化图片"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def format_document(self, file_name: str, text: str) -> Dict[str, Any]:
|
||||
"""格式化文档"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def format_audio(self, file_type: str, url: str) -> Dict[str, Any]:
|
||||
"""格式化音频"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def format_video(self, url: str) -> Dict[str, Any]:
|
||||
"""格式化视频"""
|
||||
pass
|
||||
|
||||
|
||||
class DashScopeFormatStrategy(MultimodalFormatStrategy):
|
||||
"""通义千问策略"""
|
||||
|
||||
async def format_image(self, url: str) -> Dict[str, Any]:
|
||||
"""将图片 URL 转换为特定 provider 的格式"""
|
||||
...
|
||||
|
||||
|
||||
class DashScopeImageStrategy:
|
||||
"""通义千问图片格式策略"""
|
||||
|
||||
async def format_image(self, url: str) -> Dict[str, Any]:
|
||||
"""通义千问格式: {"type": "image", "image": "url"}"""
|
||||
"""通义千问图片格式:{"type": "image", "image": "url"}"""
|
||||
return {
|
||||
"type": "image",
|
||||
"image": url
|
||||
}
|
||||
|
||||
async def format_document(self, file_name: str, text: str) -> Dict[str, Any]:
|
||||
"""通义千问文档格式"""
|
||||
return {
|
||||
"type": "text",
|
||||
"text": f"<document name=\"{file_name}\">\n{text}\n</document>"
|
||||
}
|
||||
|
||||
class BedrockImageStrategy:
|
||||
"""Bedrock/Anthropic 图片格式策略"""
|
||||
async def format_audio(self, file_type: str, url: str, transcription: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
通义千问音频格式
|
||||
- 原生支持: qwen-audio 系列
|
||||
- 其他模型: 需要转录为文本
|
||||
"""
|
||||
if transcription:
|
||||
return {
|
||||
"type": "text",
|
||||
"text": f"<audio url=\"{url}\">\n{transcription}\n</audio>"
|
||||
}
|
||||
# 通义千问音频格式:{"type": "audio", "audio": "url"}
|
||||
return {
|
||||
"type": "audio",
|
||||
"audio": url
|
||||
}
|
||||
|
||||
async def format_video(self, url: str) -> Dict[str, Any]:
|
||||
"""通义千问视频格式(qwen-vl 系列原生支持)"""
|
||||
return {
|
||||
"type": "video",
|
||||
"video": url
|
||||
}
|
||||
|
||||
|
||||
class BedrockFormatStrategy(MultimodalFormatStrategy):
|
||||
"""Bedrock/Anthropic 策略"""
|
||||
|
||||
async def format_image(self, url: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Bedrock/Anthropic 格式: base64 编码
|
||||
{"type": "image", "source": {"type": "base64", "media_type": "...", "data": "..."}}
|
||||
"""
|
||||
import httpx
|
||||
import base64
|
||||
from mimetypes import guess_type
|
||||
|
||||
logger.info(f"下载并编码图片: {url}")
|
||||
@@ -84,9 +137,46 @@ class BedrockImageStrategy:
|
||||
}
|
||||
}
|
||||
|
||||
async def format_document(self, file_name: str, text: str) -> Dict[str, Any]:
|
||||
"""Bedrock/Anthropic 文档格式(需要 base64 编码)"""
|
||||
# Bedrock 文档需要 base64 编码
|
||||
text_bytes = text.encode('utf-8')
|
||||
base64_text = base64.b64encode(text_bytes).decode('utf-8')
|
||||
|
||||
class OpenAIImageStrategy:
|
||||
"""OpenAI 图片格式策略"""
|
||||
return {
|
||||
"type": "document",
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": "text/plain",
|
||||
"data": base64_text
|
||||
}
|
||||
}
|
||||
|
||||
async def format_audio(self, file_type: str, url: str, transcription: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
Bedrock/Anthropic 音频格式
|
||||
不支持原生音频,必须转录为文本
|
||||
"""
|
||||
if transcription:
|
||||
return {
|
||||
"type": "text",
|
||||
"text": f"[音频转录]\n{transcription}"
|
||||
}
|
||||
return {
|
||||
"type": "text",
|
||||
"text": "[音频文件:Bedrock 不支持原生音频,请启用音频转文本功能]"
|
||||
}
|
||||
|
||||
async def format_video(self, url: str) -> Dict[str, Any]:
|
||||
"""Bedrock/Anthropic 视频格式"""
|
||||
return {
|
||||
"type": "text",
|
||||
"text": f"<video url=\"{url}\">\n[视频文件,当前 provider 暂不支持]\n</video>"
|
||||
}
|
||||
|
||||
|
||||
class OpenAIFormatStrategy(MultimodalFormatStrategy):
|
||||
"""OpenAI 策略"""
|
||||
|
||||
async def format_image(self, url: str) -> Dict[str, Any]:
|
||||
"""OpenAI 格式: {"type": "image_url", "image_url": {"url": "..."}}"""
|
||||
@@ -97,29 +187,97 @@ class OpenAIImageStrategy:
|
||||
}
|
||||
}
|
||||
|
||||
async def format_document(self, file_name: str, text: str) -> Dict[str, Any]:
|
||||
"""OpenAI 文档格式"""
|
||||
return {
|
||||
"type": "text",
|
||||
"text": f"<document name=\"{file_name}\">\n{text}\n</document>"
|
||||
}
|
||||
|
||||
async def format_audio(self, file_type: str, url: str, transcription: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
OpenAI 音频格式
|
||||
- gpt-4o-audio 系列支持原生音频(需要 base64 编码)
|
||||
- 其他模型使用转录文本
|
||||
"""
|
||||
if transcription:
|
||||
return {
|
||||
"type": "text",
|
||||
"text": f"<audio url=\"{url}\">\n{transcription}\n</audio>"
|
||||
}
|
||||
|
||||
# OpenAI 音频需要 base64 编码
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
response = await client.get(url)
|
||||
response.raise_for_status()
|
||||
audio_data = response.content
|
||||
base64_audio = base64.b64encode(audio_data).decode('utf-8')
|
||||
# 1. 优先从 file_type (MIME) 取扩展名
|
||||
file_ext = file_type.split('/')[-1] if file_type and '/' in file_type else None
|
||||
# 2. 从响应头 content-type 取
|
||||
if not file_ext:
|
||||
ct = response.headers.get("content-type", "")
|
||||
file_ext = ct.split('/')[-1].split(';')[0].strip() if '/' in ct else None
|
||||
# 3. 从 URL 路径取扩展名
|
||||
if not file_ext:
|
||||
file_ext = url.split('?')[0].rsplit('.', 1)[-1].lower() or None
|
||||
# 4. 默认 wav
|
||||
# supported_ext = {"wav", "mp3", "mp4", "ogg", "flac", "webm", "m4a", "wave", "x-m4a"}
|
||||
file_ext = "wav" if not file_ext else file_ext
|
||||
|
||||
return {
|
||||
"type": "input_audio",
|
||||
"input_audio": {
|
||||
"data": f"data:;base64,{base64_audio}",
|
||||
"format": file_ext
|
||||
}
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"下载音频失败: {e}")
|
||||
return {
|
||||
"type": "text",
|
||||
"text": f"[音频处理失败: {str(e)}]"
|
||||
}
|
||||
|
||||
async def format_video(self, url: str) -> Dict[str, Any]:
|
||||
"""OpenAI 视频格式"""
|
||||
return {
|
||||
"type": "video_url",
|
||||
"video_url": {
|
||||
"url": url
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
# Provider 到策略的映射
|
||||
PROVIDER_STRATEGIES = {
|
||||
"dashscope": DashScopeImageStrategy,
|
||||
"bedrock": BedrockImageStrategy,
|
||||
"anthropic": BedrockImageStrategy,
|
||||
"openai": OpenAIImageStrategy,
|
||||
"dashscope": DashScopeFormatStrategy,
|
||||
"bedrock": BedrockFormatStrategy,
|
||||
"anthropic": BedrockFormatStrategy,
|
||||
"openai": OpenAIFormatStrategy,
|
||||
}
|
||||
|
||||
|
||||
class MultimodalService:
|
||||
"""多模态文件处理服务"""
|
||||
|
||||
def __init__(self, db: Session, provider: str = "dashscope"):
|
||||
def __init__(self, db: Session, provider: str = "dashscope", api_key: Optional[str] = None, enable_audio_transcription: bool = False, is_omni: bool = False):
|
||||
"""
|
||||
初始化多模态服务
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
provider: 模型提供商(dashscope, bedrock, anthropic 等)
|
||||
provider: 模型提供商(dashscope, bedrock, anthropic, openai 等)
|
||||
api_key: API 密钥(用于音频转文本)
|
||||
enable_audio_transcription: 是否启用音频转文本
|
||||
is_omni: 是否为 Omni 模型(dashscope 的 omni 模型需要使用 OpenAI 兼容格式)
|
||||
"""
|
||||
self.db = db
|
||||
self.provider = provider.lower()
|
||||
self.api_key = api_key
|
||||
self.enable_audio_transcription = enable_audio_transcription
|
||||
self.is_omni = is_omni
|
||||
|
||||
async def process_files(
|
||||
self,
|
||||
@@ -137,20 +295,32 @@ class MultimodalService:
|
||||
if not files:
|
||||
return []
|
||||
|
||||
# 获取对应的策略
|
||||
# dashscope 的 omni 模型使用 OpenAI 兼容格式
|
||||
if self.provider == "dashscope" and self.is_omni:
|
||||
strategy_class = OpenAIFormatStrategy
|
||||
else:
|
||||
strategy_class = PROVIDER_STRATEGIES.get(self.provider)
|
||||
if not strategy_class:
|
||||
logger.warning(f"未找到 provider '{self.provider}' 的策略,使用默认策略")
|
||||
strategy_class = DashScopeFormatStrategy
|
||||
|
||||
strategy = strategy_class()
|
||||
|
||||
result = []
|
||||
for idx, file in enumerate(files):
|
||||
try:
|
||||
if file.type == FileType.IMAGE:
|
||||
content = await self._process_image(file)
|
||||
content = await self._process_image(file, strategy)
|
||||
result.append(content)
|
||||
elif file.type == FileType.DOCUMENT:
|
||||
content = await self._process_document(file)
|
||||
content = await self._process_document(file, strategy)
|
||||
result.append(content)
|
||||
elif file.type == FileType.AUDIO:
|
||||
content = await self._process_audio(file)
|
||||
content = await self._process_audio(file, strategy)
|
||||
result.append(content)
|
||||
elif file.type == FileType.VIDEO:
|
||||
content = await self._process_video(file)
|
||||
content = await self._process_video(file, strategy)
|
||||
result.append(content)
|
||||
else:
|
||||
logger.warning(f"不支持的文件类型: {file.type}")
|
||||
@@ -172,55 +342,29 @@ class MultimodalService:
|
||||
logger.info(f"成功处理 {len(result)}/{len(files)} 个文件,provider={self.provider}")
|
||||
return result
|
||||
|
||||
async def _process_image(self, file: FileInput) -> Dict[str, Any]:
|
||||
async def _process_image(self, file: FileInput, strategy) -> Dict[str, Any]:
|
||||
"""
|
||||
处理图片文件
|
||||
|
||||
Args:
|
||||
file: 图片文件输入
|
||||
strategy: 格式化策略
|
||||
|
||||
Returns:
|
||||
Dict: 根据 provider 返回不同格式
|
||||
- Anthropic/Bedrock: {"type": "image", "source": {"type": "base64", "media_type": "...", "data": "..."}}
|
||||
- 通义千问: {"type": "image", "image": "url"}
|
||||
Dict: 根据 provider 返回不同格式的图片内容
|
||||
"""
|
||||
url = await self.get_file_url(file)
|
||||
|
||||
logger.debug(f"处理图片: {url}, provider={self.provider}")
|
||||
|
||||
# 根据 provider 返回不同格式
|
||||
if self.provider in ["bedrock", "anthropic"]:
|
||||
# Anthropic/Bedrock 只支持 base64 格式,需要下载并转换
|
||||
try:
|
||||
logger.info(f"开始下载并编码图片: {url}")
|
||||
base64_data, media_type = await self._download_and_encode_image(url)
|
||||
result = {
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": media_type,
|
||||
"data": base64_data[:100] + "..." # 只记录前100个字符
|
||||
}
|
||||
}
|
||||
logger.info(f"图片编码完成: media_type={media_type}, data_length={len(base64_data)}")
|
||||
# 返回完整数据
|
||||
result["source"]["data"] = base64_data
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"下载并编码图片失败: {e}", exc_info=True)
|
||||
# 返回错误提示
|
||||
return {
|
||||
"type": "text",
|
||||
"text": f"[图片加载失败: {str(e)}]"
|
||||
}
|
||||
else:
|
||||
# 通义千问等其他格式支持 URL
|
||||
try:
|
||||
url = await self.get_file_url(file)
|
||||
return await strategy.format_image(url)
|
||||
except Exception as e:
|
||||
logger.error(f"处理图片失败: {e}", exc_info=True)
|
||||
return {
|
||||
"type": "image",
|
||||
"image": url
|
||||
"type": "text",
|
||||
"text": f"[图片处理失败: {str(e)}]"
|
||||
}
|
||||
|
||||
async def _download_and_encode_image(self, url: str) -> tuple[str, str]:
|
||||
@staticmethod
|
||||
async def _download_and_encode_image(url: str) -> tuple[str, str]:
|
||||
"""
|
||||
下载图片并转换为 base64
|
||||
|
||||
@@ -230,8 +374,6 @@ class MultimodalService:
|
||||
Returns:
|
||||
tuple: (base64_data, media_type)
|
||||
"""
|
||||
import httpx
|
||||
import base64
|
||||
from mimetypes import guess_type
|
||||
|
||||
# 下载图片
|
||||
@@ -258,15 +400,16 @@ class MultimodalService:
|
||||
|
||||
return base64_data, media_type
|
||||
|
||||
async def _process_document(self, file: FileInput) -> Dict[str, Any]:
|
||||
async def _process_document(self, file: FileInput, strategy) -> Dict[str, Any]:
|
||||
"""
|
||||
处理文档文件(PDF、Word 等)
|
||||
|
||||
Args:
|
||||
file: 文档文件输入
|
||||
strategy: 格式化策略
|
||||
|
||||
Returns:
|
||||
Dict: text 格式的内容(包含提取的文本)
|
||||
Dict: 根据 provider 返回不同格式的文档内容
|
||||
"""
|
||||
if file.transfer_method == TransferMethod.REMOTE_URL:
|
||||
# 远程文档暂不支持提取
|
||||
@@ -277,48 +420,68 @@ class MultimodalService:
|
||||
else:
|
||||
# 本地文件,提取文本内容
|
||||
text = await self._extract_document_text(file.upload_file_id)
|
||||
generic_file = self.db.query(GenericFile).filter(
|
||||
GenericFile.id == file.upload_file_id
|
||||
file_metadata = self.db.query(FileMetadata).filter(
|
||||
FileMetadata.id == file.upload_file_id
|
||||
).first()
|
||||
|
||||
file_name = generic_file.file_name if generic_file else "unknown"
|
||||
file_name = file_metadata.file_name if file_metadata else "unknown"
|
||||
|
||||
return {
|
||||
"type": "text",
|
||||
"text": f"<document name=\"{file_name}\">\n{text}\n</document>"
|
||||
}
|
||||
# 使用策略格式化文档
|
||||
return await strategy.format_document(file_name, text)
|
||||
|
||||
async def _process_audio(self, file: FileInput) -> Dict[str, Any]:
|
||||
async def _process_audio(self, file: FileInput, strategy) -> Dict[str, Any]:
|
||||
"""
|
||||
处理音频文件
|
||||
|
||||
Args:
|
||||
file: 音频文件输入
|
||||
strategy: 格式化策略
|
||||
|
||||
Returns:
|
||||
Dict: 音频内容(暂时返回占位符)
|
||||
Dict: 根据 provider 返回不同格式的音频内容
|
||||
"""
|
||||
# TODO: 实现音频转文字功能
|
||||
return {
|
||||
"type": "text",
|
||||
"text": "[音频文件,暂不支持处理]"
|
||||
}
|
||||
try:
|
||||
url = await self.get_file_url(file)
|
||||
|
||||
async def _process_video(self, file: FileInput) -> Dict[str, Any]:
|
||||
# 如果启用音频转文本且有 API Key
|
||||
transcription = None
|
||||
if self.enable_audio_transcription and self.api_key:
|
||||
logger.info(f"开始音频转文本: {url}")
|
||||
if self.provider == "dashscope":
|
||||
transcription = await AudioTranscriptionService.transcribe_dashscope(url, self.api_key)
|
||||
elif self.provider == "openai":
|
||||
transcription = await AudioTranscriptionService.transcribe_openai(url, self.api_key)
|
||||
else:
|
||||
logger.warning(f"Provider {self.provider} 不支持音频转文本")
|
||||
|
||||
return await strategy.format_audio(file.file_type, url, transcription)
|
||||
except Exception as e:
|
||||
logger.error(f"处理音频失败: {e}", exc_info=True)
|
||||
return {
|
||||
"type": "text",
|
||||
"text": f"[音频处理失败: {str(e)}]"
|
||||
}
|
||||
|
||||
async def _process_video(self, file: FileInput, strategy) -> Dict[str, Any]:
|
||||
"""
|
||||
处理视频文件
|
||||
|
||||
Args:
|
||||
file: 视频文件输入
|
||||
strategy: 格式化策略
|
||||
|
||||
Returns:
|
||||
Dict: 视频内容(暂时返回占位符)
|
||||
Dict: 根据 provider 返回不同格式的视频内容
|
||||
"""
|
||||
# TODO: 实现视频处理功能
|
||||
return {
|
||||
"type": "text",
|
||||
"text": "[视频文件,暂不支持处理]"
|
||||
}
|
||||
try:
|
||||
url = await self.get_file_url(file)
|
||||
return await strategy.format_video(url)
|
||||
except Exception as e:
|
||||
logger.error(f"处理视频失败: {e}", exc_info=True)
|
||||
return {
|
||||
"type": "text",
|
||||
"text": f"[视频处理失败: {str(e)}]"
|
||||
}
|
||||
|
||||
async def get_file_url(self, file: FileInput) -> str:
|
||||
"""
|
||||
@@ -336,26 +499,22 @@ class MultimodalService:
|
||||
if file.transfer_method == TransferMethod.REMOTE_URL:
|
||||
return file.url
|
||||
else:
|
||||
# 本地文件,通过 file_storage 系统获取永久访问 URL
|
||||
from app.models.file_metadata_model import FileMetadata
|
||||
from app.core.config import settings
|
||||
|
||||
file_id = file.upload_file_id
|
||||
print("="*50)
|
||||
print("file_id",file_id)
|
||||
|
||||
|
||||
# 查询 FileMetadata
|
||||
file_metadata = self.db.query(FileMetadata).filter(
|
||||
FileMetadata.id == file_id,
|
||||
FileMetadata.status == "completed"
|
||||
).first()
|
||||
|
||||
|
||||
if not file_metadata:
|
||||
raise BusinessException(
|
||||
f"文件不存在或已删除: {file_id}",
|
||||
BizCode.NOT_FOUND
|
||||
)
|
||||
|
||||
|
||||
# 返回永久URL
|
||||
server_url = settings.FILE_LOCAL_SERVER_URL
|
||||
return f"{server_url}/storage/permanent/{file_id}"
|
||||
@@ -370,58 +529,79 @@ class MultimodalService:
|
||||
Returns:
|
||||
str: 提取的文本内容
|
||||
"""
|
||||
generic_file = self.db.query(GenericFile).filter(
|
||||
GenericFile.id == file_id,
|
||||
GenericFile.status == "active"
|
||||
file_metadata = self.db.query(FileMetadata).filter(
|
||||
FileMetadata.id == file_id,
|
||||
FileMetadata.status == "completed"
|
||||
).first()
|
||||
|
||||
if not generic_file:
|
||||
if not file_metadata:
|
||||
raise BusinessException(
|
||||
f"文件不存在或已删除: {file_id}",
|
||||
BizCode.NOT_FOUND
|
||||
)
|
||||
|
||||
# TODO: 根据文件类型提取文本
|
||||
# - PDF: 使用 PyPDF2 或 pdfplumber
|
||||
# - Word: 使用 python-docx
|
||||
# - TXT/MD: 直接读取
|
||||
|
||||
file_ext = generic_file.file_ext.lower()
|
||||
file_ext = file_metadata.file_ext.lower()
|
||||
server_url = settings.FILE_LOCAL_SERVER_URL
|
||||
file_url = f"{server_url}/storage/permanent/{file_id}"
|
||||
|
||||
if file_ext in ['.txt', '.md', '.markdown']:
|
||||
return await self._read_text_file(generic_file.storage_path)
|
||||
return await self._read_text_file(file_url)
|
||||
elif file_ext == '.pdf':
|
||||
return await self._extract_pdf_text(generic_file.storage_path)
|
||||
return await self._extract_pdf_text(file_url)
|
||||
elif file_ext in ['.doc', '.docx']:
|
||||
return await self._extract_word_text(generic_file.storage_path)
|
||||
return await self._extract_word_text(file_url)
|
||||
else:
|
||||
return f"[不支持的文档格式: {file_ext}]"
|
||||
|
||||
async def _read_text_file(self, storage_path: str) -> str:
|
||||
@staticmethod
|
||||
async def _read_text_file(file_url: str) -> str:
|
||||
"""读取纯文本文件"""
|
||||
try:
|
||||
with open(storage_path, 'r', encoding='utf-8') as f:
|
||||
return f.read()
|
||||
# 下载文件
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
response = await client.get(file_url)
|
||||
response.raise_for_status()
|
||||
return response.text
|
||||
except Exception as e:
|
||||
logger.error(f"读取文本文件失败: {e}")
|
||||
return f"[文件读取失败: {str(e)}]"
|
||||
|
||||
async def _extract_pdf_text(self, storage_path: str) -> str:
|
||||
@staticmethod
|
||||
async def _extract_pdf_text(file_url: str) -> str:
|
||||
"""提取 PDF 文本"""
|
||||
try:
|
||||
# TODO: 实现 PDF 文本提取
|
||||
# import PyPDF2 或 pdfplumber
|
||||
return "[PDF 文本提取功能待实现]"
|
||||
# 下载 PDF 文件
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
response = await client.get(file_url)
|
||||
response.raise_for_status()
|
||||
pdf_data = response.content
|
||||
|
||||
# 使用 BytesIO 读取 PDF
|
||||
text_parts = []
|
||||
pdf_file = io.BytesIO(pdf_data)
|
||||
pdf_reader = PyPDF2.PdfReader(pdf_file)
|
||||
for page in pdf_reader.pages:
|
||||
text_parts.append(page.extract_text())
|
||||
return '\n'.join(text_parts)
|
||||
except Exception as e:
|
||||
logger.error(f"提取 PDF 文本失败: {e}")
|
||||
return f"[PDF 提取失败: {str(e)}]"
|
||||
|
||||
async def _extract_word_text(self, storage_path: str) -> str:
|
||||
@staticmethod
|
||||
async def _extract_word_text(file_url: str) -> str:
|
||||
"""提取 Word 文档文本"""
|
||||
try:
|
||||
# TODO: 实现 Word 文本提取
|
||||
# import docx
|
||||
return "[Word 文本提取功能待实现]"
|
||||
# 下载 Word 文件
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
response = await client.get(file_url)
|
||||
response.raise_for_status()
|
||||
word_data = response.content
|
||||
|
||||
# 使用 BytesIO 读取 Word 文档
|
||||
word_file = io.BytesIO(word_data)
|
||||
doc = Document(word_file)
|
||||
text_parts = [paragraph.text for paragraph in doc.paragraphs]
|
||||
return '\n'.join(text_parts)
|
||||
except Exception as e:
|
||||
logger.error(f"提取 Word 文本失败: {e}")
|
||||
return f"[Word 提取失败: {str(e)}]"
|
||||
|
||||
@@ -184,7 +184,8 @@ class PromptOptimizerService:
|
||||
model_name=api_config.model_name,
|
||||
provider=api_config.provider,
|
||||
api_key=api_config.api_key,
|
||||
base_url=api_config.api_base
|
||||
base_url=api_config.api_base,
|
||||
is_omni=api_config.is_omni
|
||||
), type=ModelType(model_config.type))
|
||||
try:
|
||||
prompt_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'prompt')
|
||||
|
||||
@@ -247,6 +247,7 @@ class SharedChatService:
|
||||
api_key=api_key_obj.api_key,
|
||||
provider=api_key_obj.provider,
|
||||
api_base=api_key_obj.api_base,
|
||||
is_omni=api_key_obj.is_omni,
|
||||
temperature=model_parameters.get("temperature", 0.7),
|
||||
max_tokens=model_parameters.get("max_tokens", 2000),
|
||||
system_prompt=system_prompt,
|
||||
@@ -454,6 +455,7 @@ class SharedChatService:
|
||||
api_key=api_key_obj.api_key,
|
||||
provider=api_key_obj.provider,
|
||||
api_base=api_key_obj.api_base,
|
||||
is_omni=api_key_obj.is_omni,
|
||||
temperature=model_parameters.get("temperature", 0.7),
|
||||
max_tokens=model_parameters.get("max_tokens", 2000),
|
||||
system_prompt=system_prompt,
|
||||
|
||||
@@ -121,7 +121,7 @@ class SkillService:
|
||||
if skill and skill.is_active:
|
||||
# 加载技能关联的工具
|
||||
for tool_config in skill.tools:
|
||||
tool = tool_service._get_tool_instance(tool_config.get("tool_id", ""), tenant_id)
|
||||
tool = tool_service.get_tool_instance(tool_config.get("tool_id", ""), tenant_id)
|
||||
if tool:
|
||||
langchain_tool = tool.to_langchain_tool(tool_config.get("operation", None))
|
||||
tools.append(langchain_tool)
|
||||
|
||||
@@ -8,6 +8,8 @@ from datetime import datetime
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.tools.mcp import MCPToolManager, SimpleMCPClient
|
||||
from app.repositories.tool_repository import (
|
||||
ToolRepository, BuiltinToolRepository, CustomToolRepository,
|
||||
@@ -79,6 +81,18 @@ class ToolService:
|
||||
config = self.tool_repo.find_by_id_and_tenant(self.db, uuid.UUID(tool_id), tenant_id)
|
||||
return self._config_to_info(config) if config else None
|
||||
|
||||
def _check_name_duplicate(self, name: str, tool_type: ToolType, tenant_id: uuid.UUID, exclude_id: Optional[uuid.UUID] = None):
|
||||
"""检查工具名称是否重复"""
|
||||
query = self.db.query(ToolConfig).filter(
|
||||
ToolConfig.name == name,
|
||||
ToolConfig.tool_type == tool_type.value,
|
||||
ToolConfig.tenant_id == tenant_id
|
||||
)
|
||||
if exclude_id:
|
||||
query = query.filter(ToolConfig.id != exclude_id)
|
||||
if query.first():
|
||||
raise BusinessException(f"工具名称 '{name}' 已存在", BizCode.DUPLICATE_NAME)
|
||||
|
||||
def create_tool(
|
||||
self,
|
||||
name: str,
|
||||
@@ -92,6 +106,7 @@ class ToolService:
|
||||
"""创建工具"""
|
||||
if tool_type == ToolType.BUILTIN:
|
||||
raise ValueError("内置工具不允许创建")
|
||||
self._check_name_duplicate(name, tool_type, tenant_id)
|
||||
|
||||
try:
|
||||
# 创建基础配置
|
||||
@@ -141,6 +156,7 @@ class ToolService:
|
||||
raise ValueError("内置工具不允许修改名称、描述和图标")
|
||||
try:
|
||||
if name:
|
||||
self._check_name_duplicate(name, config_obj.tool_type, tenant_id, exclude_id=config_obj.id)
|
||||
config_obj.name = name
|
||||
if description:
|
||||
config_obj.description = description
|
||||
@@ -209,7 +225,7 @@ class ToolService:
|
||||
|
||||
try:
|
||||
# 获取工具实例
|
||||
tool = self._get_tool_instance(tool_id, tenant_id)
|
||||
tool = self.get_tool_instance(tool_id, tenant_id)
|
||||
if not tool:
|
||||
return ToolResult.error_result(
|
||||
error=f"工具不存在: {tool_id}",
|
||||
@@ -335,7 +351,7 @@ class ToolService:
|
||||
return []
|
||||
|
||||
# 获取工具实例
|
||||
tool_instance = self._get_tool_instance(str(config.id), config.tenant_id)
|
||||
tool_instance = self.get_tool_instance(str(config.id), config.tenant_id)
|
||||
if not tool_instance:
|
||||
return []
|
||||
|
||||
@@ -792,7 +808,7 @@ class ToolService:
|
||||
"""获取工具配置"""
|
||||
return self.tool_repo.find_by_id_and_tenant(self.db, uuid.UUID(tool_id), tenant_id)
|
||||
|
||||
def _get_tool_instance(self, tool_id: str, tenant_id: uuid.UUID) -> Optional[BaseTool]:
|
||||
def get_tool_instance(self, tool_id: str, tenant_id: uuid.UUID) -> Optional[BaseTool]:
|
||||
"""获取工具实例"""
|
||||
if tool_id in self._tool_cache:
|
||||
return self._tool_cache[tool_id]
|
||||
@@ -1416,7 +1432,7 @@ class ToolService:
|
||||
"""测试内置工具连接"""
|
||||
try:
|
||||
# 获取工具实例
|
||||
tool_instance = self._get_tool_instance(str(config.id), config.tenant_id)
|
||||
tool_instance = self.get_tool_instance(str(config.id), config.tenant_id)
|
||||
if not tool_instance:
|
||||
return {"success": False, "message": "无法创建工具实例"}
|
||||
|
||||
|
||||
@@ -21,8 +21,7 @@ from app.repositories.end_user_repository import EndUserRepository
|
||||
from app.repositories.neo4j.cypher_queries import Graph_Node_query
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.schemas.memory_episodic_schema import EmotionSubject, EmotionType, type_mapping
|
||||
from app.services.implicit_memory_service import ImplicitMemoryService
|
||||
from app.services.memory_base_service import MemoryBaseService, MemoryTransService
|
||||
from app.services.memory_base_service import MemoryBaseService
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
from app.services.memory_perceptual_service import MemoryPerceptualService
|
||||
from app.services.memory_short_service import ShortService
|
||||
@@ -1167,7 +1166,6 @@ async def analytics_user_summary(end_user_id: Optional[str] = None, language: st
|
||||
|
||||
from app.core.language_utils import validate_language
|
||||
from app.core.memory.utils.prompt.prompt_utils import render_user_summary_prompt
|
||||
from app.db import get_db
|
||||
from app.repositories.end_user_repository import EndUserRepository
|
||||
|
||||
# 验证语言参数
|
||||
@@ -1178,8 +1176,7 @@ async def analytics_user_summary(end_user_id: Optional[str] = None, language: st
|
||||
if end_user_id:
|
||||
try:
|
||||
# 获取数据库会话并查询用户信息
|
||||
db = next(get_db())
|
||||
try:
|
||||
with get_db_context() as db:
|
||||
repo = EndUserRepository(db)
|
||||
end_user = repo.get_by_id(uuid.UUID(end_user_id))
|
||||
if end_user and end_user.other_name:
|
||||
@@ -1187,8 +1184,7 @@ async def analytics_user_summary(end_user_id: Optional[str] = None, language: st
|
||||
logger.info(f"使用 other_name 作为用户显示名称: {user_display_name}")
|
||||
else:
|
||||
logger.info(f"用户 {end_user_id} 的 other_name 为空,使用默认称呼: {user_display_name}")
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"获取用户 other_name 失败,使用默认称呼: {str(e)}")
|
||||
|
||||
|
||||
@@ -16,6 +16,7 @@ from app.core.workflow.adapters.registry import PlatformAdapterRegistry
|
||||
from app.core.workflow.executor import execute_workflow, execute_workflow_stream
|
||||
from app.core.workflow.nodes.enums import NodeType
|
||||
from app.core.workflow.validator import validate_workflow_config
|
||||
from app.core.workflow.variable.base_variable import FileObject
|
||||
from app.db import get_db
|
||||
from app.models import App
|
||||
from app.models.workflow_model import WorkflowConfig, WorkflowExecution
|
||||
@@ -453,11 +454,14 @@ class WorkflowService:
|
||||
files_struct = []
|
||||
for file in files:
|
||||
files_struct.append(
|
||||
{
|
||||
"type": file.type,
|
||||
"url": await self.multimodal_service.get_file_url(file),
|
||||
"__file": True
|
||||
}
|
||||
FileObject(
|
||||
type=file.type,
|
||||
url=await self.multimodal_service.get_file_url(file),
|
||||
transfer_method=file.transfer_method,
|
||||
file_id=str(file.upload_file_id),
|
||||
origin_file_type=file.file_type,
|
||||
is_file=True
|
||||
).model_dump()
|
||||
)
|
||||
return files_struct
|
||||
|
||||
|
||||
@@ -107,6 +107,7 @@ def get_user_workspaces(db: Session, user: User) -> List[Workspace]:
|
||||
for workspace in workspaces:
|
||||
if workspace.storage_type == 'neo4j':
|
||||
_ensure_default_memory_config(db, workspace)
|
||||
_ensure_default_ontology_scenes(db, workspace)
|
||||
|
||||
business_logger.info(f"用户 {user.username} 的工作空间数量: {len(workspaces)}")
|
||||
return workspaces
|
||||
@@ -1104,6 +1105,52 @@ def _fill_workspace_configs_model_defaults(
|
||||
)
|
||||
|
||||
|
||||
def _ensure_default_ontology_scenes(db: Session, workspace: Workspace) -> None:
|
||||
"""Ensure a workspace has default ontology scenes, creating them if missing.
|
||||
|
||||
Checks whether any is_system_default scene exists for the workspace.
|
||||
If not, runs the DefaultOntologyInitializer to create them.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
workspace: The workspace to check
|
||||
"""
|
||||
from app.models.ontology_scene import OntologyScene
|
||||
|
||||
# 幂等检查:是否已存在系统默认场景
|
||||
existing = db.query(OntologyScene).filter(
|
||||
OntologyScene.workspace_id == workspace.id,
|
||||
OntologyScene.is_system_default.is_(True)
|
||||
).first()
|
||||
|
||||
if existing:
|
||||
return
|
||||
|
||||
business_logger.info(
|
||||
f"Workspace {workspace.id} missing default ontology scenes, creating them"
|
||||
)
|
||||
|
||||
try:
|
||||
initializer = DefaultOntologyInitializer(db)
|
||||
success, error_msg = initializer.initialize_default_scenes(
|
||||
workspace.id, language="zh"
|
||||
)
|
||||
if success:
|
||||
db.commit()
|
||||
business_logger.info(
|
||||
f"为工作空间 {workspace.id} 补建默认本体场景成功"
|
||||
)
|
||||
else:
|
||||
business_logger.warning(
|
||||
f"为工作空间 {workspace.id} 补建默认本体场景失败: {error_msg}"
|
||||
)
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
business_logger.error(
|
||||
f"为工作空间 {workspace.id} 补建默认本体场景异常: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
def _create_default_memory_config(
|
||||
db: Session,
|
||||
workspace_id: uuid.UUID,
|
||||
|
||||
311
api/app/tasks.py
311
api/app/tasks.py
@@ -257,7 +257,7 @@ def parse_document(file_path: str, document_id: uuid.UUID):
|
||||
progress_msg += f"{datetime.now().strftime('%H:%M:%S')} GraphRAG task result for task {task}:\n{result}\n"
|
||||
return result
|
||||
|
||||
try:
|
||||
def sync_task():
|
||||
trio.run(
|
||||
lambda: _run(
|
||||
row=task,
|
||||
@@ -272,6 +272,10 @@ def parse_document(file_path: str, document_id: uuid.UUID):
|
||||
with_community=with_community,
|
||||
)
|
||||
)
|
||||
try:
|
||||
with ThreadPoolExecutor(max_workers=1) as executor:
|
||||
future = executor.submit(sync_task)
|
||||
future.result() # Blocks until the task completes
|
||||
except Exception as e:
|
||||
progress_msg += f"{datetime.now().strftime('%H:%M:%S')} GraphRAG task failed for task {task}:\n{str(e)}\n"
|
||||
progress_msg += f"{datetime.now().strftime('%H:%M:%S')} Knowledge Graph done ({time.time() - start_time}s)"
|
||||
@@ -2108,4 +2112,307 @@ def run_forgetting_cycle_task(self, config_id: Optional[uuid.UUID] = None) -> Di
|
||||
# "config_id": config_id,
|
||||
# "elapsed_time": elapsed_time,
|
||||
# "task_id": self.request.id
|
||||
# }
|
||||
# }
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 隐性记忆和情绪数据更新定时任务
|
||||
# =============================================================================
|
||||
|
||||
@celery_app.task(
|
||||
name="app.tasks.update_implicit_emotions_storage",
|
||||
bind=True,
|
||||
ignore_result=True,
|
||||
max_retries=0,
|
||||
acks_late=False,
|
||||
time_limit=7200, # 2小时硬超时
|
||||
soft_time_limit=6900, # 1小时55分钟软超时
|
||||
)
|
||||
def update_implicit_emotions_storage(self) -> Dict[str, Any]:
|
||||
"""定时任务:更新所有用户的隐性记忆画像和情绪建议数据
|
||||
|
||||
遍历数据库中所有已存在数据的用户,为每个用户重新生成隐性记忆画像和情绪建议。
|
||||
实现错误隔离,单个用户失败不影响其他用户的处理。
|
||||
|
||||
Returns:
|
||||
包含任务执行结果的字典,包括:
|
||||
- status: 任务状态 (SUCCESS/FAILURE)
|
||||
- message: 执行消息
|
||||
- total_users: 总用户数
|
||||
- successful_implicit: 成功更新隐性记忆的用户数
|
||||
- successful_emotion: 成功更新情绪建议的用户数
|
||||
- failed: 失败的用户数
|
||||
- user_results: 每个用户的详细结果
|
||||
- elapsed_time: 执行耗时(秒)
|
||||
- task_id: 任务ID
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
async def _run() -> Dict[str, Any]:
|
||||
from app.core.logging_config import get_logger
|
||||
from app.repositories.implicit_emotions_storage_repository import ImplicitEmotionsStorageRepository
|
||||
from app.models.implicit_emotions_storage_model import ImplicitEmotionsStorage
|
||||
from sqlalchemy import select, func
|
||||
from app.services.implicit_memory_service import ImplicitMemoryService
|
||||
from app.services.emotion_analytics_service import EmotionAnalyticsService
|
||||
|
||||
logger = get_logger(__name__)
|
||||
logger.info("开始执行隐性记忆和情绪数据更新定时任务")
|
||||
|
||||
total_users = 0
|
||||
successful_implicit = 0
|
||||
successful_emotion = 0
|
||||
failed = 0
|
||||
user_results = []
|
||||
|
||||
with get_db_context() as db:
|
||||
try:
|
||||
# 获取所有已存储数据的用户ID(分批次处理)
|
||||
repo = ImplicitEmotionsStorageRepository(db)
|
||||
|
||||
# 先统计总数用于日志
|
||||
from sqlalchemy import func
|
||||
total_users = db.execute(
|
||||
select(func.count()).select_from(ImplicitEmotionsStorage)
|
||||
).scalar() or 0
|
||||
logger.info(f"找到 {total_users} 个需要更新的用户")
|
||||
|
||||
# 遍历每个用户并更新数据(分批次,避免一次性加载所有ID)
|
||||
for end_user_id in repo.get_all_user_ids(batch_size=100):
|
||||
logger.info(f"开始处理用户: {end_user_id}")
|
||||
user_start_time = time.time()
|
||||
|
||||
implicit_success = False
|
||||
emotion_success = False
|
||||
errors = []
|
||||
|
||||
try:
|
||||
# 更新隐性记忆画像
|
||||
try:
|
||||
implicit_service = ImplicitMemoryService(db=db, end_user_id=end_user_id)
|
||||
profile_data = await implicit_service.generate_complete_profile(user_id=end_user_id)
|
||||
await implicit_service.save_profile_cache(
|
||||
end_user_id=end_user_id,
|
||||
profile_data=profile_data,
|
||||
db=db
|
||||
)
|
||||
implicit_success = True
|
||||
logger.info(f"成功更新用户 {end_user_id} 的隐性记忆画像")
|
||||
except Exception as e:
|
||||
error_msg = f"隐性记忆更新失败: {str(e)}"
|
||||
errors.append(error_msg)
|
||||
logger.error(f"用户 {end_user_id} {error_msg}")
|
||||
|
||||
# 更新情绪建议
|
||||
try:
|
||||
emotion_service = EmotionAnalyticsService()
|
||||
suggestions_data = await emotion_service.generate_emotion_suggestions(
|
||||
end_user_id=end_user_id,
|
||||
db=db,
|
||||
language="zh"
|
||||
)
|
||||
await emotion_service.save_suggestions_cache(
|
||||
end_user_id=end_user_id,
|
||||
suggestions_data=suggestions_data,
|
||||
db=db
|
||||
)
|
||||
emotion_success = True
|
||||
logger.info(f"成功更新用户 {end_user_id} 的情绪建议")
|
||||
except Exception as e:
|
||||
error_msg = f"情绪建议更新失败: {str(e)}"
|
||||
errors.append(error_msg)
|
||||
logger.error(f"用户 {end_user_id} {error_msg}")
|
||||
|
||||
# 统计结果
|
||||
if implicit_success:
|
||||
successful_implicit += 1
|
||||
if emotion_success:
|
||||
successful_emotion += 1
|
||||
if not implicit_success and not emotion_success:
|
||||
failed += 1
|
||||
|
||||
user_elapsed = time.time() - user_start_time
|
||||
|
||||
# 记录用户处理结果
|
||||
user_result = {
|
||||
"end_user_id": end_user_id,
|
||||
"implicit_success": implicit_success,
|
||||
"emotion_success": emotion_success,
|
||||
"errors": errors,
|
||||
"elapsed_time": user_elapsed
|
||||
}
|
||||
user_results.append(user_result)
|
||||
|
||||
logger.info(
|
||||
f"用户 {end_user_id} 处理完成: "
|
||||
f"隐性记忆={'成功' if implicit_success else '失败'}, "
|
||||
f"情绪建议={'成功' if emotion_success else '失败'}, "
|
||||
f"耗时={user_elapsed:.2f}秒"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
# 单个用户失败不影响其他用户(错误隔离)
|
||||
failed += 1
|
||||
user_elapsed = time.time() - user_start_time
|
||||
error_info = {
|
||||
"end_user_id": end_user_id,
|
||||
"implicit_success": False,
|
||||
"emotion_success": False,
|
||||
"errors": [str(e)],
|
||||
"elapsed_time": user_elapsed
|
||||
}
|
||||
user_results.append(error_info)
|
||||
logger.error(f"处理用户 {end_user_id} 时出错: {str(e)}")
|
||||
|
||||
# ---- 处理增量用户(当天新增、尚未初始化的用户)----
|
||||
new_users_initialized = 0
|
||||
new_users_failed = 0
|
||||
logger.info("开始处理当天新增的增量用户初始化")
|
||||
|
||||
for end_user_id in repo.get_new_user_ids_today(batch_size=100):
|
||||
logger.info(f"开始初始化新用户: {end_user_id}")
|
||||
user_start_time = time.time()
|
||||
implicit_success = False
|
||||
emotion_success = False
|
||||
errors = []
|
||||
|
||||
try:
|
||||
try:
|
||||
implicit_service = ImplicitMemoryService(db=db, end_user_id=end_user_id)
|
||||
profile_data = await implicit_service.generate_complete_profile(user_id=end_user_id)
|
||||
await implicit_service.save_profile_cache(
|
||||
end_user_id=end_user_id,
|
||||
profile_data=profile_data,
|
||||
db=db
|
||||
)
|
||||
implicit_success = True
|
||||
logger.info(f"成功初始化新用户 {end_user_id} 的隐性记忆画像")
|
||||
except Exception as e:
|
||||
error_msg = f"隐性记忆初始化失败: {str(e)}"
|
||||
errors.append(error_msg)
|
||||
logger.error(f"新用户 {end_user_id} {error_msg}")
|
||||
|
||||
try:
|
||||
emotion_service = EmotionAnalyticsService()
|
||||
suggestions_data = await emotion_service.generate_emotion_suggestions(
|
||||
end_user_id=end_user_id,
|
||||
db=db,
|
||||
language="zh"
|
||||
)
|
||||
await emotion_service.save_suggestions_cache(
|
||||
end_user_id=end_user_id,
|
||||
suggestions_data=suggestions_data,
|
||||
db=db
|
||||
)
|
||||
emotion_success = True
|
||||
logger.info(f"成功初始化新用户 {end_user_id} 的情绪建议")
|
||||
except Exception as e:
|
||||
error_msg = f"情绪建议初始化失败: {str(e)}"
|
||||
errors.append(error_msg)
|
||||
logger.error(f"新用户 {end_user_id} {error_msg}")
|
||||
|
||||
if implicit_success or emotion_success:
|
||||
new_users_initialized += 1
|
||||
else:
|
||||
new_users_failed += 1
|
||||
|
||||
user_elapsed = time.time() - user_start_time
|
||||
user_results.append({
|
||||
"end_user_id": end_user_id,
|
||||
"type": "init",
|
||||
"implicit_success": implicit_success,
|
||||
"emotion_success": emotion_success,
|
||||
"errors": errors,
|
||||
"elapsed_time": user_elapsed
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
new_users_failed += 1
|
||||
user_elapsed = time.time() - user_start_time
|
||||
user_results.append({
|
||||
"end_user_id": end_user_id,
|
||||
"type": "init",
|
||||
"implicit_success": False,
|
||||
"emotion_success": False,
|
||||
"errors": [str(e)],
|
||||
"elapsed_time": user_elapsed
|
||||
})
|
||||
logger.error(f"初始化新用户 {end_user_id} 时出错: {str(e)}")
|
||||
|
||||
logger.info(
|
||||
f"增量用户初始化完成: 成功={new_users_initialized}, 失败={new_users_failed}"
|
||||
)
|
||||
# ---- 增量用户处理结束 ----
|
||||
|
||||
# 记录总体统计信息
|
||||
logger.info(
|
||||
f"隐性记忆和情绪数据更新定时任务完成: "
|
||||
f"存量用户总数={total_users}, "
|
||||
f"隐性记忆成功={successful_implicit}, "
|
||||
f"情绪建议成功={successful_emotion}, "
|
||||
f"存量失败={failed}, "
|
||||
f"增量初始化成功={new_users_initialized}, "
|
||||
f"增量初始化失败={new_users_failed}"
|
||||
)
|
||||
|
||||
return {
|
||||
"status": "SUCCESS",
|
||||
"message": (
|
||||
f"存量用户 {total_users} 个,隐性记忆 {successful_implicit} 个成功,情绪建议 {successful_emotion} 个成功;"
|
||||
f"增量新用户初始化 {new_users_initialized} 个成功,{new_users_failed} 个失败"
|
||||
),
|
||||
"total_users": total_users,
|
||||
"successful_implicit": successful_implicit,
|
||||
"successful_emotion": successful_emotion,
|
||||
"failed": failed,
|
||||
"new_users_initialized": new_users_initialized,
|
||||
"new_users_failed": new_users_failed,
|
||||
"user_results": user_results[:50] # 只保留前50个用户的详细结果
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"隐性记忆和情绪数据更新定时任务执行失败: {str(e)}")
|
||||
return {
|
||||
"status": "FAILURE",
|
||||
"error": str(e),
|
||||
"total_users": total_users,
|
||||
"successful_implicit": successful_implicit,
|
||||
"successful_emotion": successful_emotion,
|
||||
"failed": failed,
|
||||
"new_users_initialized": 0,
|
||||
"new_users_failed": 0,
|
||||
"user_results": user_results[:50]
|
||||
}
|
||||
|
||||
try:
|
||||
# 使用 nest_asyncio 来避免事件循环冲突
|
||||
try:
|
||||
import nest_asyncio
|
||||
nest_asyncio.apply()
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# 尝试获取现有事件循环,如果不存在则创建新的
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
if loop.is_closed():
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
except RuntimeError:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
result = loop.run_until_complete(_run())
|
||||
elapsed_time = time.time() - start_time
|
||||
result["elapsed_time"] = elapsed_time
|
||||
result["task_id"] = self.request.id
|
||||
|
||||
return result
|
||||
except Exception as e:
|
||||
elapsed_time = time.time() - start_time
|
||||
return {
|
||||
"status": "FAILURE",
|
||||
"error": str(e),
|
||||
"elapsed_time": elapsed_time,
|
||||
"task_id": self.request.id
|
||||
}
|
||||
|
||||
@@ -29,10 +29,10 @@ REDIS_DB=
|
||||
REDIS_PASSWORD=password
|
||||
|
||||
#celery
|
||||
BROKER_URL=
|
||||
RESULT_BACKEND=
|
||||
CELERY_BROKER=
|
||||
CELERY_BACKEND=
|
||||
# NOTE: 不要使用 BROKER_URL / RESULT_BACKEND / CELERY_BROKER / CELERY_BACKEND,
|
||||
# 这些名称会被 Celery CLI 劫持,详见 docs/celery-env-bug-report.md
|
||||
REDIS_DB_CELERY_BROKER=
|
||||
REDIS_DB_CELERY_BACKEND=
|
||||
|
||||
# Memory Cache Regeneration Configuration
|
||||
# Interval in hours for regenerating memory insight and user summary cache
|
||||
|
||||
43
api/migrations/versions/6a4641cf192b_202603051440.py
Normal file
43
api/migrations/versions/6a4641cf192b_202603051440.py
Normal file
@@ -0,0 +1,43 @@
|
||||
"""202603051440
|
||||
|
||||
Revision ID: 6a4641cf192b
|
||||
Revises: b4af97639217
|
||||
Create Date: 2026-03-05 14:41:03.371557
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '6a4641cf192b'
|
||||
down_revision: Union[str, None] = 'b4af97639217'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table('implicit_emotions_storage',
|
||||
sa.Column('id', sa.UUID(), nullable=False, comment='主键ID'),
|
||||
sa.Column('end_user_id', sa.String(length=255), nullable=False, comment='终端用户ID'),
|
||||
sa.Column('implicit_profile', postgresql.JSONB(astext_type=sa.Text()), nullable=True, comment='隐性记忆用户画像数据'),
|
||||
sa.Column('emotion_suggestions', postgresql.JSONB(astext_type=sa.Text()), nullable=True, comment='情绪个性化建议数据'),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=False, comment='创建时间'),
|
||||
sa.Column('updated_at', sa.DateTime(), nullable=False, comment='更新时间'),
|
||||
sa.Column('implicit_generated_at', sa.DateTime(), nullable=True, comment='隐性记忆画像生成时间'),
|
||||
sa.Column('emotion_generated_at', sa.DateTime(), nullable=True, comment='情绪建议生成时间'),
|
||||
sa.PrimaryKeyConstraint('id'),
|
||||
sa.UniqueConstraint('end_user_id')
|
||||
)
|
||||
op.create_index('idx_updated_at', 'implicit_emotions_storage', ['updated_at'], unique=False)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_index('idx_updated_at', table_name='implicit_emotions_storage')
|
||||
op.drop_table('implicit_emotions_storage')
|
||||
# ### end Alembic commands ###
|
||||
63
api/migrations/versions/b4af97639217_202603051033.py
Normal file
63
api/migrations/versions/b4af97639217_202603051033.py
Normal file
@@ -0,0 +1,63 @@
|
||||
"""202603051033
|
||||
|
||||
Revision ID: b4af97639217
|
||||
Revises: 4bf27c66ae63
|
||||
Create Date: 2026-03-05 10:36:06.282227
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = 'b4af97639217'
|
||||
down_revision: Union[str, None] = '4bf27c66ae63'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
# Add columns as nullable first to avoid table locks
|
||||
op.add_column('model_api_keys', sa.Column('capability', sa.ARRAY(sa.String()), nullable=True, comment="模型能力列表(如['vision', 'audio', 'video'])"))
|
||||
op.add_column('model_api_keys', sa.Column('is_omni', sa.Boolean(), nullable=True, comment='是否为Omni模型(使用特殊API调用)'))
|
||||
|
||||
op.add_column('model_bases', sa.Column('capability', sa.ARRAY(sa.String()), nullable=True, comment="模型能力列表(如['vision', 'audio', 'video'])"))
|
||||
op.add_column('model_bases', sa.Column('is_omni', sa.Boolean(), nullable=True, comment='是否为Omni模型(使用特殊API调用)'))
|
||||
|
||||
op.add_column('model_configs', sa.Column('capability', sa.ARRAY(sa.String()), nullable=True, comment="模型能力列表(如['vision', 'audio', 'video'])"))
|
||||
op.add_column('model_configs', sa.Column('is_omni', sa.Boolean(), nullable=True, comment='是否为Omni模型(使用特殊API调用)'))
|
||||
|
||||
# Update existing rows with default values
|
||||
op.execute("UPDATE model_api_keys SET capability = '{}' WHERE capability IS NULL")
|
||||
op.execute("UPDATE model_api_keys SET is_omni = false WHERE is_omni IS NULL")
|
||||
|
||||
op.execute("UPDATE model_bases SET capability = '{}' WHERE capability IS NULL")
|
||||
op.execute("UPDATE model_bases SET is_omni = false WHERE is_omni IS NULL")
|
||||
|
||||
op.execute("UPDATE model_configs SET capability = '{}' WHERE capability IS NULL")
|
||||
op.execute("UPDATE model_configs SET is_omni = false WHERE is_omni IS NULL")
|
||||
|
||||
# Now make columns NOT NULL
|
||||
op.alter_column('model_api_keys', 'capability', nullable=False)
|
||||
op.alter_column('model_api_keys', 'is_omni', nullable=False)
|
||||
|
||||
op.alter_column('model_bases', 'capability', nullable=False)
|
||||
op.alter_column('model_bases', 'is_omni', nullable=False)
|
||||
|
||||
op.alter_column('model_configs', 'capability', nullable=False)
|
||||
op.alter_column('model_configs', 'is_omni', nullable=False)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_column('model_configs', 'is_omni')
|
||||
op.drop_column('model_configs', 'capability')
|
||||
op.drop_column('model_bases', 'is_omni')
|
||||
op.drop_column('model_bases', 'capability')
|
||||
op.drop_column('model_api_keys', 'is_omni')
|
||||
op.drop_column('model_api_keys', 'capability')
|
||||
# ### end Alembic commands ###
|
||||
@@ -163,9 +163,14 @@ export const getImplicitInterestAreas = (end_user_id: string) => {
|
||||
export const getImplicitHabits = (end_user_id: string) => {
|
||||
return request.get(`/memory/implicit-memory/habits/${end_user_id}`)
|
||||
}
|
||||
// Implicit Memory - Generate user portrait
|
||||
export const generateProfile = (end_user_id: string) => {
|
||||
return request.post(`/memory/implicit-memory/generate_profile`, { end_user_id })
|
||||
}
|
||||
// Implicit Memory - Check if data exists
|
||||
export const implicitCheckData = (end_user_id: string) => {
|
||||
return request.get(`/memory/implicit-memory/check-data/${end_user_id}`)
|
||||
}
|
||||
// Short-term memory
|
||||
export const getShortTerm = (end_user_id: string) => {
|
||||
return request.get(`/memory/short/short_term`, { end_user_id })
|
||||
|
||||
@@ -1,16 +1,21 @@
|
||||
import { type FC, useRef, useState } from 'react'
|
||||
import RecordRTC from 'recordrtc'
|
||||
|
||||
import { fileUpload } from '@/api/fileStorage'
|
||||
import { fileUploadUrlWithoutApiPrefix } from '@/api/fileStorage'
|
||||
import { request } from '@/utils/request'
|
||||
|
||||
interface AudioRecorderProps {
|
||||
onRecordingComplete?: (file: { file_id: string; file_key: string; }, blob: Blob) => void
|
||||
className?: string
|
||||
onRecordingComplete?: (file: { file_id: string; file_key: string; url: string; type?: string; }, blob?: Blob) => void
|
||||
className?: string;
|
||||
action?: string;
|
||||
requestConfig?: Record<string, any>;
|
||||
}
|
||||
|
||||
const AudioRecorder: FC<AudioRecorderProps> = ({
|
||||
onRecordingComplete,
|
||||
className = '',
|
||||
action = fileUploadUrlWithoutApiPrefix,
|
||||
requestConfig = {}
|
||||
}) => {
|
||||
const [isRecording, setIsRecording] = useState(false)
|
||||
const recorderRef = useRef<RecordRTC | null>(null)
|
||||
@@ -33,11 +38,17 @@ const AudioRecorder: FC<AudioRecorderProps> = ({
|
||||
if (recorderRef.current) {
|
||||
recorderRef.current.stopRecording(() => {
|
||||
const blob = recorderRef.current!.getBlob()
|
||||
const url = recorderRef.current!.toURL()
|
||||
const formData = new FormData()
|
||||
formData.append('file', blob, `recording_${Date.now()}.webm`)
|
||||
fileUpload(formData)
|
||||
request
|
||||
.uploadFile(action, formData, requestConfig)
|
||||
.then(res => {
|
||||
onRecordingComplete?.(res as { file_id: string; file_key: string; }, blob)
|
||||
onRecordingComplete?.({
|
||||
...(res as { file_id: string; file_key: string }),
|
||||
type: blob.type,
|
||||
url
|
||||
}, blob)
|
||||
recorderRef.current?.destroy()
|
||||
recorderRef.current = null
|
||||
})
|
||||
|
||||
@@ -2,10 +2,11 @@
|
||||
* @Author: ZhaoYing
|
||||
* @Date: 2025-12-10 16:46:14
|
||||
* @Last Modified by: ZhaoYing
|
||||
* @Last Modified time: 2026-02-10 12:13:52
|
||||
* @Last Modified time: 2026-03-04 18:42:49
|
||||
*/
|
||||
import { type FC, useEffect, useMemo } from 'react'
|
||||
import { Flex, Input, Form } from 'antd'
|
||||
|
||||
import SendIcon from '@/assets/images/conversation/send.svg'
|
||||
import SendDisabledIcon from '@/assets/images/conversation/sendDisabled.svg'
|
||||
import LoadingIcon from '@/assets/images/conversation/loading.svg'
|
||||
@@ -80,9 +81,31 @@ const ChatInput: FC<ChatInputProps> = ({
|
||||
</div>
|
||||
)
|
||||
}
|
||||
if (file.type.includes('video')) {
|
||||
return (
|
||||
<div key={file.uid} className="rb:w-45 rb:h-16 rb:inline-block rb:group rb:relative rb:rounded-lg">
|
||||
<video src={file.url} controls className="rb:w-45 rb:h-16 rb:rounded-lg rb:object-cover rb:cursor-pointer" />
|
||||
<div
|
||||
className="rb:hidden rb:group-hover:block rb:absolute rb:-right-1 rb:-top-1 rb:size-3.5 rb:cursor-pointer rb:bg-cover rb:bg-[url('@/assets/images/conversation/delete.svg')] rb:hover:bg-[url('@/assets/images/conversation/delete_hover.svg')]"
|
||||
onClick={() => handleDelete(file)}
|
||||
></div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
if (file.type.includes('audio')) {
|
||||
return (
|
||||
<div key={file.uid} className="rb:w-45 rb:h-16 rb:inline-flex rb:items-center rb:group rb:relative rb:rounded-lg rb:bg-[#F0F3F8] rb:py-2 rb:px-2.5 rb:gap-2">
|
||||
<audio src={file.url} controls className="rb:w-45 rb:h-16" />
|
||||
<div
|
||||
className="rb:hidden rb:group-hover:block rb:absolute rb:-right-1 rb:-top-1 rb:size-3.5 rb:cursor-pointer rb:bg-cover rb:bg-[url('@/assets/images/conversation/delete.svg')] rb:hover:bg-[url('@/assets/images/conversation/delete_hover.svg')]"
|
||||
onClick={() => handleDelete(file)}
|
||||
></div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
return (
|
||||
<div key={file.uid} className="rb:w-45 rb:text-[12px] rb:gap-2.5 rb:flex rb:items-center rb:group rb:relative rb:rounded-lg rb:bg-[#F0F3F8] rb:py-2 rb:px-2.5">
|
||||
{(file.type.includes('word') || file.type.includes('wordprocessingml.document')) && <div
|
||||
{(file.type.includes('doc') || file.type.includes('docx') || file.type.includes('word') || file.type.includes('wordprocessingml.document')) && <div
|
||||
className="rb:size-5 rb:cursor-pointer rb:bg-cover rb:bg-[url('@/assets/images/conversation/word_disabled.svg')] rb:hover:bg-[url('@/assets/images/conversation/word.svg')]"
|
||||
></div>}
|
||||
{(file.type.includes('pdf')) && <div
|
||||
|
||||
@@ -453,9 +453,11 @@ export const en = {
|
||||
prevStep: 'Previous Step',
|
||||
exportSuccess: 'Export successful',
|
||||
recommend: 'Recommend',
|
||||
default: 'Default',
|
||||
logoTip: `Supported image formats: JPG, PNG \n Suggested size: square ratio \n Maximum size: ≤ 2MB`,
|
||||
imageSquareRequired: 'Please upload a square image',
|
||||
nameInvalid: 'Name cannot start or end with a space',
|
||||
notAllSpaces: 'Cannot be all spaces',
|
||||
},
|
||||
model: {
|
||||
searchPlaceholder: 'search model…',
|
||||
@@ -603,7 +605,13 @@ export const en = {
|
||||
ollama: "Ollama",
|
||||
xinference: "Xinference",
|
||||
gpustack: "Gpustack",
|
||||
bedrock: "Bedrock"
|
||||
bedrock: "Bedrock",
|
||||
|
||||
is_vision: 'Vision Support',
|
||||
is_omni: 'Omni Support',
|
||||
vision: 'Vision',
|
||||
audio: 'Audio',
|
||||
video: 'Video',
|
||||
},
|
||||
knowledgeBase: {
|
||||
home: 'Home',
|
||||
@@ -1684,6 +1692,8 @@ Memory Bear: After the rebellion, regional warlordism intensified for several re
|
||||
uploadFile: 'Upload File',
|
||||
fileType: 'File Type',
|
||||
image: 'Image',
|
||||
video: 'Video',
|
||||
audio: 'Audio',
|
||||
fileUrl: 'File URL',
|
||||
addRemoteFile: 'Add Remote File',
|
||||
variableConfig: 'Variable Configuration',
|
||||
@@ -1962,6 +1972,8 @@ Memory Bear: After the rebellion, regional warlordism intensified for several re
|
||||
marketConnected: '● Connected',
|
||||
marketDisconnected: '○ Disconnected',
|
||||
marketConnecting: 'Connecting to {{name}}...',
|
||||
serverUrlInvalid: 'Must start with http:// or https://, and cannot have leading or trailing spaces',
|
||||
requestHeaderKeyInvalid: 'Only English letters, numbers, hyphens (-), and underscores (_) are allowed, and cannot start or end with a hyphen or underscore',
|
||||
},
|
||||
workflow: {
|
||||
coreNode: 'Core Nodes',
|
||||
@@ -2311,6 +2323,7 @@ Memory Bear: After the rebellion, regional warlordism intensified for several re
|
||||
suggestions: 'Personalized Suggestions',
|
||||
suggestionLoading: 'Your personalized suggestions are being generated',
|
||||
item: 'item',
|
||||
noData: 'Emotion suggestion data does not exist, please click the refresh button to initialize',
|
||||
},
|
||||
reflectionEngine: {
|
||||
reflectionEngineConfig: 'Reflection Engine Configuration',
|
||||
@@ -2557,7 +2570,8 @@ Memory Bear: After the rebellion, regional warlordism intensified for several re
|
||||
context_details: 'Preference Details',
|
||||
supporting_evidence: 'Preference Source',
|
||||
specific_examples: 'Source',
|
||||
wordEmpty: 'Click on a node in the left chart to view preference details'
|
||||
wordEmpty: 'Click on a node in the left chart to view preference details',
|
||||
noData: 'Portrait data does not exist, please click the refresh button to initialize',
|
||||
},
|
||||
shortTermDetail: {
|
||||
title: 'Short-term memory is the "workbench" of the AI system, connecting instant conversations with long-term knowledge bases. Through real-time capture, deep retrieval, intelligent extraction and filtering transformation, temporary unstructured information is converted into valuable long-term knowledge.',
|
||||
@@ -2615,6 +2629,7 @@ Memory Bear: After the rebellion, regional warlordism intensified for several re
|
||||
updated_at: 'Updated At',
|
||||
entityTypes: 'Entity Types',
|
||||
|
||||
classSearchPlaceholder: 'Search types',
|
||||
addClass: 'Add Type',
|
||||
class_name: 'Type Name',
|
||||
class_description: 'Type Definition',
|
||||
|
||||
@@ -1033,9 +1033,11 @@ export const zh = {
|
||||
prevStep: '上一步',
|
||||
exportSuccess: '导出成功',
|
||||
recommend: '推荐',
|
||||
default: '默认',
|
||||
logoTip: `支持图片格式(JPG、PNG)\n 尺寸:正方形比例 \n 文件大小限制:≤ 2MB`,
|
||||
imageSquareRequired: '请上传正方形比例图片',
|
||||
nameInvalid: '不能是空格开头或结尾',
|
||||
notAllSpaces: '不能是纯空格',
|
||||
},
|
||||
model: {
|
||||
searchPlaceholder: '搜索模型…',
|
||||
@@ -1184,6 +1186,12 @@ export const zh = {
|
||||
xinference: "Xinference",
|
||||
gpustack: "Gpustack",
|
||||
bedrock: "Bedrock",
|
||||
|
||||
is_vision: '支持视觉',
|
||||
is_omni: '支持全模态',
|
||||
vision: '视觉',
|
||||
audio: '音频',
|
||||
video: '视频',
|
||||
},
|
||||
timezones: {
|
||||
'Asia/Shanghai': '中国标准时间 (UTC+8)',
|
||||
@@ -1681,6 +1689,8 @@ export const zh = {
|
||||
uploadFile: '上传文件',
|
||||
fileType: '文件类型',
|
||||
image: '图片',
|
||||
video: '视频',
|
||||
audio: '音频',
|
||||
fileUrl: '文件链接',
|
||||
addRemoteFile: '添加远程文件',
|
||||
variableConfig: '变量配置',
|
||||
@@ -1959,6 +1969,8 @@ export const zh = {
|
||||
marketConnected: '● 已连接',
|
||||
marketDisconnected: '○ 未连接',
|
||||
marketConnecting: '正在连接 {{name}}...',
|
||||
serverUrlInvalid: '必须以 http:// 或 https:// 开头,且不能有前后空格',
|
||||
requestHeaderKeyInvalid: '只支持英文、数字、连字符(-)、下划线(_),不能以连字符或下划线开头结尾',
|
||||
},
|
||||
workflow: {
|
||||
coreNode: '核心节点',
|
||||
@@ -2312,6 +2324,7 @@ export const zh = {
|
||||
suggestions: '个性化建议',
|
||||
suggestionLoading: '您的个性化建议正在生成中',
|
||||
item: '个',
|
||||
noData: '情绪建议数据不存在,请点击刷新按钮进行初始化',
|
||||
},
|
||||
reflectionEngine: {
|
||||
reflectionEngineConfig: '反思引擎配置',
|
||||
@@ -2558,7 +2571,8 @@ export const zh = {
|
||||
context_details: '偏好详情',
|
||||
supporting_evidence: '偏好来源',
|
||||
specific_examples: '来源',
|
||||
wordEmpty: '点击左侧图表中的节点查看偏好详情'
|
||||
wordEmpty: '点击左侧图表中的节点查看偏好详情',
|
||||
noData: '画像数据不存在,请点击刷新按钮进行初始化',
|
||||
},
|
||||
shortTermDetail: {
|
||||
title: '短期记忆是AI系统的"工作台",连接即时对话与长期知识库。通过实时捕获、深度检索、智能提取和筛选转化,将临时的非结构化信息转化为有价值的长期知识。',
|
||||
@@ -2616,6 +2630,7 @@ export const zh = {
|
||||
updated_at: '更新时间',
|
||||
entityTypes: '实体类型',
|
||||
|
||||
classSearchPlaceholder: '搜索类型',
|
||||
addClass: '添加类型',
|
||||
class_name: '类型名称',
|
||||
class_description: '类型定义',
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
/*
|
||||
* @Author: ZhaoYing
|
||||
* @Date: 2026-02-02 16:35:15
|
||||
* @Last Modified by: ZhaoYing
|
||||
* @Last Modified time: 2026-02-02 16:35:15
|
||||
* @Last Modified by: ZhaoYing
|
||||
* @Last Modified time: 2026-03-06 10:39:00
|
||||
*/
|
||||
/**
|
||||
* HTTP Request Utility Module
|
||||
@@ -183,7 +183,9 @@ service.interceptors.response.use(
|
||||
msg = msg || i18n.t('common.serverError');
|
||||
break;
|
||||
default:
|
||||
if (!msg && Array.isArray(error.response?.data?.detail)) {
|
||||
if (['SYSTEM_DEFAULT_SCENE_CANNOT_DELETE', 'SYSTEM_DEFAULT_CLASS_CANNOT_DELETE', 'SYSTEM_DEFAULT_SCENE_CANNOT_UPDATE'].includes(msg)) {
|
||||
msg = i18n.t(`common.${msg}`)
|
||||
} else if (!msg && Array.isArray(error.response?.data?.detail)) {
|
||||
msg = error.response?.data?.detail?.map((item: { msg: string }) => item.msg).join(';')
|
||||
} else {
|
||||
msg = msg || i18n.t('common.unknownError');
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
/*
|
||||
* @Author: ZhaoYing
|
||||
* @Date: 2026-02-02 16:35:43
|
||||
* @Last Modified by: ZhaoYing
|
||||
* @Last Modified time: 2026-02-02 16:35:43
|
||||
* @Last Modified by: ZhaoYing
|
||||
* @Last Modified time: 2026-03-04 18:19:24
|
||||
*/
|
||||
/**
|
||||
* Server-Sent Events (SSE) Stream Utility Module
|
||||
@@ -176,17 +176,17 @@ export const handleSSE = async (url: string, data: any, onMessage?: (data: SSEMe
|
||||
case 500:
|
||||
case 502:
|
||||
const errorData = await response.json();
|
||||
errorData.error || i18n.t('common.serviceUpgrading');
|
||||
message.warning(errorData.error || i18n.t('common.serviceUpgrading'));
|
||||
return;
|
||||
let errorInfo = errorData.error || i18n.t('common.serviceUpgrading')
|
||||
message.warning(errorInfo);
|
||||
throw errorInfo;
|
||||
case 400:
|
||||
const error = await response.json();
|
||||
message.warning(error.error);
|
||||
throw error || 'Bad Request';
|
||||
throw error.error || 'Bad Request';
|
||||
case 504:
|
||||
const errorJson = await response.json();
|
||||
message.warning(errorJson.error || i18n.t('common.serverError'));
|
||||
return;
|
||||
throw errorData.error;
|
||||
case 401:
|
||||
if (url?.includes('/public')) {
|
||||
return message.warning(i18n.t('common.publicApiCannotRefreshToken'));
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
* @Author: ZhaoYing
|
||||
* @Date: 2026-02-03 16:27:39
|
||||
* @Last Modified by: ZhaoYing
|
||||
* @Last Modified time: 2026-03-03 14:21:54
|
||||
* @Last Modified time: 2026-03-05 17:03:46
|
||||
*/
|
||||
/**
|
||||
* Chat debugging component for application testing
|
||||
@@ -13,7 +13,7 @@
|
||||
import { type FC, useEffect, useState, useRef } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import clsx from 'clsx'
|
||||
import { Flex, Dropdown, type MenuProps, App } from 'antd'
|
||||
import { Flex, Dropdown, type MenuProps, App, Divider } from 'antd'
|
||||
|
||||
import ChatIcon from '@/assets/images/application/chat.png'
|
||||
import DebuggingEmpty from '@/assets/images/application/debuggingEmpty.png'
|
||||
@@ -25,7 +25,7 @@ import type { ChatItem } from '@/components/Chat/types'
|
||||
import { type SSEMessage } from '@/utils/stream'
|
||||
import ChatInput from '@/components/Chat/ChatInput'
|
||||
import UploadFiles from '@/views/Conversation/components/FileUpload'
|
||||
// import AudioRecorder from '@/components/AudioRecorder'
|
||||
import AudioRecorder from '@/components/AudioRecorder'
|
||||
import UploadFileListModal from '@/views/Conversation/components/UploadFileListModal'
|
||||
import type { UploadFileListModalRef } from '@/views/Conversation/types'
|
||||
import type { Variable } from './VariableList/types'
|
||||
@@ -88,7 +88,7 @@ const Chat: FC<ChatProps> = ({ chatList, data, updateChatList, handleSave, sourc
|
||||
content: '',
|
||||
created_at: Date.now(),
|
||||
};
|
||||
|
||||
|
||||
if (isCluster) {
|
||||
updateChatList(prev => prev.map(item => ({
|
||||
...item,
|
||||
@@ -134,7 +134,7 @@ const Chat: FC<ChatProps> = ({ chatList, data, updateChatList, handleSave, sourc
|
||||
})
|
||||
}
|
||||
/** Update assistant message when error occurs */
|
||||
const updateErrorAssistantMessage = (message_length: number, model_config_id?: string) => {
|
||||
const updateErrorAssistantMessage = (message_length: number, model_config_id?: string) => {
|
||||
if (message_length > 0 || !model_config_id) return
|
||||
|
||||
updateChatList(prev => {
|
||||
@@ -171,6 +171,29 @@ const Chat: FC<ChatProps> = ({ chatList, data, updateChatList, handleSave, sourc
|
||||
.then(() => {
|
||||
const message = msg
|
||||
if (!message?.trim()) return
|
||||
// Validate required variables before sending
|
||||
let isCanSend = true
|
||||
const params: Record<string, any> = {}
|
||||
if (chatVariables && chatVariables.length > 0) {
|
||||
const needRequired: string[] = []
|
||||
chatVariables.forEach(vo => {
|
||||
params[vo.name] = vo.value
|
||||
|
||||
if (vo.required && (params[vo.name] === null || params[vo.name] === undefined || params[vo.name] === '')) {
|
||||
isCanSend = false
|
||||
needRequired.push(vo.name)
|
||||
}
|
||||
})
|
||||
|
||||
if (needRequired.length) {
|
||||
messageApi.error(`${needRequired.join(',')} ${t('workflow.variableRequired')}`)
|
||||
}
|
||||
}
|
||||
if (!isCanSend) {
|
||||
setLoading(false)
|
||||
setCompareLoading(false)
|
||||
return
|
||||
}
|
||||
|
||||
addUserMessage(message, fileList)
|
||||
setMessage(message)
|
||||
@@ -198,27 +221,6 @@ const Chat: FC<ChatProps> = ({ chatList, data, updateChatList, handleSave, sourc
|
||||
};
|
||||
|
||||
setTimeout(() => {
|
||||
// Validate required variables before sending
|
||||
let isCanSend = true
|
||||
const params: Record<string, any> = {}
|
||||
if (chatVariables && chatVariables.length > 0) {
|
||||
const needRequired: string[] = []
|
||||
chatVariables.forEach(vo => {
|
||||
params[vo.name] = vo.value
|
||||
|
||||
if (vo.required && (params[vo.name] === null || params[vo.name] === undefined || params[vo.name] === '')) {
|
||||
isCanSend = false
|
||||
needRequired.push(vo.name)
|
||||
}
|
||||
})
|
||||
|
||||
if (needRequired.length) {
|
||||
messageApi.error(`${needRequired.join(',')} ${t('workflow.variableRequired')}`)
|
||||
}
|
||||
}
|
||||
if (!isCanSend) {
|
||||
return
|
||||
}
|
||||
runCompare(data.app_id, {
|
||||
message,
|
||||
files: fileList.map(file => {
|
||||
@@ -243,7 +245,15 @@ const Chat: FC<ChatProps> = ({ chatList, data, updateChatList, handleSave, sourc
|
||||
"stream": true,
|
||||
"timeout": 60,
|
||||
}, handleStreamMessage)
|
||||
.finally(() => setLoading(false));
|
||||
.catch(() => {
|
||||
setLoading(false)
|
||||
setCompareLoading(false)
|
||||
updateClusterErrorAssistantMessage(0)
|
||||
})
|
||||
.finally(() => {
|
||||
setLoading(false)
|
||||
setCompareLoading(false)
|
||||
})
|
||||
}, 0)
|
||||
})
|
||||
.catch(() => {
|
||||
@@ -288,7 +298,7 @@ const Chat: FC<ChatProps> = ({ chatList, data, updateChatList, handleSave, sourc
|
||||
})
|
||||
}
|
||||
/** Update cluster message when error occurs */
|
||||
const updateClusterErrorAssistantMessage = (message_length: number) => {
|
||||
const updateClusterErrorAssistantMessage = (message_length: number) => {
|
||||
if (message_length > 0) return
|
||||
|
||||
updateChatList(prev => {
|
||||
@@ -331,7 +341,7 @@ const Chat: FC<ChatProps> = ({ chatList, data, updateChatList, handleSave, sourc
|
||||
data.map(item => {
|
||||
const { conversation_id, content, message_length } = item.data as { conversation_id: string, content: string, message_length: number };
|
||||
|
||||
switch(item.event) {
|
||||
switch (item.event) {
|
||||
case 'start':
|
||||
if (conversation_id && conversationId !== conversation_id) {
|
||||
setConversationId(conversation_id);
|
||||
@@ -354,27 +364,35 @@ const Chat: FC<ChatProps> = ({ chatList, data, updateChatList, handleSave, sourc
|
||||
};
|
||||
|
||||
setTimeout(() => {
|
||||
draftRun(
|
||||
data.app_id,
|
||||
{
|
||||
message,
|
||||
conversation_id: conversationId,
|
||||
stream: true,
|
||||
files: fileList.map(file => {
|
||||
if (file.url) {
|
||||
return file
|
||||
} else {
|
||||
return {
|
||||
type: file.type,
|
||||
transfer_method: 'local_file',
|
||||
upload_file_id: file.response.data.file_id
|
||||
}
|
||||
draftRun(
|
||||
data.app_id,
|
||||
{
|
||||
message,
|
||||
conversation_id: conversationId,
|
||||
stream: true,
|
||||
files: fileList.map(file => {
|
||||
if (file.url) {
|
||||
return file
|
||||
} else {
|
||||
return {
|
||||
type: file.type,
|
||||
transfer_method: 'local_file',
|
||||
upload_file_id: file.response.data.file_id
|
||||
}
|
||||
}),
|
||||
},
|
||||
handleStreamMessage
|
||||
)
|
||||
.finally(() => setLoading(false))
|
||||
}
|
||||
}),
|
||||
},
|
||||
handleStreamMessage
|
||||
)
|
||||
.catch(() => {
|
||||
setLoading(false)
|
||||
setCompareLoading(false)
|
||||
updateClusterErrorAssistantMessage(0)
|
||||
})
|
||||
.finally(() => {
|
||||
setLoading(false)
|
||||
setCompareLoading(false)
|
||||
})
|
||||
}, 0)
|
||||
})
|
||||
.catch(() => {
|
||||
@@ -393,12 +411,17 @@ const Chat: FC<ChatProps> = ({ chatList, data, updateChatList, handleSave, sourc
|
||||
const fileChange = (file?: any) => {
|
||||
setFileList([...fileList, file])
|
||||
}
|
||||
// const handleRecordingComplete = async (file: any) => {
|
||||
// console.log('file', file)
|
||||
// }
|
||||
const handleRecordingComplete = async (file: any) => {
|
||||
setFileList([...fileList, {
|
||||
uid: file.file_id,
|
||||
response: { data: file },
|
||||
thumbUrl: file.url,
|
||||
type: file.type
|
||||
}])
|
||||
}
|
||||
|
||||
const handleShowUpload: MenuProps['onClick'] = ({ key }) => {
|
||||
switch(key) {
|
||||
switch (key) {
|
||||
case 'define':
|
||||
uploadFileListModalRef.current?.handleOpen()
|
||||
break
|
||||
@@ -415,99 +438,98 @@ const Chat: FC<ChatProps> = ({ chatList, data, updateChatList, handleSave, sourc
|
||||
return (
|
||||
<div className="rb:relative rb:h-full rb:flex rb:flex-col">
|
||||
{chatList.length === 0
|
||||
? <Empty
|
||||
url={DebuggingEmpty}
|
||||
? <Empty
|
||||
url={DebuggingEmpty}
|
||||
size={[300, 200]}
|
||||
title={t('application.debuggingEmpty')}
|
||||
subTitle={t('application.debuggingEmptyDesc')}
|
||||
title={t('application.debuggingEmpty')}
|
||||
subTitle={t('application.debuggingEmptyDesc')}
|
||||
className="rb:h-full"
|
||||
/>
|
||||
: <>
|
||||
<div className={clsx(`rb:relative rb:grid rb:grid-cols-${chatList.length} rb:overflow-hidden rb:w-full rb:flex-1 rb:min-h-0`)}>
|
||||
{chatList.map((chat, index) => (
|
||||
<div key={index} className={clsx('rb:flex rb:flex-col', {
|
||||
"rb:border-r rb:border-[#DFE4ED]": index !== chatList.length - 1 && chatList.length > 1,
|
||||
})}>
|
||||
{chat.label &&
|
||||
<div className={clsx(
|
||||
"rb:grid rb:bg-[#F0F3F8] rb:text-center rb:flex-[0_0_auto]",
|
||||
{
|
||||
'rb:rounded-tr-xl': index === chatList.length - 1,
|
||||
'rb:rounded-tl-xl': index === 0,
|
||||
}
|
||||
)}>
|
||||
<div className='rb:relative rb:p-[10px_12px] rb:overflow-hidden'>
|
||||
<div className="rb:text-ellipsis rb:overflow-hidden rb:whitespace-nowrap rb:w-[calc(100%-24px)]">{chat.label}</div>
|
||||
<div
|
||||
className="rb:w-4 rb:h-4 rb:cursor-pointer rb:absolute rb:top-3 rb:right-3 rb:bg-cover rb:bg-[url('@/assets/images/close.svg')] rb:hover:bg-[url('@/assets/images/close_hover.svg')]"
|
||||
onClick={() => handleDelete(index)}
|
||||
></div>
|
||||
: <>
|
||||
<div className={clsx(`rb:relative rb:grid rb:grid-cols-${chatList.length} rb:overflow-hidden rb:w-full rb:flex-1 rb:min-h-0`)}>
|
||||
{chatList.map((chat, index) => (
|
||||
<div key={index} className={clsx('rb:flex rb:flex-col', {
|
||||
"rb:border-r rb:border-[#DFE4ED]": index !== chatList.length - 1 && chatList.length > 1,
|
||||
})}>
|
||||
{chat.label &&
|
||||
<div className={clsx(
|
||||
"rb:grid rb:bg-[#F0F3F8] rb:text-center rb:flex-[0_0_auto]",
|
||||
{
|
||||
'rb:rounded-tr-xl': index === chatList.length - 1,
|
||||
'rb:rounded-tl-xl': index === 0,
|
||||
}
|
||||
)}>
|
||||
<div className='rb:relative rb:p-[10px_12px] rb:overflow-hidden'>
|
||||
<div className="rb:text-ellipsis rb:overflow-hidden rb:whitespace-nowrap rb:w-[calc(100%-24px)]">{chat.label}</div>
|
||||
<div
|
||||
className="rb:w-4 rb:h-4 rb:cursor-pointer rb:absolute rb:top-3 rb:right-3 rb:bg-cover rb:bg-[url('@/assets/images/close.svg')] rb:hover:bg-[url('@/assets/images/close_hover.svg')]"
|
||||
onClick={() => handleDelete(index)}
|
||||
></div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
}
|
||||
<ChatContent
|
||||
classNames={{
|
||||
'rb:mx-[16px] rb:mt-6': true,
|
||||
'rb:h-[calc(100vh-282px)]': isCluster,
|
||||
'rb:h-[calc(100vh-380px)]': !isCluster,
|
||||
}}
|
||||
contentClassNames={{
|
||||
'rb:max-w-[400px]!': chatList.length === 1,
|
||||
'rb:max-w-[260px]!': chatList.length === 2,
|
||||
'rb:max-w-[150px]!': chatList.length === 3,
|
||||
'rb:max-w-[108px]!': chatList.length === 4,
|
||||
}}
|
||||
empty={<Empty url={ChatIcon} title={t('application.chatEmpty')} isNeedSubTitle={false} size={[240, 200]} className="rb:h-full" />}
|
||||
data={chat.list || []}
|
||||
streamLoading={compareLoading}
|
||||
labelPosition="top"
|
||||
labelFormat={(item) => item.role === 'user' ? t('application.you') : chat.label}
|
||||
errorDesc={t('application.ReplyException')}
|
||||
/>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
<div className="rb:relative rb:flex rb:items-center rb:gap-2.5 rb:m-4 rb:mb-1">
|
||||
<ChatInput
|
||||
message={message}
|
||||
className="rb:relative!"
|
||||
loading={loading}
|
||||
fileChange={updateFileList}
|
||||
fileList={fileList}
|
||||
onSend={isCluster ? handleClusterSend : handleSend}
|
||||
onChange={handleMessageChange}
|
||||
>
|
||||
<Flex justify="space-between" className="rb:flex-1">
|
||||
<Flex gap={8} align="center">
|
||||
<Dropdown
|
||||
menu={{
|
||||
items: [
|
||||
{ key: 'define', label: t('memoryConversation.addRemoteFile') },
|
||||
{
|
||||
key: 'upload', label: (
|
||||
<UploadFiles
|
||||
fileType={['jpg', 'jpeg', 'png', 'gif', 'bmp', 'webp', 'svg']}
|
||||
onChange={fileChange}
|
||||
/>
|
||||
)
|
||||
},
|
||||
],
|
||||
onClick: handleShowUpload
|
||||
}
|
||||
<ChatContent
|
||||
classNames={{
|
||||
'rb:mx-[16px] rb:mt-6': true,
|
||||
'rb:h-[calc(100vh-282px)]': isCluster,
|
||||
'rb:h-[calc(100vh-380px)]': !isCluster,
|
||||
}}
|
||||
>
|
||||
<div
|
||||
className="rb:size-6 rb:cursor-pointer rb:bg-cover rb:bg-[url('@/assets/images/conversation/link.svg')] rb:hover:bg-[url('@/assets/images/conversation/link_hover.svg')]"
|
||||
></div>
|
||||
</Dropdown>
|
||||
contentClassNames={{
|
||||
'rb:max-w-[400px]!': chatList.length === 1,
|
||||
'rb:max-w-[260px]!': chatList.length === 2,
|
||||
'rb:max-w-[150px]!': chatList.length === 3,
|
||||
'rb:max-w-[108px]!': chatList.length === 4,
|
||||
}}
|
||||
empty={<Empty url={ChatIcon} title={t('application.chatEmpty')} isNeedSubTitle={false} size={[240, 200]} className="rb:h-full" />}
|
||||
data={chat.list || []}
|
||||
streamLoading={compareLoading}
|
||||
labelPosition="top"
|
||||
labelFormat={(item) => item.role === 'user' ? t('application.you') : chat.label}
|
||||
errorDesc={t('application.ReplyException')}
|
||||
/>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
<div className="rb:relative rb:flex rb:items-center rb:gap-2.5 rb:m-4 rb:mb-1">
|
||||
<ChatInput
|
||||
message={message}
|
||||
className="rb:relative!"
|
||||
loading={loading}
|
||||
fileChange={updateFileList}
|
||||
fileList={fileList}
|
||||
onSend={isCluster ? handleClusterSend : handleSend}
|
||||
onChange={handleMessageChange}
|
||||
>
|
||||
<Flex justify="space-between" className="rb:flex-1">
|
||||
<Flex gap={8} align="center">
|
||||
<Dropdown
|
||||
menu={{
|
||||
items: [
|
||||
{ key: 'define', label: t('memoryConversation.addRemoteFile') },
|
||||
{
|
||||
key: 'upload', label: (
|
||||
<UploadFiles
|
||||
onChange={fileChange}
|
||||
/>
|
||||
)
|
||||
},
|
||||
],
|
||||
onClick: handleShowUpload
|
||||
}}
|
||||
>
|
||||
<div
|
||||
className="rb:size-6 rb:cursor-pointer rb:bg-cover rb:bg-[url('@/assets/images/conversation/link.svg')] rb:hover:bg-[url('@/assets/images/conversation/link_hover.svg')]"
|
||||
></div>
|
||||
</Dropdown>
|
||||
</Flex>
|
||||
<Flex align="center">
|
||||
<AudioRecorder onRecordingComplete={handleRecordingComplete} />
|
||||
<Divider type="vertical" className="rb:ml-1.5! rb:mr-3!" />
|
||||
</Flex>
|
||||
</Flex>
|
||||
{/* <Flex align="center">
|
||||
<AudioRecorder onRecordingComplete={handleRecordingComplete} />
|
||||
<Divider type="vertical" className="rb:ml-1.5! rb:mr-3!" />
|
||||
</Flex> */}
|
||||
</Flex>
|
||||
</ChatInput>
|
||||
</div>
|
||||
</>
|
||||
</ChatInput>
|
||||
</div>
|
||||
</>
|
||||
}
|
||||
|
||||
<UploadFileListModal
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
* @Author: ZhaoYing
|
||||
* @Date: 2026-02-06 21:09:42
|
||||
* @Last Modified by: ZhaoYing
|
||||
* @Last Modified time: 2026-02-11 11:32:48
|
||||
* @Last Modified time: 2026-03-05 15:09:22
|
||||
*/
|
||||
/**
|
||||
* File Upload Component
|
||||
@@ -25,6 +25,7 @@ import { Upload, Progress, App } from 'antd';
|
||||
import type { UploadProps, UploadFile } from 'antd';
|
||||
import type { UploadProps as RcUploadProps } from 'antd/es/upload/interface';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
import { request } from '@/utils/request'
|
||||
import { fileUploadUrlWithoutApiPrefix } from '@/api/fileStorage'
|
||||
|
||||
@@ -56,27 +57,36 @@ interface UploadFilesProps extends Omit<UploadProps, 'onChange'> {
|
||||
/** Custom file removal callback */
|
||||
onRemove?: (file: UploadFile) => boolean | void | Promise<boolean | void>;
|
||||
}
|
||||
|
||||
const transform_file_type = {
|
||||
'text/plain': 'document/text',
|
||||
'text/markdown': 'document/markdown',
|
||||
'text/x-markdown': 'document/x-markdown',
|
||||
|
||||
'application/pdf': 'document/pdf',
|
||||
|
||||
'application/msword': 'document/doc',
|
||||
'application/vnd.openxmlformats-officedocument.wordprocessingml.document': 'document/docx',
|
||||
|
||||
'application/vnd.ms-powerpoint': 'document/ppt',
|
||||
'application/vnd.openxmlformats-officedocument.presentationml.presentation': 'document/pptx',
|
||||
}
|
||||
// Mapping of file extensions to MIME types
|
||||
const ALL_FILE_TYPE: {
|
||||
[key: string]: string;
|
||||
} = {
|
||||
// txt: 'text/plain',
|
||||
txt: 'text/plain',
|
||||
md: 'text/markdown',
|
||||
xmd: 'text/x-markdown',
|
||||
|
||||
pdf: 'application/pdf',
|
||||
|
||||
doc: 'application/msword',
|
||||
docx: 'application/vnd.openxmlformats-officedocument.wordprocessingml.document',
|
||||
|
||||
xls: 'application/vnd.ms-excel',
|
||||
xlsx: 'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet',
|
||||
csv: 'text/csv',
|
||||
|
||||
ppt: 'application/vnd.ms-powerpoint',
|
||||
pptx: 'application/vnd.openxmlformats-officedocument.presentationml.presentation',
|
||||
|
||||
// md: 'text/markdown',
|
||||
// htm: 'text/html',
|
||||
// html: 'text/html',
|
||||
// json: 'application/json',
|
||||
|
||||
jpg: 'image/jpeg',
|
||||
jpeg: 'image/jpeg',
|
||||
png: 'image/png',
|
||||
@@ -84,6 +94,23 @@ const ALL_FILE_TYPE: {
|
||||
bmp: 'image/bmp',
|
||||
webp: 'image/webp',
|
||||
svg: 'image/svg+xml',
|
||||
|
||||
mp4: 'video/mp4',
|
||||
mov: 'video/quicktime',
|
||||
avi: 'video/x-msvideo',
|
||||
mkv: 'video/x-matroska',
|
||||
webm: 'video/webm',
|
||||
flv: 'video/x-flv',
|
||||
wmv: 'video/x-ms-wmv',
|
||||
|
||||
mp3: 'audio/mpeg',
|
||||
wav: 'audio/wav',
|
||||
ogg: 'audio/ogg',
|
||||
aac: 'audio/aac',
|
||||
flac: 'audio/flac',
|
||||
m4a: 'audio/mp4',
|
||||
wma: 'audio/x-ms-wma',
|
||||
xm4a: 'audio/x-m4a',
|
||||
}
|
||||
export interface UploadFilesRef {
|
||||
/** Current file list */
|
||||
@@ -178,6 +205,10 @@ const UploadFiles = forwardRef<UploadFilesRef, UploadFilesProps>(({
|
||||
* Handles upload state changes
|
||||
*/
|
||||
const handleChange: UploadProps['onChange'] = ({ fileList: newFileList }) => {
|
||||
newFileList.map(file => {
|
||||
const type = (file.type && transform_file_type[file.type as keyof typeof transform_file_type]) || file.type || 'document'
|
||||
file.type = type
|
||||
})
|
||||
setFileList(newFileList);
|
||||
if (onChange) {
|
||||
onChange(maxCount === 1 ? newFileList[newFileList.length - 1] : newFileList);
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
* @Author: ZhaoYing
|
||||
* @Date: 2026-02-06 21:09:47
|
||||
* @Last Modified by: ZhaoYing
|
||||
* @Last Modified time: 2026-02-09 10:17:54
|
||||
* @Last Modified time: 2026-03-04 17:47:09
|
||||
*/
|
||||
/**
|
||||
* Upload File List Modal Component
|
||||
@@ -104,7 +104,9 @@ const UploadFileListModal = forwardRef<UploadFileListModalRef, UploadFileListMod
|
||||
<Select
|
||||
placeholder={t('memoryConversation.fileType')}
|
||||
options={[
|
||||
{ label: t('memoryConversation.image'), value: 'image' }
|
||||
{ label: t('memoryConversation.image'), value: 'image' },
|
||||
{ label: t('memoryConversation.audio'), value: 'audio' },
|
||||
{ label: t('memoryConversation.video'), value: 'video' },
|
||||
]}
|
||||
className="rb:w-30"
|
||||
/>
|
||||
|
||||
@@ -14,7 +14,7 @@ import { type FC, useState, useEffect, useRef } from 'react'
|
||||
import { useParams, useLocation } from 'react-router-dom'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import InfiniteScroll from 'react-infinite-scroll-component';
|
||||
import { Flex, Skeleton, Form, Dropdown, type MenuProps, App } from 'antd'
|
||||
import { Flex, Skeleton, Form, Dropdown, type MenuProps, App, Divider } from 'antd'
|
||||
import { SettingOutlined } from '@ant-design/icons'
|
||||
import clsx from 'clsx'
|
||||
import dayjs from 'dayjs'
|
||||
@@ -35,7 +35,7 @@ import OnlineCheckedIcon from '@/assets/images/conversation/onlineChecked.svg'
|
||||
import MemoryFunctionCheckedIcon from '@/assets/images/conversation/memoryFunctionChecked.svg'
|
||||
import { type SSEMessage } from '@/utils/stream'
|
||||
import UploadFiles from './components/FileUpload'
|
||||
// import AudioRecorder from '@/components/AudioRecorder'
|
||||
import AudioRecorder from '@/components/AudioRecorder'
|
||||
import { shareFileUploadUrlWithoutApiPrefix } from '@/api/fileStorage'
|
||||
import UploadFileListModal from './components/UploadFileListModal'
|
||||
import type { VariableConfigModalRef } from '@/views/Workflow/types'
|
||||
@@ -305,17 +305,27 @@ const Conversation: FC = () => {
|
||||
}),
|
||||
variables: params
|
||||
}, handleStreamMessage, shareToken)
|
||||
.catch(() => {
|
||||
setLoading(false)
|
||||
setStreamLoading(false)
|
||||
})
|
||||
.finally(() => {
|
||||
setLoading(false)
|
||||
setStreamLoading(false)
|
||||
})
|
||||
}
|
||||
|
||||
const fileChange = (file?: any) => {
|
||||
form.setFieldValue('files', [...(queryValues.files || []), file])
|
||||
}
|
||||
// const handleRecordingComplete = async (file: any) => {
|
||||
// console.log('file', file)
|
||||
// }
|
||||
const handleRecordingComplete = async (file: any) => {
|
||||
form.setFieldValue('files', [...(queryValues.files || []), {
|
||||
uid: file.file_id,
|
||||
response: { data: file },
|
||||
thumbUrl: file.url,
|
||||
type: file.type
|
||||
}])
|
||||
}
|
||||
|
||||
const handleShowUpload: MenuProps['onClick'] = ({ key }) => {
|
||||
switch(key) {
|
||||
@@ -329,6 +339,7 @@ const Conversation: FC = () => {
|
||||
form.setFieldValue('files', [...(queryValues.files || []), ...fileList])
|
||||
}
|
||||
const updateFileList = (fileList?: any[]) => {
|
||||
console.log('fileList', fileList)
|
||||
form.setFieldValue('files', [...(fileList || [])])
|
||||
}
|
||||
|
||||
@@ -383,7 +394,7 @@ const Conversation: FC = () => {
|
||||
<div className='rb:w-190 rb:h-screen rb:mx-auto rb:pt-10'>
|
||||
<Chat
|
||||
empty={<Empty url={ChatEmpty} className="rb:h-full" size={[320,180]} title={t('memoryConversation.chatEmpty')} subTitle={t('memoryConversation.emptyDesc')} />}
|
||||
contentClassName="rb:h-[calc(100%-180px)]"
|
||||
contentClassName={!queryValues?.files?.length ? "rb:h-[calc(100%-144px)]" : "rb:h-[calc(100%-208px)]"}
|
||||
data={chatList}
|
||||
streamLoading={streamLoading}
|
||||
loading={loading}
|
||||
@@ -405,13 +416,12 @@ const Conversation: FC = () => {
|
||||
key: 'upload', label: (
|
||||
<UploadFiles
|
||||
action={shareFileUploadUrlWithoutApiPrefix}
|
||||
fileType={['jpg', 'jpeg', 'png', 'gif', 'bmp', 'webp', 'svg']}
|
||||
onChange={fileChange}
|
||||
requestConfig={{
|
||||
headers: {
|
||||
'Content-Type': 'multipart/form-data',
|
||||
Authorization: `Bearer ${shareToken || ''}`,
|
||||
} }}
|
||||
}}}
|
||||
/>
|
||||
)
|
||||
},
|
||||
@@ -455,10 +465,19 @@ const Conversation: FC = () => {
|
||||
</Form.Item>
|
||||
)}
|
||||
</Flex>
|
||||
{/* <Flex align="center">
|
||||
<AudioRecorder onRecordingComplete={handleRecordingComplete} />
|
||||
<Flex align="center">
|
||||
<AudioRecorder
|
||||
action={shareFileUploadUrlWithoutApiPrefix}
|
||||
requestConfig={{
|
||||
headers: {
|
||||
'Content-Type': 'multipart/form-data',
|
||||
Authorization: `Bearer ${shareToken || ''}`,
|
||||
}
|
||||
}}
|
||||
onRecordingComplete={handleRecordingComplete}
|
||||
/>
|
||||
<Divider type="vertical" className="rb:ml-1.5! rb:mr-3!" />
|
||||
</Flex> */}
|
||||
</Flex>
|
||||
</Flex>
|
||||
</Form>
|
||||
</Chat>
|
||||
|
||||
@@ -15,6 +15,7 @@ import {
|
||||
} from '@/api/knowledgeBase'
|
||||
import RbModal from '@/components/RbModal'
|
||||
import SliderInput from '@/components/SliderInput'
|
||||
import { stringRegExp } from '@/utils/validator'
|
||||
const { TextArea } = Input;
|
||||
const { confirm } = Modal
|
||||
|
||||
@@ -519,12 +520,16 @@ const CreateModal = forwardRef<CreateModalRef, CreateModalRefProps>(({
|
||||
<Form.Item
|
||||
name="name"
|
||||
label={t('knowledgeBase.createForm.name')}
|
||||
rules={[{ required: true, message: t('knowledgeBase.createForm.nameRequired') }]}
|
||||
rules={[
|
||||
{ required: true, message: t('knowledgeBase.createForm.nameRequired') },
|
||||
{ max: 50 },
|
||||
{ pattern: stringRegExp, message: t('common.nameInvalid') },
|
||||
]}
|
||||
>
|
||||
<Input placeholder={t('knowledgeBase.createForm.name')} />
|
||||
</Form.Item>
|
||||
)}
|
||||
<Form.Item name="description" label={t('knowledgeBase.createForm.description')}>
|
||||
<Form.Item name="description" label={t('knowledgeBase.createForm.description')} rules={[{ max: 500 }]}>
|
||||
<TextArea rows={2} placeholder={t('knowledgeBase.createForm.description')} />
|
||||
</Form.Item>
|
||||
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
/*
|
||||
* @Author: ZhaoYing
|
||||
* @Date: 2026-02-03 17:33:15
|
||||
* @Last Modified by: ZhaoYing
|
||||
* @Last Modified time: 2026-02-03 17:33:15
|
||||
* @Last Modified by: ZhaoYing
|
||||
* @Last Modified time: 2026-03-05 16:28:58
|
||||
*/
|
||||
/**
|
||||
* Memory Management Page
|
||||
@@ -110,9 +110,15 @@ const MemoryManagement: React.FC = () => {
|
||||
<List.Item key={item.config_id}>
|
||||
<RbCard
|
||||
title={item.config_name}
|
||||
className="rb:relative"
|
||||
>
|
||||
{item.is_system_default &&
|
||||
<div className="rb:absolute rb:-right-px rb:-top-px rb:bg-[#FF5D34] rb:rounded-[0px_7px_0px_8px] rb:text-[12px] rb:text-white rb:font-regular rb:leading-4 rb:py-0.5 rb:px-1">
|
||||
{t('common.default')}
|
||||
</div>
|
||||
}
|
||||
<Tooltip title={item.config_desc}>
|
||||
<div className="rb:text-[#5B6167] rb:text-[12px] rb:leading-4.25 rb:font-regular rb:-mt-1 rb:wrap-break-word rb:line-clamp-1 rb:h-[17px]">{item.config_desc}</div>
|
||||
<div className="rb:text-[#5B6167] rb:text-[12px] rb:leading-4.25 rb:font-regular rb:-mt-1 rb:wrap-break-word rb:line-clamp-1 rb:h-4.25">{item.config_desc}</div>
|
||||
</Tooltip>
|
||||
<RbAlert className="rb:mt-3 ">
|
||||
<div className={clsx("rb:flex rb:gap-5 rb:font-regular rb:text-[14px]")}>
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user