Merge branch 'release/v0.2.6' into feature/memory_zy

This commit is contained in:
yingzhao
2026-03-05 16:49:46 +08:00
committed by GitHub
164 changed files with 6833 additions and 3090 deletions

1
.gitignore vendored
View File

@@ -29,6 +29,7 @@ search_results.json
api/migrations/versions
tmp
files
powers/
# Exclude dep files
huggingface.co/

View File

@@ -3,9 +3,8 @@ Cache 缓存模块
提供各种缓存功能的统一入口
"""
from .memory import EmotionMemoryCache, ImplicitMemoryCache
from .memory import InterestMemoryCache
__all__ = [
"EmotionMemoryCache",
"ImplicitMemoryCache",
"InterestMemoryCache",
]

View File

@@ -3,10 +3,8 @@ Memory 缓存模块
提供记忆系统相关的缓存功能
"""
from .emotion_memory import EmotionMemoryCache
from .implicit_memory import ImplicitMemoryCache
from .interest_memory import InterestMemoryCache
__all__ = [
"EmotionMemoryCache",
"ImplicitMemoryCache",
"InterestMemoryCache",
]

View File

@@ -1,134 +0,0 @@
"""
Emotion Suggestions Cache
情绪个性化建议缓存模块
用于缓存用户的情绪个性化建议数据
"""
import json
import logging
from typing import Optional, Dict, Any
from datetime import datetime
from app.aioRedis import aio_redis
logger = logging.getLogger(__name__)
class EmotionMemoryCache:
"""情绪建议缓存类"""
# Key 前缀
PREFIX = "cache:memory:emotion_memory"
@classmethod
def _get_key(cls, *parts: str) -> str:
"""生成 Redis key
Args:
*parts: key 的各个部分
Returns:
完整的 Redis key
"""
return ":".join([cls.PREFIX] + list(parts))
@classmethod
async def set_emotion_suggestions(
cls,
user_id: str,
suggestions_data: Dict[str, Any],
expire: int = 86400
) -> bool:
"""设置用户情绪建议缓存
Args:
user_id: 用户IDend_user_id
suggestions_data: 建议数据字典,包含:
- health_summary: 健康状态摘要
- suggestions: 建议列表
- generated_at: 生成时间(可选)
expire: 过期时间默认24小时86400秒
Returns:
是否设置成功
"""
try:
key = cls._get_key("suggestions", user_id)
# 添加生成时间戳
if "generated_at" not in suggestions_data:
suggestions_data["generated_at"] = datetime.now().isoformat()
# 添加缓存标记
suggestions_data["cached"] = True
value = json.dumps(suggestions_data, ensure_ascii=False)
await aio_redis.set(key, value, ex=expire)
logger.info(f"设置情绪建议缓存成功: {key}, 过期时间: {expire}")
return True
except Exception as e:
logger.error(f"设置情绪建议缓存失败: {e}", exc_info=True)
return False
@classmethod
async def get_emotion_suggestions(cls, user_id: str) -> Optional[Dict[str, Any]]:
"""获取用户情绪建议缓存
Args:
user_id: 用户IDend_user_id
Returns:
建议数据字典,如果不存在或已过期返回 None
"""
try:
key = cls._get_key("suggestions", user_id)
value = await aio_redis.get(key)
if value:
data = json.loads(value)
logger.info(f"成功获取情绪建议缓存: {key}")
return data
logger.info(f"情绪建议缓存不存在或已过期: {key}")
return None
except Exception as e:
logger.error(f"获取情绪建议缓存失败: {e}", exc_info=True)
return None
@classmethod
async def delete_emotion_suggestions(cls, user_id: str) -> bool:
"""删除用户情绪建议缓存
Args:
user_id: 用户IDend_user_id
Returns:
是否删除成功
"""
try:
key = cls._get_key("suggestions", user_id)
result = await aio_redis.delete(key)
logger.info(f"删除情绪建议缓存: {key}, 结果: {result}")
return result > 0
except Exception as e:
logger.error(f"删除情绪建议缓存失败: {e}", exc_info=True)
return False
@classmethod
async def get_suggestions_ttl(cls, user_id: str) -> int:
"""获取情绪建议缓存的剩余过期时间
Args:
user_id: 用户IDend_user_id
Returns:
剩余秒数,-1表示永不过期-2表示key不存在
"""
try:
key = cls._get_key("suggestions", user_id)
ttl = await aio_redis.ttl(key)
logger.debug(f"情绪建议缓存TTL: {key} = {ttl}")
return ttl
except Exception as e:
logger.error(f"获取情绪建议缓存TTL失败: {e}")
return -2

View File

@@ -1,136 +0,0 @@
"""
Implicit Memory Profile Cache
隐式记忆用户画像缓存模块
用于缓存用户的完整画像数据(偏好标签、四维画像、兴趣领域、行为习惯)
"""
import json
import logging
from typing import Optional, Dict, Any
from datetime import datetime
from app.aioRedis import aio_redis
logger = logging.getLogger(__name__)
class ImplicitMemoryCache:
"""隐式记忆用户画像缓存类"""
# Key 前缀
PREFIX = "cache:memory:implicit_memory"
@classmethod
def _get_key(cls, *parts: str) -> str:
"""生成 Redis key
Args:
*parts: key 的各个部分
Returns:
完整的 Redis key
"""
return ":".join([cls.PREFIX] + list(parts))
@classmethod
async def set_user_profile(
cls,
user_id: str,
profile_data: Dict[str, Any],
expire: int = 86400
) -> bool:
"""设置用户完整画像缓存
Args:
user_id: 用户IDend_user_id
profile_data: 画像数据字典,包含:
- preferences: 偏好标签列表
- portrait: 四维画像对象
- interest_areas: 兴趣领域分布对象
- habits: 行为习惯列表
- generated_at: 生成时间(可选)
expire: 过期时间默认24小时86400秒
Returns:
是否设置成功
"""
try:
key = cls._get_key("profile", user_id)
# 添加生成时间戳
if "generated_at" not in profile_data:
profile_data["generated_at"] = datetime.now().isoformat()
# 添加缓存标记
profile_data["cached"] = True
value = json.dumps(profile_data, ensure_ascii=False)
await aio_redis.set(key, value, ex=expire)
logger.info(f"设置用户画像缓存成功: {key}, 过期时间: {expire}")
return True
except Exception as e:
logger.error(f"设置用户画像缓存失败: {e}", exc_info=True)
return False
@classmethod
async def get_user_profile(cls, user_id: str) -> Optional[Dict[str, Any]]:
"""获取用户完整画像缓存
Args:
user_id: 用户IDend_user_id
Returns:
画像数据字典,如果不存在或已过期返回 None
"""
try:
key = cls._get_key("profile", user_id)
value = await aio_redis.get(key)
if value:
data = json.loads(value)
logger.info(f"成功获取用户画像缓存: {key}")
return data
logger.info(f"用户画像缓存不存在或已过期: {key}")
return None
except Exception as e:
logger.error(f"获取用户画像缓存失败: {e}", exc_info=True)
return None
@classmethod
async def delete_user_profile(cls, user_id: str) -> bool:
"""删除用户完整画像缓存
Args:
user_id: 用户IDend_user_id
Returns:
是否删除成功
"""
try:
key = cls._get_key("profile", user_id)
result = await aio_redis.delete(key)
logger.info(f"删除用户画像缓存: {key}, 结果: {result}")
return result > 0
except Exception as e:
logger.error(f"删除用户画像缓存失败: {e}", exc_info=True)
return False
@classmethod
async def get_profile_ttl(cls, user_id: str) -> int:
"""获取用户画像缓存的剩余过期时间
Args:
user_id: 用户IDend_user_id
Returns:
剩余秒数,-1表示永不过期-2表示key不存在
"""
try:
key = cls._get_key("profile", user_id)
ttl = await aio_redis.ttl(key)
logger.debug(f"用户画像缓存TTL: {key} = {ttl}")
return ttl
except Exception as e:
logger.error(f"获取用户画像缓存TTL失败: {e}")
return -2

122
api/app/cache/memory/interest_memory.py vendored Normal file
View File

@@ -0,0 +1,122 @@
"""
Interest Distribution Cache
兴趣分布缓存模块
用于缓存用户的兴趣分布标签数据,避免重复调用模型生成
"""
import json
import logging
from typing import Optional, List, Dict, Any
from datetime import datetime
from app.aioRedis import aio_redis
logger = logging.getLogger(__name__)
# 缓存过期时间24小时
INTEREST_CACHE_EXPIRE = 86400
class InterestMemoryCache:
"""兴趣分布缓存类"""
PREFIX = "cache:memory:interest_distribution"
@classmethod
def _get_key(cls, end_user_id: str, language: str) -> str:
"""生成 Redis key
Args:
end_user_id: 用户ID
language: 语言类型
Returns:
完整的 Redis key
"""
return f"{cls.PREFIX}:by_user:{end_user_id}:{language}"
@classmethod
async def set_interest_distribution(
cls,
end_user_id: str,
language: str,
data: List[Dict[str, Any]],
expire: int = INTEREST_CACHE_EXPIRE,
) -> bool:
"""设置用户兴趣分布缓存
Args:
end_user_id: 用户ID
language: 语言类型
data: 兴趣分布列表,格式 [{"name": "...", "frequency": ...}, ...]
expire: 过期时间默认24小时
Returns:
是否设置成功
"""
try:
key = cls._get_key(end_user_id, language)
payload = {
"data": data,
"generated_at": datetime.now().isoformat(),
"cached": True,
}
value = json.dumps(payload, ensure_ascii=False)
await aio_redis.set(key, value, ex=expire)
logger.info(f"设置兴趣分布缓存成功: {key}, 过期时间: {expire}")
return True
except Exception as e:
logger.error(f"设置兴趣分布缓存失败: {e}", exc_info=True)
return False
@classmethod
async def get_interest_distribution(
cls,
end_user_id: str,
language: str,
) -> Optional[List[Dict[str, Any]]]:
"""获取用户兴趣分布缓存
Args:
end_user_id: 用户ID
language: 语言类型
Returns:
兴趣分布列表,缓存不存在或已过期返回 None
"""
try:
key = cls._get_key(end_user_id, language)
value = await aio_redis.get(key)
if value:
payload = json.loads(value)
logger.info(f"命中兴趣分布缓存: {key}")
return payload.get("data")
logger.info(f"兴趣分布缓存不存在或已过期: {key}")
return None
except Exception as e:
logger.error(f"获取兴趣分布缓存失败: {e}", exc_info=True)
return None
@classmethod
async def delete_interest_distribution(
cls,
end_user_id: str,
language: str,
) -> bool:
"""删除用户兴趣分布缓存
Args:
end_user_id: 用户ID
language: 语言类型
Returns:
是否删除成功
"""
try:
key = cls._get_key(end_user_id, language)
result = await aio_redis.delete(key)
logger.info(f"删除兴趣分布缓存: {key}, 结果: {result}")
return result > 0
except Exception as e:
logger.error(f"删除兴趣分布缓存失败: {e}", exc_info=True)
return False

View File

@@ -1,9 +1,11 @@
import os
import platform
from datetime import timedelta
from celery.schedules import crontab
from urllib.parse import quote
from celery import Celery
from celery.schedules import crontab
from app.core.config import settings
@@ -43,8 +45,8 @@ celery_app.conf.update(
task_ignore_result=False,
# 超时设置
task_time_limit=1800, # 30分钟硬超时
task_soft_time_limit=1500, # 25分钟软超时
task_time_limit=3600, # 60分钟硬超时
task_soft_time_limit=3000, # 50分钟软超时
# Worker 设置 (per-worker settings are in docker-compose command line)
worker_prefetch_multiplier=1, # Don't hoard tasks, fairer distribution
@@ -82,7 +84,8 @@ celery_app.conf.update(
'app.tasks.workspace_reflection_task': {'queue': 'periodic_tasks'},
'app.tasks.regenerate_memory_cache': {'queue': 'periodic_tasks'},
'app.tasks.run_forgetting_cycle_task': {'queue': 'periodic_tasks'},
'app.controllers.memory_storage_controller.search_all': {'queue': 'periodic_tasks'},
'app.tasks.write_all_workspaces_memory_task': {'queue': 'periodic_tasks'},
'app.tasks.update_implicit_emotions_storage': {'queue': 'periodic_tasks'},
},
)
@@ -90,11 +93,14 @@ celery_app.conf.update(
celery_app.autodiscover_tasks(['app'])
# Celery Beat schedule for periodic tasks
memory_increment_schedule = timedelta(hours=settings.MEMORY_INCREMENT_INTERVAL_HOURS)
memory_increment_schedule = crontab(hour=settings.MEMORY_INCREMENT_HOUR, minute=settings.MEMORY_INCREMENT_MINUTE)
memory_cache_regeneration_schedule = timedelta(hours=settings.MEMORY_CACHE_REGENERATION_HOURS)
# 这个30秒的设计不合理
workspace_reflection_schedule = timedelta(seconds=30) # 每30秒运行一次settings.REFLECTION_INTERVAL_TIME
forgetting_cycle_schedule = timedelta(hours=24) # 每24小时运行一次遗忘周期
workspace_reflection_schedule = timedelta(seconds=settings.WORKSPACE_REFLECTION_INTERVAL_SECONDS)
forgetting_cycle_schedule = timedelta(hours=settings.FORGETTING_CYCLE_INTERVAL_HOURS)
implicit_emotions_update_schedule = crontab(
hour=settings.IMPLICIT_EMOTIONS_UPDATE_HOUR,
minute=settings.IMPLICIT_EMOTIONS_UPDATE_MINUTE,
)
#构建定时任务配置
beat_schedule_config = {
@@ -115,16 +121,16 @@ beat_schedule_config = {
"config_id": None, # 使用默认配置,可以通过环境变量配置
},
},
"write-all-workspaces-memory": {
"task": "app.tasks.write_all_workspaces_memory_task",
"schedule": memory_increment_schedule,
"args": (),
},
"update-implicit-emotions-storage": {
"task": "app.tasks.update_implicit_emotions_storage",
"schedule": implicit_emotions_update_schedule,
"args": (),
},
}
#如果配置了默认工作空间ID则添加记忆总量统计任务
if settings.DEFAULT_WORKSPACE_ID:
beat_schedule_config["write-total-memory"] = {
"task": "app.controllers.memory_storage_controller.search_all",
"schedule": memory_increment_schedule,
"kwargs": {
"workspace_id": settings.DEFAULT_WORKSPACE_ID,
},
}
celery_app.conf.beat_schedule = beat_schedule_config

View File

@@ -396,10 +396,10 @@ async def draft_run(
from app.models import AgentConfig, ModelConfig
from sqlalchemy import select
from app.core.exceptions import BusinessException
from app.services.draft_run_service import DraftRunService
from app.services.draft_run_service import AgentRunService
service = AppService(db)
draft_service = DraftRunService(db)
draft_service = AgentRunService(db)
# 1. 验证应用
app = service._get_app_or_404(app_id)
@@ -484,8 +484,8 @@ async def draft_run(
}
)
from app.services.draft_run_service import DraftRunService
draft_service = DraftRunService(db)
from app.services.draft_run_service import AgentRunService
draft_service = AgentRunService(db)
result = await draft_service.run(
agent_config=agent_cfg,
model_config=model_config,
@@ -789,8 +789,8 @@ async def draft_run_compare(
# 流式返回
if payload.stream:
async def event_generator():
from app.services.draft_run_service import DraftRunService
draft_service = DraftRunService(db)
from app.services.draft_run_service import AgentRunService
draft_service = AgentRunService(db)
async for event in draft_service.run_compare_stream(
agent_config=agent_cfg,
models=model_configs,
@@ -820,8 +820,8 @@ async def draft_run_compare(
)
# 非流式返回
from app.services.draft_run_service import DraftRunService
draft_service = DraftRunService(db)
from app.services.draft_run_service import AgentRunService
draft_service = AgentRunService(db)
result = await draft_service.run_compare(
agent_config=agent_cfg,
models=model_configs,
@@ -835,7 +835,8 @@ async def draft_run_compare(
web_search=True,
memory=True,
parallel=payload.parallel,
timeout=payload.timeout or 60
timeout=payload.timeout or 60,
files=payload.files
)
logger.info(

View File

@@ -441,14 +441,14 @@ async def retrieve_chunks(
# 1 participle search, 2 semantic search, 3 hybrid search
match retrieve_data.retrieve_type:
case chunk_schema.RetrieveType.PARTICIPLE:
rs = vector_service.search_by_full_text(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.similarity_threshold)
rs = vector_service.search_by_full_text(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.similarity_threshold, file_names_filter=retrieve_data.file_names_filter)
return success(data=rs, msg="retrieval successful")
case chunk_schema.RetrieveType.SEMANTIC:
rs = vector_service.search_by_vector(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.vector_similarity_weight)
rs = vector_service.search_by_vector(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.vector_similarity_weight, file_names_filter=retrieve_data.file_names_filter)
return success(data=rs, msg="retrieval successful")
case _:
rs1 = vector_service.search_by_vector(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.vector_similarity_weight)
rs2 = vector_service.search_by_full_text(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.similarity_threshold)
rs1 = vector_service.search_by_vector(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.vector_similarity_weight, file_names_filter=retrieve_data.file_names_filter)
rs2 = vector_service.search_by_full_text(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.similarity_threshold, file_names_filter=retrieve_data.file_names_filter)
# Efficient deduplication
seen_ids = set()
unique_rs = []

View File

@@ -208,14 +208,64 @@ async def get_emotion_health(
# @router.post("/check-data", response_model=ApiResponse)
# async def check_emotion_data_exists(
# request: EmotionSuggestionsRequest,
# db: Session = Depends(get_db),
# current_user: User = Depends(get_current_user),
# ):
# """检查用户情绪建议数据是否存在
# Args:
# request: 包含 end_user_id
# db: 数据库会话
# current_user: 当前用户
# Returns:
# 数据存在状态
# """
# try:
# api_logger.info(
# f"检查用户情绪建议数据是否存在: {request.end_user_id}",
# extra={"end_user_id": request.end_user_id}
# )
# # 从数据库获取建议
# data = await emotion_service.get_cached_suggestions(
# end_user_id=request.end_user_id,
# db=db
# )
# if data is None:
# api_logger.info(f"用户 {request.end_user_id} 的情绪建议数据不存在")
# return fail(
# BizCode.NOT_FOUND,
# "情绪建议数据不存在,请点击右上角刷新进行初始化",
# {"exists": False}
# )
# api_logger.info(f"用户 {request.end_user_id} 的情绪建议数据存在")
# return success(data={"exists": True}, msg="情绪建议数据已存在")
# except Exception as e:
# api_logger.error(
# f"检查情绪建议数据失败: {str(e)}",
# extra={"end_user_id": request.end_user_id},
# exc_info=True
# )
# raise HTTPException(
# status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
# detail=f"检查情绪建议数据失败: {str(e)}"
# )
@router.post("/suggestions", response_model=ApiResponse)
async def get_emotion_suggestions(
request: EmotionSuggestionsRequest,
language_type: str = Header(default=None, alias="X-Language-Type"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""获取个性化情绪建议(从缓存读取)
"""获取个性化情绪建议(从数据库读取)
Args:
request: 包含 end_user_id 和可选的 config_id
@@ -223,77 +273,42 @@ async def get_emotion_suggestions(
current_user: 当前用户
Returns:
存的个性化情绪建议响应
的个性化情绪建议响应
"""
try:
# 使用集中化的语言校验
language = get_language_from_header(language_type)
api_logger.info(
f"用户 {current_user.username} 请求获取个性化情绪建议(缓存)",
f"用户 {current_user.username} 请求获取个性化情绪建议",
extra={
"end_user_id": request.end_user_id,
"config_id": request.config_id
}
)
# 从缓存获取建议
# 从数据库获取建议
data = await emotion_service.get_cached_suggestions(
end_user_id=request.end_user_id,
db=db
)
if data is None:
# 缓存不存在或已过期,自动触发生成
api_logger.info(
f"用户 {request.end_user_id} 的建议缓存不存在或已过期,自动生成新建议",
f"用户 {request.end_user_id} 的建议数据不存在",
extra={"end_user_id": request.end_user_id}
)
try:
data = await emotion_service.generate_emotion_suggestions(
end_user_id=request.end_user_id,
db=db,
language=language
)
# 保存到缓存
await emotion_service.save_suggestions_cache(
end_user_id=request.end_user_id,
suggestions_data=data,
db=db,
expires_hours=24
)
except (ValueError, KeyError) as gen_e:
# 预期内的业务异常:配置缺失、数据格式问题等
api_logger.warning(
f"自动生成建议失败(业务异常): {str(gen_e)}",
extra={"end_user_id": request.end_user_id}
)
return fail(
BizCode.NOT_FOUND,
f"自动生成建议失败: {str(gen_e)}",
""
)
except Exception as gen_e:
# 非预期异常:记录完整 traceback 便于排查
api_logger.error(
f"自动生成建议时发生未预期异常: {str(gen_e)}",
extra={"end_user_id": request.end_user_id},
exc_info=True
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"生成建议时发生内部错误: {str(gen_e)}"
)
return success(
data={"exists": False},
msg="情绪建议数据不存在,请点击右上角刷新进行初始化"
)
api_logger.info(
"个性化建议获取成功(缓存)",
"个性化建议获取成功",
extra={
"end_user_id": request.end_user_id,
"suggestions_count": len(data.get("suggestions", []))
}
)
return success(data=data, msg="个性化建议获取成功(缓存)")
return success(data=data, msg="个性化建议获取成功")
except Exception as e:
api_logger.error(
@@ -314,7 +329,7 @@ async def generate_emotion_suggestions(
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""生成个性化情绪建议调用LLM并缓存
"""生成个性化情绪建议调用LLM并保存到数据库
Args:
request: 包含 end_user_id
@@ -342,12 +357,11 @@ async def generate_emotion_suggestions(
language=language
)
# 保存到缓存
# 保存到数据库
await emotion_service.save_suggestions_cache(
end_user_id=request.end_user_id,
suggestions_data=data,
db=db,
expires_hours=24
db=db
)
api_logger.info(
@@ -369,4 +383,4 @@ async def generate_emotion_suggestions(
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"生成个性化建议失败: {str(e)}"
)
)

View File

@@ -122,6 +122,48 @@ def validate_confidence_threshold(threshold: float) -> None:
raise ValueError("confidence_threshold must be between 0.0 and 1.0")
@router.get("/check-data/{end_user_id}", response_model=ApiResponse)
@cur_workspace_access_guard()
async def check_user_data_exists(
end_user_id: str,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
) -> ApiResponse:
"""
检查用户画像数据是否存在
Args:
end_user_id: 目标用户ID
Returns:
数据存在状态
"""
api_logger.info(f"检查用户画像数据是否存在: {end_user_id}")
try:
# Validate inputs
validate_user_id(end_user_id)
# Create service with user-specific config
service = ImplicitMemoryService(db=db, end_user_id=end_user_id)
# Get cached profile
cached_profile = await service.get_cached_profile(end_user_id=end_user_id, db=db)
if cached_profile is None:
api_logger.info(f"用户 {end_user_id} 的画像数据不存在")
return success(
data={"exists": False},
msg="画像数据不存在,请点击右上角刷新进行初始化"
)
api_logger.info(f"用户 {end_user_id} 的画像数据存在")
return success(data={"exists": True}, msg="画像数据已存在")
except Exception as e:
return handle_implicit_memory_error(e, "检查画像数据", end_user_id)
@router.get("/preferences/{end_user_id}", response_model=ApiResponse)
@cur_workspace_access_guard()
async def get_preference_tags(
@@ -159,12 +201,8 @@ async def get_preference_tags(
cached_profile = await service.get_cached_profile(end_user_id=end_user_id, db=db)
if cached_profile is None:
api_logger.info(f"用户 {end_user_id} 的画像缓存不存在或已过期")
return fail(
BizCode.NOT_FOUND,
"画像缓存不存在或已过期,请右上角刷新生成新画像",
""
)
api_logger.info(f"用户 {end_user_id} 的画像数据不存在")
return fail(BizCode.NOT_FOUND, "", "")
# Extract preferences from cache
preferences = cached_profile.get("preferences", [])
@@ -230,12 +268,8 @@ async def get_dimension_portrait(
cached_profile = await service.get_cached_profile(end_user_id=end_user_id, db=db)
if cached_profile is None:
api_logger.info(f"用户 {end_user_id} 的画像缓存不存在或已过期")
return fail(
BizCode.NOT_FOUND,
"画像缓存不存在或已过期,请右上角刷新生成新画像",
""
)
api_logger.info(f"用户 {end_user_id} 的画像数据不存在")
return fail(BizCode.NOT_FOUND, "", "")
# Extract portrait from cache
portrait = cached_profile.get("portrait", {})
@@ -278,12 +312,8 @@ async def get_interest_area_distribution(
cached_profile = await service.get_cached_profile(end_user_id=end_user_id, db=db)
if cached_profile is None:
api_logger.info(f"用户 {end_user_id} 的画像缓存不存在或已过期")
return fail(
BizCode.NOT_FOUND,
"画像缓存不存在或已过期,请右上角刷新生成新画像",
""
)
api_logger.info(f"用户 {end_user_id} 的画像数据不存在")
return fail(BizCode.NOT_FOUND, "", "")
# Extract interest areas from cache
interest_areas = cached_profile.get("interest_areas", {})
@@ -330,12 +360,8 @@ async def get_behavior_habits(
cached_profile = await service.get_cached_profile(end_user_id=end_user_id, db=db)
if cached_profile is None:
api_logger.info(f"用户 {end_user_id} 的画像缓存不存在或已过期")
return fail(
BizCode.NOT_FOUND,
"画像缓存不存在或已过期,请右上角刷新生成新画像",
""
)
api_logger.info(f"用户 {end_user_id} 的画像数据不存在")
return fail(BizCode.NOT_FOUND, "", "")
# Extract habits from cache
habits = cached_profile.get("habits", [])

View File

@@ -90,7 +90,7 @@ async def get_mcp_servers(
cookies=cookies)
raise_for_http_status(r)
except requests.exceptions.RequestException as e:
api_logger.error(f"mFailed to get MCP servers: {str(e)}")
api_logger.error(f"Failed to get MCP servers: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to get MCP servers: {str(e)}"
@@ -118,6 +118,65 @@ async def get_mcp_servers(
return success(data=result, msg="Query of mcp servers list successful")
@router.get("/operational_mcp_servers", response_model=ApiResponse)
async def get_operational_mcp_servers(
mcp_market_config_id: uuid.UUID,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
Query the operational mcp servers list in pages
- Support keyword search for name,author,owner
- Return paging metadata + operational mcp server list
"""
api_logger.info(
f"Query operational mcp server list: tenant_id={current_user.tenant_id}, username: {current_user.username}")
# 1. Query mcp market config information from the database
api_logger.debug(f"Query mcp market config: {mcp_market_config_id}")
db_mcp_market_config = mcp_market_config_service.get_mcp_market_config_by_id(db,
mcp_market_config_id=mcp_market_config_id,
current_user=current_user)
if not db_mcp_market_config:
api_logger.warning(
f"The mcp market config does not exist or access is denied: mcp_market_config_id={mcp_market_config_id}")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="The mcp market config does not exist or access is denied"
)
# 2. Execute paged query
api = MCPApi()
token = db_mcp_market_config.token
api.login(token)
url = f'{api.mcp_base_url}/operational'
headers = api.builder_headers(api.headers)
try:
cookies = api.get_cookies(access_token=token, cookies_required=True)
r = api.session.get(url, headers=headers, cookies=cookies)
raise_for_http_status(r)
except requests.exceptions.RequestException as e:
api_logger.error(f"Failed to get operational MCP servers: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to get operational MCP servers: {str(e)}"
)
data = api._handle_response(r)
total = data.get('total_count', 0)
mcp_server_list = data.get('mcp_server_list', [])
# items = [{
# 'name': item.get('name', ''),
# 'id': item.get('id', ''),
# 'description': item.get('description', '')
# } for item in mcp_server_list]
# 3. Return structured response
return success(data=mcp_server_list, msg="Query of operational mcp servers list successful")
@router.get("/mcp_server", response_model=ApiResponse)
async def get_mcp_server(
mcp_market_config_id: uuid.UUID,

View File

@@ -1,5 +1,6 @@
from typing import List, Optional
from app.cache.memory.interest_memory import InterestMemoryCache
from app.celery_app import celery_app
from app.core.error_codes import BizCode
from app.core.language_utils import get_language_from_header
@@ -633,12 +634,11 @@ async def get_knowledge_type_stats_api(
current_user: User = Depends(get_current_user)
):
"""
统计当前空间下各知识库类型的数量,包含 General | Web | Third-party | Folder | Memory
统计当前空间下各知识库类型的数量,包含 General | Web | Third-party | Folder。
会对缺失类型补 0返回字典形式。
可选按状态过滤。
- 知识库类型根据当前用户的 current_workspace_id 过滤
- Memory 是 Neo4j 中 Chunk 的数量,根据 end_user_id (end_user_id) 过滤
- 如果用户没有当前工作空间或未提供 end_user_id对应的统计返回 0
- 如果用户没有当前工作空间,对应的统计返回 0
"""
api_logger.info(f"Knowledge type stats requested for workspace_id: {current_user.current_workspace_id}, end_user_id: {end_user_id}")
try:
@@ -662,34 +662,56 @@ async def get_knowledge_type_stats_api(
return fail(BizCode.INTERNAL_ERROR, "获取知识库类型统计失败", str(e))
@router.get("/analytics/hot_memory_tags/by_user", response_model=ApiResponse)
async def get_hot_memory_tags_by_user_api(
end_user_id: Optional[str] = Query(None, description="用户ID可选"),
limit: int = Query(20, description="返回标签数量限制"),
@router.get("/analytics/interest_distribution/by_user", response_model=ApiResponse)
async def get_interest_distribution_by_user_api(
end_user_id: str = Query(..., description="用户ID必填"),
limit: int = Query(5, le=5, description="返回兴趣标签数量限制最多5个"),
language_type: str = Header(default=None, alias="X-Language-Type"),
current_user: User = Depends(get_current_user),
db: Session=Depends(get_db),
db: Session = Depends(get_db),
):
"""
获取指定用户的热门记忆标签
获取指定用户的兴趣分布标签
注意:标签语言由写入时的 X-Language-Type 决定,查询时不进行翻译
与热门标签不同,此接口专注于识别用户的兴趣活动(运动、爱好、学习、创作等),
过滤掉纯物品、工具、地点等不代表用户主动参与活动的名词。
返回格式:
[
{"name": "标签", "frequency": 频次},
{"name": "兴趣活动", "frequency": 频次},
...
]
"""
api_logger.info(f"Hot memory tags by user requested: end_user_id={end_user_id}")
language = get_language_from_header(language_type)
api_logger.info(f"Interest distribution by user requested: end_user_id={end_user_id}, language={language}")
try:
result = await memory_agent_service.get_hot_memory_tags_by_user(
# 优先读取缓存
cached = await InterestMemoryCache.get_interest_distribution(
end_user_id=end_user_id,
limit=limit
language=language,
)
return success(data=result, msg="获取热门记忆标签成功")
if cached is not None:
api_logger.info(f"Interest distribution cache hit: end_user_id={end_user_id}")
return success(data=cached, msg="获取兴趣分布标签成功")
# 缓存未命中,调用模型生成
result = await memory_agent_service.get_interest_distribution_by_user(
end_user_id=end_user_id,
limit=limit,
language=language
)
# 写入缓存24小时过期
await InterestMemoryCache.set_interest_distribution(
end_user_id=end_user_id,
language=language,
data=result,
)
return success(data=result, msg="获取兴趣分布标签成功")
except Exception as e:
api_logger.error(f"Hot memory tags by user failed: {str(e)}")
return fail(BizCode.INTERNAL_ERROR, "获取热门记忆标签失败", str(e))
api_logger.error(f"Interest distribution by user failed: {str(e)}")
return fail(BizCode.INTERNAL_ERROR, "获取兴趣分布标签失败", str(e))
@router.get("/analytics/user_profile", response_model=ApiResponse)

View File

@@ -9,6 +9,7 @@ from app.schemas.response_schema import ApiResponse
from app.services import memory_dashboard_service, memory_storage_service, workspace_service
from app.services.memory_agent_service import get_end_users_connected_configs_batch
from app.services.app_statistics_service import AppStatisticsService
from app.core.logging_config import get_api_logger
# 获取API专用日志器
@@ -469,6 +470,8 @@ async def get_chunk_insight(
@router.get("/dashboard_data", response_model=ApiResponse)
async def dashboard_data(
end_user_id: Optional[str] = Query(None, description="可选的用户ID"),
start_date: Optional[int] = Query(None, description="开始时间戳(毫秒)"),
end_date: Optional[int] = Query(None, description="结束时间戳(毫秒)"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
@@ -503,6 +506,15 @@ async def dashboard_data(
workspace_id = current_user.current_workspace_id
api_logger.info(f"用户 {current_user.username} 请求获取工作空间 {workspace_id} 的dashboard整合数据")
# 如果没有提供时间范围默认使用最近30天
if start_date is None or end_date is None:
from datetime import datetime, timedelta
end_dt = datetime.now()
start_dt = end_dt - timedelta(days=30)
end_date = int(end_dt.timestamp() * 1000)
start_date = int(start_dt.timestamp() * 1000)
api_logger.info(f"使用默认时间范围: {start_dt}{end_dt}")
# 获取 storage_type如果为 None 则使用默认值
storage_type = workspace_service.get_workspace_storage_type(
db=db,
@@ -563,17 +575,22 @@ async def dashboard_data(
except Exception as e:
api_logger.warning(f"获取知识库类型统计失败: {str(e)}")
# 3. 获取API调用增量total_api_call,转换为整数
# 3. 获取API调用统计total_api_call
try:
api_increment = memory_dashboard_service.get_workspace_api_increment(
db=db,
# 使用 AppStatisticsService 获取真实的API调用统计
app_stats_service = AppStatisticsService(db)
api_stats = app_stats_service.get_workspace_api_statistics(
workspace_id=workspace_id,
current_user=current_user
start_date=start_date,
end_date=end_date
)
neo4j_data["total_api_call"] = api_increment
api_logger.info(f"成功获取API调用增量: {neo4j_data['total_api_call']}")
# 计算总调用次数
total_api_calls = sum(item.get("total_calls", 0) for item in api_stats)
neo4j_data["total_api_call"] = total_api_calls
api_logger.info(f"成功获取API调用统计: {neo4j_data['total_api_call']}")
except Exception as e:
api_logger.warning(f"获取API调用增量失败: {str(e)}")
api_logger.error(f"获取API调用统计失败: {str(e)}")
neo4j_data["total_api_call"] = 0
result["neo4j_data"] = neo4j_data
api_logger.info("成功获取neo4j_data")
@@ -589,8 +606,8 @@ async def dashboard_data(
# 获取RAG相关数据
try:
# total_memory: 使用 total_chunkchunk数
total_chunk = memory_dashboard_service.get_rag_total_chunk(db, current_user)
# total_memory: 只统计用户知识库permission_id='Memory')的chunk数
total_chunk = memory_dashboard_service.get_rag_user_kb_total_chunk(db, current_user)
rag_data["total_memory"] = total_chunk
# total_app: 统计当前空间下的所有app数量
@@ -602,10 +619,23 @@ async def dashboard_data(
total_kb = memory_dashboard_service.get_rag_total_kb(db, current_user)
rag_data["total_knowledge"] = total_kb
# total_api_call: 固定值
rag_data["total_api_call"] = 1024
# total_api_call: 使用 AppStatisticsService 获取真实的API调用统计
try:
app_stats_service = AppStatisticsService(db)
api_stats = app_stats_service.get_workspace_api_statistics(
workspace_id=workspace_id,
start_date=start_date,
end_date=end_date
)
# 计算总调用次数
total_api_calls = sum(item.get("total_calls", 0) for item in api_stats)
rag_data["total_api_call"] = total_api_calls
api_logger.info(f"成功获取RAG模式API调用统计: {rag_data['total_api_call']}")
except Exception as e:
api_logger.warning(f"获取RAG模式API调用统计失败使用默认值: {str(e)}")
rag_data["total_api_call"] = 0
api_logger.info(f"成功获取RAG相关数据: memory={total_chunk}, app={len(apps_orm)}, knowledge={total_kb}")
api_logger.info(f"成功获取RAG相关数据: memory={total_chunk}, app={len(apps_orm)}, knowledge={total_kb}, api_calls={rag_data['total_api_call']}")
except Exception as e:
api_logger.warning(f"获取RAG相关数据失败: {str(e)}")

View File

@@ -1,16 +1,18 @@
from fastapi import APIRouter, Depends, HTTPException, status,Header
from typing import Optional
from dotenv import load_dotenv
from fastapi import APIRouter, Depends, Header, HTTPException, status
from sqlalchemy.orm import Session
from app.core.language_utils import get_language_from_header
from app.core.logging_config import get_api_logger
from app.core.response_utils import success
from app.db import get_db
from app.dependencies import get_current_user
from app.models.user_model import User
from app.services.memory_short_service import LongService, ShortService
from app.services.memory_storage_service import search_entity
from app.services.memory_short_service import ShortService,LongService
from dotenv import load_dotenv
from sqlalchemy.orm import Session
from typing import Optional
load_dotenv()
api_logger = get_api_logger()
@@ -29,11 +31,11 @@ async def short_term_configs(
language = get_language_from_header(language_type)
# 获取短期记忆数据
short_term=ShortService(end_user_id)
short_term=ShortService(end_user_id, db)
short_result=short_term.get_short_databasets()
short_count=short_term.get_short_count()
long_term=LongService(end_user_id)
long_term=LongService(end_user_id, db)
long_result=long_term.get_long_databasets()
entity_result = await search_entity(end_user_id)

View File

@@ -469,7 +469,9 @@ async def create_model_api_key_by_provider(
config=api_key_data.config,
is_active=api_key_data.is_active,
priority=api_key_data.priority,
model_config_ids=model_config_ids
model_config_ids=model_config_ids,
capability=api_key_data.capability,
is_omni=api_key_data.is_omni
)
created_keys, failed_models = await ModelApiKeyService.create_api_key_by_provider(db=db, data=create_data)

View File

@@ -124,15 +124,23 @@ def _get_ontology_service(
)
# 通过 Repository 获取可用的 API Key负载均衡逻辑由 Repository 处理)
from app.repositories.model_repository import ModelApiKeyRepository
api_keys = ModelApiKeyRepository.get_by_model_config(db, model_config.id)
if not api_keys:
# from app.repositories.model_repository import ModelApiKeyRepository
from app.services.model_service import ModelApiKeyService
api_key_config = ModelApiKeyService.get_available_api_key(db, model_config.id)
if not api_key_config:
logger.error(f"Model {llm_id} has no active API key")
raise HTTPException(
status_code=400,
detail="指定的LLM模型没有可用的API密钥"
)
api_key_config = api_keys[0]
# api_keys = ModelApiKeyRepository.get_by_model_config(db, model_config.id)
# if not api_keys:
# logger.error(f"Model {llm_id} has no active API key")
# raise HTTPException(
# status_code=400,
# detail="指定的LLM模型没有可用的API密钥"
# )
# api_key_config = api_keys[0]
is_composite = getattr(model_config, 'is_composite', False)
logger.info(
@@ -154,6 +162,7 @@ def _get_ontology_service(
provider=actual_provider,
api_key=api_key_config.api_key,
base_url=api_key_config.api_base,
is_omni=api_key_config.is_omni,
max_retries=3,
timeout=60.0
)
@@ -514,10 +523,9 @@ async def delete_scene(
f"尝试删除系统默认场景: user_id={current_user.id}, "
f"scene_id={scene_id}, scene_name={scene.scene_name}"
)
return fail(
BizCode.BAD_REQUEST,
"系统默认场景不可删除",
"该场景为系统预设场景,不允许删除"
raise HTTPException(
status_code=400,
detail="SYSTEM_DEFAULT_SCENE_CANNOT_DELETE"
)
# 创建OntologyService实例
@@ -543,6 +551,9 @@ async def delete_scene(
return success(data={"deleted": success_flag}, msg="场景删除成功")
except HTTPException:
raise
except ValueError as e:
api_logger.warning(f"Validation error in scene deletion: {str(e)}")
return fail(BizCode.BAD_REQUEST, "请求参数无效", str(e))

View File

@@ -2,25 +2,32 @@ import hashlib
import json
import uuid
from typing import Annotated
from fastapi import APIRouter, Depends, Query, Request
from fastapi.responses import StreamingResponse
from sqlalchemy.orm import Session
from app.core.error_codes import BizCode
from app.core.exceptions import BusinessException
from app.core.logging_config import get_business_logger
from app.core.response_utils import success
from app.core.response_utils import success, fail
from app.db import get_db, get_db_read
from app.dependencies import get_share_user_id, ShareTokenData
from app.models.app_model import App
from app.models.app_model import AppType
from app.repositories import knowledge_repository
from app.repositories.end_user_repository import EndUserRepository
from app.repositories.workflow_repository import WorkflowConfigRepository
from app.schemas import release_share_schema, conversation_schema
from app.schemas.response_schema import PageData, PageMeta
from app.services import workspace_service
from app.services.app_chat_service import AppChatService, get_app_chat_service
from app.services.auth_service import create_access_token
from app.services.conversation_service import ConversationService
from app.services.release_share_service import ReleaseShareService
from app.services.shared_chat_service import SharedChatService
from app.services.app_chat_service import AppChatService, get_app_chat_service
from app.utils.app_config_utils import dict_to_multi_agent_config, workflow_config_4_app_release, \
from app.services.workflow_service import WorkflowService
from app.utils.app_config_utils import workflow_config_4_app_release, \
agent_config_4_app_release, multi_agent_config_4_app_release
router = APIRouter(prefix="/public/share", tags=["Public Share"])
@@ -206,15 +213,13 @@ def list_conversations(
logger.debug(f"share_data:{share_data.user_id}")
other_id = share_data.user_id
service = SharedChatService(db)
share, release = service._get_release_by_share_token(share_data.share_token, password)
from app.repositories.end_user_repository import EndUserRepository
share, release = service.get_release_by_share_token(share_data.share_token, password)
end_user_repo = EndUserRepository(db)
new_end_user = end_user_repo.get_or_create_end_user(
app_id=share.app_id,
other_id=other_id
)
logger.debug(new_end_user.id)
service = SharedChatService(db)
conversations, total = service.list_conversations(
share_token=share_data.share_token,
user_id=str(new_end_user.id),
@@ -293,19 +298,15 @@ async def chat(
# 提前验证和准备(在流式响应开始前完成)
# 这样可以确保错误能正确返回,而不是在流式响应中间出错
from app.models.app_model import AppType
try:
from app.core.exceptions import BusinessException
from app.core.error_codes import BizCode
from app.services.app_service import AppService
# 验证分享链接和密码
share, release = service._get_release_by_share_token(share_token, password)
share, release = service.get_release_by_share_token(share_token, password)
# # Create end_user_id by concatenating app_id with user_id
# end_user_id = f"{share.app_id}_{user_id}"
# Store end_user_id in database with original user_id
from app.repositories.end_user_repository import EndUserRepository
end_user_repo = EndUserRepository(db)
new_end_user = end_user_repo.get_or_create_end_user(
app_id=share.app_id,
@@ -318,7 +319,6 @@ async def chat(
"""获取存储类型和工作空间的ID"""
# 直接通过 SQLAlchemy 查询 app仅查询未删除的应用
from app.models.app_model import App
app = db.query(App).filter(
App.id == appid,
App.is_active.is_(True)
@@ -359,12 +359,12 @@ async def chat(
app_type = release.app.type if release.app else None
# 根据应用类型验证配置
if app_type == "agent":
if app_type == AppType.AGENT:
# Agent 类型:验证模型配置
model_config_id = release.default_model_config_id
if not model_config_id:
raise BusinessException("Agent 应用未配置模型", BizCode.AGENT_CONFIG_MISSING)
elif app_type == "multi_agent":
elif app_type == AppType.MULTI_AGENT:
# Multi-Agent 类型:验证多 Agent 配置
config = release.config or {}
if not config.get("sub_agents"):
@@ -638,6 +638,34 @@ async def chat(
# return success(data=conversation_schema.ChatResponse(**result).model_dump(mode="json"))
else:
from app.core.exceptions import BusinessException
from app.core.error_codes import BizCode
raise BusinessException(f"不支持的应用类型: {app_type}", BizCode.APP_TYPE_NOT_SUPPORTED)
@router.get("/config", summary="获取应用启动配置")
async def config_query(
password: str = Query(None, description="访问密码"),
share_data: ShareTokenData = Depends(get_share_user_id),
db: Session = Depends(get_db),
):
share_service = SharedChatService(db)
share_token = share_data.share_token
share, release = share_service.get_release_by_share_token(share_token, password)
if release.app.type == AppType.WORKFLOW:
workflow_service = WorkflowService(db)
content = {
"app_type": release.app.type,
"variables": workflow_service.get_start_node_variables(release.config)
}
elif release.app.type == AppType.AGENT:
content = {
"app_type": release.app.type,
"variables": release.config.get("variables")
}
elif release.app.type == AppType.MULTI_AGENT:
content = {
"app_type": release.app.type,
"variables": []
}
else:
return fail(msg="Unsupported app type", code=BizCode.APP_TYPE_NOT_SUPPORTED)
return success(data=content)

View File

@@ -249,6 +249,7 @@ async def chat(
app_id=app.id,
workspace_id=workspace_id,
release_id=app.current_release.id,
public=True
):
event_type = event.get("event", "message")
event_data = event.get("data", {})

View File

@@ -39,7 +39,7 @@ async def write_memory_api_service(
Stores memory content for the specified end user using the Memory API Service.
"""
logger.info(f"Memory write request - end_user_id: {payload.end_user_id}, tenant_id: {api_key_auth.tenant_id}")
logger.info(f"Memory write request - end_user_id: {payload.end_user_id}, workspace_id: {api_key_auth.workspace_id}")
memory_api_service = MemoryAPIService(db)

View File

@@ -11,35 +11,37 @@ LangChain Agent 封装
import time
from typing import Any, AsyncGenerator, Dict, List, Optional, Sequence
from app.core.memory.agent.langgraph_graph.write_graph import write_long_term
from app.core.memory.agent.langgraph_graph.write_graph import write_long_term
from app.db import get_db
from app.core.logging_config import get_business_logger
from app.core.models import RedBearLLM, RedBearModelConfig
from app.models.models_model import ModelType
from app.models.models_model import ModelType, ModelProvider
from app.services.memory_agent_service import (
get_end_user_connected_config,
)
from langchain.agents import create_agent
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
from langchain_core.tools import BaseTool
logger = get_business_logger()
class LangChainAgent:
def __init__(
self,
model_name: str,
api_key: str,
provider: str = "openai",
api_base: Optional[str] = None,
temperature: float = 0.7,
max_tokens: int = 2000,
system_prompt: Optional[str] = None,
tools: Optional[Sequence[BaseTool]] = None,
streaming: bool = False,
max_iterations: Optional[int] = None, # 最大迭代次数None 表示自动计算)
max_tool_consecutive_calls: int = 3 # 单个工具最大连续调用次数
self,
model_name: str,
api_key: str,
provider: str = "openai",
api_base: Optional[str] = None,
is_omni: bool = False,
temperature: float = 0.7,
max_tokens: int = 2000,
system_prompt: Optional[str] = None,
tools: Optional[Sequence[BaseTool]] = None,
streaming: bool = False,
max_iterations: Optional[int] = None, # 最大迭代次数None 表示自动计算)
max_tool_consecutive_calls: int = 3 # 单个工具最大连续调用次数
):
"""初始化 LangChain Agent
@@ -60,12 +62,13 @@ class LangChainAgent:
self.provider = provider
self.tools = tools or []
self.streaming = streaming
self.is_omni = is_omni
self.max_tool_consecutive_calls = max_tool_consecutive_calls
# 工具调用计数器:记录每个工具的连续调用次数
self.tool_call_counter: Dict[str, int] = {}
self.last_tool_called: Optional[str] = None
# 根据工具数量动态调整最大迭代次数
# 基础值 + 每个工具额外的调用机会
if max_iterations is None:
@@ -73,9 +76,9 @@ class LangChainAgent:
self.max_iterations = 5 + len(self.tools) * 2
else:
self.max_iterations = max_iterations
self.system_prompt = system_prompt or "你是一个专业的AI助手"
logger.debug(
f"Agent 迭代次数配置: max_iterations={self.max_iterations}, "
f"tool_count={len(self.tools)}, "
@@ -89,6 +92,7 @@ class LangChainAgent:
provider=provider,
api_key=api_key,
base_url=api_base,
is_omni=is_omni,
extra_params={
"temperature": temperature,
"max_tokens": max_tokens,
@@ -143,21 +147,22 @@ class LangChainAgent:
"""
from langchain_core.tools import StructuredTool
from functools import wraps
wrapped_tools = []
for original_tool in tools:
tool_name = original_tool.name
original_func = original_tool.func if hasattr(original_tool, 'func') else None
if not original_func:
# 如果无法获取原始函数,直接使用原工具
wrapped_tools.append(original_tool)
continue
# 创建包装函数
def make_wrapped_func(tool_name, original_func):
"""创建包装函数的工厂函数,避免闭包问题"""
@wraps(original_func)
def wrapped_func(*args, **kwargs):
"""包装后的工具函数,跟踪连续调用次数"""
@@ -168,13 +173,13 @@ class LangChainAgent:
# 切换到新工具,重置计数器
self.tool_call_counter[tool_name] = 1
self.last_tool_called = tool_name
current_count = self.tool_call_counter[tool_name]
logger.debug(
f"工具调用: {tool_name}, 连续调用次数: {current_count}/{self.max_tool_consecutive_calls}"
)
# 检查是否超过最大连续调用次数
if current_count > self.max_tool_consecutive_calls:
logger.warning(
@@ -185,12 +190,12 @@ class LangChainAgent:
f"工具 '{tool_name}' 已连续调用 {self.max_tool_consecutive_calls} 次,"
f"未找到有效结果。请尝试其他方法或直接回答用户的问题。"
)
# 调用原始工具函数
return original_func(*args, **kwargs)
return wrapped_func
# 使用 StructuredTool 创建新工具
wrapped_tool = StructuredTool(
name=original_tool.name,
@@ -198,17 +203,17 @@ class LangChainAgent:
func=make_wrapped_func(tool_name, original_func),
args_schema=original_tool.args_schema if hasattr(original_tool, 'args_schema') else None
)
wrapped_tools.append(wrapped_tool)
return wrapped_tools
def _prepare_messages(
self,
message: str,
history: Optional[List[Dict[str, str]]] = None,
context: Optional[str] = None,
files: Optional[List[Dict[str, Any]]] = None
self,
message: str,
history: Optional[List[Dict[str, str]]] = None,
context: Optional[str] = None,
files: Optional[List[Dict[str, Any]]] = None
) -> List[BaseMessage]:
"""准备消息列表
@@ -248,7 +253,7 @@ class LangChainAgent:
messages.append(HumanMessage(content=user_content))
return messages
def _build_multimodal_content(self, text: str, files: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""
构建多模态消息内容
@@ -261,23 +266,26 @@ class LangChainAgent:
List[Dict]: 消息内容列表
"""
# 根据 provider 使用不同的文本格式
if self.provider.lower() in ["bedrock", "anthropic"]:
# Anthropic/Bedrock: {"type": "text", "text": "..."}
content_parts = [{"type": "text", "text": text}]
else:
# 通义千问等: {"text": "..."}
content_parts = [{"text": text}]
# if (self.provider.lower() in [ModelProvider.BEDROCK, ModelProvider.OPENAI, ModelProvider.XINFERENCE,
# ModelProvider.GPUSTACK] or (
# self.provider.lower() == ModelProvider.DASHSCOPE and self.is_omni)):
# # Anthropic/Bedrock/Xinference/Gpustack/Openai: {"type": "text", "text": "..."}
# content_parts = [{"type": "text", "text": text}]
# else:
# # 通义千问等: {"text": "..."}
# content_parts = [{"type": "text", "text": text}]
content_parts = [{"type": "text", "text": text}]
# 添加文件内容
# MultimodalService 已经根据 provider 返回了正确格式,直接使用
content_parts.extend(files)
logger.debug(
f"构建多模态消息: provider={self.provider}, "
f"parts={len(content_parts)}, "
f"files={len(files)}"
)
return content_parts
async def chat(
@@ -302,7 +310,7 @@ class LangChainAgent:
Returns:
Dict: 包含 content 和元数据的字典
"""
message_chat= message
message_chat = message
start_time = time.time()
actual_config_id = config_id
# If config_id is None, try to get from end_user's connected config
@@ -322,8 +330,8 @@ class LangChainAgent:
except Exception as e:
logger.warning(f"Failed to get db session: {e}")
actual_end_user_id = end_user_id if end_user_id is not None else "unknown"
logger.info(f'写入类型{storage_type,str(end_user_id), message, str(user_rag_memory_id)}')
print(f'写入类型{storage_type,str(end_user_id), message, str(user_rag_memory_id)}')
logger.info(f'写入类型{storage_type, str(end_user_id), message, str(user_rag_memory_id)}')
print(f'写入类型{storage_type, str(end_user_id), message, str(user_rag_memory_id)}')
try:
# 准备消息列表(支持多模态)
messages = self._prepare_messages(message, history, context, files)
@@ -367,14 +375,14 @@ class LangChainAgent:
# 获取最后的 AI 消息
output_messages = result.get("messages", [])
content = ""
logger.debug(f"输出消息数量: {len(output_messages)}")
total_tokens = 0
for msg in reversed(output_messages):
if isinstance(msg, AIMessage):
logger.debug(f"找到 AI 消息content 类型: {type(msg.content)}")
logger.debug(f"AI 消息内容: {msg.content}")
# 处理多模态响应content 可能是字符串或列表
if isinstance(msg.content, str):
content = msg.content
@@ -407,12 +415,13 @@ class LangChainAgent:
response_meta = msg.response_metadata if hasattr(msg, 'response_metadata') else None
total_tokens = response_meta.get("token_usage", {}).get("total_tokens", 0) if response_meta else 0
break
logger.info(f"最终提取的内容长度: {len(content)}")
elapsed_time = time.time() - start_time
if memory_flag:
await write_long_term(storage_type, end_user_id, message_chat, content, user_rag_memory_id, actual_config_id)
await write_long_term(storage_type, end_user_id, message_chat, content, user_rag_memory_id,
actual_config_id)
response = {
"content": content,
"model": self.model_name,
@@ -439,16 +448,16 @@ class LangChainAgent:
raise
async def chat_stream(
self,
message: str,
history: Optional[List[Dict[str, str]]] = None,
context: Optional[str] = None,
end_user_id:Optional[str] = None,
config_id: Optional[str] = None,
storage_type:Optional[str] = None,
user_rag_memory_id:Optional[str] = None,
memory_flag: Optional[bool] = True,
files: Optional[List[Dict[str, Any]]] = None # 新增:多模态文件
self,
message: str,
history: Optional[List[Dict[str, str]]] = None,
context: Optional[str] = None,
end_user_id: Optional[str] = None,
config_id: Optional[str] = None,
storage_type: Optional[str] = None,
user_rag_memory_id: Optional[str] = None,
memory_flag: Optional[bool] = True,
files: Optional[List[Dict[str, Any]]] = None # 新增:多模态文件
) -> AsyncGenerator[str, None]:
"""执行流式对话
@@ -482,7 +491,6 @@ class LangChainAgent:
except Exception as e:
logger.warning(f"Failed to get db session: {e}")
# 注意:不在这里写入用户消息,等 AI 回复后一起写入
try:
# 准备消息列表(支持多模态)
@@ -500,13 +508,13 @@ class LangChainAgent:
full_content = ''
try:
async for event in self.agent.astream_events(
{"messages": messages},
version="v2",
config={"recursion_limit": self.max_iterations}
{"messages": messages},
version="v2",
config={"recursion_limit": self.max_iterations}
):
chunk_count += 1
kind = event.get("event")
# 处理所有可能的流式事件
if kind == "on_chat_model_stream":
# LLM 流式输出
@@ -540,7 +548,7 @@ class LangChainAgent:
full_content += item
yield item
yielded_content = True
elif kind == "on_llm_stream":
# 另一种 LLM 流式事件
chunk = event.get("data", {}).get("chunk")
@@ -577,13 +585,13 @@ class LangChainAgent:
full_content += chunk
yield chunk
yielded_content = True
# 记录工具调用(可选)
elif kind == "on_tool_start":
logger.debug(f"工具调用开始: {event.get('name')}")
elif kind == "on_tool_end":
logger.debug(f"工具调用结束: {event.get('name')}")
logger.debug(f"Agent 流式完成,共 {chunk_count} 个事件")
# 统计token消耗
output_messages = event.get("data", {}).get("output", {}).get("messages", [])
@@ -595,7 +603,8 @@ class LangChainAgent:
yield total_tokens
break
if memory_flag:
await write_long_term(storage_type, end_user_id, message_chat, full_content, user_rag_memory_id, actual_config_id)
await write_long_term(storage_type, end_user_id, message_chat, full_content, user_rag_memory_id,
actual_config_id)
except Exception as e:
logger.error(f"Agent astream_events 失败: {str(e)}", exc_info=True)
raise
@@ -609,5 +618,3 @@ class LangChainAgent:
logger.info("=" * 80)
logger.info("chat_stream 方法执行结束")
logger.info("=" * 80)

View File

@@ -1,9 +1,10 @@
import json
import os
from pathlib import Path
from typing import Any, Dict, Optional
from typing import Annotated, Any, Dict, Optional
from dotenv import load_dotenv
from pydantic import Field, TypeAdapter
load_dotenv()
@@ -200,14 +201,30 @@ class Settings:
REFLECTION_INTERVAL_SECONDS: float = float(os.getenv("REFLECTION_INTERVAL_SECONDS", "300"))
HEALTH_CHECK_SECONDS: float = float(os.getenv("HEALTH_CHECK_SECONDS", "600"))
MEMORY_INCREMENT_INTERVAL_HOURS: float = float(os.getenv("MEMORY_INCREMENT_INTERVAL_HOURS", "24"))
DEFAULT_WORKSPACE_ID: Optional[str] = os.getenv("DEFAULT_WORKSPACE_ID", None)
REFLECTION_INTERVAL_TIME: Optional[str] = int(os.getenv("REFLECTION_INTERVAL_TIME", 30))
# Memory Cache Regeneration Configuration
MEMORY_CACHE_REGENERATION_HOURS: int = int(os.getenv("MEMORY_CACHE_REGENERATION_HOURS", "24"))
# Celery Beat Schedule Configuration (定时任务执行频率)
MEMORY_INCREMENT_HOUR: int = TypeAdapter(
Annotated[int, Field(ge=0, le=23, description="cron hour [0, 23]")]
).validate_python(int(os.getenv("MEMORY_INCREMENT_HOUR", "2")))
MEMORY_INCREMENT_MINUTE: int = TypeAdapter(
Annotated[int, Field(ge=0, le=59, description="cron minute [0, 59]")]
).validate_python(int(os.getenv("MEMORY_INCREMENT_MINUTE", "0")))
WORKSPACE_REFLECTION_INTERVAL_SECONDS: int = TypeAdapter(
Annotated[int, Field(ge=1, description="reflection interval in seconds, must be >= 1")]
).validate_python(int(os.getenv("WORKSPACE_REFLECTION_INTERVAL_SECONDS", "30")))
FORGETTING_CYCLE_INTERVAL_HOURS: int = TypeAdapter(
Annotated[int, Field(ge=1, description="forgetting cycle interval in hours, must be >= 1")]
).validate_python(int(os.getenv("FORGETTING_CYCLE_INTERVAL_HOURS", "24")))
IMPLICIT_EMOTIONS_UPDATE_HOUR: int = int(os.getenv("IMPLICIT_EMOTIONS_UPDATE_HOUR", "2"))
# implicit_emotions_update: 每天几分执行分钟0-59
IMPLICIT_EMOTIONS_UPDATE_MINUTE: int = int(os.getenv("IMPLICIT_EMOTIONS_UPDATE_MINUTE", "0"))
# Memory Module Configuration (internal)
MEMORY_OUTPUT_DIR: str = os.getenv("MEMORY_OUTPUT_DIR", "logs/memory-output")
MEMORY_CONFIG_DIR: str = os.getenv("MEMORY_CONFIG_DIR", "app/core/memory")
@@ -231,7 +248,7 @@ class Settings:
# General Ontology Type Configuration
# ========================================================================
# 通用本体文件路径列表(逗号分隔)
GENERAL_ONTOLOGY_FILES: str = os.getenv("GENERAL_ONTOLOGY_FILES", "General_purpose_entity.ttl")
GENERAL_ONTOLOGY_FILES: str = os.getenv("GENERAL_ONTOLOGY_FILES", "api/app/core/memory/ontology_services/General_purpose_entity.ttl")
# 是否启用通用本体类型功能
ENABLE_GENERAL_ONTOLOGY_TYPES: bool = os.getenv("ENABLE_GENERAL_ONTOLOGY_TYPES", "true").lower() == "true"

View File

@@ -21,7 +21,7 @@ async def get_chunked_dialogs(
end_user_id: Group identifier
messages: Structured message list [{"role": "user", "content": "..."}, ...]
ref_id: Reference identifier
config_id: Configuration ID for processing
config_id: Configuration ID for processing (used to load pruning config)
Returns:
List of DialogData objects with generated chunks
@@ -57,6 +57,61 @@ async def get_chunked_dialogs(
end_user_id=end_user_id,
config_id=config_id
)
# 语义剪枝步骤(在分块之前)
try:
from app.core.memory.storage_services.extraction_engine.data_preprocessing.data_pruning import SemanticPruner
from app.core.memory.models.config_models import PruningConfig
from app.db import get_db_context
from app.services.memory_config_service import MemoryConfigService
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
# 加载剪枝配置
pruning_config = None
if config_id:
try:
with get_db_context() as db:
# 使用 MemoryConfigService 加载完整的 MemoryConfig 对象
config_service = MemoryConfigService(db)
memory_config = config_service.load_memory_config(
config_id=config_id,
service_name="semantic_pruning"
)
if memory_config:
pruning_config = PruningConfig(
pruning_switch=memory_config.pruning_enabled,
pruning_scene=memory_config.pruning_scene or "education",
pruning_threshold=memory_config.pruning_threshold
)
logger.info(f"[剪枝] 加载配置: switch={pruning_config.pruning_switch}, scene={pruning_config.pruning_scene}, threshold={pruning_config.pruning_threshold}")
# 获取LLM客户端用于剪枝
if pruning_config.pruning_switch:
factory = MemoryClientFactory(db)
llm_client = factory.get_llm_client_from_config(memory_config)
# 执行剪枝 - 使用 prune_dataset 支持消息级剪枝
pruner = SemanticPruner(config=pruning_config, llm_client=llm_client)
original_msg_count = len(dialog_data.context.msgs)
# 使用 prune_dataset 而不是 prune_dialog
# prune_dataset 会进行消息级剪枝,即使对话整体相关也会删除不重要消息
pruned_dialogs = await pruner.prune_dataset([dialog_data])
if pruned_dialogs:
dialog_data = pruned_dialogs[0]
remaining_msg_count = len(dialog_data.context.msgs)
deleted_count = original_msg_count - remaining_msg_count
logger.info(f"[剪枝] 完成: 原始{original_msg_count}条 -> 保留{remaining_msg_count}条 (删除{deleted_count}条)")
else:
logger.warning("[剪枝] prune_dataset 返回空列表")
else:
logger.info("[剪枝] 配置中剪枝开关关闭,跳过剪枝")
except Exception as e:
logger.warning(f"[剪枝] 加载配置失败,跳过剪枝: {e}", exc_info=True)
except Exception as e:
logger.warning(f"[剪枝] 执行失败,跳过剪枝: {e}", exc_info=True)
chunker = DialogueChunker(chunker_strategy)
extracted_chunks = await chunker.process_dialogue(dialog_data)

View File

@@ -1,9 +1,12 @@
import asyncio
import json
import logging
import os
from typing import List, Tuple
from app.core.config import settings
logger = logging.getLogger(__name__)
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.db import get_db_context
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
@@ -16,6 +19,10 @@ class FilteredTags(BaseModel):
"""用于接收LLM筛选后的核心标签列表的模型。"""
meaningful_tags: List[str] = Field(..., description="从原始列表中筛选出的具有核心代表意义的名词列表。")
class InterestTags(BaseModel):
"""用于接收LLM筛选后的兴趣活动标签列表的模型。"""
interest_tags: List[str] = Field(..., description="从原始列表中筛选出的代表用户兴趣活动的标签列表。")
async def filter_tags_with_llm(tags: List[str], end_user_id: str) -> List[str]:
"""
使用LLM筛选标签列表仅保留具有代表性的核心名词。
@@ -85,10 +92,74 @@ async def filter_tags_with_llm(tags: List[str], end_user_id: str) -> List[str]:
return structured_response.meaningful_tags
except Exception as e:
print(f"LLM筛选过程中发生错误: {e}")
logger.error(f"LLM筛选过程中发生错误: {e}", exc_info=True)
# 在LLM失败时返回原始标签确保流程继续
return tags
async def filter_interests_with_llm(tags: List[str], end_user_id: str, language: str = "zh") -> List[str]:
"""
使用LLM从标签列表中筛选出代表用户兴趣活动的标签。
与 filter_tags_with_llm 不同,此函数专注于识别"活动/行为"类兴趣,
过滤掉纯物品、工具、地点等不代表用户主动参与活动的名词。
Args:
tags: 原始标签列表
end_user_id: 用户ID用于获取LLM配置
Returns:
筛选后的兴趣活动标签列表
"""
try:
with get_db_context() as db:
from app.services.memory_agent_service import (
get_end_user_connected_config,
)
connected_config = get_end_user_connected_config(end_user_id, db)
config_id = connected_config.get("memory_config_id")
workspace_id = connected_config.get("workspace_id")
if not config_id and not workspace_id:
raise ValueError(
f"No memory_config_id found for end_user_id: {end_user_id}."
)
config_service = MemoryConfigService(db)
memory_config = config_service.load_memory_config(
config_id=config_id,
workspace_id=workspace_id
)
if not memory_config.llm_model_id:
raise ValueError(
f"No llm_model_id found in memory config {config_id}."
)
factory = MemoryClientFactory(db)
llm_client = factory.get_llm_client(memory_config.llm_model_id)
tag_list_str = ", ".join(tags)
from app.core.memory.utils.prompt.prompt_utils import render_interest_filter_prompt
rendered_prompt = render_interest_filter_prompt(tag_list_str, language=language)
messages = [
{
"role": "user",
"content": rendered_prompt
}
]
structured_response = await llm_client.response_structured(
messages=messages,
response_model=InterestTags
)
return structured_response.interest_tags
except Exception as e:
logger.error(f"兴趣标签LLM筛选过程中发生错误: {e}", exc_info=True)
return tags
async def get_raw_tags_from_db(
connector: Neo4jConnector,
end_user_id: str,
@@ -139,14 +210,14 @@ async def get_raw_tags_from_db(
return [(record["name"], record["frequency"]) for record in results]
async def get_hot_memory_tags(end_user_id: str, limit: int = 40, by_user: bool = False) -> List[Tuple[str, int]]:
async def get_hot_memory_tags(end_user_id: str, limit: int = 10, by_user: bool = False) -> List[Tuple[str, int]]:
"""
获取原始标签然后使用LLM进行筛选返回最终的热门标签列表。
查询更多的标签(limit=40)给LLM提供更丰富的上下文进行筛选。
查询更多的标签(40)给LLM提供更丰富的上下文进行筛选但最终返回数量由limit参数控制
Args:
end_user_id: 必需参数。如果by_user=False则为end_user_id如果by_user=True则为user_id
limit: 返回的标签数量限制
limit: 最终返回的标签数量限制默认10
by_user: 是否按user_id查询默认False按end_user_id查询
Raises:
@@ -161,8 +232,9 @@ async def get_hot_memory_tags(end_user_id: str, limit: int = 40, by_user: bool =
# 使用项目的Neo4jConnector
connector = Neo4jConnector()
try:
# 1. 从数据库获取原始排名靠前的标签
raw_tags_with_freq = await get_raw_tags_from_db(connector, end_user_id, limit, by_user=by_user)
# 1. 从数据库获取原始排名靠前的标签查询40条给LLM提供更丰富的上下文
query_limit = 40
raw_tags_with_freq = await get_raw_tags_from_db(connector, end_user_id, query_limit, by_user=by_user)
if not raw_tags_with_freq:
return []
@@ -177,7 +249,61 @@ async def get_hot_memory_tags(end_user_id: str, limit: int = 40, by_user: bool =
if tag in meaningful_tag_names:
final_tags.append((tag, freq))
return final_tags
# 4. 限制返回的标签数量
return final_tags[:limit]
finally:
# 确保关闭连接
await connector.close()
async def get_interest_distribution(end_user_id: str, limit: int = 10, by_user: bool = False, language: str = "zh") -> List[Tuple[str, int]]:
"""
获取用户的兴趣分布标签。
与 get_hot_memory_tags 不同,此函数使用专门针对"活动/行为"的LLM prompt
过滤掉纯物品、工具、地点等,只保留能代表用户兴趣爱好的活动类标签。
Args:
end_user_id: 必需参数。如果by_user=False则为end_user_id如果by_user=True则为user_id
limit: 最终返回的标签数量限制默认10
by_user: 是否按user_id查询默认False按end_user_id查询
Raises:
ValueError: 如果end_user_id未提供或为空
"""
if not end_user_id or not end_user_id.strip():
raise ValueError(
"end_user_id is required. Please provide a valid end_user_id or user_id."
)
connector = Neo4jConnector()
try:
# 查询更多原始标签给LLM提供充足上下文
query_limit = 40
raw_tags_with_freq = await get_raw_tags_from_db(connector, end_user_id, query_limit, by_user=by_user)
if not raw_tags_with_freq:
return []
raw_tag_names = [tag for tag, freq in raw_tags_with_freq]
raw_freq_map = {tag: freq for tag, freq in raw_tags_with_freq}
# 使用兴趣活动专用prompt进行筛选支持语义推断出新标签
interest_tag_names = await filter_interests_with_llm(raw_tag_names, end_user_id, language=language)
# 构建最终标签列表:
# - 原始标签中存在的,保留原始频率
# - LLM推断出的新标签不在原始列表中赋予默认频率1
final_tags = []
seen = set()
for tag in interest_tag_names:
if tag in seen:
continue
seen.add(tag)
freq = raw_freq_map.get(tag, 1)
final_tags.append((tag, freq))
# 按频率降序排列
final_tags.sort(key=lambda x: x[1], reverse=True)
return final_tags[:limit]
finally:
await connector.close()

View File

@@ -5,20 +5,27 @@
- 对话级一次性抽取判定相关性
- 仅对"不相关对话"的消息按比例删除
- 重要信息(时间、编号、金额、联系方式、地址等)优先保留
- 改进版:增强重要性判断、智能填充消息识别、问答对保护、并发优化
"""
import asyncio
import os
import hashlib
import json
import re
from collections import OrderedDict
from datetime import datetime
from typing import List, Optional
from typing import List, Optional, Dict, Tuple, Set
from pydantic import BaseModel, Field
from app.core.memory.models.message_models import DialogData, ConversationMessage, ConversationContext
from app.core.memory.models.config_models import PruningConfig
from app.core.memory.utils.config.config_utils import get_pruning_config
from app.core.memory.utils.prompt.prompt_utils import prompt_env, log_prompt_rendering, log_template_rendering
from app.core.memory.storage_services.extraction_engine.data_preprocessing.scene_config import (
SceneConfigRegistry,
ScenePatterns
)
class DialogExtractionResponse(BaseModel):
@@ -36,6 +43,23 @@ class DialogExtractionResponse(BaseModel):
keywords: List[str] = Field(default_factory=list)
class MessageImportanceResponse(BaseModel):
"""消息重要性批量判断的结构化返回用于LLM语义判断
- importance_scores: 消息索引到重要性分数的映射 (0-10分)
- reasons: 可选的判断理由
"""
importance_scores: Dict[int, int] = Field(default_factory=dict, description="消息索引到重要性分数(0-10)的映射")
reasons: Optional[Dict[int, str]] = Field(default_factory=dict, description="可选的判断理由")
class QAPair(BaseModel):
"""问答对模型,用于识别和保护对话中的问答结构。"""
question_idx: int = Field(..., description="问题消息的索引")
answer_idx: int = Field(..., description="答案消息的索引")
confidence: float = Field(default=1.0, description="问答对的置信度(0-1)")
class SemanticPruner:
"""语义剪枝:在预处理与分块之间过滤与场景不相关内容。
@@ -43,109 +67,374 @@ class SemanticPruner:
重要信息(时间、编号、金额、联系方式、地址等)优先保留。
"""
def __init__(self, config: Optional[PruningConfig] = None, llm_client=None):
cfg_dict = get_pruning_config() if config is None else config.model_dump()
self.config = PruningConfig.model_validate(cfg_dict)
def __init__(self, config: Optional[PruningConfig] = None, llm_client=None, language: str = "zh", max_concurrent: int = 5):
# 如果没有提供config使用默认配置
if config is None:
# 使用默认的剪枝配置
config = PruningConfig(
pruning_switch=False, # 默认关闭剪枝,保持向后兼容
pruning_scene="education",
pruning_threshold=0.5
)
self.config = config
self.llm_client = llm_client
self.language = language # 保存语言配置
self.max_concurrent = max_concurrent # 新增:最大并发数
# 详细日志配置:限制逐条消息日志的数量
self._detailed_prune_logging = True # 是否启用详细日志
self._max_debug_msgs_per_dialog = 20 # 每个对话最多记录前N条消息的详细日志
# 加载场景特定配置
self.scene_config: ScenePatterns = SceneConfigRegistry.get_config(
self.config.pruning_scene,
fallback_to_generic=True
)
# 检查场景是否有专门支持
is_supported = SceneConfigRegistry.is_scene_supported(self.config.pruning_scene)
if is_supported:
self._log(f"[剪枝-初始化] 场景={self.config.pruning_scene} 使用专门配置")
else:
self._log(f"[剪枝-初始化] 场景={self.config.pruning_scene} 未预定义,使用通用配置(保守策略)")
self._log(f"[剪枝-初始化] 支持的场景: {SceneConfigRegistry.get_all_scenes()}")
# Load Jinja2 template
self.template = prompt_env.get_template("extracat_Pruning.jinja2")
# 对话抽取缓存:避免同一对话重复调用 LLM / 重复渲染
self._dialog_extract_cache: dict[str, DialogExtractionResponse] = {}
# 对话抽取缓存:使用 OrderedDict 实现 LRU 缓存
self._dialog_extract_cache: OrderedDict[str, DialogExtractionResponse] = OrderedDict()
self._cache_max_size = 1000 # 缓存大小限制
# 运行日志:收集关键终端输出,便于写入 JSON
self.run_logs: List[str] = []
# 采用顺序处理,移除并发配置以简化与稳定执行
def _is_important_message(self, message: ConversationMessage) -> bool:
"""基于启发式规则识别重要信息消息,优先保留。
- 含日期/时间如YYYY-MM-DD、HH:MM、2024年11月10日、上午/下午)。
- 含编号/ID/订单号/申请号/账号/电话/金额等关键字段。
- 关键词:"时间""日期""编号""订单""流水""金额""""""电话""手机号""邮箱""地址"
改进版:使用场景特定的模式进行识别
- 根据 pruning_scene 动态加载对应的识别规则
- 支持教育、在线服务、外呼三个场景的特定模式
"""
import re
text = message.msg.strip()
if not text:
return False
patterns = [
r"\b\d{4}-\d{1,2}-\d{1,2}\b",
r"\b\d{1,2}:\d{2}\b",
r"\d{4}\d{1,2}月\d{1,2}日",
r"上午|下午|AM|PM",
r"订单号|工单|申请号|编号|ID|账号|账户",
r"电话|手机号|微信|QQ|邮箱",
r"地址|地点",
r"金额|费用|价格|¥|¥|\d+元",
r"时间|日期|有效期|截止",
]
for p in patterns:
if re.search(p, text, flags=re.IGNORECASE):
# 使用场景特定的模式
all_patterns = (
self.scene_config.high_priority_patterns +
self.scene_config.medium_priority_patterns +
self.scene_config.low_priority_patterns
)
for pattern, _ in all_patterns:
if re.search(pattern, text, flags=re.IGNORECASE):
return True
# 检查是否为问句(以问号结尾或包含疑问词)
if text.endswith("") or text.endswith("?"):
return True
# 检查是否包含问句关键词
if any(keyword in text for keyword in self.scene_config.question_keywords):
return True
# 检查是否包含决策性关键词
if any(keyword in text for keyword in self.scene_config.decision_keywords):
return True
return False
def _importance_score(self, message: ConversationMessage) -> int:
"""为重要消息打分,用于在保留比例内优先保留更关键的内容。
简单启发:匹配到的类别越多、越关键分值越高。
改进版使用场景特定的权重体系0-10分
- 根据场景动态调整不同信息类型的权重
- 高优先级模式4-6分
- 中优先级模式2-3分
- 低优先级模式1分
"""
import re
text = message.msg.strip()
score = 0
weights = [
(r"\b\d{4}-\d{1,2}-\d{1,2}\b", 3),
(r"\b\d{1,2}:\d{2}\b", 2),
(r"\d{4}\d{1,2}月\d{1,2}日", 3),
(r"订单号|工单|申请号|编号|ID|账号|账户", 4),
(r"电话|手机号|微信|QQ|邮箱", 3),
(r"地址|地点", 2),
(r"金额|费用|价格|¥|¥|\d+元", 4),
(r"时间|日期|有效期|截止", 2),
]
for p, w in weights:
if re.search(p, text, flags=re.IGNORECASE):
score += w
return score
# 使用场景特定的权重
for pattern, weight in self.scene_config.high_priority_patterns:
if re.search(pattern, text, flags=re.IGNORECASE):
score += weight
for pattern, weight in self.scene_config.medium_priority_patterns:
if re.search(pattern, text, flags=re.IGNORECASE):
score += weight
for pattern, weight in self.scene_config.low_priority_patterns:
if re.search(pattern, text, flags=re.IGNORECASE):
score += weight
# 问句加分
if text.endswith("") or text.endswith("?"):
score += 2
# 包含问句关键词加分
if any(keyword in text for keyword in self.scene_config.question_keywords):
score += 1
# 包含决策性关键词加分
if any(keyword in text for keyword in self.scene_config.decision_keywords):
score += 2
# 长度加分(较长的消息通常包含更多信息)
if len(text) > 50:
score += 1
if len(text) > 100:
score += 1
return min(score, 10) # 最高10分
def _is_filler_message(self, message: ConversationMessage) -> bool:
"""检测典型寒暄/口头禅/确认类短消息用于跳过LLM分类以加速
"""检测典型寒暄/口头禅/确认类短消息。
改进版:更严格的填充消息判断,避免误删场景相关内容
满足以下之一视为填充消息:
- 纯标点或长度很短(<= 4 个汉字或 <= 8 个字符)且不包含数字或关键实体;
- 常见词:你好/您好/在吗/嗯/嗯嗯/哦/好的/好/行/可以/不可以/谢谢/拜拜/再见/哈哈/呵呵/哈哈哈/。。。/??。
- 纯标点或空白
- 在场景特定填充词库中(精确匹配)
- 纯表情符号
- 常见寒暄(精确匹配短语)
注意:不再使用长度判断,避免误删短但重要的消息
"""
import re
t = message.msg.strip()
if not t:
return True
# 常见填充语
fillers = [
"你好", "您好", "在吗", "", "嗯嗯", "", "好的", "", "", "可以", "不可以", "谢谢",
"拜拜", "再见", "哈哈", "呵呵", "哈哈哈", "。。。", "??", ""
]
if t in fillers:
# 检查是否在场景特定填充词库中(精确匹配)
if t in self.scene_config.filler_phrases:
return True
# 长度与字符类型判断
if len(t) <= 8:
# 非数字、无关键实体的短文本
if not re.search(r"[0-9]", t) and not self._is_important_message(message):
# 主要是标点或简单确认词
if re.fullmatch(r"[。!?,.!?…·\s]+", t) or t in fillers:
return True
# 常见寒暄和问候(精确匹配,避免误删)
common_greetings = {
"在吗", "在不在", "在呢", "在的",
"你好", "您好", "hello", "hi",
"拜拜", "再见", "", "88", "bye",
"好的", "", "", "可以", "", "", "",
"是的", "", "对的", "没错", "是啊",
"哈哈", "呵呵", "嘿嘿", "嗯嗯"
}
if t in common_greetings:
return True
# 检查是否为纯表情符号(方括号包裹)
if re.fullmatch(r"(\[[^\]]+\])+", t):
return True
# 检查是否为纯emojiUnicode表情
emoji_pattern = re.compile(
"["
"\U0001F600-\U0001F64F" # 表情符号
"\U0001F300-\U0001F5FF" # 符号和象形文字
"\U0001F680-\U0001F6FF" # 交通和地图符号
"\U0001F1E0-\U0001F1FF" # 旗帜
"\U00002702-\U000027B0"
"\U000024C2-\U0001F251"
"]+", flags=re.UNICODE
)
if emoji_pattern.fullmatch(t):
return True
# 纯标点符号
if re.fullmatch(r"[。!?,.!?…·\s]+", t):
return True
return False
async def _batch_evaluate_importance_with_llm(
self,
messages: List[ConversationMessage],
context: str = ""
) -> Dict[int, int]:
"""使用LLM批量评估消息的重要性语义层面
Args:
messages: 消息列表
context: 对话上下文(可选)
Returns:
消息索引到重要性分数(0-10)的映射
"""
if not self.llm_client or not messages:
return {}
# 构建批量评估的提示词
msg_list = []
for idx, msg in enumerate(messages):
msg_list.append(f"{idx}. {msg.msg}")
msg_text = "\n".join(msg_list)
prompt = f"""请评估以下消息的重要性给每条消息打分0-10分
- 0-2分无意义的寒暄、口头禅、纯表情
- 3-5分一般性对话有一定信息量但不关键
- 6-8分包含重要信息时间、地点、人物、事件等
- 9-10分关键决策、承诺、重要数据
对话上下文:
{context if context else ""}
待评估的消息:
{msg_text}
请以JSON格式返回格式为
{{
"importance_scores": {{
"0": 分数,
"1": 分数,
...
}}
}}
"""
try:
messages_for_llm = [
{"role": "system", "content": "你是一个专业的对话分析助手,擅长评估消息的重要性。"},
{"role": "user", "content": prompt}
]
response = await self.llm_client.response_structured(
messages_for_llm,
MessageImportanceResponse
)
# 转换字符串键为整数键
return {int(k): v for k, v in response.importance_scores.items()}
except Exception as e:
self._log(f"[剪枝-LLM] 批量重要性评估失败: {str(e)[:100]}")
return {}
def _identify_qa_pairs(self, messages: List[ConversationMessage]) -> List[QAPair]:
"""识别对话中的问答对,用于保护问答结构的完整性。
改进版:使用场景特定的问句关键词,并排除寒暄类问句
Args:
messages: 消息列表
Returns:
问答对列表
"""
qa_pairs = []
# 寒暄类问句,不应该被保护(这些不是真正的问答)
greeting_questions = {
"在吗", "在不在", "你好吗", "怎么样", "好吗",
"有空吗", "忙吗", "睡了吗", "起床了吗"
}
for i in range(len(messages) - 1):
current_msg = messages[i].msg.strip()
next_msg = messages[i + 1].msg.strip()
# 排除寒暄类问句
if current_msg in greeting_questions:
continue
# 使用场景特定的问句关键词,但要求更严格
is_question = False
# 1. 以问号结尾
if current_msg.endswith("") or current_msg.endswith("?"):
is_question = True
# 2. 包含实质性问句关键词(排除"吗"这种太宽泛的)
elif any(word in current_msg for word in ["什么", "为什么", "怎么", "如何", "哪里", "哪个", "", "多少", "几点", "何时"]):
is_question = True
if is_question and next_msg:
# 检查下一条消息是否像答案(不是另一个问句,也不是寒暄)
is_answer = not (next_msg.endswith("") or next_msg.endswith("?"))
# 排除寒暄类回复
greeting_answers = {"你好", "您好", "在呢", "在的", "", "", "好的"}
if next_msg in greeting_answers:
is_answer = False
if is_answer:
qa_pairs.append(QAPair(
question_idx=i,
answer_idx=i + 1,
confidence=0.8 # 基于规则的置信度
))
return qa_pairs
def _get_protected_indices(
self,
messages: List[ConversationMessage],
qa_pairs: List[QAPair],
window_size: int = 2
) -> Set[int]:
"""获取需要保护的消息索引集合(问答对+上下文窗口)。
Args:
messages: 消息列表
qa_pairs: 问答对列表
window_size: 上下文窗口大小(前后各保留几条消息)
Returns:
需要保护的消息索引集合
"""
protected = set()
for qa_pair in qa_pairs:
# 保护问答对本身
protected.add(qa_pair.question_idx)
protected.add(qa_pair.answer_idx)
# 保护上下文窗口
for offset in range(-window_size, window_size + 1):
q_idx = qa_pair.question_idx + offset
a_idx = qa_pair.answer_idx + offset
if 0 <= q_idx < len(messages):
protected.add(q_idx)
if 0 <= a_idx < len(messages):
protected.add(a_idx)
return protected
async def _extract_dialog_important(self, dialog_text: str) -> DialogExtractionResponse:
"""对话级一次性抽取:从整段对话中提取重要信息并判定相关性。
- 仅使用 LLM 结构化输出;
改进版:
- LRU缓存管理
- 重试机制
- 降级策略
"""
# 缓存命中则直接返回(场景+内容作为键)
cache_key = f"{self.config.pruning_scene}:" + hashlib.sha1(dialog_text.encode("utf-8")).hexdigest()
# LRU缓存如果命中移到末尾最近使用
if cache_key in self._dialog_extract_cache:
self._dialog_extract_cache.move_to_end(cache_key)
return self._dialog_extract_cache[cache_key]
rendered = self.template.render(pruning_scene=self.config.pruning_scene, dialog_text=dialog_text)
log_template_rendering("extracat_Pruning.jinja2", {"pruning_scene": self.config.pruning_scene})
# LRU缓存大小限制超过限制时删除最旧的条目
if len(self._dialog_extract_cache) >= self._cache_max_size:
# 删除最旧的条目OrderedDict的第一个
oldest_key = next(iter(self._dialog_extract_cache))
del self._dialog_extract_cache[oldest_key]
self._log(f"[剪枝-缓存] LRU缓存已满删除最旧条目")
rendered = self.template.render(
pruning_scene=self.config.pruning_scene,
dialog_text=dialog_text,
language=self.language
)
log_template_rendering("extracat_Pruning.jinja2", {
"pruning_scene": self.config.pruning_scene,
"language": self.language
})
log_prompt_rendering("pruning-extract", rendered)
# 强制使用 LLM;移除正则回退
# 强制使用 LLM
if not self.llm_client:
raise RuntimeError("llm_client 未配置;请配置 LLM 以进行结构化抽取。")
@@ -153,12 +442,32 @@ class SemanticPruner:
{"role": "system", "content": "你是一个严谨的场景抽取助手,只输出严格 JSON。"},
{"role": "user", "content": rendered},
]
try:
ex = await self.llm_client.response_structured(messages, DialogExtractionResponse)
self._dialog_extract_cache[cache_key] = ex
return ex
except Exception as e:
raise RuntimeError("LLM 结构化抽取失败;请检查 LLM 配置或重试。") from e
# 重试机制
max_retries = 3
for attempt in range(max_retries):
try:
ex = await self.llm_client.response_structured(messages, DialogExtractionResponse)
self._dialog_extract_cache[cache_key] = ex
return ex
except Exception as e:
if attempt < max_retries - 1:
self._log(f"[剪枝-LLM] 第 {attempt + 1} 次尝试失败,重试中... 错误: {str(e)[:100]}")
await asyncio.sleep(0.5 * (attempt + 1)) # 指数退避
continue
else:
# 降级策略:标记为相关,避免误删
self._log(f"[剪枝-LLM] LLM 调用失败 {max_retries} 次,使用降级策略(标记为相关)")
fallback_response = DialogExtractionResponse(
is_related=True,
times=[],
ids=[],
amounts=[],
contacts=[],
addresses=[],
keywords=[]
)
return fallback_response
def _msg_matches_tokens(self, message: ConversationMessage, tokens: List[str]) -> bool:
"""判断消息是否包含任意抽取到的重要片段。"""
@@ -248,12 +557,14 @@ class SemanticPruner:
async def prune_dataset(self, dialogs: List[DialogData]) -> List[DialogData]:
"""数据集层面:全局消息级剪枝,保留所有对话。
- 仅在"不相关对话"的范围内执行消息剪枝;相关对话不动。
- 只删除"不重要的不相关消息",重要信息(时间、编号等)强制保留。
- 删除总量 = 阈值 * 全部不相关可删消息数,按可删容量比例分配;顺序删除。
- 保证每段对话至少保留1条消息不会删除整段对话。
改进版:
- 消息级独立判断,每条消息根据场景规则独立评估
- 问答对保护已注释(暂不启用,留作观察)
- 优化删除策略:填充消息 → 不重要消息 → 低分重要消息
- 只删除"不重要的不相关消息",重要信息(时间、编号等)强制保留
- 保证每段对话至少保留1条消息不会删除整段对话
"""
# 如果剪枝功能关闭,直接返回原始数据集
# 如果剪枝功能关闭,直接返回原始数据集
if not self.config.pruning_switch:
return dialogs
@@ -264,179 +575,140 @@ class SemanticPruner:
proportion = 0.9
if proportion < 0.0:
proportion = 0.0
evaluated_dialogs = [] # list of dicts: {dialog, is_related}
self._log(
f"[剪枝-数据集] 对话总数={len(dialogs)} 场景={self.config.pruning_scene} 删除比例={proportion} 开关={self.config.pruning_switch}"
f"[剪枝-数据集] 对话总数={len(dialogs)} 场景={self.config.pruning_scene} 删除比例={proportion} 开关={self.config.pruning_switch} 模式=消息级独立判断"
)
# 对话级相关性分类(一次性对整段对话文本进行判断,顺序执行并复用缓存)
evaluated_dialogs = []
for idx, dd in enumerate(dialogs):
try:
ex = await self._extract_dialog_important(dd.content)
evaluated_dialogs.append({
"dialog": dd,
"is_related": bool(ex.is_related),
"index": idx,
"extraction": ex
})
except Exception:
evaluated_dialogs.append({
"dialog": dd,
"is_related": True,
"index": idx,
"extraction": None
})
# 统计相关 / 不相关对话
not_related_dialogs = [d for d in evaluated_dialogs if not d["is_related"]]
related_dialogs = [d for d in evaluated_dialogs if d["is_related"]]
self._log(
f"[剪枝-数据集] 相关对话数={len(related_dialogs)} 不相关对话数={len(not_related_dialogs)}"
)
# 简洁打印第几段对话相关/不相关索引基于1
def _fmt_indices(items, cap: int = 10):
inds = [i["index"] + 1 for i in items]
if len(inds) <= cap:
return inds
# 超过上限时只打印前cap个并标注总数
return inds[:cap] + ["...", f"{len(inds)}"]
rel_inds = _fmt_indices(related_dialogs)
nrel_inds = _fmt_indices(not_related_dialogs)
self._log(f"[剪枝-数据集] 相关对话:第{rel_inds}段;不相关对话:第{nrel_inds}")
result: List[DialogData] = []
if not_related_dialogs:
# 为每个不相关对话进行一次性抽取,识别重要/不重要(避免逐条 LLM
per_dialog_info = {}
total_unrelated = 0
total_capacity = 0
for d in not_related_dialogs:
dd = d["dialog"]
extraction = d.get("extraction")
if extraction is None:
extraction = await self._extract_dialog_important(dd.content)
# 合并所有重要标记
tokens = extraction.times + extraction.ids + extraction.amounts + extraction.contacts + extraction.addresses + extraction.keywords
msgs = dd.context.msgs
# 分类消息
imp_unrel_msgs = [m for m in msgs if self._msg_matches_tokens(m, tokens) or self._is_important_message(m)]
unimp_unrel_msgs = [m for m in msgs if m not in imp_unrel_msgs]
# 重要消息按重要性排序
imp_sorted_ids = [id(m) for m in sorted(imp_unrel_msgs, key=lambda m: self._importance_score(m))]
info = {
"dialog": dd,
"total_msgs": len(msgs),
"unrelated_count": len(msgs),
"imp_ids_sorted": imp_sorted_ids,
"unimp_ids": [id(m) for m in unimp_unrel_msgs],
}
per_dialog_info[d["index"]] = info
total_unrelated += info["unrelated_count"]
# 全局删除配额:比例作用于全部不相关消息(重要+不重要)
global_delete = int(total_unrelated * proportion)
if proportion > 0 and total_unrelated > 0 and global_delete == 0:
global_delete = 1
# 每段的最大可删容量:不重要全部 + 重要最多删除 floor(len(重要)*比例)且至少保留1条消息
capacities = []
for d in not_related_dialogs:
idx = d["index"]
info = per_dialog_info[idx]
# 统计重要数量
imp_count = len(info["imp_ids_sorted"])
unimp_count = len(info["unimp_ids"])
imp_cap = int(imp_count * proportion)
cap = min(unimp_count + imp_cap, max(0, info["total_msgs"] - 1))
capacities.append(cap)
total_capacity = sum(capacities)
if global_delete > total_capacity:
print(f"[剪枝-数据集] 不相关消息总数={total_unrelated},目标删除={global_delete},最大可删={total_capacity}(重要消息按比例保留)。将按最大可删执行。")
global_delete = total_capacity
# 配额分配:按不相关消息占比分配到各对话,但不超过各自容量
alloc = []
for i, d in enumerate(not_related_dialogs):
idx = d["index"]
info = per_dialog_info[idx]
share = int(global_delete * (info["unrelated_count"] / total_unrelated)) if total_unrelated > 0 else 0
alloc.append(min(share, capacities[i]))
allocated = sum(alloc)
rem = global_delete - allocated
turn = 0
while rem > 0 and turn < 100000:
progressed = False
for i in range(len(not_related_dialogs)):
if rem <= 0:
break
if alloc[i] < capacities[i]:
alloc[i] += 1
rem -= 1
progressed = True
if not progressed:
break
turn += 1
# 应用删除:相关对话不动;不相关按分配先删不重要,再删重要(低分优先)
total_deleted_confirm = 0
for d in evaluated_dialogs:
dd = d["dialog"]
msgs = dd.context.msgs
original = len(msgs)
if d["is_related"]:
result.append(dd)
continue
idx_in_unrel = next((k for k, x in enumerate(not_related_dialogs) if x["index"] == d["index"]), None)
if idx_in_unrel is None:
result.append(dd)
continue
quota = alloc[idx_in_unrel]
info = per_dialog_info[d["index"]]
# 计算本对话重要最多可删数量
imp_count = len(info["imp_ids_sorted"])
imp_del_cap = int(imp_count * proportion)
# 先构造顺序删除的"不重要ID集合"(按出现顺序前 quota 条)
unimp_delete_ids = set(info["unimp_ids"][:min(quota, len(info["unimp_ids"]))])
del_unimp = min(quota, len(unimp_delete_ids))
rem_quota = quota - del_unimp
# 再从重要里选低分优先的删除ID不超过 imp_del_cap
imp_delete_ids = set(info["imp_ids_sorted"][:min(rem_quota, imp_del_cap)])
deleted_here = 0
actual_unimp_deleted = 0
actual_imp_deleted = 0
kept = []
for m in msgs:
mid = id(m)
if mid in unimp_delete_ids and actual_unimp_deleted < del_unimp:
actual_unimp_deleted += 1
deleted_here += 1
continue
if mid in imp_delete_ids and actual_imp_deleted < len(imp_delete_ids):
actual_imp_deleted += 1
deleted_here += 1
continue
kept.append(m)
if not kept and msgs:
kept = [msgs[0]]
dd.context.msgs = kept
total_deleted_confirm += deleted_here
self._log(
f"[剪枝-对话] 对话 {d['index']+1} 总消息={original} 分配删除={quota} 实删={deleted_here} 保留={len(kept)}"
)
result.append(dd)
self._log(f"[剪枝-数据集] 全局消息级顺序剪枝完成,总删除 {total_deleted_confirm} 条(不相关消息,重要按比例保留)。")
else:
# 全部相关:不执行剪枝
result = [d["dialog"] for d in evaluated_dialogs]
total_original_msgs = 0
total_deleted_msgs = 0
for d_idx, dd in enumerate(dialogs):
msgs = dd.context.msgs
original_count = len(msgs)
total_original_msgs += original_count
# ========== 问答对保护(已注释,暂不启用,留作观察) ==========
# qa_pairs = self._identify_qa_pairs(msgs)
# protected_indices = self._get_protected_indices(msgs, qa_pairs, window_size=0)
# ========================================================
# 消息级分类:每条消息独立判断
important_msgs = [] # 重要消息(保留)
unimportant_msgs = [] # 不重要消息(可删除)
filler_msgs = [] # 填充消息(优先删除)
# 判断是否需要详细日志仅对前N条消息记录
should_log_details = self._detailed_prune_logging and original_count <= self._max_debug_msgs_per_dialog
if self._detailed_prune_logging and original_count > self._max_debug_msgs_per_dialog:
self._log(f" 对话[{d_idx}]消息数={original_count},仅采样前{self._max_debug_msgs_per_dialog}条进行详细日志")
for idx, m in enumerate(msgs):
msg_text = m.msg.strip()
# ========== 问答对保护判断(已注释) ==========
# if idx in protected_indices:
# important_msgs.append((idx, m))
# self._log(f" [{idx}] '{msg_text[:30]}...' → 重要(问答对保护)")
# ==========================================
# 填充消息(寒暄、表情等)
if self._is_filler_message(m):
filler_msgs.append((idx, m))
if should_log_details or idx < self._max_debug_msgs_per_dialog:
self._log(f" [{idx}] '{msg_text[:30]}...' → 填充")
# 重要信息(学号、成绩、时间、金额等)
elif self._is_important_message(m):
important_msgs.append((idx, m))
if should_log_details or idx < self._max_debug_msgs_per_dialog:
self._log(f" [{idx}] '{msg_text[:30]}...' → 重要(场景规则)")
# 其他消息
else:
unimportant_msgs.append((idx, m))
if should_log_details or idx < self._max_debug_msgs_per_dialog:
self._log(f" [{idx}] '{msg_text[:30]}...' → 不重要")
# 计算删除配额
delete_target = int(original_count * proportion)
if proportion > 0 and original_count > 0 and delete_target == 0:
delete_target = 1
# 确保至少保留1条消息
max_deletable = max(0, original_count - 1)
delete_target = min(delete_target, max_deletable)
# 删除策略:优先删除填充消息,再删除不重要消息
to_delete_indices = set()
deleted_details = [] # 记录删除的消息详情
# 第一步:删除填充消息
filler_to_delete = min(len(filler_msgs), delete_target)
for i in range(filler_to_delete):
idx, msg = filler_msgs[i]
to_delete_indices.add(idx)
deleted_details.append(f"[{idx}] 填充: '{msg.msg[:50]}'")
# 第二步:如果还需要删除,删除不重要消息
remaining_quota = delete_target - len(to_delete_indices)
if remaining_quota > 0:
unimp_to_delete = min(len(unimportant_msgs), remaining_quota)
for i in range(unimp_to_delete):
idx, msg = unimportant_msgs[i]
to_delete_indices.add(idx)
deleted_details.append(f"[{idx}] 不重要: '{msg.msg[:50]}'")
# 第三步:如果还需要删除,按重要性分数删除重要消息
remaining_quota = delete_target - len(to_delete_indices)
if remaining_quota > 0 and important_msgs:
# 按重要性分数排序(分数低的优先删除)
imp_sorted = sorted(important_msgs, key=lambda x: self._importance_score(x[1]))
imp_to_delete = min(len(imp_sorted), remaining_quota)
for i in range(imp_to_delete):
idx, msg = imp_sorted[i]
to_delete_indices.add(idx)
score = self._importance_score(msg)
deleted_details.append(f"[{idx}] 重要(分数{score}): '{msg.msg[:50]}'")
# 执行删除
kept_msgs = []
for idx, m in enumerate(msgs):
if idx not in to_delete_indices:
kept_msgs.append(m)
# 确保至少保留1条
if not kept_msgs and msgs:
kept_msgs = [msgs[0]]
dd.context.msgs = kept_msgs
deleted_count = original_count - len(kept_msgs)
total_deleted_msgs += deleted_count
# 输出删除详情
if deleted_details:
self._log(f"[剪枝-删除详情] 对话 {d_idx+1} 删除了以下消息:")
for detail in deleted_details:
self._log(f" {detail}")
# ========== 问答对统计(已注释) ==========
# qa_info = f",问答对={len(qa_pairs)}" if qa_pairs else ""
# ========================================
self._log(
f"[剪枝-对话] 对话 {d_idx+1} 总消息={original_count} "
f"(重要={len(important_msgs)} 不重要={len(unimportant_msgs)} 填充={len(filler_msgs)}) "
f"删除={deleted_count} 保留={len(kept_msgs)}"
)
result.append(dd)
self._log(f"[剪枝-数据集] 剩余对话数={len(result)}")
# 将本次剪枝阶段的终端输出保存为 JSON 文件(仅在剪枝器内部完成)
# 保存日志
try:
from app.core.config import settings
settings.ensure_memory_output_dir()
log_output_path = settings.get_memory_output_path("pruned_terminal.json")
# 去除日志前缀标签(如 [剪枝-数据集]、[剪枝-对话])后再解析为结构化字段保存
sanitized_logs = [self._sanitize_log_line(l) for l in self.run_logs]
payload = self._parse_logs_to_structured(sanitized_logs)
with open(log_output_path, "w", encoding="utf-8") as f:
@@ -448,6 +720,7 @@ class SemanticPruner:
if not result:
print("警告: 语义剪枝后数据集为空,已回退为未剪枝数据以避免流程中断")
return dialogs
return result
def _log(self, msg: str) -> None:

View File

@@ -0,0 +1,326 @@
"""
场景特定配置 - 为不同场景提供定制化的剪枝规则
功能:
- 场景特定的重要信息识别模式
- 场景特定的重要性评分权重
- 场景特定的填充词库
- 场景特定的问答对识别规则
"""
from typing import Dict, List, Set, Tuple
from dataclasses import dataclass, field
@dataclass
class ScenePatterns:
"""场景特定的识别模式"""
# 重要信息的正则模式(优先级从高到低)
high_priority_patterns: List[Tuple[str, int]] = field(default_factory=list) # (pattern, weight)
medium_priority_patterns: List[Tuple[str, int]] = field(default_factory=list)
low_priority_patterns: List[Tuple[str, int]] = field(default_factory=list)
# 填充词库(无意义对话)
filler_phrases: Set[str] = field(default_factory=set)
# 问句关键词(用于识别问答对)
question_keywords: Set[str] = field(default_factory=set)
# 决策性/承诺性关键词
decision_keywords: Set[str] = field(default_factory=set)
class SceneConfigRegistry:
"""场景配置注册表 - 管理所有场景的特定配置"""
# 基础通用模式(所有场景共享)
BASE_HIGH_PRIORITY = [
(r"订单号|工单|申请号|编号|ID|账号|账户", 5),
(r"金额|费用|价格|¥|¥|\d+元", 5),
(r"\d{11}", 4), # 手机号
(r"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}", 4), # 邮箱
]
BASE_MEDIUM_PRIORITY = [
(r"\d{4}-\d{1,2}-\d{1,2}", 3), # 日期
(r"\d{4}\d{1,2}月\d{1,2}日", 3),
(r"电话|手机号|微信|QQ|联系方式", 3),
(r"地址|地点|位置", 2),
(r"时间|日期|有效期|截止", 2),
(r"今天|明天|后天|昨天|前天", 3), # 相对时间(提高权重)
(r"下周|下月|下年|上周|上月|上年|本周|本月|本年", 3),
(r"今年|去年|明年", 3),
]
BASE_LOW_PRIORITY = [
(r"\d{1,2}:\d{2}", 2), # 时间点 HH:MM
(r"\d{1,2}点\d{0,2}分?", 2), # 时间点 X点Y分 或 X点
(r"上午|下午|中午|晚上|早上|傍晚|凌晨", 2), # 时段(提高权重并扩充)
(r"AM|PM|am|pm", 1),
]
BASE_FILLERS = {
# 基础寒暄
"你好", "您好", "在吗", "在的", "在呢", "", "嗯嗯", "", "哦哦",
"好的", "", "", "可以", "不可以", "谢谢", "多谢", "感谢",
"拜拜", "再见", "88", "", "回见",
# 口头禅
"哈哈", "呵呵", "哈哈哈", "嘿嘿", "嘻嘻", "hiahia",
"", "", "", "", "", "", "嗯哼",
# 确认词
"是的", "", "对的", "没错", "嗯嗯", "好嘞", "收到", "明白", "了解", "知道了",
# 标点和符号
"。。。", "...", "???", "", "!!!", "",
# 表情符号
"[微笑]", "[呲牙]", "[发呆]", "[得意]", "[流泪]", "[害羞]", "[闭嘴]",
"[睡]", "[大哭]", "[尴尬]", "[发怒]", "[调皮]", "[龇牙]", "[惊讶]",
"[难过]", "[酷]", "[冷汗]", "[抓狂]", "[吐]", "[偷笑]", "[可爱]",
"[白眼]", "[傲慢]", "[饥饿]", "[困]", "[惊恐]", "[流汗]", "[憨笑]",
# 网络用语
"hhh", "hhhh", "2333", "666", "gg", "ok", "OK", "okok",
"emmm", "emm", "em", "mmp", "wtf", "omg",
}
BASE_QUESTION_KEYWORDS = {
"什么", "为什么", "怎么", "如何", "哪里", "哪个", "", "多少", "几点", "何时", ""
}
BASE_DECISION_KEYWORDS = {
"必须", "一定", "务必", "需要", "要求", "规定", "应该",
"承诺", "保证", "确保", "负责", "同意", "答应"
}
@classmethod
def get_education_config(cls) -> ScenePatterns:
"""教育场景配置"""
return ScenePatterns(
high_priority_patterns=cls.BASE_HIGH_PRIORITY + [
# 成绩相关(最高优先级)
(r"成绩|分数|得分|满分|及格|不及格", 6),
(r"GPA|绩点|学分|平均分", 6),
(r"\d+分|\d+\.?\d*分", 5), # 具体分数
(r"排名|名次|第.{1,3}名", 5), # 支持"第三名"、"第1名"等
# 学籍信息
(r"学号|学生证|教师工号|工号", 5),
(r"班级|年级|专业|院系", 4),
# 课程相关
(r"课程|科目|学科|必修|选修", 4),
(r"教材|课本|教科书|参考书", 4),
(r"章节|第.{1,3}章|第.{1,3}节", 3), # 支持"第三章"、"第1章"等
# 学科内容(新增)
(r"微积分|导数|积分|函数|极限|微分", 4),
(r"代数|几何|三角|概率|统计", 4),
(r"物理|化学|生物|历史|地理", 4),
(r"英语|语文|数学|政治|哲学", 4),
(r"定义|定理|公式|概念|原理|法则", 3),
(r"例题|解题|证明|推导|计算", 3),
],
medium_priority_patterns=cls.BASE_MEDIUM_PRIORITY + [
# 教学活动
(r"作业|练习|习题|题目", 3),
(r"考试|测验|测试|考核|期中|期末", 3),
(r"上课|下课|课堂|讲课", 2),
(r"提问|回答|发言|讨论", 2),
(r"问一下|请教|咨询|询问", 2), # 新增:问询相关
(r"理解|明白|懂|掌握|学会", 2), # 新增:学习状态
# 时间安排
(r"课表|课程表|时间表", 3),
(r"第.{1,3}节课|第.{1,3}周", 2), # 支持"第三节课"、"第1周"等
],
low_priority_patterns=cls.BASE_LOW_PRIORITY + [
(r"老师|教师|同学|学生", 1),
(r"教室|实验室|图书馆", 1),
],
filler_phrases=cls.BASE_FILLERS | {
# 教育场景特有填充词(移除了"明白了"、"懂了"、"不懂"等,这些在教育场景中有意义)
"老师好", "同学们好", "上课", "下课", "起立", "坐下",
"举手", "请坐", "很好", "不错", "继续",
"下一个", "下一题", "下一位", "还有吗", "还有问题吗",
},
question_keywords=cls.BASE_QUESTION_KEYWORDS | {
"为啥", "", "咋办", "怎样", "如何做",
"能不能", "可不可以", "行不行", "对不对", "是不是",
},
decision_keywords=cls.BASE_DECISION_KEYWORDS | {
"必考", "重点", "考点", "难点", "关键",
"记住", "背诵", "掌握", "理解", "复习",
}
)
@classmethod
def get_online_service_config(cls) -> ScenePatterns:
"""在线服务场景配置"""
return ScenePatterns(
high_priority_patterns=cls.BASE_HIGH_PRIORITY + [
# 工单相关(最高优先级)
(r"工单号|工单编号|ticket|TK\d+", 6),
(r"工单状态|处理中|已解决|已关闭|待处理", 5),
(r"优先级|紧急|高优先级|P0|P1|P2", 5),
# 产品信息
(r"产品型号|型号|SKU|产品编号", 5),
(r"序列号|SN|设备号", 5),
(r"版本号|软件版本|固件版本", 4),
# 问题描述
(r"故障|错误|异常|bug|问题", 4),
(r"错误代码|故障代码|error code", 5),
(r"无法|不能|失败|报错", 3),
],
medium_priority_patterns=cls.BASE_MEDIUM_PRIORITY + [
# 服务相关
(r"退款|退货|换货|补发", 4),
(r"发票|收据|凭证", 3),
(r"物流|快递|运单号", 3),
(r"保修|质保|售后", 3),
# 时效相关
(r"SLA|响应时间|处理时长", 4),
(r"超时|延迟|等待", 2),
],
low_priority_patterns=cls.BASE_LOW_PRIORITY + [
(r"客服|工程师|技术支持", 1),
(r"用户|客户|会员", 1),
],
filler_phrases=cls.BASE_FILLERS | {
# 在线服务特有填充词
"您好", "请问", "请稍等", "稍等", "马上", "立即",
"正在查询", "正在处理", "正在为您", "帮您查一下",
"还有其他问题吗", "还需要什么帮助", "很高兴为您服务",
"感谢您的耐心等待", "抱歉让您久等了",
"已记录", "已反馈", "已转接", "已升级",
"祝您生活愉快", "再见", "欢迎下次咨询",
},
question_keywords=cls.BASE_QUESTION_KEYWORDS | {
"能否", "可否", "是否", "有没有", "能不能",
"怎么办", "如何处理", "怎么解决",
},
decision_keywords=cls.BASE_DECISION_KEYWORDS | {
"立即处理", "马上解决", "尽快", "优先",
"升级", "转接", "派单", "跟进",
"补偿", "赔偿", "退款", "换货",
}
)
@classmethod
def get_outbound_config(cls) -> ScenePatterns:
"""外呼场景配置"""
return ScenePatterns(
high_priority_patterns=cls.BASE_HIGH_PRIORITY + [
# 意向相关(最高优先级)
(r"意向|意愿|兴趣|感兴趣", 6),
(r"A类|B类|C类|D类|高意向|低意向", 6),
(r"成交|签约|下单|购买|确认", 6),
# 联系信息(外呼场景中更重要)
(r"预约|约定|安排|确定时间", 5),
(r"下次联系|回访|跟进", 5),
(r"方便|有空|可以|时间", 4),
# 通话状态
(r"接通|未接通|占线|关机|停机", 4),
(r"通话时长|通话时间", 3),
],
medium_priority_patterns=cls.BASE_MEDIUM_PRIORITY + [
# 客户信息
(r"姓名|称呼|先生|女士", 3),
(r"公司|单位|职位|职务", 3),
(r"需求|要求|期望", 3),
# 跟进状态
(r"跟进状态|进展|进度", 3),
(r"已联系|待联系|联系中", 2),
(r"拒绝|不感兴趣|考虑|再说", 3),
],
low_priority_patterns=cls.BASE_LOW_PRIORITY + [
(r"销售|客户经理|业务员", 1),
(r"产品|服务|方案", 1),
],
filler_phrases=cls.BASE_FILLERS | {
# 外呼场景特有填充词
"您好", "", "hello", "打扰了", "不好意思",
"方便接电话吗", "现在方便吗", "占用您一点时间",
"我是", "我们是", "我们公司", "我们这边",
"了解一下", "介绍一下", "简单说一下",
"考虑考虑", "想一想", "再说", "再看看",
"不需要", "不感兴趣", "没兴趣", "不用了",
"好的", "", "可以", "没问题", "那就这样",
"再联系", "回头聊", "有需要再说",
},
question_keywords=cls.BASE_QUESTION_KEYWORDS | {
"有没有", "需不需要", "要不要", "考虑不考虑",
"了解吗", "知道吗", "听说过吗",
"方便吗", "有空吗", "在吗",
},
decision_keywords=cls.BASE_DECISION_KEYWORDS | {
"确定", "决定", "选择", "购买", "下单",
"预约", "安排", "约定", "确认",
"跟进", "回访", "联系", "沟通",
}
)
@classmethod
def get_config(cls, scene: str, fallback_to_generic: bool = True) -> ScenePatterns:
"""根据场景名称获取配置
Args:
scene: 场景名称 ('education', 'online_service', 'outbound' 或其他)
fallback_to_generic: 如果场景不存在,是否降级到通用配置
Returns:
对应场景的配置,如果场景不存在:
- fallback_to_generic=True: 返回通用配置(仅基础规则)
- fallback_to_generic=False: 抛出异常
"""
scene_map = {
'education': cls.get_education_config,
'online_service': cls.get_online_service_config,
'outbound': cls.get_outbound_config,
}
if scene in scene_map:
return scene_map[scene]()
if fallback_to_generic:
# 返回通用配置(仅包含基础规则,不包含场景特定规则)
return cls.get_generic_config()
else:
raise ValueError(f"不支持的场景: {scene},支持的场景: {list(scene_map.keys())}")
@classmethod
def get_generic_config(cls) -> ScenePatterns:
"""通用场景配置 - 仅包含基础规则,适用于未定义的场景
这是一个保守的配置,只使用最通用的规则,避免误删重要信息
"""
return ScenePatterns(
high_priority_patterns=cls.BASE_HIGH_PRIORITY,
medium_priority_patterns=cls.BASE_MEDIUM_PRIORITY,
low_priority_patterns=cls.BASE_LOW_PRIORITY,
filler_phrases=cls.BASE_FILLERS,
question_keywords=cls.BASE_QUESTION_KEYWORDS,
decision_keywords=cls.BASE_DECISION_KEYWORDS
)
@classmethod
def get_all_scenes(cls) -> List[str]:
"""获取所有预定义场景的列表"""
return ['education', 'online_service', 'outbound']
@classmethod
def is_scene_supported(cls, scene: str) -> bool:
"""检查场景是否有专门的配置支持
Args:
scene: 场景名称
Returns:
True: 有专门配置
False: 将使用通用配置
"""
return scene in cls.get_all_scenes()

View File

@@ -1932,17 +1932,17 @@ def preprocess_data(
Returns:
经过清洗转换后的 DialogData 列表
"""
print("\n=== 数据预处理 ===")
logger.debug("=== 数据预处理 ===")
from app.core.memory.storage_services.extraction_engine.data_preprocessing.data_preprocessor import (
DataPreprocessor,
)
preprocessor = DataPreprocessor()
try:
cleaned_data = preprocessor.preprocess(input_path=input_path, output_path=output_path, skip_cleaning=skip_cleaning, indices=indices)
print(f"数据预处理完成!共处理了 {len(cleaned_data)} 条对话数据")
logger.debug(f"数据预处理完成!共处理了 {len(cleaned_data)} 条对话数据")
return cleaned_data
except Exception as e:
print(f"数据预处理过程中出现错误: {e}")
logger.error(f"数据预处理过程中出现错误: {e}")
raise
@@ -1961,7 +1961,7 @@ async def get_chunked_dialogs_from_preprocessed(
Returns:
带 chunks 的 DialogData 列表
"""
print(f"\n=== 批量对话分块处理 (使用 {chunker_strategy}) ===")
logger.debug(f"=== 批量对话分块处理 (使用 {chunker_strategy}) ===")
if not data:
raise ValueError("预处理数据为空,无法进行分块")
@@ -1988,6 +1988,7 @@ async def get_chunked_dialogs_with_preprocessing(
input_data_path: Optional[str] = None,
llm_client: Optional[Any] = None,
skip_cleaning: bool = True,
pruning_config: Optional[Dict] = None,
) -> List[DialogData]:
"""包含数据预处理步骤的完整分块流程
@@ -2000,11 +2001,12 @@ async def get_chunked_dialogs_with_preprocessing(
input_data_path: 输入数据路径
llm_client: LLM 客户端
skip_cleaning: 是否跳过数据清洗步骤默认False
pruning_config: 剪枝配置字典,包含 pruning_switch, pruning_scene, pruning_threshold
Returns:
带 chunks 的 DialogData 列表
"""
print("\n=== 完整数据处理流程(包含预处理)===")
logger.debug("=== 完整数据处理流程(包含预处理)===")
if input_data_path is None:
input_data_path = os.path.join(
@@ -2030,7 +2032,19 @@ async def get_chunked_dialogs_with_preprocessing(
from app.core.memory.storage_services.extraction_engine.data_preprocessing.data_pruning import (
SemanticPruner,
)
pruner = SemanticPruner(llm_client=llm_client)
from app.core.memory.models.config_models import PruningConfig
# 构建剪枝配置
if pruning_config:
# 使用传入的配置
config = PruningConfig(**pruning_config)
logger.debug(f"[剪枝] 使用传入配置: switch={config.pruning_switch}, scene={config.pruning_scene}, threshold={config.pruning_threshold}")
else:
# 使用默认配置(关闭剪枝)
config = None
logger.debug("[剪枝] 未提供配置,使用默认配置(剪枝关闭)")
pruner = SemanticPruner(config=config, llm_client=llm_client)
# 记录单对话场景下剪枝前的消息数量
single_dialog_original_msgs = None
@@ -2043,12 +2057,12 @@ async def get_chunked_dialogs_with_preprocessing(
if len(preprocessed_data) == 1 and single_dialog_original_msgs is not None:
remaining_msgs = len(preprocessed_data[0].context.msgs) if preprocessed_data[0].context else 0
deleted_msgs = max(0, single_dialog_original_msgs - remaining_msgs)
print(
logger.debug(
f"语义剪枝完成!剩余 1 条对话!原始消息数:{single_dialog_original_msgs}"
f"保留消息数:{remaining_msgs},删除 {deleted_msgs} 条。"
)
else:
print(f"语义剪枝完成!剩余 {len(preprocessed_data)} 条对话")
logger.debug(f"语义剪枝完成!剩余 {len(preprocessed_data)} 条对话")
# 保存剪枝后的数据
try:
@@ -2059,9 +2073,9 @@ async def get_chunked_dialogs_with_preprocessing(
dp = DataPreprocessor(output_file_path=pruned_output_path)
dp.save_data(preprocessed_data, output_path=pruned_output_path)
except Exception as se:
print(f"保存剪枝结果失败:{se}")
logger.error(f"保存剪枝结果失败:{se}")
except Exception as e:
print(f"语义剪枝过程中出现错误,跳过剪枝: {e}")
logger.error(f"语义剪枝过程中出现错误,跳过剪枝: {e}")
# 步骤3: 对话分块
return await get_chunked_dialogs_from_preprocessed(

View File

@@ -1,5 +1,7 @@
import os
from typing import Optional
from typing import Optional, List, Any
from enum import Enum
from pathlib import Path
from app.core.logging_config import get_memory_logger
from app.core.memory.models.message_models import DialogData, Chunk
@@ -10,6 +12,20 @@ from app.core.memory.utils.config.config_utils import get_chunker_config
logger = get_memory_logger(__name__)
class ChunkerStrategy(Enum):
"""Supported chunking strategies."""
RECURSIVE = "RecursiveChunker"
SEMANTIC = "SemanticChunker"
LATE = "LateChunker"
NEURAL = "NeuralChunker"
LLM = "LLMChunker"
@classmethod
def get_valid_strategies(cls) -> List[str]:
"""Get list of valid strategy names."""
return [strategy.value for strategy in cls]
class DialogueChunker:
"""A class that processes dialogues and fills them with chunks based on a specified strategy.
@@ -17,23 +33,51 @@ class DialogueChunker:
of different chunking strategies to dialogue data.
"""
def __init__(self, chunker_strategy: str = "RecursiveChunker", llm_client=None):
def __init__(self, chunker_strategy: str = "RecursiveChunker", llm_client: Optional[Any] = None):
"""Initialize the DialogueChunker with a specific chunking strategy.
Args:
chunker_strategy: The chunking strategy to use (default: RecursiveChunker)
Options: SemanticChunker, RecursiveChunker, LateChunker, NeuralChunker
Options: SemanticChunker, RecursiveChunker, LateChunker, NeuralChunker, LLMChunker
llm_client: LLM client instance (required for LLMChunker strategy)
Raises:
ValueError: If chunker_strategy is invalid or required parameters are missing
"""
self.chunker_strategy = chunker_strategy
chunker_config_dict = get_chunker_config(chunker_strategy)
self.chunker_config = ChunkerConfig.model_validate(chunker_config_dict)
# Validate strategy
valid_strategies = ChunkerStrategy.get_valid_strategies()
if chunker_strategy not in valid_strategies:
raise ValueError(
f"Invalid chunker_strategy: '{chunker_strategy}'. "
f"Must be one of {valid_strategies}"
)
if self.chunker_config.chunker_strategy == "LLMChunker":
self.chunker_client = ChunkerClient(self.chunker_config, llm_client)
else:
self.chunker_client = ChunkerClient(self.chunker_config)
self.chunker_strategy = chunker_strategy
logger.info(f"Initializing DialogueChunker with strategy: {chunker_strategy}")
try:
# Load and validate configuration
chunker_config_dict = get_chunker_config(chunker_strategy)
if not chunker_config_dict:
raise ValueError(f"Failed to load configuration for strategy: {chunker_strategy}")
self.chunker_config = ChunkerConfig.model_validate(chunker_config_dict)
# Initialize chunker client
if self.chunker_config.chunker_strategy == "LLMChunker":
if not llm_client:
raise ValueError("llm_client is required for LLMChunker strategy")
self.chunker_client = ChunkerClient(self.chunker_config, llm_client)
else:
self.chunker_client = ChunkerClient(self.chunker_config)
logger.info(f"DialogueChunker initialized successfully with strategy: {chunker_strategy}")
except Exception as e:
logger.error(f"Failed to initialize DialogueChunker: {e}", exc_info=True)
raise
async def process_dialogue(self, dialogue: DialogData) -> list[Chunk]:
async def process_dialogue(self, dialogue: DialogData) -> List[Chunk]:
"""Process a dialogue by generating chunks and adding them to the DialogData object.
Args:
@@ -43,54 +87,125 @@ class DialogueChunker:
A list of Chunk objects
Raises:
ValueError: If chunking fails or returns empty chunks
ValueError: If dialogue is invalid or chunking fails
Exception: If chunking process encounters an error
"""
result_dialogue = await self.chunker_client.generate_chunks(dialogue)
chunks = result_dialogue.chunks
if not chunks or len(chunks) == 0:
# Validate input
if not dialogue:
raise ValueError("dialogue cannot be None")
if not dialogue.context or not dialogue.context.msgs:
raise ValueError(
f"Chunking failed: No chunks generated for dialogue {dialogue.ref_id}. "
f"Messages: {len(dialogue.context.msgs) if dialogue.context else 0}, "
f"Strategy: {self.chunker_config.chunker_strategy}"
f"Dialogue {dialogue.ref_id} has no messages to chunk. "
f"Context: {dialogue.context is not None}, "
f"Messages: {len(dialogue.context.msgs) if dialogue.context else 0}"
)
logger.info(
f"Processing dialogue {dialogue.ref_id} with {len(dialogue.context.msgs)} messages "
f"using strategy: {self.chunker_strategy}"
)
try:
# Generate chunks
result_dialogue = await self.chunker_client.generate_chunks(dialogue)
chunks = result_dialogue.chunks
return chunks
# Validate results
if not chunks or len(chunks) == 0:
raise ValueError(
f"Chunking failed: No chunks generated for dialogue {dialogue.ref_id}. "
f"Messages: {len(dialogue.context.msgs)}, "
f"Content length: {len(dialogue.content) if dialogue.content else 0}, "
f"Strategy: {self.chunker_config.chunker_strategy}"
)
def save_chunking_results(self, dialogue: DialogData, output_path: Optional[str] = None) -> str:
logger.info(
f"Successfully generated {len(chunks)} chunks for dialogue {dialogue.ref_id}. "
f"Total characters processed: {len(dialogue.content) if dialogue.content else 0}"
)
return chunks
except ValueError:
# Re-raise validation errors
raise
except Exception as e:
logger.error(
f"Error processing dialogue {dialogue.ref_id} with strategy {self.chunker_strategy}: {e}",
exc_info=True
)
raise
def save_chunking_results(
self,
chunks: List[Chunk],
dialogue: DialogData,
output_path: Optional[str] = None,
preview_length: int = 100
) -> str:
"""Save the chunking results to a file and return the output path.
Args:
dialogue: The processed DialogData object with chunks
output_path: Optional path to save the output
chunks: List of Chunk objects to save
dialogue: The DialogData object that was processed
output_path: Optional path to save the output (defaults to current directory)
preview_length: Maximum length of content preview (default: 100)
Returns:
The path where the output was saved
Raises:
ValueError: If chunks or dialogue is invalid
IOError: If file writing fails
"""
if not output_path:
output_path = os.path.join(
os.path.dirname(__file__), "..", "..",
f"chunker_output_{self.chunker_strategy.lower()}.txt"
)
output_lines = [
f"=== Chunking Results ({self.chunker_strategy}) ===",
f"Dialogue ID: {dialogue.ref_id}",
f"Original conversation has {len(dialogue.context.msgs)} messages",
f"Total characters: {len(dialogue.content)}",
f"Generated {len(dialogue.chunks)} chunks:"
]
# Validate input
if not chunks:
raise ValueError("chunks list cannot be empty")
if not dialogue:
raise ValueError("dialogue cannot be None")
for i, chunk in enumerate(dialogue.chunks):
output_lines.append(f" Chunk {i+1}: {len(chunk.content)} characters")
output_lines.append(f" Content preview: {chunk.content}...")
if chunk.metadata:
output_lines.append(f" Metadata: {chunk.metadata}")
# Generate default output path if not provided
if not output_path:
output_dir = Path(__file__).parent.parent.parent
output_path = str(output_dir / f"chunker_output_{self.chunker_strategy.lower()}.txt")
logger.info(f"Saving chunking results to: {output_path}")
try:
# Prepare output content
output_lines = [
f"=== Chunking Results ({self.chunker_strategy}) ===",
f"Dialogue ID: {dialogue.ref_id}",
f"Original conversation has {len(dialogue.context.msgs) if dialogue.context else 0} messages",
f"Total characters: {len(dialogue.content) if dialogue.content else 0}",
f"Generated {len(chunks)} chunks:",
""
]
for i, chunk in enumerate(chunks, 1):
content_preview = chunk.content[:preview_length] if chunk.content else ""
if len(chunk.content) > preview_length:
content_preview += "..."
output_lines.append(f" Chunk {i}: {len(chunk.content)} characters")
output_lines.append(f" Content preview: {content_preview}")
if chunk.metadata:
output_lines.append(f" Metadata: {chunk.metadata}")
output_lines.append("")
with open(output_path, "w", encoding="utf-8") as f:
f.write("\n".join(output_lines))
# Write to file
with open(output_path, "w", encoding="utf-8") as f:
f.write("\n".join(output_lines))
logger.info(f"Chunking results saved to: {output_path}")
return output_path
logger.info(f"Successfully saved chunking results to: {output_path}")
return output_path
except IOError as e:
logger.error(f"Failed to write chunking results to {output_path}: {e}", exc_info=True)
raise
except Exception as e:
logger.error(f"Unexpected error saving chunking results: {e}", exc_info=True)
raise

View File

@@ -327,7 +327,7 @@ class MultiOntologyParser:
Example:
>>> parser = MultiOntologyParser([
... "General_purpose_entity.ttl",
... "app/core/memory/ontology_services/General_purpose_entity.ttl",
... "domain_specific.owl"
... ])
>>> registry = parser.parse_all()

View File

@@ -400,7 +400,8 @@ async def render_user_summary_prompt(
user_id: str,
entities: str,
statements: str,
language: str = "zh"
language: str = "zh",
user_display_name: str = None
) -> str:
"""
Renders the user summary prompt using the user_summary.jinja2 template.
@@ -410,16 +411,22 @@ async def render_user_summary_prompt(
entities: Core entities with frequency information
statements: Representative statement samples
language: The language to use for summary generation ("zh" for Chinese, "en" for English)
user_display_name: Display name for the user (e.g., other_name or "该用户"/"the user")
Returns:
Rendered prompt content as string
"""
# 如果没有提供 user_display_name使用默认值
if user_display_name is None:
user_display_name = "该用户" if language == "zh" else "the user"
template = prompt_env.get_template("user_summary.jinja2")
rendered_prompt = template.render(
user_id=user_id,
entities=entities,
statements=statements,
language=language
language=language,
user_display_name=user_display_name
)
# 记录渲染结果到提示日志
@@ -429,7 +436,8 @@ async def render_user_summary_prompt(
'user_id': user_id,
'entities_len': len(entities),
'statements_len': len(statements),
'language': language
'language': language,
'user_display_name': user_display_name
})
return rendered_prompt
@@ -540,3 +548,20 @@ async def render_ontology_extraction_prompt(
})
return rendered_prompt
def render_interest_filter_prompt(tag_list: str, language: str = "zh") -> str:
"""
Renders the interest filter prompt using the interest_filter.jinja2 template.
Args:
tag_list: Comma-separated string of raw tags to filter
language: Output language ("zh" for Chinese, "en" for English)
Returns:
Rendered prompt content as string
"""
template = prompt_env.get_template("interest_filter.jinja2")
rendered_prompt = template.render(tag_list=tag_list, language=language)
log_prompt_rendering('interest filter', rendered_prompt)
return rendered_prompt

View File

@@ -0,0 +1,67 @@
{% if language == "zh" %}
You are a user interest analysis expert. Your task is to infer and extract the user's core hobby/interest activities from a tag list. The tags may be specific project names, tool names, or compound nouns — your job is to identify the underlying interest they represent.
**Step 1 - Infer the underlying interest from each tag**:
Look at each tag and ask: "What hobby or interest does this tag suggest the user has?"
Examples of inference:
- '攀岩', '室内攀岩馆', '攀岩者数据仪表盘', '路线解锁地图', '指力', '路线等级', '当日攀岩流畅度' → '攀岩'
- '风光摄影元数据增强器', 'EXIF数据', '.CR2文件', '.NEF文件', '日出拍摄点', '曝光补偿', '光圈', '太阳高度角', '云量预测图层' → '摄影'
- '晨间冥想坚持天数', '身心协同峰值' → '冥想'
- '川味可视化', '川菜' → '烹饪'
- '开源项目命名建议', 'climbviz', '可视化', '力量增长雷达图' → '编程' 或 '数据可视化'
- '吉他', '指弹', '琴谱' → '吉他'
- '跑步', '5公里', '跑鞋' → '跑步'
- '瑜伽垫', '瑜伽课' → '瑜伽'
**Step 2 - Consolidate and deduplicate**:
- Merge tags that point to the same interest into one representative label
- Use concise, standard hobby names (e.g., '攀岩', '摄影', '编程', '烹饪', '冥想', '吉他', '跑步')
- If multiple tags all point to '攀岩', output '攀岩' only once
**Step 3 - Filter out non-interest tags**:
Remove tags that do NOT suggest any hobby or interest:
- Generic system/assistant terms (e.g., '助手', '用户', 'AI')
- Pure abstract metrics with no clear hobby link (e.g., '完成时间', '日期', '自我评分')
- Location names with no clear hobby link (e.g., '青城山后山' alone — but if combined with photography context, infer '摄影')
**Output format**: Return a list of concise interest activity names in Chinese.
**Example**:
Input: ['攀岩', '攀岩者数据仪表盘', '路线解锁地图', '指力', '风光摄影元数据增强器', 'EXIF数据', '晨间冥想坚持天数', '川味可视化', '可视化', '助手', '完成时间']
Output: ['攀岩', '摄影', '冥想', '烹饪', '编程']
Now process the following tag list and return the inferred interest activities in Chinese: {{ tag_list }}
{% else %}
You are a user interest analysis expert. Your task is to infer and extract the user's core hobby/interest activities from a tag list. The tags may be specific project names, tool names, or compound nouns — your job is to identify the underlying interest they represent.
**Step 1 - Infer the underlying interest from each tag**:
Look at each tag and ask: "What hobby or interest does this tag suggest the user has?"
Examples of inference:
- 'rock climbing', 'indoor climbing gym', 'climber dashboard', 'route map', 'finger strength' → 'rock climbing'
- 'landscape photography metadata enhancer', 'EXIF data', 'sunrise shooting spot', 'exposure compensation' → 'photography'
- 'morning meditation streak', 'mind-body peak' → 'meditation'
- 'Sichuan cuisine visualization', 'Sichuan food' → 'cooking'
- 'open source project', 'data visualization tool', 'Python' → 'programming'
- 'guitar', 'fingerpicking', 'sheet music' → 'guitar'
- 'running', '5km', 'running shoes' → 'running'
**Step 2 - Consolidate and deduplicate**:
- Merge tags that point to the same interest into one representative label
- Use concise, standard hobby names (e.g., 'rock climbing', 'photography', 'programming', 'cooking', 'meditation')
- If multiple tags all point to 'rock climbing', output 'rock climbing' only once
**Step 3 - Filter out non-interest tags**:
Remove tags that do NOT suggest any hobby or interest:
- Generic system/assistant terms (e.g., 'assistant', 'user', 'AI')
- Pure abstract metrics with no clear hobby link (e.g., 'completion time', 'date', 'self-rating')
**Output format**: Return a list of concise interest activity names in English.
**Example**:
Input: ['rock climbing', 'climber dashboard', 'route map', 'finger strength', 'landscape photography metadata enhancer', 'EXIF data', 'morning meditation streak', 'Sichuan cuisine visualization', 'visualization', 'assistant', 'completion time']
Output: ['rock climbing', 'photography', 'meditation', 'cooking', 'programming']
Now process the following tag list and return the inferred interest activities in English: {{ tag_list }}
{% endif %}

View File

@@ -14,8 +14,8 @@ Your task is to generate a comprehensive user profile based on the provided enti
{% endif %}
===Inputs===
{% if user_id %}
- User ID: {{ user_id }}
{% if user_display_name %}
- User Display Name: {{ user_display_name }}
{% endif %}
{% if entities %}
- Core Entities & Frequency: {{ entities }}
@@ -33,6 +33,20 @@ Your task is to generate a comprehensive user profile based on the provided enti
3. Avoid excessive adjectives and empty phrases
4. Strictly follow the output format specified below
{% if language == "zh" %}
**【严格人称规定】**
- 在描述用户时,必须使用"{{ user_display_name }}"作为人称
- 绝对禁止使用用户ID如 {{ user_id }})来称呼用户
- 绝对禁止在摘要中出现任何形式的UUID或ID字符串
- 如果需要指代用户,只能使用"{{ user_display_name }}"或相应的代词(他/她/TA
{% else %}
**【STRICT PRONOUN RULES】**
- When describing the user, you MUST use "{{ user_display_name }}" as the reference
- It is ABSOLUTELY FORBIDDEN to use the user ID (such as {{ user_id }}) to refer to the user
- It is ABSOLUTELY FORBIDDEN to include any form of UUID or ID string in the summary
- If you need to refer to the user, you can ONLY use "{{ user_display_name }}" or appropriate pronouns (he/she/they)
{% endif %}
**Section-Specific Requirements:**
{% if language == "zh" %}
@@ -103,13 +117,13 @@ Your task is to generate a comprehensive user profile based on the provided enti
{% if language == "zh" %}
Example Input:
- User ID: user_12345
- User Display Name: 张三
- Core Entities & Frequency: 产品经理 (15), AI (12), 深圳 (10), 数据分析 (8), 团队协作 (7)
- Representative Statement Samples: 我在深圳从事产品经理工作已经5年了 | 我相信好的产品源于对用户需求的深刻理解 | 我喜欢在团队中起到协调作用 | 数据驱动决策是我的工作原则
Example Output:
【基本介绍】
我是张三一名充满热情的高级产品经理。在过去的5年里专注于AI和数据驱动的产品设计致力于创造能够真正改善用户生活的产品。相信好的产品源于对用户需求的深刻理解和对技术可能性的不断探索。
张三一名充满热情的高级产品经理,在深圳工作。在过去的5年里张三专注于AI和数据驱动的产品设计致力于创造能够真正改善用户生活的产品。张三相信好的产品源于对用户需求的深刻理解和对技术可能性的不断探索。
【性格特点】
性格开朗,善于沟通,注重细节。喜欢在团队中起到协调作用,帮助大家达成共识。面对挑战时保持乐观,相信每个问题都有解决方案。
@@ -121,13 +135,13 @@ Example Output:
"让每一个产品决策都充满温度。"
{% else %}
Example Input:
- User ID: user_12345
- User Display Name: John
- Core Entities & Frequency: Product Manager (15), AI (12), San Francisco (10), Data Analysis (8), Team Collaboration (7)
- Representative Statement Samples: I have been working as a product manager in San Francisco for 5 years | I believe good products come from deep understanding of user needs | I enjoy playing a coordinating role in teams | Data-driven decision making is my work principle
Example Output:
【Basic Introduction】
This is a passionate senior product manager based in San Francisco. Over the past 5 years, they have focused on AI and data-driven product design, dedicated to creating products that truly improve users' lives. They believe good products stem from deep understanding of user needs and continuous exploration of technological possibilities.
John is a passionate senior product manager based in San Francisco. Over the past 5 years, John has focused on AI and data-driven product design, dedicated to creating products that truly improve users' lives. John believes good products stem from deep understanding of user needs and continuous exploration of technological possibilities.
【Personality Traits】
Outgoing personality with excellent communication skills and attention to detail. Enjoys playing a coordinating role in teams, helping everyone reach consensus. Maintains optimism when facing challenges, believing every problem has a solution.

View File

@@ -21,31 +21,55 @@ from pydantic import BaseModel, Field
T = TypeVar("T")
class RedBearModelConfig(BaseModel):
"""模型配置基类"""
model_name: str
provider: str
api_key: str
base_url: Optional[str] = None
is_omni: bool = False # 是否为 Omni 模型
# 请求超时时间(秒)- 默认120秒以支持复杂的LLM调用可通过环境变量 LLM_TIMEOUT 配置
timeout: float = Field(default_factory=lambda: float(os.getenv("LLM_TIMEOUT", "120.0")))
# 最大重试次数 - 默认2次以避免过长等待可通过环境变量 LLM_MAX_RETRIES 配置
max_retries: int = Field(default_factory=lambda: int(os.getenv("LLM_MAX_RETRIES", "2")))
concurrency: int = 5 # 并发限流
concurrency: int = 5 # 并发限流
extra_params: Dict[str, Any] = {}
class RedBearModelFactory:
"""模型工厂类"""
@classmethod
def get_model_params(cls, config: RedBearModelConfig) -> Dict[str, Any]:
"""根据提供商获取模型参数"""
provider = config.provider.lower()
# 打印供应商信息用于调试
from app.core.logging_config import get_business_logger
logger = get_business_logger()
logger.debug(f"获取模型参数 - Provider: {provider}, Model: {config.model_name}")
logger.debug(f"获取模型参数 - Provider: {provider}, Model: {config.model_name}, is_omni: {config.is_omni}")
# dashscope 的 omni 模型使用 OpenAI 兼容模式
if provider == ModelProvider.DASHSCOPE and config.is_omni:
import httpx
if not config.base_url:
config.base_url = "https://dashscope.aliyuncs.com/compatible-mode/v1"
timeout_config = httpx.Timeout(
timeout=config.timeout,
connect=60.0,
read=config.timeout,
write=60.0,
pool=10.0,
)
return {
"model": config.model_name,
"base_url": config.base_url,
"api_key": config.api_key,
"timeout": timeout_config,
"max_retries": config.max_retries,
**config.extra_params
}
if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK, ModelProvider.OLLAMA]:
# 使用 httpx.Timeout 对象来设置详细的超时配置
@@ -65,7 +89,7 @@ class RedBearModelFactory:
"timeout": timeout_config,
"max_retries": config.max_retries,
**config.extra_params
}
}
elif provider == ModelProvider.DASHSCOPE:
# DashScope (通义千问) 使用自己的参数格式
# 注意: DashScopeEmbeddings 不支持 timeout 和 base_url 参数
@@ -82,7 +106,7 @@ class RedBearModelFactory:
# region 从 base_url 或 extra_params 获取
from botocore.config import Config as BotoConfig
from app.core.models.bedrock_model_mapper import normalize_bedrock_model_id
max_pool_connections = int(os.getenv("BEDROCK_MAX_POOL_CONNECTIONS", "50"))
max_retries = int(os.getenv("BEDROCK_MAX_RETRIES", "2"))
# Configure with increased connection pool
@@ -90,16 +114,16 @@ class RedBearModelFactory:
max_pool_connections=max_pool_connections,
retries={'max_attempts': max_retries, 'mode': 'adaptive'}
)
# 标准化模型 ID自动转换简化名称为完整 Bedrock Model ID
model_id = normalize_bedrock_model_id(config.model_name)
params = {
"model_id": model_id,
"config": boto_config,
**config.extra_params
}
# 解析 API key (格式: access_key_id:secret_access_key)
if config.api_key and ":" in config.api_key:
access_key_id, secret_access_key = config.api_key.split(":", 1)
@@ -107,45 +131,52 @@ class RedBearModelFactory:
params["aws_secret_access_key"] = secret_access_key
elif config.api_key:
params["aws_access_key_id"] = config.api_key
# 设置 region
if config.base_url:
params["region_name"] = config.base_url
elif "region_name" not in params:
params["region_name"] = "us-east-1" # 默认区域
return params
else:
raise BusinessException(f"不支持的提供商: {provider}", code=BizCode.PROVIDER_NOT_SUPPORTED)
@classmethod
def get_rerank_model_params(cls, config: RedBearModelConfig) -> Dict[str, Any]:
"""根据提供商获取模型参数"""
provider = config.provider.lower()
if provider in [ModelProvider.XINFERENCE, ModelProvider.GPUSTACK]:
return {
return {
"model": config.model_name,
# "base_url": config.base_url,
"jina_api_key": config.api_key,
**config.extra_params
}
}
else:
raise BusinessException(f"不支持的提供商: {provider}", code=BizCode.PROVIDER_NOT_SUPPORTED)
def get_provider_llm_class(config:RedBearModelConfig, type: ModelType=ModelType.LLM) -> type[BaseLLM]:
def get_provider_llm_class(config: RedBearModelConfig, type: ModelType = ModelType.LLM) -> type[BaseLLM]:
"""根据模型提供商获取对应的模型类"""
provider = config.provider.lower()
if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK] :
# dashscope 的 omni 模型使用 OpenAI 兼容模式
if provider == ModelProvider.DASHSCOPE and config.is_omni:
from langchain_openai import ChatOpenAI
return ChatOpenAI
if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK] :
if type == ModelType.LLM:
from langchain_openai import OpenAI
return OpenAI
return OpenAI
elif type == ModelType.CHAT:
from langchain_openai import ChatOpenAI
return ChatOpenAI
elif provider == ModelProvider.DASHSCOPE:
from langchain_community.chat_models import ChatTongyi
return ChatTongyi
elif provider == ModelProvider.OLLAMA:
elif provider == ModelProvider.OLLAMA:
from langchain_ollama import OllamaLLM
return OllamaLLM
elif provider == ModelProvider.BEDROCK:
@@ -155,15 +186,16 @@ def get_provider_llm_class(config:RedBearModelConfig, type: ModelType=ModelType.
else:
raise BusinessException(f"不支持的模型提供商: {provider}", code=BizCode.PROVIDER_NOT_SUPPORTED)
def get_provider_embedding_class(provider: str) -> type[Embeddings]:
"""根据模型提供商获取对应的模型类"""
provider = provider.lower()
if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK] :
if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK]:
from langchain_openai import OpenAIEmbeddings
return OpenAIEmbeddings
return OpenAIEmbeddings
elif provider == ModelProvider.DASHSCOPE:
from langchain_community.embeddings import DashScopeEmbeddings
return DashScopeEmbeddings
return DashScopeEmbeddings
elif provider == ModelProvider.OLLAMA:
from langchain_ollama import OllamaEmbeddings
return OllamaEmbeddings
@@ -173,14 +205,15 @@ def get_provider_embedding_class(provider: str) -> type[Embeddings]:
else:
raise BusinessException(f"不支持的模型提供商: {provider}", code=BizCode.PROVIDER_NOT_SUPPORTED)
def get_provider_rerank_class(provider: str):
"""根据模型提供商获取对应的模型类"""
provider = provider.lower()
if provider in [ModelProvider.XINFERENCE, ModelProvider.GPUSTACK] :
provider = provider.lower()
if provider in [ModelProvider.XINFERENCE, ModelProvider.GPUSTACK]:
from langchain_community.document_compressors import JinaRerank
return JinaRerank
# elif provider == ModelProvider.OLLAMA:
return JinaRerank
# elif provider == ModelProvider.OLLAMA:
# from langchain_ollama import OllamaEmbeddings
# return OllamaEmbeddings
else:
raise BusinessException(f"不支持的模型提供商: {provider}", code=BizCode.PROVIDER_NOT_SUPPORTED)
raise BusinessException(f"不支持的模型提供商: {provider}", code=BizCode.PROVIDER_NOT_SUPPORTED)

View File

@@ -6,6 +6,8 @@ models:
description: AI21 Labs大语言模型completion生成模式256000上下文窗口
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
logo: bedrock
@@ -15,6 +17,9 @@ models:
description: Amazon Nova大语言模型支持智能体思考、工具调用、流式工具调用、视觉能力300000上下文窗口对话模式
is_deprecated: false
is_official: true
capability:
- vision
is_omni: false
tags:
- 大语言模型
- agent-thought
@@ -28,6 +33,9 @@ models:
description: Anthropic Claude大语言模型支持智能体思考、视觉能力、工具调用、流式工具调用、文档处理200000上下文窗口对话模式
is_deprecated: false
is_official: true
capability:
- vision
is_omni: false
tags:
- 大语言模型
- agent-thought
@@ -42,6 +50,8 @@ models:
description: Cohere大语言模型支持智能体思考、工具调用、流式工具调用128000上下文窗口对话模式
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- agent-thought
@@ -54,6 +64,9 @@ models:
description: DeepSeek大语言模型支持智能体思考、视觉能力、工具调用、流式工具调用32768上下文窗口对话模式
is_deprecated: false
is_official: true
capability:
- vision
is_omni: false
tags:
- 大语言模型
- agent-thought
@@ -67,6 +80,8 @@ models:
description: Meta Llama大语言模型支持智能体思考、工具调用128000上下文窗口对话模式
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- agent-thought
@@ -78,6 +93,8 @@ models:
description: Mistral AI大语言模型支持智能体思考、工具调用32000上下文窗口对话模式
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- agent-thought
@@ -89,6 +106,8 @@ models:
description: OpenAI大语言模型支持智能体思考、工具调用、流式工具调用32768上下文窗口对话模式
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- agent-thought
@@ -101,6 +120,8 @@ models:
description: Qwen大语言模型支持智能体思考、工具调用、流式工具调用32768上下文窗口对话模式
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- agent-thought
@@ -113,6 +134,8 @@ models:
description: amazon.rerank-v1:0重排序模型5120上下文窗口
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 重排序模型
logo: bedrock
@@ -122,6 +145,8 @@ models:
description: cohere.rerank-v3-5:0重排序模型5120上下文窗口
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 重排序模型
logo: bedrock
@@ -131,6 +156,9 @@ models:
description: amazon.nova-2-multimodal-embeddings-v1:0文本嵌入模型支持视觉能力8192上下文窗口
is_deprecated: false
is_official: true
capability:
- vision
is_omni: false
tags:
- 文本嵌入模型
- vision
@@ -141,6 +169,8 @@ models:
description: amazon.titan-embed-text-v1文本嵌入模型8192上下文窗口
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 文本嵌入模型
logo: bedrock
@@ -150,6 +180,8 @@ models:
description: amazon.titan-embed-text-v2:0文本嵌入模型8192上下文窗口
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 文本嵌入模型
logo: bedrock
@@ -159,6 +191,8 @@ models:
description: Cohere Embed 3 English文本嵌入模型512上下文窗口
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 文本嵌入模型
logo: bedrock
@@ -168,6 +202,8 @@ models:
description: Cohere Embed 3 Multilingual文本嵌入模型512上下文窗口
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 文本嵌入模型
logo: bedrock
logo: bedrock

View File

@@ -6,6 +6,8 @@ models:
description: DeepSeek-R1-Distill-Qwen-14B大语言模型支持智能体思考32000上下文窗口对话模式
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- agent-thought
@@ -16,6 +18,8 @@ models:
description: DeepSeek-R1-Distill-Qwen-32B大语言模型支持智能体思考32000上下文窗口对话模式
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- agent-thought
@@ -26,6 +30,8 @@ models:
description: DeepSeek-R1大语言模型支持智能体思考131072超大上下文窗口对话模式
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- agent-thought
@@ -36,6 +42,8 @@ models:
description: DeepSeek-V3.1大语言模型支持智能体思考131072超大上下文窗口对话模式支持丰富生成参数调节
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- agent-thought
@@ -46,6 +54,8 @@ models:
description: DeepSeek-V3.2-exp实验版大语言模型支持智能体思考131072超大上下文窗口对话模式支持丰富生成参数调节
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- agent-thought
@@ -56,6 +66,8 @@ models:
description: DeepSeek-V3.2大语言模型支持智能体思考131072超大上下文窗口对话模式支持丰富生成参数调节
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- agent-thought
@@ -66,6 +78,8 @@ models:
description: DeepSeek-V3大语言模型支持智能体思考64000上下文窗口对话模式支持文本与JSON格式输出
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- agent-thought
@@ -76,6 +90,8 @@ models:
description: farui-plus大语言模型支持多工具调用、智能体思考、流式工具调用12288上下文窗口对话模式
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- multi-tool-call
@@ -88,6 +104,8 @@ models:
description: GLM-4.7大语言模型支持多工具调用、智能体思考、流式工具调用202752超大上下文窗口对话模式
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- multi-tool-call
@@ -100,6 +118,9 @@ models:
description: qvq-max-latest大语言模型支持视觉、智能体思考、流式工具调用131072上下文窗口对话模式
is_deprecated: false
is_official: true
capability:
- vision
is_omni: false
tags:
- 大语言模型
- vision
@@ -112,6 +133,9 @@ models:
description: qvq-max大语言模型支持视觉、智能体思考、流式工具调用131072上下文窗口对话模式
is_deprecated: false
is_official: true
capability:
- vision
is_omni: false
tags:
- 大语言模型
- vision
@@ -124,6 +148,8 @@ models:
description: qwen-coder-turbo-0919代码专用大语言模型支持智能体思考131072上下文窗口对话模式已废弃
is_deprecated: true
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- 代码模型
@@ -135,6 +161,8 @@ models:
description: qwen-max-latest大语言模型支持多工具调用、智能体思考、流式工具调用131072上下文窗口对话模式支持联网搜索
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- multi-tool-call
@@ -147,6 +175,8 @@ models:
description: qwen-max-longcontext长上下文大语言模型支持多工具调用、智能体思考、流式工具调用32000上下文窗口对话模式已废弃
is_deprecated: true
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- multi-tool-call
@@ -159,6 +189,8 @@ models:
description: qwen-max大语言模型支持多工具调用、智能体思考、流式工具调用32768上下文窗口对话模式支持联网搜索
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- multi-tool-call
@@ -171,6 +203,8 @@ models:
description: qwen-mt-plus多语言翻译大语言模型支持智能体思考16384上下文窗口对话模式支持多语种互译与领域翻译适配
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- 翻译模型
@@ -182,6 +216,8 @@ models:
description: qwen-mt-turbo轻量化多语言翻译大语言模型支持智能体思考16384上下文窗口对话模式支持多语种互译与领域翻译适配
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- 翻译模型
@@ -193,6 +229,8 @@ models:
description: qwen-plus-0112大语言模型支持多工具调用、智能体思考、流式工具调用131072上下文窗口对话模式支持联网搜索已废弃
is_deprecated: true
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- multi-tool-call
@@ -205,6 +243,8 @@ models:
description: qwen-plus-0125大语言模型支持多工具调用、智能体思考、流式工具调用131072上下文窗口对话模式支持联网搜索已废弃
is_deprecated: true
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- multi-tool-call
@@ -217,6 +257,8 @@ models:
description: qwen-plus-0723大语言模型支持多工具调用、智能体思考、流式工具调用32000上下文窗口对话模式支持联网搜索已废弃
is_deprecated: true
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- multi-tool-call
@@ -229,6 +271,8 @@ models:
description: qwen-plus-0806大语言模型支持多工具调用、智能体思考、流式工具调用131072上下文窗口对话模式支持联网搜索已废弃
is_deprecated: true
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- multi-tool-call
@@ -241,6 +285,8 @@ models:
description: qwen-plus-0919大语言模型支持多工具调用、智能体思考、流式工具调用131072上下文窗口对话模式支持联网搜索已废弃
is_deprecated: true
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- multi-tool-call
@@ -253,6 +299,8 @@ models:
description: qwen-plus-1125大语言模型支持多工具调用、智能体思考、流式工具调用131072上下文窗口对话模式支持联网搜索已废弃
is_deprecated: true
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- multi-tool-call
@@ -265,6 +313,8 @@ models:
description: qwen-plus-1127大语言模型支持多工具调用、智能体思考、流式工具调用131072上下文窗口对话模式支持联网搜索已废弃
is_deprecated: true
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- multi-tool-call
@@ -277,6 +327,8 @@ models:
description: qwen-plus-1220大语言模型支持多工具调用、智能体思考、流式工具调用131072上下文窗口对话模式已废弃
is_deprecated: true
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- multi-tool-call
@@ -289,6 +341,10 @@ models:
description: qwen-vl-max多模态大模型支持视觉理解、智能体思考、视频理解131072上下文窗口对话模式未废弃
is_deprecated: false
is_official: true
capability:
- vision
- video
is_omni: false
tags:
- 大语言模型
- 多模态模型
@@ -302,6 +358,10 @@ models:
description: qwen-vl-plus-0809多模态大模型支持视觉理解、智能体思考、视频理解32768上下文窗口对话模式已废弃
is_deprecated: true
is_official: true
capability:
- vision
- video
is_omni: false
tags:
- 大语言模型
- 多模态模型
@@ -315,6 +375,10 @@ models:
description: qwen-vl-plus-2025-01-02多模态大模型支持视觉理解、智能体思考、视频理解32768上下文窗口对话模式未废弃
is_deprecated: false
is_official: true
capability:
- vision
- video
is_omni: false
tags:
- 大语言模型
- 多模态模型
@@ -328,6 +392,10 @@ models:
description: qwen-vl-plus-2025-01-25多模态大模型支持视觉理解、智能体思考、视频理解131072上下文窗口对话模式未废弃
is_deprecated: false
is_official: true
capability:
- vision
- video
is_omni: false
tags:
- 大语言模型
- 多模态模型
@@ -341,6 +409,10 @@ models:
description: qwen-vl-plus-latest多模态大模型支持视觉理解、智能体思考、视频理解131072上下文窗口对话模式未废弃
is_deprecated: false
is_official: true
capability:
- vision
- video
is_omni: false
tags:
- 大语言模型
- 多模态模型
@@ -354,6 +426,10 @@ models:
description: qwen-vl-plus多模态大模型支持视觉理解、智能体思考、视频理解131072上下文窗口对话模式未废弃
is_deprecated: false
is_official: true
capability:
- vision
- video
is_omni: false
tags:
- 大语言模型
- 多模态模型
@@ -367,6 +443,8 @@ models:
description: qwen2.5-0.5b-instruct大语言模型支持多工具调用、智能体思考、流式工具调用32768上下文窗口对话模式未废弃
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- multi-tool-call
@@ -379,6 +457,8 @@ models:
description: qwen3-14b大语言模型支持多工具调用、智能体思考、流式工具调用131072上下文窗口对话模式
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- multi-tool-call
@@ -391,6 +471,8 @@ models:
description: qwen3-235b-a22b-instruct-2507大语言模型支持多工具调用、智能体思考、流式工具调用131072上下文窗口对话模式
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- multi-tool-call
@@ -403,6 +485,8 @@ models:
description: qwen3-235b-a22b-thinking-2507大语言模型支持多工具调用、智能体思考、流式工具调用131072上下文窗口对话模式
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- multi-tool-call
@@ -415,6 +499,8 @@ models:
description: qwen3-235b-a22b大语言模型支持多工具调用、智能体思考、流式工具调用131072上下文窗口对话模式
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- multi-tool-call
@@ -427,6 +513,8 @@ models:
description: qwen3-30b-a3b-instruct-2507大语言模型支持多工具调用、智能体思考、流式工具调用131072上下文窗口对话模式
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- multi-tool-call
@@ -439,6 +527,8 @@ models:
description: qwen3-30b-a3b大语言模型支持多工具调用、智能体思考、流式工具调用131072上下文窗口对话模式
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- multi-tool-call
@@ -451,6 +541,8 @@ models:
description: qwen3-32b大语言模型支持多工具调用、智能体思考、流式工具调用131072上下文窗口对话模式
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- multi-tool-call
@@ -463,6 +555,8 @@ models:
description: qwen3-4b大语言模型支持多工具调用、智能体思考、流式工具调用131072上下文窗口对话模式
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- multi-tool-call
@@ -475,6 +569,8 @@ models:
description: qwen3-8b大语言模型支持多工具调用、智能体思考、流式工具调用131072上下文窗口对话模式
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- multi-tool-call
@@ -487,6 +583,8 @@ models:
description: qwen3-coder-30b-a3b-instruct大语言模型支持智能体思考262144上下文窗口对话模式
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- 代码模型
@@ -498,6 +596,8 @@ models:
description: qwen3-coder-480b-a35b-instruct大语言模型支持智能体思考262144上下文窗口对话模式
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- 代码模型
@@ -509,6 +609,8 @@ models:
description: qwen3-coder-plus-2025-09-23大语言模型支持智能体思考1000000上下文窗口对话模式
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- 代码模型
@@ -520,6 +622,8 @@ models:
description: qwen3-coder-plus大语言模型支持智能体思考1000000上下文窗口对话模式
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- 代码模型
@@ -531,6 +635,8 @@ models:
description: qwen3-max-2025-09-23大语言模型支持多工具调用、智能体思考、流式工具调用262144上下文窗口对话模式支持联网搜索
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- multi-tool-call
@@ -544,6 +650,8 @@ models:
description: qwen3-max-2026-01-23大语言模型支持多工具调用、智能体思考、流式工具调用262144上下文窗口对话模式支持联网搜索
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- multi-tool-call
@@ -557,6 +665,8 @@ models:
description: qwen3-max-preview大语言模型支持多工具调用、智能体思考、流式工具调用262144上下文窗口对话模式
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- multi-tool-call
@@ -569,6 +679,8 @@ models:
description: qwen3-max大语言模型支持多工具调用、智能体思考、流式工具调用262144上下文窗口对话模式支持联网搜索
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- multi-tool-call
@@ -582,6 +694,8 @@ models:
description: qwen3-next-80b-a3b-instruct大语言模型支持多工具调用、智能体思考、流式工具调用131072上下文窗口对话模式
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- multi-tool-call
@@ -594,6 +708,8 @@ models:
description: qwen3-next-80b-a3b-thinking大语言模型支持多工具调用、智能体思考、流式工具调用131072上下文窗口对话模式
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- multi-tool-call
@@ -606,6 +722,11 @@ models:
description: qwen3-omni-flash-2025-12-01多模态大语言模型支持视觉、智能体思考、视频、音频能力65536上下文窗口对话模式
is_deprecated: false
is_official: true
capability:
- vision
- video
- audio
is_omni: true
tags:
- 大语言模型
- 多模态模型
@@ -620,6 +741,10 @@ models:
description: qwen3-vl-235b-a22b-instruct多模态大语言模型支持多工具调用、智能体思考、流式工具调用、视觉、视频能力131072上下文窗口对话模式
is_deprecated: false
is_official: true
capability:
- vision
- video
is_omni: false
tags:
- 大语言模型
- 多模态模型
@@ -635,6 +760,10 @@ models:
description: qwen3-vl-235b-a22b-thinking多模态大语言模型支持多工具调用、智能体思考、流式工具调用、视觉、视频能力131072上下文窗口对话模式
is_deprecated: false
is_official: true
capability:
- vision
- video
is_omni: false
tags:
- 大语言模型
- 多模态模型
@@ -650,6 +779,10 @@ models:
description: qwen3-vl-30b-a3b-instruct多模态大语言模型支持多工具调用、智能体思考、流式工具调用、视觉、视频能力131072上下文窗口对话模式
is_deprecated: false
is_official: true
capability:
- vision
- video
is_omni: false
tags:
- 大语言模型
- 多模态模型
@@ -665,6 +798,10 @@ models:
description: qwen3-vl-30b-a3b-thinking多模态大语言模型支持多工具调用、智能体思考、流式工具调用、视觉、视频能力131072上下文窗口对话模式
is_deprecated: false
is_official: true
capability:
- vision
- video
is_omni: false
tags:
- 大语言模型
- 多模态模型
@@ -680,6 +817,10 @@ models:
description: qwen3-vl-flash多模态大语言模型支持多工具调用、智能体思考、流式工具调用、视觉、视频能力131072上下文窗口对话模式
is_deprecated: false
is_official: true
capability:
- vision
- video
is_omni: false
tags:
- 大语言模型
- 多模态模型
@@ -695,6 +836,10 @@ models:
description: qwen3-vl-plus-2025-09-23多模态大语言模型支持视觉、智能体思考、视频能力262144上下文窗口对话模式
is_deprecated: false
is_official: true
capability:
- vision
- video
is_omni: false
tags:
- 大语言模型
- 多模态模型
@@ -708,6 +853,10 @@ models:
description: qwen3-vl-plus多模态大语言模型支持视觉、智能体思考、视频能力262144上下文窗口对话模式
is_deprecated: false
is_official: true
capability:
- vision
- video
is_omni: false
tags:
- 大语言模型
- 多模态模型
@@ -721,6 +870,8 @@ models:
description: qwq-32b大语言模型支持智能体思考、流式工具调用131072上下文窗口对话模式
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- agent-thought
@@ -732,6 +883,8 @@ models:
description: qwq-plus-0305大语言模型支持智能体思考、流式工具调用131072上下文窗口对话模式
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- agent-thought
@@ -743,6 +896,8 @@ models:
description: qwq-plus大语言模型支持智能体思考、流式工具调用131072上下文窗口对话模式
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- agent-thought
@@ -754,6 +909,8 @@ models:
description: gte-rerank-v2重排序模型4000上下文窗口
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 重排序模型
logo: dashscope
@@ -763,6 +920,8 @@ models:
description: gte-rerank重排序模型4000上下文窗口
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 重排序模型
logo: dashscope
@@ -772,6 +931,9 @@ models:
description: multimodal-embedding-v1多模态嵌入模型支持视觉能力8192上下文窗口最大分块数10
is_deprecated: false
is_official: true
capability:
- vision
is_omni: false
tags:
- 嵌入模型
- 多模态模型
@@ -783,6 +945,8 @@ models:
description: text-embedding-v1文本嵌入模型2048上下文窗口最大分块数25
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 嵌入模型
- 文本嵌入
@@ -793,6 +957,8 @@ models:
description: text-embedding-v2文本嵌入模型2048上下文窗口最大分块数25
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 嵌入模型
- 文本嵌入
@@ -803,6 +969,8 @@ models:
description: text-embedding-v3文本嵌入模型8192上下文窗口最大分块数10
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 嵌入模型
- 文本嵌入
@@ -813,7 +981,9 @@ models:
description: text-embedding-v4文本嵌入模型8192上下文窗口最大分块数10
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 嵌入模型
- 文本嵌入
logo: dashscope
logo: dashscope

View File

@@ -6,7 +6,7 @@ from typing import Callable
import yaml
from sqlalchemy.orm import Session
from app.models.models_model import ModelBase, ModelProvider
from app.models.models_model import ModelBase, ModelProvider, ModelConfig
def _load_yaml_config(provider: ModelProvider) -> list[dict]:
@@ -55,6 +55,15 @@ def load_models(db: Session, providers: list[str] = None, silent: bool = False)
print(f"\n正在加载 {provider.value}{len(models)} 个模型...")
for model_data in models:
config_sync_fields = {
"logo": None,
"capability": None,
"is_omni": None,
"name": None,
"provider": None,
"type": None,
"description": None
}
try:
# 检查模型是否已存在
existing = db.query(ModelBase).filter(
@@ -66,6 +75,40 @@ def load_models(db: Session, providers: list[str] = None, silent: bool = False)
# 更新现有模型配置
for key, value in model_data.items():
setattr(existing, key, value)
# 更新绑定了该 model_id 的 ModelConfig 和 ModelApiKey
sync_fields = [k for k in config_sync_fields.keys() if k in model_data]
if sync_fields:
# 批量更新 ModelConfig
update_kwargs = {k: model_data[k] for k in sync_fields}
db.query(ModelConfig).filter(ModelConfig.model_id == existing.id).update(
update_kwargs,
synchronize_session=False
)
# 更新 ModelApiKey 的 capability 和 is_omni
if 'capability' in model_data or 'is_omni' in model_data:
from app.models.models_model import ModelApiKey, model_config_api_key_association
api_key_update = {}
if 'capability' in model_data:
api_key_update['capability'] = model_data['capability']
if 'is_omni' in model_data:
api_key_update['is_omni'] = model_data['is_omni']
if api_key_update:
# 查找所有关联的 API Key
api_key_ids = db.query(model_config_api_key_association.c.api_key_id).join(
ModelConfig,
ModelConfig.id == model_config_api_key_association.c.model_config_id
).filter(ModelConfig.model_id == existing.id).distinct().all()
if api_key_ids:
api_key_ids = [aid[0] for aid in api_key_ids]
db.query(ModelApiKey).filter(ModelApiKey.id.in_(api_key_ids)).update(
api_key_update,
synchronize_session=False
)
db.commit()
if not silent:
print(f"更新成功: {model_data['name']}")

View File

@@ -6,12 +6,19 @@ models:
description: chatgpt-4o-latest大语言模型支持多工具调用、智能体思考、流式工具调用、视觉能力128000上下文窗口对话模式
is_deprecated: false
is_official: true
capability:
- vision
- audio
- video
is_omni: true
tags:
- 大语言模型
- multi-tool-call
- agent-thought
- stream-tool-call
- vision
- audio
- video
logo: openai
- name: gpt-3.5-turbo-0125
type: llm
@@ -19,6 +26,8 @@ models:
description: gpt-3.5-turbo-0125大语言模型支持多工具调用、智能体思考、流式工具调用16385上下文窗口对话模式
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- multi-tool-call
@@ -31,6 +40,8 @@ models:
description: gpt-3.5-turbo-1106大语言模型支持多工具调用、智能体思考、流式工具调用16385上下文窗口对话模式
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- multi-tool-call
@@ -43,6 +54,8 @@ models:
description: gpt-3.5-turbo-16k大语言模型支持多工具调用、智能体思考、流式工具调用16385上下文窗口对话模式
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- multi-tool-call
@@ -55,6 +68,8 @@ models:
description: gpt-3.5-turbo-instruct大语言模型4096上下文窗口文本补全模式
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
logo: openai
@@ -64,6 +79,8 @@ models:
description: gpt-3.5-turbo大语言模型支持多工具调用、智能体思考、流式工具调用16385上下文窗口对话模式
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- multi-tool-call
@@ -76,6 +93,8 @@ models:
description: gpt-4-0125-preview大语言模型支持多工具调用、智能体思考、流式工具调用128000上下文窗口对话模式
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- multi-tool-call
@@ -88,6 +107,8 @@ models:
description: gpt-4-1106-preview大语言模型支持多工具调用、智能体思考、流式工具调用128000上下文窗口对话模式
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- multi-tool-call
@@ -100,6 +121,9 @@ models:
description: gpt-4-turbo-2024-04-09大语言模型支持多工具调用、智能体思考、流式工具调用、视觉能力128000上下文窗口对话模式
is_deprecated: false
is_official: true
capability:
- vision
is_omni: false
tags:
- 大语言模型
- multi-tool-call
@@ -113,6 +137,8 @@ models:
description: gpt-4-turbo-preview大语言模型支持多工具调用、智能体思考、流式工具调用128000上下文窗口对话模式
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- multi-tool-call
@@ -125,6 +151,9 @@ models:
description: gpt-4-turbo大语言模型支持多工具调用、智能体思考、流式工具调用、视觉能力128000上下文窗口对话模式
is_deprecated: false
is_official: true
capability:
- vision
is_omni: false
tags:
- 大语言模型
- multi-tool-call
@@ -138,6 +167,8 @@ models:
description: o1-preview大语言模型支持智能体思考128000上下文窗口对话模式已废弃
is_deprecated: true
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- agent-thought
@@ -148,6 +179,9 @@ models:
description: o1大语言模型支持多工具调用、智能体思考、流式工具调用、视觉能力、结构化输出200000上下文窗口对话模式
is_deprecated: false
is_official: true
capability:
- vision
is_omni: false
tags:
- 大语言模型
- multi-tool-call
@@ -162,6 +196,9 @@ models:
description: o3-2025-04-16大语言模型支持智能体思考、工具调用、视觉能力、流式工具调用、结构化输出200000上下文窗口对话模式
is_deprecated: false
is_official: true
capability:
- vision
is_omni: false
tags:
- 大语言模型
- agent-thought
@@ -176,6 +213,8 @@ models:
description: o3-mini-2025-01-31大语言模型支持智能体思考、工具调用、流式工具调用、结构化输出200000上下文窗口对话模式
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- agent-thought
@@ -189,6 +228,8 @@ models:
description: o3-mini大语言模型支持智能体思考、工具调用、流式工具调用、结构化输出200000上下文窗口对话模式
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
- agent-thought
@@ -202,6 +243,9 @@ models:
description: o3-pro-2025-06-10大语言模型支持智能体思考、工具调用、视觉能力、结构化输出200000上下文窗口对话模式
is_deprecated: false
is_official: true
capability:
- vision
is_omni: false
tags:
- 大语言模型
- agent-thought
@@ -215,6 +259,9 @@ models:
description: o3-pro大语言模型支持智能体思考、工具调用、视觉能力、结构化输出200000上下文窗口对话模式
is_deprecated: false
is_official: true
capability:
- vision
is_omni: false
tags:
- 大语言模型
- agent-thought
@@ -228,6 +275,9 @@ models:
description: o3大语言模型支持智能体思考、视觉能力、工具调用、流式工具调用、结构化输出200000上下文窗口对话模式
is_deprecated: false
is_official: true
capability:
- vision
is_omni: false
tags:
- 大语言模型
- agent-thought
@@ -242,6 +292,9 @@ models:
description: o4-mini-2025-04-16大语言模型支持智能体思考、工具调用、视觉能力、流式工具调用、结构化输出200000上下文窗口对话模式
is_deprecated: false
is_official: true
capability:
- vision
is_omni: false
tags:
- 大语言模型
- agent-thought
@@ -256,6 +309,9 @@ models:
description: o4-mini大语言模型支持智能体思考、工具调用、视觉能力、流式工具调用、结构化输出200000上下文窗口对话模式
is_deprecated: false
is_official: true
capability:
- vision
is_omni: false
tags:
- 大语言模型
- agent-thought
@@ -270,6 +326,8 @@ models:
description: text-embedding-3-large文本向量模型8191上下文窗口最大分块数32
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 文本向量模型
logo: openai
@@ -279,6 +337,8 @@ models:
description: text-embedding-3-small文本向量模型8191上下文窗口最大分块数32
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 文本向量模型
logo: openai
@@ -288,6 +348,8 @@ models:
description: text-embedding-ada-002文本向量模型8097上下文窗口最大分块数32
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 文本向量模型
logo: openai
logo: openai

View File

@@ -12,7 +12,7 @@ from app.core.workflow.adapters.errors import UnsupportVariableType, UnknowModel
ExceptionType
from app.core.workflow.nodes.assigner import AssignerNodeConfig
from app.core.workflow.nodes.assigner.config import AssignmentItem
from app.core.workflow.nodes.base_config import VariableDefinition
from app.core.workflow.nodes.base_config import VariableDefinition, BaseNodeConfig
from app.core.workflow.nodes.code import CodeNodeConfig
from app.core.workflow.nodes.code.config import InputVariable, OutputVariable
from app.core.workflow.nodes.configs import StartNodeConfig, LLMNodeConfig
@@ -69,9 +69,27 @@ class DifyConverter(BaseConverter):
}
def get_node_convert(self, node_type):
func = self.CONFIG_CONVERT_MAP.get(node_type, None)
func = self.CONFIG_CONVERT_MAP.get(node_type, lambda x: {})
return func
def config_validate(
self,
node_id: str,
node_name: str,
config: type[BaseNodeConfig],
value: dict
):
try:
return config.model_validate(value)
except Exception as e:
self.errors.append(ExceptionDefineition(
type=ExceptionType.CONFIG,
node_id=node_id,
node_name=node_name,
detail=str(e)
))
return None
@staticmethod
def is_variable(expression) -> bool:
return bool(re.match(r"\{\{#(.*?)#}}", expression))
@@ -80,14 +98,16 @@ class DifyConverter(BaseConverter):
if not var_selector:
return ""
selector = var_selector.split('.')
if len(selector) != 2:
if len(selector) not in [2, 3] and var_selector != "context":
raise Exception(f"invalid variable selector: {var_selector}")
if len(selector) == 3:
selector = selector[1:]
if selector[0] == "conversation":
selector[0] = "conv"
var_selector = ".".join(selector)
mapping = {
"sys.query": "sys.message"
} | self.node_output_map
"sys.query": "sys.message"
} | self.node_output_map
var_selector = mapping.get(var_selector, var_selector)
return var_selector
@@ -237,7 +257,7 @@ class DifyConverter(BaseConverter):
node_id=node["id"],
node_name=node_data["title"],
name=var["variable"],
detail=f"Unsupport Variable type for start node: {var_type}"
detail=f"Unsupported Variable type for start node: {var_type}"
)
)
continue
@@ -253,9 +273,11 @@ class DifyConverter(BaseConverter):
max_length=var.get("max_length"),
)
start_vars.append(var_def)
return StartNodeConfig(
result = StartNodeConfig.model_construct(
variables=start_vars
).model_dump()
self.config_validate(node["id"], node["data"]["title"], StartNodeConfig, result)
return result
def convert_question_classifier_node_config(self, node: dict) -> dict:
node_data = node["data"]
@@ -270,16 +292,18 @@ class DifyConverter(BaseConverter):
for category in node_data["classes"]:
self.branch_node_cache[node["id"]].append(category["id"])
categories.append(
ClassifierConfig(
ClassifierConfig.model_construct(
class_name=category["name"],
)
)
return QuestionClassifierNodeConfig.model_construct(
input_variable=self._process_list_variable_litearl(node_data["query_variable_selector"]),
user_supplement_prompt=self.trans_variable_format(node_data["instructions"]),
result = QuestionClassifierNodeConfig.model_construct(
input_variable=self._process_list_variable_litearl(node_data.get("query_variable_selector")),
user_supplement_prompt=self.trans_variable_format(node_data.get("instructions", "")),
categories=categories,
).model_dump()
self.config_validate(node["id"], node["data"]["title"], QuestionClassifierNodeConfig, result)
return result
def convert_llm_node_config(self, node: dict) -> dict:
node_data = node["data"]
@@ -308,14 +332,16 @@ class DifyConverter(BaseConverter):
messages.append(
MessageConfig(
role="user",
content=self.trans_variable_format(node_data["memory"]["query_prompt_template"])
content=self.trans_variable_format(
node_data["memory"].get("query_prompt_template", "{{#sys.query#}}")
)
)
)
vision = node_data["vision"]["enabled"]
vision_input = self._process_list_variable_litearl(
node_data["vision"]["configs"]["variable_selector"]
) if vision else None
return LLMNodeConfig.model_construct(
result = LLMNodeConfig.model_construct(
model_id=None,
context=context,
memory=memory,
@@ -323,12 +349,16 @@ class DifyConverter(BaseConverter):
vision_input=vision_input,
messages=messages
).model_dump()
self.config_validate(node["id"], node["data"]["title"], LLMNodeConfig, result)
return result
def convert_end_node_config(self, node: dict) -> dict:
node_data = node["data"]
return EndNodeConfig(
output=self.trans_variable_format(node_data["answer"]),
result = EndNodeConfig.model_construct(
output=self.trans_variable_format(node_data.get("answer", "")),
).model_dump()
self.config_validate(node["id"], node["data"]["title"], EndNodeConfig, result)
return result
def convert_if_else_node_config(self, node: dict) -> dict:
node_data = node["data"]
@@ -359,9 +389,11 @@ class DifyConverter(BaseConverter):
)
)
self.branch_node_cache[node["id"]].append(case_id)
return IfElseNodeConfig(
result = IfElseNodeConfig.model_construct(
cases=cases
).model_dump()
self.config_validate(node["id"], node["data"]["title"], IfElseNodeConfig, result)
return result
def convert_loop_node_config(self, node: dict) -> dict:
node_data = node["data"]
@@ -370,7 +402,7 @@ class DifyConverter(BaseConverter):
for condition in node_data["break_conditions"]:
right_value = condition["value"]
conditions.append(
LoopConditionDetail(
LoopConditionDetail.model_construct(
operator=self.convert_compare_operator(condition["comparison_operator"]),
left=self._process_list_variable_litearl(condition["variable_selector"]),
right=self.trans_variable_format(
@@ -383,7 +415,7 @@ class DifyConverter(BaseConverter):
if isinstance(right_value, str) and self.is_variable(right_value) else ValueInputType.CONSTANT,
)
)
condition_config = ConditionsConfig(
condition_config = ConditionsConfig.model_construct(
logical_operator=logical_operator,
expressions=conditions
)
@@ -392,9 +424,9 @@ class DifyConverter(BaseConverter):
right_input_type = variable["value_type"]
right_value_type = self.variable_type_map(variable["var_type"])
if right_input_type == ValueInputType.VARIABLE:
right_value = self._process_list_variable_litearl(variable["value"])
right_value = self._process_list_variable_litearl(variable.get("value", ""))
else:
right_value = self.convert_variable_type(right_value_type, variable["value"])
right_value = self.convert_variable_type(right_value_type, variable.get("value", ""))
loop_variables.append(
CycleVariable(
name=variable["label"],
@@ -403,23 +435,28 @@ class DifyConverter(BaseConverter):
input_type=right_input_type
)
)
return LoopNodeConfig(
result = LoopNodeConfig.model_construct(
condition=condition_config,
cycle_vars=loop_variables,
max_loop=node_data["loop_count"]
max_loop=node_data.get("loop_count", 10)
).model_dump()
self.config_validate(node["id"], node["data"]["title"], LoopNodeConfig, result)
return result
def convert_iteration_node_config(self, node: dict) -> dict:
node_data = node["data"]
return IterationNodeConfig(
result = IterationNodeConfig.model_construct(
input=self._process_list_variable_litearl(node_data["iterator_selector"]),
parallel=node_data["is_parallel"],
parallel_count=node_data["parallel_nums"],
output=self._process_list_variable_litearl(node_data["output_selector"]),
output_type=self.variable_type_map(node_data["output_type"]),
output_type=self.variable_type_map(node_data.get("output_type")),
flatten=node_data["flatten_output"],
).model_dump()
self.config_validate(node["id"], node["data"]["title"], IterationNodeConfig, result)
return result
def convert_assigner_node_config(self, node: dict) -> dict:
node_data = node["data"]
assignments = []
@@ -435,16 +472,18 @@ class DifyConverter(BaseConverter):
operation=self.convert_assignment_operator(assignment["operation"])
)
)
return AssignerNodeConfig(
result = AssignerNodeConfig.model_construct(
assignments=assignments
).model_dump()
self.config_validate(node["id"], node["data"]["title"], AssignerNodeConfig, result)
return result
def convert_code_node_config(self, node: dict) -> dict:
node_data = node["data"]
input_variables = []
for input_variable in node_data["variables"]:
input_variables.append(
InputVariable(
InputVariable.model_construct(
name=input_variable["variable"],
variable=self._process_list_variable_litearl(input_variable["value_selector"]),
)
@@ -453,7 +492,7 @@ class DifyConverter(BaseConverter):
output_variables = []
for output_variable in node_data["outputs"]:
output_variables.append(
OutputVariable(
OutputVariable.model_construct(
name=output_variable,
type=node_data["outputs"][output_variable]["type"],
)
@@ -461,18 +500,20 @@ class DifyConverter(BaseConverter):
code = base64.b64encode(quote(node_data["code"]).encode("utf-8")).decode("utf-8")
return CodeNodeConfig(
result = CodeNodeConfig.model_construct(
input_variables=input_variables,
language=node_data["code_language"],
output_variables=output_variables,
code=code
).model_dump()
self.config_validate(node["id"], node["data"]["title"], CodeNodeConfig, result)
return result
def convert_http_node_config(self, node: dict) -> dict:
node_data = node["data"]
if node_data["authorization"] != 'no-auth':
if node_data["authorization"]["type"] != 'no-auth':
auth_type = self.convert_http_auth_type(node_data["authorization"]["config"]["type"])
auth_config = HttpAuthConfig(
auth_config = HttpAuthConfig.model_construct(
auth_type=auth_type,
header=node_data["authorization"]["config"].get("header"),
api_key=node_data["authorization"]["config"].get("api_key"),
@@ -504,7 +545,7 @@ class DifyConverter(BaseConverter):
body_content = ""
headers = {}
for header in node_data["headers"].split("\n"):
for header in node_data.get("headers", "").split("\n"):
if not header:
continue
@@ -522,7 +563,7 @@ class DifyConverter(BaseConverter):
))
params = {}
for param in node_data["params"].split("\n"):
for param in node_data.get("params", "").split("\n"):
if not param:
continue
@@ -547,7 +588,7 @@ class DifyConverter(BaseConverter):
default_body = ""
default_header = {}
default_status_code = 0
for var in node_data["default_value"]:
for var in node_data.get("default_value") or []:
if var["key"] == "body":
default_body = var["value"]
elif var["key"] == "header":
@@ -561,45 +602,50 @@ class DifyConverter(BaseConverter):
)
self.error_branch_node_cache.append(node['id'])
return HttpRequestNodeConfig(
result = HttpRequestNodeConfig.model_construct(
method=node_data["method"].upper(),
url=node_data["url"],
auth=auth_config,
body=HttpContentTypeConfig(
body=HttpContentTypeConfig.model_construct(
content_type=self.convert_http_content_type(node_data["body"]["type"]),
data=body_content,
),
headers=headers,
params=params,
verify_ssl=node_data["ssl_verify"],
timeouts=HttpTimeOutConfig(
timeouts=HttpTimeOutConfig.model_construct(
connect_timeout=node_data["timeout"]["max_connect_timeout"] or 5,
read_timeout=node_data["timeout"]["max_read_timeout"] or 5,
write_timeout=node_data["timeout"]["max_write_timeout"] or 5,
),
retry=HttpRetryConfig(
retry=HttpRetryConfig.model_construct(
enable=node_data["retry_config"]["retry_enabled"],
max_attempts=node_data["retry_config"]["max_retries"],
retry_interval=node_data["retry_config"]["retry_interval"],
),
error_handle=HttpErrorHandleConfig(
error_handle=HttpErrorHandleConfig.model_construct(
method=error_handle_type,
default=default_value,
)
).model_dump()
self.config_validate(node["id"], node["data"]["title"], HttpRequestNodeConfig, result)
return result
def convert_jinja_render_node_config(self, node: dict) -> dict:
node_data = node["data"]
mapping = []
for variable in node_data["variables"]:
mapping.append(VariablesMappingConfig(
mapping.append(VariablesMappingConfig.model_construct(
name=variable["variable"],
value=self._process_list_variable_litearl(variable["value_selector"])
))
return JinjaRenderNodeConfig(
result = JinjaRenderNodeConfig.model_construct(
template=node_data["template"],
mapping=mapping,
).model_dump()
self.config_validate(node["id"], node["data"]["title"], JinjaRenderNodeConfig, result)
return result
def convert_knowledge_node_config(self, node: dict) -> dict:
node_data = node["data"]
@@ -609,10 +655,13 @@ class DifyConverter(BaseConverter):
type=ExceptionType.CONFIG,
detail=f"Please reconfigure the Knowledge Retrieval node.",
))
return KnowledgeRetrievalNodeConfig.model_construct(
result = KnowledgeRetrievalNodeConfig.model_construct(
query=self._process_list_variable_litearl(node_data["query_variable_selector"]),
).model_dump()
self.config_validate(node["id"], node["data"]["title"], KnowledgeRetrievalNodeConfig, result)
return result
def convert_parameter_extractor_node_config(self, node: dict) -> dict:
node_data = node["data"]
self.warnings.append(
@@ -623,46 +672,53 @@ class DifyConverter(BaseConverter):
)
)
params = []
for param in node_data["parameters"]:
for param in node_data.get("parameters", []):
params.append(
ParamsConfig(
ParamsConfig.model_construct(
name=param["name"],
desc=param["description"],
required=param["required"],
type=param["type"],
)
)
return ParameterExtractorNodeConfig.model_construct(
result = ParameterExtractorNodeConfig.model_construct(
text=self._process_list_variable_litearl(node_data["query"]),
params=params,
prompt=node_data["instruction"]
prompt=node_data.get("instruction")
).model_dump()
self.config_validate(node["id"], node["data"]["title"], ParameterExtractorNodeConfig, result)
return result
def convert_variable_aggregator_node_config(self, node: dict) -> dict:
node_data = node["data"]
group_enable = node_data["advanced_settings"]["group_enabled"]
advanced_settings = node_data.get("advanced_settings", {})
group_variables = {}
group_type = {}
if not group_enable:
if not advanced_settings or not advanced_settings["group_enabled"]:
group_variables["output"] = [
self._process_list_variable_litearl(variable)
for variable in node_data["variables"]
]
group_type["output"] = node_data["output_type"]
else:
for group in node_data["advanced_settings"]["groups"]:
for group in advanced_settings["groups"]:
group_variables[group["group_name"]] = [
self._process_list_variable_litearl(variable)
for variable in group["variables"]
]
group_type[group["group_name"]] = group["output_type"]
return VariableAggregatorNodeConfig(
group=group_enable,
result = VariableAggregatorNodeConfig.model_construct(
group=advanced_settings.get("group_enabled", False),
group_variables=group_variables,
group_type=group_type,
).model_dump()
self.config_validate(node["id"], node["data"]["title"], VariableAggregatorNodeConfig, result)
return result
def convert_tool_node_config(self, node: dict) -> dict:
node_data = node["data"]
self.warnings.append(ExceptionDefineition(
@@ -671,4 +727,4 @@ class DifyConverter(BaseConverter):
type=ExceptionType.CONFIG,
detail=f"Please reconfigure the tool node.",
))
return {}
return {}

View File

@@ -59,7 +59,7 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
)
def map_node_type(self, platform_node_type) -> str:
return self.NODE_TYPE_MAPPING.get(platform_node_type)
return self.NODE_TYPE_MAPPING.get(platform_node_type, NodeType.UNKNOWN)
@property
def origin_nodes(self):
@@ -80,7 +80,7 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
return True
def validate_config(self) -> bool:
require_fields = frozenset({'app', 'dependencies', 'kind', 'version', 'workflow'})
require_fields = frozenset({'app', 'kind', 'version', 'workflow'})
if not all(field in self.config for field in require_fields):
return False
@@ -179,8 +179,13 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
node_type = node_data["type"]
try:
converter = self.get_node_convert(node_type)
if converter is None:
raise Exception(f"node type not supported - {node_type}")
if node_type not in self.CONFIG_CONVERT_MAP:
self.errors.append(ExceptionDefineition(
type=ExceptionType.NODE,
node_id=node["id"],
node_name=node["data"]["title"],
detail=f"node type {node_type} is unsupported",
))
return converter(node)
except Exception as e:
self.errors.append(ExceptionDefineition(

View File

@@ -127,7 +127,7 @@ class EventStreamHandler:
yield {
"event": "message",
"data": {
"chunk": data.get("chunk")
"content": data.get("chunk")
}
}

View File

@@ -274,7 +274,7 @@ class StreamOutputCoordinator:
yield {
"event": "message",
"data": {
"chunk": final_chunk
"content": final_chunk
}
}

View File

@@ -73,7 +73,7 @@ class VariableStruct(BaseModel, Generic[T]):
instance:
The concrete variable object. The actual Python type is
represented by the generic parameter ``T`` (e.g. StringVariable,
NumberVariable, ArrayObject[StringVariable]).
NumberVariable, ArrayVariable[StringVariable]).
mut:
Whether the variable is mutable.
"""
@@ -152,6 +152,36 @@ class VariablePool:
return None
return var_instance
def get_instance(
self,
selector: str,
default: Any = None,
strict: bool = True
):
"""Retrieve a variable instance from the variable pool.
Args:
selector:
Variable selector as a string variable literal (e.g. "{{ sys.message }}").
default:
The value to return if the variable does not exist.
strict:
If True, raises KeyError when the variable does not exist.
Returns:
The variable instance object if it exists; otherwise returns `default`.
Raises:
KeyError: If strict is True and the variable does not exist.
"""
variable_struct = self._get_variable_struct(selector)
if variable_struct is None:
if strict:
raise KeyError(f"{selector} not exist")
return default
return variable_struct.instance
def get_value(
self,
selector: str,
@@ -273,38 +303,52 @@ class VariablePool:
"""
return self._get_variable_struct(selector) is not None
def get_all_system_vars(self) -> dict[str, Any]:
def get_all_system_vars(self, literal=False) -> dict[str, Any]:
"""获取所有系统变量
Returns:
系统变量字典
"""
sys_namespace = self.variables.get("sys", {})
if literal:
return {k: v.instance.to_literal() for k, v in sys_namespace.items()}
return {k: v.instance.get_value() for k, v in sys_namespace.items()}
def get_all_conversation_vars(self) -> dict[str, Any]:
def get_all_conversation_vars(self, literal=False) -> dict[str, Any]:
"""获取所有会话变量
Returns:
会话变量字典
"""
conv_namespace = self.variables.get("conv", {})
if literal:
return {k: v.instance.to_literal() for k, v in conv_namespace.items()}
return {k: v.instance.get_value() for k, v in conv_namespace.items()}
def get_all_node_outputs(self) -> dict[str, Any]:
def get_all_node_outputs(self, literal=False) -> dict[str, Any]:
"""获取所有节点输出(运行时变量)
Returns:
节点输出字典,键为节点 ID
"""
runtime_vars = {
namespace: {
k: v.instance.get_value()
for k, v in vars_dict.items()
if literal:
runtime_vars = {
namespace: {
k: v.instance.to_literal()
for k, v in vars_dict.items()
}
for namespace, vars_dict in self.variables.items()
if namespace not in ("sys", "conv")
}
else:
runtime_vars = {
namespace: {
k: v.instance.get_value()
for k, v in vars_dict.items()
}
for namespace, vars_dict in self.variables.items()
if namespace not in ("sys", "conv")
}
for namespace, vars_dict in self.variables.items()
if namespace not in ("sys", "conv")
}
return runtime_vars
def get_node_output(self, node_id: str, defalut: Any = None, strict: bool = True) -> dict[str, Any] | None:

View File

@@ -132,24 +132,24 @@ class WorkflowExecutor:
start_time = datetime.datetime.now()
# Build the workflow graph
graph = self.build_graph()
# Initialize the variable pool with input data
await self.variable_initializer.initialize(
variable_pool=self.variable_pool,
input_data=input_data,
execution_context=self.execution_context
)
initial_state = self.state_manager.create_initial_state(
workflow_config=self.workflow_config,
input_data=input_data,
execution_context=self.execution_context,
start_node_id=self.start_node_id
)
# Execute the workflow
try:
# Build the workflow graph
graph = self.build_graph()
# Initialize the variable pool with input data
await self.variable_initializer.initialize(
variable_pool=self.variable_pool,
input_data=input_data,
execution_context=self.execution_context
)
initial_state = self.state_manager.create_initial_state(
workflow_config=self.workflow_config,
input_data=input_data,
execution_context=self.execution_context,
start_node_id=self.start_node_id
)
result = await graph.ainvoke(initial_state, config=self.execution_context.checkpoint_config)
# Aggregate output from all End nodes
@@ -231,23 +231,23 @@ class WorkflowExecutor:
}
}
# Build the workflow graph in streaming mode
graph = self.build_graph(stream=True)
# Initialize the variable pool and system variables
await self.variable_initializer.initialize(
variable_pool=self.variable_pool,
input_data=input_data,
execution_context=self.execution_context
)
initial_state = self.state_manager.create_initial_state(
workflow_config=self.workflow_config,
input_data=input_data,
execution_context=self.execution_context,
start_node_id=self.start_node_id
)
try:
# Build the workflow graph in streaming mode
graph = self.build_graph(stream=True)
# Initialize the variable pool and system variables
await self.variable_initializer.initialize(
variable_pool=self.variable_pool,
input_data=input_data,
execution_context=self.execution_context
)
initial_state = self.state_manager.create_initial_state(
workflow_config=self.workflow_config,
input_data=input_data,
execution_context=self.execution_context,
start_node_id=self.start_node_id
)
full_content = ''
self.stream_coordinator.update_scope_activation("sys")
@@ -272,7 +272,7 @@ class WorkflowExecutor:
event_type = data.get("type", "node_chunk") # "message" or "node_chunk"
if event_type == "node_chunk":
async for msg_event in self.event_handler.handle_node_chunk_event(data):
full_content += msg_event["data"]["chunk"]
full_content += msg_event["data"]["content"]
yield msg_event
elif event_type == "node_error":
@@ -295,12 +295,12 @@ class WorkflowExecutor:
self.graph,
self.execution_context.checkpoint_config
):
full_content += msg_event["data"]['chunk']
full_content += msg_event["data"]['content']
yield msg_event
# Flush any remaining chunks
async for msg_event in self.stream_coordinator.flush_remaining_chunk(self.variable_pool):
full_content += msg_event["data"]['chunk']
full_content += msg_event["data"]['content']
yield msg_event
result = graph.get_state(self.execution_context.checkpoint_config).values

View File

@@ -16,7 +16,7 @@ from app.core.workflow.nodes.base_node import BaseNode
from app.core.workflow.variable.base_variable import VariableType
from app.db import get_db
from app.models import AppRelease
from app.services.draft_run_service import DraftRunService
from app.services.draft_run_service import AgentRunService
logger = logging.getLogger(__name__)
@@ -39,7 +39,7 @@ class AgentNode(BaseNode):
def _output_types(self) -> dict[str, VariableType]:
return {"output": VariableType.STRING}
def _prepare_agent(self, variable_pool: VariablePool) -> tuple[DraftRunService, AppRelease, str]:
def _prepare_agent(self, variable_pool: VariablePool) -> tuple[AgentRunService, AppRelease, str]:
"""准备 Agent公共逻辑
Args:
@@ -65,7 +65,7 @@ class AgentNode(BaseNode):
if not release:
raise ValueError(f"Agent 不存在: {agent_id}")
draft_service = DraftRunService(db)
draft_service = AgentRunService(db)
return draft_service, release, message

View File

@@ -1,5 +1,6 @@
import asyncio
import logging
import uuid
from abc import ABC, abstractmethod
from functools import cached_property
from typing import Any, AsyncGenerator
@@ -10,8 +11,10 @@ from app.core.config import settings
from app.core.workflow.engine.state_manager import WorkflowState
from app.core.workflow.engine.variable_pool import VariablePool
from app.core.workflow.nodes.enums import BRANCH_NODES
from app.core.workflow.variable.base_variable import VariableType
from app.services.multimodal_service import PROVIDER_STRATEGIES
from app.core.workflow.variable.base_variable import VariableType, FileObject
from app.db import get_db_read
from app.schemas import FileInput
from app.services.multimodal_service import MultimodalService
logger = logging.getLogger(__name__)
@@ -548,9 +551,9 @@ class BaseNode(ABC):
return render_template(
template=template,
conv_vars=variable_pool.get_all_conversation_vars(),
node_outputs=variable_pool.get_all_node_outputs(),
system_vars=variable_pool.get_all_system_vars(),
conv_vars=variable_pool.get_all_conversation_vars(literal=True),
node_outputs=variable_pool.get_all_node_outputs(literal=True),
system_vars=variable_pool.get_all_system_vars(literal=True),
strict=strict
)
@@ -614,16 +617,32 @@ class BaseNode(ABC):
return variable_pool.has(selector)
@staticmethod
async def process_message(provider, content, enable_file=False) -> dict | str | None:
async def process_message(provider: str, content: str | FileObject, enable_file=False) -> dict | str | None:
if isinstance(content, str):
if enable_file:
return {"text": content}
return content
elif isinstance(content, dict):
trans_tool = PROVIDER_STRATEGIES[provider]()
result = await trans_tool.format_image(content["url"])
return result
raise TypeError('Unexpect input value type')
elif isinstance(content, FileObject):
if content.content_cache.get(provider):
return content.content_cache[provider]
with get_db_read() as db:
multimodel_service = MultimodalService(db, provider)
message = await multimodel_service.process_files(
[FileInput.model_construct(
type=content.type,
url=content.url,
transfer_method=content.transfer_method,
file_type=content.origin_file_type,
upload_file_id=content.file_id
)]
)
if message:
content.content_cache[provider] = message[0]
return message[0]
return None
raise TypeError(f'Unexpect input value type - {type(content)}')
@staticmethod
def process_model_output(content) -> str:

View File

@@ -91,8 +91,8 @@ class IterationRuntime:
return loopstate
def merge_conv_vars(self):
self.variable_pool.get_all_conversation_vars().update(
self.child_variable_pool.get_all_conversation_vars()
self.variable_pool.variables["conv"].update(
self.child_variable_pool.variables["conv"]
)
async def run_task(self, item, idx):

View File

@@ -156,7 +156,7 @@ class LoopRuntime:
def merge_conv_vars(self, loopstate):
self.variable_pool.variables["conv"].update(
self.child_variable_pool.variables.get("conv", {})
self.child_variable_pool.variables["conv"]
)
loop_vars = self.child_variable_pool.get_node_output(self.node_id, defalut={}, strict=False)
loopstate["node_outputs"][self.node_id] = loop_vars

View File

@@ -66,7 +66,7 @@ class CycleGraphNode(BaseNode):
if config.flatten:
outputs['output'] = config.output_type
else:
outputs['output'] = VariableType.ARRAY_STRING
outputs['output'] = VariableType.NESTED_ARRAY
else:
outputs['output'] = VariableType(f"array[{config.output_type}]")
return outputs

View File

@@ -24,6 +24,8 @@ class NodeType(StrEnum):
MEMORY_READ = "memory-read"
MEMORY_WRITE = "memory-write"
UNKNOWN = "unknown"
BRANCH_NODES = [NodeType.IF_ELSE, NodeType.HTTP_REQUEST, NodeType.QUESTION_CLASSIFIER]

View File

@@ -1,6 +1,7 @@
import asyncio
import json
import logging
import uuid
from typing import Any, Callable, Coroutine
import httpx
@@ -13,6 +14,7 @@ from app.core.workflow.nodes.base_node import BaseNode
from app.core.workflow.nodes.enums import HttpRequestMethod, HttpErrorHandle, HttpAuthType, HttpContentType
from app.core.workflow.nodes.http_request.config import HttpRequestNodeConfig, HttpRequestNodeOutput
from app.core.workflow.variable.base_variable import VariableType
from app.core.workflow.variable.variable_objects import FileVariable, ArrayVariable
logger = logging.getLogger(__file__)
@@ -115,7 +117,7 @@ class HttpRequestNode(BaseNode):
params[self._render_template(key, variable_pool)] = self._render_template(value, variable_pool)
return params
def _build_content(self, variable_pool: VariablePool) -> dict[str, Any]:
async def _build_content(self, variable_pool: VariablePool) -> dict[str, Any]:
"""
Build HTTP request body arguments for httpx request methods.
@@ -135,16 +137,35 @@ class HttpRequestNode(BaseNode):
))
case HttpContentType.FROM_DATA:
data = {}
content["files"] = {}
for item in self.typed_config.body.data:
if item.type == "text":
data[self._render_template(item.key, variable_pool)] = self._render_template(item.value, variable_pool)
data[self._render_template(item.key, variable_pool)] = self._render_template(item.value,
variable_pool)
elif item.type == "file":
# TODO: File support (Feature)
pass
content["files"][self._render_template(item.key, variable_pool)] = (
uuid.uuid4().hex,
await variable_pool.get_instance(item.value).get_content()
)
content["data"] = data
case HttpContentType.BINARY:
# TODO: File support (Feature)
pass
content["files"] = []
file_instence = variable_pool.get_instance(self.typed_config.body.data)
if isinstance(file_instence, ArrayVariable):
for v in file_instence.value:
if isinstance(v, FileVariable):
content["files"].append(
(
"files", (uuid.uuid4().hex, await v.get_content())
)
)
elif isinstance(file_instence, FileVariable):
content["files"].append(
(
"file", (uuid.uuid4().hex, await file_instence.get_content())
)
)
case HttpContentType.WWW_FORM:
content["data"] = json.loads(self._render_template(
json.dumps(self.typed_config.body.data), variable_pool
@@ -207,7 +228,7 @@ class HttpRequestNode(BaseNode):
request_func = self._get_client_method(client)
resp = await request_func(
url=self._render_template(self.typed_config.url, variable_pool),
**self._build_content(variable_pool)
**(await self._build_content(variable_pool))
)
resp.raise_for_status()
logger.info(f"Node {self.node_id}: HTTP request succeeded")

View File

@@ -172,9 +172,9 @@ class LLMNode(BaseNode):
if self.typed_config.vision_input and self.typed_config.vision:
file_content = []
files = variable_pool.get_value(self.typed_config.vision_input)
for file in files:
content = await self.process_message(provider, file, self.typed_config.vision)
files = variable_pool.get_instance(self.typed_config.vision_input)
for file in files.value:
content = await self.process_message(provider, file.value, self.typed_config.vision)
if content:
file_content.append(content)
if messages and messages[-1]["role"] == 'user':

View File

@@ -123,10 +123,10 @@ class NodeFactory:
# 获取节点类
node_class = cls._node_types.get(node_type)
if not node_class:
raise ValueError(f"不支持的节点类型: {node_type}")
raise ValueError(f"Unsupported node type: {node_type}")
# 创建节点实例
logger.debug(f"创建节点: {node_config.get('id')} (type={node_type})")
logger.debug(f"create node instance: {node_config.get('id')} (type={node_type})")
return node_class(node_config, workflow_config)
@classmethod

View File

@@ -2,7 +2,7 @@ from enum import StrEnum
from abc import abstractmethod, ABC
from typing import Any
from pydantic import BaseModel
from pydantic import BaseModel, Field
from app.schemas import FileType
@@ -45,7 +45,7 @@ class VariableType(StrEnum):
return cls.NUMBER
elif isinstance(var, bool):
return cls.BOOLEAN
elif isinstance(var, FileObject) or (isinstance(var, dict) and var.get('__file')):
elif isinstance(var, FileObject) or (isinstance(var, dict) and var.get('is_file')):
return cls.FILE
elif isinstance(var, dict):
return cls.OBJECT
@@ -109,7 +109,13 @@ def DEFAULT_VALUE(var_type: VariableType) -> Any:
class FileObject(BaseModel):
type: FileType
url: str
__file: bool
transfer_method: str
origin_file_type: str
file_id: str | None
content_cache: dict = Field(default_factory=dict)
is_file: bool
class BaseVariable(ABC):

View File

@@ -1,8 +1,10 @@
from typing import Any, TypeVar, Type, Generic
import httpx
from deprecated import deprecated
from app.core.workflow.variable.base_variable import BaseVariable, VariableType, FileObject, FileType
from app.core.config import settings
T = TypeVar("T", bound=BaseVariable)
@@ -61,13 +63,16 @@ class FileVariable(BaseVariable):
def valid_value(self, value) -> FileObject:
if isinstance(value, dict):
if not value.get("__file"):
if not value.get("is_file"):
raise TypeError(f"Value must be a FileObject - {type(value)}:{value}")
return FileObject(
**{
"type": str(value.get('type')),
"transfer_method": value.get("transfer_method"),
"url": value.get('url'),
"__file": True
"file_id": value.get("file_id"),
"origin_file_type": value.get("origin_file_type"),
"is_file": True
}
)
if isinstance(value, FileObject):
@@ -80,8 +85,23 @@ class FileVariable(BaseVariable):
def get_value(self) -> Any:
return self.value.model_dump()
async def get_content(self):
total_bytes = 0
chunks = []
class ArrayObject(BaseVariable, Generic[T]):
async with httpx.AsyncClient() as client:
async with client.stream("GET", self.value.url) as resp:
resp.raise_for_status()
async for chunk in resp.aiter_bytes(8192):
total_bytes += len(chunk)
if total_bytes > settings.MAX_FILE_SIZE:
raise ValueError(f"File too large: {total_bytes} bytes")
chunks.append(chunk)
return b"".join(chunks)
class ArrayVariable(BaseVariable, Generic[T]):
type = 'array'
def __init__(self, child_type: Type[T], value: list[Any]):
@@ -108,7 +128,7 @@ class ArrayObject(BaseVariable, Generic[T]):
return [v.get_value() for v in self.value]
class NestedArrayObject(BaseVariable):
class NestedArrayVariable(BaseVariable):
type = 'array_nest'
def valid_value(self, value: list[T]) -> list[T]:
@@ -116,23 +136,23 @@ class NestedArrayObject(BaseVariable):
raise TypeError(f"Value must be a list - {type(value)}:{value}")
final_value = []
for v in value:
if not isinstance(v, ArrayObject):
if not isinstance(v, list):
raise TypeError("All elements must be of type list")
final_value.append(v)
final_value.append(make_array(AnyVariable, v))
return final_value
def to_literal(self) -> str:
return "\n".join(["\n".join([item.to_literal() for item in row]) for row in self.value])
return "\n".join(["\n".join([str(item) for item in row.get_value()]) for row in self.value])
def get_value(self) -> Any:
return [[item.get_value() for item in row] for row in self.value]
return [[item for item in row.get_value()] for row in self.value]
@deprecated(
reason="Using arbitrary-type values may cause unexpected errors; please switch to strongly-typed values.",
category=RuntimeWarning
)
class AnyObject(BaseVariable):
class AnyVariable(BaseVariable):
type = 'any'
def valid_value(self, value: Any) -> Any:
@@ -142,10 +162,10 @@ class AnyObject(BaseVariable):
return str(self.value)
def make_array(child_type: Type[T], value: list[Any]) -> ArrayObject[T]:
"""简化 ArrayObject 创建,不需要重复写类型"""
def make_array(child_type: Type[T], value: list[Any]) -> ArrayVariable[T]:
"""简化 ArrayVariable 创建,不需要重复写类型"""
return ArrayObject(child_type, value)
return ArrayVariable(child_type, value)
def create_variable_instance(var_type: VariableType, value: Any) -> T:
@@ -168,7 +188,9 @@ def create_variable_instance(var_type: VariableType, value: Any) -> T:
return make_array(DictVariable, value)
case VariableType.ARRAY_FILE:
return make_array(FileVariable, value)
case VariableType.NESTED_ARRAY:
return NestedArrayVariable(value)
case VariableType.ANY:
return AnyObject(value)
return AnyVariable(value)
case _:
raise TypeError(f"Invalid type - {var_type}")

View File

@@ -35,6 +35,7 @@ from .ontology_scene import OntologyScene
from .ontology_class import OntologyClass
from .ontology_scene import OntologyScene
from .ontology_class import OntologyClass
from .implicit_emotions_storage_model import ImplicitEmotionsStorage
__all__ = [
"Tenants",
@@ -90,5 +91,6 @@ __all__ = [
"MemoryPerceptualModel",
"ModelBase",
"LoadBalanceStrategy",
"Skill"
"Skill",
"ImplicitEmotionsStorage"
]

View File

@@ -0,0 +1,45 @@
"""
Implicit Emotions Storage Model
数据库模型:存储用户的隐性记忆画像和情绪建议数据
替代原有的Redis缓存方式
"""
import uuid
from datetime import datetime
from sqlalchemy import Column, String, Text, DateTime, Index
from sqlalchemy.dialects.postgresql import UUID, JSONB
from app.db import Base
class ImplicitEmotionsStorage(Base):
"""隐性记忆和情绪存储表"""
__tablename__ = "implicit_emotions_storage"
# 主键
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, comment="主键ID")
# 用户标识unique=True会自动创建唯一索引
end_user_id = Column(String(255), nullable=False, unique=True, comment="终端用户ID")
# 隐性记忆画像数据JSON格式
implicit_profile = Column(JSONB, nullable=True, comment="隐性记忆用户画像数据")
# 情绪建议数据JSON格式
emotion_suggestions = Column(JSONB, nullable=True, comment="情绪个性化建议数据")
# 时间戳
created_at = Column(DateTime, nullable=False, default=datetime.utcnow, comment="创建时间")
updated_at = Column(DateTime, nullable=False, default=datetime.utcnow, onupdate=datetime.utcnow, comment="更新时间")
# 数据生成时间(用于业务逻辑)
implicit_generated_at = Column(DateTime, nullable=True, comment="隐性记忆画像生成时间")
emotion_generated_at = Column(DateTime, nullable=True, comment="情绪建议生成时间")
# 索引只为updated_at创建索引end_user_id的unique约束已自动创建索引
__table_args__ = (
Index('idx_updated_at', 'updated_at'),
)
def __repr__(self):
return f"<ImplicitEmotionsStorage(id={self.id}, end_user_id={self.end_user_id})>"

View File

@@ -2,7 +2,7 @@ import datetime
import uuid
from enum import StrEnum
from sqlalchemy import Column, String, Boolean, DateTime, Text, ForeignKey, Enum as SQLEnum, UniqueConstraint, Integer, ARRAY, Table
from sqlalchemy import Column, String, Boolean, DateTime, Text, ForeignKey, Enum as SQLEnum, UniqueConstraint, Integer, ARRAY, Table, text
from sqlalchemy.dialects.postgresql import UUID, JSON
from sqlalchemy.orm import relationship
from sqlalchemy.sql import func
@@ -78,6 +78,9 @@ class ModelConfig(BaseModel):
description = Column(String, comment="模型描述")
# 模型配置参数
capability = Column(ARRAY(String), default=list, nullable=False, server_default=text("'{}'::varchar[]"),
comment="模型能力列表(如['vision', 'audio', 'video']")
is_omni = Column(Boolean, default=False, nullable=False, server_default="false", comment="是否为Omni模型使用特殊API调用")
config = Column(JSON, comment="模型配置参数")
# - temperature : 控制生成文本的随机性。值越高,输出越随机、越有创造性;值越低,输出越确定、越保守。
# - top_p : 一种替代 temperature 的采样方法,控制模型从概率最高的词中选择的范围。
@@ -118,6 +121,11 @@ class ModelApiKey(BaseModel):
api_key = Column(String, nullable=False, comment="API密钥")
api_base = Column(String, comment="API基础URL")
# 模型能力参数
capability = Column(ARRAY(String), default=list, nullable=False, server_default=text("'{}'::varchar[]"),
comment="模型能力列表(如['vision', 'audio', 'video']")
is_omni = Column(Boolean, default=False, nullable=False, server_default="false", comment="是否为Omni模型使用特殊API调用")
# 配置参数
config = Column(JSON, comment="API Key特定配置")
@@ -155,6 +163,9 @@ class ModelBase(Base):
tags = Column(ARRAY(String), default=list, nullable=False, comment="模型标签(如['聊天', '创作']")
add_count = Column(Integer, default=0, nullable=False, comment="模型被用户添加的次数")
created_at = Column(DateTime, default=datetime.datetime.now, comment="创建时间", server_default=func.now())
capability = Column(ARRAY(String), default=list, nullable=False, server_default=text("'{}'::varchar[]"),
comment="模型能力列表(如['vision', 'audio', 'video']")
is_omni = Column(Boolean, default=False, nullable=False, server_default="false", comment="是否为Omni模型使用特殊API调用")
# 关联关系
configs = relationship("ModelConfig", back_populates="model_base", cascade="all, delete-orphan")

View File

@@ -0,0 +1,169 @@
"""
Implicit Emotions Storage Repository
数据访问层:处理隐性记忆和情绪数据的数据库操作
事务由调用方控制,仓储层只使用 flush/refresh
"""
import logging
from datetime import datetime, date, timezone, timedelta
from typing import Optional, Generator
from sqlalchemy.orm import Session
from sqlalchemy import select, not_, exists
from app.models.implicit_emotions_storage_model import ImplicitEmotionsStorage
from app.models.end_user_model import EndUser
logger = logging.getLogger(__name__)
class ImplicitEmotionsStorageRepository:
"""隐性记忆和情绪存储仓储类"""
def __init__(self, db: Session):
self.db = db
def get_by_end_user_id(self, end_user_id: str) -> Optional[ImplicitEmotionsStorage]:
"""根据终端用户ID获取存储记录"""
try:
stmt = select(ImplicitEmotionsStorage).where(
ImplicitEmotionsStorage.end_user_id == end_user_id
)
return self.db.execute(stmt).scalar_one_or_none()
except Exception as e:
logger.error(f"获取用户存储记录失败: end_user_id={end_user_id}, error={e}")
return None
def create(self, end_user_id: str) -> ImplicitEmotionsStorage:
"""创建新的存储记录(事务由调用方提交)"""
storage = ImplicitEmotionsStorage(
end_user_id=end_user_id,
created_at=datetime.utcnow(),
updated_at=datetime.utcnow()
)
self.db.add(storage)
self.db.flush()
self.db.refresh(storage)
logger.info(f"创建用户存储记录成功: end_user_id={end_user_id}")
return storage
def update_implicit_profile(
self,
end_user_id: str,
profile_data: dict
) -> ImplicitEmotionsStorage:
"""更新隐性记忆画像数据(事务由调用方提交)"""
storage = self.get_by_end_user_id(end_user_id)
if storage is None:
storage = self.create(end_user_id)
storage.implicit_profile = profile_data
storage.implicit_generated_at = datetime.utcnow()
storage.updated_at = datetime.utcnow()
self.db.flush()
self.db.refresh(storage)
logger.info(f"更新隐性记忆画像成功: end_user_id={end_user_id}")
return storage
def update_emotion_suggestions(
self,
end_user_id: str,
suggestions_data: dict
) -> ImplicitEmotionsStorage:
"""更新情绪建议数据(事务由调用方提交)"""
storage = self.get_by_end_user_id(end_user_id)
if storage is None:
storage = self.create(end_user_id)
storage.emotion_suggestions = suggestions_data
storage.emotion_generated_at = datetime.utcnow()
storage.updated_at = datetime.utcnow()
self.db.flush()
self.db.refresh(storage)
logger.info(f"更新情绪建议成功: end_user_id={end_user_id}")
return storage
def get_all_user_ids(self, batch_size: int = 100) -> Generator[str, None, None]:
"""分批次获取所有已存储数据的用户ID避免大数据量内存溢出
Args:
batch_size: 每批次加载的数量默认100
Yields:
用户ID字符串
"""
offset = 0
while True:
try:
stmt = (
select(ImplicitEmotionsStorage.end_user_id)
.order_by(ImplicitEmotionsStorage.end_user_id)
.limit(batch_size)
.offset(offset)
)
batch = self.db.execute(stmt).scalars().all()
if not batch:
break
yield from batch
offset += batch_size
except Exception as e:
logger.error(f"分批获取用户ID失败: offset={offset}, error={e}")
break
def get_new_user_ids_today(self, batch_size: int = 100) -> Generator[str, None, None]:
"""分批次获取当天新增的、尚未初始化隐性记忆和情绪建议数据的用户ID
查询逻辑end_users 表中 created_at 为今天,且在 implicit_emotions_storage 中没有对应记录。
没有对应记录意味着隐性记忆画像和情绪建议均未初始化,需要对这批用户执行首次初始化。
end_users.idUUID转为字符串后与 implicit_emotions_storage.end_user_idString对比。
Args:
batch_size: 每批次加载的数量默认100
Yields:
用户ID字符串
"""
from sqlalchemy import cast, String as SAString
CST = timezone(timedelta(hours=8))
now_cst = datetime.now(CST)
today_start = now_cst.replace(hour=0, minute=0, second=0, microsecond=0).astimezone(timezone.utc).replace(tzinfo=None)
tomorrow_start = today_start + timedelta(days=1)
offset = 0
while True:
try:
stmt = (
select(EndUser.id)
.where(
EndUser.created_at >= today_start,
EndUser.created_at < tomorrow_start,
not_(
exists(
select(ImplicitEmotionsStorage.end_user_id).where(
ImplicitEmotionsStorage.end_user_id == cast(EndUser.id, SAString)
)
)
)
)
.order_by(EndUser.id)
.limit(batch_size)
.offset(offset)
)
batch = self.db.execute(stmt).scalars().all()
if not batch:
break
yield from (str(uid) for uid in batch)
offset += batch_size
except Exception as e:
logger.error(f"分批获取当天新增用户ID失败: offset={offset}, error={e}")
break
def delete_by_end_user_id(self, end_user_id: str) -> bool:
"""删除用户的存储记录(事务由调用方提交)"""
storage = self.get_by_end_user_id(end_user_id)
if storage:
self.db.delete(storage)
self.db.flush()
logger.info(f"删除用户存储记录成功: end_user_id={end_user_id}")
return True
return False

View File

@@ -211,3 +211,46 @@ def get_total_kb_count_by_workspace(db: Session, workspace_id: uuid.UUID) -> int
except Exception as e:
db_logger.error(f"Failed to query total knowledge base count: workspace_id={workspace_id} - {str(e)}")
raise
def get_user_kb_chunk_num_by_workspace(db: Session, workspace_id: uuid.UUID) -> int:
"""
根据workspace_id查询knowledges表中permission_id='Memory'用户知识库的chunk_num总和
"""
db_logger.debug(f"Query user KB chunk_num by workspace_id: workspace_id={workspace_id}")
try:
from sqlalchemy import func
result = db.query(func.sum(Knowledge.chunk_num)).filter(
Knowledge.workspace_id == workspace_id,
Knowledge.status == 1,
Knowledge.permission_id == "Memory"
).scalar()
total = result if result is not None else 0
db_logger.info(f"User KB chunk_num query successful: workspace_id={workspace_id}, total={total}")
return total
except Exception as e:
db_logger.error(f"Failed to query user KB chunk_num: workspace_id={workspace_id} - {str(e)}")
raise
def get_non_user_kb_count_by_workspace(db: Session, workspace_id: uuid.UUID) -> int:
"""
根据workspace_id查询knowledges表中排除用户知识库permission_id!='Memory')的数量
"""
db_logger.debug(f"Query non-user KB count by workspace_id: workspace_id={workspace_id}")
try:
count = db.query(Knowledge).filter(
Knowledge.workspace_id == workspace_id,
Knowledge.status == 1,
Knowledge.permission_id != "Memory"
).count()
db_logger.info(f"Non-user KB count query successful: workspace_id={workspace_id}, count={count}")
return count
except Exception as e:
db_logger.error(f"Failed to query non-user KB count: workspace_id={workspace_id} - {str(e)}")
raise

View File

@@ -155,8 +155,7 @@ class ApiKey(BaseModel):
return datetime.datetime.now() > self.expires_at
@field_serializer('expires_at', 'last_used_at', 'created_at', 'updated_at')
@classmethod
def serialize_datetime(cls, v: Optional[datetime.datetime]) -> Optional[int]:
def serialize_datetime(self, v: Optional[datetime.datetime]) -> Optional[int]:
"""将datetime转换为时间戳"""
return datetime_to_timestamp(v)
@@ -171,8 +170,7 @@ class ApiKeyStats(BaseModel):
avg_response_time: Optional[float] = Field(None, description="平均响应时间(毫秒)")
@field_serializer('last_used_at')
@classmethod
def serialize_datetime(cls, v: Optional[datetime.datetime]) -> Optional[int]:
def serialize_datetime(self, v: Optional[datetime.datetime]) -> Optional[int]:
"""将datetime转换为时间戳"""
return datetime_to_timestamp(v)
@@ -219,7 +217,6 @@ class ApiKeyLog(BaseModel):
created_at: datetime.datetime
@field_serializer('created_at')
@classmethod
def serialize_datetime(cls, v: datetime.datetime) -> int:
def serialize_datetime(self, v: datetime.datetime) -> int:
"""将datetime转换为时间戳"""
return datetime_to_timestamp(v)

View File

@@ -21,8 +21,14 @@ class FileType(StrEnum):
def trans(cls, value: str) -> 'FileType':
if value.startswith("image"):
return cls.IMAGE
# TODO: other file type support
raise RuntimeError("Unsupport file type")
elif value.startswith("document"):
return cls.DOCUMENT
elif value.startswith("audio"):
return cls.AUDIO
elif value.startswith("video"):
return cls.VIDEO
else:
raise RuntimeError("Unsupport file type")
class TransferMethod(str, Enum):
@@ -37,6 +43,12 @@ class FileInput(BaseModel):
transfer_method: TransferMethod = Field(..., description="传输方式: local_file/remote_url")
upload_file_id: Optional[uuid.UUID] = Field(None, description="已上传文件IDlocal_file时必填")
url: Optional[str] = Field(None, description="远程URLremote_url时必填")
file_type: Optional[str] = Field(None, description="具体文件格式如image/jpg、audio/wav、document/docx、video/mp4")
def __init__(self, **data):
if "type" in data:
data['file_type'] = data['type']
super().__init__(**data)
@field_validator("type", mode="before")
@classmethod
@@ -433,7 +445,7 @@ class AppChatRequest(BaseModel):
user_id: Optional[str] = Field(default=None, description="用户ID用于会话管理")
variables: Optional[Dict[str, Any]] = Field(default=None, description="自定义变量参数值")
stream: bool = Field(default=False, description="是否流式返回")
files: Optional[List[FileInput]] = Field(default=None, description="附件列表(支持多文件)")
files: List[FileInput] = Field(default_factory=list, description="附件列表(支持多文件)")
class DraftRunRequest(BaseModel):

View File

@@ -46,6 +46,7 @@ class ChunkUpdate(BaseModel):
class ChunkRetrieve(BaseModel):
query: str
kb_ids: list[uuid.UUID]
file_names_filter: list[str] | None = Field(None)
similarity_threshold: float | None = Field(None)
vector_similarity_weight: float | None = Field(None)
top_k: int | None = Field(None)

View File

@@ -21,6 +21,8 @@ class ModelConfigBase(BaseModel):
is_active: bool = Field(True, description="是否激活")
is_public: bool = Field(False, description="是否公开")
load_balance_strategy: Optional[str] = Field(LoadBalanceStrategy.NONE.value, description="负载均衡策略")
capability: List[str] = Field(default_factory=list, description="模型能力列表")
is_omni: bool = Field(False, description="是否为Omni模型")
class ApiKeyCreateNested(BaseModel):
@@ -30,6 +32,8 @@ class ApiKeyCreateNested(BaseModel):
provider: Optional[str] = Field(None, description="API Key提供商")
api_key: str = Field(..., description="API密钥", max_length=500)
api_base: Optional[str] = Field(None, description="API基础URL", max_length=500)
capability: Optional[List[str]] = Field(None, description="模型能力列表")
is_omni: Optional[bool] = Field(None, description="是否为Omni模型")
config: Optional[Dict[str, Any]] = Field({}, description="API Key特定配置")
priority: str = Field("1", description="优先级", max_length=10)
@@ -63,6 +67,8 @@ class ModelConfigUpdate(BaseModel):
config: Optional[Dict[str, Any]] = Field(None, description="模型配置参数")
is_active: Optional[bool] = Field(None, description="是否激活")
is_public: Optional[bool] = Field(None, description="是否公开")
capability: Optional[List[str]] = Field(None, description="模型能力列表")
is_omni: Optional[bool] = Field(None, description="是否为Omni模型")
class ModelConfig(ModelConfigBase):
@@ -95,6 +101,8 @@ class ModelApiKeyCreateByProvider(BaseModel):
api_key: str = Field(..., description="API密钥", max_length=500)
api_base: Optional[str] = Field(None, description="API基础URL", max_length=500)
description: Optional[str] = Field(None, description="备注")
capability: Optional[List[str]] = Field(None, description="模型能力列表")
is_omni: Optional[bool] = Field(None, description="是否为Omni模型")
config: Optional[Dict[str, Any]] = Field({}, description="API Key特定配置")
is_active: bool = Field(True, description="是否激活")
priority: str = Field("1", description="优先级", max_length=10)
@@ -108,6 +116,8 @@ class ModelApiKeyBase(BaseModel):
provider: ModelProvider = Field(..., description="API Key提供商")
api_key: str = Field(..., description="API密钥", max_length=500)
api_base: Optional[str] = Field(None, description="API基础URL", max_length=500)
capability: List[str] = Field(default_factory=list, description="模型能力列表")
is_omni: bool = Field(False, description="是否为Omni模型")
config: Optional[Dict[str, Any]] = Field({}, description="API Key特定配置")
is_active: bool = Field(True, description="是否激活")
priority: str = Field("1", description="优先级", max_length=10)
@@ -124,6 +134,8 @@ class ModelApiKeyUpdate(BaseModel):
provider: Optional[ModelProvider] = Field(None, description="API Key提供商")
api_key: Optional[str] = Field(None, description="API密钥", max_length=500)
api_base: Optional[str] = Field(None, description="API基础URL", max_length=500)
capability: Optional[List[str]] = Field(None, description="模型能力列表")
is_omni: Optional[bool] = Field(None, description="是否为Omni模型")
config: Optional[Dict[str, Any]] = Field(None, description="API Key特定配置")
is_active: Optional[bool] = Field(None, description="是否激活")
priority: Optional[str] = Field(None, description="优先级", max_length=10)
@@ -270,6 +282,8 @@ class ModelBaseCreate(BaseModel):
description: Optional[str] = Field(None, description="模型描述")
is_official: bool = Field(True, description="是否供应商官方模型")
tags: List[str] = Field(default_factory=list, description="模型标签")
capability: List[str] = Field(default_factory=list, description="模型能力列表(如['vision', 'audio', 'video']")
is_omni: bool = Field(False, description="是否为Omni模型")
class ModelBaseUpdate(BaseModel):
@@ -282,6 +296,8 @@ class ModelBaseUpdate(BaseModel):
is_deprecated: Optional[bool] = Field(None, description="是否弃用")
is_official: Optional[bool] = Field(None, description="是否供应商官方模型")
tags: Optional[List[str]] = Field(None, description="模型标签")
capability: Optional[List[str]] = Field(None, description="模型能力列表")
is_omni: Optional[bool] = Field(None, description="是否为Omni模型")
class ModelBase(BaseModel):
@@ -298,6 +314,8 @@ class ModelBase(BaseModel):
is_official: bool
tags: List[str]
add_count: int
capability: List[str] = []
is_omni: bool = False
class ModelBaseQuery(BaseModel):

View File

@@ -64,14 +64,14 @@ class ExecutionConfig(BaseModel):
class MultiAgentConfigCreate(BaseModel):
"""创建多 Agent 配置"""
master_agent_id: uuid.UUID = Field(..., description="主 Agent ID")
master_agent_name: Optional[str] = Field(None, max_length=100, description="主 Agent 名称")
master_agent_name: Optional[str] = Field(default=None, max_length=100, description="主 Agent 名称")
orchestration_mode: str = Field(
default="collaboration",
pattern="^(collaboration|supervisor)$",
description="协作模式collaboration协作| supervisor监督"
)
sub_agents: List[SubAgentConfig] = Field(..., description="子 Agent 列表")
routing_rules: Optional[List[RoutingRule]] = Field(None, description="路由规则")
routing_rules: Optional[List[RoutingRule]] = Field(default=None, description="路由规则")
execution_config: ExecutionConfig = Field(default_factory=ExecutionConfig, description="执行配置")
aggregation_strategy: str = Field(
default="merge",
@@ -83,7 +83,7 @@ class MultiAgentConfigCreate(BaseModel):
class MultiAgentConfigUpdate(BaseModel):
"""更新多 Agent 配置"""
master_agent_id: Optional[uuid.UUID] = None
master_agent_name: Optional[str] = Field(None, max_length=100, description="主 Agent 名称")
master_agent_name: Optional[str] = Field(default=None, max_length=100, description="主 Agent 名称")
default_model_config_id: Optional[uuid.UUID] = Field(None, description="默认模型配置ID")
model_parameters: Optional[ModelParameters] = Field(
None,

View File

@@ -263,8 +263,8 @@ def create_agent_invocation_tool(
try:
# 9. 调用 Agent
from app.services.draft_run_service import DraftRunService
draft_service = DraftRunService(db)
from app.services.draft_run_service import AgentRunService
draft_service = AgentRunService(db)
result = await draft_service.run(
agent_config=agent_config,

View File

@@ -10,25 +10,24 @@ from sqlalchemy.orm import Session
from app.core.agent.agent_middleware import AgentMiddleware
from app.core.agent.langchain_agent import LangChainAgent
from app.core.error_codes import BizCode
from app.core.exceptions import BusinessException
from app.core.logging_config import get_business_logger
from app.db import get_db, get_db_context
from app.models import MultiAgentConfig, AgentConfig, WorkflowConfig
from app.schemas import DraftRunRequest
from app.schemas.app_schema import FileInput
from app.services.tool_service import ToolService
from app.repositories.tool_repository import ToolRepository
from app.db import get_db
from app.models import MultiAgentConfig, AgentConfig
from app.models import WorkflowConfig
from app.repositories.tool_repository import ToolRepository
from app.schemas import DraftRunRequest
from app.schemas.app_schema import FileInput
from app.schemas.prompt_schema import render_prompt_message, PromptMessageRole
from app.services.conversation_service import ConversationService
from app.services.draft_run_service import create_knowledge_retrieval_tool, create_long_term_memory_tool
from app.services.draft_run_service import create_knowledge_retrieval_tool, create_long_term_memory_tool, \
AgentRunService
from app.services.draft_run_service import create_web_search_tool
from app.services.model_service import ModelApiKeyService
from app.services.multi_agent_orchestrator import MultiAgentOrchestrator
from app.services.workflow_service import WorkflowService
from app.services.multimodal_service import MultimodalService
from app.services.tool_service import ToolService
from app.services.workflow_service import WorkflowService
logger = get_business_logger()
@@ -39,6 +38,8 @@ class AppChatService:
def __init__(self, db: Session):
self.db = db
self.conversation_service = ConversationService(db)
self.agent_service = AgentRunService(db)
self.workflow_service = WorkflowService(db)
async def agnet_chat(
self,
@@ -55,12 +56,10 @@ class AppChatService:
files: Optional[List[FileInput]] = None # 新增:多模态文件
) -> Dict[str, Any]:
"""聊天(非流式)"""
start_time = time.time()
config_id = None
if variables is None:
variables = {}
variables = self.agent_service.prepare_variables(variables, config.variables)
# 获取模型配置ID
model_config_id = config.default_model_config_id
@@ -79,74 +78,20 @@ class AppChatService:
tools = []
# 获取工具服务
tool_service = ToolService(self.db)
tenant_id = ToolRepository.get_tenant_id_by_workspace_id(self.db, str(workspace_id))
# 从配置中获取启用的工具
if hasattr(config, 'tools') and config.tools and isinstance(config.tools, list):
for tool_config in config.tools:
if tool_config.get("enabled", False):
# 根据工具名称查找工具实例
tool_instance = tool_service._get_tool_instance(tool_config.get("tool_id", ""), tenant_id)
if tool_instance:
if tool_instance.name == "baidu_search_tool" and not web_search:
continue
# 转换为LangChain工具
langchain_tool = tool_instance.to_langchain_tool(tool_config.get("operation", None))
tools.append(langchain_tool)
elif hasattr(config, 'tools') and config.tools and isinstance(config.tools, dict):
web_tools = config.tools
web_search_choice = web_tools.get("web_search", {})
web_search_enable = web_search_choice.get("enabled", False)
if web_search:
if web_search_enable:
search_tool = create_web_search_tool({})
tools.append(search_tool)
logger.debug(
"已添加网络搜索工具",
extra={
"tool_count": len(tools)
}
)
# 加载技能关联的工具
if hasattr(config, 'skills') and config.skills:
skills = config.skills
skill_enable = skills.get("enabled", False)
if skill_enable:
middleware = AgentMiddleware(skills=skills)
skill_tools, skill_configs, tool_to_skill_map = middleware.load_skill_tools(self.db, tenant_id)
tools.extend(skill_tools)
logger.debug(f"已加载 {len(skill_tools)} 个技能工具")
# 应用动态过滤
if skill_configs:
tools, activated_skill_ids = middleware.filter_tools(tools, message, skill_configs,
tool_to_skill_map)
logger.debug(f"过滤后剩余 {len(tools)} 个工具")
active_prompts = AgentMiddleware.get_active_prompts(
activated_skill_ids, skill_configs
)
system_prompt = f"{system_prompt}\n\n{active_prompts}"
# 添加知识库检索工具
knowledge_retrieval = config.knowledge_retrieval
if knowledge_retrieval:
knowledge_bases = knowledge_retrieval.get("knowledge_bases", [])
kb_ids = [kb.get("kb_id") for kb in knowledge_bases if kb.get("kb_id")]
if kb_ids:
kb_tool = create_knowledge_retrieval_tool(knowledge_retrieval, kb_ids, user_id)
tools.append(kb_tool)
# 添加长期记忆工具
tools.extend(self.agent_service.load_tools_config(config.tools, web_search, tenant_id))
skill_tools, skill_prompts = self.agent_service.load_skill_config(config.skills, message, tenant_id)
tools.extend(skill_tools)
if skill_prompts:
system_prompt = f"{system_prompt}\n\n{skill_prompts}"
tools.extend(self.agent_service.load_knowledge_retrieval_config(config.knowledge_retrieval, user_id))
memory_flag = False
if memory == True:
memory_config = config.memory
if memory_config.get("enabled") and user_id:
memory_flag = True
memory_tool = create_long_term_memory_tool(memory_config, user_id)
tools.append(memory_tool)
if memory:
memory_tools, memory_flag = self.agent_service.load_memory_config(
config.memory, user_id, storage_type, user_rag_memory_id
)
tools.extend(memory_tools)
# 获取模型参数
model_parameters = config.model_parameters
@@ -157,6 +102,7 @@ class AppChatService:
api_key=api_key_obj.api_key,
provider=api_key_obj.provider,
api_base=api_key_obj.api_base,
is_omni=api_key_obj.is_omni,
temperature=model_parameters.get("temperature", 0.7),
max_tokens=model_parameters.get("max_tokens", 2000),
system_prompt=system_prompt,
@@ -180,7 +126,7 @@ class AppChatService:
# 处理多模态文件
processed_files = None
if files:
multimodal_service = MultimodalService(self.db)
multimodal_service = MultimodalService(self.db, api_key_obj.provider, is_omni=api_key_obj.is_omni)
processed_files = await multimodal_service.process_files(files)
logger.info(f"处理了 {len(processed_files)} 个文件")
@@ -245,10 +191,9 @@ class AppChatService:
try:
start_time = time.time()
config_id = None
yield f"event: start\ndata: {json.dumps({'conversation_id': str(conversation_id)}, ensure_ascii=False)}\n\n"
if variables is None:
variables = {}
variables = self.agent_service.prepare_variables(variables, config.variables)
# 获取模型配置ID
model_config_id = config.default_model_config_id
api_key_obj = ModelApiKeyService.get_available_api_key(self.db, model_config_id)
@@ -266,73 +211,22 @@ class AppChatService:
tools = []
# 获取工具服务
tool_service = ToolService(self.db)
tenant_id = ToolRepository.get_tenant_id_by_workspace_id(self.db, str(workspace_id))
if hasattr(config, 'tools') and config.tools and isinstance(config.tools, list):
for tool_config in config.tools:
if tool_config.get("enabled", False):
# 根据工具名称查找工具实例
tool_instance = tool_service._get_tool_instance(tool_config.get("tool_id", ""), tenant_id)
if tool_instance:
if tool_instance.name == "baidu_search_tool" and not web_search:
continue
# 转换为LangChain工具
langchain_tool = tool_instance.to_langchain_tool(tool_config.get("operation", None))
tools.append(langchain_tool)
elif hasattr(config, 'tools') and config.tools and isinstance(config.tools, dict):
web_tools = config.tools
web_search_choice = web_tools.get("web_search", {})
web_search_enable = web_search_choice.get("enabled", False)
if web_search:
if web_search_enable:
search_tool = create_web_search_tool({})
tools.append(search_tool)
logger.debug(
"已添加网络搜索工具",
extra={
"tool_count": len(tools)
}
)
# 加载技能关联的工具
if hasattr(config, 'skills') and config.skills:
skills = config.skills
skill_enable = skills.get("enabled", False)
if skill_enable:
middleware = AgentMiddleware(skills=skills)
skill_tools, skill_configs, tool_to_skill_map = middleware.load_skill_tools(self.db, tenant_id)
tools.extend(skill_tools)
logger.debug(f"已加载 {len(skill_tools)} 个技能工具")
# 应用动态过滤
if skill_configs:
tools, activated_skill_ids = middleware.filter_tools(tools, message, skill_configs,
tool_to_skill_map)
logger.debug(f"过滤后剩余 {len(tools)} 个工具")
active_prompts = AgentMiddleware.get_active_prompts(
activated_skill_ids, skill_configs
)
system_prompt = f"{system_prompt}\n\n{active_prompts}"
# 添加知识库检索工具
knowledge_retrieval = config.knowledge_retrieval
if knowledge_retrieval:
knowledge_bases = knowledge_retrieval.get("knowledge_bases", [])
kb_ids = [kb.get("kb_id") for kb in knowledge_bases if kb.get("kb_id")]
if kb_ids:
kb_tool = create_knowledge_retrieval_tool(knowledge_retrieval, kb_ids, user_id)
tools.append(kb_tool)
tools.extend(self.agent_service.load_tools_config(config.tools, web_search, tenant_id))
skill_tools, skill_prompts = self.agent_service.load_skill_config(config.skills, message, tenant_id)
tools.extend(skill_tools)
if skill_prompts:
system_prompt = f"{system_prompt}\n\n{skill_prompts}"
tools.extend(self.agent_service.load_knowledge_retrieval_config(config.knowledge_retrieval, user_id))
# 添加长期记忆工具
memory_flag = False
if memory:
memory_config = config.memory
if memory_config.get("enabled") and user_id:
memory_flag = True
memory_tool = create_long_term_memory_tool(memory_config, user_id)
tools.append(memory_tool)
memory_tools, memory_flag = self.agent_service.load_memory_config(
config.memory, user_id, storage_type, user_rag_memory_id
)
tools.extend(memory_tools)
# 获取模型参数
model_parameters = config.model_parameters
@@ -343,6 +237,7 @@ class AppChatService:
api_key=api_key_obj.api_key,
provider=api_key_obj.provider,
api_base=api_key_obj.api_base,
is_omni=api_key_obj.is_omni,
temperature=model_parameters.get("temperature", 0.7),
max_tokens=model_parameters.get("max_tokens", 2000),
system_prompt=system_prompt,
@@ -366,13 +261,10 @@ class AppChatService:
# 处理多模态文件
processed_files = None
if files:
multimodal_service = MultimodalService(self.db)
multimodal_service = MultimodalService(self.db, api_key_obj.provider, is_omni=api_key_obj.is_omni)
processed_files = await multimodal_service.process_files(files)
logger.info(f"处理了 {len(processed_files)} 个文件")
# 发送开始事件
yield f"event: start\ndata: {json.dumps({'conversation_id': str(conversation_id)}, ensure_ascii=False)}\n\n"
# 流式调用 Agent支持多模态
full_content = ""
total_tokens = 0
@@ -416,7 +308,7 @@ class AppChatService:
ModelApiKeyService.record_api_key_usage(self.db, api_key_obj.id)
# 发送结束事件
end_data = {"elapsed_time": elapsed_time, "message_length": len(full_content)}
end_data = {"elapsed_time": elapsed_time, "message_length": len(full_content), "error": None}
yield f"event: end\ndata: {json.dumps(end_data, ensure_ascii=False)}\n\n"
logger.info(
@@ -435,7 +327,7 @@ class AppChatService:
except Exception as e:
logger.error(f"流式聊天失败: {str(e)}", exc_info=True)
# 发送错误事件
yield f"event: error\ndata: {json.dumps({'error': str(e)}, ensure_ascii=False)}\n\n"
yield f"event: end\ndata: {json.dumps({'error': str(e)}, ensure_ascii=False)}\n\n"
async def multi_agent_chat(
self,
@@ -489,10 +381,10 @@ class AppChatService:
"mode": result.get("mode"),
"elapsed_time": result.get("elapsed_time"),
"usage": result.get("usage", {
"prompt_tokens": 0,
"completion_tokens": 0,
"total_tokens": 0
})
"prompt_tokens": 0,
"completion_tokens": 0,
"total_tokens": 0
})
}
)
@@ -522,8 +414,6 @@ class AppChatService:
"""多 Agent 聊天(流式)"""
start_time = time.time()
actual_config_id = None
config_id = actual_config_id
if variables is None:
variables = {}
@@ -629,7 +519,6 @@ class AppChatService:
user_rag_memory_id: Optional[str] = None,
) -> Dict[str, Any]:
"""聊天(非流式)"""
workflow_service = WorkflowService(self.db)
payload = DraftRunRequest(
message=message,
variables=variables,
@@ -637,7 +526,7 @@ class AppChatService:
stream=True,
user_id=user_id
)
return await workflow_service.run(
return await self.workflow_service.run(
app_id=app_id,
payload=payload,
config=config,
@@ -664,7 +553,6 @@ class AppChatService:
) -> AsyncGenerator[dict, None]:
"""聊天(流式)"""
workflow_service = WorkflowService(self.db)
payload = DraftRunRequest(
message=message,
variables=variables,
@@ -673,7 +561,7 @@ class AppChatService:
user_id=user_id,
files=files
)
async for event in workflow_service.run_stream(
async for event in self.workflow_service.run_stream(
app_id=app_id,
payload=payload,
config=config,

View File

@@ -232,7 +232,7 @@ class AppService:
# 检查主 Agent 的模型配置
multi_agent_config.default_model_config_id = master_agent_release.default_model_config_id
model_api_key = ModelApiKeyService.get_a_api_key(self.db, multi_agent_config.default_model_config_id)
model_api_key = ModelApiKeyService.get_available_api_key(self.db, multi_agent_config.default_model_config_id)
if not model_api_key:
raise ResourceNotFoundException("模型配置", str(multi_agent_config.default_model_config_id))
@@ -1791,372 +1791,6 @@ class AppService:
return shares
# ==================== 试运行功能 ====================
async def draft_run(
self,
*,
app_id: uuid.UUID,
message: str,
conversation_id: Optional[str] = None,
user_id: Optional[str] = None,
variables: Optional[Dict[str, Any]] = None,
workspace_id: Optional[uuid.UUID] = None
) -> Dict[str, Any]:
"""试运行 Agent使用当前草稿配置
Args:
app_id: 应用ID
message: 用户消息
conversation_id: 会话ID用于多轮对话
user_id: 用户ID用于会话管理
variables: 自定义变量参数值
workspace_id: 工作空间ID用于权限验证
Returns:
Dict: 包含 AI 回复和元数据的字典
Raises:
ResourceNotFoundException: 当应用不存在时
BusinessException: 当应用类型不支持或配置缺失时
"""
from app.services.draft_run_service import DraftRunService
logger.info("试运行 Agent", extra={"app_id": str(app_id), "user_message": message[:50]})
# 1. 验证应用
app = self._get_app_or_404(app_id)
if app.type != "agent":
raise BusinessException("只有 Agent 类型应用支持试运行", BizCode.APP_TYPE_NOT_SUPPORTED)
# 只读操作,允许访问共享应用
self._validate_app_accessible(app, workspace_id)
# 2. 获取 Agent 配置
stmt = select(AgentConfig).where(AgentConfig.app_id == app_id)
agent_cfg = self.db.scalars(stmt).first()
if not agent_cfg:
raise BusinessException("Agent 配置不存在,无法试运行", BizCode.AGENT_CONFIG_MISSING)
# 3. 获取模型配置
model_config = None
if agent_cfg.default_model_config_id:
from app.models import ModelConfig
model_config = self.db.get(ModelConfig, agent_cfg.default_model_config_id)
if not model_config:
raise BusinessException("模型配置不存在,无法试运行", BizCode.AGENT_CONFIG_MISSING)
# 4. 调用试运行服务
logger.debug(
"准备调用试运行服务",
extra={
"app_id": str(app_id),
"model": model_config.name,
"has_conversation_id": bool(conversation_id),
"has_variables": bool(variables)
}
)
draft_service = DraftRunService(self.db)
result = await draft_service.run(
agent_config=agent_cfg,
model_config=model_config,
message=message,
workspace_id=workspace_id,
conversation_id=conversation_id,
user_id=user_id,
variables=variables
)
logger.debug(
"试运行服务返回结果",
extra={
"result_type": str(type(result)),
"result_keys": list(result.keys()) if isinstance(result, dict) else "not_dict",
"has_message": "message" in result if isinstance(result, dict) else False,
"has_conversation_id": "conversation_id" in result if isinstance(result, dict) else False
}
)
logger.info(
"试运行完成",
extra={
"app_id": str(app_id),
"elapsed_time": result.get("elapsed_time"),
"model": model_config.name
}
)
return result
async def draft_run_stream(
self,
*,
app_id: uuid.UUID,
message: str,
conversation_id: Optional[str] = None,
user_id: Optional[str] = None,
variables: Optional[Dict[str, Any]] = None,
workspace_id: Optional[uuid.UUID] = None
):
"""试运行 Agent流式返回
Args:
app_id: 应用ID
message: 用户消息
conversation_id: 会话ID用于多轮对话
user_id: 用户ID用于会话管理
variables: 自定义变量参数值
workspace_id: 工作空间ID用于权限验证
Yields:
str: SSE 格式的事件数据
Raises:
ResourceNotFoundException: 当应用不存在时
BusinessException: 当应用类型不支持或配置缺失时
"""
from app.services.draft_run_service import DraftRunService
logger.info("流式试运行 Agent", extra={"app_id": str(app_id), "user_message": message[:50]})
# 1. 验证应用
app = self._get_app_or_404(app_id)
if app.type != "agent":
raise BusinessException("只有 Agent 类型应用支持试运行", BizCode.APP_TYPE_NOT_SUPPORTED)
# 只读操作,允许访问共享应用
self._validate_app_accessible(app, workspace_id)
# 2. 获取 Agent 配置
stmt = select(AgentConfig).where(AgentConfig.app_id == app_id)
agent_cfg = self.db.scalars(stmt).first()
if not agent_cfg:
raise BusinessException("Agent 配置不存在,无法试运行", BizCode.AGENT_CONFIG_MISSING)
# 3. 获取模型配置
model_config = None
if agent_cfg.default_model_config_id:
from app.models import ModelConfig
model_config = self.db.get(ModelConfig, agent_cfg.default_model_config_id)
if not model_config:
raise BusinessException("模型配置不存在,无法试运行", BizCode.AGENT_CONFIG_MISSING)
# 4. 调用流式试运行服务
draft_service = DraftRunService(self.db)
async for event in draft_service.run_stream(
agent_config=agent_cfg,
model_config=model_config,
message=message,
workspace_id=workspace_id,
conversation_id=conversation_id,
user_id=user_id,
variables=variables
):
yield event
# ==================== 多模型对比试运行 ====================
async def draft_run_compare(
self,
*,
app_id: uuid.UUID,
message: str,
models: List[app_schema.ModelCompareItem],
conversation_id: Optional[str] = None,
user_id: Optional[str] = None,
variables: Optional[Dict[str, Any]] = None,
workspace_id: Optional[uuid.UUID] = None,
parallel: bool = True,
timeout: int = 60
) -> Dict[str, Any]:
"""多模型对比试运行
Args:
app_id: 应用ID
message: 用户消息
models: 要对比的模型列表
conversation_id: 会话ID
user_id: 用户ID
variables: 变量参数
workspace_id: 工作空间ID
parallel: 是否并行执行
timeout: 超时时间(秒)
Returns:
Dict: 对比结果
"""
from app.models import ModelConfig
from app.services.draft_run_service import DraftRunService
logger.info(
"多模型对比试运行",
extra={
"app_id": str(app_id),
"model_count": len(models),
"parallel": parallel
}
)
# 1. 验证应用
app = self._get_app_or_404(app_id)
if app.type != "agent":
raise BusinessException("只有 Agent 类型应用支持试运行", BizCode.APP_TYPE_NOT_SUPPORTED)
# 只读操作,允许访问共享应用
self._validate_app_accessible(app, workspace_id)
# 2. 获取 Agent 配置
stmt = select(AgentConfig).where(AgentConfig.app_id == app_id)
agent_cfg = self.db.scalars(stmt).first()
if not agent_cfg:
raise BusinessException("Agent 配置不存在", BizCode.AGENT_CONFIG_MISSING)
# 3. 准备所有模型配置
model_configs = []
for model_item in models:
model_config = self.db.get(ModelConfig, model_item.model_config_id)
if not model_config:
raise ResourceNotFoundException("模型配置", str(model_item.model_config_id))
# 合并参数agent配置参数 + 请求覆盖参数
merged_parameters = {
**(agent_cfg.model_parameters or {}),
**(model_item.model_parameters or {})
}
model_configs.append({
"model_config": model_config,
"parameters": merged_parameters,
"label": model_item.label or model_config.name,
"model_config_id": model_item.model_config_id
})
# 4. 调用 DraftRunService 的对比方法
draft_service = DraftRunService(self.db)
result = await draft_service.run_compare(
agent_config=agent_cfg,
models=model_configs,
message=message,
workspace_id=workspace_id,
conversation_id=conversation_id,
user_id=user_id,
variables=variables,
parallel=parallel,
timeout=timeout
)
logger.info(
"多模型对比完成",
extra={
"app_id": str(app_id),
"successful": result["successful_count"],
"failed": result["failed_count"]
}
)
return result
async def draft_run_compare_stream(
self,
*,
app_id: uuid.UUID,
message: str,
models: List[app_schema.ModelCompareItem],
conversation_id: Optional[str] = None,
user_id: Optional[str] = None,
variables: Optional[Dict[str, Any]] = None,
workspace_id: Optional[uuid.UUID] = None,
parallel: bool = True,
timeout: int = 60
):
"""多模型对比试运行(流式返回)
Args:
app_id: 应用ID
message: 用户消息
models: 要对比的模型列表
conversation_id: 会话ID
user_id: 用户ID
variables: 变量参数
workspace_id: 工作空间ID
timeout: 超时时间(秒)
Yields:
str: SSE 格式的事件数据
"""
from app.models import ModelConfig
from app.services.draft_run_service import DraftRunService
logger.info(
"多模型对比流式试运行",
extra={
"app_id": str(app_id),
"model_count": len(models)
}
)
# 1. 验证应用
app = self._get_app_or_404(app_id)
if app.type != "agent":
raise BusinessException("只有 Agent 类型应用支持试运行", BizCode.APP_TYPE_NOT_SUPPORTED)
# 只读操作,允许访问共享应用
self._validate_app_accessible(app, workspace_id)
# 2. 获取 Agent 配置
stmt = select(AgentConfig).where(AgentConfig.app_id == app_id)
agent_cfg = self.db.scalars(stmt).first()
if not agent_cfg:
raise BusinessException("Agent 配置不存在", BizCode.AGENT_CONFIG_MISSING)
# 3. 准备所有模型配置
model_configs = []
for model_item in models:
model_config = self.db.get(ModelConfig, model_item.model_config_id)
if not model_config:
raise ResourceNotFoundException("模型配置", str(model_item.model_config_id))
# 合并参数agent配置参数 + 请求覆盖参数
merged_parameters = {
**(agent_cfg.model_parameters or {}),
**(model_item.model_parameters or {})
}
model_configs.append({
"model_config": model_config,
"parameters": merged_parameters,
"label": model_item.label or model_config.name,
"model_config_id": model_item.model_config_id
})
# 4. 调用 DraftRunService 的流式对比方法
draft_service = DraftRunService(self.db)
async for event in draft_service.run_compare_stream(
agent_config=agent_cfg,
models=model_configs,
message=message,
workspace_id=workspace_id,
conversation_id=conversation_id,
user_id=user_id,
variables=variables,
parallel=parallel,
timeout=timeout
):
yield event
logger.info(
"多模型对比流式完成",
extra={"app_id": str(app_id)}
)
# ==================== 向后兼容的函数接口 ====================
# 保留函数接口以兼容现有代码,但内部使用服务类
@@ -2278,53 +1912,6 @@ def get_apps_by_ids(
return service.get_apps_by_ids(app_ids, workspace_id)
# ==================== 向后兼容的函数接口 ====================
async def draft_run(
db: Session,
*,
app_id: uuid.UUID,
message: str,
conversation_id: Optional[str] = None,
user_id: Optional[str] = None,
variables: Optional[Dict[str, Any]] = None,
workspace_id: Optional[uuid.UUID] = None
) -> Dict[str, Any]:
"""试运行 Agent向后兼容接口"""
service = AppService(db)
return await service.draft_run(
app_id=app_id,
message=message,
conversation_id=conversation_id,
user_id=user_id,
variables=variables,
workspace_id=workspace_id
)
async def draft_run_stream(
db: Session,
*,
app_id: uuid.UUID,
message: str,
conversation_id: Optional[str] = None,
user_id: Optional[str] = None,
variables: Optional[Dict[str, Any]] = None,
workspace_id: Optional[uuid.UUID] = None
):
"""试运行 Agent 流式返回(向后兼容接口)"""
service = AppService(db)
async for event in service.draft_run_stream(
app_id=app_id,
message=message,
conversation_id=conversation_id,
user_id=user_id,
variables=variables,
workspace_id=workspace_id
):
yield event
# ==================== 依赖注入函数 ====================
def get_app_service(

View File

@@ -0,0 +1,101 @@
"""
音频转文本服务
支持的服务商:
- DashScope (阿里云通义千问)
- OpenAI Whisper
"""
import httpx
from app.core.logging_config import get_business_logger
logger = get_business_logger()
class AudioTranscriptionService:
"""音频转文本服务"""
@staticmethod
async def transcribe_dashscope(audio_url: str, api_key: str) -> str:
"""
使用阿里云通义千问语音识别服务转换音频为文本
Args:
audio_url: 音频文件 URL
api_key: DashScope API Key
Returns:
str: 转录的文本
"""
try:
async with httpx.AsyncClient(timeout=60.0) as client:
response = await client.post(
"https://dashscope.aliyuncs.com/api/v1/services/audio/asr/transcription",
headers={
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
"X-DashScope-Async": "enable",
},
json={
"model": "paraformer-v2",
"input": {
"file_urls": [audio_url]
},
"parameters": {
"language_hints": ["zh", "en", "ja", "yue", "ko", "de", "fr", "ru"]
}
}
)
response.raise_for_status()
result = response.json()
if result.get("output", {}).get("results"):
text = result["output"]["results"][0].get("transcription_text", "")
logger.info(f"音频转文本成功: {len(text)} 字符")
return text
return "[音频转文本失败]"
except Exception as e:
logger.error(f"DashScope 音频转文本失败: {e}")
return f"[音频转文本失败: {str(e)}]"
@staticmethod
async def transcribe_openai(audio_url: str, api_key: str) -> str:
"""
使用 OpenAI Whisper 转换音频为文本
Args:
audio_url: 音频文件 URL
api_key: OpenAI API Key
Returns:
str: 转录的文本
"""
try:
# 下载音频文件
async with httpx.AsyncClient(timeout=60.0) as client:
audio_response = await client.get(audio_url)
audio_response.raise_for_status()
audio_data = audio_response.content
# 调用 Whisper API
files = {"file": ("audio.mp3", audio_data, "audio/mpeg")}
data = {"model": "whisper-1"}
response = await client.post(
"https://api.openai.com/v1/audio/transcriptions",
headers={"Authorization": f"Bearer {api_key}"},
files=files,
data=data
)
response.raise_for_status()
result = response.json()
text = result.get("text", "")
logger.info(f"音频转文本成功: {len(text)} 字符")
return text
except Exception as e:
logger.error(f"OpenAI Whisper 音频转文本失败: {e}")
return f"[音频转文本失败: {str(e)}]"

View File

@@ -445,6 +445,7 @@ class CollaborativeOrchestrator:
"provider": api_key_config.provider,
"api_key": api_key_config.api_key,
"api_base": api_key_config.api_base,
"is_omni": api_key_config.is_omni,
"model_parameters": config_data.get("model_parameters", {}),
"api_key_id": api_key_config.id
}
@@ -511,6 +512,7 @@ class CollaborativeOrchestrator:
provider=agent_config["provider"],
api_key=agent_config["api_key"],
base_url=agent_config.get("api_base"),
is_omni=agent_config.get("is_omni", False),
extra_params=extra_params
)

View File

@@ -17,6 +17,7 @@ from sqlalchemy.orm import Session
from app.celery_app import celery_app
from app.core.agent.agent_middleware import AgentMiddleware
from app.core.agent.langchain_agent import LangChainAgent
from app.core.error_codes import BizCode
from app.core.exceptions import BusinessException
from app.core.logging_config import get_business_logger
@@ -26,6 +27,7 @@ from app.repositories.tool_repository import ToolRepository
from app.schemas.app_schema import FileInput
from app.schemas.prompt_schema import PromptMessageRole, render_prompt_message
from app.services import task_service
from app.services.conversation_service import ConversationService
from app.services.langchain_tool_server import Search
from app.services.memory_agent_service import MemoryAgentService
from app.services.model_parameter_merger import ModelParameterMerger
@@ -52,8 +54,12 @@ class LongTermMemoryInput(BaseModel):
description="经过优化重写的查询问题。请将用户的原始问题重写为更合适的检索形式,包含关键词,上下文和具体描述,注意错词检查并且改写")
def create_long_term_memory_tool(memory_config: Dict[str, Any], end_user_id: str, storage_type: Optional[str] = None,
user_rag_memory_id: Optional[str] = None):
def create_long_term_memory_tool(
memory_config: Dict[str, Any],
end_user_id: str,
storage_type: Optional[str] = None,
user_rag_memory_id: Optional[str] = None
):
"""创建记忆工具,
@@ -61,6 +67,7 @@ def create_long_term_memory_tool(memory_config: Dict[str, Any], end_user_id: str
memory_config: 记忆配置
end_user_id: 用户ID
storage_type: 存储类型(可选)
user_rag_memory_id: 用户RAG记忆ID可选
Returns:
长期记忆工具
@@ -188,7 +195,9 @@ def create_knowledge_retrieval_tool(kb_config, kb_ids, user_id):
"""从知识库中检索相关信息。当用户的问题需要参考知识库、文档或历史记录时,使用此工具进行检索。
Args:
query: 需要检索的问题或关键词
kb_config: 知识库配置
kb_ids: 知识库ID列表
user_id: 用户ID
Returns:
检索到的相关知识内容
@@ -232,17 +241,141 @@ def create_knowledge_retrieval_tool(kb_config, kb_ids, user_id):
return knowledge_retrieval_tool
class DraftRunService:
"""运行服务类"""
class AgentRunService:
"""Agent运行服务类"""
def __init__(self, db: Session):
"""初始化试运行服务
"""Agent运行服务
Args:
db: 数据库会话
"""
self.db = db
@staticmethod
def prepare_variables(
input_vars: dict | None,
variables_config: dict
) -> dict:
input_vars = input_vars or {}
for variable in variables_config:
if variable.get("required") and variable.get("name") not in input_vars:
raise ValueError(f"The required parameter '{variable.get('name')}' was not provided")
return input_vars
def load_tools_config(self, tools_config, web_search, tenant_id) -> list:
"""加载工具配置"""
if not tools_config:
return []
tools = []
tool_service = ToolService(self.db)
if tools_config and isinstance(tools_config, list):
for tool_config in tools_config:
if tool_config.get("enabled", False):
# 根据工具名称查找工具实例
tool_instance = tool_service.get_tool_instance(tool_config.get("tool_id", ""), tenant_id)
if tool_instance:
if tool_instance.name == "baidu_search_tool" and not web_search:
continue
# 转换为LangChain工具
langchain_tool = tool_instance.to_langchain_tool(tool_config.get("operation", None))
tools.append(langchain_tool)
elif tools_config and isinstance(tools_config, dict):
web_search_choice = tools_config.get("web_search", {})
web_search_enable = web_search_choice.get("enabled", False)
if web_search and web_search_enable:
search_tool = create_web_search_tool({})
tools.append(search_tool)
logger.debug(
"已添加网络搜索工具",
extra={
"tool_count": len(tools)
}
)
return tools
def load_skill_config(
self,
skills_config: dict | None,
message: str, tenant_id
) -> tuple[list, str]:
if not skills_config:
return [], ""
tools = []
skill_prompts = ""
skill_enable = skills_config.get("enabled", False)
if skill_enable:
middleware = AgentMiddleware(skills=skills_config)
skill_tools, skill_configs, tool_to_skill_map = middleware.load_skill_tools(self.db, tenant_id)
tools.extend(skill_tools)
logger.debug(f"已加载 {len(skill_tools)} 个技能工具")
if skill_configs:
tools, activated_skill_ids = middleware.filter_tools(tools, message, skill_configs,
tool_to_skill_map)
logger.debug(f"过滤后剩余 {len(tools)} 个工具")
skill_prompts = AgentMiddleware.get_active_prompts(
activated_skill_ids, skill_configs
)
return tools, skill_prompts
def load_knowledge_retrieval_config(
self,
knowledge_retrieval_config: dict | None,
user_id
) -> list:
if not knowledge_retrieval_config:
return []
tools = []
knowledge_bases = knowledge_retrieval_config.get("knowledge_bases", [])
kb_ids = bool(knowledge_bases and knowledge_bases[0].get("kb_id"))
if kb_ids:
# 创建知识库检索工具
kb_tool = create_knowledge_retrieval_tool(knowledge_retrieval_config, kb_ids, user_id)
tools.append(kb_tool)
logger.debug(
"已添加知识库检索工具",
extra={
"kb_ids": kb_ids,
"tool_count": len(tools)
}
)
return tools
def load_memory_config(
self,
memory_config: dict | None,
user_id,
storage_type,
user_rag_memory_id
) -> tuple[list, bool]:
"""加载长期记忆配置"""
if not memory_config:
return [], False
tools = []
if memory_config.get("enabled"):
if user_id:
# 创建长期记忆工具
memory_tool = create_long_term_memory_tool(memory_config, user_id, storage_type,
user_rag_memory_id)
tools.append(memory_tool)
logger.debug(
"已添加长期记忆工具",
extra={
"user_id": user_id,
"tool_count": len(tools)
}
)
return tools, bool(memory_config.get("enabled"))
async def run(
self,
*,
@@ -270,19 +403,21 @@ class DraftRunService:
conversation_id: 会话ID用于多轮对话
user_id: 用户ID
variables: 自定义变量参数值
storage_type: 存储类型(可选)
user_rag_memory_id: 用户RAG记忆ID可选
web_search: 是否启用网络搜索默认True
memory: 是否启用长期记忆默认True
sub_agent: 是否为子代理调用默认False
files: 多模态文件列表(可选)
Returns:
Dict: 包含 AI 回复和元数据的字典
"""
memory_flag = False
print('===========', storage_type)
print(user_id)
if variables == None: variables = {}
from app.core.agent.langchain_agent import LangChainAgent
start_time = time.time()
tools_config: dict | list | None = agent_config.tools
skills_config: dict | None = agent_config.skills
knowledge_retrieval_config: dict | None = agent_config.knowledge_retrieval
memory_config: dict | None = agent_config.memory
try:
# 1. 获取 API Key 配置
@@ -302,112 +437,40 @@ class DraftRunService:
agent_config=agent_config
)
items_params = variables
if sub_agent:
variables = self.prepare_variables(variables, agent_config.variables)
else:
# FIXME: subagent input valid
variables = variables or {}
system_prompt = render_prompt_message(
agent_config.system_prompt, # 修正拼写错误
agent_config.system_prompt,
PromptMessageRole.USER,
items_params
variables
)
# 3. 处理系统提示词(支持变量替换)
system_prompt = system_prompt.get_text_content() or "你是一个专业的AI助手"
print('系统提示词:', system_prompt)
# 4. 准备工具列表
tools = []
tool_service = ToolService(self.db)
tenant_id = ToolRepository.get_tenant_id_by_workspace_id(self.db, str(workspace_id))
# 从配置中获取启用的工具
if hasattr(agent_config, 'tools') and agent_config.tools and isinstance(agent_config.tools, list):
if hasattr(agent_config, 'tools') and agent_config.tools:
for tool_config in agent_config.tools:
print("+" * 50)
print(f"agent_config:{agent_config}")
print(f"tool_config:{tool_config}")
if tool_config.get("enabled", False):
# 根据工具名称查找工具实例
tool_instance = tool_service._get_tool_instance(tool_config.get("tool_id", ""), tenant_id)
if tool_instance:
if tool_instance.name == "baidu_search_tool" and not web_search:
continue
# 转换为LangChain工具
langchain_tool = tool_instance.to_langchain_tool(tool_config.get("operation", None))
tools.append(langchain_tool)
elif hasattr(agent_config, 'tools') and agent_config.tools and isinstance(agent_config.tools, dict):
web_tools = agent_config.tools
web_search_choice = web_tools.get("web_search", {})
web_search_enable = web_search_choice.get("enabled", False)
if web_search:
if web_search_enable:
search_tool = create_web_search_tool({})
tools.append(search_tool)
logger.debug(
"已添加网络搜索工具",
extra={
"tool_count": len(tools)
}
)
# 加载技能关联的工具
if hasattr(agent_config, 'skills') and agent_config.skills:
skills = agent_config.skills
skill_enable = skills.get("enabled", False)
if skill_enable:
middleware = AgentMiddleware(skills=skills)
skill_tools, skill_configs, tool_to_skill_map = middleware.load_skill_tools(self.db, tenant_id)
tools.extend(skill_tools)
logger.debug(f"已加载 {len(skill_tools)} 个技能工具")
# 应用动态过滤
if skill_configs:
tools, activated_skill_ids = middleware.filter_tools(tools, message, skill_configs,
tool_to_skill_map)
logger.debug(f"过滤后剩余 {len(tools)} 个工具")
active_prompts = AgentMiddleware.get_active_prompts(
activated_skill_ids, skill_configs
)
system_prompt = f"{system_prompt}\n\n{active_prompts}"
# 添加知识库检索工具
if agent_config.knowledge_retrieval:
kb_config = agent_config.knowledge_retrieval
knowledge_bases = kb_config.get("knowledge_bases", [])
kb_ids = bool(knowledge_bases and knowledge_bases[0].get("kb_id"))
if kb_ids:
# 创建知识库检索工具
kb_tool = create_knowledge_retrieval_tool(kb_config, kb_ids, user_id)
tools.append(kb_tool)
logger.debug(
"已添加知识库检索工具",
extra={
"kb_ids": kb_ids,
"tool_count": len(tools)
}
)
tools.extend(self.load_tools_config(tools_config, web_search, tenant_id))
skill_tools, skill_prompts = self.load_skill_config(skills_config, message, tenant_id)
tools.extend(skill_tools)
if skill_prompts:
system_prompt = f"{system_prompt}\n\n{skill_prompts}"
tools.extend(self.load_knowledge_retrieval_config(knowledge_retrieval_config, user_id))
# 添加长期记忆工具
memory_flag = False
if memory:
if agent_config.memory and agent_config.memory.get("enabled"):
memory_flag = True
memory_config = agent_config.memory
if user_id:
# 创建长期记忆工具
memory_tool = create_long_term_memory_tool(memory_config, user_id, storage_type,
user_rag_memory_id)
tools.append(memory_tool)
logger.debug(
"已添加长期记忆工具",
extra={
"user_id": user_id,
"tool_count": len(tools)
}
)
memory_tools, memory_flag = self.load_memory_config(
memory_config, user_id, storage_type, user_rag_memory_id
)
tools.extend(memory_tools)
# 4. 创建 LangChain Agent
agent = LangChainAgent(
@@ -415,6 +478,7 @@ class DraftRunService:
api_key=api_key_config["api_key"],
provider=api_key_config.get("provider", "openai"),
api_base=api_key_config.get("api_base"),
is_omni=api_key_config.get("is_omni", False),
temperature=effective_params.get("temperature", 0.7),
max_tokens=effective_params.get("max_tokens", 2000),
system_prompt=system_prompt,
@@ -431,7 +495,7 @@ class DraftRunService:
# 6. 加载历史消息
history = []
if agent_config.memory and agent_config.memory.get("enabled"):
if memory_config and memory_config.get("enabled"):
history = await self._load_conversation_history(
conversation_id=conversation_id,
max_history=agent_config.memory.get("max_history", 10)
@@ -442,7 +506,7 @@ class DraftRunService:
if files:
# 获取 provider 信息
provider = api_key_config.get("provider", "openai")
multimodal_service = MultimodalService(self.db, provider=provider)
multimodal_service = MultimodalService(self.db, provider=provider, is_omni=api_key_config.get("is_omni", False))
processed_files = await multimodal_service.process_files(files)
logger.info(f"处理了 {len(processed_files)} 个文件provider={provider}")
@@ -481,7 +545,7 @@ class DraftRunService:
ModelApiKeyService.record_api_key_usage(self.db, api_key_config.get("api_key_id"))
# 9. 保存会话消息
if not sub_agent and agent_config.memory and agent_config.memory.get("enabled"):
if not sub_agent and memory_config and memory_config.get("enabled"):
await self._save_conversation_message(
conversation_id=conversation_id,
user_message=message,
@@ -556,16 +620,21 @@ class DraftRunService:
Yields:
str: SSE 格式的事件数据
"""
memory_flag = False
if variables == None: variables = {}
from app.core.agent.langchain_agent import LangChainAgent
tools_config: dict | list | None = agent_config.tools
skills_config: dict | None = agent_config.skills
knowledge_retrieval_config: dict | None = agent_config.knowledge_retrieval
memory_config: dict | None = agent_config.memory
start_time = time.time()
try:
# 1. 获取 API Key 配置
api_key_config = await self._get_api_key(model_config.id)
if not sub_agent:
variables = self.prepare_variables(variables, agent_config.variables)
else:
# FIXME: subagent input valid
variables = variables or {}
# 2. 合并模型参数
effective_params = ModelParameterMerger.get_effective_parameters(
@@ -587,95 +656,22 @@ class DraftRunService:
# 4. 准备工具列表
tools = []
tool_service = ToolService(self.db)
tenant_id = ToolRepository.get_tenant_id_by_workspace_id(self.db, str(workspace_id))
# 从配置中获取启用的工具
if hasattr(agent_config, 'tools') and agent_config.tools and isinstance(agent_config.tools, list):
for tool_config in agent_config.tools:
# print("+"*50)
# print(f"agent_config:{agent_config}")
# print(f"tool_config:{tool_config}")
if tool_config.get("enabled", False):
# 根据工具名称查找工具实例
tool_instance = tool_service._get_tool_instance(tool_config.get("tool_id", ""), tenant_id)
if tool_instance:
if tool_instance.name == "baidu_search_tool" and not web_search:
continue
# 转换为LangChain工具
langchain_tool = tool_instance.to_langchain_tool(tool_config.get("operation", None))
tools.append(langchain_tool)
elif hasattr(agent_config, 'tools') and agent_config.tools and isinstance(agent_config.tools, dict):
web_tools = agent_config.tools
web_search_choice = web_tools.get("web_search", {})
web_search_enable = web_search_choice.get("enabled", False)
if web_search:
if web_search_enable:
search_tool = create_web_search_tool({})
tools.append(search_tool)
tools.extend(self.load_tools_config(tools_config, web_search, tenant_id))
skill_tools, skill_prompts = self.load_skill_config(skills_config, message, tenant_id)
tools.extend(skill_tools)
if skill_prompts:
system_prompt = f"{system_prompt}\n\n{skill_prompts}"
tools.extend(self.load_knowledge_retrieval_config(knowledge_retrieval_config, user_id))
logger.debug(
"已添加网络搜索工具",
extra={
"tool_count": len(tools)
}
)
# 加载技能关联的工具
if hasattr(agent_config, 'skills') and agent_config.skills:
skills = agent_config.skills
skill_enable = skills.get("enabled", False)
if skill_enable:
middleware = AgentMiddleware(skills=skills)
skill_tools, skill_configs, tool_to_skill_map = middleware.load_skill_tools(self.db, tenant_id)
tools.extend(skill_tools)
logger.debug(f"已加载 {len(skill_tools)} 个技能工具")
# 应用动态过滤
if skill_configs:
tools, activated_skill_ids = middleware.filter_tools(tools, message, skill_configs,
tool_to_skill_map)
logger.debug(f"过滤后剩余 {len(tools)} 个工具")
active_prompts = AgentMiddleware.get_active_prompts(
activated_skill_ids, skill_configs
)
system_prompt = f"{system_prompt}\n\n{active_prompts}"
# 添加知识库检索工具
if agent_config.knowledge_retrieval:
kb_config = agent_config.knowledge_retrieval
knowledge_bases = kb_config.get("knowledge_bases", [])
kb_ids = bool(knowledge_bases and knowledge_bases[0].get("kb_id"))
if kb_ids:
# 创建知识库检索工具
kb_tool = create_knowledge_retrieval_tool(kb_config, kb_ids, user_id)
tools.append(kb_tool)
logger.debug(
"已添加知识库检索工具",
extra={
"kb_ids": kb_ids,
"tool_count": len(tools)
}
)
# 添加长期记忆工具
memory_flag = False
if memory:
if agent_config.memory and agent_config.memory.get("enabled"):
memory_flag = True
memory_config = agent_config.memory
if user_id:
# 创建长期记忆工具
memory_tool = create_long_term_memory_tool(memory_config, user_id, storage_type,
user_rag_memory_id)
tools.append(memory_tool)
logger.debug(
"已添加长期记忆工具",
extra={
"user_id": user_id,
"tool_count": len(tools)
}
)
memory_tools, memory_flag = self.load_memory_config(memory_config, user_id, storage_type,
user_rag_memory_id)
tools.extend(memory_tools)
# 4. 创建 LangChain Agent
agent = LangChainAgent(
@@ -683,6 +679,7 @@ class DraftRunService:
api_key=api_key_config["api_key"],
provider=api_key_config.get("provider", "openai"),
api_base=api_key_config.get("api_base"),
is_omni=api_key_config.get("is_omni", False),
temperature=effective_params.get("temperature", 0.7),
max_tokens=effective_params.get("max_tokens", 2000),
system_prompt=system_prompt,
@@ -700,10 +697,10 @@ class DraftRunService:
# 6. 加载历史消息
history = []
if agent_config.memory and agent_config.memory.get("enabled"):
if memory_config and memory_config.get("enabled"):
history = await self._load_conversation_history(
conversation_id=conversation_id,
max_history=agent_config.memory.get("max_history", 10)
max_history=memory_config.get("max_history", 10)
)
# 6. 处理多模态文件
@@ -711,7 +708,7 @@ class DraftRunService:
if files:
# 获取 provider 信息
provider = api_key_config.get("provider", "openai")
multimodal_service = MultimodalService(self.db, provider=provider)
multimodal_service = MultimodalService(self.db, provider=provider, is_omni=api_key_config.get("is_omni", False))
processed_files = await multimodal_service.process_files(files)
logger.info(f"处理了 {len(processed_files)} 个文件provider={provider}")
@@ -761,7 +758,7 @@ class DraftRunService:
})
# 10. 保存会话消息
if not sub_agent and agent_config.memory and agent_config.memory.get("enabled"):
if not sub_agent and memory_config and memory_config.get("enabled"):
await self._save_conversation_message(
conversation_id=conversation_id,
user_message=message,
@@ -809,7 +806,7 @@ class DraftRunService:
"""
return f"event: {event}\ndata: {json.dumps(data, ensure_ascii=False)}\n\n"
async def _get_api_key(self, model_config_id: uuid.UUID) -> Dict[str, str]:
async def _get_api_key(self, model_config_id: uuid.UUID) -> Dict:
"""获取模型的 API Key
Args:
@@ -846,7 +843,8 @@ class DraftRunService:
"provider": api_key.provider,
"api_key": api_key.api_key,
"api_base": api_key.api_base,
"api_key_id": api_key.id
"api_key_id": api_key.id,
"is_omni": api_key.is_omni
}
async def _ensure_conversation(
@@ -966,7 +964,6 @@ class DraftRunService:
List[Dict]: 历史消息列表
"""
try:
from app.services.conversation_service import ConversationService
conversation_service = ConversationService(self.db)
history = conversation_service.get_conversation_history(
@@ -1486,6 +1483,15 @@ class DraftRunService:
"conversation_id": returned_conversation_id,
"content": chunk
}))
if event_type == "error" and event_data:
await event_queue.put(self._format_sse_event("model_error", {
"model_index": idx,
"model_config_id": model_config_id,
"label": model_label,
"conversation_id": returned_conversation_id,
"error": event_data.get("error", "未知错误")
}))
except Exception as e:
logger.warning(f"解析流式事件失败: {e}")
finally:
@@ -1670,41 +1676,3 @@ class DraftRunService:
"total_time": sum(r.get("elapsed_time", 0) for r in results)
}
)
async def draft_run(
db: Session,
*,
agent_config: AgentConfig,
model_config: ModelConfig,
message: str,
user_id: Optional[str] = None,
kb_ids: Optional[List[str]] = None,
similarity_threshold: float = 0.7,
top_k: int = 3
) -> Dict[str, Any]:
"""试运行 Agent便捷函数
Args:
db: 数据库会话
agent_config: Agent 配置
model_config: 模型配置
message: 用户消息
user_id: 用户ID
kb_ids: 知识库ID列表
similarity_threshold: 相似度阈值
top_k: 检索返回的文档数量
Returns:
Dict: 包含 AI 回复和元数据的字典
"""
service = DraftRunService(db)
return await service.run(
agent_config=agent_config,
model_config=model_config,
message=message,
user_id=user_id,
kb_ids=kb_ids,
similarity_threshold=similarity_threshold,
top_k=top_k
)

View File

@@ -843,32 +843,33 @@ class EmotionAnalyticsService:
end_user_id: str,
db: Session,
) -> Optional[Dict[str, Any]]:
""" Redis 缓存获取个性化情绪建议
"""数据库获取个性化情绪建议
Args:
end_user_id: 宿主ID用户组ID
db: 数据库会话(保留参数以保持接口兼容性)
db: 数据库会话
Returns:
Dict: 存的建议数据,如果不存在或已过期返回 None
Dict: 存的建议数据,如果不存在返回 None
"""
try:
from app.cache.memory.emotion_memory import EmotionMemoryCache
from app.repositories.implicit_emotions_storage_repository import ImplicitEmotionsStorageRepository
logger.info(f"尝试从 Redis 缓存获取情绪建议: user={end_user_id}")
logger.info(f"尝试从数据库获取情绪建议: user={end_user_id}")
# 从 Redis 获取缓存
cached_data = await EmotionMemoryCache.get_emotion_suggestions(end_user_id)
# 从数据库获取存储记录
repo = ImplicitEmotionsStorageRepository(db)
storage = repo.get_by_end_user_id(end_user_id)
if cached_data is None:
logger.info(f"用户 {end_user_id} 的建议缓存不存在或已过期")
if storage is None or storage.emotion_suggestions is None:
logger.info(f"用户 {end_user_id} 的建议数据不存在")
return None
logger.info(f"成功从 Redis 缓存获取建议: user={end_user_id}")
return cached_data
logger.info(f"成功从数据库获取建议: user={end_user_id}")
return storage.emotion_suggestions
except Exception as e:
logger.error(f" Redis 缓存获取建议失败: {str(e)}", exc_info=True)
logger.error(f"数据库获取建议失败: {str(e)}", exc_info=True)
return None
async def save_suggestions_cache(
@@ -876,36 +877,27 @@ class EmotionAnalyticsService:
end_user_id: str,
suggestions_data: Dict[str, Any],
db: Session,
expires_hours: int = 24
expires_hours: int = 24 # 参数保留以保持接口兼容性
) -> None:
"""保存建议到 Redis 缓存
"""保存建议到数据库
Args:
end_user_id: 宿主ID用户组ID
suggestions_data: 建议数据
db: 数据库会话(保留参数以保持接口兼容性)
expires_hours: 过期时间小时默认24小时
db: 数据库会话
expires_hours: 保留参数(兼容性)
"""
try:
from app.cache.memory.emotion_memory import EmotionMemoryCache
from app.repositories.implicit_emotions_storage_repository import ImplicitEmotionsStorageRepository
logger.info(f"保存建议到 Redis 缓存: user={end_user_id}, expires={expires_hours}小时")
logger.info(f"保存建议到数据库: user={end_user_id}")
# 计算过期时间(秒)
expire_seconds = expires_hours * 3600
repo = ImplicitEmotionsStorageRepository(db)
repo.update_emotion_suggestions(end_user_id, suggestions_data)
db.commit()
# 保存到 Redis
success = await EmotionMemoryCache.set_emotion_suggestions(
user_id=end_user_id,
suggestions_data=suggestions_data,
expire=expire_seconds
)
if success:
logger.info(f"建议缓存保存成功: user={end_user_id}")
else:
logger.warning(f"建议缓存保存失败: user={end_user_id}")
logger.info(f"建议保存成功: user={end_user_id}")
except Exception as e:
logger.error(f"保存建议缓存失败: {str(e)}", exc_info=True)
# 不抛出异常,缓存失败不应影响主流程
db.rollback()
logger.error(f"保存建议失败: {str(e)}", exc_info=True)

View File

@@ -544,6 +544,7 @@ def convert_multi_agent_config_to_handoffs(
provider=model_api_key.provider,
api_key=model_api_key.api_key,
base_url=model_api_key.api_base,
is_omni=model_api_key.is_omni,
extra_params={
"temperature": 0.7,
"max_tokens": 2000,

View File

@@ -422,32 +422,33 @@ class ImplicitMemoryService:
end_user_id: str,
db: Session
) -> Optional[dict]:
""" Redis 缓存获取完整用户画像
"""数据库获取完整用户画像
Args:
end_user_id: 终端用户ID
db: 数据库会话(保留参数以保持接口兼容性)
db: 数据库会话
Returns:
Dict: 存的画像数据,如果不存在或已过期返回 None
Dict: 存的画像数据,如果不存在返回 None
"""
try:
from app.cache.memory.implicit_memory import ImplicitMemoryCache
from app.repositories.implicit_emotions_storage_repository import ImplicitEmotionsStorageRepository
logger.info(f"尝试从 Redis 缓存获取用户画像: user={end_user_id}")
logger.info(f"尝试从数据库获取用户画像: user={end_user_id}")
# 从 Redis 获取缓存
cached_data = await ImplicitMemoryCache.get_user_profile(end_user_id)
# 从数据库获取存储记录
repo = ImplicitEmotionsStorageRepository(db)
storage = repo.get_by_end_user_id(end_user_id)
if cached_data is None:
logger.info(f"用户 {end_user_id} 的画像缓存不存在或已过期")
if storage is None or storage.implicit_profile is None:
logger.info(f"用户 {end_user_id} 的画像数据不存在")
return None
logger.info(f"成功从 Redis 缓存获取用户画像: user={end_user_id}")
return cached_data
logger.info(f"成功从数据库获取用户画像: user={end_user_id}")
return storage.implicit_profile
except Exception as e:
logger.error(f" Redis 缓存获取用户画像失败: {str(e)}", exc_info=True)
logger.error(f"数据库获取用户画像失败: {str(e)}", exc_info=True)
return None
async def save_profile_cache(
@@ -455,36 +456,27 @@ class ImplicitMemoryService:
end_user_id: str,
profile_data: dict,
db: Session,
expires_hours: int = 168 # 默认7天
expires_hours: int = 168 # 参数保留以保持接口兼容性
) -> None:
"""保存用户画像到 Redis 缓存
"""保存用户画像到数据库
Args:
end_user_id: 终端用户ID
profile_data: 画像数据
db: 数据库会话(保留参数以保持接口兼容性)
expires_hours: 过期时间小时默认168小时7天
db: 数据库会话
expires_hours: 保留参数(兼容性
"""
try:
from app.cache.memory.implicit_memory import ImplicitMemoryCache
from app.repositories.implicit_emotions_storage_repository import ImplicitEmotionsStorageRepository
logger.info(f"保存用户画像到 Redis 缓存: user={end_user_id}, expires={expires_hours}小时")
logger.info(f"保存用户画像到数据库: user={end_user_id}")
# 计算过期时间(秒)
expire_seconds = expires_hours * 3600
repo = ImplicitEmotionsStorageRepository(db)
repo.update_implicit_profile(end_user_id, profile_data)
db.commit()
# 保存到 Redis
success = await ImplicitMemoryCache.set_user_profile(
user_id=end_user_id,
profile_data=profile_data,
expire=expire_seconds
)
if success:
logger.info(f"用户画像缓存保存成功: user={end_user_id}")
else:
logger.warning(f"用户画像缓存保存失败: user={end_user_id}")
logger.info(f"用户画像保存成功: user={end_user_id}")
except Exception as e:
logger.error(f"保存用户画像缓存失败: {str(e)}", exc_info=True)
# 不抛出异常,缓存失败不应影响主流程
db.rollback()
logger.error(f"保存用户画像失败: {str(e)}", exc_info=True)

View File

@@ -9,6 +9,8 @@ load_dotenv()
# 读取web_search环境变量
web_search_value = os.getenv('web_search')
def Search(query):
url = "https://qianfan.baidubce.com/v2/ai_search/chat/completions"
api_key = web_search_value
@@ -18,23 +20,24 @@ def Search(query):
"role": "user",
"content": query
}
], #搜索输入
"edition":"standard", #搜索版本。默认为standard。可选值standard完整版本。lite标准版本对召回规模和精排条数简化后的版本时延表现更好效果略弱于完整版。
"search_source": "baidu_search_v2", #使用的搜索引擎版本
"resource_type_filter": [{"type": "web","top_k": 20}], #支持设置网页、视频、图片、阿拉丁搜索模态网页top_k最大取值为50视频top_k最大为10图片top_k最大为30阿拉丁top_k最大为5
], # 搜索输入
"edition": "standard", # 搜索版本。默认为standard。可选值standard完整版本。lite标准版本对召回规模和精排条数简化后的版本时延表现更好效果略弱于完整版。
"search_source": "baidu_search_v2", # 使用的搜索引擎版本
"resource_type_filter": [{"type": "web", "top_k": 20}],
# 支持设置网页、视频、图片、阿拉丁搜索模态网页top_k最大取值为50视频top_k最大为10图片top_k最大为30阿拉丁top_k最大为5
"search_filter": {
"range": {
"page_time": {
"gte": "now-1w/d", #时间查询参数,大于或等于
"lt": "now/d", #时间查询参数,小于
"gt": "", #时间查询参数,大于
"lte": "" #时间查询参数,小于或等于
"gte": "now-1w/d", # 时间查询参数,大于或等于
"lt": "now/d", # 时间查询参数,小于
"gt": "", # 时间查询参数,大于
"lte": "" # 时间查询参数,小于或等于
}
}
},
"block_websites":["tieba.baidu.com"], #需要屏蔽的站点列表
"search_recency_filter":"week", #根据网页发布时间进行筛选可填值为week,month,semiyear,year
"enable_full_content":True #是否输出网页完整原文
"block_websites": ["tieba.baidu.com"], # 需要屏蔽的站点列表
"search_recency_filter": "week", # 根据网页发布时间进行筛选可填值为week,month,semiyear,year
"enable_full_content": True # 是否输出网页完整原文
}, ensure_ascii=False)
headers = {
'Content-Type': 'application/json',
@@ -42,10 +45,10 @@ def Search(query):
}
response = requests.request("POST", url, headers=headers, data=payload.encode("utf-8")).json()
content=[]
content = []
for i in response['references']:
title=i['title']
snippet=i['snippet']
content.append(title+';'+snippet)
content=''.join(content)
return content
title = i['title']
snippet = i['snippet']
content.append(title + ';' + snippet)
content = ''.join(content)
return content

View File

@@ -414,6 +414,7 @@ class LLMRouter:
provider=api_key_config.provider,
api_key=api_key_config.api_key,
base_url=api_key_config.api_base,
is_omni=api_key_config.is_omni,
temperature=0.3,
max_tokens=500
)

View File

@@ -392,6 +392,7 @@ class MasterAgentRouter:
provider=api_key_config.provider,
api_key=api_key_config.api_key,
base_url=api_key_config.api_base,
is_omni=api_key_config.is_omni,
extra_params = extra_params
)

View File

@@ -36,7 +36,7 @@ from app.core.memory.agent.utils.messages_tools import (
)
from app.core.memory.agent.utils.type_classifier import status_typle
from app.core.memory.agent.utils.write_tools import write # 新增:直接导入 write 函数
from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags
from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags, get_interest_distribution
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.db import get_db_context
from app.models.knowledge_model import Knowledge, KnowledgeType
@@ -816,11 +816,10 @@ class MemoryAgentService:
"""
统计知识库类型分布,包含:
1. PostgreSQL 中的知识库类型General, Web, Third-party, Folder根据 workspace_id 过滤)
2. Neo4j 中的 Memory 类型(仅统计 Chunk 数量,根据 end_user_id/end_user_id 过滤)
3. total: 所有类型的总和
2. total: 所有类型的总和
参数:
- end_user_id: 用户组ID可选未提供时 Memory 统计为 0
- end_user_id: 用户组ID可选保留参数以保持接口兼容性
- only_active: 是否仅统计有效记录
- current_workspace_id: 当前工作空间ID可选未提供时知识库统计为 0
- db: 数据库会话
@@ -831,7 +830,6 @@ class MemoryAgentService:
"Web": count,
"Third-party": count,
"Folder": count,
"Memory": chunk_count,
"total": sum_of_all
}
"""
@@ -878,51 +876,8 @@ class MemoryAgentService:
logger.error(f"知识库类型统计失败: {e}")
raise Exception(f"知识库类型统计失败: {e}")
# 2. 统计 Neo4j 中的 memory 总量(统计当前空间下所有宿主的 Chunk 总数)
try:
if current_workspace_id:
# 获取当前空间下的所有宿主
from app.repositories import app_repository, end_user_repository
from app.schemas.app_schema import App as AppSchema
from app.schemas.end_user_schema import EndUser as EndUserSchema
# 查询应用并转换为 Pydantic 模型
apps_orm = app_repository.get_apps_by_workspace_id(db, current_workspace_id)
apps = [AppSchema.model_validate(h) for h in apps_orm]
app_ids = [app.id for app in apps]
# 获取所有宿主
end_users = []
for app_id in app_ids:
end_user_orm_list = end_user_repository.get_end_users_by_app_id(db, app_id)
end_users.extend(h for h in end_user_orm_list)
# 统计所有宿主的 Chunk 总数
total_chunks = 0
for end_user in end_users:
end_user_id_str = str(end_user.id)
memory_query = """
MATCH (n:Chunk) WHERE n.end_user_id = $end_user_id RETURN count(n) AS Count
"""
neo4j_result = await _neo4j_connector.execute_query(
memory_query,
end_user_id=end_user_id_str,
)
chunk_count = neo4j_result[0]["Count"] if neo4j_result else 0
total_chunks += chunk_count
logger.debug(f"EndUser {end_user_id_str} Chunk数量: {chunk_count}")
result["Memory"] = total_chunks
logger.info(f"Neo4j memory统计成功: 总Chunk数={total_chunks}, 宿主数={len(end_users)}")
else:
# 没有 workspace_id 时,返回 0
result["Memory"] = 0
logger.info("未提供 workspace_idmemory 统计为 0")
except Exception as e:
logger.error(f"Neo4j memory统计失败: {e}", exc_info=True)
# 如果 Neo4j 查询失败memory 设为 0
result["Memory"] = 0
# 2. 统计 Neo4j 中的 memory 总量已移除
# memory 字段不再返回
# 3. 计算知识库类型总和(不包括 memory
result["total"] = (
@@ -935,36 +890,36 @@ class MemoryAgentService:
return result
async def get_hot_memory_tags_by_user(
async def get_interest_distribution_by_user(
self,
end_user_id: Optional[str] = None,
limit: int = 20
limit: int = 5,
language: str = "zh"
) -> List[Dict[str, Any]]:
"""
获取指定用户的热门记忆标签
获取指定用户的兴趣分布标签
与热门标签不同,此接口专注于识别用户的兴趣活动(运动、爱好、学习等),
过滤掉纯物品、工具、地点等不代表用户主动参与活动的名词。
参数:
- end_user_id: 用户ID可选对应Neo4j中的end_user_id字段
- end_user_id: 用户ID必填)
- limit: 返回标签数量限制
- language: 输出语言("zh" 中文, "en" 英文)
返回格式:
[
{"name": "标签", "frequency": 频次},
{"name": "兴趣活动", "frequency": 频次},
...
]
注意:标签语言由写入时的 X-Language-Type 决定,查询时不进行翻译
"""
try:
# by_user=False 表示按 end_user_id 查询在Neo4j中end_user_id就是用户维度
tags = await get_hot_memory_tags(end_user_id, limit=limit, by_user=False)
payload = []
for tag, freq in tags:
payload.append({"name": tag, "frequency": freq})
return payload
tags = await get_interest_distribution(end_user_id, limit=limit, by_user=False, language=language)
return [{"name": tag, "frequency": freq} for tag, freq in tags]
except Exception as e:
logger.error(f"热门记忆标签查询失败: {e}")
raise Exception(f"热门记忆标签查询失败: {e}")
logger.error(f"兴趣分布标签查询失败: {e}")
raise Exception(f"兴趣分布标签查询失败: {e}")
async def get_user_profile(

View File

@@ -140,9 +140,11 @@ class MemoryAPIService:
try:
# Delegate to MemoryAgentService
# Convert string message to list[dict] format expected by MemoryAgentService
messages = message if isinstance(message, list) else [{"role": "user", "content": message}]
result = await MemoryAgentService().write_memory(
end_user_id=end_user_id,
messages=message,
messages=messages,
config_id=config_id,
db=self.db,
storage_type=storage_type,
@@ -151,9 +153,18 @@ class MemoryAPIService:
logger.info(f"Memory write successful for end_user: {end_user_id}")
# result may be a string "success" or a dict with a "status" key
# Preserve the full dict so callers don't silently lose extra fields
# (e.g. error codes, metadata) returned by MemoryAgentService.
if isinstance(result, dict):
return {
**result,
"status": result.get("status", "unknown"),
"end_user_id": end_user_id,
}
return {
"status": "success" if result == "success" else result,
"end_user_id": end_user_id
"status": result if isinstance(result, str) else "success",
"end_user_id": end_user_id,
}
except ConfigurationError as e:

View File

@@ -390,19 +390,59 @@ def get_rag_total_kb(
current_user: User
) -> int:
"""
根据当前用户所在的workspace_id查询konwledges表所有不同id的数量
根据当前用户所在的workspace_id查询konwledges表中排除用户知识库permission_id!='Memory'的数量
"""
workspace_id = current_user.current_workspace_id
business_logger.info(f"获取RAG总知识库数: workspace_id={workspace_id}, 操作者: {current_user.username}")
business_logger.info(f"获取RAG总知识库数(排除用户知识库): workspace_id={workspace_id}, 操作者: {current_user.username}")
try:
total_kb = knowledge_repository.get_total_kb_count_by_workspace(db, workspace_id)
total_kb = knowledge_repository.get_non_user_kb_count_by_workspace(db, workspace_id)
business_logger.info(f"成功获取RAG总知识库数: {total_kb}")
return total_kb
except Exception as e:
business_logger.error(f"获取RAG总知识库数失败: workspace_id={workspace_id} - {str(e)}")
raise
def get_rag_user_kb_total_chunk(
db: Session,
current_user: User
) -> int:
"""
根据当前用户所在的workspace_id从documents表统计所有用户知识库的chunk总数。
与 /end_users 接口保持同源:查询 file_name 匹配 end_user_id.txt 的文档 chunk_num 之和。
"""
workspace_id = current_user.current_workspace_id
business_logger.info(f"获取用户知识库总chunk数(documents表): workspace_id={workspace_id}, 操作者: {current_user.username}")
try:
from app.models.document_model import Document
from app.models.end_user_model import EndUser
from app.models.app_model import App
from sqlalchemy import func
# 通过 App 关联取该 workspace 下所有 end_user_id
end_user_ids = [
str(eid) for (eid,) in db.query(EndUser.id)
.join(App, EndUser.app_id == App.id)
.filter(App.workspace_id == workspace_id)
.all()
]
if not end_user_ids:
return 0
file_names = [f"{uid}.txt" for uid in end_user_ids]
result = db.query(func.sum(Document.chunk_num)).filter(
Document.file_name.in_(file_names)
).scalar()
total_chunk = int(result or 0)
business_logger.info(f"成功获取用户知识库总chunk数: {total_chunk}")
return total_chunk
except Exception as e:
business_logger.error(f"获取用户知识库总chunk数失败: workspace_id={workspace_id} - {str(e)}")
raise
def get_current_user_total_chunk(
end_user_id: str,
db: Session,

View File

@@ -1,22 +1,37 @@
from typing import Dict, List
from sqlalchemy.orm import Session
from app.core.logging_config import get_api_logger
from app.db import get_db
from app.repositories.memory_short_repository import LongTermMemoryRepository
from app.repositories.memory_short_repository import ShortTermMemoryRepository
from app.repositories.memory_short_repository import (
LongTermMemoryRepository,
ShortTermMemoryRepository,
)
api_logger = get_api_logger()
db=next(get_db())
class ShortService:
def __init__(self, end_user_id):
def __init__(self, end_user_id: str, db: Session) -> None:
"""Service for short-term memory queries.
Args:
end_user_id: The end user identifier to query memories for.
db: SQLAlchemy database session (caller-managed lifecycle).
"""
self.short_repo = ShortTermMemoryRepository(db)
self.end_user_id = end_user_id
def get_short_databasets(self):
def get_short_databasets(self) -> List[Dict]:
"""Retrieve the latest short-term memory entries for the user.
Returns:
List[Dict]: List of memory dicts with retrieval, message, and answer keys.
"""
short_memories = self.short_repo.get_latest_by_user_id(self.end_user_id, 3)
short_result = []
for memory in short_memories:
deep_expanded = {} # Create a new dictionary for each memory
deep_expanded = {}
messages = memory.messages
aimessages = memory.aimessages
retrieved_content = memory.retrieved_content or []
@@ -27,23 +42,41 @@ class ShortService:
for item in retrieved_content:
if isinstance(item, dict):
for key, values in item.items():
retrieval_source.append({"query": key, "retrieval": values,"source":"上下文记忆"})
retrieval_source.append({"query": key, "retrieval": values, "source": "上下文记忆"})
deep_expanded['retrieval'] = retrieval_source
deep_expanded['message'] = messages # 修正拼写错误
deep_expanded['message'] = messages
deep_expanded['answer'] = aimessages
short_result.append(deep_expanded)
return short_result
def get_short_count(self):
def get_short_count(self) -> int:
"""Count total short-term memory entries for the user.
Returns:
int: Number of short-term memory records.
"""
short_count = self.short_repo.count_by_user_id(self.end_user_id)
return short_count
class LongService:
def __init__(self, end_user_id):
def __init__(self, end_user_id: str, db: Session) -> None:
"""Service for long-term memory queries.
Args:
end_user_id: The end user identifier to query memories for.
db: SQLAlchemy database session (caller-managed lifecycle).
"""
self.long_repo = LongTermMemoryRepository(db)
self.end_user_id = end_user_id
def get_long_databasets(self):
# 获取长期记忆数据
def get_long_databasets(self) -> List[Dict]:
"""Retrieve long-term memory retrieval data for the user.
Returns:
List[Dict]: List of dicts with query and retrieval keys.
"""
long_memories = self.long_repo.get_by_user_id(self.end_user_id, 1)
long_result = []

View File

@@ -90,7 +90,8 @@ class ModelConfigService:
api_key: str,
api_base: Optional[str] = None,
model_type: str = "llm",
test_message: str = "Hello"
test_message: str = "Hello",
is_omni: bool = False
) -> Dict[str, Any]:
"""验证模型配置是否有效
@@ -102,6 +103,7 @@ class ModelConfigService:
api_base: API基础URL
model_type: 模型类型 (llm/chat/embedding/rerank)
test_message: 测试消息
is_omni: 是否为Omni模型
Returns:
Dict: 验证结果
@@ -114,14 +116,27 @@ class ModelConfigService:
try:
start_time = time.time()
model_config = RedBearModelConfig(
model_name=model_name,
provider=provider,
api_key=api_key,
base_url=api_base,
temperature=0.7,
max_tokens=100
)
# dashscope 的 omni 模型需要使用 compatible-mode
if provider.lower() == ModelProvider.DASHSCOPE and is_omni:
if not api_base:
api_base = "https://dashscope.aliyuncs.com/compatible-mode/v1"
model_config = RedBearModelConfig(
model_name=model_name,
provider=ModelProvider.OPENAI,
api_key=api_key,
base_url=api_base,
temperature=0.7,
max_tokens=100
)
else:
model_config = RedBearModelConfig(
model_name=model_name,
provider=provider,
api_key=api_key,
base_url=api_base,
temperature=0.7,
max_tokens=100
)
# 根据模型类型选择不同的验证方式
model_type_lower = model_type.lower()
@@ -257,8 +272,9 @@ class ModelConfigService:
provider=model_data.provider,
api_key=api_key_data.api_key,
api_base=api_key_data.api_base,
model_type=model_data.type, # 传递模型类型
test_message="Hello"
model_type=model_data.type,
test_message="Hello",
is_omni=model_data.is_omni
)
if not validation_result["valid"]:
raise BusinessException(
@@ -279,6 +295,9 @@ class ModelConfigService:
for api_key_data in api_key_datas:
api_key_data.model_name = model_data.name
api_key_data.provider = model_data.provider
# 同步capability和is_omni
api_key_data.capability = model_data.capability
api_key_data.is_omni = model_data.is_omni
api_key_create_schema = ModelApiKeyCreate(
model_config_ids=[model.id],
**api_key_data.model_dump()
@@ -497,6 +516,8 @@ class ModelApiKeyService:
existing_key.config = data.config
existing_key.priority = data.priority
existing_key.model_name = model_name
existing_key.capability = data.capability
existing_key.is_omni = data.is_omni
# 检查是否已关联该模型配置
if model_config not in existing_key.model_configs:
@@ -513,7 +534,8 @@ class ModelApiKeyService:
api_key=data.api_key,
api_base=data.api_base,
model_type=model_config.type,
test_message="Hello"
test_message="Hello",
is_omni=data.is_omni
)
if not validation_result["valid"]:
# 记录验证失败的模型,但不抛出异常
@@ -528,6 +550,8 @@ class ModelApiKeyService:
provider=data.provider,
api_key=data.api_key,
api_base=data.api_base,
capability=data.capability if data.capability is not None else model_config.capability,
is_omni=data.is_omni if data.is_omni is not None else model_config.is_omni,
config=data.config,
is_active=data.is_active,
priority=data.priority
@@ -572,6 +596,8 @@ class ModelApiKeyService:
existing_key.config = api_key_data.config
existing_key.priority = api_key_data.priority
existing_key.model_name = api_key_data.model_name
existing_key.capability = api_key_data.capability
existing_key.is_omni = api_key_data.is_omni
# 检查是否已关联该模型配置
if model_config not in existing_key.model_configs:
@@ -589,7 +615,8 @@ class ModelApiKeyService:
api_key=api_key_data.api_key,
api_base=api_key_data.api_base,
model_type=model_config.type,
test_message="Hello"
test_message="Hello",
is_omni=model_config.is_omni
)
if not validation_result["valid"]:
raise BusinessException(
@@ -620,7 +647,8 @@ class ModelApiKeyService:
api_key=api_key_data.api_key or existing_api_key.api_key,
api_base=api_key_data.api_base or existing_api_key.api_base,
model_type=model_config.type,
test_message="Hello"
test_message="Hello",
is_omni=model_config.is_omni
)
if not validation_result["valid"]:
raise BusinessException(
@@ -755,6 +783,8 @@ class ModelBaseService:
"type": model_base.type,
"logo": model_base.logo,
"description": model_base.description,
"capability": model_base.capability,
"is_omni": model_base.is_omni,
"is_composite": False
}
model_config = ModelConfigRepository.create(db, model_config_data)

View File

@@ -123,11 +123,14 @@ class MultiAgentOrchestrator:
user_id: 用户 ID
variables: 变量参数
use_llm_routing: 是否使用 LLM 路由
web_search: 是否启用网络搜索
memory: 是否启用记忆功能
storage_type: 存储类型
user_rag_memory_id: 用户 RAG 记忆 ID
Yields:
SSE 格式的事件流
"""
import json
start_time = time.time()
@@ -200,7 +203,8 @@ class MultiAgentOrchestrator:
except Exception as e:
logger.error(
"多 Agent 任务执行失败(流式)",
extra={"error": str(e), "mode": self._normalized_mode}
extra={"error": str(e), "mode": self._normalized_mode},
exc_info=True
)
# 发送错误事件
yield self._format_sse_event("error", {
@@ -1267,7 +1271,7 @@ class MultiAgentOrchestrator:
Yields:
SSE 格式的事件流
"""
from app.services.draft_run_service import DraftRunService
from app.services.draft_run_service import AgentRunService
# 获取模型配置
model_config = self.db.get(ModelConfig, agent_config.default_model_config_id)
@@ -1278,7 +1282,7 @@ class MultiAgentOrchestrator:
)
# 流式执行 Agent
draft_service = DraftRunService(self.db)
draft_service = AgentRunService(self.db)
async for event in draft_service.run_stream(
agent_config=agent_config,
model_config=model_config,
@@ -1320,7 +1324,7 @@ class MultiAgentOrchestrator:
Returns:
执行结果
"""
from app.services.draft_run_service import DraftRunService
from app.services.draft_run_service import AgentRunService
# 获取模型配置
model_config = self.db.get(ModelConfig, agent_config.default_model_config_id)
@@ -1331,7 +1335,7 @@ class MultiAgentOrchestrator:
)
# 执行 Agent
draft_service = DraftRunService(self.db)
draft_service = AgentRunService(self.db)
result = await draft_service.run(
agent_config=agent_config,
model_config=model_config,
@@ -1633,6 +1637,7 @@ class MultiAgentOrchestrator:
self.memory = config_data.get("memory")
self.variables = config_data.get("variables", [])
self.tools = config_data.get("tools", {})
self.skills = config_data.get("skills", {})
self.default_model_config_id = release.default_model_config_id
return AgentConfigProxy(release, app, config_data)
@@ -2593,6 +2598,7 @@ class MultiAgentOrchestrator:
provider=api_key_config.provider,
api_key=api_key_config.api_key,
base_url=api_key_config.api_base,
is_omni=api_key_config.is_omni,
temperature=0.7, # 整合任务使用中等温度
max_tokens=2000
)
@@ -2758,6 +2764,7 @@ class MultiAgentOrchestrator:
provider=api_key_config.provider,
api_key=api_key_config.api_key,
base_url=api_key_config.api_base,
is_omni=api_key_config.is_omni,
temperature=0.7,
max_tokens=2000,
extra_params={"streaming": True} # 启用流式输出

View File

@@ -267,7 +267,7 @@ class MultiAgentService:
# 2. 验证模型配置(如果提供了)
if data.default_model_config_id:
model_api_key = ModelApiKeyService.get_a_api_key(self.db, data.default_model_config_id)
model_api_key = ModelApiKeyService.get_available_api_key(self.db, data.default_model_config_id)
if not model_api_key:
raise ResourceNotFoundException("模型配置", str(data.default_model_config_id))

View File

@@ -9,47 +9,100 @@
- OpenAI: 支持 URL 和 base64 格式
"""
import uuid
from typing import List, Dict, Any, Optional, Protocol
import httpx
import base64
from typing import List, Dict, Any, Optional
from abc import ABC, abstractmethod
from sqlalchemy.orm import Session
from docx import Document
import io
import PyPDF2
from app.core.logging_config import get_business_logger
from app.core.exceptions import BusinessException
from app.core.error_codes import BizCode
from app.schemas.app_schema import FileInput, FileType, TransferMethod
from app.models.generic_file_model import GenericFile
from app.models.file_metadata_model import FileMetadata
from app.core.config import settings
from app.services.audio_transcription_service import AudioTranscriptionService
logger = get_business_logger()
class ImageFormatStrategy(Protocol):
"""图片格式策略接口"""
class MultimodalFormatStrategy(ABC):
"""多模态格式策略基类"""
@abstractmethod
async def format_image(self, url: str) -> Dict[str, Any]:
"""格式化图片"""
pass
@abstractmethod
async def format_document(self, file_name: str, text: str) -> Dict[str, Any]:
"""格式化文档"""
pass
@abstractmethod
async def format_audio(self, file_type: str, url: str) -> Dict[str, Any]:
"""格式化音频"""
pass
@abstractmethod
async def format_video(self, url: str) -> Dict[str, Any]:
"""格式化视频"""
pass
class DashScopeFormatStrategy(MultimodalFormatStrategy):
"""通义千问策略"""
async def format_image(self, url: str) -> Dict[str, Any]:
"""将图片 URL 转换为特定 provider 的格式"""
...
class DashScopeImageStrategy:
"""通义千问图片格式策略"""
async def format_image(self, url: str) -> Dict[str, Any]:
"""通义千问格式: {"type": "image", "image": "url"}"""
"""通义千问图片格式:{"type": "image", "image": "url"}"""
return {
"type": "image",
"image": url
}
async def format_document(self, file_name: str, text: str) -> Dict[str, Any]:
"""通义千问文档格式"""
return {
"type": "text",
"text": f"<document name=\"{file_name}\">\n{text}\n</document>"
}
class BedrockImageStrategy:
"""Bedrock/Anthropic 图片格式策略"""
async def format_audio(self, file_type: str, url: str, transcription: Optional[str] = None) -> Dict[str, Any]:
"""
通义千问音频格式
- 原生支持: qwen-audio 系列
- 其他模型: 需要转录为文本
"""
if transcription:
return {
"type": "text",
"text": f"<audio url=\"{url}\">\n{transcription}\n</audio>"
}
# 通义千问音频格式:{"type": "audio", "audio": "url"}
return {
"type": "audio",
"audio": url
}
async def format_video(self, url: str) -> Dict[str, Any]:
"""通义千问视频格式qwen-vl 系列原生支持)"""
return {
"type": "video",
"video": url
}
class BedrockFormatStrategy(MultimodalFormatStrategy):
"""Bedrock/Anthropic 策略"""
async def format_image(self, url: str) -> Dict[str, Any]:
"""
Bedrock/Anthropic 格式: base64 编码
{"type": "image", "source": {"type": "base64", "media_type": "...", "data": "..."}}
"""
import httpx
import base64
from mimetypes import guess_type
logger.info(f"下载并编码图片: {url}")
@@ -84,9 +137,46 @@ class BedrockImageStrategy:
}
}
async def format_document(self, file_name: str, text: str) -> Dict[str, Any]:
"""Bedrock/Anthropic 文档格式(需要 base64 编码)"""
# Bedrock 文档需要 base64 编码
text_bytes = text.encode('utf-8')
base64_text = base64.b64encode(text_bytes).decode('utf-8')
class OpenAIImageStrategy:
"""OpenAI 图片格式策略"""
return {
"type": "document",
"source": {
"type": "base64",
"media_type": "text/plain",
"data": base64_text
}
}
async def format_audio(self, file_type: str, url: str, transcription: Optional[str] = None) -> Dict[str, Any]:
"""
Bedrock/Anthropic 音频格式
不支持原生音频,必须转录为文本
"""
if transcription:
return {
"type": "text",
"text": f"[音频转录]\n{transcription}"
}
return {
"type": "text",
"text": "[音频文件Bedrock 不支持原生音频,请启用音频转文本功能]"
}
async def format_video(self, url: str) -> Dict[str, Any]:
"""Bedrock/Anthropic 视频格式"""
return {
"type": "text",
"text": f"<video url=\"{url}\">\n[视频文件,当前 provider 暂不支持]\n</video>"
}
class OpenAIFormatStrategy(MultimodalFormatStrategy):
"""OpenAI 策略"""
async def format_image(self, url: str) -> Dict[str, Any]:
"""OpenAI 格式: {"type": "image_url", "image_url": {"url": "..."}}"""
@@ -97,29 +187,97 @@ class OpenAIImageStrategy:
}
}
async def format_document(self, file_name: str, text: str) -> Dict[str, Any]:
"""OpenAI 文档格式"""
return {
"type": "text",
"text": f"<document name=\"{file_name}\">\n{text}\n</document>"
}
async def format_audio(self, file_type: str, url: str, transcription: Optional[str] = None) -> Dict[str, Any]:
"""
OpenAI 音频格式
- gpt-4o-audio 系列支持原生音频(需要 base64 编码)
- 其他模型使用转录文本
"""
if transcription:
return {
"type": "text",
"text": f"<audio url=\"{url}\">\n{transcription}\n</audio>"
}
# OpenAI 音频需要 base64 编码
try:
async with httpx.AsyncClient(timeout=30.0) as client:
response = await client.get(url)
response.raise_for_status()
audio_data = response.content
base64_audio = base64.b64encode(audio_data).decode('utf-8')
# 1. 优先从 file_type (MIME) 取扩展名
file_ext = file_type.split('/')[-1] if file_type and '/' in file_type else None
# 2. 从响应头 content-type 取
if not file_ext:
ct = response.headers.get("content-type", "")
file_ext = ct.split('/')[-1].split(';')[0].strip() if '/' in ct else None
# 3. 从 URL 路径取扩展名
if not file_ext:
file_ext = url.split('?')[0].rsplit('.', 1)[-1].lower() or None
# 4. 默认 wav
# supported_ext = {"wav", "mp3", "mp4", "ogg", "flac", "webm", "m4a", "wave", "x-m4a"}
file_ext = "wav" if not file_ext else file_ext
return {
"type": "input_audio",
"input_audio": {
"data": f"data:;base64,{base64_audio}",
"format": file_ext
}
}
except Exception as e:
logger.error(f"下载音频失败: {e}")
return {
"type": "text",
"text": f"[音频处理失败: {str(e)}]"
}
async def format_video(self, url: str) -> Dict[str, Any]:
"""OpenAI 视频格式"""
return {
"type": "video_url",
"video_url": {
"url": url
}
}
# Provider 到策略的映射
PROVIDER_STRATEGIES = {
"dashscope": DashScopeImageStrategy,
"bedrock": BedrockImageStrategy,
"anthropic": BedrockImageStrategy,
"openai": OpenAIImageStrategy,
"dashscope": DashScopeFormatStrategy,
"bedrock": BedrockFormatStrategy,
"anthropic": BedrockFormatStrategy,
"openai": OpenAIFormatStrategy,
}
class MultimodalService:
"""多模态文件处理服务"""
def __init__(self, db: Session, provider: str = "dashscope"):
def __init__(self, db: Session, provider: str = "dashscope", api_key: Optional[str] = None, enable_audio_transcription: bool = False, is_omni: bool = False):
"""
初始化多模态服务
Args:
db: 数据库会话
provider: 模型提供商dashscope, bedrock, anthropic 等)
provider: 模型提供商dashscope, bedrock, anthropic, openai 等)
api_key: API 密钥(用于音频转文本)
enable_audio_transcription: 是否启用音频转文本
is_omni: 是否为 Omni 模型dashscope 的 omni 模型需要使用 OpenAI 兼容格式)
"""
self.db = db
self.provider = provider.lower()
self.api_key = api_key
self.enable_audio_transcription = enable_audio_transcription
self.is_omni = is_omni
async def process_files(
self,
@@ -137,20 +295,32 @@ class MultimodalService:
if not files:
return []
# 获取对应的策略
# dashscope 的 omni 模型使用 OpenAI 兼容格式
if self.provider == "dashscope" and self.is_omni:
strategy_class = OpenAIFormatStrategy
else:
strategy_class = PROVIDER_STRATEGIES.get(self.provider)
if not strategy_class:
logger.warning(f"未找到 provider '{self.provider}' 的策略,使用默认策略")
strategy_class = DashScopeFormatStrategy
strategy = strategy_class()
result = []
for idx, file in enumerate(files):
try:
if file.type == FileType.IMAGE:
content = await self._process_image(file)
content = await self._process_image(file, strategy)
result.append(content)
elif file.type == FileType.DOCUMENT:
content = await self._process_document(file)
content = await self._process_document(file, strategy)
result.append(content)
elif file.type == FileType.AUDIO:
content = await self._process_audio(file)
content = await self._process_audio(file, strategy)
result.append(content)
elif file.type == FileType.VIDEO:
content = await self._process_video(file)
content = await self._process_video(file, strategy)
result.append(content)
else:
logger.warning(f"不支持的文件类型: {file.type}")
@@ -172,55 +342,29 @@ class MultimodalService:
logger.info(f"成功处理 {len(result)}/{len(files)} 个文件provider={self.provider}")
return result
async def _process_image(self, file: FileInput) -> Dict[str, Any]:
async def _process_image(self, file: FileInput, strategy) -> Dict[str, Any]:
"""
处理图片文件
Args:
file: 图片文件输入
strategy: 格式化策略
Returns:
Dict: 根据 provider 返回不同格式
- Anthropic/Bedrock: {"type": "image", "source": {"type": "base64", "media_type": "...", "data": "..."}}
- 通义千问: {"type": "image", "image": "url"}
Dict: 根据 provider 返回不同格式的图片内容
"""
url = await self.get_file_url(file)
logger.debug(f"处理图片: {url}, provider={self.provider}")
# 根据 provider 返回不同格式
if self.provider in ["bedrock", "anthropic"]:
# Anthropic/Bedrock 只支持 base64 格式,需要下载并转换
try:
logger.info(f"开始下载并编码图片: {url}")
base64_data, media_type = await self._download_and_encode_image(url)
result = {
"type": "image",
"source": {
"type": "base64",
"media_type": media_type,
"data": base64_data[:100] + "..." # 只记录前100个字符
}
}
logger.info(f"图片编码完成: media_type={media_type}, data_length={len(base64_data)}")
# 返回完整数据
result["source"]["data"] = base64_data
return result
except Exception as e:
logger.error(f"下载并编码图片失败: {e}", exc_info=True)
# 返回错误提示
return {
"type": "text",
"text": f"[图片加载失败: {str(e)}]"
}
else:
# 通义千问等其他格式支持 URL
try:
url = await self.get_file_url(file)
return await strategy.format_image(url)
except Exception as e:
logger.error(f"处理图片失败: {e}", exc_info=True)
return {
"type": "image",
"image": url
"type": "text",
"text": f"[图片处理失败: {str(e)}]"
}
async def _download_and_encode_image(self, url: str) -> tuple[str, str]:
@staticmethod
async def _download_and_encode_image(url: str) -> tuple[str, str]:
"""
下载图片并转换为 base64
@@ -230,8 +374,6 @@ class MultimodalService:
Returns:
tuple: (base64_data, media_type)
"""
import httpx
import base64
from mimetypes import guess_type
# 下载图片
@@ -258,15 +400,16 @@ class MultimodalService:
return base64_data, media_type
async def _process_document(self, file: FileInput) -> Dict[str, Any]:
async def _process_document(self, file: FileInput, strategy) -> Dict[str, Any]:
"""
处理文档文件PDF、Word 等)
Args:
file: 文档文件输入
strategy: 格式化策略
Returns:
Dict: text 格式的内容(包含提取的文本)
Dict: 根据 provider 返回不同格式的文档内容
"""
if file.transfer_method == TransferMethod.REMOTE_URL:
# 远程文档暂不支持提取
@@ -277,48 +420,68 @@ class MultimodalService:
else:
# 本地文件,提取文本内容
text = await self._extract_document_text(file.upload_file_id)
generic_file = self.db.query(GenericFile).filter(
GenericFile.id == file.upload_file_id
file_metadata = self.db.query(FileMetadata).filter(
FileMetadata.id == file.upload_file_id
).first()
file_name = generic_file.file_name if generic_file else "unknown"
file_name = file_metadata.file_name if file_metadata else "unknown"
return {
"type": "text",
"text": f"<document name=\"{file_name}\">\n{text}\n</document>"
}
# 使用策略格式化文档
return await strategy.format_document(file_name, text)
async def _process_audio(self, file: FileInput) -> Dict[str, Any]:
async def _process_audio(self, file: FileInput, strategy) -> Dict[str, Any]:
"""
处理音频文件
Args:
file: 音频文件输入
strategy: 格式化策略
Returns:
Dict: 音频内容(暂时返回占位符)
Dict: 根据 provider 返回不同格式的音频内容
"""
# TODO: 实现音频转文字功能
return {
"type": "text",
"text": "[音频文件,暂不支持处理]"
}
try:
url = await self.get_file_url(file)
async def _process_video(self, file: FileInput) -> Dict[str, Any]:
# 如果启用音频转文本且有 API Key
transcription = None
if self.enable_audio_transcription and self.api_key:
logger.info(f"开始音频转文本: {url}")
if self.provider == "dashscope":
transcription = await AudioTranscriptionService.transcribe_dashscope(url, self.api_key)
elif self.provider == "openai":
transcription = await AudioTranscriptionService.transcribe_openai(url, self.api_key)
else:
logger.warning(f"Provider {self.provider} 不支持音频转文本")
return await strategy.format_audio(file.file_type, url, transcription)
except Exception as e:
logger.error(f"处理音频失败: {e}", exc_info=True)
return {
"type": "text",
"text": f"[音频处理失败: {str(e)}]"
}
async def _process_video(self, file: FileInput, strategy) -> Dict[str, Any]:
"""
处理视频文件
Args:
file: 视频文件输入
strategy: 格式化策略
Returns:
Dict: 视频内容(暂时返回占位符)
Dict: 根据 provider 返回不同格式的视频内容
"""
# TODO: 实现视频处理功能
return {
"type": "text",
"text": "[视频文件,暂不支持处理]"
}
try:
url = await self.get_file_url(file)
return await strategy.format_video(url)
except Exception as e:
logger.error(f"处理视频失败: {e}", exc_info=True)
return {
"type": "text",
"text": f"[视频处理失败: {str(e)}]"
}
async def get_file_url(self, file: FileInput) -> str:
"""
@@ -336,26 +499,22 @@ class MultimodalService:
if file.transfer_method == TransferMethod.REMOTE_URL:
return file.url
else:
# 本地文件,通过 file_storage 系统获取永久访问 URL
from app.models.file_metadata_model import FileMetadata
from app.core.config import settings
file_id = file.upload_file_id
print("="*50)
print("file_id",file_id)
# 查询 FileMetadata
file_metadata = self.db.query(FileMetadata).filter(
FileMetadata.id == file_id,
FileMetadata.status == "completed"
).first()
if not file_metadata:
raise BusinessException(
f"文件不存在或已删除: {file_id}",
BizCode.NOT_FOUND
)
# 返回永久URL
server_url = settings.FILE_LOCAL_SERVER_URL
return f"{server_url}/storage/permanent/{file_id}"
@@ -370,58 +529,79 @@ class MultimodalService:
Returns:
str: 提取的文本内容
"""
generic_file = self.db.query(GenericFile).filter(
GenericFile.id == file_id,
GenericFile.status == "active"
file_metadata = self.db.query(FileMetadata).filter(
FileMetadata.id == file_id,
FileMetadata.status == "completed"
).first()
if not generic_file:
if not file_metadata:
raise BusinessException(
f"文件不存在或已删除: {file_id}",
BizCode.NOT_FOUND
)
# TODO: 根据文件类型提取文本
# - PDF: 使用 PyPDF2 或 pdfplumber
# - Word: 使用 python-docx
# - TXT/MD: 直接读取
file_ext = generic_file.file_ext.lower()
file_ext = file_metadata.file_ext.lower()
server_url = settings.FILE_LOCAL_SERVER_URL
file_url = f"{server_url}/storage/permanent/{file_id}"
if file_ext in ['.txt', '.md', '.markdown']:
return await self._read_text_file(generic_file.storage_path)
return await self._read_text_file(file_url)
elif file_ext == '.pdf':
return await self._extract_pdf_text(generic_file.storage_path)
return await self._extract_pdf_text(file_url)
elif file_ext in ['.doc', '.docx']:
return await self._extract_word_text(generic_file.storage_path)
return await self._extract_word_text(file_url)
else:
return f"[不支持的文档格式: {file_ext}]"
async def _read_text_file(self, storage_path: str) -> str:
@staticmethod
async def _read_text_file(file_url: str) -> str:
"""读取纯文本文件"""
try:
with open(storage_path, 'r', encoding='utf-8') as f:
return f.read()
# 下载文件
async with httpx.AsyncClient(timeout=30.0) as client:
response = await client.get(file_url)
response.raise_for_status()
return response.text
except Exception as e:
logger.error(f"读取文本文件失败: {e}")
return f"[文件读取失败: {str(e)}]"
async def _extract_pdf_text(self, storage_path: str) -> str:
@staticmethod
async def _extract_pdf_text(file_url: str) -> str:
"""提取 PDF 文本"""
try:
# TODO: 实现 PDF 文本提取
# import PyPDF2 或 pdfplumber
return "[PDF 文本提取功能待实现]"
# 下载 PDF 文
async with httpx.AsyncClient(timeout=30.0) as client:
response = await client.get(file_url)
response.raise_for_status()
pdf_data = response.content
# 使用 BytesIO 读取 PDF
text_parts = []
pdf_file = io.BytesIO(pdf_data)
pdf_reader = PyPDF2.PdfReader(pdf_file)
for page in pdf_reader.pages:
text_parts.append(page.extract_text())
return '\n'.join(text_parts)
except Exception as e:
logger.error(f"提取 PDF 文本失败: {e}")
return f"[PDF 提取失败: {str(e)}]"
async def _extract_word_text(self, storage_path: str) -> str:
@staticmethod
async def _extract_word_text(file_url: str) -> str:
"""提取 Word 文档文本"""
try:
# TODO: 实现 Word 文本提取
# import docx
return "[Word 文本提取功能待实现]"
# 下载 Word 文
async with httpx.AsyncClient(timeout=30.0) as client:
response = await client.get(file_url)
response.raise_for_status()
word_data = response.content
# 使用 BytesIO 读取 Word 文档
word_file = io.BytesIO(word_data)
doc = Document(word_file)
text_parts = [paragraph.text for paragraph in doc.paragraphs]
return '\n'.join(text_parts)
except Exception as e:
logger.error(f"提取 Word 文本失败: {e}")
return f"[Word 提取失败: {str(e)}]"

View File

@@ -101,34 +101,141 @@ async def run_pilot_extraction(
)
if progress_callback:
await progress_callback("text_preprocessing", "开始预处理文本...")
await progress_callback("text_preprocessing", "开始预处理文本(语义剪枝 + 语义分块)...")
# ========== 步骤 2.1: 语义剪枝 ==========
pruned_dialogs = [dialog]
deleted_messages = [] # 记录被删除的消息
pruning_stats = None # 保存剪枝统计信息,用于最终汇总
if memory_config.pruning_enabled:
try:
from app.core.memory.storage_services.extraction_engine.data_preprocessing.data_pruning import (
SemanticPruner,
)
from app.core.memory.models.config_models import PruningConfig
# 构建剪枝配置
pruning_config_dict = {
"pruning_switch": memory_config.pruning_enabled,
"pruning_scene": memory_config.pruning_scene,
"pruning_threshold": memory_config.pruning_threshold,
"llm_model_id": str(memory_config.llm_model_id),
}
config = PruningConfig(**pruning_config_dict)
logger.info(f"[PILOT_RUN] 开始语义剪枝: scene={config.pruning_scene}, threshold={config.pruning_threshold}")
# 记录剪枝前的消息(用于对比)
original_messages = [{"role": msg.role, "content": msg.msg} for msg in dialog.context.msgs]
original_msg_count = len(original_messages)
# 执行剪枝
pruner = SemanticPruner(config=config, llm_client=llm_client)
pruned_dialogs = await pruner.prune_dataset([dialog])
# 计算剪枝结果并找出被删除的消息
if pruned_dialogs and pruned_dialogs[0].context:
remaining_messages = [{"role": msg.role, "content": msg.msg} for msg in pruned_dialogs[0].context.msgs]
remaining_msg_count = len(remaining_messages)
deleted_msg_count = original_msg_count - remaining_msg_count
# 找出被删除的消息(基于索引精确匹配)
# 为剩余消息创建带索引的列表,用于精确追踪
remaining_with_index = []
remaining_idx = 0
for orig_idx, orig_msg in enumerate(original_messages):
if remaining_idx < len(remaining_messages) and \
orig_msg["role"] == remaining_messages[remaining_idx]["role"] and \
orig_msg["content"] == remaining_messages[remaining_idx]["content"]:
remaining_with_index.append(orig_idx)
remaining_idx += 1
# 找出未在保留列表中的消息索引
deleted_messages = [
{"index": idx, "role": msg["role"], "content": msg["content"]}
for idx, msg in enumerate(original_messages)
if idx not in remaining_with_index
]
# 保存剪枝统计信息用于最终汇总只保留deleted_count
pruning_stats = {
"enabled": True,
"scene": config.pruning_scene,
"threshold": config.pruning_threshold,
"deleted_count": deleted_msg_count,
}
# 输出剪枝结果(显示删除的消息详情)
pruning_result = {
"type": "pruning",
"deleted_messages": deleted_messages,
}
logger.info(
f"[PILOT_RUN] 语义剪枝完成: 原始{original_msg_count}条 -> "
f"保留{remaining_msg_count}条 (删除{deleted_msg_count}条)"
)
if progress_callback:
await progress_callback("text_preprocessing_result", "语义剪枝完成", pruning_result)
else:
logger.warning("[PILOT_RUN] 剪枝后对话为空,使用原始对话")
pruned_dialogs = [dialog]
except Exception as e:
logger.error(f"[PILOT_RUN] 语义剪枝失败,使用原始对话: {e}", exc_info=True)
pruned_dialogs = [dialog]
if progress_callback:
error_result = {
"type": "pruning",
"error": str(e),
"fallback": "使用原始对话"
}
await progress_callback("text_preprocessing_result", "语义剪枝失败", error_result)
else:
logger.info("[PILOT_RUN] 语义剪枝已关闭,跳过")
pruning_stats = {
"enabled": False,
}
# ========== 步骤 2.2: 语义分块 ==========
chunked_dialogs = await get_chunked_dialogs_from_preprocessed(
data=[dialog],
data=pruned_dialogs,
chunker_strategy=memory_config.chunker_strategy,
llm_client=llm_client,
)
logger.info(f"Processed dialogue text: {len(messages)} messages")
remaining_msg_count = len(pruned_dialogs[0].context.msgs) if pruned_dialogs and pruned_dialogs[0].context else 0
logger.info(f"Processed dialogue text: {remaining_msg_count} messages after pruning")
# 进度回调:输出每个分块的结果
if progress_callback:
for dlg in chunked_dialogs:
for i, chunk in enumerate(dlg.chunks):
chunk_result = {
"chunk_index": i + 1,
"content": chunk.content[:200] + "..." if len(chunk.content) > 200 else chunk.content,
"full_length": len(chunk.content),
"dialog_id": dlg.id,
"chunker_strategy": memory_config.chunker_strategy,
}
await progress_callback("text_preprocessing_result", f"分块 {i + 1} 处理完成", chunk_result)
if hasattr(dlg, 'chunks') and dlg.chunks:
for i, chunk in enumerate(dlg.chunks):
chunk_result = {
"type": "chunking",
"chunk_index": i + 1,
"content": chunk.content[:200] + "..." if len(chunk.content) > 200 else chunk.content,
"full_length": len(chunk.content),
"dialog_id": dlg.id,
"chunker_strategy": memory_config.chunker_strategy,
}
await progress_callback("text_preprocessing_result", f"分块 {i + 1} 处理完成", chunk_result)
# 构建预处理完成总结(包含剪枝统计)
preprocessing_summary = {
"total_chunks": sum(len(dlg.chunks) for dlg in chunked_dialogs),
"total_chunks": sum(len(dlg.chunks) for dlg in chunked_dialogs if hasattr(dlg, 'chunks') and dlg.chunks),
"total_dialogs": len(chunked_dialogs),
"chunker_strategy": memory_config.chunker_strategy,
}
await progress_callback("text_preprocessing_complete", "预处理文本完成", preprocessing_summary)
# 添加剪枝统计信息
if pruning_stats:
preprocessing_summary["pruning"] = pruning_stats
await progress_callback("text_preprocessing_complete", "预处理文本完成(剪枝 + 分块)", preprocessing_summary)
log_time("Data Loading & Chunking", time.time() - step_start, log_file)

View File

@@ -184,7 +184,8 @@ class PromptOptimizerService:
model_name=api_config.model_name,
provider=api_config.provider,
api_key=api_config.api_key,
base_url=api_config.api_base
base_url=api_config.api_base,
is_omni=api_config.is_omni
), type=ModelType(model_config.type))
try:
prompt_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'prompt')

View File

@@ -21,63 +21,64 @@ from app.repositories import knowledge_repository
import json
from app.services.task_service import get_task_memory_write_result
from app.tasks import write_message_task
logger = get_business_logger()
class SharedChatService:
"""基于分享链接的聊天服务"""
def __init__(self, db: Session):
self.db = db
self.conversation_service = ConversationService(db)
self.share_service = ReleaseShareService(db)
def _get_release_by_share_token(
self,
share_token: str,
password: Optional[str] = None
def get_release_by_share_token(
self,
share_token: str,
password: Optional[str] = None
) -> tuple[ReleaseShare, AppRelease]:
"""通过 share_token 获取发布版本"""
# 获取分享配置
share = self.share_service.repo.get_by_share_token(share_token)
if not share:
raise ResourceNotFoundException("分享链接", share_token)
# 验证分享是否启用
if not share.is_enabled:
raise BusinessException("该分享链接已被禁用", BizCode.SHARE_DISABLED)
# 验证密码
if share.require_password:
if not password:
raise BusinessException("需要提供访问密码", BizCode.PASSWORD_REQUIRED)
if not self.share_service.verify_password(share_token, password):
raise BusinessException("访问密码错误", BizCode.INVALID_PASSWORD)
# 获取发布版本
release = self.db.get(AppRelease, share.release_id)
if not release:
raise ResourceNotFoundException("发布版本", str(share.release_id))
# 更新访问统计
try:
self.share_service.repo.increment_view_count(share.id)
except Exception as e:
logger.warning(f"更新访问统计失败: {str(e)}")
return share, release
def create_or_get_conversation(
self,
share_token: str,
conversation_id: Optional[uuid.UUID] = None,
user_id: Optional[str] = None,
password: Optional[str] = None
self,
share_token: str,
conversation_id: Optional[uuid.UUID] = None,
user_id: Optional[str] = None,
password: Optional[str] = None
) -> Conversation:
"""创建或获取会话"""
share, release = self._get_release_by_share_token(share_token, password)
share, release = self.get_release_by_share_token(share_token, password)
# 如果提供了 conversation_id尝试获取现有会话
if conversation_id:
try:
@@ -85,18 +86,18 @@ class SharedChatService:
conversation_id=conversation_id,
workspace_id=release.app.workspace_id
)
# 验证会话是否属于该应用
if conversation.app_id != release.app_id:
raise BusinessException("会话不属于该应用", BizCode.INVALID_CONVERSATION)
return conversation
except ResourceNotFoundException:
logger.warning(
"会话不存在,将创建新会话",
extra={"conversation_id": str(conversation_id)}
)
# 创建新会话(使用发布版本的配置)
conversation = self.conversation_service.create_conversation(
app_id=release.app_id,
@@ -105,7 +106,7 @@ class SharedChatService:
is_draft=False, # 分享链接使用发布版本
config_snapshot=release.config
)
logger.info(
"为分享链接创建新会话",
extra={
@@ -114,25 +115,25 @@ class SharedChatService:
"release_id": str(release.id)
}
)
return conversation
async def chat(
self,
share_token: str,
message: str,
conversation_id: Optional[uuid.UUID] = None,
user_id: Optional[str] = None,
variables: Optional[Dict[str, Any]] = None,
password: Optional[str] = None,
web_search: bool = False,
memory: bool = True,
self,
share_token: str,
message: str,
conversation_id: Optional[uuid.UUID] = None,
user_id: Optional[str] = None,
variables: Optional[Dict[str, Any]] = None,
password: Optional[str] = None,
web_search: bool = False,
memory: bool = True,
storage_type: Optional[str] = None,
user_rag_memory_id: Optional[str] = None,
) -> Dict[str, Any]:
"""聊天(非流式)"""
actual_config_id = None
config_id=actual_config_id
config_id = actual_config_id
from app.core.agent.langchain_agent import LangChainAgent
from app.services.draft_run_service import create_knowledge_retrieval_tool, create_long_term_memory_tool
from app.services.model_parameter_merger import ModelParameterMerger
@@ -140,32 +141,30 @@ class SharedChatService:
from sqlalchemy import select
from app.models import ModelApiKey
start_time = time.time()
actual_config_id=None
config_id=actual_config_id
actual_config_id = None
config_id = actual_config_id
if variables is None:
variables = {}
# 获取发布版本和配置
share, release = self._get_release_by_share_token(share_token, password)
share, release = self.get_release_by_share_token(share_token, password)
# 获取 Agent 配置
config = release.config or {}
# 获取模型配置ID
model_config_id = release.default_model_config_id
if not model_config_id:
raise BusinessException("发布版本未配置模型", BizCode.AGENT_CONFIG_MISSING)
# 获取模型配置
from app.models import ModelConfig
model_config = self.db.get(ModelConfig, model_config_id)
if not model_config:
raise ResourceNotFoundException("模型配置", str(model_config_id))
# 获取 API Key
# stmt = (
# select(ModelApiKey).join(
@@ -184,7 +183,7 @@ class SharedChatService:
api_key_obj = ModelApiKeyService.get_available_api_key(self.db, model_config_id)
if not api_key_obj:
raise BusinessException("没有可用的 API Key", BizCode.AGENT_CONFIG_MISSING)
# 获取或创建会话
conversation = self.create_or_get_conversation(
share_token=share_token,
@@ -192,7 +191,7 @@ class SharedChatService:
user_id=user_id,
password=password
)
# 处理系统提示词(支持变量替换)
system_prompt = config.get("system_prompt", "你是一个专业的AI助手")
if variables:
@@ -202,31 +201,31 @@ class SharedChatService:
variables
)
system_prompt = system_prompt_rendered.get_text_content() or system_prompt
# 准备工具列表
tools = []
# 添加知识库检索工具
knowledge_retrieval = config.get("knowledge_retrieval")
if knowledge_retrieval:
knowledge_bases = knowledge_retrieval.get("knowledge_bases", [])
kb_ids = [kb.get("kb_id") for kb in knowledge_bases if kb.get("kb_id")]
if kb_ids:
kb_tool = create_knowledge_retrieval_tool(knowledge_retrieval, kb_ids,user_id)
kb_tool = create_knowledge_retrieval_tool(knowledge_retrieval, kb_ids, user_id)
tools.append(kb_tool)
# 添加长期记忆工具
memory_flag=False
memory_flag = False
if memory:
memory_config = config.get("memory", {})
if memory_config.get("enabled") and user_id:
memory_flag=True
memory_flag = True
memory_tool = create_long_term_memory_tool(memory_config, user_id)
tools.append(memory_tool)
web_tools=config.get("tools")
web_tools = config.get("tools")
web_search_choice = web_tools.get("web_search", {})
web_search_enable = web_search_choice.get("enabled",False)
web_search_enable = web_search_choice.get("enabled", False)
if web_search:
if web_search_enable:
search_tool = create_web_search_tool({})
@@ -238,26 +237,27 @@ class SharedChatService:
"tool_count": len(tools)
}
)
# 获取模型参数
model_parameters = config.get("model_parameters", {})
# 创建 LangChain Agent
agent = LangChainAgent(
model_name=api_key_obj.model_name,
api_key=api_key_obj.api_key,
provider=api_key_obj.provider,
api_base=api_key_obj.api_base,
is_omni=api_key_obj.is_omni,
temperature=model_parameters.get("temperature", 0.7),
max_tokens=model_parameters.get("max_tokens", 2000),
system_prompt=system_prompt,
tools=tools,
)
# 加载历史消息
history = []
memory_config={"enabled":True,'max_history':10}
memory_config = {"enabled": True, 'max_history': 10}
if memory_config.get("enabled"):
messages = self.conversation_service.get_messages(
conversation_id=conversation.id,
@@ -267,7 +267,7 @@ class SharedChatService:
{"role": msg.role, "content": msg.content}
for msg in messages
]
# 调用 Agent
result = await agent.chat(
message=message,
@@ -279,7 +279,7 @@ class SharedChatService:
config_id=config_id,
memory_flag=memory_flag
)
# 保存消息
self.conversation_service.save_conversation_messages(
conversation_id=conversation.id,
@@ -298,7 +298,7 @@ class SharedChatService:
# role="user",
# content=message
# )
# self.conversation_service.add_message(
# conversation_id=conversation.id,
# role="assistant",
@@ -308,12 +308,11 @@ class SharedChatService:
# "usage": result.get("usage", {})
# }
# )
elapsed_time = time.time() - start_time
ModelApiKeyService.record_api_key_usage(self.db, api_key_obj.id)
return {
"conversation_id": conversation.id,
"message": result["content"],
@@ -324,19 +323,19 @@ class SharedChatService:
}),
"elapsed_time": elapsed_time
}
async def chat_stream(
self,
share_token: str,
message: str,
conversation_id: Optional[uuid.UUID] = None,
user_id: Optional[str] = None,
variables: Optional[Dict[str, Any]] = None,
password: Optional[str] = None,
web_search: bool = False,
memory: bool = True,
storage_type:Optional[str] = None,
user_rag_memory_id: Optional[str] = None,
self,
share_token: str,
message: str,
conversation_id: Optional[uuid.UUID] = None,
user_id: Optional[str] = None,
variables: Optional[Dict[str, Any]] = None,
password: Optional[str] = None,
web_search: bool = False,
memory: bool = True,
storage_type: Optional[str] = None,
user_rag_memory_id: Optional[str] = None,
) -> AsyncGenerator[str, None]:
"""聊天(流式)"""
from app.core.agent.langchain_agent import LangChainAgent
@@ -345,36 +344,35 @@ class SharedChatService:
from sqlalchemy import select
from app.models import ModelApiKey
import json
start_time = time.time()
actual_config_id=None
config_id=actual_config_id
start_time = time.time()
actual_config_id = None
config_id = actual_config_id
if variables is None:
variables = {}
# 兼容新旧字段名:使用 memory_config_id
memory_config = {"enabled": memory, "memory_config_id": "17", "max_history": 10}
try:
# 获取发布版本和配置
share, release = self._get_release_by_share_token(share_token, password)
share, release = self.get_release_by_share_token(share_token, password)
# 获取 Agent 配置
config = release.config or {}
agent_config_data = config.get("agent_config", {})
# 获取模型配置ID
model_config_id = release.default_model_config_id
if not model_config_id:
raise BusinessException("发布版本未配置模型", BizCode.AGENT_CONFIG_MISSING)
# 获取模型配置
from app.models import ModelConfig
model_config = self.db.get(ModelConfig, model_config_id)
if not model_config:
raise ResourceNotFoundException("模型配置", str(model_config_id))
# 获取 API Key
# stmt = (
# select(ModelApiKey).join(
@@ -393,7 +391,7 @@ class SharedChatService:
api_key_obj = ModelApiKeyService.get_available_api_key(self.db, model_config_id)
if not api_key_obj:
raise BusinessException("没有可用的 API Key", BizCode.AGENT_CONFIG_MISSING)
# 获取或创建会话
conversation = self.create_or_get_conversation(
share_token=share_token,
@@ -401,7 +399,7 @@ class SharedChatService:
user_id=user_id,
password=password
)
# 处理系统提示词(支持变量替换)
system_prompt = config.get("system_prompt", "你是一个专业的AI助手")
if variables:
@@ -411,21 +409,21 @@ class SharedChatService:
variables
)
system_prompt = system_prompt_rendered.get_text_content() or system_prompt
# 准备工具列表
tools = []
# 添加知识库检索工具
knowledge_retrieval = config.get("knowledge_retrieval")
if knowledge_retrieval:
knowledge_bases = knowledge_retrieval.get("knowledge_bases", [])
kb_ids = [kb.get("kb_id") for kb in knowledge_bases if kb.get("kb_id")]
if kb_ids:
kb_tool = create_knowledge_retrieval_tool(knowledge_retrieval, kb_ids,user_id)
kb_tool = create_knowledge_retrieval_tool(knowledge_retrieval, kb_ids, user_id)
tools.append(kb_tool)
# 添加长期记忆工具
memory_flag=False
memory_flag = False
if memory:
memory_config = config.get("memory", {})
if memory_config.get("enabled") and user_id:
@@ -450,20 +448,21 @@ class SharedChatService:
# 获取模型参数
model_parameters = config.get("model_parameters", {})
# 创建 LangChain Agent
agent = LangChainAgent(
model_name=api_key_obj.model_name,
api_key=api_key_obj.api_key,
provider=api_key_obj.provider,
api_base=api_key_obj.api_base,
is_omni=api_key_obj.is_omni,
temperature=model_parameters.get("temperature", 0.7),
max_tokens=model_parameters.get("max_tokens", 2000),
system_prompt=system_prompt,
tools=tools,
streaming=True
)
# 加载历史消息
history = []
memory_config = {"enabled": True, 'max_history': 10}
@@ -476,22 +475,22 @@ class SharedChatService:
{"role": msg.role, "content": msg.content}
for msg in messages
]
# 发送开始事件
yield f"event: start\ndata: {json.dumps({'conversation_id': str(conversation.id)}, ensure_ascii=False)}\n\n"
# 流式调用 Agent
full_content = ""
total_tokens = 0
async for chunk in agent.chat_stream(
message=message,
history=history,
context=None,
end_user_id=user_id,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id,
config_id=config_id,
memory_flag=memory_flag
message=message,
history=history,
context=None,
end_user_id=user_id,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id,
config_id=config_id,
memory_flag=memory_flag
):
if isinstance(chunk, int):
total_tokens = chunk
@@ -499,16 +498,16 @@ class SharedChatService:
full_content += chunk
# 发送消息块事件
yield f"event: message\ndata: {json.dumps({'content': chunk}, ensure_ascii=False)}\n\n"
elapsed_time = time.time() - start_time
# 保存消息
self.conversation_service.add_message(
conversation_id=conversation.id,
role="user",
content=message
)
self.conversation_service.add_message(
conversation_id=conversation.id,
role="assistant",
@@ -524,7 +523,7 @@ class SharedChatService:
# 发送结束事件
end_data = {"elapsed_time": elapsed_time, "message_length": len(full_content)}
yield f"event: end\ndata: {json.dumps(end_data, ensure_ascii=False)}\n\n"
logger.info(
"流式聊天完成",
extra={
@@ -533,7 +532,7 @@ class SharedChatService:
"message_length": len(full_content)
}
)
except (GeneratorExit, asyncio.CancelledError):
# 生成器被关闭或任务被取消,正常退出
logger.debug("流式聊天被中断")
@@ -542,39 +541,39 @@ class SharedChatService:
logger.error(f"流式聊天失败: {str(e)}", exc_info=True)
# 发送错误事件
yield f"event: error\ndata: {json.dumps({'error': str(e)}, ensure_ascii=False)}\n\n"
def get_conversation_messages(
self,
share_token: str,
conversation_id: uuid.UUID,
password: Optional[str] = None
self,
share_token: str,
conversation_id: uuid.UUID,
password: Optional[str] = None
) -> Conversation:
"""获取会话消息"""
share, release = self._get_release_by_share_token(share_token, password)
share, release = self.get_release_by_share_token(share_token, password)
# 获取会话
conversation = self.conversation_service.get_conversation(
conversation_id=conversation_id,
workspace_id=release.app.workspace_id
)
# 验证会话是否属于该应用
if conversation.app_id != release.app_id:
raise BusinessException("会话不属于该应用", BizCode.INVALID_CONVERSATION)
return conversation
def list_conversations(
self,
share_token: str,
user_id: Optional[str] = None,
password: Optional[str] = None,
page: int = 1,
pagesize: int = 20
self,
share_token: str,
user_id: Optional[str] = None,
password: Optional[str] = None,
page: int = 1,
pagesize: int = 20
) -> tuple[list[Conversation], int]:
"""列出会话"""
share, release = self._get_release_by_share_token(share_token, password)
share, release = self.get_release_by_share_token(share_token, password)
conversations, total = self.conversation_service.list_conversations(
app_id=release.app_id,
workspace_id=release.app.workspace_id,
@@ -583,19 +582,19 @@ class SharedChatService:
page=page,
pagesize=pagesize
)
return conversations, total
async def multi_agent_chat(
self,
share_token: str,
message: str,
conversation_id: Optional[uuid.UUID] = None,
user_id: Optional[str] = None,
variables: Optional[Dict[str, Any]] = None,
password: Optional[str] = None,
web_search: bool = False,
memory: bool = True,
self,
share_token: str,
message: str,
conversation_id: Optional[uuid.UUID] = None,
user_id: Optional[str] = None,
variables: Optional[Dict[str, Any]] = None,
password: Optional[str] = None,
web_search: bool = False,
memory: bool = True,
storage_type: Optional[str] = None,
user_rag_memory_id: Optional[str] = None
) -> Dict[str, Any]:
@@ -603,18 +602,16 @@ class SharedChatService:
from app.services.multi_agent_service import MultiAgentService
from app.models import MultiAgentConfig
start_time = time.time()
actual_config_id=None
config_id=actual_config_id
actual_config_id = None
config_id = actual_config_id
if variables is None:
variables = {}
# 获取发布版本和配置
share, release = self._get_release_by_share_token(share_token, password)
share, release = self.get_release_by_share_token(share_token, password)
# 获取或创建会话
conversation = self.create_or_get_conversation(
share_token=share_token,
@@ -622,19 +619,19 @@ class SharedChatService:
user_id=user_id,
password=password
)
# 获取多 Agent 配置
multi_agent_config = self.db.query(MultiAgentConfig).filter(
MultiAgentConfig.app_id == release.app_id,
MultiAgentConfig.is_active.is_(True)
).first()
if not multi_agent_config:
raise BusinessException("多 Agent 配置不存在", BizCode.AGENT_CONFIG_MISSING)
# 构建多 Agent 运行请求
from app.schemas.multi_agent_schema import MultiAgentRunRequest
multi_agent_request = MultiAgentRunRequest(
message=message,
conversation_id=conversation.id,
@@ -644,23 +641,23 @@ class SharedChatService:
web_search=web_search,
memory=memory
)
# 使用多 Agent 服务执行
multi_agent_service = MultiAgentService(self.db)
result = await multi_agent_service.run(
app_id=release.app_id,
request=multi_agent_request
)
elapsed_time = time.time() - start_time
# 保存消息
self.conversation_service.add_message(
conversation_id=conversation.id,
role="user",
content=message
)
self.conversation_service.add_message(
conversation_id=conversation.id,
role="assistant",
@@ -672,8 +669,6 @@ class SharedChatService:
}
)
return {
"conversation_id": conversation.id,
"message": result.get("message", ""),
@@ -684,34 +679,33 @@ class SharedChatService:
},
"elapsed_time": elapsed_time
}
async def multi_agent_chat_stream(
self,
share_token: str,
message: str,
conversation_id: Optional[uuid.UUID] = None,
user_id: Optional[str] = None,
variables: Optional[Dict[str, Any]] = None,
password: Optional[str] = None,
web_search: bool = False,
memory: bool = True,
self,
share_token: str,
message: str,
conversation_id: Optional[uuid.UUID] = None,
user_id: Optional[str] = None,
variables: Optional[Dict[str, Any]] = None,
password: Optional[str] = None,
web_search: bool = False,
memory: bool = True,
storage_type: Optional[str] = None,
user_rag_memory_id:Optional[str] = None
user_rag_memory_id: Optional[str] = None
) -> AsyncGenerator[str, None]:
"""多 Agent 聊天(流式)"""
start_time = time.time()
actual_config_id=None
config_id=actual_config_id
actual_config_id = None
config_id = actual_config_id
if variables is None:
variables = {}
try:
# 获取发布版本和配置
share, release = self._get_release_by_share_token(share_token, password)
share, release = self.get_release_by_share_token(share_token, password)
# 获取或创建会话
conversation = self.create_or_get_conversation(
share_token=share_token,
@@ -719,28 +713,28 @@ class SharedChatService:
user_id=user_id,
password=password
)
# 获取多 Agent 配置
multi_agent_config = self.db.query(MultiAgentConfig).filter(
MultiAgentConfig.app_id == release.app_id,
MultiAgentConfig.is_active.is_(True)
).first()
if not multi_agent_config:
raise BusinessException("多 Agent 配置不存在", BizCode.AGENT_CONFIG_MISSING)
# 获取 storage_type 和 user_rag_memory_id
workspace_id = release.app.workspace_id
storage_type = 'neo4j' # 默认值
user_rag_memory_id = ''
try:
# 获取工作空间的存储类型(不需要用户权限检查,因为是公开分享)
from app.models import Workspace
workspace = self.db.get(Workspace, workspace_id)
if workspace and workspace.storage_type:
storage_type = workspace.storage_type
# 获取 USER_RAG_MERORY 知识库 ID
knowledge = knowledge_repository.get_knowledge_by_name(
db=self.db,
@@ -751,13 +745,13 @@ class SharedChatService:
user_rag_memory_id = str(knowledge.id)
except Exception as e:
logger.warning(f"获取 storage_type 或 user_rag_memory_id 失败,使用默认值: {str(e)}")
# 发送开始事件
yield f"event: start\ndata: {json.dumps({'conversation_id': str(conversation.id)}, ensure_ascii=False)}\n\n"
# 构建多 Agent 运行请求
from app.schemas.multi_agent_schema import MultiAgentRunRequest
multi_agent_request = MultiAgentRunRequest(
message=message,
conversation_id=conversation.id,
@@ -767,20 +761,20 @@ class SharedChatService:
web_search=web_search,
memory=memory
)
# 使用多 Agent 服务流式执行
multi_agent_service = MultiAgentService(self.db)
full_content = ""
async for event in multi_agent_service.run_stream(
app_id=release.app_id,
request=multi_agent_request,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id
app_id=release.app_id,
request=multi_agent_request,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id
):
# 直接转发事件
yield event
# 尝试提取内容(用于保存)
if "data:" in event:
try:
@@ -790,16 +784,16 @@ class SharedChatService:
full_content += data["content"]
except:
pass
elapsed_time = time.time() - start_time
# 保存消息
self.conversation_service.add_message(
conversation_id=conversation.id,
role="user",
content=message
)
self.conversation_service.add_message(
conversation_id=conversation.id,
role="assistant",
@@ -808,7 +802,7 @@ class SharedChatService:
"elapsed_time": elapsed_time
}
)
logger.info(
"多 Agent 流式聊天完成",
extra={
@@ -818,7 +812,6 @@ class SharedChatService:
}
)
except (GeneratorExit, asyncio.CancelledError):
# 生成器被关闭或任务被取消,正常退出
logger.debug("多 Agent 流式聊天被中断")

View File

@@ -121,7 +121,7 @@ class SkillService:
if skill and skill.is_active:
# 加载技能关联的工具
for tool_config in skill.tools:
tool = tool_service._get_tool_instance(tool_config.get("tool_id", ""), tenant_id)
tool = tool_service.get_tool_instance(tool_config.get("tool_id", ""), tenant_id)
if tool:
langchain_tool = tool.to_langchain_tool(tool_config.get("operation", None))
tools.append(langchain_tool)

View File

@@ -209,7 +209,7 @@ class ToolService:
try:
# 获取工具实例
tool = self._get_tool_instance(tool_id, tenant_id)
tool = self.get_tool_instance(tool_id, tenant_id)
if not tool:
return ToolResult.error_result(
error=f"工具不存在: {tool_id}",
@@ -335,7 +335,7 @@ class ToolService:
return []
# 获取工具实例
tool_instance = self._get_tool_instance(str(config.id), config.tenant_id)
tool_instance = self.get_tool_instance(str(config.id), config.tenant_id)
if not tool_instance:
return []
@@ -792,7 +792,7 @@ class ToolService:
"""获取工具配置"""
return self.tool_repo.find_by_id_and_tenant(self.db, uuid.UUID(tool_id), tenant_id)
def _get_tool_instance(self, tool_id: str, tenant_id: uuid.UUID) -> Optional[BaseTool]:
def get_tool_instance(self, tool_id: str, tenant_id: uuid.UUID) -> Optional[BaseTool]:
"""获取工具实例"""
if tool_id in self._tool_cache:
return self._tool_cache[tool_id]
@@ -1416,7 +1416,7 @@ class ToolService:
"""测试内置工具连接"""
try:
# 获取工具实例
tool_instance = self._get_tool_instance(str(config.id), config.tenant_id)
tool_instance = self.get_tool_instance(str(config.id), config.tenant_id)
if not tool_instance:
return {"success": False, "message": "无法创建工具实例"}

View File

@@ -10,6 +10,9 @@ from collections import Counter
from datetime import datetime
from typing import Any, Dict, List, Optional, Tuple
from pydantic import BaseModel, Field
from sqlalchemy.orm import Session
from app.core.logging_config import get_logger
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.db import get_db_context
@@ -23,8 +26,6 @@ from app.services.memory_base_service import MemoryBaseService, MemoryTransServi
from app.services.memory_config_service import MemoryConfigService
from app.services.memory_perceptual_service import MemoryPerceptualService
from app.services.memory_short_service import ShortService
from pydantic import BaseModel, Field
from sqlalchemy.orm import Session
logger = get_logger(__name__)
@@ -1035,9 +1036,10 @@ async def analytics_memory_insight_report(end_user_id: Optional[str] = None, lan
"growth_trajectory": str # 成长轨迹
}
"""
from app.core.memory.utils.prompt.prompt_utils import render_memory_insight_prompt
from app.core.language_utils import validate_language
import re
from app.core.language_utils import validate_language
from app.core.memory.utils.prompt.prompt_utils import render_memory_insight_prompt
# 验证语言参数
language = validate_language(language)
@@ -1161,13 +1163,35 @@ async def analytics_user_summary(end_user_id: Optional[str] = None, language: st
"one_sentence": str
}
"""
from app.core.memory.utils.prompt.prompt_utils import render_user_summary_prompt
from app.core.language_utils import validate_language
import re
from app.core.language_utils import validate_language
from app.core.memory.utils.prompt.prompt_utils import render_user_summary_prompt
from app.db import get_db
from app.repositories.end_user_repository import EndUserRepository
# 验证语言参数
language = validate_language(language)
# 获取用户的 other_name 字段
user_display_name = "该用户" if language == "zh" else "the user"
if end_user_id:
try:
# 获取数据库会话并查询用户信息
db = next(get_db())
try:
repo = EndUserRepository(db)
end_user = repo.get_by_id(uuid.UUID(end_user_id))
if end_user and end_user.other_name:
user_display_name = end_user.other_name
logger.info(f"使用 other_name 作为用户显示名称: {user_display_name}")
else:
logger.info(f"用户 {end_user_id} 的 other_name 为空,使用默认称呼: {user_display_name}")
finally:
db.close()
except Exception as e:
logger.warning(f"获取用户 other_name 失败,使用默认称呼: {str(e)}")
# 创建 UserSummaryHelper 实例
user_summary_tool = UserSummaryHelper(end_user_id or os.getenv("SELECTED_end_user_id", "group_123"))
@@ -1184,7 +1208,8 @@ async def analytics_user_summary(end_user_id: Optional[str] = None, language: st
user_id=user_summary_tool.user_id,
entities=", ".join(entity_lines) if entity_lines else "(空)" if language == "zh" else "(empty)",
statements=" | ".join(statement_samples) if statement_samples else "(空)" if language == "zh" else "(empty)",
language=language
language=language,
user_display_name=user_display_name
)
messages = [
@@ -1435,7 +1460,7 @@ async def analytics_memory_types(
short_term_count = 0
if end_user_id:
try:
short_term_service = ShortService(end_user_id)
short_term_service = ShortService(end_user_id, db)
short_term_data = short_term_service.get_short_databasets()
# 统计 short_term 数组的长度
if short_term_data:
@@ -1449,8 +1474,10 @@ async def analytics_memory_types(
forgetting_threshold = 0.3 # 默认值
if end_user_id:
try:
from app.core.memory.storage_services.forgetting_engine.config_utils import (
load_actr_config_from_db,
)
from app.services.memory_agent_service import get_end_user_connected_config
from app.core.memory.storage_services.forgetting_engine.config_utils import load_actr_config_from_db
# 获取用户关联的 config_id
connected_config = get_end_user_connected_config(end_user_id, db)

View File

@@ -13,7 +13,10 @@ from sqlalchemy.orm import Session
from app.core.error_codes import BizCode
from app.core.exceptions import BusinessException
from app.core.workflow.adapters.registry import PlatformAdapterRegistry
from app.core.workflow.executor import execute_workflow, execute_workflow_stream
from app.core.workflow.nodes.enums import NodeType
from app.core.workflow.validator import validate_workflow_config
from app.core.workflow.variable.base_variable import FileObject
from app.db import get_db
from app.models import App
from app.models.workflow_model import WorkflowConfig, WorkflowExecution
@@ -22,7 +25,7 @@ from app.repositories.workflow_repository import (
WorkflowExecutionRepository,
WorkflowNodeExecutionRepository
)
from app.schemas import DraftRunRequest
from app.schemas import DraftRunRequest, FileInput
from app.services.conversation_service import ConversationService
from app.services.multi_agent_service import convert_uuids_to_str
from app.services.multimodal_service import MultimodalService
@@ -444,6 +447,94 @@ class WorkflowService:
"success_rate": completed / total if total > 0 else 0
}
async def _handle_file_input(self, files: list[FileInput]):
if not files:
return []
files_struct = []
for file in files:
files_struct.append(
FileObject(
type=file.type,
url=await self.multimodal_service.get_file_url(file),
transfer_method=file.transfer_method,
file_id=str(file.upload_file_id),
origin_file_type=file.file_type,
is_file=True
).model_dump()
)
return files_struct
@staticmethod
def _map_public_event(event: dict) -> dict | None:
"""
Map internal workflow events to public-facing event formats.
Purpose:
- Hide internal execution details
- Expose a stable and simplified public event schema
- Filter out non-public events
- Maintain backward compatibility when possible
Args:
event (dict): Internal event object, e.g.:
{
"event": "workflow_start",
"data": {...}
}
Returns:
dict | None:
- Returns the mapped public event
- Returns None if the event should not be exposed
"""
event_type = event.get("event")
payload = event.get("data")
match event_type:
case "workflow_start":
return {
"event": "start",
"data": {
"conversation_id": payload.get("conversation_id"),
}
}
case "workflow_end":
return {
"event": "end",
"data": {
"elapsed_time": payload.get("elapsed_time"),
"message_length": len(payload.get("output", "")),
"error": payload.get("error", "")
}
}
case "node_start" | "node_end" | "node_error" | "cycle_item":
return None
case _:
return event
def _emit(self, public: bool, internal_event: dict):
"""
Unified event emission entry.
Args:
public (bool):
- True -> Emit mapped public event
- False -> Emit raw internal event
internal_event (dict):
The original internal event object
Returns:
dict | None:
- The mapped event
- Or None if the event is filtered out
"""
if public:
mapped = self._map_public_event(internal_event)
else:
mapped = internal_event
return mapped
# ==================== 工作流执行 ====================
async def run(
@@ -478,10 +569,11 @@ class WorkflowService:
message=f"工作流配置不存在: app_id={app_id}"
)
input_data = {"message": payload.message, "variables": payload.variables,
"conversation_id": payload.conversation_id,
"files": [file.model_dump(mode='json') for file in payload.files]
}
input_data = {
"message": payload.message, "variables": payload.variables,
"conversation_id": payload.conversation_id,
"files": [file.model_dump(mode='json') for file in payload.files]
}
# 转换 conversation_id 为 UUID
conversation_id_uuid = uuid.UUID(payload.conversation_id) if payload.conversation_id else None
@@ -505,22 +597,8 @@ class WorkflowService:
"execution_config": config.execution_config
}
# 4. 获取工作空间 ID从 app 获取)
# 5. 执行工作流
from app.core.workflow.executor import execute_workflow
try:
files = []
if payload.files:
for file in payload.files:
files.append(
{
"type": file.type,
"url": await self.multimodal_service.get_file_url(file),
"__file": True
}
)
files = await self._handle_file_input(payload.files)
input_data["files"] = files
# 更新状态为运行中
self.update_execution_status(execution.execution_id, "running")
@@ -580,6 +658,7 @@ class WorkflowService:
# "variables": result.get("variables"),
# "messages": result.get("messages"),
"output": result.get("output"), # 最终输出(字符串)
"message": result.get("output"), # 最终输出(字符串)
# "output_data": result.get("node_outputs", {}), # 所有节点输出(详细数据)
"conversation_id": result.get("conversation_id"), # 所有节点输出详细数据payload., # 会话 ID
"error_message": result.get("error"),
@@ -599,41 +678,6 @@ class WorkflowService:
message=f"工作流执行失败: {str(e)}"
)
@staticmethod
def _map_public_event(event: dict) -> dict | None:
event_type = event.get("event")
payload = event.get("data")
match event_type:
case "workflow_start":
return {
"event": "start",
"data": {
"conversation_id": payload.get("conversation_id"),
}
}
case "workflow_end":
return {
"event": "end",
"data": {
"elapsed_time": payload.get("elapsed_time"),
"message_length": len(payload.get("output", ""))
}
}
case "node_start" | "node_end" | "node_error" | "cycle_item":
return None
case _:
return event
def _emit(self, public: bool, internal_event: dict):
"""
decide
"""
if public:
mapped = self._map_public_event(internal_event)
else:
mapped = internal_event
return mapped
async def run_stream(
self,
app_id: uuid.UUID,
@@ -668,10 +712,11 @@ class WorkflowService:
message=f"工作流配置不存在: app_id={app_id}"
)
input_data = {"message": payload.message, "variables": payload.variables,
"conversation_id": payload.conversation_id,
"files": [file.model_dump(mode='json') for file in payload.files]
}
input_data = {
"message": payload.message, "variables": payload.variables,
"conversation_id": payload.conversation_id,
"files": [file.model_dump(mode='json') for file in payload.files]
}
# 转换 conversation_id 为 UUID
conversation_id_uuid = uuid.UUID(payload.conversation_id) if payload.conversation_id else None
@@ -696,16 +741,7 @@ class WorkflowService:
}
try:
files = []
if payload.files:
for file in payload.files:
files.append(
{
"type": file.type,
"url": await self.multimodal_service.get_file_url(file),
"__file": True
}
)
files = await self._handle_file_input(payload.files)
input_data["files"] = files
self.update_execution_status(execution.execution_id, "running")
executions = self.execution_repo.get_by_conversation_id(conversation_id=conversation_id_uuid)
@@ -720,7 +756,6 @@ class WorkflowService:
input_data["conv_messages"] = last_state.get("messages") or []
break
init_message_length = len(input_data.get("conv_messages", []))
from app.core.workflow.executor import execute_workflow_stream
async for event in execute_workflow_stream(
workflow_config=workflow_config_dict,
@@ -778,36 +813,13 @@ class WorkflowService:
}
}
def _clean_event_for_json(self, event: dict[str, Any]) -> dict[str, Any]:
"""清理事件数据,移除不可序列化的对象
Args:
event: 原始事件数据
Returns:
可序列化的事件数据
"""
from langchain_core.messages import BaseMessage
def clean_value(value):
"""递归清理值"""
if isinstance(value, BaseMessage):
# 将 Message 对象转换为字典
return {
"type": value.__class__.__name__,
"content": value.content,
}
elif isinstance(value, dict):
return {k: clean_value(v) for k, v in value.items()}
elif isinstance(value, list):
return [clean_value(item) for item in value]
elif isinstance(value, (str, int, float, bool, type(None))):
return value
else:
# 其他不可序列化的对象转换为字符串
return str(value)
return clean_value(event)
@staticmethod
def get_start_node_variables(config: dict) -> list:
nodes = config.get("nodes", [])
for node in nodes:
if node.get("type") == NodeType.START:
return node.get("config", {}).get("variables", [])
raise BusinessException("workflow config error - start node not found")
# ==================== 依赖注入函数 ====================

View File

@@ -1,16 +1,16 @@
import asyncio
from concurrent.futures import ThreadPoolExecutor
import json
import os
import re
import shutil
import time
import uuid
from uuid import UUID
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime, timezone
from math import ceil
from pathlib import Path
import shutil
from typing import Any, Dict, List, Optional
from uuid import UUID
import redis
import requests
@@ -38,7 +38,7 @@ from app.db import get_db, get_db_context
from app.models.document_model import Document
from app.models.file_model import File
from app.models.knowledge_model import Knowledge
from app.schemas import file_schema, document_schema
from app.schemas import document_schema, file_schema
from app.services.memory_agent_service import MemoryAgentService
from app.utils.config_utils import resolve_config_id
@@ -67,8 +67,9 @@ def parse_document(file_path: str, document_id: uuid.UUID):
Document parsing, vectorization, and storage
"""
# Force re-importing Trio in child processes (to avoid inheriting the state of the parent process)
import trio
import importlib
import trio
importlib.reload(trio)
db = next(get_db()) # Manually call the generator
db_document = None
@@ -256,7 +257,7 @@ def parse_document(file_path: str, document_id: uuid.UUID):
progress_msg += f"{datetime.now().strftime('%H:%M:%S')} GraphRAG task result for task {task}:\n{result}\n"
return result
try:
def sync_task():
trio.run(
lambda: _run(
row=task,
@@ -271,6 +272,10 @@ def parse_document(file_path: str, document_id: uuid.UUID):
with_community=with_community,
)
)
try:
with ThreadPoolExecutor(max_workers=1) as executor:
future = executor.submit(sync_task)
future.result() # Blocks until the task completes
except Exception as e:
progress_msg += f"{datetime.now().strftime('%H:%M:%S')} GraphRAG task failed for task {task}:\n{str(e)}\n"
progress_msg += f"{datetime.now().strftime('%H:%M:%S')} Knowledge Graph done ({time.time() - start_time}s)"
@@ -297,8 +302,9 @@ def build_graphrag_for_kb(kb_id: uuid.UUID):
build knowledge graph
"""
# Force re-importing Trio in child processes (to avoid inheriting the state of the parent process)
import trio
import importlib
import trio
importlib.reload(trio)
db = next(get_db()) # Manually call the generator
db_documents = None
@@ -932,24 +938,18 @@ def read_message_task(self, end_user_id: str, message: str, history: List[Dict[s
if actual_config_id is None:
try:
from app.services.memory_agent_service import get_end_user_connected_config
db = next(get_db())
try:
with get_db_context() as db:
connected_config = get_end_user_connected_config(end_user_id, db)
actual_config_id = connected_config.get("memory_config_id")
finally:
db.close()
except Exception:
# Log but continue - will fail later with proper error
pass
async def _run() -> str:
db = next(get_db())
try:
with get_db_context() as db:
service = MemoryAgentService()
return await service.read_memory(end_user_id, message, history, search_switch, actual_config_id, db,
storage_type, user_rag_memory_id)
finally:
db.close()
try:
# 使用 nest_asyncio 来避免事件循环冲突
@@ -1049,19 +1049,15 @@ def write_message_task(self, end_user_id: str, message: list[dict], config_id: s
if actual_config_id is None:
try:
from app.services.memory_agent_service import get_end_user_connected_config
db = next(get_db())
try:
with get_db_context() as db:
connected_config = get_end_user_connected_config(end_user_id, db)
actual_config_id = connected_config.get("memory_config_id")
finally:
db.close()
except Exception:
# Log but continue - will fail later with proper error
pass
async def _run() -> str:
db = next(get_db())
try:
with get_db_context() as db:
logger.info(
f"[CELERY WRITE] Executing MemoryAgentService.write_memory with config_id={actual_config_id} (type: {type(actual_config_id).__name__}), language={language}")
service = MemoryAgentService()
@@ -1069,11 +1065,6 @@ def write_message_task(self, end_user_id: str, message: list[dict], config_id: s
user_rag_memory_id, language)
logger.info(f"[CELERY WRITE] Write completed successfully: {result}")
return result
except Exception as e:
logger.error(f"[CELERY WRITE] Write failed: {e}", exc_info=True)
raise
finally:
db.close()
try:
# 使用 nest_asyncio 来避免事件循环冲突
@@ -1304,6 +1295,203 @@ def write_total_memory_task(workspace_id: str) -> Dict[str, Any]:
"workspace_id": workspace_id,
"elapsed_time": elapsed_time,
}
@celery_app.task(
name="app.tasks.write_all_workspaces_memory_task",
bind=True,
ignore_result=False,
max_retries=3,
acks_late=True,
time_limit=3600,
soft_time_limit=3300,
)
def write_all_workspaces_memory_task(self) -> Dict[str, Any]:
"""定时任务:遍历所有工作空间,统计并写入记忆增量
此任务会:
1. 查询所有活跃的工作空间
2. 对每个工作空间统计记忆总量
3. 将统计结果写入 memory_increments 表
Returns:
包含任务执行结果的字典
"""
start_time = time.time()
async def _run() -> Dict[str, Any]:
from app.core.logging_config import get_api_logger
from app.models.app_model import App
from app.models.end_user_model import EndUser
from app.models.workspace_model import Workspace
from app.repositories.memory_increment_repository import write_memory_increment
from app.services.memory_storage_service import search_all
api_logger = get_api_logger()
with get_db_context() as db:
try:
# 获取所有活跃的工作空间
workspaces = db.query(Workspace).filter(
Workspace.is_active.is_(True)
).all()
if not workspaces:
api_logger.warning("没有找到活跃的工作空间")
return {
"status": "SUCCESS",
"message": "没有找到活跃的工作空间",
"workspace_count": 0,
"workspace_results": []
}
api_logger.info(f"开始统计 {len(workspaces)} 个工作空间的记忆增量")
all_workspace_results = []
# 遍历每个工作空间
for workspace in workspaces:
workspace_id = workspace.id
api_logger.info(f"开始处理工作空间: {workspace.name} (ID: {workspace_id})")
try:
# 1. 查询当前workspace下的所有app仅未删除的
apps = db.query(App).filter(
App.workspace_id == workspace_id,
App.is_active.is_(True)
).all()
if not apps:
# 如果没有app总量为0
memory_increment = write_memory_increment(
db=db,
workspace_id=workspace_id,
total_num=0
)
all_workspace_results.append({
"workspace_id": str(workspace_id),
"workspace_name": workspace.name,
"status": "SUCCESS",
"total_num": 0,
"end_user_count": 0,
"memory_increment_id": str(memory_increment.id),
"created_at": memory_increment.created_at.isoformat(),
})
api_logger.info(f"工作空间 {workspace.name} 没有应用记录总量为0")
continue
# 2. 查询所有app下的end_user_id去重
app_ids = [app.id for app in apps]
end_users = db.query(EndUser.id).filter(
EndUser.app_id.in_(app_ids)
).distinct().all()
# 3. 遍历所有end_user查询每个宿主的记忆总量并累加
total_num = 0
end_user_details = []
for (end_user_id,) in end_users:
try:
# 调用 search_all 接口查询该宿主的总量
result = await search_all(str(end_user_id))
user_total = result.get("total", 0)
total_num += user_total
end_user_details.append({
"end_user_id": str(end_user_id),
"total": user_total
})
except Exception as e:
# 记录单个用户查询失败,但继续处理其他用户
api_logger.warning(f"查询用户 {end_user_id} 记忆失败: {str(e)}")
end_user_details.append({
"end_user_id": str(end_user_id),
"total": 0,
"error": str(e)
})
# 4. 写入数据库
memory_increment = write_memory_increment(
db=db,
workspace_id=workspace_id,
total_num=total_num
)
all_workspace_results.append({
"workspace_id": str(workspace_id),
"workspace_name": workspace.name,
"status": "SUCCESS",
"total_num": total_num,
"end_user_count": len(end_users),
"memory_increment_id": str(memory_increment.id),
"created_at": memory_increment.created_at.isoformat(),
})
api_logger.info(
f"工作空间 {workspace.name} 统计完成: 总量={total_num}, 用户数={len(end_users)}"
)
except Exception as e:
db.rollback() # 回滚失败的事务,允许继续处理下一个工作空间
api_logger.error(f"处理工作空间 {workspace.name} (ID: {workspace_id}) 失败: {str(e)}")
all_workspace_results.append({
"workspace_id": str(workspace_id),
"workspace_name": workspace.name,
"status": "FAILURE",
"error": str(e),
"total_num": 0,
"end_user_count": 0,
})
total_memory = sum(r.get("total_num", 0) for r in all_workspace_results)
success_count = sum(1 for r in all_workspace_results if r.get("status") == "SUCCESS")
return {
"status": "SUCCESS",
"message": f"成功处理 {success_count}/{len(workspaces)} 个工作空间,总记忆量: {total_memory}",
"workspace_count": len(workspaces),
"success_count": success_count,
"total_memory": total_memory,
"workspace_results": all_workspace_results
}
except Exception as e:
api_logger.error(f"记忆增量统计任务执行失败: {str(e)}")
return {
"status": "FAILURE",
"error": str(e),
"workspace_count": 0,
"workspace_results": []
}
try:
# 使用 nest_asyncio 来避免事件循环冲突
try:
import nest_asyncio
nest_asyncio.apply()
except ImportError:
pass
# 尝试获取现有事件循环,如果不存在则创建新的
try:
loop = asyncio.get_event_loop()
if loop.is_closed():
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
result = loop.run_until_complete(_run())
elapsed_time = time.time() - start_time
result["elapsed_time"] = elapsed_time
result["task_id"] = self.request.id
return result
except Exception as e:
elapsed_time = time.time() - start_time
return {
"status": "FAILURE",
"error": str(e),
"elapsed_time": elapsed_time,
"task_id": self.request.id
}
@celery_app.task(
@@ -1924,4 +2112,307 @@ def run_forgetting_cycle_task(self, config_id: Optional[uuid.UUID] = None) -> Di
# "config_id": config_id,
# "elapsed_time": elapsed_time,
# "task_id": self.request.id
# }
# }
# =============================================================================
# 隐性记忆和情绪数据更新定时任务
# =============================================================================
@celery_app.task(
name="app.tasks.update_implicit_emotions_storage",
bind=True,
ignore_result=True,
max_retries=0,
acks_late=False,
time_limit=7200, # 2小时硬超时
soft_time_limit=6900, # 1小时55分钟软超时
)
def update_implicit_emotions_storage(self) -> Dict[str, Any]:
"""定时任务:更新所有用户的隐性记忆画像和情绪建议数据
遍历数据库中所有已存在数据的用户,为每个用户重新生成隐性记忆画像和情绪建议。
实现错误隔离,单个用户失败不影响其他用户的处理。
Returns:
包含任务执行结果的字典,包括:
- status: 任务状态 (SUCCESS/FAILURE)
- message: 执行消息
- total_users: 总用户数
- successful_implicit: 成功更新隐性记忆的用户数
- successful_emotion: 成功更新情绪建议的用户数
- failed: 失败的用户数
- user_results: 每个用户的详细结果
- elapsed_time: 执行耗时(秒)
- task_id: 任务ID
"""
start_time = time.time()
async def _run() -> Dict[str, Any]:
from app.core.logging_config import get_logger
from app.repositories.implicit_emotions_storage_repository import ImplicitEmotionsStorageRepository
from app.models.implicit_emotions_storage_model import ImplicitEmotionsStorage
from sqlalchemy import select, func
from app.services.implicit_memory_service import ImplicitMemoryService
from app.services.emotion_analytics_service import EmotionAnalyticsService
logger = get_logger(__name__)
logger.info("开始执行隐性记忆和情绪数据更新定时任务")
total_users = 0
successful_implicit = 0
successful_emotion = 0
failed = 0
user_results = []
with get_db_context() as db:
try:
# 获取所有已存储数据的用户ID分批次处理
repo = ImplicitEmotionsStorageRepository(db)
# 先统计总数用于日志
from sqlalchemy import func
total_users = db.execute(
select(func.count()).select_from(ImplicitEmotionsStorage)
).scalar() or 0
logger.info(f"找到 {total_users} 个需要更新的用户")
# 遍历每个用户并更新数据分批次避免一次性加载所有ID
for end_user_id in repo.get_all_user_ids(batch_size=100):
logger.info(f"开始处理用户: {end_user_id}")
user_start_time = time.time()
implicit_success = False
emotion_success = False
errors = []
try:
# 更新隐性记忆画像
try:
implicit_service = ImplicitMemoryService(db=db, end_user_id=end_user_id)
profile_data = await implicit_service.generate_complete_profile(user_id=end_user_id)
await implicit_service.save_profile_cache(
end_user_id=end_user_id,
profile_data=profile_data,
db=db
)
implicit_success = True
logger.info(f"成功更新用户 {end_user_id} 的隐性记忆画像")
except Exception as e:
error_msg = f"隐性记忆更新失败: {str(e)}"
errors.append(error_msg)
logger.error(f"用户 {end_user_id} {error_msg}")
# 更新情绪建议
try:
emotion_service = EmotionAnalyticsService()
suggestions_data = await emotion_service.generate_emotion_suggestions(
end_user_id=end_user_id,
db=db,
language="zh"
)
await emotion_service.save_suggestions_cache(
end_user_id=end_user_id,
suggestions_data=suggestions_data,
db=db
)
emotion_success = True
logger.info(f"成功更新用户 {end_user_id} 的情绪建议")
except Exception as e:
error_msg = f"情绪建议更新失败: {str(e)}"
errors.append(error_msg)
logger.error(f"用户 {end_user_id} {error_msg}")
# 统计结果
if implicit_success:
successful_implicit += 1
if emotion_success:
successful_emotion += 1
if not implicit_success and not emotion_success:
failed += 1
user_elapsed = time.time() - user_start_time
# 记录用户处理结果
user_result = {
"end_user_id": end_user_id,
"implicit_success": implicit_success,
"emotion_success": emotion_success,
"errors": errors,
"elapsed_time": user_elapsed
}
user_results.append(user_result)
logger.info(
f"用户 {end_user_id} 处理完成: "
f"隐性记忆={'成功' if implicit_success else '失败'}, "
f"情绪建议={'成功' if emotion_success else '失败'}, "
f"耗时={user_elapsed:.2f}"
)
except Exception as e:
# 单个用户失败不影响其他用户(错误隔离)
failed += 1
user_elapsed = time.time() - user_start_time
error_info = {
"end_user_id": end_user_id,
"implicit_success": False,
"emotion_success": False,
"errors": [str(e)],
"elapsed_time": user_elapsed
}
user_results.append(error_info)
logger.error(f"处理用户 {end_user_id} 时出错: {str(e)}")
# ---- 处理增量用户(当天新增、尚未初始化的用户)----
new_users_initialized = 0
new_users_failed = 0
logger.info("开始处理当天新增的增量用户初始化")
for end_user_id in repo.get_new_user_ids_today(batch_size=100):
logger.info(f"开始初始化新用户: {end_user_id}")
user_start_time = time.time()
implicit_success = False
emotion_success = False
errors = []
try:
try:
implicit_service = ImplicitMemoryService(db=db, end_user_id=end_user_id)
profile_data = await implicit_service.generate_complete_profile(user_id=end_user_id)
await implicit_service.save_profile_cache(
end_user_id=end_user_id,
profile_data=profile_data,
db=db
)
implicit_success = True
logger.info(f"成功初始化新用户 {end_user_id} 的隐性记忆画像")
except Exception as e:
error_msg = f"隐性记忆初始化失败: {str(e)}"
errors.append(error_msg)
logger.error(f"新用户 {end_user_id} {error_msg}")
try:
emotion_service = EmotionAnalyticsService()
suggestions_data = await emotion_service.generate_emotion_suggestions(
end_user_id=end_user_id,
db=db,
language="zh"
)
await emotion_service.save_suggestions_cache(
end_user_id=end_user_id,
suggestions_data=suggestions_data,
db=db
)
emotion_success = True
logger.info(f"成功初始化新用户 {end_user_id} 的情绪建议")
except Exception as e:
error_msg = f"情绪建议初始化失败: {str(e)}"
errors.append(error_msg)
logger.error(f"新用户 {end_user_id} {error_msg}")
if implicit_success or emotion_success:
new_users_initialized += 1
else:
new_users_failed += 1
user_elapsed = time.time() - user_start_time
user_results.append({
"end_user_id": end_user_id,
"type": "init",
"implicit_success": implicit_success,
"emotion_success": emotion_success,
"errors": errors,
"elapsed_time": user_elapsed
})
except Exception as e:
new_users_failed += 1
user_elapsed = time.time() - user_start_time
user_results.append({
"end_user_id": end_user_id,
"type": "init",
"implicit_success": False,
"emotion_success": False,
"errors": [str(e)],
"elapsed_time": user_elapsed
})
logger.error(f"初始化新用户 {end_user_id} 时出错: {str(e)}")
logger.info(
f"增量用户初始化完成: 成功={new_users_initialized}, 失败={new_users_failed}"
)
# ---- 增量用户处理结束 ----
# 记录总体统计信息
logger.info(
f"隐性记忆和情绪数据更新定时任务完成: "
f"存量用户总数={total_users}, "
f"隐性记忆成功={successful_implicit}, "
f"情绪建议成功={successful_emotion}, "
f"存量失败={failed}, "
f"增量初始化成功={new_users_initialized}, "
f"增量初始化失败={new_users_failed}"
)
return {
"status": "SUCCESS",
"message": (
f"存量用户 {total_users} 个,隐性记忆 {successful_implicit} 个成功,情绪建议 {successful_emotion} 个成功;"
f"增量新用户初始化 {new_users_initialized} 个成功,{new_users_failed} 个失败"
),
"total_users": total_users,
"successful_implicit": successful_implicit,
"successful_emotion": successful_emotion,
"failed": failed,
"new_users_initialized": new_users_initialized,
"new_users_failed": new_users_failed,
"user_results": user_results[:50] # 只保留前50个用户的详细结果
}
except Exception as e:
logger.error(f"隐性记忆和情绪数据更新定时任务执行失败: {str(e)}")
return {
"status": "FAILURE",
"error": str(e),
"total_users": total_users,
"successful_implicit": successful_implicit,
"successful_emotion": successful_emotion,
"failed": failed,
"new_users_initialized": 0,
"new_users_failed": 0,
"user_results": user_results[:50]
}
try:
# 使用 nest_asyncio 来避免事件循环冲突
try:
import nest_asyncio
nest_asyncio.apply()
except ImportError:
pass
# 尝试获取现有事件循环,如果不存在则创建新的
try:
loop = asyncio.get_event_loop()
if loop.is_closed():
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
result = loop.run_until_complete(_run())
elapsed_time = time.time() - start_time
result["elapsed_time"] = elapsed_time
result["task_id"] = self.request.id
return result
except Exception as e:
elapsed_time = time.time() - start_time
return {
"status": "FAILURE",
"error": str(e),
"elapsed_time": elapsed_time,
"task_id": self.request.id
}

View File

@@ -139,7 +139,7 @@ SMTP_USER=
SMTP_PASSWORD=
# 本体类型融合配置 (记得写入env_example)
GENERAL_ONTOLOGY_FILES=General_purpose_entity.ttl # 指定要加载的本体文件路径,多个文件用逗号分隔
GENERAL_ONTOLOGY_FILES=api/app/core/memory/ontology_services/General_purpose_entity.ttl # 指定要加载的本体文件路径,多个文件用逗号分隔
ENABLE_GENERAL_ONTOLOGY_TYPES=true # 总开关,控制是否启用通用本体类型融合功能(false = 不使用任何本体类型指导)
MAX_ONTOLOGY_TYPES_IN_PROMPT=100 # 限制传给 LLM 的类型数量,防止 Prompt 过长
CORE_GENERAL_TYPES=Person,Organization,Place,Event,Work,Concept # 定义核心类型列表,这些类型会优先包含在合并结果中

View File

@@ -0,0 +1,43 @@
"""202603051440
Revision ID: 6a4641cf192b
Revises: b4af97639217
Create Date: 2026-03-05 14:41:03.371557
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision: str = '6a4641cf192b'
down_revision: Union[str, None] = 'b4af97639217'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('implicit_emotions_storage',
sa.Column('id', sa.UUID(), nullable=False, comment='主键ID'),
sa.Column('end_user_id', sa.String(length=255), nullable=False, comment='终端用户ID'),
sa.Column('implicit_profile', postgresql.JSONB(astext_type=sa.Text()), nullable=True, comment='隐性记忆用户画像数据'),
sa.Column('emotion_suggestions', postgresql.JSONB(astext_type=sa.Text()), nullable=True, comment='情绪个性化建议数据'),
sa.Column('created_at', sa.DateTime(), nullable=False, comment='创建时间'),
sa.Column('updated_at', sa.DateTime(), nullable=False, comment='更新时间'),
sa.Column('implicit_generated_at', sa.DateTime(), nullable=True, comment='隐性记忆画像生成时间'),
sa.Column('emotion_generated_at', sa.DateTime(), nullable=True, comment='情绪建议生成时间'),
sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint('end_user_id')
)
op.create_index('idx_updated_at', 'implicit_emotions_storage', ['updated_at'], unique=False)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_index('idx_updated_at', table_name='implicit_emotions_storage')
op.drop_table('implicit_emotions_storage')
# ### end Alembic commands ###

View File

@@ -0,0 +1,63 @@
"""202603051033
Revision ID: b4af97639217
Revises: 4bf27c66ae63
Create Date: 2026-03-05 10:36:06.282227
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = 'b4af97639217'
down_revision: Union[str, None] = '4bf27c66ae63'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
# Add columns as nullable first to avoid table locks
op.add_column('model_api_keys', sa.Column('capability', sa.ARRAY(sa.String()), nullable=True, comment="模型能力列表(如['vision', 'audio', 'video']"))
op.add_column('model_api_keys', sa.Column('is_omni', sa.Boolean(), nullable=True, comment='是否为Omni模型使用特殊API调用'))
op.add_column('model_bases', sa.Column('capability', sa.ARRAY(sa.String()), nullable=True, comment="模型能力列表(如['vision', 'audio', 'video']"))
op.add_column('model_bases', sa.Column('is_omni', sa.Boolean(), nullable=True, comment='是否为Omni模型使用特殊API调用'))
op.add_column('model_configs', sa.Column('capability', sa.ARRAY(sa.String()), nullable=True, comment="模型能力列表(如['vision', 'audio', 'video']"))
op.add_column('model_configs', sa.Column('is_omni', sa.Boolean(), nullable=True, comment='是否为Omni模型使用特殊API调用'))
# Update existing rows with default values
op.execute("UPDATE model_api_keys SET capability = '{}' WHERE capability IS NULL")
op.execute("UPDATE model_api_keys SET is_omni = false WHERE is_omni IS NULL")
op.execute("UPDATE model_bases SET capability = '{}' WHERE capability IS NULL")
op.execute("UPDATE model_bases SET is_omni = false WHERE is_omni IS NULL")
op.execute("UPDATE model_configs SET capability = '{}' WHERE capability IS NULL")
op.execute("UPDATE model_configs SET is_omni = false WHERE is_omni IS NULL")
# Now make columns NOT NULL
op.alter_column('model_api_keys', 'capability', nullable=False)
op.alter_column('model_api_keys', 'is_omni', nullable=False)
op.alter_column('model_bases', 'capability', nullable=False)
op.alter_column('model_bases', 'is_omni', nullable=False)
op.alter_column('model_configs', 'capability', nullable=False)
op.alter_column('model_configs', 'is_omni', nullable=False)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column('model_configs', 'is_omni')
op.drop_column('model_configs', 'capability')
op.drop_column('model_bases', 'is_omni')
op.drop_column('model_bases', 'capability')
op.drop_column('model_api_keys', 'is_omni')
op.drop_column('model_api_keys', 'capability')
# ### end Alembic commands ###

2
web/.gitignore vendored
View File

@@ -23,4 +23,4 @@ dist-ssr
*.sln
*.sw?
vite.config.js
package-lock.json
package-lock.json

View File

@@ -1,205 +1,195 @@
# i18n 中英文对比报告
# Memory Bear 前端项目 - 中英文国际化对比报告
## 📊 统计概览
生成时间: 2024
- **中文键总数**: 1136
- **英文键总数**: 1052
- **中文缺失**: 27 个键
- **英文缺失**: 111 个键
## 📊 概览统计
### 文件信息
- **中文文件**: `src/i18n/zh.ts`
- **英文文件**: `src/i18n/en.ts`
### 模块统计
| 模块名称 | 中文键数 | 英文键数 | 状态 |
|---------|---------|---------|------|
| translation | ✅ | ✅ | 完整 |
## 🔍 详细对比分析
### 1. 主要模块对比
#### 1.1 基础信息 (title, memoryBear)
-**完全匹配**
- 中文: "记忆熊.AI"
- 英文: "Memory Bear.AI"
#### 1.2 首页模块 (index)
-**完全匹配** - 包含所有子键
#### 1.3 版本信息 (version)
-**完全匹配**
#### 1.4 快速操作 (quickActions)
-**完全匹配** - 包含所有功能入口
#### 1.5 引导模块 (guide)
-**完全匹配**
#### 1.6 首页引导 (indexTour)
-**完全匹配**
#### 1.7 菜单模块 (menu)
-**完全匹配** - 包含所有导航项
#### 1.8 仪表盘 (dashboard)
-**完全匹配** - 包含所有统计指标
#### 1.9 表格 (table)
-**完全匹配**
#### 1.10 头部 (header)
-**完全匹配**
#### 1.11 语言 (language)
-**完全匹配**
#### 1.12 用户管理 (user)
-**完全匹配** - 包含所有用户相关功能
#### 1.13 时区 (timezones)
-**完全匹配** - 包含全球主要时区
#### 1.14 通用 (common)
-**完全匹配** - 包含所有通用操作和提示
#### 1.15 模型管理 (model)
-**完全匹配**
#### 1.16 新模型管理 (modelNew)
-**完全匹配**
#### 1.17 知识库 (knowledgeBase)
-**完全匹配** - 包含所有知识库功能
- 包含知识图谱相关配置
#### 1.18 API (api)
-**完全匹配**
#### 1.19 记忆管理 (memory)
-**完全匹配**
#### 1.20 成员管理 (member)
-**完全匹配**
#### 1.21 记忆摘要 (memorySummary)
-**完全匹配**
#### 1.22 遗忘引擎 (forgettingEngine)
-**完全匹配**
#### 1.23 应用管理 (application)
-**完全匹配** - 包含所有应用配置功能
- 包含工作流、Agent配置等
#### 1.24 用户记忆 (userMemory)
-**完全匹配** - 包含所有记忆类型
#### 1.25 空间管理 (space)
-**完全匹配**
#### 1.26 记忆萃取引擎 (memoryExtractionEngine)
-**完全匹配** - 包含所有配置参数
#### 1.27 记忆对话 (memoryConversation)
-**完全匹配**
#### 1.28 登录 (login)
-**完全匹配**
#### 1.29 空状态 (empty)
-**完全匹配**
#### 1.30 API密钥 (apiKey)
-**完全匹配**
#### 1.31 工具管理 (tool)
-**完全匹配** - 包含MCP服务、内置工具、自定义工具
#### 1.32 工作流 (workflow)
-**完全匹配** - 包含所有节点配置
#### 1.33 情感引擎 (emotionEngine)
-**完全匹配**
#### 1.34 情感详情 (statementDetail)
-**完全匹配**
#### 1.35 反思引擎 (reflectionEngine)
-**完全匹配**
#### 1.36 定价 (pricing)
-**完全匹配** - 包含所有套餐信息
#### 1.37 遗忘详情 (forgetDetail)
-**完全匹配**
#### 1.38 情景记忆详情 (episodicDetail)
-**完全匹配**
#### 1.39 内隐记忆详情 (implicitDetail)
-**完全匹配**
#### 1.40 短期记忆详情 (shortTermDetail)
-**完全匹配**
#### 1.41 感知记忆详情 (perceptualDetail)
-**完全匹配**
#### 1.42 外显记忆详情 (explicitDetail)
-**完全匹配**
#### 1.43 工作记忆详情 (workingDetail)
-**完全匹配**
#### 1.44 本体工程 (ontology)
-**完全匹配**
#### 1.45 提示词工程 (prompt)
-**完全匹配**
#### 1.46 技能库 (skills)
-**完全匹配**
## ✅ 结论
### 整体评估
- **状态**: 🟢 完全同步
- **中英文键值对**: 完全匹配
- **结构一致性**: 100%
### 优点
1. ✅ 所有模块的中英文翻译完整
2. ✅ 键名结构完全一致
3. ✅ 嵌套层级对应准确
4. ✅ 特殊字符和变量占位符使用正确
5. ✅ 时区、语言等枚举值完整
### 建议
1. 定期检查新增功能的国际化覆盖
2. 建议添加自动化测试确保中英文键值对同步
3. 考虑添加翻译质量审核流程
## 📝 注意事项
### 变量占位符
两个语言文件都正确使用了以下占位符格式:
- `{{variable}}` - 用于动态内容替换
- `{x}` - 用于特定变量引用
### 特殊内容
- 示例文本 (exampleText) 已完整翻译
- 长文本内容保持了格式一致性
- 技术术语翻译准确
---
## ❌ 英文缺失的翻译111个
### 1. Application 模块 (3个)
- `application.cluster` - 集群
- `application.clusterDesc` - 创建Agent集群
- `application.fullAmount` - 全量
### 2. Role 角色管理模块 (15个)
- `role.roleManagement` - 角色管理
- `role.roleId` - 角色ID
- `role.roleName` - 角色名称
- `role.roleCode` - 角色编码
- `role.description` - 角色描述
- `role.status` - 状态
- `role.enabled` - 已启用
- `role.disabled` - 已停用
- `role.createTime` - 创建时间
- `role.createRole` - 新建角色
- `role.editRole` - 编辑角色
- `role.roleTemplate` - 角色模板
- `role.emptyTemplate` - 空模板
- `role.adminTemplate` - 管理员模板
- `role.userTemplate` - 用户模板
- `role.confirmDelete` - 确定要删除这个角色吗?
- `role.createSuccess` - 角色创建成功
- `role.updateSuccess` - 角色更新成功
- `role.deleteSuccess` - 角色删除成功
- `role.createFailed` - 角色创建失败
- `role.updateFailed` - 角色更新失败
- `role.deleteFailed` - 角色删除失败
### 3. Tenant 租户管理模块 (20个)
- `tenant.tenantId` - 租户ID
- `tenant.tenantName` - 租户名称
- `tenant.contactPerson` - 联系人
- `tenant.contactInfo` - 联系方式
- `tenant.status` - 状态
- `tenant.enabled` - 启用
- `tenant.disabled` - 禁用
- `tenant.expiryDate` - 到期时间
- `tenant.createTenant` - 新增租户
- `tenant.editTenant` - 编辑租户
- `tenant.searchPlaceholder` - 搜索租户ID、名称、联系人或联系方式
- `tenant.confirmDelete` - 确定要删除该租户吗?
- `tenant.confirmBatchDelete` - 确定要批量删除选中的租户吗?
- `tenant.fetchFailed` - 获取租户数据失败
- `tenant.batchEnableSuccess` - 批量启用成功
- `tenant.batchEnableFailed` - 批量启用失败
- `tenant.batchDisableSuccess` - 批量停用成功
- `tenant.batchDisableFailed` - 批量停用失败
- `tenant.exportSuccess` - 导出成功
- `tenant.batchDeleteSuccess` - 批量删除成功
- `tenant.batchDeleteFailed` - 批量删除失败
- `tenant.saveFailed` - 保存失败
- `tenant.batchImport` - 批量导入
### 4. User 用户管理模块 (13个)
- `user.tenantName` - 所属租户
- `user.password` - 密码
- `user.expiryDate` - 有效期
- `user.expiryDateDue` - 有效期至
- `user.batchImport` - 批量导入
- `user.batchImportUser` - 批量导入用户
- `user.downloadTemplate` - 下载导入模板
- `user.templateDownloadSuccess` - 模板下载成功
- `user.startImport` - 开始导入
- `user.batchImportSuccess` - 批量导入成功
- `user.importFailed` - 导入失败,请检查文件格式
- `user.noFileSelected` - 请选择要导入的文件
- `user.onlyXlsxOrCsv` - 只能上传 .xlsx 或 .csv 格式的文件
- `user.reselect` - 重新选择
- `user.noFileSelectedTip` - 未选择任何文件
- `user.downloadTemplateTip` - 请下载模板,填写用户信息后上传。
### 5. Product 产品管理模块 (13个)
- `product.applicationManagement` - 应用管理
- `product.createApplication` - 创建应用
- `product.applicationName` - 应用名称
- `product.applicationIcon` - 应用图标
- `product.applicationNameRequired` - 请输入应用名称
- `product.associationStatus` - 关联状态
- `product.associated` - 已关联
- `product.notAssociated` - 未关联
- `product.unassociate` - 解除关联
- `product.unassociateSuccess` - 解除关联成功
- `product.unassociateFailed` - 解除关联失败
- `product.viewKey` - 查看KEY
- `product.viewStats` - 查看统计
- `product.disableSuccess` - 停用成功
- `product.enableSuccess` - 启用成功
- `product.operationFailed` - 操作失败
### 6. 其他模块 (47个)
- `count` - 计数: {{count}}
- `increment` - 增加
- `decrement` - 减少
- `reset` - 重置
- `switchLanguage` - 切换语言
- `home.title` - 首页
- `home.welcome` - 欢迎使用我们的带单页路由的 React 应用!
- `home.counterCard` - 计数器演示
- `home.aboutCard` - 关于我们
- `home.workflowCard` - 工作流编辑器
- `home.websocketDemoCard` - WebSocket 演示
- `home.sseDemoCard` - SSE演示
- `workflow.title` - 工作流编辑器
- `workflow.description` - 拖拽节点创建连接,构建您的工作流程。点击节点可进行配置。
- `workflow.addNode` - 添加节点
- `workflow.deleteNode` - 删除选中
- `workflow.saveWorkflow` - 保存工作流
- `workflow.startNode` - 触发节点
- `workflow.conditionNode` - 条件判断
- `workflow.actionNode` - 执行动作
- `workflow.endNode` - 结束节点
- `workflow.newNode` - 新节点
- `workflow.node` - 节点
- `workflow.nodesCreated` - 已创建节点
- `workflow.loadingNodes` - 正在加载节点 {{progress}}%
- `workflow.loadingFailed` - 加载节点失败
- `workflow.create5kNodes` - 创建5000节点
- `workflow.create10kNodes` - 创建10000节点
- `notFound.title` - 页面未找到
- `notFound.description` - 请求的页面不存在。
- `notFound.backToHome` - 返回首页
---
## ✅ 中文缺失的翻译27个
### 1. Common 通用模块 (1个)
- `common.operateSuccess` - Operation successful
### 2. KnowledgeBase 知识库模块 (3个)
- `knowledgeBase.models` - Model
- `knowledgeBase.owner` - Owner
- `knowledgeBase.operation` - Operation
### 3. Application 应用模块 (15个)
- `application.multi_agent` - Cluster
- `application.multi_agentDesc` - Create an Agent Cluster
- `application.current` - Current
- `application.versionName` - Version Name
- `application.versionNameTip` - Version number format: v[major version number].[next version number].[revision number] (e.g. v1.3.0)
- `application.agentName` - Agent Name
- `application.roleType` - Role Type
- `application.coordinator` - Coordinator
- `application.analyzer` - Analyzer
- `application.executor` - Executor
- `application.reviewer` - Reviewer
- `application.updateSubAgent` - Update Sub Agent
- `application.subAgentMaxLength` - Sub Agent maximum {{maxLength}}
- `application.capabilities` - Capabilities
### 4. Space 空间模块 (5个)
- `space.storageType` - Storage Type
- `space.rag` - RAG storage
- `space.ragDesc` - Based on vector retrieval, suitable for document Q&A and semantic search
- `space.neo4j` - Graph storage
- `space.neo4jDesc` - Based on knowledge graph, suitable for relational reasoning and path query
### 5. MemoryExtractionEngine 记忆提取引擎模块 (4个)
- `memoryExtractionEngine.coreEntitiesAfterDedup` - Core entities after deduplication
- `memoryExtractionEngine.extractRelationalTriples` - Extracted relational triples (partial)
- `memoryExtractionEngine.extractRelationalTriplesDesc` - There are a total of {{count}} segments with clear semantic boundaries
- `memoryExtractionEngine.theEffectOfEntityDisambiguationLLMDriven` - The effect of entity disambiguation (LLM driven)
---
## 🎯 建议
### 优先级 1 - 核心功能模块(需要立即补充)
1. **Role 角色管理** - 完整模块缺失15个键
2. **Tenant 租户管理** - 完整模块缺失20个键
3. **Product 产品管理** - 完整模块缺失13个键
4. **User 用户管理扩展** - 批量导入功能缺失13个键
### 优先级 2 - 功能增强(建议补充)
1. **Application 应用模块** - 多代理相关功能15个键
2. **Space 空间模块** - 存储类型配置5个键
3. **MemoryExtractionEngine** - 实体去重相关4个键
### 优先级 3 - 演示/测试功能(可选)
1. **Home/Workflow/NotFound** - 演示页面30个键
2. **通用计数器功能** - 测试功能5个键
---
## 📝 下一步行动
1. **补充英文翻译**: 优先补充 Role、Tenant、Product、User 模块的英文翻译
2. **补充中文翻译**: 补充 Application、Space、MemoryExtractionEngine 模块的中文翻译
3. **清理无用翻译**: 如果 Home/Workflow 等演示功能不再使用,可以考虑从中文文件中移除
4. **建立翻译规范**: 建议建立翻译键的命名规范和审查流程,避免未来出现遗漏
**报告生成完成**

View File

@@ -2,7 +2,7 @@
* @Author: ZhaoYing
* @Date: 2026-02-03 13:59:45
* @Last Modified by: ZhaoYing
* @Last Modified time: 2026-02-28 16:34:15
* @Last Modified time: 2026-03-03 12:08:42
*/
import { request } from '@/utils/request'
import type { ApplicationModalData } from '@/views/ApplicationManagement/types'
@@ -120,15 +120,19 @@ export const copyApplication = (app_id: string, new_name: string) => {
export const getAppStatistics = (app_id: string, data: { start_date: number; end_date: number; }) => {
return request.get(`/apps/${app_id}/statistics`, data)
}
// 导出工作流
export const exportWorkflow = (app_id: string, fileName: string) => {
return request.downloadFile(`/apps/${app_id}/workflow/export`, fileName, undefined, undefined, 'GET')
}
// 工作流上传+兼容性分析
// Upload workflow and analyze compatibility
export const importWorkflow = (formData: FormData) => {
return request.uploadFile(`/apps/workflow/import`, formData)
}
// 完成工作流导入
// Complete workflow import
export const completeImportWorkflow = (data: { temp_id: string; name?: string; description?: string }) => {
return request.post(`/apps/workflow/import/save`, data)
}
// Get experience config
export const getExperienceConfig = (share_token: string) => {
return request.get(`/public/share/config`, {}, {
headers: {
'Authorization': `Bearer ${localStorage.getItem(`shareToken_${share_token}`)}`
}
})
}

View File

@@ -154,7 +154,7 @@ export const uploadFile = async (data: FormData, options?: UploadFileOptions) =>
// 下载文件
export const downloadFile = async (fileId: string, fileName?: string) => {
const token = cookieUtils.get('authToken');
const url = `${apiPrefix}/files/${fileId}`;
const url = `/api/files/${fileId}`;
try {
const response = await fetch(url, {

Some files were not shown because too many files have changed in this diff Show More