Merge branch 'release/v0.2.6' into feature/memory_zy
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -29,6 +29,7 @@ search_results.json
|
||||
api/migrations/versions
|
||||
tmp
|
||||
files
|
||||
powers/
|
||||
|
||||
# Exclude dep files
|
||||
huggingface.co/
|
||||
|
||||
5
api/app/cache/__init__.py
vendored
5
api/app/cache/__init__.py
vendored
@@ -3,9 +3,8 @@ Cache 缓存模块
|
||||
|
||||
提供各种缓存功能的统一入口
|
||||
"""
|
||||
from .memory import EmotionMemoryCache, ImplicitMemoryCache
|
||||
from .memory import InterestMemoryCache
|
||||
|
||||
__all__ = [
|
||||
"EmotionMemoryCache",
|
||||
"ImplicitMemoryCache",
|
||||
"InterestMemoryCache",
|
||||
]
|
||||
|
||||
6
api/app/cache/memory/__init__.py
vendored
6
api/app/cache/memory/__init__.py
vendored
@@ -3,10 +3,8 @@ Memory 缓存模块
|
||||
|
||||
提供记忆系统相关的缓存功能
|
||||
"""
|
||||
from .emotion_memory import EmotionMemoryCache
|
||||
from .implicit_memory import ImplicitMemoryCache
|
||||
from .interest_memory import InterestMemoryCache
|
||||
|
||||
__all__ = [
|
||||
"EmotionMemoryCache",
|
||||
"ImplicitMemoryCache",
|
||||
"InterestMemoryCache",
|
||||
]
|
||||
|
||||
134
api/app/cache/memory/emotion_memory.py
vendored
134
api/app/cache/memory/emotion_memory.py
vendored
@@ -1,134 +0,0 @@
|
||||
"""
|
||||
Emotion Suggestions Cache
|
||||
|
||||
情绪个性化建议缓存模块
|
||||
用于缓存用户的情绪个性化建议数据
|
||||
"""
|
||||
import json
|
||||
import logging
|
||||
from typing import Optional, Dict, Any
|
||||
from datetime import datetime
|
||||
|
||||
from app.aioRedis import aio_redis
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EmotionMemoryCache:
|
||||
"""情绪建议缓存类"""
|
||||
|
||||
# Key 前缀
|
||||
PREFIX = "cache:memory:emotion_memory"
|
||||
|
||||
@classmethod
|
||||
def _get_key(cls, *parts: str) -> str:
|
||||
"""生成 Redis key
|
||||
|
||||
Args:
|
||||
*parts: key 的各个部分
|
||||
|
||||
Returns:
|
||||
完整的 Redis key
|
||||
"""
|
||||
return ":".join([cls.PREFIX] + list(parts))
|
||||
|
||||
@classmethod
|
||||
async def set_emotion_suggestions(
|
||||
cls,
|
||||
user_id: str,
|
||||
suggestions_data: Dict[str, Any],
|
||||
expire: int = 86400
|
||||
) -> bool:
|
||||
"""设置用户情绪建议缓存
|
||||
|
||||
Args:
|
||||
user_id: 用户ID(end_user_id)
|
||||
suggestions_data: 建议数据字典,包含:
|
||||
- health_summary: 健康状态摘要
|
||||
- suggestions: 建议列表
|
||||
- generated_at: 生成时间(可选)
|
||||
expire: 过期时间(秒),默认24小时(86400秒)
|
||||
|
||||
Returns:
|
||||
是否设置成功
|
||||
"""
|
||||
try:
|
||||
key = cls._get_key("suggestions", user_id)
|
||||
|
||||
# 添加生成时间戳
|
||||
if "generated_at" not in suggestions_data:
|
||||
suggestions_data["generated_at"] = datetime.now().isoformat()
|
||||
|
||||
# 添加缓存标记
|
||||
suggestions_data["cached"] = True
|
||||
|
||||
value = json.dumps(suggestions_data, ensure_ascii=False)
|
||||
await aio_redis.set(key, value, ex=expire)
|
||||
logger.info(f"设置情绪建议缓存成功: {key}, 过期时间: {expire}秒")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"设置情绪建议缓存失败: {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
async def get_emotion_suggestions(cls, user_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""获取用户情绪建议缓存
|
||||
|
||||
Args:
|
||||
user_id: 用户ID(end_user_id)
|
||||
|
||||
Returns:
|
||||
建议数据字典,如果不存在或已过期返回 None
|
||||
"""
|
||||
try:
|
||||
key = cls._get_key("suggestions", user_id)
|
||||
value = await aio_redis.get(key)
|
||||
|
||||
if value:
|
||||
data = json.loads(value)
|
||||
logger.info(f"成功获取情绪建议缓存: {key}")
|
||||
return data
|
||||
|
||||
logger.info(f"情绪建议缓存不存在或已过期: {key}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"获取情绪建议缓存失败: {e}", exc_info=True)
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
async def delete_emotion_suggestions(cls, user_id: str) -> bool:
|
||||
"""删除用户情绪建议缓存
|
||||
|
||||
Args:
|
||||
user_id: 用户ID(end_user_id)
|
||||
|
||||
Returns:
|
||||
是否删除成功
|
||||
"""
|
||||
try:
|
||||
key = cls._get_key("suggestions", user_id)
|
||||
result = await aio_redis.delete(key)
|
||||
logger.info(f"删除情绪建议缓存: {key}, 结果: {result}")
|
||||
return result > 0
|
||||
except Exception as e:
|
||||
logger.error(f"删除情绪建议缓存失败: {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
async def get_suggestions_ttl(cls, user_id: str) -> int:
|
||||
"""获取情绪建议缓存的剩余过期时间
|
||||
|
||||
Args:
|
||||
user_id: 用户ID(end_user_id)
|
||||
|
||||
Returns:
|
||||
剩余秒数,-1表示永不过期,-2表示key不存在
|
||||
"""
|
||||
try:
|
||||
key = cls._get_key("suggestions", user_id)
|
||||
ttl = await aio_redis.ttl(key)
|
||||
logger.debug(f"情绪建议缓存TTL: {key} = {ttl}秒")
|
||||
return ttl
|
||||
except Exception as e:
|
||||
logger.error(f"获取情绪建议缓存TTL失败: {e}")
|
||||
return -2
|
||||
136
api/app/cache/memory/implicit_memory.py
vendored
136
api/app/cache/memory/implicit_memory.py
vendored
@@ -1,136 +0,0 @@
|
||||
"""
|
||||
Implicit Memory Profile Cache
|
||||
|
||||
隐式记忆用户画像缓存模块
|
||||
用于缓存用户的完整画像数据(偏好标签、四维画像、兴趣领域、行为习惯)
|
||||
"""
|
||||
import json
|
||||
import logging
|
||||
from typing import Optional, Dict, Any
|
||||
from datetime import datetime
|
||||
|
||||
from app.aioRedis import aio_redis
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ImplicitMemoryCache:
|
||||
"""隐式记忆用户画像缓存类"""
|
||||
|
||||
# Key 前缀
|
||||
PREFIX = "cache:memory:implicit_memory"
|
||||
|
||||
@classmethod
|
||||
def _get_key(cls, *parts: str) -> str:
|
||||
"""生成 Redis key
|
||||
|
||||
Args:
|
||||
*parts: key 的各个部分
|
||||
|
||||
Returns:
|
||||
完整的 Redis key
|
||||
"""
|
||||
return ":".join([cls.PREFIX] + list(parts))
|
||||
|
||||
@classmethod
|
||||
async def set_user_profile(
|
||||
cls,
|
||||
user_id: str,
|
||||
profile_data: Dict[str, Any],
|
||||
expire: int = 86400
|
||||
) -> bool:
|
||||
"""设置用户完整画像缓存
|
||||
|
||||
Args:
|
||||
user_id: 用户ID(end_user_id)
|
||||
profile_data: 画像数据字典,包含:
|
||||
- preferences: 偏好标签列表
|
||||
- portrait: 四维画像对象
|
||||
- interest_areas: 兴趣领域分布对象
|
||||
- habits: 行为习惯列表
|
||||
- generated_at: 生成时间(可选)
|
||||
expire: 过期时间(秒),默认24小时(86400秒)
|
||||
|
||||
Returns:
|
||||
是否设置成功
|
||||
"""
|
||||
try:
|
||||
key = cls._get_key("profile", user_id)
|
||||
|
||||
# 添加生成时间戳
|
||||
if "generated_at" not in profile_data:
|
||||
profile_data["generated_at"] = datetime.now().isoformat()
|
||||
|
||||
# 添加缓存标记
|
||||
profile_data["cached"] = True
|
||||
|
||||
value = json.dumps(profile_data, ensure_ascii=False)
|
||||
await aio_redis.set(key, value, ex=expire)
|
||||
logger.info(f"设置用户画像缓存成功: {key}, 过期时间: {expire}秒")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"设置用户画像缓存失败: {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
async def get_user_profile(cls, user_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""获取用户完整画像缓存
|
||||
|
||||
Args:
|
||||
user_id: 用户ID(end_user_id)
|
||||
|
||||
Returns:
|
||||
画像数据字典,如果不存在或已过期返回 None
|
||||
"""
|
||||
try:
|
||||
key = cls._get_key("profile", user_id)
|
||||
value = await aio_redis.get(key)
|
||||
|
||||
if value:
|
||||
data = json.loads(value)
|
||||
logger.info(f"成功获取用户画像缓存: {key}")
|
||||
return data
|
||||
|
||||
logger.info(f"用户画像缓存不存在或已过期: {key}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"获取用户画像缓存失败: {e}", exc_info=True)
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
async def delete_user_profile(cls, user_id: str) -> bool:
|
||||
"""删除用户完整画像缓存
|
||||
|
||||
Args:
|
||||
user_id: 用户ID(end_user_id)
|
||||
|
||||
Returns:
|
||||
是否删除成功
|
||||
"""
|
||||
try:
|
||||
key = cls._get_key("profile", user_id)
|
||||
result = await aio_redis.delete(key)
|
||||
logger.info(f"删除用户画像缓存: {key}, 结果: {result}")
|
||||
return result > 0
|
||||
except Exception as e:
|
||||
logger.error(f"删除用户画像缓存失败: {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
async def get_profile_ttl(cls, user_id: str) -> int:
|
||||
"""获取用户画像缓存的剩余过期时间
|
||||
|
||||
Args:
|
||||
user_id: 用户ID(end_user_id)
|
||||
|
||||
Returns:
|
||||
剩余秒数,-1表示永不过期,-2表示key不存在
|
||||
"""
|
||||
try:
|
||||
key = cls._get_key("profile", user_id)
|
||||
ttl = await aio_redis.ttl(key)
|
||||
logger.debug(f"用户画像缓存TTL: {key} = {ttl}秒")
|
||||
return ttl
|
||||
except Exception as e:
|
||||
logger.error(f"获取用户画像缓存TTL失败: {e}")
|
||||
return -2
|
||||
122
api/app/cache/memory/interest_memory.py
vendored
Normal file
122
api/app/cache/memory/interest_memory.py
vendored
Normal file
@@ -0,0 +1,122 @@
|
||||
"""
|
||||
Interest Distribution Cache
|
||||
|
||||
兴趣分布缓存模块
|
||||
用于缓存用户的兴趣分布标签数据,避免重复调用模型生成
|
||||
"""
|
||||
import json
|
||||
import logging
|
||||
from typing import Optional, List, Dict, Any
|
||||
from datetime import datetime
|
||||
|
||||
from app.aioRedis import aio_redis
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 缓存过期时间:24小时
|
||||
INTEREST_CACHE_EXPIRE = 86400
|
||||
|
||||
|
||||
class InterestMemoryCache:
|
||||
"""兴趣分布缓存类"""
|
||||
|
||||
PREFIX = "cache:memory:interest_distribution"
|
||||
|
||||
@classmethod
|
||||
def _get_key(cls, end_user_id: str, language: str) -> str:
|
||||
"""生成 Redis key
|
||||
|
||||
Args:
|
||||
end_user_id: 用户ID
|
||||
language: 语言类型
|
||||
|
||||
Returns:
|
||||
完整的 Redis key
|
||||
"""
|
||||
return f"{cls.PREFIX}:by_user:{end_user_id}:{language}"
|
||||
|
||||
@classmethod
|
||||
async def set_interest_distribution(
|
||||
cls,
|
||||
end_user_id: str,
|
||||
language: str,
|
||||
data: List[Dict[str, Any]],
|
||||
expire: int = INTEREST_CACHE_EXPIRE,
|
||||
) -> bool:
|
||||
"""设置用户兴趣分布缓存
|
||||
|
||||
Args:
|
||||
end_user_id: 用户ID
|
||||
language: 语言类型
|
||||
data: 兴趣分布列表,格式 [{"name": "...", "frequency": ...}, ...]
|
||||
expire: 过期时间(秒),默认24小时
|
||||
|
||||
Returns:
|
||||
是否设置成功
|
||||
"""
|
||||
try:
|
||||
key = cls._get_key(end_user_id, language)
|
||||
payload = {
|
||||
"data": data,
|
||||
"generated_at": datetime.now().isoformat(),
|
||||
"cached": True,
|
||||
}
|
||||
value = json.dumps(payload, ensure_ascii=False)
|
||||
await aio_redis.set(key, value, ex=expire)
|
||||
logger.info(f"设置兴趣分布缓存成功: {key}, 过期时间: {expire}秒")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"设置兴趣分布缓存失败: {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
async def get_interest_distribution(
|
||||
cls,
|
||||
end_user_id: str,
|
||||
language: str,
|
||||
) -> Optional[List[Dict[str, Any]]]:
|
||||
"""获取用户兴趣分布缓存
|
||||
|
||||
Args:
|
||||
end_user_id: 用户ID
|
||||
language: 语言类型
|
||||
|
||||
Returns:
|
||||
兴趣分布列表,缓存不存在或已过期返回 None
|
||||
"""
|
||||
try:
|
||||
key = cls._get_key(end_user_id, language)
|
||||
value = await aio_redis.get(key)
|
||||
if value:
|
||||
payload = json.loads(value)
|
||||
logger.info(f"命中兴趣分布缓存: {key}")
|
||||
return payload.get("data")
|
||||
logger.info(f"兴趣分布缓存不存在或已过期: {key}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"获取兴趣分布缓存失败: {e}", exc_info=True)
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
async def delete_interest_distribution(
|
||||
cls,
|
||||
end_user_id: str,
|
||||
language: str,
|
||||
) -> bool:
|
||||
"""删除用户兴趣分布缓存
|
||||
|
||||
Args:
|
||||
end_user_id: 用户ID
|
||||
language: 语言类型
|
||||
|
||||
Returns:
|
||||
是否删除成功
|
||||
"""
|
||||
try:
|
||||
key = cls._get_key(end_user_id, language)
|
||||
result = await aio_redis.delete(key)
|
||||
logger.info(f"删除兴趣分布缓存: {key}, 结果: {result}")
|
||||
return result > 0
|
||||
except Exception as e:
|
||||
logger.error(f"删除兴趣分布缓存失败: {e}", exc_info=True)
|
||||
return False
|
||||
@@ -1,9 +1,11 @@
|
||||
import os
|
||||
import platform
|
||||
from datetime import timedelta
|
||||
from celery.schedules import crontab
|
||||
from urllib.parse import quote
|
||||
|
||||
from celery import Celery
|
||||
from celery.schedules import crontab
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
@@ -43,8 +45,8 @@ celery_app.conf.update(
|
||||
task_ignore_result=False,
|
||||
|
||||
# 超时设置
|
||||
task_time_limit=1800, # 30分钟硬超时
|
||||
task_soft_time_limit=1500, # 25分钟软超时
|
||||
task_time_limit=3600, # 60分钟硬超时
|
||||
task_soft_time_limit=3000, # 50分钟软超时
|
||||
|
||||
# Worker 设置 (per-worker settings are in docker-compose command line)
|
||||
worker_prefetch_multiplier=1, # Don't hoard tasks, fairer distribution
|
||||
@@ -82,7 +84,8 @@ celery_app.conf.update(
|
||||
'app.tasks.workspace_reflection_task': {'queue': 'periodic_tasks'},
|
||||
'app.tasks.regenerate_memory_cache': {'queue': 'periodic_tasks'},
|
||||
'app.tasks.run_forgetting_cycle_task': {'queue': 'periodic_tasks'},
|
||||
'app.controllers.memory_storage_controller.search_all': {'queue': 'periodic_tasks'},
|
||||
'app.tasks.write_all_workspaces_memory_task': {'queue': 'periodic_tasks'},
|
||||
'app.tasks.update_implicit_emotions_storage': {'queue': 'periodic_tasks'},
|
||||
},
|
||||
)
|
||||
|
||||
@@ -90,11 +93,14 @@ celery_app.conf.update(
|
||||
celery_app.autodiscover_tasks(['app'])
|
||||
|
||||
# Celery Beat schedule for periodic tasks
|
||||
memory_increment_schedule = timedelta(hours=settings.MEMORY_INCREMENT_INTERVAL_HOURS)
|
||||
memory_increment_schedule = crontab(hour=settings.MEMORY_INCREMENT_HOUR, minute=settings.MEMORY_INCREMENT_MINUTE)
|
||||
memory_cache_regeneration_schedule = timedelta(hours=settings.MEMORY_CACHE_REGENERATION_HOURS)
|
||||
# 这个30秒的设计不合理
|
||||
workspace_reflection_schedule = timedelta(seconds=30) # 每30秒运行一次settings.REFLECTION_INTERVAL_TIME
|
||||
forgetting_cycle_schedule = timedelta(hours=24) # 每24小时运行一次遗忘周期
|
||||
workspace_reflection_schedule = timedelta(seconds=settings.WORKSPACE_REFLECTION_INTERVAL_SECONDS)
|
||||
forgetting_cycle_schedule = timedelta(hours=settings.FORGETTING_CYCLE_INTERVAL_HOURS)
|
||||
implicit_emotions_update_schedule = crontab(
|
||||
hour=settings.IMPLICIT_EMOTIONS_UPDATE_HOUR,
|
||||
minute=settings.IMPLICIT_EMOTIONS_UPDATE_MINUTE,
|
||||
)
|
||||
|
||||
#构建定时任务配置
|
||||
beat_schedule_config = {
|
||||
@@ -115,16 +121,16 @@ beat_schedule_config = {
|
||||
"config_id": None, # 使用默认配置,可以通过环境变量配置
|
||||
},
|
||||
},
|
||||
"write-all-workspaces-memory": {
|
||||
"task": "app.tasks.write_all_workspaces_memory_task",
|
||||
"schedule": memory_increment_schedule,
|
||||
"args": (),
|
||||
},
|
||||
"update-implicit-emotions-storage": {
|
||||
"task": "app.tasks.update_implicit_emotions_storage",
|
||||
"schedule": implicit_emotions_update_schedule,
|
||||
"args": (),
|
||||
},
|
||||
}
|
||||
|
||||
#如果配置了默认工作空间ID,则添加记忆总量统计任务
|
||||
if settings.DEFAULT_WORKSPACE_ID:
|
||||
beat_schedule_config["write-total-memory"] = {
|
||||
"task": "app.controllers.memory_storage_controller.search_all",
|
||||
"schedule": memory_increment_schedule,
|
||||
"kwargs": {
|
||||
"workspace_id": settings.DEFAULT_WORKSPACE_ID,
|
||||
},
|
||||
}
|
||||
|
||||
celery_app.conf.beat_schedule = beat_schedule_config
|
||||
|
||||
@@ -396,10 +396,10 @@ async def draft_run(
|
||||
from app.models import AgentConfig, ModelConfig
|
||||
from sqlalchemy import select
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.services.draft_run_service import DraftRunService
|
||||
from app.services.draft_run_service import AgentRunService
|
||||
|
||||
service = AppService(db)
|
||||
draft_service = DraftRunService(db)
|
||||
draft_service = AgentRunService(db)
|
||||
|
||||
# 1. 验证应用
|
||||
app = service._get_app_or_404(app_id)
|
||||
@@ -484,8 +484,8 @@ async def draft_run(
|
||||
}
|
||||
)
|
||||
|
||||
from app.services.draft_run_service import DraftRunService
|
||||
draft_service = DraftRunService(db)
|
||||
from app.services.draft_run_service import AgentRunService
|
||||
draft_service = AgentRunService(db)
|
||||
result = await draft_service.run(
|
||||
agent_config=agent_cfg,
|
||||
model_config=model_config,
|
||||
@@ -789,8 +789,8 @@ async def draft_run_compare(
|
||||
# 流式返回
|
||||
if payload.stream:
|
||||
async def event_generator():
|
||||
from app.services.draft_run_service import DraftRunService
|
||||
draft_service = DraftRunService(db)
|
||||
from app.services.draft_run_service import AgentRunService
|
||||
draft_service = AgentRunService(db)
|
||||
async for event in draft_service.run_compare_stream(
|
||||
agent_config=agent_cfg,
|
||||
models=model_configs,
|
||||
@@ -820,8 +820,8 @@ async def draft_run_compare(
|
||||
)
|
||||
|
||||
# 非流式返回
|
||||
from app.services.draft_run_service import DraftRunService
|
||||
draft_service = DraftRunService(db)
|
||||
from app.services.draft_run_service import AgentRunService
|
||||
draft_service = AgentRunService(db)
|
||||
result = await draft_service.run_compare(
|
||||
agent_config=agent_cfg,
|
||||
models=model_configs,
|
||||
@@ -835,7 +835,8 @@ async def draft_run_compare(
|
||||
web_search=True,
|
||||
memory=True,
|
||||
parallel=payload.parallel,
|
||||
timeout=payload.timeout or 60
|
||||
timeout=payload.timeout or 60,
|
||||
files=payload.files
|
||||
)
|
||||
|
||||
logger.info(
|
||||
|
||||
@@ -441,14 +441,14 @@ async def retrieve_chunks(
|
||||
# 1 participle search, 2 semantic search, 3 hybrid search
|
||||
match retrieve_data.retrieve_type:
|
||||
case chunk_schema.RetrieveType.PARTICIPLE:
|
||||
rs = vector_service.search_by_full_text(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.similarity_threshold)
|
||||
rs = vector_service.search_by_full_text(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.similarity_threshold, file_names_filter=retrieve_data.file_names_filter)
|
||||
return success(data=rs, msg="retrieval successful")
|
||||
case chunk_schema.RetrieveType.SEMANTIC:
|
||||
rs = vector_service.search_by_vector(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.vector_similarity_weight)
|
||||
rs = vector_service.search_by_vector(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.vector_similarity_weight, file_names_filter=retrieve_data.file_names_filter)
|
||||
return success(data=rs, msg="retrieval successful")
|
||||
case _:
|
||||
rs1 = vector_service.search_by_vector(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.vector_similarity_weight)
|
||||
rs2 = vector_service.search_by_full_text(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.similarity_threshold)
|
||||
rs1 = vector_service.search_by_vector(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.vector_similarity_weight, file_names_filter=retrieve_data.file_names_filter)
|
||||
rs2 = vector_service.search_by_full_text(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.similarity_threshold, file_names_filter=retrieve_data.file_names_filter)
|
||||
# Efficient deduplication
|
||||
seen_ids = set()
|
||||
unique_rs = []
|
||||
|
||||
@@ -208,14 +208,64 @@ async def get_emotion_health(
|
||||
|
||||
|
||||
|
||||
# @router.post("/check-data", response_model=ApiResponse)
|
||||
# async def check_emotion_data_exists(
|
||||
# request: EmotionSuggestionsRequest,
|
||||
# db: Session = Depends(get_db),
|
||||
# current_user: User = Depends(get_current_user),
|
||||
# ):
|
||||
# """检查用户情绪建议数据是否存在
|
||||
|
||||
# Args:
|
||||
# request: 包含 end_user_id
|
||||
# db: 数据库会话
|
||||
# current_user: 当前用户
|
||||
|
||||
# Returns:
|
||||
# 数据存在状态
|
||||
# """
|
||||
# try:
|
||||
# api_logger.info(
|
||||
# f"检查用户情绪建议数据是否存在: {request.end_user_id}",
|
||||
# extra={"end_user_id": request.end_user_id}
|
||||
# )
|
||||
|
||||
# # 从数据库获取建议
|
||||
# data = await emotion_service.get_cached_suggestions(
|
||||
# end_user_id=request.end_user_id,
|
||||
# db=db
|
||||
# )
|
||||
|
||||
# if data is None:
|
||||
# api_logger.info(f"用户 {request.end_user_id} 的情绪建议数据不存在")
|
||||
# return fail(
|
||||
# BizCode.NOT_FOUND,
|
||||
# "情绪建议数据不存在,请点击右上角刷新进行初始化",
|
||||
# {"exists": False}
|
||||
# )
|
||||
|
||||
# api_logger.info(f"用户 {request.end_user_id} 的情绪建议数据存在")
|
||||
# return success(data={"exists": True}, msg="情绪建议数据已存在")
|
||||
|
||||
# except Exception as e:
|
||||
# api_logger.error(
|
||||
# f"检查情绪建议数据失败: {str(e)}",
|
||||
# extra={"end_user_id": request.end_user_id},
|
||||
# exc_info=True
|
||||
# )
|
||||
# raise HTTPException(
|
||||
# status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
# detail=f"检查情绪建议数据失败: {str(e)}"
|
||||
# )
|
||||
|
||||
|
||||
@router.post("/suggestions", response_model=ApiResponse)
|
||||
async def get_emotion_suggestions(
|
||||
request: EmotionSuggestionsRequest,
|
||||
language_type: str = Header(default=None, alias="X-Language-Type"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""获取个性化情绪建议(从缓存读取)
|
||||
"""获取个性化情绪建议(从数据库读取)
|
||||
|
||||
Args:
|
||||
request: 包含 end_user_id 和可选的 config_id
|
||||
@@ -223,77 +273,42 @@ async def get_emotion_suggestions(
|
||||
current_user: 当前用户
|
||||
|
||||
Returns:
|
||||
缓存的个性化情绪建议响应
|
||||
存储的个性化情绪建议响应
|
||||
"""
|
||||
try:
|
||||
# 使用集中化的语言校验
|
||||
language = get_language_from_header(language_type)
|
||||
|
||||
api_logger.info(
|
||||
f"用户 {current_user.username} 请求获取个性化情绪建议(缓存)",
|
||||
f"用户 {current_user.username} 请求获取个性化情绪建议",
|
||||
extra={
|
||||
"end_user_id": request.end_user_id,
|
||||
"config_id": request.config_id
|
||||
}
|
||||
)
|
||||
|
||||
# 从缓存获取建议
|
||||
# 从数据库获取建议
|
||||
data = await emotion_service.get_cached_suggestions(
|
||||
end_user_id=request.end_user_id,
|
||||
db=db
|
||||
)
|
||||
|
||||
if data is None:
|
||||
# 缓存不存在或已过期,自动触发生成
|
||||
api_logger.info(
|
||||
f"用户 {request.end_user_id} 的建议缓存不存在或已过期,自动生成新建议",
|
||||
f"用户 {request.end_user_id} 的建议数据不存在",
|
||||
extra={"end_user_id": request.end_user_id}
|
||||
)
|
||||
try:
|
||||
data = await emotion_service.generate_emotion_suggestions(
|
||||
end_user_id=request.end_user_id,
|
||||
db=db,
|
||||
language=language
|
||||
)
|
||||
# 保存到缓存
|
||||
await emotion_service.save_suggestions_cache(
|
||||
end_user_id=request.end_user_id,
|
||||
suggestions_data=data,
|
||||
db=db,
|
||||
expires_hours=24
|
||||
)
|
||||
except (ValueError, KeyError) as gen_e:
|
||||
# 预期内的业务异常:配置缺失、数据格式问题等
|
||||
api_logger.warning(
|
||||
f"自动生成建议失败(业务异常): {str(gen_e)}",
|
||||
extra={"end_user_id": request.end_user_id}
|
||||
)
|
||||
return fail(
|
||||
BizCode.NOT_FOUND,
|
||||
f"自动生成建议失败: {str(gen_e)}",
|
||||
""
|
||||
)
|
||||
except Exception as gen_e:
|
||||
# 非预期异常:记录完整 traceback 便于排查
|
||||
api_logger.error(
|
||||
f"自动生成建议时发生未预期异常: {str(gen_e)}",
|
||||
extra={"end_user_id": request.end_user_id},
|
||||
exc_info=True
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"生成建议时发生内部错误: {str(gen_e)}"
|
||||
)
|
||||
return success(
|
||||
data={"exists": False},
|
||||
msg="情绪建议数据不存在,请点击右上角刷新进行初始化"
|
||||
)
|
||||
|
||||
api_logger.info(
|
||||
"个性化建议获取成功(缓存)",
|
||||
"个性化建议获取成功",
|
||||
extra={
|
||||
"end_user_id": request.end_user_id,
|
||||
"suggestions_count": len(data.get("suggestions", []))
|
||||
}
|
||||
)
|
||||
|
||||
return success(data=data, msg="个性化建议获取成功(缓存)")
|
||||
return success(data=data, msg="个性化建议获取成功")
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(
|
||||
@@ -314,7 +329,7 @@ async def generate_emotion_suggestions(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""生成个性化情绪建议(调用LLM并缓存)
|
||||
"""生成个性化情绪建议(调用LLM并保存到数据库)
|
||||
|
||||
Args:
|
||||
request: 包含 end_user_id
|
||||
@@ -342,12 +357,11 @@ async def generate_emotion_suggestions(
|
||||
language=language
|
||||
)
|
||||
|
||||
# 保存到缓存
|
||||
# 保存到数据库
|
||||
await emotion_service.save_suggestions_cache(
|
||||
end_user_id=request.end_user_id,
|
||||
suggestions_data=data,
|
||||
db=db,
|
||||
expires_hours=24
|
||||
db=db
|
||||
)
|
||||
|
||||
api_logger.info(
|
||||
@@ -369,4 +383,4 @@ async def generate_emotion_suggestions(
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"生成个性化建议失败: {str(e)}"
|
||||
)
|
||||
)
|
||||
@@ -122,6 +122,48 @@ def validate_confidence_threshold(threshold: float) -> None:
|
||||
raise ValueError("confidence_threshold must be between 0.0 and 1.0")
|
||||
|
||||
|
||||
@router.get("/check-data/{end_user_id}", response_model=ApiResponse)
|
||||
@cur_workspace_access_guard()
|
||||
async def check_user_data_exists(
|
||||
end_user_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
) -> ApiResponse:
|
||||
"""
|
||||
检查用户画像数据是否存在
|
||||
|
||||
Args:
|
||||
end_user_id: 目标用户ID
|
||||
|
||||
Returns:
|
||||
数据存在状态
|
||||
"""
|
||||
api_logger.info(f"检查用户画像数据是否存在: {end_user_id}")
|
||||
|
||||
try:
|
||||
# Validate inputs
|
||||
validate_user_id(end_user_id)
|
||||
|
||||
# Create service with user-specific config
|
||||
service = ImplicitMemoryService(db=db, end_user_id=end_user_id)
|
||||
|
||||
# Get cached profile
|
||||
cached_profile = await service.get_cached_profile(end_user_id=end_user_id, db=db)
|
||||
|
||||
if cached_profile is None:
|
||||
api_logger.info(f"用户 {end_user_id} 的画像数据不存在")
|
||||
return success(
|
||||
data={"exists": False},
|
||||
msg="画像数据不存在,请点击右上角刷新进行初始化"
|
||||
)
|
||||
|
||||
api_logger.info(f"用户 {end_user_id} 的画像数据存在")
|
||||
return success(data={"exists": True}, msg="画像数据已存在")
|
||||
|
||||
except Exception as e:
|
||||
return handle_implicit_memory_error(e, "检查画像数据", end_user_id)
|
||||
|
||||
|
||||
@router.get("/preferences/{end_user_id}", response_model=ApiResponse)
|
||||
@cur_workspace_access_guard()
|
||||
async def get_preference_tags(
|
||||
@@ -159,12 +201,8 @@ async def get_preference_tags(
|
||||
cached_profile = await service.get_cached_profile(end_user_id=end_user_id, db=db)
|
||||
|
||||
if cached_profile is None:
|
||||
api_logger.info(f"用户 {end_user_id} 的画像缓存不存在或已过期")
|
||||
return fail(
|
||||
BizCode.NOT_FOUND,
|
||||
"画像缓存不存在或已过期,请右上角刷新生成新画像",
|
||||
""
|
||||
)
|
||||
api_logger.info(f"用户 {end_user_id} 的画像数据不存在")
|
||||
return fail(BizCode.NOT_FOUND, "", "")
|
||||
|
||||
# Extract preferences from cache
|
||||
preferences = cached_profile.get("preferences", [])
|
||||
@@ -230,12 +268,8 @@ async def get_dimension_portrait(
|
||||
cached_profile = await service.get_cached_profile(end_user_id=end_user_id, db=db)
|
||||
|
||||
if cached_profile is None:
|
||||
api_logger.info(f"用户 {end_user_id} 的画像缓存不存在或已过期")
|
||||
return fail(
|
||||
BizCode.NOT_FOUND,
|
||||
"画像缓存不存在或已过期,请右上角刷新生成新画像",
|
||||
""
|
||||
)
|
||||
api_logger.info(f"用户 {end_user_id} 的画像数据不存在")
|
||||
return fail(BizCode.NOT_FOUND, "", "")
|
||||
|
||||
# Extract portrait from cache
|
||||
portrait = cached_profile.get("portrait", {})
|
||||
@@ -278,12 +312,8 @@ async def get_interest_area_distribution(
|
||||
cached_profile = await service.get_cached_profile(end_user_id=end_user_id, db=db)
|
||||
|
||||
if cached_profile is None:
|
||||
api_logger.info(f"用户 {end_user_id} 的画像缓存不存在或已过期")
|
||||
return fail(
|
||||
BizCode.NOT_FOUND,
|
||||
"画像缓存不存在或已过期,请右上角刷新生成新画像",
|
||||
""
|
||||
)
|
||||
api_logger.info(f"用户 {end_user_id} 的画像数据不存在")
|
||||
return fail(BizCode.NOT_FOUND, "", "")
|
||||
|
||||
# Extract interest areas from cache
|
||||
interest_areas = cached_profile.get("interest_areas", {})
|
||||
@@ -330,12 +360,8 @@ async def get_behavior_habits(
|
||||
cached_profile = await service.get_cached_profile(end_user_id=end_user_id, db=db)
|
||||
|
||||
if cached_profile is None:
|
||||
api_logger.info(f"用户 {end_user_id} 的画像缓存不存在或已过期")
|
||||
return fail(
|
||||
BizCode.NOT_FOUND,
|
||||
"画像缓存不存在或已过期,请右上角刷新生成新画像",
|
||||
""
|
||||
)
|
||||
api_logger.info(f"用户 {end_user_id} 的画像数据不存在")
|
||||
return fail(BizCode.NOT_FOUND, "", "")
|
||||
|
||||
# Extract habits from cache
|
||||
habits = cached_profile.get("habits", [])
|
||||
|
||||
@@ -90,7 +90,7 @@ async def get_mcp_servers(
|
||||
cookies=cookies)
|
||||
raise_for_http_status(r)
|
||||
except requests.exceptions.RequestException as e:
|
||||
api_logger.error(f"mFailed to get MCP servers: {str(e)}")
|
||||
api_logger.error(f"Failed to get MCP servers: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to get MCP servers: {str(e)}"
|
||||
@@ -118,6 +118,65 @@ async def get_mcp_servers(
|
||||
return success(data=result, msg="Query of mcp servers list successful")
|
||||
|
||||
|
||||
@router.get("/operational_mcp_servers", response_model=ApiResponse)
|
||||
async def get_operational_mcp_servers(
|
||||
mcp_market_config_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Query the operational mcp servers list in pages
|
||||
- Support keyword search for name,author,owner
|
||||
- Return paging metadata + operational mcp server list
|
||||
"""
|
||||
api_logger.info(
|
||||
f"Query operational mcp server list: tenant_id={current_user.tenant_id}, username: {current_user.username}")
|
||||
|
||||
# 1. Query mcp market config information from the database
|
||||
api_logger.debug(f"Query mcp market config: {mcp_market_config_id}")
|
||||
db_mcp_market_config = mcp_market_config_service.get_mcp_market_config_by_id(db,
|
||||
mcp_market_config_id=mcp_market_config_id,
|
||||
current_user=current_user)
|
||||
if not db_mcp_market_config:
|
||||
api_logger.warning(
|
||||
f"The mcp market config does not exist or access is denied: mcp_market_config_id={mcp_market_config_id}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="The mcp market config does not exist or access is denied"
|
||||
)
|
||||
|
||||
# 2. Execute paged query
|
||||
api = MCPApi()
|
||||
token = db_mcp_market_config.token
|
||||
api.login(token)
|
||||
|
||||
url = f'{api.mcp_base_url}/operational'
|
||||
headers = api.builder_headers(api.headers)
|
||||
|
||||
try:
|
||||
cookies = api.get_cookies(access_token=token, cookies_required=True)
|
||||
r = api.session.get(url, headers=headers, cookies=cookies)
|
||||
raise_for_http_status(r)
|
||||
except requests.exceptions.RequestException as e:
|
||||
api_logger.error(f"Failed to get operational MCP servers: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to get operational MCP servers: {str(e)}"
|
||||
)
|
||||
|
||||
data = api._handle_response(r)
|
||||
total = data.get('total_count', 0)
|
||||
mcp_server_list = data.get('mcp_server_list', [])
|
||||
# items = [{
|
||||
# 'name': item.get('name', ''),
|
||||
# 'id': item.get('id', ''),
|
||||
# 'description': item.get('description', '')
|
||||
# } for item in mcp_server_list]
|
||||
|
||||
# 3. Return structured response
|
||||
return success(data=mcp_server_list, msg="Query of operational mcp servers list successful")
|
||||
|
||||
|
||||
@router.get("/mcp_server", response_model=ApiResponse)
|
||||
async def get_mcp_server(
|
||||
mcp_market_config_id: uuid.UUID,
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from typing import List, Optional
|
||||
|
||||
from app.cache.memory.interest_memory import InterestMemoryCache
|
||||
from app.celery_app import celery_app
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.language_utils import get_language_from_header
|
||||
@@ -633,12 +634,11 @@ async def get_knowledge_type_stats_api(
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
统计当前空间下各知识库类型的数量,包含 General | Web | Third-party | Folder | Memory。
|
||||
统计当前空间下各知识库类型的数量,包含 General | Web | Third-party | Folder。
|
||||
会对缺失类型补 0,返回字典形式。
|
||||
可选按状态过滤。
|
||||
- 知识库类型根据当前用户的 current_workspace_id 过滤
|
||||
- Memory 是 Neo4j 中 Chunk 的数量,根据 end_user_id (end_user_id) 过滤
|
||||
- 如果用户没有当前工作空间或未提供 end_user_id,对应的统计返回 0
|
||||
- 如果用户没有当前工作空间,对应的统计返回 0
|
||||
"""
|
||||
api_logger.info(f"Knowledge type stats requested for workspace_id: {current_user.current_workspace_id}, end_user_id: {end_user_id}")
|
||||
try:
|
||||
@@ -662,34 +662,56 @@ async def get_knowledge_type_stats_api(
|
||||
return fail(BizCode.INTERNAL_ERROR, "获取知识库类型统计失败", str(e))
|
||||
|
||||
|
||||
@router.get("/analytics/hot_memory_tags/by_user", response_model=ApiResponse)
|
||||
async def get_hot_memory_tags_by_user_api(
|
||||
end_user_id: Optional[str] = Query(None, description="用户ID(可选)"),
|
||||
limit: int = Query(20, description="返回标签数量限制"),
|
||||
@router.get("/analytics/interest_distribution/by_user", response_model=ApiResponse)
|
||||
async def get_interest_distribution_by_user_api(
|
||||
end_user_id: str = Query(..., description="用户ID(必填)"),
|
||||
limit: int = Query(5, le=5, description="返回兴趣标签数量限制,最多5个"),
|
||||
language_type: str = Header(default=None, alias="X-Language-Type"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session=Depends(get_db),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
获取指定用户的热门记忆标签
|
||||
获取指定用户的兴趣分布标签
|
||||
|
||||
注意:标签语言由写入时的 X-Language-Type 决定,查询时不进行翻译
|
||||
与热门标签不同,此接口专注于识别用户的兴趣活动(运动、爱好、学习、创作等),
|
||||
过滤掉纯物品、工具、地点等不代表用户主动参与活动的名词。
|
||||
|
||||
返回格式:
|
||||
[
|
||||
{"name": "标签名", "frequency": 频次},
|
||||
{"name": "兴趣活动名", "frequency": 频次},
|
||||
...
|
||||
]
|
||||
"""
|
||||
api_logger.info(f"Hot memory tags by user requested: end_user_id={end_user_id}")
|
||||
language = get_language_from_header(language_type)
|
||||
api_logger.info(f"Interest distribution by user requested: end_user_id={end_user_id}, language={language}")
|
||||
try:
|
||||
result = await memory_agent_service.get_hot_memory_tags_by_user(
|
||||
# 优先读取缓存
|
||||
cached = await InterestMemoryCache.get_interest_distribution(
|
||||
end_user_id=end_user_id,
|
||||
limit=limit
|
||||
language=language,
|
||||
)
|
||||
return success(data=result, msg="获取热门记忆标签成功")
|
||||
if cached is not None:
|
||||
api_logger.info(f"Interest distribution cache hit: end_user_id={end_user_id}")
|
||||
return success(data=cached, msg="获取兴趣分布标签成功")
|
||||
|
||||
# 缓存未命中,调用模型生成
|
||||
result = await memory_agent_service.get_interest_distribution_by_user(
|
||||
end_user_id=end_user_id,
|
||||
limit=limit,
|
||||
language=language
|
||||
)
|
||||
|
||||
# 写入缓存,24小时过期
|
||||
await InterestMemoryCache.set_interest_distribution(
|
||||
end_user_id=end_user_id,
|
||||
language=language,
|
||||
data=result,
|
||||
)
|
||||
|
||||
return success(data=result, msg="获取兴趣分布标签成功")
|
||||
except Exception as e:
|
||||
api_logger.error(f"Hot memory tags by user failed: {str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "获取热门记忆标签失败", str(e))
|
||||
api_logger.error(f"Interest distribution by user failed: {str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "获取兴趣分布标签失败", str(e))
|
||||
|
||||
|
||||
@router.get("/analytics/user_profile", response_model=ApiResponse)
|
||||
|
||||
@@ -9,6 +9,7 @@ from app.schemas.response_schema import ApiResponse
|
||||
|
||||
from app.services import memory_dashboard_service, memory_storage_service, workspace_service
|
||||
from app.services.memory_agent_service import get_end_users_connected_configs_batch
|
||||
from app.services.app_statistics_service import AppStatisticsService
|
||||
from app.core.logging_config import get_api_logger
|
||||
|
||||
# 获取API专用日志器
|
||||
@@ -469,6 +470,8 @@ async def get_chunk_insight(
|
||||
@router.get("/dashboard_data", response_model=ApiResponse)
|
||||
async def dashboard_data(
|
||||
end_user_id: Optional[str] = Query(None, description="可选的用户ID"),
|
||||
start_date: Optional[int] = Query(None, description="开始时间戳(毫秒)"),
|
||||
end_date: Optional[int] = Query(None, description="结束时间戳(毫秒)"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
@@ -503,6 +506,15 @@ async def dashboard_data(
|
||||
workspace_id = current_user.current_workspace_id
|
||||
api_logger.info(f"用户 {current_user.username} 请求获取工作空间 {workspace_id} 的dashboard整合数据")
|
||||
|
||||
# 如果没有提供时间范围,默认使用最近30天
|
||||
if start_date is None or end_date is None:
|
||||
from datetime import datetime, timedelta
|
||||
end_dt = datetime.now()
|
||||
start_dt = end_dt - timedelta(days=30)
|
||||
end_date = int(end_dt.timestamp() * 1000)
|
||||
start_date = int(start_dt.timestamp() * 1000)
|
||||
api_logger.info(f"使用默认时间范围: {start_dt} 到 {end_dt}")
|
||||
|
||||
# 获取 storage_type,如果为 None 则使用默认值
|
||||
storage_type = workspace_service.get_workspace_storage_type(
|
||||
db=db,
|
||||
@@ -563,17 +575,22 @@ async def dashboard_data(
|
||||
except Exception as e:
|
||||
api_logger.warning(f"获取知识库类型统计失败: {str(e)}")
|
||||
|
||||
# 3. 获取API调用增量(total_api_call,转换为整数)
|
||||
# 3. 获取API调用统计(total_api_call)
|
||||
try:
|
||||
api_increment = memory_dashboard_service.get_workspace_api_increment(
|
||||
db=db,
|
||||
# 使用 AppStatisticsService 获取真实的API调用统计
|
||||
app_stats_service = AppStatisticsService(db)
|
||||
api_stats = app_stats_service.get_workspace_api_statistics(
|
||||
workspace_id=workspace_id,
|
||||
current_user=current_user
|
||||
start_date=start_date,
|
||||
end_date=end_date
|
||||
)
|
||||
neo4j_data["total_api_call"] = api_increment
|
||||
api_logger.info(f"成功获取API调用增量: {neo4j_data['total_api_call']}")
|
||||
# 计算总调用次数
|
||||
total_api_calls = sum(item.get("total_calls", 0) for item in api_stats)
|
||||
neo4j_data["total_api_call"] = total_api_calls
|
||||
api_logger.info(f"成功获取API调用统计: {neo4j_data['total_api_call']}")
|
||||
except Exception as e:
|
||||
api_logger.warning(f"获取API调用增量失败: {str(e)}")
|
||||
api_logger.error(f"获取API调用统计失败: {str(e)}")
|
||||
neo4j_data["total_api_call"] = 0
|
||||
|
||||
result["neo4j_data"] = neo4j_data
|
||||
api_logger.info("成功获取neo4j_data")
|
||||
@@ -589,8 +606,8 @@ async def dashboard_data(
|
||||
|
||||
# 获取RAG相关数据
|
||||
try:
|
||||
# total_memory: 使用 total_chunk(总chunk数)
|
||||
total_chunk = memory_dashboard_service.get_rag_total_chunk(db, current_user)
|
||||
# total_memory: 只统计用户知识库(permission_id='Memory')的chunk数
|
||||
total_chunk = memory_dashboard_service.get_rag_user_kb_total_chunk(db, current_user)
|
||||
rag_data["total_memory"] = total_chunk
|
||||
|
||||
# total_app: 统计当前空间下的所有app数量
|
||||
@@ -602,10 +619,23 @@ async def dashboard_data(
|
||||
total_kb = memory_dashboard_service.get_rag_total_kb(db, current_user)
|
||||
rag_data["total_knowledge"] = total_kb
|
||||
|
||||
# total_api_call: 固定值
|
||||
rag_data["total_api_call"] = 1024
|
||||
# total_api_call: 使用 AppStatisticsService 获取真实的API调用统计
|
||||
try:
|
||||
app_stats_service = AppStatisticsService(db)
|
||||
api_stats = app_stats_service.get_workspace_api_statistics(
|
||||
workspace_id=workspace_id,
|
||||
start_date=start_date,
|
||||
end_date=end_date
|
||||
)
|
||||
# 计算总调用次数
|
||||
total_api_calls = sum(item.get("total_calls", 0) for item in api_stats)
|
||||
rag_data["total_api_call"] = total_api_calls
|
||||
api_logger.info(f"成功获取RAG模式API调用统计: {rag_data['total_api_call']}")
|
||||
except Exception as e:
|
||||
api_logger.warning(f"获取RAG模式API调用统计失败,使用默认值: {str(e)}")
|
||||
rag_data["total_api_call"] = 0
|
||||
|
||||
api_logger.info(f"成功获取RAG相关数据: memory={total_chunk}, app={len(apps_orm)}, knowledge={total_kb}")
|
||||
api_logger.info(f"成功获取RAG相关数据: memory={total_chunk}, app={len(apps_orm)}, knowledge={total_kb}, api_calls={rag_data['total_api_call']}")
|
||||
except Exception as e:
|
||||
api_logger.warning(f"获取RAG相关数据失败: {str(e)}")
|
||||
|
||||
|
||||
@@ -1,16 +1,18 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException, status,Header
|
||||
from typing import Optional
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from fastapi import APIRouter, Depends, Header, HTTPException, status
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.language_utils import get_language_from_header
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.core.response_utils import success
|
||||
from app.db import get_db
|
||||
from app.dependencies import get_current_user
|
||||
from app.models.user_model import User
|
||||
|
||||
from app.services.memory_short_service import LongService, ShortService
|
||||
from app.services.memory_storage_service import search_entity
|
||||
from app.services.memory_short_service import ShortService,LongService
|
||||
from dotenv import load_dotenv
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import Optional
|
||||
|
||||
load_dotenv()
|
||||
api_logger = get_api_logger()
|
||||
|
||||
@@ -29,11 +31,11 @@ async def short_term_configs(
|
||||
language = get_language_from_header(language_type)
|
||||
|
||||
# 获取短期记忆数据
|
||||
short_term=ShortService(end_user_id)
|
||||
short_term=ShortService(end_user_id, db)
|
||||
short_result=short_term.get_short_databasets()
|
||||
short_count=short_term.get_short_count()
|
||||
|
||||
long_term=LongService(end_user_id)
|
||||
long_term=LongService(end_user_id, db)
|
||||
long_result=long_term.get_long_databasets()
|
||||
|
||||
entity_result = await search_entity(end_user_id)
|
||||
|
||||
@@ -469,7 +469,9 @@ async def create_model_api_key_by_provider(
|
||||
config=api_key_data.config,
|
||||
is_active=api_key_data.is_active,
|
||||
priority=api_key_data.priority,
|
||||
model_config_ids=model_config_ids
|
||||
model_config_ids=model_config_ids,
|
||||
capability=api_key_data.capability,
|
||||
is_omni=api_key_data.is_omni
|
||||
)
|
||||
created_keys, failed_models = await ModelApiKeyService.create_api_key_by_provider(db=db, data=create_data)
|
||||
|
||||
|
||||
@@ -124,15 +124,23 @@ def _get_ontology_service(
|
||||
)
|
||||
|
||||
# 通过 Repository 获取可用的 API Key(负载均衡逻辑由 Repository 处理)
|
||||
from app.repositories.model_repository import ModelApiKeyRepository
|
||||
api_keys = ModelApiKeyRepository.get_by_model_config(db, model_config.id)
|
||||
if not api_keys:
|
||||
# from app.repositories.model_repository import ModelApiKeyRepository
|
||||
from app.services.model_service import ModelApiKeyService
|
||||
api_key_config = ModelApiKeyService.get_available_api_key(db, model_config.id)
|
||||
if not api_key_config:
|
||||
logger.error(f"Model {llm_id} has no active API key")
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="指定的LLM模型没有可用的API密钥"
|
||||
)
|
||||
api_key_config = api_keys[0]
|
||||
# api_keys = ModelApiKeyRepository.get_by_model_config(db, model_config.id)
|
||||
# if not api_keys:
|
||||
# logger.error(f"Model {llm_id} has no active API key")
|
||||
# raise HTTPException(
|
||||
# status_code=400,
|
||||
# detail="指定的LLM模型没有可用的API密钥"
|
||||
# )
|
||||
# api_key_config = api_keys[0]
|
||||
|
||||
is_composite = getattr(model_config, 'is_composite', False)
|
||||
logger.info(
|
||||
@@ -154,6 +162,7 @@ def _get_ontology_service(
|
||||
provider=actual_provider,
|
||||
api_key=api_key_config.api_key,
|
||||
base_url=api_key_config.api_base,
|
||||
is_omni=api_key_config.is_omni,
|
||||
max_retries=3,
|
||||
timeout=60.0
|
||||
)
|
||||
@@ -514,10 +523,9 @@ async def delete_scene(
|
||||
f"尝试删除系统默认场景: user_id={current_user.id}, "
|
||||
f"scene_id={scene_id}, scene_name={scene.scene_name}"
|
||||
)
|
||||
return fail(
|
||||
BizCode.BAD_REQUEST,
|
||||
"系统默认场景不可删除",
|
||||
"该场景为系统预设场景,不允许删除"
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="SYSTEM_DEFAULT_SCENE_CANNOT_DELETE"
|
||||
)
|
||||
|
||||
# 创建OntologyService实例
|
||||
@@ -543,6 +551,9 @@ async def delete_scene(
|
||||
|
||||
return success(data={"deleted": success_flag}, msg="场景删除成功")
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
|
||||
except ValueError as e:
|
||||
api_logger.warning(f"Validation error in scene deletion: {str(e)}")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", str(e))
|
||||
|
||||
@@ -2,25 +2,32 @@ import hashlib
|
||||
import json
|
||||
import uuid
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends, Query, Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.core.response_utils import success
|
||||
from app.core.response_utils import success, fail
|
||||
from app.db import get_db, get_db_read
|
||||
from app.dependencies import get_share_user_id, ShareTokenData
|
||||
from app.models.app_model import App
|
||||
from app.models.app_model import AppType
|
||||
from app.repositories import knowledge_repository
|
||||
from app.repositories.end_user_repository import EndUserRepository
|
||||
from app.repositories.workflow_repository import WorkflowConfigRepository
|
||||
from app.schemas import release_share_schema, conversation_schema
|
||||
from app.schemas.response_schema import PageData, PageMeta
|
||||
from app.services import workspace_service
|
||||
from app.services.app_chat_service import AppChatService, get_app_chat_service
|
||||
from app.services.auth_service import create_access_token
|
||||
from app.services.conversation_service import ConversationService
|
||||
from app.services.release_share_service import ReleaseShareService
|
||||
from app.services.shared_chat_service import SharedChatService
|
||||
from app.services.app_chat_service import AppChatService, get_app_chat_service
|
||||
from app.utils.app_config_utils import dict_to_multi_agent_config, workflow_config_4_app_release, \
|
||||
from app.services.workflow_service import WorkflowService
|
||||
from app.utils.app_config_utils import workflow_config_4_app_release, \
|
||||
agent_config_4_app_release, multi_agent_config_4_app_release
|
||||
|
||||
router = APIRouter(prefix="/public/share", tags=["Public Share"])
|
||||
@@ -206,15 +213,13 @@ def list_conversations(
|
||||
logger.debug(f"share_data:{share_data.user_id}")
|
||||
other_id = share_data.user_id
|
||||
service = SharedChatService(db)
|
||||
share, release = service._get_release_by_share_token(share_data.share_token, password)
|
||||
from app.repositories.end_user_repository import EndUserRepository
|
||||
share, release = service.get_release_by_share_token(share_data.share_token, password)
|
||||
end_user_repo = EndUserRepository(db)
|
||||
new_end_user = end_user_repo.get_or_create_end_user(
|
||||
app_id=share.app_id,
|
||||
other_id=other_id
|
||||
)
|
||||
logger.debug(new_end_user.id)
|
||||
service = SharedChatService(db)
|
||||
conversations, total = service.list_conversations(
|
||||
share_token=share_data.share_token,
|
||||
user_id=str(new_end_user.id),
|
||||
@@ -293,19 +298,15 @@ async def chat(
|
||||
|
||||
# 提前验证和准备(在流式响应开始前完成)
|
||||
# 这样可以确保错误能正确返回,而不是在流式响应中间出错
|
||||
from app.models.app_model import AppType
|
||||
|
||||
try:
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.error_codes import BizCode
|
||||
from app.services.app_service import AppService
|
||||
# 验证分享链接和密码
|
||||
share, release = service._get_release_by_share_token(share_token, password)
|
||||
share, release = service.get_release_by_share_token(share_token, password)
|
||||
|
||||
# # Create end_user_id by concatenating app_id with user_id
|
||||
# end_user_id = f"{share.app_id}_{user_id}"
|
||||
|
||||
# Store end_user_id in database with original user_id
|
||||
from app.repositories.end_user_repository import EndUserRepository
|
||||
end_user_repo = EndUserRepository(db)
|
||||
new_end_user = end_user_repo.get_or_create_end_user(
|
||||
app_id=share.app_id,
|
||||
@@ -318,7 +319,6 @@ async def chat(
|
||||
"""获取存储类型和工作空间的ID"""
|
||||
|
||||
# 直接通过 SQLAlchemy 查询 app(仅查询未删除的应用)
|
||||
from app.models.app_model import App
|
||||
app = db.query(App).filter(
|
||||
App.id == appid,
|
||||
App.is_active.is_(True)
|
||||
@@ -359,12 +359,12 @@ async def chat(
|
||||
app_type = release.app.type if release.app else None
|
||||
|
||||
# 根据应用类型验证配置
|
||||
if app_type == "agent":
|
||||
if app_type == AppType.AGENT:
|
||||
# Agent 类型:验证模型配置
|
||||
model_config_id = release.default_model_config_id
|
||||
if not model_config_id:
|
||||
raise BusinessException("Agent 应用未配置模型", BizCode.AGENT_CONFIG_MISSING)
|
||||
elif app_type == "multi_agent":
|
||||
elif app_type == AppType.MULTI_AGENT:
|
||||
# Multi-Agent 类型:验证多 Agent 配置
|
||||
config = release.config or {}
|
||||
if not config.get("sub_agents"):
|
||||
@@ -638,6 +638,34 @@ async def chat(
|
||||
# return success(data=conversation_schema.ChatResponse(**result).model_dump(mode="json"))
|
||||
|
||||
else:
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.error_codes import BizCode
|
||||
raise BusinessException(f"不支持的应用类型: {app_type}", BizCode.APP_TYPE_NOT_SUPPORTED)
|
||||
|
||||
|
||||
@router.get("/config", summary="获取应用启动配置")
|
||||
async def config_query(
|
||||
password: str = Query(None, description="访问密码"),
|
||||
share_data: ShareTokenData = Depends(get_share_user_id),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
share_service = SharedChatService(db)
|
||||
share_token = share_data.share_token
|
||||
share, release = share_service.get_release_by_share_token(share_token, password)
|
||||
if release.app.type == AppType.WORKFLOW:
|
||||
workflow_service = WorkflowService(db)
|
||||
content = {
|
||||
"app_type": release.app.type,
|
||||
"variables": workflow_service.get_start_node_variables(release.config)
|
||||
}
|
||||
elif release.app.type == AppType.AGENT:
|
||||
content = {
|
||||
"app_type": release.app.type,
|
||||
"variables": release.config.get("variables")
|
||||
}
|
||||
elif release.app.type == AppType.MULTI_AGENT:
|
||||
content = {
|
||||
"app_type": release.app.type,
|
||||
"variables": []
|
||||
}
|
||||
else:
|
||||
return fail(msg="Unsupported app type", code=BizCode.APP_TYPE_NOT_SUPPORTED)
|
||||
return success(data=content)
|
||||
|
||||
@@ -249,6 +249,7 @@ async def chat(
|
||||
app_id=app.id,
|
||||
workspace_id=workspace_id,
|
||||
release_id=app.current_release.id,
|
||||
public=True
|
||||
):
|
||||
event_type = event.get("event", "message")
|
||||
event_data = event.get("data", {})
|
||||
|
||||
@@ -39,7 +39,7 @@ async def write_memory_api_service(
|
||||
|
||||
Stores memory content for the specified end user using the Memory API Service.
|
||||
"""
|
||||
logger.info(f"Memory write request - end_user_id: {payload.end_user_id}, tenant_id: {api_key_auth.tenant_id}")
|
||||
logger.info(f"Memory write request - end_user_id: {payload.end_user_id}, workspace_id: {api_key_auth.workspace_id}")
|
||||
|
||||
memory_api_service = MemoryAPIService(db)
|
||||
|
||||
|
||||
@@ -11,35 +11,37 @@ LangChain Agent 封装
|
||||
import time
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional, Sequence
|
||||
|
||||
from app.core.memory.agent.langgraph_graph.write_graph import write_long_term
|
||||
from app.core.memory.agent.langgraph_graph.write_graph import write_long_term
|
||||
from app.db import get_db
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.core.models import RedBearLLM, RedBearModelConfig
|
||||
from app.models.models_model import ModelType
|
||||
from app.models.models_model import ModelType, ModelProvider
|
||||
from app.services.memory_agent_service import (
|
||||
get_end_user_connected_config,
|
||||
)
|
||||
from langchain.agents import create_agent
|
||||
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
|
||||
from langchain_core.tools import BaseTool
|
||||
|
||||
logger = get_business_logger()
|
||||
|
||||
|
||||
class LangChainAgent:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str,
|
||||
api_key: str,
|
||||
provider: str = "openai",
|
||||
api_base: Optional[str] = None,
|
||||
temperature: float = 0.7,
|
||||
max_tokens: int = 2000,
|
||||
system_prompt: Optional[str] = None,
|
||||
tools: Optional[Sequence[BaseTool]] = None,
|
||||
streaming: bool = False,
|
||||
max_iterations: Optional[int] = None, # 最大迭代次数(None 表示自动计算)
|
||||
max_tool_consecutive_calls: int = 3 # 单个工具最大连续调用次数
|
||||
self,
|
||||
model_name: str,
|
||||
api_key: str,
|
||||
provider: str = "openai",
|
||||
api_base: Optional[str] = None,
|
||||
is_omni: bool = False,
|
||||
temperature: float = 0.7,
|
||||
max_tokens: int = 2000,
|
||||
system_prompt: Optional[str] = None,
|
||||
tools: Optional[Sequence[BaseTool]] = None,
|
||||
streaming: bool = False,
|
||||
max_iterations: Optional[int] = None, # 最大迭代次数(None 表示自动计算)
|
||||
max_tool_consecutive_calls: int = 3 # 单个工具最大连续调用次数
|
||||
):
|
||||
"""初始化 LangChain Agent
|
||||
|
||||
@@ -60,12 +62,13 @@ class LangChainAgent:
|
||||
self.provider = provider
|
||||
self.tools = tools or []
|
||||
self.streaming = streaming
|
||||
self.is_omni = is_omni
|
||||
self.max_tool_consecutive_calls = max_tool_consecutive_calls
|
||||
|
||||
|
||||
# 工具调用计数器:记录每个工具的连续调用次数
|
||||
self.tool_call_counter: Dict[str, int] = {}
|
||||
self.last_tool_called: Optional[str] = None
|
||||
|
||||
|
||||
# 根据工具数量动态调整最大迭代次数
|
||||
# 基础值 + 每个工具额外的调用机会
|
||||
if max_iterations is None:
|
||||
@@ -73,9 +76,9 @@ class LangChainAgent:
|
||||
self.max_iterations = 5 + len(self.tools) * 2
|
||||
else:
|
||||
self.max_iterations = max_iterations
|
||||
|
||||
|
||||
self.system_prompt = system_prompt or "你是一个专业的AI助手"
|
||||
|
||||
|
||||
logger.debug(
|
||||
f"Agent 迭代次数配置: max_iterations={self.max_iterations}, "
|
||||
f"tool_count={len(self.tools)}, "
|
||||
@@ -89,6 +92,7 @@ class LangChainAgent:
|
||||
provider=provider,
|
||||
api_key=api_key,
|
||||
base_url=api_base,
|
||||
is_omni=is_omni,
|
||||
extra_params={
|
||||
"temperature": temperature,
|
||||
"max_tokens": max_tokens,
|
||||
@@ -143,21 +147,22 @@ class LangChainAgent:
|
||||
"""
|
||||
from langchain_core.tools import StructuredTool
|
||||
from functools import wraps
|
||||
|
||||
|
||||
wrapped_tools = []
|
||||
|
||||
|
||||
for original_tool in tools:
|
||||
tool_name = original_tool.name
|
||||
original_func = original_tool.func if hasattr(original_tool, 'func') else None
|
||||
|
||||
|
||||
if not original_func:
|
||||
# 如果无法获取原始函数,直接使用原工具
|
||||
wrapped_tools.append(original_tool)
|
||||
continue
|
||||
|
||||
|
||||
# 创建包装函数
|
||||
def make_wrapped_func(tool_name, original_func):
|
||||
"""创建包装函数的工厂函数,避免闭包问题"""
|
||||
|
||||
@wraps(original_func)
|
||||
def wrapped_func(*args, **kwargs):
|
||||
"""包装后的工具函数,跟踪连续调用次数"""
|
||||
@@ -168,13 +173,13 @@ class LangChainAgent:
|
||||
# 切换到新工具,重置计数器
|
||||
self.tool_call_counter[tool_name] = 1
|
||||
self.last_tool_called = tool_name
|
||||
|
||||
|
||||
current_count = self.tool_call_counter[tool_name]
|
||||
|
||||
|
||||
logger.debug(
|
||||
f"工具调用: {tool_name}, 连续调用次数: {current_count}/{self.max_tool_consecutive_calls}"
|
||||
)
|
||||
|
||||
|
||||
# 检查是否超过最大连续调用次数
|
||||
if current_count > self.max_tool_consecutive_calls:
|
||||
logger.warning(
|
||||
@@ -185,12 +190,12 @@ class LangChainAgent:
|
||||
f"工具 '{tool_name}' 已连续调用 {self.max_tool_consecutive_calls} 次,"
|
||||
f"未找到有效结果。请尝试其他方法或直接回答用户的问题。"
|
||||
)
|
||||
|
||||
|
||||
# 调用原始工具函数
|
||||
return original_func(*args, **kwargs)
|
||||
|
||||
|
||||
return wrapped_func
|
||||
|
||||
|
||||
# 使用 StructuredTool 创建新工具
|
||||
wrapped_tool = StructuredTool(
|
||||
name=original_tool.name,
|
||||
@@ -198,17 +203,17 @@ class LangChainAgent:
|
||||
func=make_wrapped_func(tool_name, original_func),
|
||||
args_schema=original_tool.args_schema if hasattr(original_tool, 'args_schema') else None
|
||||
)
|
||||
|
||||
|
||||
wrapped_tools.append(wrapped_tool)
|
||||
|
||||
|
||||
return wrapped_tools
|
||||
|
||||
def _prepare_messages(
|
||||
self,
|
||||
message: str,
|
||||
history: Optional[List[Dict[str, str]]] = None,
|
||||
context: Optional[str] = None,
|
||||
files: Optional[List[Dict[str, Any]]] = None
|
||||
self,
|
||||
message: str,
|
||||
history: Optional[List[Dict[str, str]]] = None,
|
||||
context: Optional[str] = None,
|
||||
files: Optional[List[Dict[str, Any]]] = None
|
||||
) -> List[BaseMessage]:
|
||||
"""准备消息列表
|
||||
|
||||
@@ -248,7 +253,7 @@ class LangChainAgent:
|
||||
messages.append(HumanMessage(content=user_content))
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
def _build_multimodal_content(self, text: str, files: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
构建多模态消息内容
|
||||
@@ -261,23 +266,26 @@ class LangChainAgent:
|
||||
List[Dict]: 消息内容列表
|
||||
"""
|
||||
# 根据 provider 使用不同的文本格式
|
||||
if self.provider.lower() in ["bedrock", "anthropic"]:
|
||||
# Anthropic/Bedrock: {"type": "text", "text": "..."}
|
||||
content_parts = [{"type": "text", "text": text}]
|
||||
else:
|
||||
# 通义千问等: {"text": "..."}
|
||||
content_parts = [{"text": text}]
|
||||
|
||||
# if (self.provider.lower() in [ModelProvider.BEDROCK, ModelProvider.OPENAI, ModelProvider.XINFERENCE,
|
||||
# ModelProvider.GPUSTACK] or (
|
||||
# self.provider.lower() == ModelProvider.DASHSCOPE and self.is_omni)):
|
||||
# # Anthropic/Bedrock/Xinference/Gpustack/Openai: {"type": "text", "text": "..."}
|
||||
# content_parts = [{"type": "text", "text": text}]
|
||||
# else:
|
||||
# # 通义千问等: {"text": "..."}
|
||||
# content_parts = [{"type": "text", "text": text}]
|
||||
content_parts = [{"type": "text", "text": text}]
|
||||
|
||||
# 添加文件内容
|
||||
# MultimodalService 已经根据 provider 返回了正确格式,直接使用
|
||||
content_parts.extend(files)
|
||||
|
||||
|
||||
logger.debug(
|
||||
f"构建多模态消息: provider={self.provider}, "
|
||||
f"parts={len(content_parts)}, "
|
||||
f"files={len(files)}"
|
||||
)
|
||||
|
||||
|
||||
return content_parts
|
||||
|
||||
async def chat(
|
||||
@@ -302,7 +310,7 @@ class LangChainAgent:
|
||||
Returns:
|
||||
Dict: 包含 content 和元数据的字典
|
||||
"""
|
||||
message_chat= message
|
||||
message_chat = message
|
||||
start_time = time.time()
|
||||
actual_config_id = config_id
|
||||
# If config_id is None, try to get from end_user's connected config
|
||||
@@ -322,8 +330,8 @@ class LangChainAgent:
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get db session: {e}")
|
||||
actual_end_user_id = end_user_id if end_user_id is not None else "unknown"
|
||||
logger.info(f'写入类型{storage_type,str(end_user_id), message, str(user_rag_memory_id)}')
|
||||
print(f'写入类型{storage_type,str(end_user_id), message, str(user_rag_memory_id)}')
|
||||
logger.info(f'写入类型{storage_type, str(end_user_id), message, str(user_rag_memory_id)}')
|
||||
print(f'写入类型{storage_type, str(end_user_id), message, str(user_rag_memory_id)}')
|
||||
try:
|
||||
# 准备消息列表(支持多模态)
|
||||
messages = self._prepare_messages(message, history, context, files)
|
||||
@@ -367,14 +375,14 @@ class LangChainAgent:
|
||||
# 获取最后的 AI 消息
|
||||
output_messages = result.get("messages", [])
|
||||
content = ""
|
||||
|
||||
|
||||
logger.debug(f"输出消息数量: {len(output_messages)}")
|
||||
total_tokens = 0
|
||||
for msg in reversed(output_messages):
|
||||
if isinstance(msg, AIMessage):
|
||||
logger.debug(f"找到 AI 消息,content 类型: {type(msg.content)}")
|
||||
logger.debug(f"AI 消息内容: {msg.content}")
|
||||
|
||||
|
||||
# 处理多模态响应:content 可能是字符串或列表
|
||||
if isinstance(msg.content, str):
|
||||
content = msg.content
|
||||
@@ -407,12 +415,13 @@ class LangChainAgent:
|
||||
response_meta = msg.response_metadata if hasattr(msg, 'response_metadata') else None
|
||||
total_tokens = response_meta.get("token_usage", {}).get("total_tokens", 0) if response_meta else 0
|
||||
break
|
||||
|
||||
|
||||
logger.info(f"最终提取的内容长度: {len(content)}")
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
if memory_flag:
|
||||
await write_long_term(storage_type, end_user_id, message_chat, content, user_rag_memory_id, actual_config_id)
|
||||
await write_long_term(storage_type, end_user_id, message_chat, content, user_rag_memory_id,
|
||||
actual_config_id)
|
||||
response = {
|
||||
"content": content,
|
||||
"model": self.model_name,
|
||||
@@ -439,16 +448,16 @@ class LangChainAgent:
|
||||
raise
|
||||
|
||||
async def chat_stream(
|
||||
self,
|
||||
message: str,
|
||||
history: Optional[List[Dict[str, str]]] = None,
|
||||
context: Optional[str] = None,
|
||||
end_user_id:Optional[str] = None,
|
||||
config_id: Optional[str] = None,
|
||||
storage_type:Optional[str] = None,
|
||||
user_rag_memory_id:Optional[str] = None,
|
||||
memory_flag: Optional[bool] = True,
|
||||
files: Optional[List[Dict[str, Any]]] = None # 新增:多模态文件
|
||||
self,
|
||||
message: str,
|
||||
history: Optional[List[Dict[str, str]]] = None,
|
||||
context: Optional[str] = None,
|
||||
end_user_id: Optional[str] = None,
|
||||
config_id: Optional[str] = None,
|
||||
storage_type: Optional[str] = None,
|
||||
user_rag_memory_id: Optional[str] = None,
|
||||
memory_flag: Optional[bool] = True,
|
||||
files: Optional[List[Dict[str, Any]]] = None # 新增:多模态文件
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""执行流式对话
|
||||
|
||||
@@ -482,7 +491,6 @@ class LangChainAgent:
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get db session: {e}")
|
||||
|
||||
|
||||
# 注意:不在这里写入用户消息,等 AI 回复后一起写入
|
||||
try:
|
||||
# 准备消息列表(支持多模态)
|
||||
@@ -500,13 +508,13 @@ class LangChainAgent:
|
||||
full_content = ''
|
||||
try:
|
||||
async for event in self.agent.astream_events(
|
||||
{"messages": messages},
|
||||
version="v2",
|
||||
config={"recursion_limit": self.max_iterations}
|
||||
{"messages": messages},
|
||||
version="v2",
|
||||
config={"recursion_limit": self.max_iterations}
|
||||
):
|
||||
chunk_count += 1
|
||||
kind = event.get("event")
|
||||
|
||||
|
||||
# 处理所有可能的流式事件
|
||||
if kind == "on_chat_model_stream":
|
||||
# LLM 流式输出
|
||||
@@ -540,7 +548,7 @@ class LangChainAgent:
|
||||
full_content += item
|
||||
yield item
|
||||
yielded_content = True
|
||||
|
||||
|
||||
elif kind == "on_llm_stream":
|
||||
# 另一种 LLM 流式事件
|
||||
chunk = event.get("data", {}).get("chunk")
|
||||
@@ -577,13 +585,13 @@ class LangChainAgent:
|
||||
full_content += chunk
|
||||
yield chunk
|
||||
yielded_content = True
|
||||
|
||||
|
||||
# 记录工具调用(可选)
|
||||
elif kind == "on_tool_start":
|
||||
logger.debug(f"工具调用开始: {event.get('name')}")
|
||||
elif kind == "on_tool_end":
|
||||
logger.debug(f"工具调用结束: {event.get('name')}")
|
||||
|
||||
|
||||
logger.debug(f"Agent 流式完成,共 {chunk_count} 个事件")
|
||||
# 统计token消耗
|
||||
output_messages = event.get("data", {}).get("output", {}).get("messages", [])
|
||||
@@ -595,7 +603,8 @@ class LangChainAgent:
|
||||
yield total_tokens
|
||||
break
|
||||
if memory_flag:
|
||||
await write_long_term(storage_type, end_user_id, message_chat, full_content, user_rag_memory_id, actual_config_id)
|
||||
await write_long_term(storage_type, end_user_id, message_chat, full_content, user_rag_memory_id,
|
||||
actual_config_id)
|
||||
except Exception as e:
|
||||
logger.error(f"Agent astream_events 失败: {str(e)}", exc_info=True)
|
||||
raise
|
||||
@@ -609,5 +618,3 @@ class LangChainAgent:
|
||||
logger.info("=" * 80)
|
||||
logger.info("chat_stream 方法执行结束")
|
||||
logger.info("=" * 80)
|
||||
|
||||
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Annotated, Any, Dict, Optional
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from pydantic import Field, TypeAdapter
|
||||
|
||||
load_dotenv()
|
||||
|
||||
@@ -200,14 +201,30 @@ class Settings:
|
||||
|
||||
REFLECTION_INTERVAL_SECONDS: float = float(os.getenv("REFLECTION_INTERVAL_SECONDS", "300"))
|
||||
HEALTH_CHECK_SECONDS: float = float(os.getenv("HEALTH_CHECK_SECONDS", "600"))
|
||||
MEMORY_INCREMENT_INTERVAL_HOURS: float = float(os.getenv("MEMORY_INCREMENT_INTERVAL_HOURS", "24"))
|
||||
DEFAULT_WORKSPACE_ID: Optional[str] = os.getenv("DEFAULT_WORKSPACE_ID", None)
|
||||
REFLECTION_INTERVAL_TIME: Optional[str] = int(os.getenv("REFLECTION_INTERVAL_TIME", 30))
|
||||
|
||||
# Memory Cache Regeneration Configuration
|
||||
MEMORY_CACHE_REGENERATION_HOURS: int = int(os.getenv("MEMORY_CACHE_REGENERATION_HOURS", "24"))
|
||||
|
||||
# Celery Beat Schedule Configuration (定时任务执行频率)
|
||||
MEMORY_INCREMENT_HOUR: int = TypeAdapter(
|
||||
Annotated[int, Field(ge=0, le=23, description="cron hour [0, 23]")]
|
||||
).validate_python(int(os.getenv("MEMORY_INCREMENT_HOUR", "2")))
|
||||
MEMORY_INCREMENT_MINUTE: int = TypeAdapter(
|
||||
Annotated[int, Field(ge=0, le=59, description="cron minute [0, 59]")]
|
||||
).validate_python(int(os.getenv("MEMORY_INCREMENT_MINUTE", "0")))
|
||||
WORKSPACE_REFLECTION_INTERVAL_SECONDS: int = TypeAdapter(
|
||||
Annotated[int, Field(ge=1, description="reflection interval in seconds, must be >= 1")]
|
||||
).validate_python(int(os.getenv("WORKSPACE_REFLECTION_INTERVAL_SECONDS", "30")))
|
||||
FORGETTING_CYCLE_INTERVAL_HOURS: int = TypeAdapter(
|
||||
Annotated[int, Field(ge=1, description="forgetting cycle interval in hours, must be >= 1")]
|
||||
).validate_python(int(os.getenv("FORGETTING_CYCLE_INTERVAL_HOURS", "24")))
|
||||
|
||||
IMPLICIT_EMOTIONS_UPDATE_HOUR: int = int(os.getenv("IMPLICIT_EMOTIONS_UPDATE_HOUR", "2"))
|
||||
# implicit_emotions_update: 每天几分执行(分钟,0-59)
|
||||
IMPLICIT_EMOTIONS_UPDATE_MINUTE: int = int(os.getenv("IMPLICIT_EMOTIONS_UPDATE_MINUTE", "0"))
|
||||
# Memory Module Configuration (internal)
|
||||
|
||||
MEMORY_OUTPUT_DIR: str = os.getenv("MEMORY_OUTPUT_DIR", "logs/memory-output")
|
||||
MEMORY_CONFIG_DIR: str = os.getenv("MEMORY_CONFIG_DIR", "app/core/memory")
|
||||
|
||||
@@ -231,7 +248,7 @@ class Settings:
|
||||
# General Ontology Type Configuration
|
||||
# ========================================================================
|
||||
# 通用本体文件路径列表(逗号分隔)
|
||||
GENERAL_ONTOLOGY_FILES: str = os.getenv("GENERAL_ONTOLOGY_FILES", "General_purpose_entity.ttl")
|
||||
GENERAL_ONTOLOGY_FILES: str = os.getenv("GENERAL_ONTOLOGY_FILES", "api/app/core/memory/ontology_services/General_purpose_entity.ttl")
|
||||
|
||||
# 是否启用通用本体类型功能
|
||||
ENABLE_GENERAL_ONTOLOGY_TYPES: bool = os.getenv("ENABLE_GENERAL_ONTOLOGY_TYPES", "true").lower() == "true"
|
||||
|
||||
@@ -21,7 +21,7 @@ async def get_chunked_dialogs(
|
||||
end_user_id: Group identifier
|
||||
messages: Structured message list [{"role": "user", "content": "..."}, ...]
|
||||
ref_id: Reference identifier
|
||||
config_id: Configuration ID for processing
|
||||
config_id: Configuration ID for processing (used to load pruning config)
|
||||
|
||||
Returns:
|
||||
List of DialogData objects with generated chunks
|
||||
@@ -57,6 +57,61 @@ async def get_chunked_dialogs(
|
||||
end_user_id=end_user_id,
|
||||
config_id=config_id
|
||||
)
|
||||
|
||||
# 语义剪枝步骤(在分块之前)
|
||||
try:
|
||||
from app.core.memory.storage_services.extraction_engine.data_preprocessing.data_pruning import SemanticPruner
|
||||
from app.core.memory.models.config_models import PruningConfig
|
||||
from app.db import get_db_context
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
|
||||
# 加载剪枝配置
|
||||
pruning_config = None
|
||||
if config_id:
|
||||
try:
|
||||
with get_db_context() as db:
|
||||
# 使用 MemoryConfigService 加载完整的 MemoryConfig 对象
|
||||
config_service = MemoryConfigService(db)
|
||||
memory_config = config_service.load_memory_config(
|
||||
config_id=config_id,
|
||||
service_name="semantic_pruning"
|
||||
)
|
||||
|
||||
if memory_config:
|
||||
pruning_config = PruningConfig(
|
||||
pruning_switch=memory_config.pruning_enabled,
|
||||
pruning_scene=memory_config.pruning_scene or "education",
|
||||
pruning_threshold=memory_config.pruning_threshold
|
||||
)
|
||||
logger.info(f"[剪枝] 加载配置: switch={pruning_config.pruning_switch}, scene={pruning_config.pruning_scene}, threshold={pruning_config.pruning_threshold}")
|
||||
|
||||
# 获取LLM客户端用于剪枝
|
||||
if pruning_config.pruning_switch:
|
||||
factory = MemoryClientFactory(db)
|
||||
llm_client = factory.get_llm_client_from_config(memory_config)
|
||||
|
||||
# 执行剪枝 - 使用 prune_dataset 支持消息级剪枝
|
||||
pruner = SemanticPruner(config=pruning_config, llm_client=llm_client)
|
||||
original_msg_count = len(dialog_data.context.msgs)
|
||||
|
||||
# 使用 prune_dataset 而不是 prune_dialog
|
||||
# prune_dataset 会进行消息级剪枝,即使对话整体相关也会删除不重要消息
|
||||
pruned_dialogs = await pruner.prune_dataset([dialog_data])
|
||||
|
||||
if pruned_dialogs:
|
||||
dialog_data = pruned_dialogs[0]
|
||||
remaining_msg_count = len(dialog_data.context.msgs)
|
||||
deleted_count = original_msg_count - remaining_msg_count
|
||||
logger.info(f"[剪枝] 完成: 原始{original_msg_count}条 -> 保留{remaining_msg_count}条 (删除{deleted_count}条)")
|
||||
else:
|
||||
logger.warning("[剪枝] prune_dataset 返回空列表")
|
||||
else:
|
||||
logger.info("[剪枝] 配置中剪枝开关关闭,跳过剪枝")
|
||||
except Exception as e:
|
||||
logger.warning(f"[剪枝] 加载配置失败,跳过剪枝: {e}", exc_info=True)
|
||||
except Exception as e:
|
||||
logger.warning(f"[剪枝] 执行失败,跳过剪枝: {e}", exc_info=True)
|
||||
|
||||
chunker = DialogueChunker(chunker_strategy)
|
||||
extracted_chunks = await chunker.process_dialogue(dialog_data)
|
||||
|
||||
@@ -1,9 +1,12 @@
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from typing import List, Tuple
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.db import get_db_context
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
@@ -16,6 +19,10 @@ class FilteredTags(BaseModel):
|
||||
"""用于接收LLM筛选后的核心标签列表的模型。"""
|
||||
meaningful_tags: List[str] = Field(..., description="从原始列表中筛选出的具有核心代表意义的名词列表。")
|
||||
|
||||
class InterestTags(BaseModel):
|
||||
"""用于接收LLM筛选后的兴趣活动标签列表的模型。"""
|
||||
interest_tags: List[str] = Field(..., description="从原始列表中筛选出的代表用户兴趣活动的标签列表。")
|
||||
|
||||
async def filter_tags_with_llm(tags: List[str], end_user_id: str) -> List[str]:
|
||||
"""
|
||||
使用LLM筛选标签列表,仅保留具有代表性的核心名词。
|
||||
@@ -85,10 +92,74 @@ async def filter_tags_with_llm(tags: List[str], end_user_id: str) -> List[str]:
|
||||
return structured_response.meaningful_tags
|
||||
|
||||
except Exception as e:
|
||||
print(f"LLM筛选过程中发生错误: {e}")
|
||||
logger.error(f"LLM筛选过程中发生错误: {e}", exc_info=True)
|
||||
# 在LLM失败时返回原始标签,确保流程继续
|
||||
return tags
|
||||
|
||||
async def filter_interests_with_llm(tags: List[str], end_user_id: str, language: str = "zh") -> List[str]:
|
||||
"""
|
||||
使用LLM从标签列表中筛选出代表用户兴趣活动的标签。
|
||||
|
||||
与 filter_tags_with_llm 不同,此函数专注于识别"活动/行为"类兴趣,
|
||||
过滤掉纯物品、工具、地点等不代表用户主动参与活动的名词。
|
||||
|
||||
Args:
|
||||
tags: 原始标签列表
|
||||
end_user_id: 用户ID,用于获取LLM配置
|
||||
|
||||
Returns:
|
||||
筛选后的兴趣活动标签列表
|
||||
"""
|
||||
try:
|
||||
with get_db_context() as db:
|
||||
from app.services.memory_agent_service import (
|
||||
get_end_user_connected_config,
|
||||
)
|
||||
connected_config = get_end_user_connected_config(end_user_id, db)
|
||||
config_id = connected_config.get("memory_config_id")
|
||||
workspace_id = connected_config.get("workspace_id")
|
||||
|
||||
if not config_id and not workspace_id:
|
||||
raise ValueError(
|
||||
f"No memory_config_id found for end_user_id: {end_user_id}."
|
||||
)
|
||||
|
||||
config_service = MemoryConfigService(db)
|
||||
memory_config = config_service.load_memory_config(
|
||||
config_id=config_id,
|
||||
workspace_id=workspace_id
|
||||
)
|
||||
|
||||
if not memory_config.llm_model_id:
|
||||
raise ValueError(
|
||||
f"No llm_model_id found in memory config {config_id}."
|
||||
)
|
||||
|
||||
factory = MemoryClientFactory(db)
|
||||
llm_client = factory.get_llm_client(memory_config.llm_model_id)
|
||||
|
||||
tag_list_str = ", ".join(tags)
|
||||
from app.core.memory.utils.prompt.prompt_utils import render_interest_filter_prompt
|
||||
rendered_prompt = render_interest_filter_prompt(tag_list_str, language=language)
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": rendered_prompt
|
||||
}
|
||||
]
|
||||
|
||||
structured_response = await llm_client.response_structured(
|
||||
messages=messages,
|
||||
response_model=InterestTags
|
||||
)
|
||||
|
||||
return structured_response.interest_tags
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"兴趣标签LLM筛选过程中发生错误: {e}", exc_info=True)
|
||||
return tags
|
||||
|
||||
|
||||
async def get_raw_tags_from_db(
|
||||
connector: Neo4jConnector,
|
||||
end_user_id: str,
|
||||
@@ -139,14 +210,14 @@ async def get_raw_tags_from_db(
|
||||
|
||||
return [(record["name"], record["frequency"]) for record in results]
|
||||
|
||||
async def get_hot_memory_tags(end_user_id: str, limit: int = 40, by_user: bool = False) -> List[Tuple[str, int]]:
|
||||
async def get_hot_memory_tags(end_user_id: str, limit: int = 10, by_user: bool = False) -> List[Tuple[str, int]]:
|
||||
"""
|
||||
获取原始标签,然后使用LLM进行筛选,返回最终的热门标签列表。
|
||||
查询更多的标签(limit=40)给LLM提供更丰富的上下文进行筛选。
|
||||
查询更多的标签(40条)给LLM提供更丰富的上下文进行筛选,但最终返回数量由limit参数控制。
|
||||
|
||||
Args:
|
||||
end_user_id: 必需参数。如果by_user=False,则为end_user_id;如果by_user=True,则为user_id
|
||||
limit: 返回的标签数量限制
|
||||
limit: 最终返回的标签数量限制(默认10)
|
||||
by_user: 是否按user_id查询(默认False,按end_user_id查询)
|
||||
|
||||
Raises:
|
||||
@@ -161,8 +232,9 @@ async def get_hot_memory_tags(end_user_id: str, limit: int = 40, by_user: bool =
|
||||
# 使用项目的Neo4jConnector
|
||||
connector = Neo4jConnector()
|
||||
try:
|
||||
# 1. 从数据库获取原始排名靠前的标签
|
||||
raw_tags_with_freq = await get_raw_tags_from_db(connector, end_user_id, limit, by_user=by_user)
|
||||
# 1. 从数据库获取原始排名靠前的标签(查询40条给LLM提供更丰富的上下文)
|
||||
query_limit = 40
|
||||
raw_tags_with_freq = await get_raw_tags_from_db(connector, end_user_id, query_limit, by_user=by_user)
|
||||
if not raw_tags_with_freq:
|
||||
return []
|
||||
|
||||
@@ -177,7 +249,61 @@ async def get_hot_memory_tags(end_user_id: str, limit: int = 40, by_user: bool =
|
||||
if tag in meaningful_tag_names:
|
||||
final_tags.append((tag, freq))
|
||||
|
||||
return final_tags
|
||||
# 4. 限制返回的标签数量
|
||||
return final_tags[:limit]
|
||||
finally:
|
||||
# 确保关闭连接
|
||||
await connector.close()
|
||||
|
||||
async def get_interest_distribution(end_user_id: str, limit: int = 10, by_user: bool = False, language: str = "zh") -> List[Tuple[str, int]]:
|
||||
"""
|
||||
获取用户的兴趣分布标签。
|
||||
|
||||
与 get_hot_memory_tags 不同,此函数使用专门针对"活动/行为"的LLM prompt,
|
||||
过滤掉纯物品、工具、地点等,只保留能代表用户兴趣爱好的活动类标签。
|
||||
|
||||
Args:
|
||||
end_user_id: 必需参数。如果by_user=False,则为end_user_id;如果by_user=True,则为user_id
|
||||
limit: 最终返回的标签数量限制(默认10)
|
||||
by_user: 是否按user_id查询(默认False,按end_user_id查询)
|
||||
|
||||
Raises:
|
||||
ValueError: 如果end_user_id未提供或为空
|
||||
"""
|
||||
if not end_user_id or not end_user_id.strip():
|
||||
raise ValueError(
|
||||
"end_user_id is required. Please provide a valid end_user_id or user_id."
|
||||
)
|
||||
|
||||
connector = Neo4jConnector()
|
||||
try:
|
||||
# 查询更多原始标签,给LLM提供充足上下文
|
||||
query_limit = 40
|
||||
raw_tags_with_freq = await get_raw_tags_from_db(connector, end_user_id, query_limit, by_user=by_user)
|
||||
if not raw_tags_with_freq:
|
||||
return []
|
||||
|
||||
raw_tag_names = [tag for tag, freq in raw_tags_with_freq]
|
||||
raw_freq_map = {tag: freq for tag, freq in raw_tags_with_freq}
|
||||
|
||||
# 使用兴趣活动专用prompt进行筛选(支持语义推断出新标签)
|
||||
interest_tag_names = await filter_interests_with_llm(raw_tag_names, end_user_id, language=language)
|
||||
|
||||
# 构建最终标签列表:
|
||||
# - 原始标签中存在的,保留原始频率
|
||||
# - LLM推断出的新标签(不在原始列表中),赋予默认频率1
|
||||
final_tags = []
|
||||
seen = set()
|
||||
for tag in interest_tag_names:
|
||||
if tag in seen:
|
||||
continue
|
||||
seen.add(tag)
|
||||
freq = raw_freq_map.get(tag, 1)
|
||||
final_tags.append((tag, freq))
|
||||
|
||||
# 按频率降序排列
|
||||
final_tags.sort(key=lambda x: x[1], reverse=True)
|
||||
|
||||
return final_tags[:limit]
|
||||
finally:
|
||||
await connector.close()
|
||||
|
||||
@@ -5,20 +5,27 @@
|
||||
- 对话级一次性抽取判定相关性
|
||||
- 仅对"不相关对话"的消息按比例删除
|
||||
- 重要信息(时间、编号、金额、联系方式、地址等)优先保留
|
||||
- 改进版:增强重要性判断、智能填充消息识别、问答对保护、并发优化
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import hashlib
|
||||
import json
|
||||
import re
|
||||
from collections import OrderedDict
|
||||
from datetime import datetime
|
||||
from typing import List, Optional
|
||||
from typing import List, Optional, Dict, Tuple, Set
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.core.memory.models.message_models import DialogData, ConversationMessage, ConversationContext
|
||||
from app.core.memory.models.config_models import PruningConfig
|
||||
from app.core.memory.utils.config.config_utils import get_pruning_config
|
||||
from app.core.memory.utils.prompt.prompt_utils import prompt_env, log_prompt_rendering, log_template_rendering
|
||||
from app.core.memory.storage_services.extraction_engine.data_preprocessing.scene_config import (
|
||||
SceneConfigRegistry,
|
||||
ScenePatterns
|
||||
)
|
||||
|
||||
|
||||
class DialogExtractionResponse(BaseModel):
|
||||
@@ -36,6 +43,23 @@ class DialogExtractionResponse(BaseModel):
|
||||
keywords: List[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
class MessageImportanceResponse(BaseModel):
|
||||
"""消息重要性批量判断的结构化返回(用于LLM语义判断)。
|
||||
|
||||
- importance_scores: 消息索引到重要性分数的映射 (0-10分)
|
||||
- reasons: 可选的判断理由
|
||||
"""
|
||||
importance_scores: Dict[int, int] = Field(default_factory=dict, description="消息索引到重要性分数(0-10)的映射")
|
||||
reasons: Optional[Dict[int, str]] = Field(default_factory=dict, description="可选的判断理由")
|
||||
|
||||
|
||||
class QAPair(BaseModel):
|
||||
"""问答对模型,用于识别和保护对话中的问答结构。"""
|
||||
question_idx: int = Field(..., description="问题消息的索引")
|
||||
answer_idx: int = Field(..., description="答案消息的索引")
|
||||
confidence: float = Field(default=1.0, description="问答对的置信度(0-1)")
|
||||
|
||||
|
||||
class SemanticPruner:
|
||||
"""语义剪枝:在预处理与分块之间过滤与场景不相关内容。
|
||||
|
||||
@@ -43,109 +67,374 @@ class SemanticPruner:
|
||||
重要信息(时间、编号、金额、联系方式、地址等)优先保留。
|
||||
"""
|
||||
|
||||
def __init__(self, config: Optional[PruningConfig] = None, llm_client=None):
|
||||
cfg_dict = get_pruning_config() if config is None else config.model_dump()
|
||||
self.config = PruningConfig.model_validate(cfg_dict)
|
||||
def __init__(self, config: Optional[PruningConfig] = None, llm_client=None, language: str = "zh", max_concurrent: int = 5):
|
||||
# 如果没有提供config,使用默认配置
|
||||
if config is None:
|
||||
# 使用默认的剪枝配置
|
||||
config = PruningConfig(
|
||||
pruning_switch=False, # 默认关闭剪枝,保持向后兼容
|
||||
pruning_scene="education",
|
||||
pruning_threshold=0.5
|
||||
)
|
||||
|
||||
self.config = config
|
||||
self.llm_client = llm_client
|
||||
self.language = language # 保存语言配置
|
||||
self.max_concurrent = max_concurrent # 新增:最大并发数
|
||||
|
||||
# 详细日志配置:限制逐条消息日志的数量
|
||||
self._detailed_prune_logging = True # 是否启用详细日志
|
||||
self._max_debug_msgs_per_dialog = 20 # 每个对话最多记录前N条消息的详细日志
|
||||
|
||||
# 加载场景特定配置
|
||||
self.scene_config: ScenePatterns = SceneConfigRegistry.get_config(
|
||||
self.config.pruning_scene,
|
||||
fallback_to_generic=True
|
||||
)
|
||||
|
||||
# 检查场景是否有专门支持
|
||||
is_supported = SceneConfigRegistry.is_scene_supported(self.config.pruning_scene)
|
||||
if is_supported:
|
||||
self._log(f"[剪枝-初始化] 场景={self.config.pruning_scene} 使用专门配置")
|
||||
else:
|
||||
self._log(f"[剪枝-初始化] 场景={self.config.pruning_scene} 未预定义,使用通用配置(保守策略)")
|
||||
self._log(f"[剪枝-初始化] 支持的场景: {SceneConfigRegistry.get_all_scenes()}")
|
||||
|
||||
# Load Jinja2 template
|
||||
self.template = prompt_env.get_template("extracat_Pruning.jinja2")
|
||||
# 对话抽取缓存:避免同一对话重复调用 LLM / 重复渲染
|
||||
self._dialog_extract_cache: dict[str, DialogExtractionResponse] = {}
|
||||
|
||||
# 对话抽取缓存:使用 OrderedDict 实现 LRU 缓存
|
||||
self._dialog_extract_cache: OrderedDict[str, DialogExtractionResponse] = OrderedDict()
|
||||
self._cache_max_size = 1000 # 缓存大小限制
|
||||
|
||||
# 运行日志:收集关键终端输出,便于写入 JSON
|
||||
self.run_logs: List[str] = []
|
||||
# 采用顺序处理,移除并发配置以简化与稳定执行
|
||||
|
||||
def _is_important_message(self, message: ConversationMessage) -> bool:
|
||||
"""基于启发式规则识别重要信息消息,优先保留。
|
||||
|
||||
- 含日期/时间(如YYYY-MM-DD、HH:MM、2024年11月10日、上午/下午)。
|
||||
- 含编号/ID/订单号/申请号/账号/电话/金额等关键字段。
|
||||
- 关键词:"时间"、"日期"、"编号"、"订单"、"流水"、"金额"、"¥"、"元"、"电话"、"手机号"、"邮箱"、"地址"。
|
||||
改进版:使用场景特定的模式进行识别
|
||||
- 根据 pruning_scene 动态加载对应的识别规则
|
||||
- 支持教育、在线服务、外呼三个场景的特定模式
|
||||
"""
|
||||
import re
|
||||
text = message.msg.strip()
|
||||
if not text:
|
||||
return False
|
||||
patterns = [
|
||||
r"\b\d{4}-\d{1,2}-\d{1,2}\b",
|
||||
r"\b\d{1,2}:\d{2}\b",
|
||||
r"\d{4}年\d{1,2}月\d{1,2}日",
|
||||
r"上午|下午|AM|PM",
|
||||
r"订单号|工单|申请号|编号|ID|账号|账户",
|
||||
r"电话|手机号|微信|QQ|邮箱",
|
||||
r"地址|地点",
|
||||
r"金额|费用|价格|¥|¥|\d+元",
|
||||
r"时间|日期|有效期|截止",
|
||||
]
|
||||
for p in patterns:
|
||||
if re.search(p, text, flags=re.IGNORECASE):
|
||||
|
||||
# 使用场景特定的模式
|
||||
all_patterns = (
|
||||
self.scene_config.high_priority_patterns +
|
||||
self.scene_config.medium_priority_patterns +
|
||||
self.scene_config.low_priority_patterns
|
||||
)
|
||||
|
||||
for pattern, _ in all_patterns:
|
||||
if re.search(pattern, text, flags=re.IGNORECASE):
|
||||
return True
|
||||
|
||||
# 检查是否为问句(以问号结尾或包含疑问词)
|
||||
if text.endswith("?") or text.endswith("?"):
|
||||
return True
|
||||
|
||||
# 检查是否包含问句关键词
|
||||
if any(keyword in text for keyword in self.scene_config.question_keywords):
|
||||
return True
|
||||
|
||||
# 检查是否包含决策性关键词
|
||||
if any(keyword in text for keyword in self.scene_config.decision_keywords):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _importance_score(self, message: ConversationMessage) -> int:
|
||||
"""为重要消息打分,用于在保留比例内优先保留更关键的内容。
|
||||
|
||||
简单启发:匹配到的类别越多、越关键分值越高。
|
||||
改进版:使用场景特定的权重体系(0-10分)
|
||||
- 根据场景动态调整不同信息类型的权重
|
||||
- 高优先级模式:4-6分
|
||||
- 中优先级模式:2-3分
|
||||
- 低优先级模式:1分
|
||||
"""
|
||||
import re
|
||||
text = message.msg.strip()
|
||||
score = 0
|
||||
weights = [
|
||||
(r"\b\d{4}-\d{1,2}-\d{1,2}\b", 3),
|
||||
(r"\b\d{1,2}:\d{2}\b", 2),
|
||||
(r"\d{4}年\d{1,2}月\d{1,2}日", 3),
|
||||
(r"订单号|工单|申请号|编号|ID|账号|账户", 4),
|
||||
(r"电话|手机号|微信|QQ|邮箱", 3),
|
||||
(r"地址|地点", 2),
|
||||
(r"金额|费用|价格|¥|¥|\d+元", 4),
|
||||
(r"时间|日期|有效期|截止", 2),
|
||||
]
|
||||
for p, w in weights:
|
||||
if re.search(p, text, flags=re.IGNORECASE):
|
||||
score += w
|
||||
return score
|
||||
|
||||
# 使用场景特定的权重
|
||||
for pattern, weight in self.scene_config.high_priority_patterns:
|
||||
if re.search(pattern, text, flags=re.IGNORECASE):
|
||||
score += weight
|
||||
|
||||
for pattern, weight in self.scene_config.medium_priority_patterns:
|
||||
if re.search(pattern, text, flags=re.IGNORECASE):
|
||||
score += weight
|
||||
|
||||
for pattern, weight in self.scene_config.low_priority_patterns:
|
||||
if re.search(pattern, text, flags=re.IGNORECASE):
|
||||
score += weight
|
||||
|
||||
# 问句加分
|
||||
if text.endswith("?") or text.endswith("?"):
|
||||
score += 2
|
||||
|
||||
# 包含问句关键词加分
|
||||
if any(keyword in text for keyword in self.scene_config.question_keywords):
|
||||
score += 1
|
||||
|
||||
# 包含决策性关键词加分
|
||||
if any(keyword in text for keyword in self.scene_config.decision_keywords):
|
||||
score += 2
|
||||
|
||||
# 长度加分(较长的消息通常包含更多信息)
|
||||
if len(text) > 50:
|
||||
score += 1
|
||||
if len(text) > 100:
|
||||
score += 1
|
||||
|
||||
return min(score, 10) # 最高10分
|
||||
|
||||
def _is_filler_message(self, message: ConversationMessage) -> bool:
|
||||
"""检测典型寒暄/口头禅/确认类短消息,用于跳过LLM分类以加速。
|
||||
"""检测典型寒暄/口头禅/确认类短消息。
|
||||
|
||||
改进版:更严格的填充消息判断,避免误删场景相关内容
|
||||
满足以下之一视为填充消息:
|
||||
- 纯标点或长度很短(<= 4 个汉字或 <= 8 个字符)且不包含数字或关键实体;
|
||||
- 常见词:你好/您好/在吗/嗯/嗯嗯/哦/好的/好/行/可以/不可以/谢谢/拜拜/再见/哈哈/呵呵/哈哈哈/。。。/??。
|
||||
- 纯标点或空白
|
||||
- 在场景特定填充词库中(精确匹配)
|
||||
- 纯表情符号
|
||||
- 常见寒暄(精确匹配短语)
|
||||
|
||||
注意:不再使用长度判断,避免误删短但重要的消息
|
||||
"""
|
||||
import re
|
||||
t = message.msg.strip()
|
||||
if not t:
|
||||
return True
|
||||
# 常见填充语
|
||||
fillers = [
|
||||
"你好", "您好", "在吗", "嗯", "嗯嗯", "哦", "好的", "好", "行", "可以", "不可以", "谢谢",
|
||||
"拜拜", "再见", "哈哈", "呵呵", "哈哈哈", "。。。", "??", "??"
|
||||
]
|
||||
if t in fillers:
|
||||
|
||||
# 检查是否在场景特定填充词库中(精确匹配)
|
||||
if t in self.scene_config.filler_phrases:
|
||||
return True
|
||||
# 长度与字符类型判断
|
||||
if len(t) <= 8:
|
||||
# 非数字、无关键实体的短文本
|
||||
if not re.search(r"[0-9]", t) and not self._is_important_message(message):
|
||||
# 主要是标点或简单确认词
|
||||
if re.fullmatch(r"[。!?,.!?…·\s]+", t) or t in fillers:
|
||||
return True
|
||||
|
||||
# 常见寒暄和问候(精确匹配,避免误删)
|
||||
common_greetings = {
|
||||
"在吗", "在不在", "在呢", "在的",
|
||||
"你好", "您好", "hello", "hi",
|
||||
"拜拜", "再见", "拜", "88", "bye",
|
||||
"好的", "好", "行", "可以", "嗯", "哦", "啊",
|
||||
"是的", "对", "对的", "没错", "是啊",
|
||||
"哈哈", "呵呵", "嘿嘿", "嗯嗯"
|
||||
}
|
||||
if t in common_greetings:
|
||||
return True
|
||||
|
||||
# 检查是否为纯表情符号(方括号包裹)
|
||||
if re.fullmatch(r"(\[[^\]]+\])+", t):
|
||||
return True
|
||||
|
||||
# 检查是否为纯emoji(Unicode表情)
|
||||
emoji_pattern = re.compile(
|
||||
"["
|
||||
"\U0001F600-\U0001F64F" # 表情符号
|
||||
"\U0001F300-\U0001F5FF" # 符号和象形文字
|
||||
"\U0001F680-\U0001F6FF" # 交通和地图符号
|
||||
"\U0001F1E0-\U0001F1FF" # 旗帜
|
||||
"\U00002702-\U000027B0"
|
||||
"\U000024C2-\U0001F251"
|
||||
"]+", flags=re.UNICODE
|
||||
)
|
||||
if emoji_pattern.fullmatch(t):
|
||||
return True
|
||||
|
||||
# 纯标点符号
|
||||
if re.fullmatch(r"[。!?,.!?…·\s]+", t):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
async def _batch_evaluate_importance_with_llm(
|
||||
self,
|
||||
messages: List[ConversationMessage],
|
||||
context: str = ""
|
||||
) -> Dict[int, int]:
|
||||
"""使用LLM批量评估消息的重要性(语义层面)。
|
||||
|
||||
Args:
|
||||
messages: 消息列表
|
||||
context: 对话上下文(可选)
|
||||
|
||||
Returns:
|
||||
消息索引到重要性分数(0-10)的映射
|
||||
"""
|
||||
if not self.llm_client or not messages:
|
||||
return {}
|
||||
|
||||
# 构建批量评估的提示词
|
||||
msg_list = []
|
||||
for idx, msg in enumerate(messages):
|
||||
msg_list.append(f"{idx}. {msg.msg}")
|
||||
|
||||
msg_text = "\n".join(msg_list)
|
||||
|
||||
prompt = f"""请评估以下消息的重要性,给每条消息打分(0-10分):
|
||||
- 0-2分:无意义的寒暄、口头禅、纯表情
|
||||
- 3-5分:一般性对话,有一定信息量但不关键
|
||||
- 6-8分:包含重要信息(时间、地点、人物、事件等)
|
||||
- 9-10分:关键决策、承诺、重要数据
|
||||
|
||||
对话上下文:
|
||||
{context if context else "无"}
|
||||
|
||||
待评估的消息:
|
||||
{msg_text}
|
||||
|
||||
请以JSON格式返回,格式为:
|
||||
{{
|
||||
"importance_scores": {{
|
||||
"0": 分数,
|
||||
"1": 分数,
|
||||
...
|
||||
}}
|
||||
}}
|
||||
"""
|
||||
|
||||
try:
|
||||
messages_for_llm = [
|
||||
{"role": "system", "content": "你是一个专业的对话分析助手,擅长评估消息的重要性。"},
|
||||
{"role": "user", "content": prompt}
|
||||
]
|
||||
|
||||
response = await self.llm_client.response_structured(
|
||||
messages_for_llm,
|
||||
MessageImportanceResponse
|
||||
)
|
||||
|
||||
# 转换字符串键为整数键
|
||||
return {int(k): v for k, v in response.importance_scores.items()}
|
||||
except Exception as e:
|
||||
self._log(f"[剪枝-LLM] 批量重要性评估失败: {str(e)[:100]}")
|
||||
return {}
|
||||
|
||||
def _identify_qa_pairs(self, messages: List[ConversationMessage]) -> List[QAPair]:
|
||||
"""识别对话中的问答对,用于保护问答结构的完整性。
|
||||
|
||||
改进版:使用场景特定的问句关键词,并排除寒暄类问句
|
||||
|
||||
Args:
|
||||
messages: 消息列表
|
||||
|
||||
Returns:
|
||||
问答对列表
|
||||
"""
|
||||
qa_pairs = []
|
||||
|
||||
# 寒暄类问句,不应该被保护(这些不是真正的问答)
|
||||
greeting_questions = {
|
||||
"在吗", "在不在", "你好吗", "怎么样", "好吗",
|
||||
"有空吗", "忙吗", "睡了吗", "起床了吗"
|
||||
}
|
||||
|
||||
for i in range(len(messages) - 1):
|
||||
current_msg = messages[i].msg.strip()
|
||||
next_msg = messages[i + 1].msg.strip()
|
||||
|
||||
# 排除寒暄类问句
|
||||
if current_msg in greeting_questions:
|
||||
continue
|
||||
|
||||
# 使用场景特定的问句关键词,但要求更严格
|
||||
is_question = False
|
||||
|
||||
# 1. 以问号结尾
|
||||
if current_msg.endswith("?") or current_msg.endswith("?"):
|
||||
is_question = True
|
||||
# 2. 包含实质性问句关键词(排除"吗"这种太宽泛的)
|
||||
elif any(word in current_msg for word in ["什么", "为什么", "怎么", "如何", "哪里", "哪个", "谁", "多少", "几点", "何时"]):
|
||||
is_question = True
|
||||
|
||||
if is_question and next_msg:
|
||||
# 检查下一条消息是否像答案(不是另一个问句,也不是寒暄)
|
||||
is_answer = not (next_msg.endswith("?") or next_msg.endswith("?"))
|
||||
|
||||
# 排除寒暄类回复
|
||||
greeting_answers = {"你好", "您好", "在呢", "在的", "嗯", "哦", "好的"}
|
||||
if next_msg in greeting_answers:
|
||||
is_answer = False
|
||||
|
||||
if is_answer:
|
||||
qa_pairs.append(QAPair(
|
||||
question_idx=i,
|
||||
answer_idx=i + 1,
|
||||
confidence=0.8 # 基于规则的置信度
|
||||
))
|
||||
|
||||
return qa_pairs
|
||||
|
||||
def _get_protected_indices(
|
||||
self,
|
||||
messages: List[ConversationMessage],
|
||||
qa_pairs: List[QAPair],
|
||||
window_size: int = 2
|
||||
) -> Set[int]:
|
||||
"""获取需要保护的消息索引集合(问答对+上下文窗口)。
|
||||
|
||||
Args:
|
||||
messages: 消息列表
|
||||
qa_pairs: 问答对列表
|
||||
window_size: 上下文窗口大小(前后各保留几条消息)
|
||||
|
||||
Returns:
|
||||
需要保护的消息索引集合
|
||||
"""
|
||||
protected = set()
|
||||
|
||||
for qa_pair in qa_pairs:
|
||||
# 保护问答对本身
|
||||
protected.add(qa_pair.question_idx)
|
||||
protected.add(qa_pair.answer_idx)
|
||||
|
||||
# 保护上下文窗口
|
||||
for offset in range(-window_size, window_size + 1):
|
||||
q_idx = qa_pair.question_idx + offset
|
||||
a_idx = qa_pair.answer_idx + offset
|
||||
|
||||
if 0 <= q_idx < len(messages):
|
||||
protected.add(q_idx)
|
||||
if 0 <= a_idx < len(messages):
|
||||
protected.add(a_idx)
|
||||
|
||||
return protected
|
||||
|
||||
async def _extract_dialog_important(self, dialog_text: str) -> DialogExtractionResponse:
|
||||
"""对话级一次性抽取:从整段对话中提取重要信息并判定相关性。
|
||||
|
||||
- 仅使用 LLM 结构化输出;
|
||||
改进版:
|
||||
- LRU缓存管理
|
||||
- 重试机制
|
||||
- 降级策略
|
||||
"""
|
||||
# 缓存命中则直接返回(场景+内容作为键)
|
||||
cache_key = f"{self.config.pruning_scene}:" + hashlib.sha1(dialog_text.encode("utf-8")).hexdigest()
|
||||
|
||||
# LRU缓存:如果命中,移到末尾(最近使用)
|
||||
if cache_key in self._dialog_extract_cache:
|
||||
self._dialog_extract_cache.move_to_end(cache_key)
|
||||
return self._dialog_extract_cache[cache_key]
|
||||
|
||||
rendered = self.template.render(pruning_scene=self.config.pruning_scene, dialog_text=dialog_text)
|
||||
log_template_rendering("extracat_Pruning.jinja2", {"pruning_scene": self.config.pruning_scene})
|
||||
# LRU缓存大小限制:超过限制时删除最旧的条目
|
||||
if len(self._dialog_extract_cache) >= self._cache_max_size:
|
||||
# 删除最旧的条目(OrderedDict的第一个)
|
||||
oldest_key = next(iter(self._dialog_extract_cache))
|
||||
del self._dialog_extract_cache[oldest_key]
|
||||
self._log(f"[剪枝-缓存] LRU缓存已满,删除最旧条目")
|
||||
|
||||
rendered = self.template.render(
|
||||
pruning_scene=self.config.pruning_scene,
|
||||
dialog_text=dialog_text,
|
||||
language=self.language
|
||||
)
|
||||
log_template_rendering("extracat_Pruning.jinja2", {
|
||||
"pruning_scene": self.config.pruning_scene,
|
||||
"language": self.language
|
||||
})
|
||||
log_prompt_rendering("pruning-extract", rendered)
|
||||
|
||||
# 强制使用 LLM;移除正则回退
|
||||
# 强制使用 LLM
|
||||
if not self.llm_client:
|
||||
raise RuntimeError("llm_client 未配置;请配置 LLM 以进行结构化抽取。")
|
||||
|
||||
@@ -153,12 +442,32 @@ class SemanticPruner:
|
||||
{"role": "system", "content": "你是一个严谨的场景抽取助手,只输出严格 JSON。"},
|
||||
{"role": "user", "content": rendered},
|
||||
]
|
||||
try:
|
||||
ex = await self.llm_client.response_structured(messages, DialogExtractionResponse)
|
||||
self._dialog_extract_cache[cache_key] = ex
|
||||
return ex
|
||||
except Exception as e:
|
||||
raise RuntimeError("LLM 结构化抽取失败;请检查 LLM 配置或重试。") from e
|
||||
|
||||
# 重试机制
|
||||
max_retries = 3
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
ex = await self.llm_client.response_structured(messages, DialogExtractionResponse)
|
||||
self._dialog_extract_cache[cache_key] = ex
|
||||
return ex
|
||||
except Exception as e:
|
||||
if attempt < max_retries - 1:
|
||||
self._log(f"[剪枝-LLM] 第 {attempt + 1} 次尝试失败,重试中... 错误: {str(e)[:100]}")
|
||||
await asyncio.sleep(0.5 * (attempt + 1)) # 指数退避
|
||||
continue
|
||||
else:
|
||||
# 降级策略:标记为相关,避免误删
|
||||
self._log(f"[剪枝-LLM] LLM 调用失败 {max_retries} 次,使用降级策略(标记为相关)")
|
||||
fallback_response = DialogExtractionResponse(
|
||||
is_related=True,
|
||||
times=[],
|
||||
ids=[],
|
||||
amounts=[],
|
||||
contacts=[],
|
||||
addresses=[],
|
||||
keywords=[]
|
||||
)
|
||||
return fallback_response
|
||||
|
||||
def _msg_matches_tokens(self, message: ConversationMessage, tokens: List[str]) -> bool:
|
||||
"""判断消息是否包含任意抽取到的重要片段。"""
|
||||
@@ -248,12 +557,14 @@ class SemanticPruner:
|
||||
async def prune_dataset(self, dialogs: List[DialogData]) -> List[DialogData]:
|
||||
"""数据集层面:全局消息级剪枝,保留所有对话。
|
||||
|
||||
- 仅在"不相关对话"的范围内执行消息剪枝;相关对话不动。
|
||||
- 只删除"不重要的不相关消息",重要信息(时间、编号等)强制保留。
|
||||
- 删除总量 = 阈值 * 全部不相关可删消息数,按可删容量比例分配;顺序删除。
|
||||
- 保证每段对话至少保留1条消息,不会删除整段对话。
|
||||
改进版:
|
||||
- 消息级独立判断,每条消息根据场景规则独立评估
|
||||
- 问答对保护已注释(暂不启用,留作观察)
|
||||
- 优化删除策略:填充消息 → 不重要消息 → 低分重要消息
|
||||
- 只删除"不重要的不相关消息",重要信息(时间、编号等)强制保留
|
||||
- 保证每段对话至少保留1条消息,不会删除整段对话
|
||||
"""
|
||||
# 如果剪枝功能关闭,直接返回原始数据集。
|
||||
# 如果剪枝功能关闭,直接返回原始数据集
|
||||
if not self.config.pruning_switch:
|
||||
return dialogs
|
||||
|
||||
@@ -264,179 +575,140 @@ class SemanticPruner:
|
||||
proportion = 0.9
|
||||
if proportion < 0.0:
|
||||
proportion = 0.0
|
||||
evaluated_dialogs = [] # list of dicts: {dialog, is_related}
|
||||
|
||||
self._log(
|
||||
f"[剪枝-数据集] 对话总数={len(dialogs)} 场景={self.config.pruning_scene} 删除比例={proportion} 开关={self.config.pruning_switch}"
|
||||
f"[剪枝-数据集] 对话总数={len(dialogs)} 场景={self.config.pruning_scene} 删除比例={proportion} 开关={self.config.pruning_switch} 模式=消息级独立判断"
|
||||
)
|
||||
# 对话级相关性分类(一次性对整段对话文本进行判断,顺序执行并复用缓存)
|
||||
evaluated_dialogs = []
|
||||
for idx, dd in enumerate(dialogs):
|
||||
try:
|
||||
ex = await self._extract_dialog_important(dd.content)
|
||||
evaluated_dialogs.append({
|
||||
"dialog": dd,
|
||||
"is_related": bool(ex.is_related),
|
||||
"index": idx,
|
||||
"extraction": ex
|
||||
})
|
||||
except Exception:
|
||||
evaluated_dialogs.append({
|
||||
"dialog": dd,
|
||||
"is_related": True,
|
||||
"index": idx,
|
||||
"extraction": None
|
||||
})
|
||||
|
||||
# 统计相关 / 不相关对话
|
||||
not_related_dialogs = [d for d in evaluated_dialogs if not d["is_related"]]
|
||||
related_dialogs = [d for d in evaluated_dialogs if d["is_related"]]
|
||||
self._log(
|
||||
f"[剪枝-数据集] 相关对话数={len(related_dialogs)} 不相关对话数={len(not_related_dialogs)}"
|
||||
)
|
||||
|
||||
# 简洁打印第几段对话相关/不相关(索引基于1)
|
||||
def _fmt_indices(items, cap: int = 10):
|
||||
inds = [i["index"] + 1 for i in items]
|
||||
if len(inds) <= cap:
|
||||
return inds
|
||||
# 超过上限时只打印前cap个,并标注总数
|
||||
return inds[:cap] + ["...", f"共{len(inds)}个"]
|
||||
|
||||
rel_inds = _fmt_indices(related_dialogs)
|
||||
nrel_inds = _fmt_indices(not_related_dialogs)
|
||||
self._log(f"[剪枝-数据集] 相关对话:第{rel_inds}段;不相关对话:第{nrel_inds}段")
|
||||
|
||||
|
||||
result: List[DialogData] = []
|
||||
if not_related_dialogs:
|
||||
# 为每个不相关对话进行一次性抽取,识别重要/不重要(避免逐条 LLM)
|
||||
per_dialog_info = {}
|
||||
total_unrelated = 0
|
||||
total_capacity = 0
|
||||
for d in not_related_dialogs:
|
||||
dd = d["dialog"]
|
||||
extraction = d.get("extraction")
|
||||
if extraction is None:
|
||||
extraction = await self._extract_dialog_important(dd.content)
|
||||
# 合并所有重要标记
|
||||
tokens = extraction.times + extraction.ids + extraction.amounts + extraction.contacts + extraction.addresses + extraction.keywords
|
||||
msgs = dd.context.msgs
|
||||
# 分类消息
|
||||
imp_unrel_msgs = [m for m in msgs if self._msg_matches_tokens(m, tokens) or self._is_important_message(m)]
|
||||
unimp_unrel_msgs = [m for m in msgs if m not in imp_unrel_msgs]
|
||||
# 重要消息按重要性排序
|
||||
imp_sorted_ids = [id(m) for m in sorted(imp_unrel_msgs, key=lambda m: self._importance_score(m))]
|
||||
info = {
|
||||
"dialog": dd,
|
||||
"total_msgs": len(msgs),
|
||||
"unrelated_count": len(msgs),
|
||||
"imp_ids_sorted": imp_sorted_ids,
|
||||
"unimp_ids": [id(m) for m in unimp_unrel_msgs],
|
||||
}
|
||||
per_dialog_info[d["index"]] = info
|
||||
total_unrelated += info["unrelated_count"]
|
||||
# 全局删除配额:比例作用于全部不相关消息(重要+不重要)
|
||||
global_delete = int(total_unrelated * proportion)
|
||||
if proportion > 0 and total_unrelated > 0 and global_delete == 0:
|
||||
global_delete = 1
|
||||
# 每段的最大可删容量:不重要全部 + 重要最多删除 floor(len(重要)*比例),且至少保留1条消息
|
||||
capacities = []
|
||||
for d in not_related_dialogs:
|
||||
idx = d["index"]
|
||||
info = per_dialog_info[idx]
|
||||
# 统计重要数量
|
||||
imp_count = len(info["imp_ids_sorted"])
|
||||
unimp_count = len(info["unimp_ids"])
|
||||
imp_cap = int(imp_count * proportion)
|
||||
cap = min(unimp_count + imp_cap, max(0, info["total_msgs"] - 1))
|
||||
capacities.append(cap)
|
||||
total_capacity = sum(capacities)
|
||||
if global_delete > total_capacity:
|
||||
print(f"[剪枝-数据集] 不相关消息总数={total_unrelated},目标删除={global_delete},最大可删={total_capacity}(重要消息按比例保留)。将按最大可删执行。")
|
||||
global_delete = total_capacity
|
||||
|
||||
# 配额分配:按不相关消息占比分配到各对话,但不超过各自容量
|
||||
alloc = []
|
||||
for i, d in enumerate(not_related_dialogs):
|
||||
idx = d["index"]
|
||||
info = per_dialog_info[idx]
|
||||
share = int(global_delete * (info["unrelated_count"] / total_unrelated)) if total_unrelated > 0 else 0
|
||||
alloc.append(min(share, capacities[i]))
|
||||
allocated = sum(alloc)
|
||||
rem = global_delete - allocated
|
||||
turn = 0
|
||||
while rem > 0 and turn < 100000:
|
||||
progressed = False
|
||||
for i in range(len(not_related_dialogs)):
|
||||
if rem <= 0:
|
||||
break
|
||||
if alloc[i] < capacities[i]:
|
||||
alloc[i] += 1
|
||||
rem -= 1
|
||||
progressed = True
|
||||
if not progressed:
|
||||
break
|
||||
turn += 1
|
||||
|
||||
# 应用删除:相关对话不动;不相关按分配先删不重要,再删重要(低分优先)
|
||||
total_deleted_confirm = 0
|
||||
for d in evaluated_dialogs:
|
||||
dd = d["dialog"]
|
||||
msgs = dd.context.msgs
|
||||
original = len(msgs)
|
||||
if d["is_related"]:
|
||||
result.append(dd)
|
||||
continue
|
||||
idx_in_unrel = next((k for k, x in enumerate(not_related_dialogs) if x["index"] == d["index"]), None)
|
||||
if idx_in_unrel is None:
|
||||
result.append(dd)
|
||||
continue
|
||||
quota = alloc[idx_in_unrel]
|
||||
info = per_dialog_info[d["index"]]
|
||||
# 计算本对话重要最多可删数量
|
||||
imp_count = len(info["imp_ids_sorted"])
|
||||
imp_del_cap = int(imp_count * proportion)
|
||||
# 先构造顺序删除的"不重要ID集合"(按出现顺序前 quota 条)
|
||||
unimp_delete_ids = set(info["unimp_ids"][:min(quota, len(info["unimp_ids"]))])
|
||||
del_unimp = min(quota, len(unimp_delete_ids))
|
||||
rem_quota = quota - del_unimp
|
||||
# 再从重要里选低分优先的删除ID(不超过 imp_del_cap)
|
||||
imp_delete_ids = set(info["imp_ids_sorted"][:min(rem_quota, imp_del_cap)])
|
||||
deleted_here = 0
|
||||
actual_unimp_deleted = 0
|
||||
actual_imp_deleted = 0
|
||||
kept = []
|
||||
for m in msgs:
|
||||
mid = id(m)
|
||||
if mid in unimp_delete_ids and actual_unimp_deleted < del_unimp:
|
||||
actual_unimp_deleted += 1
|
||||
deleted_here += 1
|
||||
continue
|
||||
if mid in imp_delete_ids and actual_imp_deleted < len(imp_delete_ids):
|
||||
actual_imp_deleted += 1
|
||||
deleted_here += 1
|
||||
continue
|
||||
kept.append(m)
|
||||
if not kept and msgs:
|
||||
kept = [msgs[0]]
|
||||
dd.context.msgs = kept
|
||||
total_deleted_confirm += deleted_here
|
||||
self._log(
|
||||
f"[剪枝-对话] 对话 {d['index']+1} 总消息={original} 分配删除={quota} 实删={deleted_here} 保留={len(kept)}"
|
||||
)
|
||||
result.append(dd)
|
||||
self._log(f"[剪枝-数据集] 全局消息级顺序剪枝完成,总删除 {total_deleted_confirm} 条(不相关消息,重要按比例保留)。")
|
||||
else:
|
||||
# 全部相关:不执行剪枝
|
||||
result = [d["dialog"] for d in evaluated_dialogs]
|
||||
total_original_msgs = 0
|
||||
total_deleted_msgs = 0
|
||||
|
||||
for d_idx, dd in enumerate(dialogs):
|
||||
msgs = dd.context.msgs
|
||||
original_count = len(msgs)
|
||||
total_original_msgs += original_count
|
||||
|
||||
# ========== 问答对保护(已注释,暂不启用,留作观察) ==========
|
||||
# qa_pairs = self._identify_qa_pairs(msgs)
|
||||
# protected_indices = self._get_protected_indices(msgs, qa_pairs, window_size=0)
|
||||
# ========================================================
|
||||
|
||||
# 消息级分类:每条消息独立判断
|
||||
important_msgs = [] # 重要消息(保留)
|
||||
unimportant_msgs = [] # 不重要消息(可删除)
|
||||
filler_msgs = [] # 填充消息(优先删除)
|
||||
|
||||
# 判断是否需要详细日志(仅对前N条消息记录)
|
||||
should_log_details = self._detailed_prune_logging and original_count <= self._max_debug_msgs_per_dialog
|
||||
if self._detailed_prune_logging and original_count > self._max_debug_msgs_per_dialog:
|
||||
self._log(f" 对话[{d_idx}]消息数={original_count},仅采样前{self._max_debug_msgs_per_dialog}条进行详细日志")
|
||||
|
||||
for idx, m in enumerate(msgs):
|
||||
msg_text = m.msg.strip()
|
||||
|
||||
# ========== 问答对保护判断(已注释) ==========
|
||||
# if idx in protected_indices:
|
||||
# important_msgs.append((idx, m))
|
||||
# self._log(f" [{idx}] '{msg_text[:30]}...' → 重要(问答对保护)")
|
||||
# ==========================================
|
||||
|
||||
# 填充消息(寒暄、表情等)
|
||||
if self._is_filler_message(m):
|
||||
filler_msgs.append((idx, m))
|
||||
if should_log_details or idx < self._max_debug_msgs_per_dialog:
|
||||
self._log(f" [{idx}] '{msg_text[:30]}...' → 填充")
|
||||
# 重要信息(学号、成绩、时间、金额等)
|
||||
elif self._is_important_message(m):
|
||||
important_msgs.append((idx, m))
|
||||
if should_log_details or idx < self._max_debug_msgs_per_dialog:
|
||||
self._log(f" [{idx}] '{msg_text[:30]}...' → 重要(场景规则)")
|
||||
# 其他消息
|
||||
else:
|
||||
unimportant_msgs.append((idx, m))
|
||||
if should_log_details or idx < self._max_debug_msgs_per_dialog:
|
||||
self._log(f" [{idx}] '{msg_text[:30]}...' → 不重要")
|
||||
|
||||
# 计算删除配额
|
||||
delete_target = int(original_count * proportion)
|
||||
if proportion > 0 and original_count > 0 and delete_target == 0:
|
||||
delete_target = 1
|
||||
|
||||
# 确保至少保留1条消息
|
||||
max_deletable = max(0, original_count - 1)
|
||||
delete_target = min(delete_target, max_deletable)
|
||||
|
||||
# 删除策略:优先删除填充消息,再删除不重要消息
|
||||
to_delete_indices = set()
|
||||
deleted_details = [] # 记录删除的消息详情
|
||||
|
||||
# 第一步:删除填充消息
|
||||
filler_to_delete = min(len(filler_msgs), delete_target)
|
||||
for i in range(filler_to_delete):
|
||||
idx, msg = filler_msgs[i]
|
||||
to_delete_indices.add(idx)
|
||||
deleted_details.append(f"[{idx}] 填充: '{msg.msg[:50]}'")
|
||||
|
||||
# 第二步:如果还需要删除,删除不重要消息
|
||||
remaining_quota = delete_target - len(to_delete_indices)
|
||||
if remaining_quota > 0:
|
||||
unimp_to_delete = min(len(unimportant_msgs), remaining_quota)
|
||||
for i in range(unimp_to_delete):
|
||||
idx, msg = unimportant_msgs[i]
|
||||
to_delete_indices.add(idx)
|
||||
deleted_details.append(f"[{idx}] 不重要: '{msg.msg[:50]}'")
|
||||
|
||||
# 第三步:如果还需要删除,按重要性分数删除重要消息
|
||||
remaining_quota = delete_target - len(to_delete_indices)
|
||||
if remaining_quota > 0 and important_msgs:
|
||||
# 按重要性分数排序(分数低的优先删除)
|
||||
imp_sorted = sorted(important_msgs, key=lambda x: self._importance_score(x[1]))
|
||||
imp_to_delete = min(len(imp_sorted), remaining_quota)
|
||||
for i in range(imp_to_delete):
|
||||
idx, msg = imp_sorted[i]
|
||||
to_delete_indices.add(idx)
|
||||
score = self._importance_score(msg)
|
||||
deleted_details.append(f"[{idx}] 重要(分数{score}): '{msg.msg[:50]}'")
|
||||
|
||||
# 执行删除
|
||||
kept_msgs = []
|
||||
for idx, m in enumerate(msgs):
|
||||
if idx not in to_delete_indices:
|
||||
kept_msgs.append(m)
|
||||
|
||||
# 确保至少保留1条
|
||||
if not kept_msgs and msgs:
|
||||
kept_msgs = [msgs[0]]
|
||||
|
||||
dd.context.msgs = kept_msgs
|
||||
deleted_count = original_count - len(kept_msgs)
|
||||
total_deleted_msgs += deleted_count
|
||||
|
||||
# 输出删除详情
|
||||
if deleted_details:
|
||||
self._log(f"[剪枝-删除详情] 对话 {d_idx+1} 删除了以下消息:")
|
||||
for detail in deleted_details:
|
||||
self._log(f" {detail}")
|
||||
|
||||
# ========== 问答对统计(已注释) ==========
|
||||
# qa_info = f",问答对={len(qa_pairs)}" if qa_pairs else ""
|
||||
# ========================================
|
||||
|
||||
self._log(
|
||||
f"[剪枝-对话] 对话 {d_idx+1} 总消息={original_count} "
|
||||
f"(重要={len(important_msgs)} 不重要={len(unimportant_msgs)} 填充={len(filler_msgs)}) "
|
||||
f"删除={deleted_count} 保留={len(kept_msgs)}"
|
||||
)
|
||||
|
||||
result.append(dd)
|
||||
|
||||
self._log(f"[剪枝-数据集] 剩余对话数={len(result)}")
|
||||
|
||||
# 将本次剪枝阶段的终端输出保存为 JSON 文件(仅在剪枝器内部完成)
|
||||
# 保存日志
|
||||
try:
|
||||
from app.core.config import settings
|
||||
settings.ensure_memory_output_dir()
|
||||
log_output_path = settings.get_memory_output_path("pruned_terminal.json")
|
||||
# 去除日志前缀标签(如 [剪枝-数据集]、[剪枝-对话])后再解析为结构化字段保存
|
||||
sanitized_logs = [self._sanitize_log_line(l) for l in self.run_logs]
|
||||
payload = self._parse_logs_to_structured(sanitized_logs)
|
||||
with open(log_output_path, "w", encoding="utf-8") as f:
|
||||
@@ -448,6 +720,7 @@ class SemanticPruner:
|
||||
if not result:
|
||||
print("警告: 语义剪枝后数据集为空,已回退为未剪枝数据以避免流程中断")
|
||||
return dialogs
|
||||
|
||||
return result
|
||||
|
||||
def _log(self, msg: str) -> None:
|
||||
|
||||
@@ -0,0 +1,326 @@
|
||||
"""
|
||||
场景特定配置 - 为不同场景提供定制化的剪枝规则
|
||||
|
||||
功能:
|
||||
- 场景特定的重要信息识别模式
|
||||
- 场景特定的重要性评分权重
|
||||
- 场景特定的填充词库
|
||||
- 场景特定的问答对识别规则
|
||||
"""
|
||||
|
||||
from typing import Dict, List, Set, Tuple
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScenePatterns:
|
||||
"""场景特定的识别模式"""
|
||||
|
||||
# 重要信息的正则模式(优先级从高到低)
|
||||
high_priority_patterns: List[Tuple[str, int]] = field(default_factory=list) # (pattern, weight)
|
||||
medium_priority_patterns: List[Tuple[str, int]] = field(default_factory=list)
|
||||
low_priority_patterns: List[Tuple[str, int]] = field(default_factory=list)
|
||||
|
||||
# 填充词库(无意义对话)
|
||||
filler_phrases: Set[str] = field(default_factory=set)
|
||||
|
||||
# 问句关键词(用于识别问答对)
|
||||
question_keywords: Set[str] = field(default_factory=set)
|
||||
|
||||
# 决策性/承诺性关键词
|
||||
decision_keywords: Set[str] = field(default_factory=set)
|
||||
|
||||
|
||||
class SceneConfigRegistry:
|
||||
"""场景配置注册表 - 管理所有场景的特定配置"""
|
||||
|
||||
# 基础通用模式(所有场景共享)
|
||||
BASE_HIGH_PRIORITY = [
|
||||
(r"订单号|工单|申请号|编号|ID|账号|账户", 5),
|
||||
(r"金额|费用|价格|¥|¥|\d+元", 5),
|
||||
(r"\d{11}", 4), # 手机号
|
||||
(r"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}", 4), # 邮箱
|
||||
]
|
||||
|
||||
BASE_MEDIUM_PRIORITY = [
|
||||
(r"\d{4}-\d{1,2}-\d{1,2}", 3), # 日期
|
||||
(r"\d{4}年\d{1,2}月\d{1,2}日", 3),
|
||||
(r"电话|手机号|微信|QQ|联系方式", 3),
|
||||
(r"地址|地点|位置", 2),
|
||||
(r"时间|日期|有效期|截止", 2),
|
||||
(r"今天|明天|后天|昨天|前天", 3), # 相对时间(提高权重)
|
||||
(r"下周|下月|下年|上周|上月|上年|本周|本月|本年", 3),
|
||||
(r"今年|去年|明年", 3),
|
||||
]
|
||||
|
||||
BASE_LOW_PRIORITY = [
|
||||
(r"\d{1,2}:\d{2}", 2), # 时间点 HH:MM
|
||||
(r"\d{1,2}点\d{0,2}分?", 2), # 时间点 X点Y分 或 X点
|
||||
(r"上午|下午|中午|晚上|早上|傍晚|凌晨", 2), # 时段(提高权重并扩充)
|
||||
(r"AM|PM|am|pm", 1),
|
||||
]
|
||||
|
||||
BASE_FILLERS = {
|
||||
# 基础寒暄
|
||||
"你好", "您好", "在吗", "在的", "在呢", "嗯", "嗯嗯", "哦", "哦哦",
|
||||
"好的", "好", "行", "可以", "不可以", "谢谢", "多谢", "感谢",
|
||||
"拜拜", "再见", "88", "拜", "回见",
|
||||
# 口头禅
|
||||
"哈哈", "呵呵", "哈哈哈", "嘿嘿", "嘻嘻", "hiahia",
|
||||
"额", "呃", "啊", "诶", "唉", "哎", "嗯哼",
|
||||
# 确认词
|
||||
"是的", "对", "对的", "没错", "嗯嗯", "好嘞", "收到", "明白", "了解", "知道了",
|
||||
# 标点和符号
|
||||
"。。。", "...", "???", "???", "!!!", "!!!",
|
||||
# 表情符号
|
||||
"[微笑]", "[呲牙]", "[发呆]", "[得意]", "[流泪]", "[害羞]", "[闭嘴]",
|
||||
"[睡]", "[大哭]", "[尴尬]", "[发怒]", "[调皮]", "[龇牙]", "[惊讶]",
|
||||
"[难过]", "[酷]", "[冷汗]", "[抓狂]", "[吐]", "[偷笑]", "[可爱]",
|
||||
"[白眼]", "[傲慢]", "[饥饿]", "[困]", "[惊恐]", "[流汗]", "[憨笑]",
|
||||
# 网络用语
|
||||
"hhh", "hhhh", "2333", "666", "gg", "ok", "OK", "okok",
|
||||
"emmm", "emm", "em", "mmp", "wtf", "omg",
|
||||
}
|
||||
|
||||
BASE_QUESTION_KEYWORDS = {
|
||||
"什么", "为什么", "怎么", "如何", "哪里", "哪个", "谁", "多少", "几点", "何时", "吗"
|
||||
}
|
||||
|
||||
BASE_DECISION_KEYWORDS = {
|
||||
"必须", "一定", "务必", "需要", "要求", "规定", "应该",
|
||||
"承诺", "保证", "确保", "负责", "同意", "答应"
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_education_config(cls) -> ScenePatterns:
|
||||
"""教育场景配置"""
|
||||
return ScenePatterns(
|
||||
high_priority_patterns=cls.BASE_HIGH_PRIORITY + [
|
||||
# 成绩相关(最高优先级)
|
||||
(r"成绩|分数|得分|满分|及格|不及格", 6),
|
||||
(r"GPA|绩点|学分|平均分", 6),
|
||||
(r"\d+分|\d+\.?\d*分", 5), # 具体分数
|
||||
(r"排名|名次|第.{1,3}名", 5), # 支持"第三名"、"第1名"等
|
||||
|
||||
# 学籍信息
|
||||
(r"学号|学生证|教师工号|工号", 5),
|
||||
(r"班级|年级|专业|院系", 4),
|
||||
|
||||
# 课程相关
|
||||
(r"课程|科目|学科|必修|选修", 4),
|
||||
(r"教材|课本|教科书|参考书", 4),
|
||||
(r"章节|第.{1,3}章|第.{1,3}节", 3), # 支持"第三章"、"第1章"等
|
||||
|
||||
# 学科内容(新增)
|
||||
(r"微积分|导数|积分|函数|极限|微分", 4),
|
||||
(r"代数|几何|三角|概率|统计", 4),
|
||||
(r"物理|化学|生物|历史|地理", 4),
|
||||
(r"英语|语文|数学|政治|哲学", 4),
|
||||
(r"定义|定理|公式|概念|原理|法则", 3),
|
||||
(r"例题|解题|证明|推导|计算", 3),
|
||||
],
|
||||
medium_priority_patterns=cls.BASE_MEDIUM_PRIORITY + [
|
||||
# 教学活动
|
||||
(r"作业|练习|习题|题目", 3),
|
||||
(r"考试|测验|测试|考核|期中|期末", 3),
|
||||
(r"上课|下课|课堂|讲课", 2),
|
||||
(r"提问|回答|发言|讨论", 2),
|
||||
(r"问一下|请教|咨询|询问", 2), # 新增:问询相关
|
||||
(r"理解|明白|懂|掌握|学会", 2), # 新增:学习状态
|
||||
|
||||
# 时间安排
|
||||
(r"课表|课程表|时间表", 3),
|
||||
(r"第.{1,3}节课|第.{1,3}周", 2), # 支持"第三节课"、"第1周"等
|
||||
],
|
||||
low_priority_patterns=cls.BASE_LOW_PRIORITY + [
|
||||
(r"老师|教师|同学|学生", 1),
|
||||
(r"教室|实验室|图书馆", 1),
|
||||
],
|
||||
filler_phrases=cls.BASE_FILLERS | {
|
||||
# 教育场景特有填充词(移除了"明白了"、"懂了"、"不懂"等,这些在教育场景中有意义)
|
||||
"老师好", "同学们好", "上课", "下课", "起立", "坐下",
|
||||
"举手", "请坐", "很好", "不错", "继续",
|
||||
"下一个", "下一题", "下一位", "还有吗", "还有问题吗",
|
||||
},
|
||||
question_keywords=cls.BASE_QUESTION_KEYWORDS | {
|
||||
"为啥", "咋", "咋办", "怎样", "如何做",
|
||||
"能不能", "可不可以", "行不行", "对不对", "是不是",
|
||||
},
|
||||
decision_keywords=cls.BASE_DECISION_KEYWORDS | {
|
||||
"必考", "重点", "考点", "难点", "关键",
|
||||
"记住", "背诵", "掌握", "理解", "复习",
|
||||
}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_online_service_config(cls) -> ScenePatterns:
|
||||
"""在线服务场景配置"""
|
||||
return ScenePatterns(
|
||||
high_priority_patterns=cls.BASE_HIGH_PRIORITY + [
|
||||
# 工单相关(最高优先级)
|
||||
(r"工单号|工单编号|ticket|TK\d+", 6),
|
||||
(r"工单状态|处理中|已解决|已关闭|待处理", 5),
|
||||
(r"优先级|紧急|高优先级|P0|P1|P2", 5),
|
||||
|
||||
# 产品信息
|
||||
(r"产品型号|型号|SKU|产品编号", 5),
|
||||
(r"序列号|SN|设备号", 5),
|
||||
(r"版本号|软件版本|固件版本", 4),
|
||||
|
||||
# 问题描述
|
||||
(r"故障|错误|异常|bug|问题", 4),
|
||||
(r"错误代码|故障代码|error code", 5),
|
||||
(r"无法|不能|失败|报错", 3),
|
||||
],
|
||||
medium_priority_patterns=cls.BASE_MEDIUM_PRIORITY + [
|
||||
# 服务相关
|
||||
(r"退款|退货|换货|补发", 4),
|
||||
(r"发票|收据|凭证", 3),
|
||||
(r"物流|快递|运单号", 3),
|
||||
(r"保修|质保|售后", 3),
|
||||
|
||||
# 时效相关
|
||||
(r"SLA|响应时间|处理时长", 4),
|
||||
(r"超时|延迟|等待", 2),
|
||||
],
|
||||
low_priority_patterns=cls.BASE_LOW_PRIORITY + [
|
||||
(r"客服|工程师|技术支持", 1),
|
||||
(r"用户|客户|会员", 1),
|
||||
],
|
||||
filler_phrases=cls.BASE_FILLERS | {
|
||||
# 在线服务特有填充词
|
||||
"您好", "请问", "请稍等", "稍等", "马上", "立即",
|
||||
"正在查询", "正在处理", "正在为您", "帮您查一下",
|
||||
"还有其他问题吗", "还需要什么帮助", "很高兴为您服务",
|
||||
"感谢您的耐心等待", "抱歉让您久等了",
|
||||
"已记录", "已反馈", "已转接", "已升级",
|
||||
"祝您生活愉快", "再见", "欢迎下次咨询",
|
||||
},
|
||||
question_keywords=cls.BASE_QUESTION_KEYWORDS | {
|
||||
"能否", "可否", "是否", "有没有", "能不能",
|
||||
"怎么办", "如何处理", "怎么解决",
|
||||
},
|
||||
decision_keywords=cls.BASE_DECISION_KEYWORDS | {
|
||||
"立即处理", "马上解决", "尽快", "优先",
|
||||
"升级", "转接", "派单", "跟进",
|
||||
"补偿", "赔偿", "退款", "换货",
|
||||
}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_outbound_config(cls) -> ScenePatterns:
|
||||
"""外呼场景配置"""
|
||||
return ScenePatterns(
|
||||
high_priority_patterns=cls.BASE_HIGH_PRIORITY + [
|
||||
# 意向相关(最高优先级)
|
||||
(r"意向|意愿|兴趣|感兴趣", 6),
|
||||
(r"A类|B类|C类|D类|高意向|低意向", 6),
|
||||
(r"成交|签约|下单|购买|确认", 6),
|
||||
|
||||
# 联系信息(外呼场景中更重要)
|
||||
(r"预约|约定|安排|确定时间", 5),
|
||||
(r"下次联系|回访|跟进", 5),
|
||||
(r"方便|有空|可以|时间", 4),
|
||||
|
||||
# 通话状态
|
||||
(r"接通|未接通|占线|关机|停机", 4),
|
||||
(r"通话时长|通话时间", 3),
|
||||
],
|
||||
medium_priority_patterns=cls.BASE_MEDIUM_PRIORITY + [
|
||||
# 客户信息
|
||||
(r"姓名|称呼|先生|女士", 3),
|
||||
(r"公司|单位|职位|职务", 3),
|
||||
(r"需求|要求|期望", 3),
|
||||
|
||||
# 跟进状态
|
||||
(r"跟进状态|进展|进度", 3),
|
||||
(r"已联系|待联系|联系中", 2),
|
||||
(r"拒绝|不感兴趣|考虑|再说", 3),
|
||||
],
|
||||
low_priority_patterns=cls.BASE_LOW_PRIORITY + [
|
||||
(r"销售|客户经理|业务员", 1),
|
||||
(r"产品|服务|方案", 1),
|
||||
],
|
||||
filler_phrases=cls.BASE_FILLERS | {
|
||||
# 外呼场景特有填充词
|
||||
"您好", "喂", "hello", "打扰了", "不好意思",
|
||||
"方便接电话吗", "现在方便吗", "占用您一点时间",
|
||||
"我是", "我们是", "我们公司", "我们这边",
|
||||
"了解一下", "介绍一下", "简单说一下",
|
||||
"考虑考虑", "想一想", "再说", "再看看",
|
||||
"不需要", "不感兴趣", "没兴趣", "不用了",
|
||||
"好的", "行", "可以", "没问题", "那就这样",
|
||||
"再联系", "回头聊", "有需要再说",
|
||||
},
|
||||
question_keywords=cls.BASE_QUESTION_KEYWORDS | {
|
||||
"有没有", "需不需要", "要不要", "考虑不考虑",
|
||||
"了解吗", "知道吗", "听说过吗",
|
||||
"方便吗", "有空吗", "在吗",
|
||||
},
|
||||
decision_keywords=cls.BASE_DECISION_KEYWORDS | {
|
||||
"确定", "决定", "选择", "购买", "下单",
|
||||
"预约", "安排", "约定", "确认",
|
||||
"跟进", "回访", "联系", "沟通",
|
||||
}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_config(cls, scene: str, fallback_to_generic: bool = True) -> ScenePatterns:
|
||||
"""根据场景名称获取配置
|
||||
|
||||
Args:
|
||||
scene: 场景名称 ('education', 'online_service', 'outbound' 或其他)
|
||||
fallback_to_generic: 如果场景不存在,是否降级到通用配置
|
||||
|
||||
Returns:
|
||||
对应场景的配置,如果场景不存在:
|
||||
- fallback_to_generic=True: 返回通用配置(仅基础规则)
|
||||
- fallback_to_generic=False: 抛出异常
|
||||
"""
|
||||
scene_map = {
|
||||
'education': cls.get_education_config,
|
||||
'online_service': cls.get_online_service_config,
|
||||
'outbound': cls.get_outbound_config,
|
||||
}
|
||||
|
||||
if scene in scene_map:
|
||||
return scene_map[scene]()
|
||||
|
||||
if fallback_to_generic:
|
||||
# 返回通用配置(仅包含基础规则,不包含场景特定规则)
|
||||
return cls.get_generic_config()
|
||||
else:
|
||||
raise ValueError(f"不支持的场景: {scene},支持的场景: {list(scene_map.keys())}")
|
||||
|
||||
@classmethod
|
||||
def get_generic_config(cls) -> ScenePatterns:
|
||||
"""通用场景配置 - 仅包含基础规则,适用于未定义的场景
|
||||
|
||||
这是一个保守的配置,只使用最通用的规则,避免误删重要信息
|
||||
"""
|
||||
return ScenePatterns(
|
||||
high_priority_patterns=cls.BASE_HIGH_PRIORITY,
|
||||
medium_priority_patterns=cls.BASE_MEDIUM_PRIORITY,
|
||||
low_priority_patterns=cls.BASE_LOW_PRIORITY,
|
||||
filler_phrases=cls.BASE_FILLERS,
|
||||
question_keywords=cls.BASE_QUESTION_KEYWORDS,
|
||||
decision_keywords=cls.BASE_DECISION_KEYWORDS
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_all_scenes(cls) -> List[str]:
|
||||
"""获取所有预定义场景的列表"""
|
||||
return ['education', 'online_service', 'outbound']
|
||||
|
||||
@classmethod
|
||||
def is_scene_supported(cls, scene: str) -> bool:
|
||||
"""检查场景是否有专门的配置支持
|
||||
|
||||
Args:
|
||||
scene: 场景名称
|
||||
|
||||
Returns:
|
||||
True: 有专门配置
|
||||
False: 将使用通用配置
|
||||
"""
|
||||
return scene in cls.get_all_scenes()
|
||||
@@ -1932,17 +1932,17 @@ def preprocess_data(
|
||||
Returns:
|
||||
经过清洗转换后的 DialogData 列表
|
||||
"""
|
||||
print("\n=== 数据预处理 ===")
|
||||
logger.debug("=== 数据预处理 ===")
|
||||
from app.core.memory.storage_services.extraction_engine.data_preprocessing.data_preprocessor import (
|
||||
DataPreprocessor,
|
||||
)
|
||||
preprocessor = DataPreprocessor()
|
||||
try:
|
||||
cleaned_data = preprocessor.preprocess(input_path=input_path, output_path=output_path, skip_cleaning=skip_cleaning, indices=indices)
|
||||
print(f"数据预处理完成!共处理了 {len(cleaned_data)} 条对话数据")
|
||||
logger.debug(f"数据预处理完成!共处理了 {len(cleaned_data)} 条对话数据")
|
||||
return cleaned_data
|
||||
except Exception as e:
|
||||
print(f"数据预处理过程中出现错误: {e}")
|
||||
logger.error(f"数据预处理过程中出现错误: {e}")
|
||||
raise
|
||||
|
||||
|
||||
@@ -1961,7 +1961,7 @@ async def get_chunked_dialogs_from_preprocessed(
|
||||
Returns:
|
||||
带 chunks 的 DialogData 列表
|
||||
"""
|
||||
print(f"\n=== 批量对话分块处理 (使用 {chunker_strategy}) ===")
|
||||
logger.debug(f"=== 批量对话分块处理 (使用 {chunker_strategy}) ===")
|
||||
if not data:
|
||||
raise ValueError("预处理数据为空,无法进行分块")
|
||||
|
||||
@@ -1988,6 +1988,7 @@ async def get_chunked_dialogs_with_preprocessing(
|
||||
input_data_path: Optional[str] = None,
|
||||
llm_client: Optional[Any] = None,
|
||||
skip_cleaning: bool = True,
|
||||
pruning_config: Optional[Dict] = None,
|
||||
) -> List[DialogData]:
|
||||
"""包含数据预处理步骤的完整分块流程
|
||||
|
||||
@@ -2000,11 +2001,12 @@ async def get_chunked_dialogs_with_preprocessing(
|
||||
input_data_path: 输入数据路径
|
||||
llm_client: LLM 客户端
|
||||
skip_cleaning: 是否跳过数据清洗步骤(默认False)
|
||||
pruning_config: 剪枝配置字典,包含 pruning_switch, pruning_scene, pruning_threshold
|
||||
|
||||
Returns:
|
||||
带 chunks 的 DialogData 列表
|
||||
"""
|
||||
print("\n=== 完整数据处理流程(包含预处理)===")
|
||||
logger.debug("=== 完整数据处理流程(包含预处理)===")
|
||||
|
||||
if input_data_path is None:
|
||||
input_data_path = os.path.join(
|
||||
@@ -2030,7 +2032,19 @@ async def get_chunked_dialogs_with_preprocessing(
|
||||
from app.core.memory.storage_services.extraction_engine.data_preprocessing.data_pruning import (
|
||||
SemanticPruner,
|
||||
)
|
||||
pruner = SemanticPruner(llm_client=llm_client)
|
||||
from app.core.memory.models.config_models import PruningConfig
|
||||
|
||||
# 构建剪枝配置
|
||||
if pruning_config:
|
||||
# 使用传入的配置
|
||||
config = PruningConfig(**pruning_config)
|
||||
logger.debug(f"[剪枝] 使用传入配置: switch={config.pruning_switch}, scene={config.pruning_scene}, threshold={config.pruning_threshold}")
|
||||
else:
|
||||
# 使用默认配置(关闭剪枝)
|
||||
config = None
|
||||
logger.debug("[剪枝] 未提供配置,使用默认配置(剪枝关闭)")
|
||||
|
||||
pruner = SemanticPruner(config=config, llm_client=llm_client)
|
||||
|
||||
# 记录单对话场景下剪枝前的消息数量
|
||||
single_dialog_original_msgs = None
|
||||
@@ -2043,12 +2057,12 @@ async def get_chunked_dialogs_with_preprocessing(
|
||||
if len(preprocessed_data) == 1 and single_dialog_original_msgs is not None:
|
||||
remaining_msgs = len(preprocessed_data[0].context.msgs) if preprocessed_data[0].context else 0
|
||||
deleted_msgs = max(0, single_dialog_original_msgs - remaining_msgs)
|
||||
print(
|
||||
logger.debug(
|
||||
f"语义剪枝完成!剩余 1 条对话!原始消息数:{single_dialog_original_msgs},"
|
||||
f"保留消息数:{remaining_msgs},删除 {deleted_msgs} 条。"
|
||||
)
|
||||
else:
|
||||
print(f"语义剪枝完成!剩余 {len(preprocessed_data)} 条对话")
|
||||
logger.debug(f"语义剪枝完成!剩余 {len(preprocessed_data)} 条对话")
|
||||
|
||||
# 保存剪枝后的数据
|
||||
try:
|
||||
@@ -2059,9 +2073,9 @@ async def get_chunked_dialogs_with_preprocessing(
|
||||
dp = DataPreprocessor(output_file_path=pruned_output_path)
|
||||
dp.save_data(preprocessed_data, output_path=pruned_output_path)
|
||||
except Exception as se:
|
||||
print(f"保存剪枝结果失败:{se}")
|
||||
logger.error(f"保存剪枝结果失败:{se}")
|
||||
except Exception as e:
|
||||
print(f"语义剪枝过程中出现错误,跳过剪枝: {e}")
|
||||
logger.error(f"语义剪枝过程中出现错误,跳过剪枝: {e}")
|
||||
|
||||
# 步骤3: 对话分块
|
||||
return await get_chunked_dialogs_from_preprocessed(
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
import os
|
||||
from typing import Optional
|
||||
from typing import Optional, List, Any
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
|
||||
from app.core.logging_config import get_memory_logger
|
||||
from app.core.memory.models.message_models import DialogData, Chunk
|
||||
@@ -10,6 +12,20 @@ from app.core.memory.utils.config.config_utils import get_chunker_config
|
||||
logger = get_memory_logger(__name__)
|
||||
|
||||
|
||||
class ChunkerStrategy(Enum):
|
||||
"""Supported chunking strategies."""
|
||||
RECURSIVE = "RecursiveChunker"
|
||||
SEMANTIC = "SemanticChunker"
|
||||
LATE = "LateChunker"
|
||||
NEURAL = "NeuralChunker"
|
||||
LLM = "LLMChunker"
|
||||
|
||||
@classmethod
|
||||
def get_valid_strategies(cls) -> List[str]:
|
||||
"""Get list of valid strategy names."""
|
||||
return [strategy.value for strategy in cls]
|
||||
|
||||
|
||||
class DialogueChunker:
|
||||
"""A class that processes dialogues and fills them with chunks based on a specified strategy.
|
||||
|
||||
@@ -17,23 +33,51 @@ class DialogueChunker:
|
||||
of different chunking strategies to dialogue data.
|
||||
"""
|
||||
|
||||
def __init__(self, chunker_strategy: str = "RecursiveChunker", llm_client=None):
|
||||
def __init__(self, chunker_strategy: str = "RecursiveChunker", llm_client: Optional[Any] = None):
|
||||
"""Initialize the DialogueChunker with a specific chunking strategy.
|
||||
|
||||
Args:
|
||||
chunker_strategy: The chunking strategy to use (default: RecursiveChunker)
|
||||
Options: SemanticChunker, RecursiveChunker, LateChunker, NeuralChunker
|
||||
Options: SemanticChunker, RecursiveChunker, LateChunker, NeuralChunker, LLMChunker
|
||||
llm_client: LLM client instance (required for LLMChunker strategy)
|
||||
|
||||
Raises:
|
||||
ValueError: If chunker_strategy is invalid or required parameters are missing
|
||||
"""
|
||||
self.chunker_strategy = chunker_strategy
|
||||
chunker_config_dict = get_chunker_config(chunker_strategy)
|
||||
self.chunker_config = ChunkerConfig.model_validate(chunker_config_dict)
|
||||
# Validate strategy
|
||||
valid_strategies = ChunkerStrategy.get_valid_strategies()
|
||||
if chunker_strategy not in valid_strategies:
|
||||
raise ValueError(
|
||||
f"Invalid chunker_strategy: '{chunker_strategy}'. "
|
||||
f"Must be one of {valid_strategies}"
|
||||
)
|
||||
|
||||
if self.chunker_config.chunker_strategy == "LLMChunker":
|
||||
self.chunker_client = ChunkerClient(self.chunker_config, llm_client)
|
||||
else:
|
||||
self.chunker_client = ChunkerClient(self.chunker_config)
|
||||
self.chunker_strategy = chunker_strategy
|
||||
logger.info(f"Initializing DialogueChunker with strategy: {chunker_strategy}")
|
||||
|
||||
try:
|
||||
# Load and validate configuration
|
||||
chunker_config_dict = get_chunker_config(chunker_strategy)
|
||||
if not chunker_config_dict:
|
||||
raise ValueError(f"Failed to load configuration for strategy: {chunker_strategy}")
|
||||
|
||||
self.chunker_config = ChunkerConfig.model_validate(chunker_config_dict)
|
||||
|
||||
# Initialize chunker client
|
||||
if self.chunker_config.chunker_strategy == "LLMChunker":
|
||||
if not llm_client:
|
||||
raise ValueError("llm_client is required for LLMChunker strategy")
|
||||
self.chunker_client = ChunkerClient(self.chunker_config, llm_client)
|
||||
else:
|
||||
self.chunker_client = ChunkerClient(self.chunker_config)
|
||||
|
||||
logger.info(f"DialogueChunker initialized successfully with strategy: {chunker_strategy}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize DialogueChunker: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def process_dialogue(self, dialogue: DialogData) -> list[Chunk]:
|
||||
async def process_dialogue(self, dialogue: DialogData) -> List[Chunk]:
|
||||
"""Process a dialogue by generating chunks and adding them to the DialogData object.
|
||||
|
||||
Args:
|
||||
@@ -43,54 +87,125 @@ class DialogueChunker:
|
||||
A list of Chunk objects
|
||||
|
||||
Raises:
|
||||
ValueError: If chunking fails or returns empty chunks
|
||||
ValueError: If dialogue is invalid or chunking fails
|
||||
Exception: If chunking process encounters an error
|
||||
"""
|
||||
result_dialogue = await self.chunker_client.generate_chunks(dialogue)
|
||||
chunks = result_dialogue.chunks
|
||||
|
||||
if not chunks or len(chunks) == 0:
|
||||
# Validate input
|
||||
if not dialogue:
|
||||
raise ValueError("dialogue cannot be None")
|
||||
|
||||
if not dialogue.context or not dialogue.context.msgs:
|
||||
raise ValueError(
|
||||
f"Chunking failed: No chunks generated for dialogue {dialogue.ref_id}. "
|
||||
f"Messages: {len(dialogue.context.msgs) if dialogue.context else 0}, "
|
||||
f"Strategy: {self.chunker_config.chunker_strategy}"
|
||||
f"Dialogue {dialogue.ref_id} has no messages to chunk. "
|
||||
f"Context: {dialogue.context is not None}, "
|
||||
f"Messages: {len(dialogue.context.msgs) if dialogue.context else 0}"
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Processing dialogue {dialogue.ref_id} with {len(dialogue.context.msgs)} messages "
|
||||
f"using strategy: {self.chunker_strategy}"
|
||||
)
|
||||
|
||||
try:
|
||||
# Generate chunks
|
||||
result_dialogue = await self.chunker_client.generate_chunks(dialogue)
|
||||
chunks = result_dialogue.chunks
|
||||
|
||||
return chunks
|
||||
# Validate results
|
||||
if not chunks or len(chunks) == 0:
|
||||
raise ValueError(
|
||||
f"Chunking failed: No chunks generated for dialogue {dialogue.ref_id}. "
|
||||
f"Messages: {len(dialogue.context.msgs)}, "
|
||||
f"Content length: {len(dialogue.content) if dialogue.content else 0}, "
|
||||
f"Strategy: {self.chunker_config.chunker_strategy}"
|
||||
)
|
||||
|
||||
def save_chunking_results(self, dialogue: DialogData, output_path: Optional[str] = None) -> str:
|
||||
logger.info(
|
||||
f"Successfully generated {len(chunks)} chunks for dialogue {dialogue.ref_id}. "
|
||||
f"Total characters processed: {len(dialogue.content) if dialogue.content else 0}"
|
||||
)
|
||||
|
||||
return chunks
|
||||
|
||||
except ValueError:
|
||||
# Re-raise validation errors
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error processing dialogue {dialogue.ref_id} with strategy {self.chunker_strategy}: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
raise
|
||||
|
||||
def save_chunking_results(
|
||||
self,
|
||||
chunks: List[Chunk],
|
||||
dialogue: DialogData,
|
||||
output_path: Optional[str] = None,
|
||||
preview_length: int = 100
|
||||
) -> str:
|
||||
"""Save the chunking results to a file and return the output path.
|
||||
|
||||
Args:
|
||||
dialogue: The processed DialogData object with chunks
|
||||
output_path: Optional path to save the output
|
||||
chunks: List of Chunk objects to save
|
||||
dialogue: The DialogData object that was processed
|
||||
output_path: Optional path to save the output (defaults to current directory)
|
||||
preview_length: Maximum length of content preview (default: 100)
|
||||
|
||||
Returns:
|
||||
The path where the output was saved
|
||||
|
||||
Raises:
|
||||
ValueError: If chunks or dialogue is invalid
|
||||
IOError: If file writing fails
|
||||
"""
|
||||
if not output_path:
|
||||
output_path = os.path.join(
|
||||
os.path.dirname(__file__), "..", "..",
|
||||
f"chunker_output_{self.chunker_strategy.lower()}.txt"
|
||||
)
|
||||
|
||||
output_lines = [
|
||||
f"=== Chunking Results ({self.chunker_strategy}) ===",
|
||||
f"Dialogue ID: {dialogue.ref_id}",
|
||||
f"Original conversation has {len(dialogue.context.msgs)} messages",
|
||||
f"Total characters: {len(dialogue.content)}",
|
||||
f"Generated {len(dialogue.chunks)} chunks:"
|
||||
]
|
||||
# Validate input
|
||||
if not chunks:
|
||||
raise ValueError("chunks list cannot be empty")
|
||||
if not dialogue:
|
||||
raise ValueError("dialogue cannot be None")
|
||||
|
||||
for i, chunk in enumerate(dialogue.chunks):
|
||||
output_lines.append(f" Chunk {i+1}: {len(chunk.content)} characters")
|
||||
output_lines.append(f" Content preview: {chunk.content}...")
|
||||
if chunk.metadata:
|
||||
output_lines.append(f" Metadata: {chunk.metadata}")
|
||||
# Generate default output path if not provided
|
||||
if not output_path:
|
||||
output_dir = Path(__file__).parent.parent.parent
|
||||
output_path = str(output_dir / f"chunker_output_{self.chunker_strategy.lower()}.txt")
|
||||
|
||||
logger.info(f"Saving chunking results to: {output_path}")
|
||||
|
||||
try:
|
||||
# Prepare output content
|
||||
output_lines = [
|
||||
f"=== Chunking Results ({self.chunker_strategy}) ===",
|
||||
f"Dialogue ID: {dialogue.ref_id}",
|
||||
f"Original conversation has {len(dialogue.context.msgs) if dialogue.context else 0} messages",
|
||||
f"Total characters: {len(dialogue.content) if dialogue.content else 0}",
|
||||
f"Generated {len(chunks)} chunks:",
|
||||
""
|
||||
]
|
||||
|
||||
for i, chunk in enumerate(chunks, 1):
|
||||
content_preview = chunk.content[:preview_length] if chunk.content else ""
|
||||
if len(chunk.content) > preview_length:
|
||||
content_preview += "..."
|
||||
|
||||
output_lines.append(f" Chunk {i}: {len(chunk.content)} characters")
|
||||
output_lines.append(f" Content preview: {content_preview}")
|
||||
if chunk.metadata:
|
||||
output_lines.append(f" Metadata: {chunk.metadata}")
|
||||
output_lines.append("")
|
||||
|
||||
with open(output_path, "w", encoding="utf-8") as f:
|
||||
f.write("\n".join(output_lines))
|
||||
# Write to file
|
||||
with open(output_path, "w", encoding="utf-8") as f:
|
||||
f.write("\n".join(output_lines))
|
||||
|
||||
logger.info(f"Chunking results saved to: {output_path}")
|
||||
return output_path
|
||||
logger.info(f"Successfully saved chunking results to: {output_path}")
|
||||
return output_path
|
||||
|
||||
except IOError as e:
|
||||
logger.error(f"Failed to write chunking results to {output_path}: {e}", exc_info=True)
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error saving chunking results: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
|
||||
@@ -327,7 +327,7 @@ class MultiOntologyParser:
|
||||
|
||||
Example:
|
||||
>>> parser = MultiOntologyParser([
|
||||
... "General_purpose_entity.ttl",
|
||||
... "app/core/memory/ontology_services/General_purpose_entity.ttl",
|
||||
... "domain_specific.owl"
|
||||
... ])
|
||||
>>> registry = parser.parse_all()
|
||||
|
||||
@@ -400,7 +400,8 @@ async def render_user_summary_prompt(
|
||||
user_id: str,
|
||||
entities: str,
|
||||
statements: str,
|
||||
language: str = "zh"
|
||||
language: str = "zh",
|
||||
user_display_name: str = None
|
||||
) -> str:
|
||||
"""
|
||||
Renders the user summary prompt using the user_summary.jinja2 template.
|
||||
@@ -410,16 +411,22 @@ async def render_user_summary_prompt(
|
||||
entities: Core entities with frequency information
|
||||
statements: Representative statement samples
|
||||
language: The language to use for summary generation ("zh" for Chinese, "en" for English)
|
||||
user_display_name: Display name for the user (e.g., other_name or "该用户"/"the user")
|
||||
|
||||
Returns:
|
||||
Rendered prompt content as string
|
||||
"""
|
||||
# 如果没有提供 user_display_name,使用默认值
|
||||
if user_display_name is None:
|
||||
user_display_name = "该用户" if language == "zh" else "the user"
|
||||
|
||||
template = prompt_env.get_template("user_summary.jinja2")
|
||||
rendered_prompt = template.render(
|
||||
user_id=user_id,
|
||||
entities=entities,
|
||||
statements=statements,
|
||||
language=language
|
||||
language=language,
|
||||
user_display_name=user_display_name
|
||||
)
|
||||
|
||||
# 记录渲染结果到提示日志
|
||||
@@ -429,7 +436,8 @@ async def render_user_summary_prompt(
|
||||
'user_id': user_id,
|
||||
'entities_len': len(entities),
|
||||
'statements_len': len(statements),
|
||||
'language': language
|
||||
'language': language,
|
||||
'user_display_name': user_display_name
|
||||
})
|
||||
|
||||
return rendered_prompt
|
||||
@@ -540,3 +548,20 @@ async def render_ontology_extraction_prompt(
|
||||
})
|
||||
|
||||
return rendered_prompt
|
||||
|
||||
|
||||
def render_interest_filter_prompt(tag_list: str, language: str = "zh") -> str:
|
||||
"""
|
||||
Renders the interest filter prompt using the interest_filter.jinja2 template.
|
||||
|
||||
Args:
|
||||
tag_list: Comma-separated string of raw tags to filter
|
||||
language: Output language ("zh" for Chinese, "en" for English)
|
||||
|
||||
Returns:
|
||||
Rendered prompt content as string
|
||||
"""
|
||||
template = prompt_env.get_template("interest_filter.jinja2")
|
||||
rendered_prompt = template.render(tag_list=tag_list, language=language)
|
||||
log_prompt_rendering('interest filter', rendered_prompt)
|
||||
return rendered_prompt
|
||||
|
||||
@@ -0,0 +1,67 @@
|
||||
{% if language == "zh" %}
|
||||
You are a user interest analysis expert. Your task is to infer and extract the user's core hobby/interest activities from a tag list. The tags may be specific project names, tool names, or compound nouns — your job is to identify the underlying interest they represent.
|
||||
|
||||
**Step 1 - Infer the underlying interest from each tag**:
|
||||
Look at each tag and ask: "What hobby or interest does this tag suggest the user has?"
|
||||
|
||||
Examples of inference:
|
||||
- '攀岩', '室内攀岩馆', '攀岩者数据仪表盘', '路线解锁地图', '指力', '路线等级', '当日攀岩流畅度' → '攀岩'
|
||||
- '风光摄影元数据增强器', 'EXIF数据', '.CR2文件', '.NEF文件', '日出拍摄点', '曝光补偿', '光圈', '太阳高度角', '云量预测图层' → '摄影'
|
||||
- '晨间冥想坚持天数', '身心协同峰值' → '冥想'
|
||||
- '川味可视化', '川菜' → '烹饪'
|
||||
- '开源项目命名建议', 'climbviz', '可视化', '力量增长雷达图' → '编程' 或 '数据可视化'
|
||||
- '吉他', '指弹', '琴谱' → '吉他'
|
||||
- '跑步', '5公里', '跑鞋' → '跑步'
|
||||
- '瑜伽垫', '瑜伽课' → '瑜伽'
|
||||
|
||||
**Step 2 - Consolidate and deduplicate**:
|
||||
- Merge tags that point to the same interest into one representative label
|
||||
- Use concise, standard hobby names (e.g., '攀岩', '摄影', '编程', '烹饪', '冥想', '吉他', '跑步')
|
||||
- If multiple tags all point to '攀岩', output '攀岩' only once
|
||||
|
||||
**Step 3 - Filter out non-interest tags**:
|
||||
Remove tags that do NOT suggest any hobby or interest:
|
||||
- Generic system/assistant terms (e.g., '助手', '用户', 'AI')
|
||||
- Pure abstract metrics with no clear hobby link (e.g., '完成时间', '日期', '自我评分')
|
||||
- Location names with no clear hobby link (e.g., '青城山后山' alone — but if combined with photography context, infer '摄影')
|
||||
|
||||
**Output format**: Return a list of concise interest activity names in Chinese.
|
||||
|
||||
**Example**:
|
||||
Input: ['攀岩', '攀岩者数据仪表盘', '路线解锁地图', '指力', '风光摄影元数据增强器', 'EXIF数据', '晨间冥想坚持天数', '川味可视化', '可视化', '助手', '完成时间']
|
||||
Output: ['攀岩', '摄影', '冥想', '烹饪', '编程']
|
||||
|
||||
Now process the following tag list and return the inferred interest activities in Chinese: {{ tag_list }}
|
||||
{% else %}
|
||||
You are a user interest analysis expert. Your task is to infer and extract the user's core hobby/interest activities from a tag list. The tags may be specific project names, tool names, or compound nouns — your job is to identify the underlying interest they represent.
|
||||
|
||||
**Step 1 - Infer the underlying interest from each tag**:
|
||||
Look at each tag and ask: "What hobby or interest does this tag suggest the user has?"
|
||||
|
||||
Examples of inference:
|
||||
- 'rock climbing', 'indoor climbing gym', 'climber dashboard', 'route map', 'finger strength' → 'rock climbing'
|
||||
- 'landscape photography metadata enhancer', 'EXIF data', 'sunrise shooting spot', 'exposure compensation' → 'photography'
|
||||
- 'morning meditation streak', 'mind-body peak' → 'meditation'
|
||||
- 'Sichuan cuisine visualization', 'Sichuan food' → 'cooking'
|
||||
- 'open source project', 'data visualization tool', 'Python' → 'programming'
|
||||
- 'guitar', 'fingerpicking', 'sheet music' → 'guitar'
|
||||
- 'running', '5km', 'running shoes' → 'running'
|
||||
|
||||
**Step 2 - Consolidate and deduplicate**:
|
||||
- Merge tags that point to the same interest into one representative label
|
||||
- Use concise, standard hobby names (e.g., 'rock climbing', 'photography', 'programming', 'cooking', 'meditation')
|
||||
- If multiple tags all point to 'rock climbing', output 'rock climbing' only once
|
||||
|
||||
**Step 3 - Filter out non-interest tags**:
|
||||
Remove tags that do NOT suggest any hobby or interest:
|
||||
- Generic system/assistant terms (e.g., 'assistant', 'user', 'AI')
|
||||
- Pure abstract metrics with no clear hobby link (e.g., 'completion time', 'date', 'self-rating')
|
||||
|
||||
**Output format**: Return a list of concise interest activity names in English.
|
||||
|
||||
**Example**:
|
||||
Input: ['rock climbing', 'climber dashboard', 'route map', 'finger strength', 'landscape photography metadata enhancer', 'EXIF data', 'morning meditation streak', 'Sichuan cuisine visualization', 'visualization', 'assistant', 'completion time']
|
||||
Output: ['rock climbing', 'photography', 'meditation', 'cooking', 'programming']
|
||||
|
||||
Now process the following tag list and return the inferred interest activities in English: {{ tag_list }}
|
||||
{% endif %}
|
||||
@@ -14,8 +14,8 @@ Your task is to generate a comprehensive user profile based on the provided enti
|
||||
{% endif %}
|
||||
|
||||
===Inputs===
|
||||
{% if user_id %}
|
||||
- User ID: {{ user_id }}
|
||||
{% if user_display_name %}
|
||||
- User Display Name: {{ user_display_name }}
|
||||
{% endif %}
|
||||
{% if entities %}
|
||||
- Core Entities & Frequency: {{ entities }}
|
||||
@@ -33,6 +33,20 @@ Your task is to generate a comprehensive user profile based on the provided enti
|
||||
3. Avoid excessive adjectives and empty phrases
|
||||
4. Strictly follow the output format specified below
|
||||
|
||||
{% if language == "zh" %}
|
||||
**【严格人称规定】**
|
||||
- 在描述用户时,必须使用"{{ user_display_name }}"作为人称
|
||||
- 绝对禁止使用用户ID(如 {{ user_id }})来称呼用户
|
||||
- 绝对禁止在摘要中出现任何形式的UUID或ID字符串
|
||||
- 如果需要指代用户,只能使用"{{ user_display_name }}"或相应的代词(他/她/TA)
|
||||
{% else %}
|
||||
**【STRICT PRONOUN RULES】**
|
||||
- When describing the user, you MUST use "{{ user_display_name }}" as the reference
|
||||
- It is ABSOLUTELY FORBIDDEN to use the user ID (such as {{ user_id }}) to refer to the user
|
||||
- It is ABSOLUTELY FORBIDDEN to include any form of UUID or ID string in the summary
|
||||
- If you need to refer to the user, you can ONLY use "{{ user_display_name }}" or appropriate pronouns (he/she/they)
|
||||
{% endif %}
|
||||
|
||||
**Section-Specific Requirements:**
|
||||
|
||||
{% if language == "zh" %}
|
||||
@@ -103,13 +117,13 @@ Your task is to generate a comprehensive user profile based on the provided enti
|
||||
|
||||
{% if language == "zh" %}
|
||||
Example Input:
|
||||
- User ID: user_12345
|
||||
- User Display Name: 张三
|
||||
- Core Entities & Frequency: 产品经理 (15), AI (12), 深圳 (10), 数据分析 (8), 团队协作 (7)
|
||||
- Representative Statement Samples: 我在深圳从事产品经理工作已经5年了 | 我相信好的产品源于对用户需求的深刻理解 | 我喜欢在团队中起到协调作用 | 数据驱动决策是我的工作原则
|
||||
|
||||
Example Output:
|
||||
【基本介绍】
|
||||
我是张三,一名充满热情的高级产品经理。在过去的5年里,我专注于AI和数据驱动的产品设计,致力于创造能够真正改善用户生活的产品。我相信好的产品源于对用户需求的深刻理解和对技术可能性的不断探索。
|
||||
张三是一名充满热情的高级产品经理,在深圳工作。在过去的5年里,张三专注于AI和数据驱动的产品设计,致力于创造能够真正改善用户生活的产品。张三相信好的产品源于对用户需求的深刻理解和对技术可能性的不断探索。
|
||||
|
||||
【性格特点】
|
||||
性格开朗,善于沟通,注重细节。喜欢在团队中起到协调作用,帮助大家达成共识。面对挑战时保持乐观,相信每个问题都有解决方案。
|
||||
@@ -121,13 +135,13 @@ Example Output:
|
||||
"让每一个产品决策都充满温度。"
|
||||
{% else %}
|
||||
Example Input:
|
||||
- User ID: user_12345
|
||||
- User Display Name: John
|
||||
- Core Entities & Frequency: Product Manager (15), AI (12), San Francisco (10), Data Analysis (8), Team Collaboration (7)
|
||||
- Representative Statement Samples: I have been working as a product manager in San Francisco for 5 years | I believe good products come from deep understanding of user needs | I enjoy playing a coordinating role in teams | Data-driven decision making is my work principle
|
||||
|
||||
Example Output:
|
||||
【Basic Introduction】
|
||||
This is a passionate senior product manager based in San Francisco. Over the past 5 years, they have focused on AI and data-driven product design, dedicated to creating products that truly improve users' lives. They believe good products stem from deep understanding of user needs and continuous exploration of technological possibilities.
|
||||
John is a passionate senior product manager based in San Francisco. Over the past 5 years, John has focused on AI and data-driven product design, dedicated to creating products that truly improve users' lives. John believes good products stem from deep understanding of user needs and continuous exploration of technological possibilities.
|
||||
|
||||
【Personality Traits】
|
||||
Outgoing personality with excellent communication skills and attention to detail. Enjoys playing a coordinating role in teams, helping everyone reach consensus. Maintains optimism when facing challenges, believing every problem has a solution.
|
||||
|
||||
@@ -21,31 +21,55 @@ from pydantic import BaseModel, Field
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class RedBearModelConfig(BaseModel):
|
||||
"""模型配置基类"""
|
||||
model_name: str
|
||||
provider: str
|
||||
api_key: str
|
||||
base_url: Optional[str] = None
|
||||
is_omni: bool = False # 是否为 Omni 模型
|
||||
# 请求超时时间(秒)- 默认120秒以支持复杂的LLM调用,可通过环境变量 LLM_TIMEOUT 配置
|
||||
timeout: float = Field(default_factory=lambda: float(os.getenv("LLM_TIMEOUT", "120.0")))
|
||||
# 最大重试次数 - 默认2次以避免过长等待,可通过环境变量 LLM_MAX_RETRIES 配置
|
||||
max_retries: int = Field(default_factory=lambda: int(os.getenv("LLM_MAX_RETRIES", "2")))
|
||||
concurrency: int = 5 # 并发限流
|
||||
concurrency: int = 5 # 并发限流
|
||||
extra_params: Dict[str, Any] = {}
|
||||
|
||||
|
||||
class RedBearModelFactory:
|
||||
"""模型工厂类"""
|
||||
|
||||
|
||||
@classmethod
|
||||
def get_model_params(cls, config: RedBearModelConfig) -> Dict[str, Any]:
|
||||
"""根据提供商获取模型参数"""
|
||||
provider = config.provider.lower()
|
||||
|
||||
|
||||
# 打印供应商信息用于调试
|
||||
from app.core.logging_config import get_business_logger
|
||||
logger = get_business_logger()
|
||||
logger.debug(f"获取模型参数 - Provider: {provider}, Model: {config.model_name}")
|
||||
logger.debug(f"获取模型参数 - Provider: {provider}, Model: {config.model_name}, is_omni: {config.is_omni}")
|
||||
|
||||
# dashscope 的 omni 模型使用 OpenAI 兼容模式
|
||||
if provider == ModelProvider.DASHSCOPE and config.is_omni:
|
||||
import httpx
|
||||
if not config.base_url:
|
||||
config.base_url = "https://dashscope.aliyuncs.com/compatible-mode/v1"
|
||||
timeout_config = httpx.Timeout(
|
||||
timeout=config.timeout,
|
||||
connect=60.0,
|
||||
read=config.timeout,
|
||||
write=60.0,
|
||||
pool=10.0,
|
||||
)
|
||||
return {
|
||||
"model": config.model_name,
|
||||
"base_url": config.base_url,
|
||||
"api_key": config.api_key,
|
||||
"timeout": timeout_config,
|
||||
"max_retries": config.max_retries,
|
||||
**config.extra_params
|
||||
}
|
||||
|
||||
if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK, ModelProvider.OLLAMA]:
|
||||
# 使用 httpx.Timeout 对象来设置详细的超时配置
|
||||
@@ -65,7 +89,7 @@ class RedBearModelFactory:
|
||||
"timeout": timeout_config,
|
||||
"max_retries": config.max_retries,
|
||||
**config.extra_params
|
||||
}
|
||||
}
|
||||
elif provider == ModelProvider.DASHSCOPE:
|
||||
# DashScope (通义千问) 使用自己的参数格式
|
||||
# 注意: DashScopeEmbeddings 不支持 timeout 和 base_url 参数
|
||||
@@ -82,7 +106,7 @@ class RedBearModelFactory:
|
||||
# region 从 base_url 或 extra_params 获取
|
||||
from botocore.config import Config as BotoConfig
|
||||
from app.core.models.bedrock_model_mapper import normalize_bedrock_model_id
|
||||
|
||||
|
||||
max_pool_connections = int(os.getenv("BEDROCK_MAX_POOL_CONNECTIONS", "50"))
|
||||
max_retries = int(os.getenv("BEDROCK_MAX_RETRIES", "2"))
|
||||
# Configure with increased connection pool
|
||||
@@ -90,16 +114,16 @@ class RedBearModelFactory:
|
||||
max_pool_connections=max_pool_connections,
|
||||
retries={'max_attempts': max_retries, 'mode': 'adaptive'}
|
||||
)
|
||||
|
||||
|
||||
# 标准化模型 ID(自动转换简化名称为完整 Bedrock Model ID)
|
||||
model_id = normalize_bedrock_model_id(config.model_name)
|
||||
|
||||
|
||||
params = {
|
||||
"model_id": model_id,
|
||||
"config": boto_config,
|
||||
**config.extra_params
|
||||
}
|
||||
|
||||
|
||||
# 解析 API key (格式: access_key_id:secret_access_key)
|
||||
if config.api_key and ":" in config.api_key:
|
||||
access_key_id, secret_access_key = config.api_key.split(":", 1)
|
||||
@@ -107,45 +131,52 @@ class RedBearModelFactory:
|
||||
params["aws_secret_access_key"] = secret_access_key
|
||||
elif config.api_key:
|
||||
params["aws_access_key_id"] = config.api_key
|
||||
|
||||
|
||||
# 设置 region
|
||||
if config.base_url:
|
||||
params["region_name"] = config.base_url
|
||||
elif "region_name" not in params:
|
||||
params["region_name"] = "us-east-1" # 默认区域
|
||||
|
||||
|
||||
return params
|
||||
else:
|
||||
raise BusinessException(f"不支持的提供商: {provider}", code=BizCode.PROVIDER_NOT_SUPPORTED)
|
||||
|
||||
|
||||
@classmethod
|
||||
def get_rerank_model_params(cls, config: RedBearModelConfig) -> Dict[str, Any]:
|
||||
"""根据提供商获取模型参数"""
|
||||
provider = config.provider.lower()
|
||||
if provider in [ModelProvider.XINFERENCE, ModelProvider.GPUSTACK]:
|
||||
return {
|
||||
return {
|
||||
"model": config.model_name,
|
||||
# "base_url": config.base_url,
|
||||
"jina_api_key": config.api_key,
|
||||
**config.extra_params
|
||||
}
|
||||
}
|
||||
else:
|
||||
raise BusinessException(f"不支持的提供商: {provider}", code=BizCode.PROVIDER_NOT_SUPPORTED)
|
||||
|
||||
def get_provider_llm_class(config:RedBearModelConfig, type: ModelType=ModelType.LLM) -> type[BaseLLM]:
|
||||
|
||||
def get_provider_llm_class(config: RedBearModelConfig, type: ModelType = ModelType.LLM) -> type[BaseLLM]:
|
||||
"""根据模型提供商获取对应的模型类"""
|
||||
provider = config.provider.lower()
|
||||
if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK] :
|
||||
|
||||
# dashscope 的 omni 模型使用 OpenAI 兼容模式
|
||||
if provider == ModelProvider.DASHSCOPE and config.is_omni:
|
||||
from langchain_openai import ChatOpenAI
|
||||
return ChatOpenAI
|
||||
|
||||
if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK] :
|
||||
if type == ModelType.LLM:
|
||||
from langchain_openai import OpenAI
|
||||
return OpenAI
|
||||
return OpenAI
|
||||
elif type == ModelType.CHAT:
|
||||
from langchain_openai import ChatOpenAI
|
||||
return ChatOpenAI
|
||||
elif provider == ModelProvider.DASHSCOPE:
|
||||
from langchain_community.chat_models import ChatTongyi
|
||||
return ChatTongyi
|
||||
elif provider == ModelProvider.OLLAMA:
|
||||
elif provider == ModelProvider.OLLAMA:
|
||||
from langchain_ollama import OllamaLLM
|
||||
return OllamaLLM
|
||||
elif provider == ModelProvider.BEDROCK:
|
||||
@@ -155,15 +186,16 @@ def get_provider_llm_class(config:RedBearModelConfig, type: ModelType=ModelType.
|
||||
else:
|
||||
raise BusinessException(f"不支持的模型提供商: {provider}", code=BizCode.PROVIDER_NOT_SUPPORTED)
|
||||
|
||||
|
||||
def get_provider_embedding_class(provider: str) -> type[Embeddings]:
|
||||
"""根据模型提供商获取对应的模型类"""
|
||||
provider = provider.lower()
|
||||
if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK] :
|
||||
if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK]:
|
||||
from langchain_openai import OpenAIEmbeddings
|
||||
return OpenAIEmbeddings
|
||||
return OpenAIEmbeddings
|
||||
elif provider == ModelProvider.DASHSCOPE:
|
||||
from langchain_community.embeddings import DashScopeEmbeddings
|
||||
return DashScopeEmbeddings
|
||||
return DashScopeEmbeddings
|
||||
elif provider == ModelProvider.OLLAMA:
|
||||
from langchain_ollama import OllamaEmbeddings
|
||||
return OllamaEmbeddings
|
||||
@@ -173,14 +205,15 @@ def get_provider_embedding_class(provider: str) -> type[Embeddings]:
|
||||
else:
|
||||
raise BusinessException(f"不支持的模型提供商: {provider}", code=BizCode.PROVIDER_NOT_SUPPORTED)
|
||||
|
||||
|
||||
def get_provider_rerank_class(provider: str):
|
||||
"""根据模型提供商获取对应的模型类"""
|
||||
provider = provider.lower()
|
||||
if provider in [ModelProvider.XINFERENCE, ModelProvider.GPUSTACK] :
|
||||
provider = provider.lower()
|
||||
if provider in [ModelProvider.XINFERENCE, ModelProvider.GPUSTACK]:
|
||||
from langchain_community.document_compressors import JinaRerank
|
||||
return JinaRerank
|
||||
# elif provider == ModelProvider.OLLAMA:
|
||||
return JinaRerank
|
||||
# elif provider == ModelProvider.OLLAMA:
|
||||
# from langchain_ollama import OllamaEmbeddings
|
||||
# return OllamaEmbeddings
|
||||
else:
|
||||
raise BusinessException(f"不支持的模型提供商: {provider}", code=BizCode.PROVIDER_NOT_SUPPORTED)
|
||||
raise BusinessException(f"不支持的模型提供商: {provider}", code=BizCode.PROVIDER_NOT_SUPPORTED)
|
||||
|
||||
@@ -6,6 +6,8 @@ models:
|
||||
description: AI21 Labs大语言模型,completion生成模式,256000上下文窗口
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
logo: bedrock
|
||||
@@ -15,6 +17,9 @@ models:
|
||||
description: Amazon Nova大语言模型,支持智能体思考、工具调用、流式工具调用、视觉能力,300000上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- agent-thought
|
||||
@@ -28,6 +33,9 @@ models:
|
||||
description: Anthropic Claude大语言模型,支持智能体思考、视觉能力、工具调用、流式工具调用、文档处理,200000上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- agent-thought
|
||||
@@ -42,6 +50,8 @@ models:
|
||||
description: Cohere大语言模型,支持智能体思考、工具调用、流式工具调用,128000上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- agent-thought
|
||||
@@ -54,6 +64,9 @@ models:
|
||||
description: DeepSeek大语言模型,支持智能体思考、视觉能力、工具调用、流式工具调用,32768上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- agent-thought
|
||||
@@ -67,6 +80,8 @@ models:
|
||||
description: Meta Llama大语言模型,支持智能体思考、工具调用,128000上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- agent-thought
|
||||
@@ -78,6 +93,8 @@ models:
|
||||
description: Mistral AI大语言模型,支持智能体思考、工具调用,32000上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- agent-thought
|
||||
@@ -89,6 +106,8 @@ models:
|
||||
description: OpenAI大语言模型,支持智能体思考、工具调用、流式工具调用,32768上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- agent-thought
|
||||
@@ -101,6 +120,8 @@ models:
|
||||
description: Qwen大语言模型,支持智能体思考、工具调用、流式工具调用,32768上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- agent-thought
|
||||
@@ -113,6 +134,8 @@ models:
|
||||
description: amazon.rerank-v1:0重排序模型,5120上下文窗口
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 重排序模型
|
||||
logo: bedrock
|
||||
@@ -122,6 +145,8 @@ models:
|
||||
description: cohere.rerank-v3-5:0重排序模型,5120上下文窗口
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 重排序模型
|
||||
logo: bedrock
|
||||
@@ -131,6 +156,9 @@ models:
|
||||
description: amazon.nova-2-multimodal-embeddings-v1:0文本嵌入模型,支持视觉能力,8192上下文窗口
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
is_omni: false
|
||||
tags:
|
||||
- 文本嵌入模型
|
||||
- vision
|
||||
@@ -141,6 +169,8 @@ models:
|
||||
description: amazon.titan-embed-text-v1文本嵌入模型,8192上下文窗口
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 文本嵌入模型
|
||||
logo: bedrock
|
||||
@@ -150,6 +180,8 @@ models:
|
||||
description: amazon.titan-embed-text-v2:0文本嵌入模型,8192上下文窗口
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 文本嵌入模型
|
||||
logo: bedrock
|
||||
@@ -159,6 +191,8 @@ models:
|
||||
description: Cohere Embed 3 English文本嵌入模型,512上下文窗口
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 文本嵌入模型
|
||||
logo: bedrock
|
||||
@@ -168,6 +202,8 @@ models:
|
||||
description: Cohere Embed 3 Multilingual文本嵌入模型,512上下文窗口
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 文本嵌入模型
|
||||
logo: bedrock
|
||||
logo: bedrock
|
||||
@@ -6,6 +6,8 @@ models:
|
||||
description: DeepSeek-R1-Distill-Qwen-14B大语言模型,支持智能体思考,32000上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- agent-thought
|
||||
@@ -16,6 +18,8 @@ models:
|
||||
description: DeepSeek-R1-Distill-Qwen-32B大语言模型,支持智能体思考,32000上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- agent-thought
|
||||
@@ -26,6 +30,8 @@ models:
|
||||
description: DeepSeek-R1大语言模型,支持智能体思考,131072超大上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- agent-thought
|
||||
@@ -36,6 +42,8 @@ models:
|
||||
description: DeepSeek-V3.1大语言模型,支持智能体思考,131072超大上下文窗口,对话模式,支持丰富生成参数调节
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- agent-thought
|
||||
@@ -46,6 +54,8 @@ models:
|
||||
description: DeepSeek-V3.2-exp实验版大语言模型,支持智能体思考,131072超大上下文窗口,对话模式,支持丰富生成参数调节
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- agent-thought
|
||||
@@ -56,6 +66,8 @@ models:
|
||||
description: DeepSeek-V3.2大语言模型,支持智能体思考,131072超大上下文窗口,对话模式,支持丰富生成参数调节
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- agent-thought
|
||||
@@ -66,6 +78,8 @@ models:
|
||||
description: DeepSeek-V3大语言模型,支持智能体思考,64000上下文窗口,对话模式,支持文本与JSON格式输出
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- agent-thought
|
||||
@@ -76,6 +90,8 @@ models:
|
||||
description: farui-plus大语言模型,支持多工具调用、智能体思考、流式工具调用,12288上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- multi-tool-call
|
||||
@@ -88,6 +104,8 @@ models:
|
||||
description: GLM-4.7大语言模型,支持多工具调用、智能体思考、流式工具调用,202752超大上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- multi-tool-call
|
||||
@@ -100,6 +118,9 @@ models:
|
||||
description: qvq-max-latest大语言模型,支持视觉、智能体思考、流式工具调用,131072上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- vision
|
||||
@@ -112,6 +133,9 @@ models:
|
||||
description: qvq-max大语言模型,支持视觉、智能体思考、流式工具调用,131072上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- vision
|
||||
@@ -124,6 +148,8 @@ models:
|
||||
description: qwen-coder-turbo-0919代码专用大语言模型,支持智能体思考,131072上下文窗口,对话模式,已废弃
|
||||
is_deprecated: true
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- 代码模型
|
||||
@@ -135,6 +161,8 @@ models:
|
||||
description: qwen-max-latest大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式,支持联网搜索
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- multi-tool-call
|
||||
@@ -147,6 +175,8 @@ models:
|
||||
description: qwen-max-longcontext长上下文大语言模型,支持多工具调用、智能体思考、流式工具调用,32000上下文窗口,对话模式,已废弃
|
||||
is_deprecated: true
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- multi-tool-call
|
||||
@@ -159,6 +189,8 @@ models:
|
||||
description: qwen-max大语言模型,支持多工具调用、智能体思考、流式工具调用,32768上下文窗口,对话模式,支持联网搜索
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- multi-tool-call
|
||||
@@ -171,6 +203,8 @@ models:
|
||||
description: qwen-mt-plus多语言翻译大语言模型,支持智能体思考,16384上下文窗口,对话模式,支持多语种互译与领域翻译适配
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- 翻译模型
|
||||
@@ -182,6 +216,8 @@ models:
|
||||
description: qwen-mt-turbo轻量化多语言翻译大语言模型,支持智能体思考,16384上下文窗口,对话模式,支持多语种互译与领域翻译适配
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- 翻译模型
|
||||
@@ -193,6 +229,8 @@ models:
|
||||
description: qwen-plus-0112大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式,支持联网搜索,已废弃
|
||||
is_deprecated: true
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- multi-tool-call
|
||||
@@ -205,6 +243,8 @@ models:
|
||||
description: qwen-plus-0125大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式,支持联网搜索,已废弃
|
||||
is_deprecated: true
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- multi-tool-call
|
||||
@@ -217,6 +257,8 @@ models:
|
||||
description: qwen-plus-0723大语言模型,支持多工具调用、智能体思考、流式工具调用,32000上下文窗口,对话模式,支持联网搜索,已废弃
|
||||
is_deprecated: true
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- multi-tool-call
|
||||
@@ -229,6 +271,8 @@ models:
|
||||
description: qwen-plus-0806大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式,支持联网搜索,已废弃
|
||||
is_deprecated: true
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- multi-tool-call
|
||||
@@ -241,6 +285,8 @@ models:
|
||||
description: qwen-plus-0919大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式,支持联网搜索,已废弃
|
||||
is_deprecated: true
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- multi-tool-call
|
||||
@@ -253,6 +299,8 @@ models:
|
||||
description: qwen-plus-1125大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式,支持联网搜索,已废弃
|
||||
is_deprecated: true
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- multi-tool-call
|
||||
@@ -265,6 +313,8 @@ models:
|
||||
description: qwen-plus-1127大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式,支持联网搜索,已废弃
|
||||
is_deprecated: true
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- multi-tool-call
|
||||
@@ -277,6 +327,8 @@ models:
|
||||
description: qwen-plus-1220大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式,已废弃
|
||||
is_deprecated: true
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- multi-tool-call
|
||||
@@ -289,6 +341,10 @@ models:
|
||||
description: qwen-vl-max多模态大模型,支持视觉理解、智能体思考、视频理解,131072上下文窗口,对话模式,未废弃
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
- video
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- 多模态模型
|
||||
@@ -302,6 +358,10 @@ models:
|
||||
description: qwen-vl-plus-0809多模态大模型,支持视觉理解、智能体思考、视频理解,32768上下文窗口,对话模式,已废弃
|
||||
is_deprecated: true
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
- video
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- 多模态模型
|
||||
@@ -315,6 +375,10 @@ models:
|
||||
description: qwen-vl-plus-2025-01-02多模态大模型,支持视觉理解、智能体思考、视频理解,32768上下文窗口,对话模式,未废弃
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
- video
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- 多模态模型
|
||||
@@ -328,6 +392,10 @@ models:
|
||||
description: qwen-vl-plus-2025-01-25多模态大模型,支持视觉理解、智能体思考、视频理解,131072上下文窗口,对话模式,未废弃
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
- video
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- 多模态模型
|
||||
@@ -341,6 +409,10 @@ models:
|
||||
description: qwen-vl-plus-latest多模态大模型,支持视觉理解、智能体思考、视频理解,131072上下文窗口,对话模式,未废弃
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
- video
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- 多模态模型
|
||||
@@ -354,6 +426,10 @@ models:
|
||||
description: qwen-vl-plus多模态大模型,支持视觉理解、智能体思考、视频理解,131072上下文窗口,对话模式,未废弃
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
- video
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- 多模态模型
|
||||
@@ -367,6 +443,8 @@ models:
|
||||
description: qwen2.5-0.5b-instruct大语言模型,支持多工具调用、智能体思考、流式工具调用,32768上下文窗口,对话模式,未废弃
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- multi-tool-call
|
||||
@@ -379,6 +457,8 @@ models:
|
||||
description: qwen3-14b大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- multi-tool-call
|
||||
@@ -391,6 +471,8 @@ models:
|
||||
description: qwen3-235b-a22b-instruct-2507大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- multi-tool-call
|
||||
@@ -403,6 +485,8 @@ models:
|
||||
description: qwen3-235b-a22b-thinking-2507大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- multi-tool-call
|
||||
@@ -415,6 +499,8 @@ models:
|
||||
description: qwen3-235b-a22b大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- multi-tool-call
|
||||
@@ -427,6 +513,8 @@ models:
|
||||
description: qwen3-30b-a3b-instruct-2507大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- multi-tool-call
|
||||
@@ -439,6 +527,8 @@ models:
|
||||
description: qwen3-30b-a3b大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- multi-tool-call
|
||||
@@ -451,6 +541,8 @@ models:
|
||||
description: qwen3-32b大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- multi-tool-call
|
||||
@@ -463,6 +555,8 @@ models:
|
||||
description: qwen3-4b大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- multi-tool-call
|
||||
@@ -475,6 +569,8 @@ models:
|
||||
description: qwen3-8b大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- multi-tool-call
|
||||
@@ -487,6 +583,8 @@ models:
|
||||
description: qwen3-coder-30b-a3b-instruct大语言模型,支持智能体思考,262144上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- 代码模型
|
||||
@@ -498,6 +596,8 @@ models:
|
||||
description: qwen3-coder-480b-a35b-instruct大语言模型,支持智能体思考,262144上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- 代码模型
|
||||
@@ -509,6 +609,8 @@ models:
|
||||
description: qwen3-coder-plus-2025-09-23大语言模型,支持智能体思考,1000000上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- 代码模型
|
||||
@@ -520,6 +622,8 @@ models:
|
||||
description: qwen3-coder-plus大语言模型,支持智能体思考,1000000上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- 代码模型
|
||||
@@ -531,6 +635,8 @@ models:
|
||||
description: qwen3-max-2025-09-23大语言模型,支持多工具调用、智能体思考、流式工具调用,262144上下文窗口,对话模式,支持联网搜索
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- multi-tool-call
|
||||
@@ -544,6 +650,8 @@ models:
|
||||
description: qwen3-max-2026-01-23大语言模型,支持多工具调用、智能体思考、流式工具调用,262144上下文窗口,对话模式,支持联网搜索
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- multi-tool-call
|
||||
@@ -557,6 +665,8 @@ models:
|
||||
description: qwen3-max-preview大语言模型,支持多工具调用、智能体思考、流式工具调用,262144上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- multi-tool-call
|
||||
@@ -569,6 +679,8 @@ models:
|
||||
description: qwen3-max大语言模型,支持多工具调用、智能体思考、流式工具调用,262144上下文窗口,对话模式,支持联网搜索
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- multi-tool-call
|
||||
@@ -582,6 +694,8 @@ models:
|
||||
description: qwen3-next-80b-a3b-instruct大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- multi-tool-call
|
||||
@@ -594,6 +708,8 @@ models:
|
||||
description: qwen3-next-80b-a3b-thinking大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- multi-tool-call
|
||||
@@ -606,6 +722,11 @@ models:
|
||||
description: qwen3-omni-flash-2025-12-01多模态大语言模型,支持视觉、智能体思考、视频、音频能力,65536上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
- video
|
||||
- audio
|
||||
is_omni: true
|
||||
tags:
|
||||
- 大语言模型
|
||||
- 多模态模型
|
||||
@@ -620,6 +741,10 @@ models:
|
||||
description: qwen3-vl-235b-a22b-instruct多模态大语言模型,支持多工具调用、智能体思考、流式工具调用、视觉、视频能力,131072上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
- video
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- 多模态模型
|
||||
@@ -635,6 +760,10 @@ models:
|
||||
description: qwen3-vl-235b-a22b-thinking多模态大语言模型,支持多工具调用、智能体思考、流式工具调用、视觉、视频能力,131072上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
- video
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- 多模态模型
|
||||
@@ -650,6 +779,10 @@ models:
|
||||
description: qwen3-vl-30b-a3b-instruct多模态大语言模型,支持多工具调用、智能体思考、流式工具调用、视觉、视频能力,131072上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
- video
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- 多模态模型
|
||||
@@ -665,6 +798,10 @@ models:
|
||||
description: qwen3-vl-30b-a3b-thinking多模态大语言模型,支持多工具调用、智能体思考、流式工具调用、视觉、视频能力,131072上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
- video
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- 多模态模型
|
||||
@@ -680,6 +817,10 @@ models:
|
||||
description: qwen3-vl-flash多模态大语言模型,支持多工具调用、智能体思考、流式工具调用、视觉、视频能力,131072上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
- video
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- 多模态模型
|
||||
@@ -695,6 +836,10 @@ models:
|
||||
description: qwen3-vl-plus-2025-09-23多模态大语言模型,支持视觉、智能体思考、视频能力,262144上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
- video
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- 多模态模型
|
||||
@@ -708,6 +853,10 @@ models:
|
||||
description: qwen3-vl-plus多模态大语言模型,支持视觉、智能体思考、视频能力,262144上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
- video
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- 多模态模型
|
||||
@@ -721,6 +870,8 @@ models:
|
||||
description: qwq-32b大语言模型,支持智能体思考、流式工具调用,131072上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- agent-thought
|
||||
@@ -732,6 +883,8 @@ models:
|
||||
description: qwq-plus-0305大语言模型,支持智能体思考、流式工具调用,131072上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- agent-thought
|
||||
@@ -743,6 +896,8 @@ models:
|
||||
description: qwq-plus大语言模型,支持智能体思考、流式工具调用,131072上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- agent-thought
|
||||
@@ -754,6 +909,8 @@ models:
|
||||
description: gte-rerank-v2重排序模型,4000上下文窗口
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 重排序模型
|
||||
logo: dashscope
|
||||
@@ -763,6 +920,8 @@ models:
|
||||
description: gte-rerank重排序模型,4000上下文窗口
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 重排序模型
|
||||
logo: dashscope
|
||||
@@ -772,6 +931,9 @@ models:
|
||||
description: multimodal-embedding-v1多模态嵌入模型,支持视觉能力,8192上下文窗口,最大分块数10
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
is_omni: false
|
||||
tags:
|
||||
- 嵌入模型
|
||||
- 多模态模型
|
||||
@@ -783,6 +945,8 @@ models:
|
||||
description: text-embedding-v1文本嵌入模型,2048上下文窗口,最大分块数25
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 嵌入模型
|
||||
- 文本嵌入
|
||||
@@ -793,6 +957,8 @@ models:
|
||||
description: text-embedding-v2文本嵌入模型,2048上下文窗口,最大分块数25
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 嵌入模型
|
||||
- 文本嵌入
|
||||
@@ -803,6 +969,8 @@ models:
|
||||
description: text-embedding-v3文本嵌入模型,8192上下文窗口,最大分块数10
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 嵌入模型
|
||||
- 文本嵌入
|
||||
@@ -813,7 +981,9 @@ models:
|
||||
description: text-embedding-v4文本嵌入模型,8192上下文窗口,最大分块数10
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 嵌入模型
|
||||
- 文本嵌入
|
||||
logo: dashscope
|
||||
logo: dashscope
|
||||
@@ -6,7 +6,7 @@ from typing import Callable
|
||||
import yaml
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.models.models_model import ModelBase, ModelProvider
|
||||
from app.models.models_model import ModelBase, ModelProvider, ModelConfig
|
||||
|
||||
|
||||
def _load_yaml_config(provider: ModelProvider) -> list[dict]:
|
||||
@@ -55,6 +55,15 @@ def load_models(db: Session, providers: list[str] = None, silent: bool = False)
|
||||
print(f"\n正在加载 {provider.value} 的 {len(models)} 个模型...")
|
||||
|
||||
for model_data in models:
|
||||
config_sync_fields = {
|
||||
"logo": None,
|
||||
"capability": None,
|
||||
"is_omni": None,
|
||||
"name": None,
|
||||
"provider": None,
|
||||
"type": None,
|
||||
"description": None
|
||||
}
|
||||
try:
|
||||
# 检查模型是否已存在
|
||||
existing = db.query(ModelBase).filter(
|
||||
@@ -66,6 +75,40 @@ def load_models(db: Session, providers: list[str] = None, silent: bool = False)
|
||||
# 更新现有模型配置
|
||||
for key, value in model_data.items():
|
||||
setattr(existing, key, value)
|
||||
|
||||
# 更新绑定了该 model_id 的 ModelConfig 和 ModelApiKey
|
||||
sync_fields = [k for k in config_sync_fields.keys() if k in model_data]
|
||||
if sync_fields:
|
||||
# 批量更新 ModelConfig
|
||||
update_kwargs = {k: model_data[k] for k in sync_fields}
|
||||
db.query(ModelConfig).filter(ModelConfig.model_id == existing.id).update(
|
||||
update_kwargs,
|
||||
synchronize_session=False
|
||||
)
|
||||
|
||||
# 更新 ModelApiKey 的 capability 和 is_omni
|
||||
if 'capability' in model_data or 'is_omni' in model_data:
|
||||
from app.models.models_model import ModelApiKey, model_config_api_key_association
|
||||
api_key_update = {}
|
||||
if 'capability' in model_data:
|
||||
api_key_update['capability'] = model_data['capability']
|
||||
if 'is_omni' in model_data:
|
||||
api_key_update['is_omni'] = model_data['is_omni']
|
||||
|
||||
if api_key_update:
|
||||
# 查找所有关联的 API Key
|
||||
api_key_ids = db.query(model_config_api_key_association.c.api_key_id).join(
|
||||
ModelConfig,
|
||||
ModelConfig.id == model_config_api_key_association.c.model_config_id
|
||||
).filter(ModelConfig.model_id == existing.id).distinct().all()
|
||||
|
||||
if api_key_ids:
|
||||
api_key_ids = [aid[0] for aid in api_key_ids]
|
||||
db.query(ModelApiKey).filter(ModelApiKey.id.in_(api_key_ids)).update(
|
||||
api_key_update,
|
||||
synchronize_session=False
|
||||
)
|
||||
|
||||
db.commit()
|
||||
if not silent:
|
||||
print(f"更新成功: {model_data['name']}")
|
||||
|
||||
@@ -6,12 +6,19 @@ models:
|
||||
description: chatgpt-4o-latest大语言模型,支持多工具调用、智能体思考、流式工具调用、视觉能力,128000上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
- audio
|
||||
- video
|
||||
is_omni: true
|
||||
tags:
|
||||
- 大语言模型
|
||||
- multi-tool-call
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
- vision
|
||||
- audio
|
||||
- video
|
||||
logo: openai
|
||||
- name: gpt-3.5-turbo-0125
|
||||
type: llm
|
||||
@@ -19,6 +26,8 @@ models:
|
||||
description: gpt-3.5-turbo-0125大语言模型,支持多工具调用、智能体思考、流式工具调用,16385上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- multi-tool-call
|
||||
@@ -31,6 +40,8 @@ models:
|
||||
description: gpt-3.5-turbo-1106大语言模型,支持多工具调用、智能体思考、流式工具调用,16385上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- multi-tool-call
|
||||
@@ -43,6 +54,8 @@ models:
|
||||
description: gpt-3.5-turbo-16k大语言模型,支持多工具调用、智能体思考、流式工具调用,16385上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- multi-tool-call
|
||||
@@ -55,6 +68,8 @@ models:
|
||||
description: gpt-3.5-turbo-instruct大语言模型,4096上下文窗口,文本补全模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
logo: openai
|
||||
@@ -64,6 +79,8 @@ models:
|
||||
description: gpt-3.5-turbo大语言模型,支持多工具调用、智能体思考、流式工具调用,16385上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- multi-tool-call
|
||||
@@ -76,6 +93,8 @@ models:
|
||||
description: gpt-4-0125-preview大语言模型,支持多工具调用、智能体思考、流式工具调用,128000上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- multi-tool-call
|
||||
@@ -88,6 +107,8 @@ models:
|
||||
description: gpt-4-1106-preview大语言模型,支持多工具调用、智能体思考、流式工具调用,128000上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- multi-tool-call
|
||||
@@ -100,6 +121,9 @@ models:
|
||||
description: gpt-4-turbo-2024-04-09大语言模型,支持多工具调用、智能体思考、流式工具调用、视觉能力,128000上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- multi-tool-call
|
||||
@@ -113,6 +137,8 @@ models:
|
||||
description: gpt-4-turbo-preview大语言模型,支持多工具调用、智能体思考、流式工具调用,128000上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- multi-tool-call
|
||||
@@ -125,6 +151,9 @@ models:
|
||||
description: gpt-4-turbo大语言模型,支持多工具调用、智能体思考、流式工具调用、视觉能力,128000上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- multi-tool-call
|
||||
@@ -138,6 +167,8 @@ models:
|
||||
description: o1-preview大语言模型,支持智能体思考,128000上下文窗口,对话模式,已废弃
|
||||
is_deprecated: true
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- agent-thought
|
||||
@@ -148,6 +179,9 @@ models:
|
||||
description: o1大语言模型,支持多工具调用、智能体思考、流式工具调用、视觉能力、结构化输出,200000上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- multi-tool-call
|
||||
@@ -162,6 +196,9 @@ models:
|
||||
description: o3-2025-04-16大语言模型,支持智能体思考、工具调用、视觉能力、流式工具调用、结构化输出,200000上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- agent-thought
|
||||
@@ -176,6 +213,8 @@ models:
|
||||
description: o3-mini-2025-01-31大语言模型,支持智能体思考、工具调用、流式工具调用、结构化输出,200000上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- agent-thought
|
||||
@@ -189,6 +228,8 @@ models:
|
||||
description: o3-mini大语言模型,支持智能体思考、工具调用、流式工具调用、结构化输出,200000上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- agent-thought
|
||||
@@ -202,6 +243,9 @@ models:
|
||||
description: o3-pro-2025-06-10大语言模型,支持智能体思考、工具调用、视觉能力、结构化输出,200000上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- agent-thought
|
||||
@@ -215,6 +259,9 @@ models:
|
||||
description: o3-pro大语言模型,支持智能体思考、工具调用、视觉能力、结构化输出,200000上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- agent-thought
|
||||
@@ -228,6 +275,9 @@ models:
|
||||
description: o3大语言模型,支持智能体思考、视觉能力、工具调用、流式工具调用、结构化输出,200000上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- agent-thought
|
||||
@@ -242,6 +292,9 @@ models:
|
||||
description: o4-mini-2025-04-16大语言模型,支持智能体思考、工具调用、视觉能力、流式工具调用、结构化输出,200000上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- agent-thought
|
||||
@@ -256,6 +309,9 @@ models:
|
||||
description: o4-mini大语言模型,支持智能体思考、工具调用、视觉能力、流式工具调用、结构化输出,200000上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability:
|
||||
- vision
|
||||
is_omni: false
|
||||
tags:
|
||||
- 大语言模型
|
||||
- agent-thought
|
||||
@@ -270,6 +326,8 @@ models:
|
||||
description: text-embedding-3-large文本向量模型,8191上下文窗口,最大分块数32
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 文本向量模型
|
||||
logo: openai
|
||||
@@ -279,6 +337,8 @@ models:
|
||||
description: text-embedding-3-small文本向量模型,8191上下文窗口,最大分块数32
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 文本向量模型
|
||||
logo: openai
|
||||
@@ -288,6 +348,8 @@ models:
|
||||
description: text-embedding-ada-002文本向量模型,8097上下文窗口,最大分块数32
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
capability: []
|
||||
is_omni: false
|
||||
tags:
|
||||
- 文本向量模型
|
||||
logo: openai
|
||||
logo: openai
|
||||
@@ -12,7 +12,7 @@ from app.core.workflow.adapters.errors import UnsupportVariableType, UnknowModel
|
||||
ExceptionType
|
||||
from app.core.workflow.nodes.assigner import AssignerNodeConfig
|
||||
from app.core.workflow.nodes.assigner.config import AssignmentItem
|
||||
from app.core.workflow.nodes.base_config import VariableDefinition
|
||||
from app.core.workflow.nodes.base_config import VariableDefinition, BaseNodeConfig
|
||||
from app.core.workflow.nodes.code import CodeNodeConfig
|
||||
from app.core.workflow.nodes.code.config import InputVariable, OutputVariable
|
||||
from app.core.workflow.nodes.configs import StartNodeConfig, LLMNodeConfig
|
||||
@@ -69,9 +69,27 @@ class DifyConverter(BaseConverter):
|
||||
}
|
||||
|
||||
def get_node_convert(self, node_type):
|
||||
func = self.CONFIG_CONVERT_MAP.get(node_type, None)
|
||||
func = self.CONFIG_CONVERT_MAP.get(node_type, lambda x: {})
|
||||
return func
|
||||
|
||||
def config_validate(
|
||||
self,
|
||||
node_id: str,
|
||||
node_name: str,
|
||||
config: type[BaseNodeConfig],
|
||||
value: dict
|
||||
):
|
||||
try:
|
||||
return config.model_validate(value)
|
||||
except Exception as e:
|
||||
self.errors.append(ExceptionDefineition(
|
||||
type=ExceptionType.CONFIG,
|
||||
node_id=node_id,
|
||||
node_name=node_name,
|
||||
detail=str(e)
|
||||
))
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def is_variable(expression) -> bool:
|
||||
return bool(re.match(r"\{\{#(.*?)#}}", expression))
|
||||
@@ -80,14 +98,16 @@ class DifyConverter(BaseConverter):
|
||||
if not var_selector:
|
||||
return ""
|
||||
selector = var_selector.split('.')
|
||||
if len(selector) != 2:
|
||||
if len(selector) not in [2, 3] and var_selector != "context":
|
||||
raise Exception(f"invalid variable selector: {var_selector}")
|
||||
if len(selector) == 3:
|
||||
selector = selector[1:]
|
||||
if selector[0] == "conversation":
|
||||
selector[0] = "conv"
|
||||
var_selector = ".".join(selector)
|
||||
mapping = {
|
||||
"sys.query": "sys.message"
|
||||
} | self.node_output_map
|
||||
"sys.query": "sys.message"
|
||||
} | self.node_output_map
|
||||
|
||||
var_selector = mapping.get(var_selector, var_selector)
|
||||
return var_selector
|
||||
@@ -237,7 +257,7 @@ class DifyConverter(BaseConverter):
|
||||
node_id=node["id"],
|
||||
node_name=node_data["title"],
|
||||
name=var["variable"],
|
||||
detail=f"Unsupport Variable type for start node: {var_type}"
|
||||
detail=f"Unsupported Variable type for start node: {var_type}"
|
||||
)
|
||||
)
|
||||
continue
|
||||
@@ -253,9 +273,11 @@ class DifyConverter(BaseConverter):
|
||||
max_length=var.get("max_length"),
|
||||
)
|
||||
start_vars.append(var_def)
|
||||
return StartNodeConfig(
|
||||
result = StartNodeConfig.model_construct(
|
||||
variables=start_vars
|
||||
).model_dump()
|
||||
self.config_validate(node["id"], node["data"]["title"], StartNodeConfig, result)
|
||||
return result
|
||||
|
||||
def convert_question_classifier_node_config(self, node: dict) -> dict:
|
||||
node_data = node["data"]
|
||||
@@ -270,16 +292,18 @@ class DifyConverter(BaseConverter):
|
||||
for category in node_data["classes"]:
|
||||
self.branch_node_cache[node["id"]].append(category["id"])
|
||||
categories.append(
|
||||
ClassifierConfig(
|
||||
ClassifierConfig.model_construct(
|
||||
class_name=category["name"],
|
||||
)
|
||||
)
|
||||
|
||||
return QuestionClassifierNodeConfig.model_construct(
|
||||
input_variable=self._process_list_variable_litearl(node_data["query_variable_selector"]),
|
||||
user_supplement_prompt=self.trans_variable_format(node_data["instructions"]),
|
||||
result = QuestionClassifierNodeConfig.model_construct(
|
||||
input_variable=self._process_list_variable_litearl(node_data.get("query_variable_selector")),
|
||||
user_supplement_prompt=self.trans_variable_format(node_data.get("instructions", "")),
|
||||
categories=categories,
|
||||
).model_dump()
|
||||
self.config_validate(node["id"], node["data"]["title"], QuestionClassifierNodeConfig, result)
|
||||
return result
|
||||
|
||||
def convert_llm_node_config(self, node: dict) -> dict:
|
||||
node_data = node["data"]
|
||||
@@ -308,14 +332,16 @@ class DifyConverter(BaseConverter):
|
||||
messages.append(
|
||||
MessageConfig(
|
||||
role="user",
|
||||
content=self.trans_variable_format(node_data["memory"]["query_prompt_template"])
|
||||
content=self.trans_variable_format(
|
||||
node_data["memory"].get("query_prompt_template", "{{#sys.query#}}")
|
||||
)
|
||||
)
|
||||
)
|
||||
vision = node_data["vision"]["enabled"]
|
||||
vision_input = self._process_list_variable_litearl(
|
||||
node_data["vision"]["configs"]["variable_selector"]
|
||||
) if vision else None
|
||||
return LLMNodeConfig.model_construct(
|
||||
result = LLMNodeConfig.model_construct(
|
||||
model_id=None,
|
||||
context=context,
|
||||
memory=memory,
|
||||
@@ -323,12 +349,16 @@ class DifyConverter(BaseConverter):
|
||||
vision_input=vision_input,
|
||||
messages=messages
|
||||
).model_dump()
|
||||
self.config_validate(node["id"], node["data"]["title"], LLMNodeConfig, result)
|
||||
return result
|
||||
|
||||
def convert_end_node_config(self, node: dict) -> dict:
|
||||
node_data = node["data"]
|
||||
return EndNodeConfig(
|
||||
output=self.trans_variable_format(node_data["answer"]),
|
||||
result = EndNodeConfig.model_construct(
|
||||
output=self.trans_variable_format(node_data.get("answer", "")),
|
||||
).model_dump()
|
||||
self.config_validate(node["id"], node["data"]["title"], EndNodeConfig, result)
|
||||
return result
|
||||
|
||||
def convert_if_else_node_config(self, node: dict) -> dict:
|
||||
node_data = node["data"]
|
||||
@@ -359,9 +389,11 @@ class DifyConverter(BaseConverter):
|
||||
)
|
||||
)
|
||||
self.branch_node_cache[node["id"]].append(case_id)
|
||||
return IfElseNodeConfig(
|
||||
result = IfElseNodeConfig.model_construct(
|
||||
cases=cases
|
||||
).model_dump()
|
||||
self.config_validate(node["id"], node["data"]["title"], IfElseNodeConfig, result)
|
||||
return result
|
||||
|
||||
def convert_loop_node_config(self, node: dict) -> dict:
|
||||
node_data = node["data"]
|
||||
@@ -370,7 +402,7 @@ class DifyConverter(BaseConverter):
|
||||
for condition in node_data["break_conditions"]:
|
||||
right_value = condition["value"]
|
||||
conditions.append(
|
||||
LoopConditionDetail(
|
||||
LoopConditionDetail.model_construct(
|
||||
operator=self.convert_compare_operator(condition["comparison_operator"]),
|
||||
left=self._process_list_variable_litearl(condition["variable_selector"]),
|
||||
right=self.trans_variable_format(
|
||||
@@ -383,7 +415,7 @@ class DifyConverter(BaseConverter):
|
||||
if isinstance(right_value, str) and self.is_variable(right_value) else ValueInputType.CONSTANT,
|
||||
)
|
||||
)
|
||||
condition_config = ConditionsConfig(
|
||||
condition_config = ConditionsConfig.model_construct(
|
||||
logical_operator=logical_operator,
|
||||
expressions=conditions
|
||||
)
|
||||
@@ -392,9 +424,9 @@ class DifyConverter(BaseConverter):
|
||||
right_input_type = variable["value_type"]
|
||||
right_value_type = self.variable_type_map(variable["var_type"])
|
||||
if right_input_type == ValueInputType.VARIABLE:
|
||||
right_value = self._process_list_variable_litearl(variable["value"])
|
||||
right_value = self._process_list_variable_litearl(variable.get("value", ""))
|
||||
else:
|
||||
right_value = self.convert_variable_type(right_value_type, variable["value"])
|
||||
right_value = self.convert_variable_type(right_value_type, variable.get("value", ""))
|
||||
loop_variables.append(
|
||||
CycleVariable(
|
||||
name=variable["label"],
|
||||
@@ -403,23 +435,28 @@ class DifyConverter(BaseConverter):
|
||||
input_type=right_input_type
|
||||
)
|
||||
)
|
||||
return LoopNodeConfig(
|
||||
result = LoopNodeConfig.model_construct(
|
||||
condition=condition_config,
|
||||
cycle_vars=loop_variables,
|
||||
max_loop=node_data["loop_count"]
|
||||
max_loop=node_data.get("loop_count", 10)
|
||||
).model_dump()
|
||||
self.config_validate(node["id"], node["data"]["title"], LoopNodeConfig, result)
|
||||
return result
|
||||
|
||||
def convert_iteration_node_config(self, node: dict) -> dict:
|
||||
node_data = node["data"]
|
||||
return IterationNodeConfig(
|
||||
result = IterationNodeConfig.model_construct(
|
||||
input=self._process_list_variable_litearl(node_data["iterator_selector"]),
|
||||
parallel=node_data["is_parallel"],
|
||||
parallel_count=node_data["parallel_nums"],
|
||||
output=self._process_list_variable_litearl(node_data["output_selector"]),
|
||||
output_type=self.variable_type_map(node_data["output_type"]),
|
||||
output_type=self.variable_type_map(node_data.get("output_type")),
|
||||
flatten=node_data["flatten_output"],
|
||||
).model_dump()
|
||||
|
||||
self.config_validate(node["id"], node["data"]["title"], IterationNodeConfig, result)
|
||||
return result
|
||||
|
||||
def convert_assigner_node_config(self, node: dict) -> dict:
|
||||
node_data = node["data"]
|
||||
assignments = []
|
||||
@@ -435,16 +472,18 @@ class DifyConverter(BaseConverter):
|
||||
operation=self.convert_assignment_operator(assignment["operation"])
|
||||
)
|
||||
)
|
||||
return AssignerNodeConfig(
|
||||
result = AssignerNodeConfig.model_construct(
|
||||
assignments=assignments
|
||||
).model_dump()
|
||||
self.config_validate(node["id"], node["data"]["title"], AssignerNodeConfig, result)
|
||||
return result
|
||||
|
||||
def convert_code_node_config(self, node: dict) -> dict:
|
||||
node_data = node["data"]
|
||||
input_variables = []
|
||||
for input_variable in node_data["variables"]:
|
||||
input_variables.append(
|
||||
InputVariable(
|
||||
InputVariable.model_construct(
|
||||
name=input_variable["variable"],
|
||||
variable=self._process_list_variable_litearl(input_variable["value_selector"]),
|
||||
)
|
||||
@@ -453,7 +492,7 @@ class DifyConverter(BaseConverter):
|
||||
output_variables = []
|
||||
for output_variable in node_data["outputs"]:
|
||||
output_variables.append(
|
||||
OutputVariable(
|
||||
OutputVariable.model_construct(
|
||||
name=output_variable,
|
||||
type=node_data["outputs"][output_variable]["type"],
|
||||
)
|
||||
@@ -461,18 +500,20 @@ class DifyConverter(BaseConverter):
|
||||
|
||||
code = base64.b64encode(quote(node_data["code"]).encode("utf-8")).decode("utf-8")
|
||||
|
||||
return CodeNodeConfig(
|
||||
result = CodeNodeConfig.model_construct(
|
||||
input_variables=input_variables,
|
||||
language=node_data["code_language"],
|
||||
output_variables=output_variables,
|
||||
code=code
|
||||
).model_dump()
|
||||
self.config_validate(node["id"], node["data"]["title"], CodeNodeConfig, result)
|
||||
return result
|
||||
|
||||
def convert_http_node_config(self, node: dict) -> dict:
|
||||
node_data = node["data"]
|
||||
if node_data["authorization"] != 'no-auth':
|
||||
if node_data["authorization"]["type"] != 'no-auth':
|
||||
auth_type = self.convert_http_auth_type(node_data["authorization"]["config"]["type"])
|
||||
auth_config = HttpAuthConfig(
|
||||
auth_config = HttpAuthConfig.model_construct(
|
||||
auth_type=auth_type,
|
||||
header=node_data["authorization"]["config"].get("header"),
|
||||
api_key=node_data["authorization"]["config"].get("api_key"),
|
||||
@@ -504,7 +545,7 @@ class DifyConverter(BaseConverter):
|
||||
body_content = ""
|
||||
|
||||
headers = {}
|
||||
for header in node_data["headers"].split("\n"):
|
||||
for header in node_data.get("headers", "").split("\n"):
|
||||
if not header:
|
||||
continue
|
||||
|
||||
@@ -522,7 +563,7 @@ class DifyConverter(BaseConverter):
|
||||
))
|
||||
|
||||
params = {}
|
||||
for param in node_data["params"].split("\n"):
|
||||
for param in node_data.get("params", "").split("\n"):
|
||||
if not param:
|
||||
continue
|
||||
|
||||
@@ -547,7 +588,7 @@ class DifyConverter(BaseConverter):
|
||||
default_body = ""
|
||||
default_header = {}
|
||||
default_status_code = 0
|
||||
for var in node_data["default_value"]:
|
||||
for var in node_data.get("default_value") or []:
|
||||
if var["key"] == "body":
|
||||
default_body = var["value"]
|
||||
elif var["key"] == "header":
|
||||
@@ -561,45 +602,50 @@ class DifyConverter(BaseConverter):
|
||||
)
|
||||
|
||||
self.error_branch_node_cache.append(node['id'])
|
||||
return HttpRequestNodeConfig(
|
||||
result = HttpRequestNodeConfig.model_construct(
|
||||
method=node_data["method"].upper(),
|
||||
url=node_data["url"],
|
||||
auth=auth_config,
|
||||
body=HttpContentTypeConfig(
|
||||
body=HttpContentTypeConfig.model_construct(
|
||||
content_type=self.convert_http_content_type(node_data["body"]["type"]),
|
||||
data=body_content,
|
||||
),
|
||||
headers=headers,
|
||||
params=params,
|
||||
verify_ssl=node_data["ssl_verify"],
|
||||
timeouts=HttpTimeOutConfig(
|
||||
timeouts=HttpTimeOutConfig.model_construct(
|
||||
connect_timeout=node_data["timeout"]["max_connect_timeout"] or 5,
|
||||
read_timeout=node_data["timeout"]["max_read_timeout"] or 5,
|
||||
write_timeout=node_data["timeout"]["max_write_timeout"] or 5,
|
||||
),
|
||||
retry=HttpRetryConfig(
|
||||
retry=HttpRetryConfig.model_construct(
|
||||
enable=node_data["retry_config"]["retry_enabled"],
|
||||
max_attempts=node_data["retry_config"]["max_retries"],
|
||||
retry_interval=node_data["retry_config"]["retry_interval"],
|
||||
),
|
||||
error_handle=HttpErrorHandleConfig(
|
||||
error_handle=HttpErrorHandleConfig.model_construct(
|
||||
method=error_handle_type,
|
||||
default=default_value,
|
||||
)
|
||||
).model_dump()
|
||||
|
||||
self.config_validate(node["id"], node["data"]["title"], HttpRequestNodeConfig, result)
|
||||
return result
|
||||
|
||||
def convert_jinja_render_node_config(self, node: dict) -> dict:
|
||||
node_data = node["data"]
|
||||
mapping = []
|
||||
for variable in node_data["variables"]:
|
||||
mapping.append(VariablesMappingConfig(
|
||||
mapping.append(VariablesMappingConfig.model_construct(
|
||||
name=variable["variable"],
|
||||
value=self._process_list_variable_litearl(variable["value_selector"])
|
||||
))
|
||||
return JinjaRenderNodeConfig(
|
||||
result = JinjaRenderNodeConfig.model_construct(
|
||||
template=node_data["template"],
|
||||
mapping=mapping,
|
||||
).model_dump()
|
||||
self.config_validate(node["id"], node["data"]["title"], JinjaRenderNodeConfig, result)
|
||||
return result
|
||||
|
||||
def convert_knowledge_node_config(self, node: dict) -> dict:
|
||||
node_data = node["data"]
|
||||
@@ -609,10 +655,13 @@ class DifyConverter(BaseConverter):
|
||||
type=ExceptionType.CONFIG,
|
||||
detail=f"Please reconfigure the Knowledge Retrieval node.",
|
||||
))
|
||||
return KnowledgeRetrievalNodeConfig.model_construct(
|
||||
result = KnowledgeRetrievalNodeConfig.model_construct(
|
||||
query=self._process_list_variable_litearl(node_data["query_variable_selector"]),
|
||||
).model_dump()
|
||||
|
||||
self.config_validate(node["id"], node["data"]["title"], KnowledgeRetrievalNodeConfig, result)
|
||||
return result
|
||||
|
||||
def convert_parameter_extractor_node_config(self, node: dict) -> dict:
|
||||
node_data = node["data"]
|
||||
self.warnings.append(
|
||||
@@ -623,46 +672,53 @@ class DifyConverter(BaseConverter):
|
||||
)
|
||||
)
|
||||
params = []
|
||||
for param in node_data["parameters"]:
|
||||
for param in node_data.get("parameters", []):
|
||||
params.append(
|
||||
ParamsConfig(
|
||||
ParamsConfig.model_construct(
|
||||
name=param["name"],
|
||||
desc=param["description"],
|
||||
required=param["required"],
|
||||
type=param["type"],
|
||||
)
|
||||
)
|
||||
return ParameterExtractorNodeConfig.model_construct(
|
||||
result = ParameterExtractorNodeConfig.model_construct(
|
||||
text=self._process_list_variable_litearl(node_data["query"]),
|
||||
params=params,
|
||||
prompt=node_data["instruction"]
|
||||
prompt=node_data.get("instruction")
|
||||
).model_dump()
|
||||
|
||||
self.config_validate(node["id"], node["data"]["title"], ParameterExtractorNodeConfig, result)
|
||||
return result
|
||||
|
||||
def convert_variable_aggregator_node_config(self, node: dict) -> dict:
|
||||
node_data = node["data"]
|
||||
group_enable = node_data["advanced_settings"]["group_enabled"]
|
||||
advanced_settings = node_data.get("advanced_settings", {})
|
||||
group_variables = {}
|
||||
group_type = {}
|
||||
if not group_enable:
|
||||
if not advanced_settings or not advanced_settings["group_enabled"]:
|
||||
group_variables["output"] = [
|
||||
self._process_list_variable_litearl(variable)
|
||||
for variable in node_data["variables"]
|
||||
]
|
||||
group_type["output"] = node_data["output_type"]
|
||||
else:
|
||||
for group in node_data["advanced_settings"]["groups"]:
|
||||
for group in advanced_settings["groups"]:
|
||||
group_variables[group["group_name"]] = [
|
||||
self._process_list_variable_litearl(variable)
|
||||
for variable in group["variables"]
|
||||
]
|
||||
group_type[group["group_name"]] = group["output_type"]
|
||||
|
||||
return VariableAggregatorNodeConfig(
|
||||
group=group_enable,
|
||||
result = VariableAggregatorNodeConfig.model_construct(
|
||||
group=advanced_settings.get("group_enabled", False),
|
||||
group_variables=group_variables,
|
||||
group_type=group_type,
|
||||
).model_dump()
|
||||
|
||||
self.config_validate(node["id"], node["data"]["title"], VariableAggregatorNodeConfig, result)
|
||||
|
||||
return result
|
||||
|
||||
def convert_tool_node_config(self, node: dict) -> dict:
|
||||
node_data = node["data"]
|
||||
self.warnings.append(ExceptionDefineition(
|
||||
@@ -671,4 +727,4 @@ class DifyConverter(BaseConverter):
|
||||
type=ExceptionType.CONFIG,
|
||||
detail=f"Please reconfigure the tool node.",
|
||||
))
|
||||
return {}
|
||||
return {}
|
||||
|
||||
@@ -59,7 +59,7 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
|
||||
)
|
||||
|
||||
def map_node_type(self, platform_node_type) -> str:
|
||||
return self.NODE_TYPE_MAPPING.get(platform_node_type)
|
||||
return self.NODE_TYPE_MAPPING.get(platform_node_type, NodeType.UNKNOWN)
|
||||
|
||||
@property
|
||||
def origin_nodes(self):
|
||||
@@ -80,7 +80,7 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
|
||||
return True
|
||||
|
||||
def validate_config(self) -> bool:
|
||||
require_fields = frozenset({'app', 'dependencies', 'kind', 'version', 'workflow'})
|
||||
require_fields = frozenset({'app', 'kind', 'version', 'workflow'})
|
||||
if not all(field in self.config for field in require_fields):
|
||||
return False
|
||||
|
||||
@@ -179,8 +179,13 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
|
||||
node_type = node_data["type"]
|
||||
try:
|
||||
converter = self.get_node_convert(node_type)
|
||||
if converter is None:
|
||||
raise Exception(f"node type not supported - {node_type}")
|
||||
if node_type not in self.CONFIG_CONVERT_MAP:
|
||||
self.errors.append(ExceptionDefineition(
|
||||
type=ExceptionType.NODE,
|
||||
node_id=node["id"],
|
||||
node_name=node["data"]["title"],
|
||||
detail=f"node type {node_type} is unsupported",
|
||||
))
|
||||
return converter(node)
|
||||
except Exception as e:
|
||||
self.errors.append(ExceptionDefineition(
|
||||
|
||||
@@ -127,7 +127,7 @@ class EventStreamHandler:
|
||||
yield {
|
||||
"event": "message",
|
||||
"data": {
|
||||
"chunk": data.get("chunk")
|
||||
"content": data.get("chunk")
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -274,7 +274,7 @@ class StreamOutputCoordinator:
|
||||
yield {
|
||||
"event": "message",
|
||||
"data": {
|
||||
"chunk": final_chunk
|
||||
"content": final_chunk
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -73,7 +73,7 @@ class VariableStruct(BaseModel, Generic[T]):
|
||||
instance:
|
||||
The concrete variable object. The actual Python type is
|
||||
represented by the generic parameter ``T`` (e.g. StringVariable,
|
||||
NumberVariable, ArrayObject[StringVariable]).
|
||||
NumberVariable, ArrayVariable[StringVariable]).
|
||||
mut:
|
||||
Whether the variable is mutable.
|
||||
"""
|
||||
@@ -152,6 +152,36 @@ class VariablePool:
|
||||
return None
|
||||
return var_instance
|
||||
|
||||
def get_instance(
|
||||
self,
|
||||
selector: str,
|
||||
default: Any = None,
|
||||
strict: bool = True
|
||||
):
|
||||
"""Retrieve a variable instance from the variable pool.
|
||||
|
||||
Args:
|
||||
selector:
|
||||
Variable selector as a string variable literal (e.g. "{{ sys.message }}").
|
||||
default:
|
||||
The value to return if the variable does not exist.
|
||||
strict:
|
||||
If True, raises KeyError when the variable does not exist.
|
||||
|
||||
Returns:
|
||||
The variable instance object if it exists; otherwise returns `default`.
|
||||
|
||||
Raises:
|
||||
KeyError: If strict is True and the variable does not exist.
|
||||
"""
|
||||
variable_struct = self._get_variable_struct(selector)
|
||||
if variable_struct is None:
|
||||
if strict:
|
||||
raise KeyError(f"{selector} not exist")
|
||||
return default
|
||||
|
||||
return variable_struct.instance
|
||||
|
||||
def get_value(
|
||||
self,
|
||||
selector: str,
|
||||
@@ -273,38 +303,52 @@ class VariablePool:
|
||||
"""
|
||||
return self._get_variable_struct(selector) is not None
|
||||
|
||||
def get_all_system_vars(self) -> dict[str, Any]:
|
||||
def get_all_system_vars(self, literal=False) -> dict[str, Any]:
|
||||
"""获取所有系统变量
|
||||
|
||||
Returns:
|
||||
系统变量字典
|
||||
"""
|
||||
sys_namespace = self.variables.get("sys", {})
|
||||
if literal:
|
||||
return {k: v.instance.to_literal() for k, v in sys_namespace.items()}
|
||||
return {k: v.instance.get_value() for k, v in sys_namespace.items()}
|
||||
|
||||
def get_all_conversation_vars(self) -> dict[str, Any]:
|
||||
def get_all_conversation_vars(self, literal=False) -> dict[str, Any]:
|
||||
"""获取所有会话变量
|
||||
|
||||
Returns:
|
||||
会话变量字典
|
||||
"""
|
||||
conv_namespace = self.variables.get("conv", {})
|
||||
if literal:
|
||||
return {k: v.instance.to_literal() for k, v in conv_namespace.items()}
|
||||
return {k: v.instance.get_value() for k, v in conv_namespace.items()}
|
||||
|
||||
def get_all_node_outputs(self) -> dict[str, Any]:
|
||||
def get_all_node_outputs(self, literal=False) -> dict[str, Any]:
|
||||
"""获取所有节点输出(运行时变量)
|
||||
|
||||
Returns:
|
||||
节点输出字典,键为节点 ID
|
||||
"""
|
||||
runtime_vars = {
|
||||
namespace: {
|
||||
k: v.instance.get_value()
|
||||
for k, v in vars_dict.items()
|
||||
if literal:
|
||||
runtime_vars = {
|
||||
namespace: {
|
||||
k: v.instance.to_literal()
|
||||
for k, v in vars_dict.items()
|
||||
}
|
||||
for namespace, vars_dict in self.variables.items()
|
||||
if namespace not in ("sys", "conv")
|
||||
}
|
||||
else:
|
||||
runtime_vars = {
|
||||
namespace: {
|
||||
k: v.instance.get_value()
|
||||
for k, v in vars_dict.items()
|
||||
}
|
||||
for namespace, vars_dict in self.variables.items()
|
||||
if namespace not in ("sys", "conv")
|
||||
}
|
||||
for namespace, vars_dict in self.variables.items()
|
||||
if namespace not in ("sys", "conv")
|
||||
}
|
||||
return runtime_vars
|
||||
|
||||
def get_node_output(self, node_id: str, defalut: Any = None, strict: bool = True) -> dict[str, Any] | None:
|
||||
|
||||
@@ -132,24 +132,24 @@ class WorkflowExecutor:
|
||||
|
||||
start_time = datetime.datetime.now()
|
||||
|
||||
# Build the workflow graph
|
||||
graph = self.build_graph()
|
||||
|
||||
# Initialize the variable pool with input data
|
||||
await self.variable_initializer.initialize(
|
||||
variable_pool=self.variable_pool,
|
||||
input_data=input_data,
|
||||
execution_context=self.execution_context
|
||||
)
|
||||
initial_state = self.state_manager.create_initial_state(
|
||||
workflow_config=self.workflow_config,
|
||||
input_data=input_data,
|
||||
execution_context=self.execution_context,
|
||||
start_node_id=self.start_node_id
|
||||
)
|
||||
|
||||
# Execute the workflow
|
||||
try:
|
||||
# Build the workflow graph
|
||||
graph = self.build_graph()
|
||||
|
||||
# Initialize the variable pool with input data
|
||||
await self.variable_initializer.initialize(
|
||||
variable_pool=self.variable_pool,
|
||||
input_data=input_data,
|
||||
execution_context=self.execution_context
|
||||
)
|
||||
initial_state = self.state_manager.create_initial_state(
|
||||
workflow_config=self.workflow_config,
|
||||
input_data=input_data,
|
||||
execution_context=self.execution_context,
|
||||
start_node_id=self.start_node_id
|
||||
)
|
||||
|
||||
result = await graph.ainvoke(initial_state, config=self.execution_context.checkpoint_config)
|
||||
|
||||
# Aggregate output from all End nodes
|
||||
@@ -231,23 +231,23 @@ class WorkflowExecutor:
|
||||
}
|
||||
}
|
||||
|
||||
# Build the workflow graph in streaming mode
|
||||
graph = self.build_graph(stream=True)
|
||||
|
||||
# Initialize the variable pool and system variables
|
||||
await self.variable_initializer.initialize(
|
||||
variable_pool=self.variable_pool,
|
||||
input_data=input_data,
|
||||
execution_context=self.execution_context
|
||||
)
|
||||
initial_state = self.state_manager.create_initial_state(
|
||||
workflow_config=self.workflow_config,
|
||||
input_data=input_data,
|
||||
execution_context=self.execution_context,
|
||||
start_node_id=self.start_node_id
|
||||
)
|
||||
|
||||
try:
|
||||
# Build the workflow graph in streaming mode
|
||||
graph = self.build_graph(stream=True)
|
||||
|
||||
# Initialize the variable pool and system variables
|
||||
await self.variable_initializer.initialize(
|
||||
variable_pool=self.variable_pool,
|
||||
input_data=input_data,
|
||||
execution_context=self.execution_context
|
||||
)
|
||||
initial_state = self.state_manager.create_initial_state(
|
||||
workflow_config=self.workflow_config,
|
||||
input_data=input_data,
|
||||
execution_context=self.execution_context,
|
||||
start_node_id=self.start_node_id
|
||||
)
|
||||
|
||||
full_content = ''
|
||||
self.stream_coordinator.update_scope_activation("sys")
|
||||
|
||||
@@ -272,7 +272,7 @@ class WorkflowExecutor:
|
||||
event_type = data.get("type", "node_chunk") # "message" or "node_chunk"
|
||||
if event_type == "node_chunk":
|
||||
async for msg_event in self.event_handler.handle_node_chunk_event(data):
|
||||
full_content += msg_event["data"]["chunk"]
|
||||
full_content += msg_event["data"]["content"]
|
||||
yield msg_event
|
||||
|
||||
elif event_type == "node_error":
|
||||
@@ -295,12 +295,12 @@ class WorkflowExecutor:
|
||||
self.graph,
|
||||
self.execution_context.checkpoint_config
|
||||
):
|
||||
full_content += msg_event["data"]['chunk']
|
||||
full_content += msg_event["data"]['content']
|
||||
yield msg_event
|
||||
|
||||
# Flush any remaining chunks
|
||||
async for msg_event in self.stream_coordinator.flush_remaining_chunk(self.variable_pool):
|
||||
full_content += msg_event["data"]['chunk']
|
||||
full_content += msg_event["data"]['content']
|
||||
yield msg_event
|
||||
|
||||
result = graph.get_state(self.execution_context.checkpoint_config).values
|
||||
|
||||
@@ -16,7 +16,7 @@ from app.core.workflow.nodes.base_node import BaseNode
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
from app.db import get_db
|
||||
from app.models import AppRelease
|
||||
from app.services.draft_run_service import DraftRunService
|
||||
from app.services.draft_run_service import AgentRunService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -39,7 +39,7 @@ class AgentNode(BaseNode):
|
||||
def _output_types(self) -> dict[str, VariableType]:
|
||||
return {"output": VariableType.STRING}
|
||||
|
||||
def _prepare_agent(self, variable_pool: VariablePool) -> tuple[DraftRunService, AppRelease, str]:
|
||||
def _prepare_agent(self, variable_pool: VariablePool) -> tuple[AgentRunService, AppRelease, str]:
|
||||
"""准备 Agent(公共逻辑)
|
||||
|
||||
Args:
|
||||
@@ -65,7 +65,7 @@ class AgentNode(BaseNode):
|
||||
if not release:
|
||||
raise ValueError(f"Agent 不存在: {agent_id}")
|
||||
|
||||
draft_service = DraftRunService(db)
|
||||
draft_service = AgentRunService(db)
|
||||
|
||||
return draft_service, release, message
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
from abc import ABC, abstractmethod
|
||||
from functools import cached_property
|
||||
from typing import Any, AsyncGenerator
|
||||
@@ -10,8 +11,10 @@ from app.core.config import settings
|
||||
from app.core.workflow.engine.state_manager import WorkflowState
|
||||
from app.core.workflow.engine.variable_pool import VariablePool
|
||||
from app.core.workflow.nodes.enums import BRANCH_NODES
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
from app.services.multimodal_service import PROVIDER_STRATEGIES
|
||||
from app.core.workflow.variable.base_variable import VariableType, FileObject
|
||||
from app.db import get_db_read
|
||||
from app.schemas import FileInput
|
||||
from app.services.multimodal_service import MultimodalService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -548,9 +551,9 @@ class BaseNode(ABC):
|
||||
|
||||
return render_template(
|
||||
template=template,
|
||||
conv_vars=variable_pool.get_all_conversation_vars(),
|
||||
node_outputs=variable_pool.get_all_node_outputs(),
|
||||
system_vars=variable_pool.get_all_system_vars(),
|
||||
conv_vars=variable_pool.get_all_conversation_vars(literal=True),
|
||||
node_outputs=variable_pool.get_all_node_outputs(literal=True),
|
||||
system_vars=variable_pool.get_all_system_vars(literal=True),
|
||||
strict=strict
|
||||
)
|
||||
|
||||
@@ -614,16 +617,32 @@ class BaseNode(ABC):
|
||||
return variable_pool.has(selector)
|
||||
|
||||
@staticmethod
|
||||
async def process_message(provider, content, enable_file=False) -> dict | str | None:
|
||||
async def process_message(provider: str, content: str | FileObject, enable_file=False) -> dict | str | None:
|
||||
if isinstance(content, str):
|
||||
if enable_file:
|
||||
return {"text": content}
|
||||
return content
|
||||
elif isinstance(content, dict):
|
||||
trans_tool = PROVIDER_STRATEGIES[provider]()
|
||||
result = await trans_tool.format_image(content["url"])
|
||||
return result
|
||||
raise TypeError('Unexpect input value type')
|
||||
|
||||
elif isinstance(content, FileObject):
|
||||
if content.content_cache.get(provider):
|
||||
return content.content_cache[provider]
|
||||
with get_db_read() as db:
|
||||
multimodel_service = MultimodalService(db, provider)
|
||||
message = await multimodel_service.process_files(
|
||||
[FileInput.model_construct(
|
||||
type=content.type,
|
||||
url=content.url,
|
||||
transfer_method=content.transfer_method,
|
||||
file_type=content.origin_file_type,
|
||||
upload_file_id=content.file_id
|
||||
)]
|
||||
)
|
||||
|
||||
if message:
|
||||
content.content_cache[provider] = message[0]
|
||||
return message[0]
|
||||
return None
|
||||
raise TypeError(f'Unexpect input value type - {type(content)}')
|
||||
|
||||
@staticmethod
|
||||
def process_model_output(content) -> str:
|
||||
|
||||
@@ -91,8 +91,8 @@ class IterationRuntime:
|
||||
return loopstate
|
||||
|
||||
def merge_conv_vars(self):
|
||||
self.variable_pool.get_all_conversation_vars().update(
|
||||
self.child_variable_pool.get_all_conversation_vars()
|
||||
self.variable_pool.variables["conv"].update(
|
||||
self.child_variable_pool.variables["conv"]
|
||||
)
|
||||
|
||||
async def run_task(self, item, idx):
|
||||
|
||||
@@ -156,7 +156,7 @@ class LoopRuntime:
|
||||
|
||||
def merge_conv_vars(self, loopstate):
|
||||
self.variable_pool.variables["conv"].update(
|
||||
self.child_variable_pool.variables.get("conv", {})
|
||||
self.child_variable_pool.variables["conv"]
|
||||
)
|
||||
loop_vars = self.child_variable_pool.get_node_output(self.node_id, defalut={}, strict=False)
|
||||
loopstate["node_outputs"][self.node_id] = loop_vars
|
||||
|
||||
@@ -66,7 +66,7 @@ class CycleGraphNode(BaseNode):
|
||||
if config.flatten:
|
||||
outputs['output'] = config.output_type
|
||||
else:
|
||||
outputs['output'] = VariableType.ARRAY_STRING
|
||||
outputs['output'] = VariableType.NESTED_ARRAY
|
||||
else:
|
||||
outputs['output'] = VariableType(f"array[{config.output_type}]")
|
||||
return outputs
|
||||
|
||||
@@ -24,6 +24,8 @@ class NodeType(StrEnum):
|
||||
MEMORY_READ = "memory-read"
|
||||
MEMORY_WRITE = "memory-write"
|
||||
|
||||
UNKNOWN = "unknown"
|
||||
|
||||
|
||||
BRANCH_NODES = [NodeType.IF_ELSE, NodeType.HTTP_REQUEST, NodeType.QUESTION_CLASSIFIER]
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Any, Callable, Coroutine
|
||||
|
||||
import httpx
|
||||
@@ -13,6 +14,7 @@ from app.core.workflow.nodes.base_node import BaseNode
|
||||
from app.core.workflow.nodes.enums import HttpRequestMethod, HttpErrorHandle, HttpAuthType, HttpContentType
|
||||
from app.core.workflow.nodes.http_request.config import HttpRequestNodeConfig, HttpRequestNodeOutput
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
from app.core.workflow.variable.variable_objects import FileVariable, ArrayVariable
|
||||
|
||||
logger = logging.getLogger(__file__)
|
||||
|
||||
@@ -115,7 +117,7 @@ class HttpRequestNode(BaseNode):
|
||||
params[self._render_template(key, variable_pool)] = self._render_template(value, variable_pool)
|
||||
return params
|
||||
|
||||
def _build_content(self, variable_pool: VariablePool) -> dict[str, Any]:
|
||||
async def _build_content(self, variable_pool: VariablePool) -> dict[str, Any]:
|
||||
"""
|
||||
Build HTTP request body arguments for httpx request methods.
|
||||
|
||||
@@ -135,16 +137,35 @@ class HttpRequestNode(BaseNode):
|
||||
))
|
||||
case HttpContentType.FROM_DATA:
|
||||
data = {}
|
||||
content["files"] = {}
|
||||
for item in self.typed_config.body.data:
|
||||
if item.type == "text":
|
||||
data[self._render_template(item.key, variable_pool)] = self._render_template(item.value, variable_pool)
|
||||
data[self._render_template(item.key, variable_pool)] = self._render_template(item.value,
|
||||
variable_pool)
|
||||
elif item.type == "file":
|
||||
# TODO: File support (Feature)
|
||||
pass
|
||||
content["files"][self._render_template(item.key, variable_pool)] = (
|
||||
uuid.uuid4().hex,
|
||||
await variable_pool.get_instance(item.value).get_content()
|
||||
)
|
||||
content["data"] = data
|
||||
case HttpContentType.BINARY:
|
||||
# TODO: File support (Feature)
|
||||
pass
|
||||
content["files"] = []
|
||||
file_instence = variable_pool.get_instance(self.typed_config.body.data)
|
||||
if isinstance(file_instence, ArrayVariable):
|
||||
for v in file_instence.value:
|
||||
if isinstance(v, FileVariable):
|
||||
content["files"].append(
|
||||
(
|
||||
"files", (uuid.uuid4().hex, await v.get_content())
|
||||
)
|
||||
)
|
||||
elif isinstance(file_instence, FileVariable):
|
||||
content["files"].append(
|
||||
(
|
||||
"file", (uuid.uuid4().hex, await file_instence.get_content())
|
||||
)
|
||||
)
|
||||
|
||||
case HttpContentType.WWW_FORM:
|
||||
content["data"] = json.loads(self._render_template(
|
||||
json.dumps(self.typed_config.body.data), variable_pool
|
||||
@@ -207,7 +228,7 @@ class HttpRequestNode(BaseNode):
|
||||
request_func = self._get_client_method(client)
|
||||
resp = await request_func(
|
||||
url=self._render_template(self.typed_config.url, variable_pool),
|
||||
**self._build_content(variable_pool)
|
||||
**(await self._build_content(variable_pool))
|
||||
)
|
||||
resp.raise_for_status()
|
||||
logger.info(f"Node {self.node_id}: HTTP request succeeded")
|
||||
|
||||
@@ -172,9 +172,9 @@ class LLMNode(BaseNode):
|
||||
|
||||
if self.typed_config.vision_input and self.typed_config.vision:
|
||||
file_content = []
|
||||
files = variable_pool.get_value(self.typed_config.vision_input)
|
||||
for file in files:
|
||||
content = await self.process_message(provider, file, self.typed_config.vision)
|
||||
files = variable_pool.get_instance(self.typed_config.vision_input)
|
||||
for file in files.value:
|
||||
content = await self.process_message(provider, file.value, self.typed_config.vision)
|
||||
if content:
|
||||
file_content.append(content)
|
||||
if messages and messages[-1]["role"] == 'user':
|
||||
|
||||
@@ -123,10 +123,10 @@ class NodeFactory:
|
||||
# 获取节点类
|
||||
node_class = cls._node_types.get(node_type)
|
||||
if not node_class:
|
||||
raise ValueError(f"不支持的节点类型: {node_type}")
|
||||
raise ValueError(f"Unsupported node type: {node_type}")
|
||||
|
||||
# 创建节点实例
|
||||
logger.debug(f"创建节点: {node_config.get('id')} (type={node_type})")
|
||||
logger.debug(f"create node instance: {node_config.get('id')} (type={node_type})")
|
||||
return node_class(node_config, workflow_config)
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -2,7 +2,7 @@ from enum import StrEnum
|
||||
from abc import abstractmethod, ABC
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.schemas import FileType
|
||||
|
||||
@@ -45,7 +45,7 @@ class VariableType(StrEnum):
|
||||
return cls.NUMBER
|
||||
elif isinstance(var, bool):
|
||||
return cls.BOOLEAN
|
||||
elif isinstance(var, FileObject) or (isinstance(var, dict) and var.get('__file')):
|
||||
elif isinstance(var, FileObject) or (isinstance(var, dict) and var.get('is_file')):
|
||||
return cls.FILE
|
||||
elif isinstance(var, dict):
|
||||
return cls.OBJECT
|
||||
@@ -109,7 +109,13 @@ def DEFAULT_VALUE(var_type: VariableType) -> Any:
|
||||
class FileObject(BaseModel):
|
||||
type: FileType
|
||||
url: str
|
||||
__file: bool
|
||||
transfer_method: str
|
||||
origin_file_type: str
|
||||
file_id: str | None
|
||||
|
||||
content_cache: dict = Field(default_factory=dict)
|
||||
|
||||
is_file: bool
|
||||
|
||||
|
||||
class BaseVariable(ABC):
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
from typing import Any, TypeVar, Type, Generic
|
||||
|
||||
import httpx
|
||||
from deprecated import deprecated
|
||||
|
||||
from app.core.workflow.variable.base_variable import BaseVariable, VariableType, FileObject, FileType
|
||||
from app.core.config import settings
|
||||
|
||||
T = TypeVar("T", bound=BaseVariable)
|
||||
|
||||
@@ -61,13 +63,16 @@ class FileVariable(BaseVariable):
|
||||
def valid_value(self, value) -> FileObject:
|
||||
|
||||
if isinstance(value, dict):
|
||||
if not value.get("__file"):
|
||||
if not value.get("is_file"):
|
||||
raise TypeError(f"Value must be a FileObject - {type(value)}:{value}")
|
||||
return FileObject(
|
||||
**{
|
||||
"type": str(value.get('type')),
|
||||
"transfer_method": value.get("transfer_method"),
|
||||
"url": value.get('url'),
|
||||
"__file": True
|
||||
"file_id": value.get("file_id"),
|
||||
"origin_file_type": value.get("origin_file_type"),
|
||||
"is_file": True
|
||||
}
|
||||
)
|
||||
if isinstance(value, FileObject):
|
||||
@@ -80,8 +85,23 @@ class FileVariable(BaseVariable):
|
||||
def get_value(self) -> Any:
|
||||
return self.value.model_dump()
|
||||
|
||||
async def get_content(self):
|
||||
total_bytes = 0
|
||||
chunks = []
|
||||
|
||||
class ArrayObject(BaseVariable, Generic[T]):
|
||||
async with httpx.AsyncClient() as client:
|
||||
async with client.stream("GET", self.value.url) as resp:
|
||||
resp.raise_for_status()
|
||||
async for chunk in resp.aiter_bytes(8192):
|
||||
total_bytes += len(chunk)
|
||||
if total_bytes > settings.MAX_FILE_SIZE:
|
||||
raise ValueError(f"File too large: {total_bytes} bytes")
|
||||
chunks.append(chunk)
|
||||
|
||||
return b"".join(chunks)
|
||||
|
||||
|
||||
class ArrayVariable(BaseVariable, Generic[T]):
|
||||
type = 'array'
|
||||
|
||||
def __init__(self, child_type: Type[T], value: list[Any]):
|
||||
@@ -108,7 +128,7 @@ class ArrayObject(BaseVariable, Generic[T]):
|
||||
return [v.get_value() for v in self.value]
|
||||
|
||||
|
||||
class NestedArrayObject(BaseVariable):
|
||||
class NestedArrayVariable(BaseVariable):
|
||||
type = 'array_nest'
|
||||
|
||||
def valid_value(self, value: list[T]) -> list[T]:
|
||||
@@ -116,23 +136,23 @@ class NestedArrayObject(BaseVariable):
|
||||
raise TypeError(f"Value must be a list - {type(value)}:{value}")
|
||||
final_value = []
|
||||
for v in value:
|
||||
if not isinstance(v, ArrayObject):
|
||||
if not isinstance(v, list):
|
||||
raise TypeError("All elements must be of type list")
|
||||
final_value.append(v)
|
||||
final_value.append(make_array(AnyVariable, v))
|
||||
return final_value
|
||||
|
||||
def to_literal(self) -> str:
|
||||
return "\n".join(["\n".join([item.to_literal() for item in row]) for row in self.value])
|
||||
return "\n".join(["\n".join([str(item) for item in row.get_value()]) for row in self.value])
|
||||
|
||||
def get_value(self) -> Any:
|
||||
return [[item.get_value() for item in row] for row in self.value]
|
||||
return [[item for item in row.get_value()] for row in self.value]
|
||||
|
||||
|
||||
@deprecated(
|
||||
reason="Using arbitrary-type values may cause unexpected errors; please switch to strongly-typed values.",
|
||||
category=RuntimeWarning
|
||||
)
|
||||
class AnyObject(BaseVariable):
|
||||
class AnyVariable(BaseVariable):
|
||||
type = 'any'
|
||||
|
||||
def valid_value(self, value: Any) -> Any:
|
||||
@@ -142,10 +162,10 @@ class AnyObject(BaseVariable):
|
||||
return str(self.value)
|
||||
|
||||
|
||||
def make_array(child_type: Type[T], value: list[Any]) -> ArrayObject[T]:
|
||||
"""简化 ArrayObject 创建,不需要重复写类型"""
|
||||
def make_array(child_type: Type[T], value: list[Any]) -> ArrayVariable[T]:
|
||||
"""简化 ArrayVariable 创建,不需要重复写类型"""
|
||||
|
||||
return ArrayObject(child_type, value)
|
||||
return ArrayVariable(child_type, value)
|
||||
|
||||
|
||||
def create_variable_instance(var_type: VariableType, value: Any) -> T:
|
||||
@@ -168,7 +188,9 @@ def create_variable_instance(var_type: VariableType, value: Any) -> T:
|
||||
return make_array(DictVariable, value)
|
||||
case VariableType.ARRAY_FILE:
|
||||
return make_array(FileVariable, value)
|
||||
case VariableType.NESTED_ARRAY:
|
||||
return NestedArrayVariable(value)
|
||||
case VariableType.ANY:
|
||||
return AnyObject(value)
|
||||
return AnyVariable(value)
|
||||
case _:
|
||||
raise TypeError(f"Invalid type - {var_type}")
|
||||
|
||||
@@ -35,6 +35,7 @@ from .ontology_scene import OntologyScene
|
||||
from .ontology_class import OntologyClass
|
||||
from .ontology_scene import OntologyScene
|
||||
from .ontology_class import OntologyClass
|
||||
from .implicit_emotions_storage_model import ImplicitEmotionsStorage
|
||||
|
||||
__all__ = [
|
||||
"Tenants",
|
||||
@@ -90,5 +91,6 @@ __all__ = [
|
||||
"MemoryPerceptualModel",
|
||||
"ModelBase",
|
||||
"LoadBalanceStrategy",
|
||||
"Skill"
|
||||
"Skill",
|
||||
"ImplicitEmotionsStorage"
|
||||
]
|
||||
|
||||
45
api/app/models/implicit_emotions_storage_model.py
Normal file
45
api/app/models/implicit_emotions_storage_model.py
Normal file
@@ -0,0 +1,45 @@
|
||||
"""
|
||||
Implicit Emotions Storage Model
|
||||
|
||||
数据库模型:存储用户的隐性记忆画像和情绪建议数据
|
||||
替代原有的Redis缓存方式
|
||||
"""
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from sqlalchemy import Column, String, Text, DateTime, Index
|
||||
from sqlalchemy.dialects.postgresql import UUID, JSONB
|
||||
from app.db import Base
|
||||
|
||||
|
||||
class ImplicitEmotionsStorage(Base):
|
||||
"""隐性记忆和情绪存储表"""
|
||||
|
||||
__tablename__ = "implicit_emotions_storage"
|
||||
|
||||
# 主键
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, comment="主键ID")
|
||||
|
||||
# 用户标识(unique=True会自动创建唯一索引)
|
||||
end_user_id = Column(String(255), nullable=False, unique=True, comment="终端用户ID")
|
||||
|
||||
# 隐性记忆画像数据(JSON格式)
|
||||
implicit_profile = Column(JSONB, nullable=True, comment="隐性记忆用户画像数据")
|
||||
|
||||
# 情绪建议数据(JSON格式)
|
||||
emotion_suggestions = Column(JSONB, nullable=True, comment="情绪个性化建议数据")
|
||||
|
||||
# 时间戳
|
||||
created_at = Column(DateTime, nullable=False, default=datetime.utcnow, comment="创建时间")
|
||||
updated_at = Column(DateTime, nullable=False, default=datetime.utcnow, onupdate=datetime.utcnow, comment="更新时间")
|
||||
|
||||
# 数据生成时间(用于业务逻辑)
|
||||
implicit_generated_at = Column(DateTime, nullable=True, comment="隐性记忆画像生成时间")
|
||||
emotion_generated_at = Column(DateTime, nullable=True, comment="情绪建议生成时间")
|
||||
|
||||
# 索引(只为updated_at创建索引,end_user_id的unique约束已自动创建索引)
|
||||
__table_args__ = (
|
||||
Index('idx_updated_at', 'updated_at'),
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
return f"<ImplicitEmotionsStorage(id={self.id}, end_user_id={self.end_user_id})>"
|
||||
@@ -2,7 +2,7 @@ import datetime
|
||||
import uuid
|
||||
from enum import StrEnum
|
||||
|
||||
from sqlalchemy import Column, String, Boolean, DateTime, Text, ForeignKey, Enum as SQLEnum, UniqueConstraint, Integer, ARRAY, Table
|
||||
from sqlalchemy import Column, String, Boolean, DateTime, Text, ForeignKey, Enum as SQLEnum, UniqueConstraint, Integer, ARRAY, Table, text
|
||||
from sqlalchemy.dialects.postgresql import UUID, JSON
|
||||
from sqlalchemy.orm import relationship
|
||||
from sqlalchemy.sql import func
|
||||
@@ -78,6 +78,9 @@ class ModelConfig(BaseModel):
|
||||
description = Column(String, comment="模型描述")
|
||||
|
||||
# 模型配置参数
|
||||
capability = Column(ARRAY(String), default=list, nullable=False, server_default=text("'{}'::varchar[]"),
|
||||
comment="模型能力列表(如['vision', 'audio', 'video'])")
|
||||
is_omni = Column(Boolean, default=False, nullable=False, server_default="false", comment="是否为Omni模型(使用特殊API调用)")
|
||||
config = Column(JSON, comment="模型配置参数")
|
||||
# - temperature : 控制生成文本的随机性。值越高,输出越随机、越有创造性;值越低,输出越确定、越保守。
|
||||
# - top_p : 一种替代 temperature 的采样方法,控制模型从概率最高的词中选择的范围。
|
||||
@@ -118,6 +121,11 @@ class ModelApiKey(BaseModel):
|
||||
api_key = Column(String, nullable=False, comment="API密钥")
|
||||
api_base = Column(String, comment="API基础URL")
|
||||
|
||||
# 模型能力参数
|
||||
capability = Column(ARRAY(String), default=list, nullable=False, server_default=text("'{}'::varchar[]"),
|
||||
comment="模型能力列表(如['vision', 'audio', 'video'])")
|
||||
is_omni = Column(Boolean, default=False, nullable=False, server_default="false", comment="是否为Omni模型(使用特殊API调用)")
|
||||
|
||||
# 配置参数
|
||||
config = Column(JSON, comment="API Key特定配置")
|
||||
|
||||
@@ -155,6 +163,9 @@ class ModelBase(Base):
|
||||
tags = Column(ARRAY(String), default=list, nullable=False, comment="模型标签(如['聊天', '创作'])")
|
||||
add_count = Column(Integer, default=0, nullable=False, comment="模型被用户添加的次数")
|
||||
created_at = Column(DateTime, default=datetime.datetime.now, comment="创建时间", server_default=func.now())
|
||||
capability = Column(ARRAY(String), default=list, nullable=False, server_default=text("'{}'::varchar[]"),
|
||||
comment="模型能力列表(如['vision', 'audio', 'video'])")
|
||||
is_omni = Column(Boolean, default=False, nullable=False, server_default="false", comment="是否为Omni模型(使用特殊API调用)")
|
||||
|
||||
# 关联关系
|
||||
configs = relationship("ModelConfig", back_populates="model_base", cascade="all, delete-orphan")
|
||||
|
||||
169
api/app/repositories/implicit_emotions_storage_repository.py
Normal file
169
api/app/repositories/implicit_emotions_storage_repository.py
Normal file
@@ -0,0 +1,169 @@
|
||||
"""
|
||||
Implicit Emotions Storage Repository
|
||||
|
||||
数据访问层:处理隐性记忆和情绪数据的数据库操作
|
||||
事务由调用方控制,仓储层只使用 flush/refresh
|
||||
"""
|
||||
import logging
|
||||
from datetime import datetime, date, timezone, timedelta
|
||||
from typing import Optional, Generator
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import select, not_, exists
|
||||
|
||||
from app.models.implicit_emotions_storage_model import ImplicitEmotionsStorage
|
||||
from app.models.end_user_model import EndUser
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ImplicitEmotionsStorageRepository:
|
||||
"""隐性记忆和情绪存储仓储类"""
|
||||
|
||||
def __init__(self, db: Session):
|
||||
self.db = db
|
||||
|
||||
def get_by_end_user_id(self, end_user_id: str) -> Optional[ImplicitEmotionsStorage]:
|
||||
"""根据终端用户ID获取存储记录"""
|
||||
try:
|
||||
stmt = select(ImplicitEmotionsStorage).where(
|
||||
ImplicitEmotionsStorage.end_user_id == end_user_id
|
||||
)
|
||||
return self.db.execute(stmt).scalar_one_or_none()
|
||||
except Exception as e:
|
||||
logger.error(f"获取用户存储记录失败: end_user_id={end_user_id}, error={e}")
|
||||
return None
|
||||
|
||||
def create(self, end_user_id: str) -> ImplicitEmotionsStorage:
|
||||
"""创建新的存储记录(事务由调用方提交)"""
|
||||
storage = ImplicitEmotionsStorage(
|
||||
end_user_id=end_user_id,
|
||||
created_at=datetime.utcnow(),
|
||||
updated_at=datetime.utcnow()
|
||||
)
|
||||
self.db.add(storage)
|
||||
self.db.flush()
|
||||
self.db.refresh(storage)
|
||||
logger.info(f"创建用户存储记录成功: end_user_id={end_user_id}")
|
||||
return storage
|
||||
|
||||
def update_implicit_profile(
|
||||
self,
|
||||
end_user_id: str,
|
||||
profile_data: dict
|
||||
) -> ImplicitEmotionsStorage:
|
||||
"""更新隐性记忆画像数据(事务由调用方提交)"""
|
||||
storage = self.get_by_end_user_id(end_user_id)
|
||||
if storage is None:
|
||||
storage = self.create(end_user_id)
|
||||
|
||||
storage.implicit_profile = profile_data
|
||||
storage.implicit_generated_at = datetime.utcnow()
|
||||
storage.updated_at = datetime.utcnow()
|
||||
|
||||
self.db.flush()
|
||||
self.db.refresh(storage)
|
||||
logger.info(f"更新隐性记忆画像成功: end_user_id={end_user_id}")
|
||||
return storage
|
||||
|
||||
def update_emotion_suggestions(
|
||||
self,
|
||||
end_user_id: str,
|
||||
suggestions_data: dict
|
||||
) -> ImplicitEmotionsStorage:
|
||||
"""更新情绪建议数据(事务由调用方提交)"""
|
||||
storage = self.get_by_end_user_id(end_user_id)
|
||||
if storage is None:
|
||||
storage = self.create(end_user_id)
|
||||
|
||||
storage.emotion_suggestions = suggestions_data
|
||||
storage.emotion_generated_at = datetime.utcnow()
|
||||
storage.updated_at = datetime.utcnow()
|
||||
|
||||
self.db.flush()
|
||||
self.db.refresh(storage)
|
||||
logger.info(f"更新情绪建议成功: end_user_id={end_user_id}")
|
||||
return storage
|
||||
|
||||
def get_all_user_ids(self, batch_size: int = 100) -> Generator[str, None, None]:
|
||||
"""分批次获取所有已存储数据的用户ID(避免大数据量内存溢出)
|
||||
|
||||
Args:
|
||||
batch_size: 每批次加载的数量,默认100
|
||||
|
||||
Yields:
|
||||
用户ID字符串
|
||||
"""
|
||||
offset = 0
|
||||
while True:
|
||||
try:
|
||||
stmt = (
|
||||
select(ImplicitEmotionsStorage.end_user_id)
|
||||
.order_by(ImplicitEmotionsStorage.end_user_id)
|
||||
.limit(batch_size)
|
||||
.offset(offset)
|
||||
)
|
||||
batch = self.db.execute(stmt).scalars().all()
|
||||
if not batch:
|
||||
break
|
||||
yield from batch
|
||||
offset += batch_size
|
||||
except Exception as e:
|
||||
logger.error(f"分批获取用户ID失败: offset={offset}, error={e}")
|
||||
break
|
||||
|
||||
def get_new_user_ids_today(self, batch_size: int = 100) -> Generator[str, None, None]:
|
||||
"""分批次获取当天新增的、尚未初始化隐性记忆和情绪建议数据的用户ID
|
||||
|
||||
查询逻辑:end_users 表中 created_at 为今天,且在 implicit_emotions_storage 中没有对应记录。
|
||||
没有对应记录意味着隐性记忆画像和情绪建议均未初始化,需要对这批用户执行首次初始化。
|
||||
end_users.id(UUID)转为字符串后与 implicit_emotions_storage.end_user_id(String)对比。
|
||||
|
||||
Args:
|
||||
batch_size: 每批次加载的数量,默认100
|
||||
|
||||
Yields:
|
||||
用户ID字符串
|
||||
"""
|
||||
from sqlalchemy import cast, String as SAString
|
||||
CST = timezone(timedelta(hours=8))
|
||||
now_cst = datetime.now(CST)
|
||||
today_start = now_cst.replace(hour=0, minute=0, second=0, microsecond=0).astimezone(timezone.utc).replace(tzinfo=None)
|
||||
tomorrow_start = today_start + timedelta(days=1)
|
||||
offset = 0
|
||||
while True:
|
||||
try:
|
||||
stmt = (
|
||||
select(EndUser.id)
|
||||
.where(
|
||||
EndUser.created_at >= today_start,
|
||||
EndUser.created_at < tomorrow_start,
|
||||
not_(
|
||||
exists(
|
||||
select(ImplicitEmotionsStorage.end_user_id).where(
|
||||
ImplicitEmotionsStorage.end_user_id == cast(EndUser.id, SAString)
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
.order_by(EndUser.id)
|
||||
.limit(batch_size)
|
||||
.offset(offset)
|
||||
)
|
||||
batch = self.db.execute(stmt).scalars().all()
|
||||
if not batch:
|
||||
break
|
||||
yield from (str(uid) for uid in batch)
|
||||
offset += batch_size
|
||||
except Exception as e:
|
||||
logger.error(f"分批获取当天新增用户ID失败: offset={offset}, error={e}")
|
||||
break
|
||||
|
||||
def delete_by_end_user_id(self, end_user_id: str) -> bool:
|
||||
"""删除用户的存储记录(事务由调用方提交)"""
|
||||
storage = self.get_by_end_user_id(end_user_id)
|
||||
if storage:
|
||||
self.db.delete(storage)
|
||||
self.db.flush()
|
||||
logger.info(f"删除用户存储记录成功: end_user_id={end_user_id}")
|
||||
return True
|
||||
return False
|
||||
@@ -211,3 +211,46 @@ def get_total_kb_count_by_workspace(db: Session, workspace_id: uuid.UUID) -> int
|
||||
except Exception as e:
|
||||
db_logger.error(f"Failed to query total knowledge base count: workspace_id={workspace_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
def get_user_kb_chunk_num_by_workspace(db: Session, workspace_id: uuid.UUID) -> int:
|
||||
"""
|
||||
根据workspace_id查询knowledges表中permission_id='Memory'(用户知识库)的chunk_num总和
|
||||
"""
|
||||
db_logger.debug(f"Query user KB chunk_num by workspace_id: workspace_id={workspace_id}")
|
||||
|
||||
try:
|
||||
from sqlalchemy import func
|
||||
result = db.query(func.sum(Knowledge.chunk_num)).filter(
|
||||
Knowledge.workspace_id == workspace_id,
|
||||
Knowledge.status == 1,
|
||||
Knowledge.permission_id == "Memory"
|
||||
).scalar()
|
||||
|
||||
total = result if result is not None else 0
|
||||
db_logger.info(f"User KB chunk_num query successful: workspace_id={workspace_id}, total={total}")
|
||||
return total
|
||||
except Exception as e:
|
||||
db_logger.error(f"Failed to query user KB chunk_num: workspace_id={workspace_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
def get_non_user_kb_count_by_workspace(db: Session, workspace_id: uuid.UUID) -> int:
|
||||
"""
|
||||
根据workspace_id查询knowledges表中排除用户知识库(permission_id!='Memory')的数量
|
||||
"""
|
||||
db_logger.debug(f"Query non-user KB count by workspace_id: workspace_id={workspace_id}")
|
||||
|
||||
try:
|
||||
count = db.query(Knowledge).filter(
|
||||
Knowledge.workspace_id == workspace_id,
|
||||
Knowledge.status == 1,
|
||||
Knowledge.permission_id != "Memory"
|
||||
).count()
|
||||
|
||||
db_logger.info(f"Non-user KB count query successful: workspace_id={workspace_id}, count={count}")
|
||||
return count
|
||||
except Exception as e:
|
||||
db_logger.error(f"Failed to query non-user KB count: workspace_id={workspace_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
@@ -155,8 +155,7 @@ class ApiKey(BaseModel):
|
||||
return datetime.datetime.now() > self.expires_at
|
||||
|
||||
@field_serializer('expires_at', 'last_used_at', 'created_at', 'updated_at')
|
||||
@classmethod
|
||||
def serialize_datetime(cls, v: Optional[datetime.datetime]) -> Optional[int]:
|
||||
def serialize_datetime(self, v: Optional[datetime.datetime]) -> Optional[int]:
|
||||
"""将datetime转换为时间戳"""
|
||||
return datetime_to_timestamp(v)
|
||||
|
||||
@@ -171,8 +170,7 @@ class ApiKeyStats(BaseModel):
|
||||
avg_response_time: Optional[float] = Field(None, description="平均响应时间(毫秒)")
|
||||
|
||||
@field_serializer('last_used_at')
|
||||
@classmethod
|
||||
def serialize_datetime(cls, v: Optional[datetime.datetime]) -> Optional[int]:
|
||||
def serialize_datetime(self, v: Optional[datetime.datetime]) -> Optional[int]:
|
||||
"""将datetime转换为时间戳"""
|
||||
return datetime_to_timestamp(v)
|
||||
|
||||
@@ -219,7 +217,6 @@ class ApiKeyLog(BaseModel):
|
||||
created_at: datetime.datetime
|
||||
|
||||
@field_serializer('created_at')
|
||||
@classmethod
|
||||
def serialize_datetime(cls, v: datetime.datetime) -> int:
|
||||
def serialize_datetime(self, v: datetime.datetime) -> int:
|
||||
"""将datetime转换为时间戳"""
|
||||
return datetime_to_timestamp(v)
|
||||
|
||||
@@ -21,8 +21,14 @@ class FileType(StrEnum):
|
||||
def trans(cls, value: str) -> 'FileType':
|
||||
if value.startswith("image"):
|
||||
return cls.IMAGE
|
||||
# TODO: other file type support
|
||||
raise RuntimeError("Unsupport file type")
|
||||
elif value.startswith("document"):
|
||||
return cls.DOCUMENT
|
||||
elif value.startswith("audio"):
|
||||
return cls.AUDIO
|
||||
elif value.startswith("video"):
|
||||
return cls.VIDEO
|
||||
else:
|
||||
raise RuntimeError("Unsupport file type")
|
||||
|
||||
|
||||
class TransferMethod(str, Enum):
|
||||
@@ -37,6 +43,12 @@ class FileInput(BaseModel):
|
||||
transfer_method: TransferMethod = Field(..., description="传输方式: local_file/remote_url")
|
||||
upload_file_id: Optional[uuid.UUID] = Field(None, description="已上传文件ID(local_file时必填)")
|
||||
url: Optional[str] = Field(None, description="远程URL(remote_url时必填)")
|
||||
file_type: Optional[str] = Field(None, description="具体文件格式(如image/jpg、audio/wav、document/docx、video/mp4)")
|
||||
|
||||
def __init__(self, **data):
|
||||
if "type" in data:
|
||||
data['file_type'] = data['type']
|
||||
super().__init__(**data)
|
||||
|
||||
@field_validator("type", mode="before")
|
||||
@classmethod
|
||||
@@ -433,7 +445,7 @@ class AppChatRequest(BaseModel):
|
||||
user_id: Optional[str] = Field(default=None, description="用户ID(用于会话管理)")
|
||||
variables: Optional[Dict[str, Any]] = Field(default=None, description="自定义变量参数值")
|
||||
stream: bool = Field(default=False, description="是否流式返回")
|
||||
files: Optional[List[FileInput]] = Field(default=None, description="附件列表(支持多文件)")
|
||||
files: List[FileInput] = Field(default_factory=list, description="附件列表(支持多文件)")
|
||||
|
||||
|
||||
class DraftRunRequest(BaseModel):
|
||||
|
||||
@@ -46,6 +46,7 @@ class ChunkUpdate(BaseModel):
|
||||
class ChunkRetrieve(BaseModel):
|
||||
query: str
|
||||
kb_ids: list[uuid.UUID]
|
||||
file_names_filter: list[str] | None = Field(None)
|
||||
similarity_threshold: float | None = Field(None)
|
||||
vector_similarity_weight: float | None = Field(None)
|
||||
top_k: int | None = Field(None)
|
||||
|
||||
@@ -21,6 +21,8 @@ class ModelConfigBase(BaseModel):
|
||||
is_active: bool = Field(True, description="是否激活")
|
||||
is_public: bool = Field(False, description="是否公开")
|
||||
load_balance_strategy: Optional[str] = Field(LoadBalanceStrategy.NONE.value, description="负载均衡策略")
|
||||
capability: List[str] = Field(default_factory=list, description="模型能力列表")
|
||||
is_omni: bool = Field(False, description="是否为Omni模型")
|
||||
|
||||
|
||||
class ApiKeyCreateNested(BaseModel):
|
||||
@@ -30,6 +32,8 @@ class ApiKeyCreateNested(BaseModel):
|
||||
provider: Optional[str] = Field(None, description="API Key提供商")
|
||||
api_key: str = Field(..., description="API密钥", max_length=500)
|
||||
api_base: Optional[str] = Field(None, description="API基础URL", max_length=500)
|
||||
capability: Optional[List[str]] = Field(None, description="模型能力列表")
|
||||
is_omni: Optional[bool] = Field(None, description="是否为Omni模型")
|
||||
config: Optional[Dict[str, Any]] = Field({}, description="API Key特定配置")
|
||||
priority: str = Field("1", description="优先级", max_length=10)
|
||||
|
||||
@@ -63,6 +67,8 @@ class ModelConfigUpdate(BaseModel):
|
||||
config: Optional[Dict[str, Any]] = Field(None, description="模型配置参数")
|
||||
is_active: Optional[bool] = Field(None, description="是否激活")
|
||||
is_public: Optional[bool] = Field(None, description="是否公开")
|
||||
capability: Optional[List[str]] = Field(None, description="模型能力列表")
|
||||
is_omni: Optional[bool] = Field(None, description="是否为Omni模型")
|
||||
|
||||
|
||||
class ModelConfig(ModelConfigBase):
|
||||
@@ -95,6 +101,8 @@ class ModelApiKeyCreateByProvider(BaseModel):
|
||||
api_key: str = Field(..., description="API密钥", max_length=500)
|
||||
api_base: Optional[str] = Field(None, description="API基础URL", max_length=500)
|
||||
description: Optional[str] = Field(None, description="备注")
|
||||
capability: Optional[List[str]] = Field(None, description="模型能力列表")
|
||||
is_omni: Optional[bool] = Field(None, description="是否为Omni模型")
|
||||
config: Optional[Dict[str, Any]] = Field({}, description="API Key特定配置")
|
||||
is_active: bool = Field(True, description="是否激活")
|
||||
priority: str = Field("1", description="优先级", max_length=10)
|
||||
@@ -108,6 +116,8 @@ class ModelApiKeyBase(BaseModel):
|
||||
provider: ModelProvider = Field(..., description="API Key提供商")
|
||||
api_key: str = Field(..., description="API密钥", max_length=500)
|
||||
api_base: Optional[str] = Field(None, description="API基础URL", max_length=500)
|
||||
capability: List[str] = Field(default_factory=list, description="模型能力列表")
|
||||
is_omni: bool = Field(False, description="是否为Omni模型")
|
||||
config: Optional[Dict[str, Any]] = Field({}, description="API Key特定配置")
|
||||
is_active: bool = Field(True, description="是否激活")
|
||||
priority: str = Field("1", description="优先级", max_length=10)
|
||||
@@ -124,6 +134,8 @@ class ModelApiKeyUpdate(BaseModel):
|
||||
provider: Optional[ModelProvider] = Field(None, description="API Key提供商")
|
||||
api_key: Optional[str] = Field(None, description="API密钥", max_length=500)
|
||||
api_base: Optional[str] = Field(None, description="API基础URL", max_length=500)
|
||||
capability: Optional[List[str]] = Field(None, description="模型能力列表")
|
||||
is_omni: Optional[bool] = Field(None, description="是否为Omni模型")
|
||||
config: Optional[Dict[str, Any]] = Field(None, description="API Key特定配置")
|
||||
is_active: Optional[bool] = Field(None, description="是否激活")
|
||||
priority: Optional[str] = Field(None, description="优先级", max_length=10)
|
||||
@@ -270,6 +282,8 @@ class ModelBaseCreate(BaseModel):
|
||||
description: Optional[str] = Field(None, description="模型描述")
|
||||
is_official: bool = Field(True, description="是否供应商官方模型")
|
||||
tags: List[str] = Field(default_factory=list, description="模型标签")
|
||||
capability: List[str] = Field(default_factory=list, description="模型能力列表(如['vision', 'audio', 'video'])")
|
||||
is_omni: bool = Field(False, description="是否为Omni模型")
|
||||
|
||||
|
||||
class ModelBaseUpdate(BaseModel):
|
||||
@@ -282,6 +296,8 @@ class ModelBaseUpdate(BaseModel):
|
||||
is_deprecated: Optional[bool] = Field(None, description="是否弃用")
|
||||
is_official: Optional[bool] = Field(None, description="是否供应商官方模型")
|
||||
tags: Optional[List[str]] = Field(None, description="模型标签")
|
||||
capability: Optional[List[str]] = Field(None, description="模型能力列表")
|
||||
is_omni: Optional[bool] = Field(None, description="是否为Omni模型")
|
||||
|
||||
|
||||
class ModelBase(BaseModel):
|
||||
@@ -298,6 +314,8 @@ class ModelBase(BaseModel):
|
||||
is_official: bool
|
||||
tags: List[str]
|
||||
add_count: int
|
||||
capability: List[str] = []
|
||||
is_omni: bool = False
|
||||
|
||||
|
||||
class ModelBaseQuery(BaseModel):
|
||||
|
||||
@@ -64,14 +64,14 @@ class ExecutionConfig(BaseModel):
|
||||
class MultiAgentConfigCreate(BaseModel):
|
||||
"""创建多 Agent 配置"""
|
||||
master_agent_id: uuid.UUID = Field(..., description="主 Agent ID")
|
||||
master_agent_name: Optional[str] = Field(None, max_length=100, description="主 Agent 名称")
|
||||
master_agent_name: Optional[str] = Field(default=None, max_length=100, description="主 Agent 名称")
|
||||
orchestration_mode: str = Field(
|
||||
default="collaboration",
|
||||
pattern="^(collaboration|supervisor)$",
|
||||
description="协作模式:collaboration(协作)| supervisor(监督)"
|
||||
)
|
||||
sub_agents: List[SubAgentConfig] = Field(..., description="子 Agent 列表")
|
||||
routing_rules: Optional[List[RoutingRule]] = Field(None, description="路由规则")
|
||||
routing_rules: Optional[List[RoutingRule]] = Field(default=None, description="路由规则")
|
||||
execution_config: ExecutionConfig = Field(default_factory=ExecutionConfig, description="执行配置")
|
||||
aggregation_strategy: str = Field(
|
||||
default="merge",
|
||||
@@ -83,7 +83,7 @@ class MultiAgentConfigCreate(BaseModel):
|
||||
class MultiAgentConfigUpdate(BaseModel):
|
||||
"""更新多 Agent 配置"""
|
||||
master_agent_id: Optional[uuid.UUID] = None
|
||||
master_agent_name: Optional[str] = Field(None, max_length=100, description="主 Agent 名称")
|
||||
master_agent_name: Optional[str] = Field(default=None, max_length=100, description="主 Agent 名称")
|
||||
default_model_config_id: Optional[uuid.UUID] = Field(None, description="默认模型配置ID")
|
||||
model_parameters: Optional[ModelParameters] = Field(
|
||||
None,
|
||||
|
||||
@@ -263,8 +263,8 @@ def create_agent_invocation_tool(
|
||||
|
||||
try:
|
||||
# 9. 调用 Agent
|
||||
from app.services.draft_run_service import DraftRunService
|
||||
draft_service = DraftRunService(db)
|
||||
from app.services.draft_run_service import AgentRunService
|
||||
draft_service = AgentRunService(db)
|
||||
|
||||
result = await draft_service.run(
|
||||
agent_config=agent_config,
|
||||
|
||||
@@ -10,25 +10,24 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.agent.agent_middleware import AgentMiddleware
|
||||
from app.core.agent.langchain_agent import LangChainAgent
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.db import get_db, get_db_context
|
||||
from app.models import MultiAgentConfig, AgentConfig, WorkflowConfig
|
||||
from app.schemas import DraftRunRequest
|
||||
from app.schemas.app_schema import FileInput
|
||||
from app.services.tool_service import ToolService
|
||||
from app.repositories.tool_repository import ToolRepository
|
||||
from app.db import get_db
|
||||
from app.models import MultiAgentConfig, AgentConfig
|
||||
from app.models import WorkflowConfig
|
||||
from app.repositories.tool_repository import ToolRepository
|
||||
from app.schemas import DraftRunRequest
|
||||
from app.schemas.app_schema import FileInput
|
||||
from app.schemas.prompt_schema import render_prompt_message, PromptMessageRole
|
||||
from app.services.conversation_service import ConversationService
|
||||
from app.services.draft_run_service import create_knowledge_retrieval_tool, create_long_term_memory_tool
|
||||
from app.services.draft_run_service import create_knowledge_retrieval_tool, create_long_term_memory_tool, \
|
||||
AgentRunService
|
||||
from app.services.draft_run_service import create_web_search_tool
|
||||
from app.services.model_service import ModelApiKeyService
|
||||
from app.services.multi_agent_orchestrator import MultiAgentOrchestrator
|
||||
from app.services.workflow_service import WorkflowService
|
||||
from app.services.multimodal_service import MultimodalService
|
||||
from app.services.tool_service import ToolService
|
||||
from app.services.workflow_service import WorkflowService
|
||||
|
||||
logger = get_business_logger()
|
||||
|
||||
@@ -39,6 +38,8 @@ class AppChatService:
|
||||
def __init__(self, db: Session):
|
||||
self.db = db
|
||||
self.conversation_service = ConversationService(db)
|
||||
self.agent_service = AgentRunService(db)
|
||||
self.workflow_service = WorkflowService(db)
|
||||
|
||||
async def agnet_chat(
|
||||
self,
|
||||
@@ -55,12 +56,10 @@ class AppChatService:
|
||||
files: Optional[List[FileInput]] = None # 新增:多模态文件
|
||||
) -> Dict[str, Any]:
|
||||
"""聊天(非流式)"""
|
||||
|
||||
start_time = time.time()
|
||||
config_id = None
|
||||
|
||||
if variables is None:
|
||||
variables = {}
|
||||
variables = self.agent_service.prepare_variables(variables, config.variables)
|
||||
|
||||
# 获取模型配置ID
|
||||
model_config_id = config.default_model_config_id
|
||||
@@ -79,74 +78,20 @@ class AppChatService:
|
||||
tools = []
|
||||
|
||||
# 获取工具服务
|
||||
tool_service = ToolService(self.db)
|
||||
tenant_id = ToolRepository.get_tenant_id_by_workspace_id(self.db, str(workspace_id))
|
||||
|
||||
# 从配置中获取启用的工具
|
||||
if hasattr(config, 'tools') and config.tools and isinstance(config.tools, list):
|
||||
for tool_config in config.tools:
|
||||
if tool_config.get("enabled", False):
|
||||
# 根据工具名称查找工具实例
|
||||
tool_instance = tool_service._get_tool_instance(tool_config.get("tool_id", ""), tenant_id)
|
||||
if tool_instance:
|
||||
if tool_instance.name == "baidu_search_tool" and not web_search:
|
||||
continue
|
||||
# 转换为LangChain工具
|
||||
langchain_tool = tool_instance.to_langchain_tool(tool_config.get("operation", None))
|
||||
tools.append(langchain_tool)
|
||||
elif hasattr(config, 'tools') and config.tools and isinstance(config.tools, dict):
|
||||
web_tools = config.tools
|
||||
web_search_choice = web_tools.get("web_search", {})
|
||||
web_search_enable = web_search_choice.get("enabled", False)
|
||||
if web_search:
|
||||
if web_search_enable:
|
||||
search_tool = create_web_search_tool({})
|
||||
tools.append(search_tool)
|
||||
|
||||
logger.debug(
|
||||
"已添加网络搜索工具",
|
||||
extra={
|
||||
"tool_count": len(tools)
|
||||
}
|
||||
)
|
||||
|
||||
# 加载技能关联的工具
|
||||
if hasattr(config, 'skills') and config.skills:
|
||||
skills = config.skills
|
||||
skill_enable = skills.get("enabled", False)
|
||||
if skill_enable:
|
||||
middleware = AgentMiddleware(skills=skills)
|
||||
skill_tools, skill_configs, tool_to_skill_map = middleware.load_skill_tools(self.db, tenant_id)
|
||||
tools.extend(skill_tools)
|
||||
logger.debug(f"已加载 {len(skill_tools)} 个技能工具")
|
||||
|
||||
# 应用动态过滤
|
||||
if skill_configs:
|
||||
tools, activated_skill_ids = middleware.filter_tools(tools, message, skill_configs,
|
||||
tool_to_skill_map)
|
||||
logger.debug(f"过滤后剩余 {len(tools)} 个工具")
|
||||
active_prompts = AgentMiddleware.get_active_prompts(
|
||||
activated_skill_ids, skill_configs
|
||||
)
|
||||
system_prompt = f"{system_prompt}\n\n{active_prompts}"
|
||||
|
||||
# 添加知识库检索工具
|
||||
knowledge_retrieval = config.knowledge_retrieval
|
||||
if knowledge_retrieval:
|
||||
knowledge_bases = knowledge_retrieval.get("knowledge_bases", [])
|
||||
kb_ids = [kb.get("kb_id") for kb in knowledge_bases if kb.get("kb_id")]
|
||||
if kb_ids:
|
||||
kb_tool = create_knowledge_retrieval_tool(knowledge_retrieval, kb_ids, user_id)
|
||||
tools.append(kb_tool)
|
||||
|
||||
# 添加长期记忆工具
|
||||
tools.extend(self.agent_service.load_tools_config(config.tools, web_search, tenant_id))
|
||||
skill_tools, skill_prompts = self.agent_service.load_skill_config(config.skills, message, tenant_id)
|
||||
tools.extend(skill_tools)
|
||||
if skill_prompts:
|
||||
system_prompt = f"{system_prompt}\n\n{skill_prompts}"
|
||||
tools.extend(self.agent_service.load_knowledge_retrieval_config(config.knowledge_retrieval, user_id))
|
||||
memory_flag = False
|
||||
if memory == True:
|
||||
memory_config = config.memory
|
||||
if memory_config.get("enabled") and user_id:
|
||||
memory_flag = True
|
||||
memory_tool = create_long_term_memory_tool(memory_config, user_id)
|
||||
tools.append(memory_tool)
|
||||
if memory:
|
||||
memory_tools, memory_flag = self.agent_service.load_memory_config(
|
||||
config.memory, user_id, storage_type, user_rag_memory_id
|
||||
)
|
||||
tools.extend(memory_tools)
|
||||
|
||||
# 获取模型参数
|
||||
model_parameters = config.model_parameters
|
||||
@@ -157,6 +102,7 @@ class AppChatService:
|
||||
api_key=api_key_obj.api_key,
|
||||
provider=api_key_obj.provider,
|
||||
api_base=api_key_obj.api_base,
|
||||
is_omni=api_key_obj.is_omni,
|
||||
temperature=model_parameters.get("temperature", 0.7),
|
||||
max_tokens=model_parameters.get("max_tokens", 2000),
|
||||
system_prompt=system_prompt,
|
||||
@@ -180,7 +126,7 @@ class AppChatService:
|
||||
# 处理多模态文件
|
||||
processed_files = None
|
||||
if files:
|
||||
multimodal_service = MultimodalService(self.db)
|
||||
multimodal_service = MultimodalService(self.db, api_key_obj.provider, is_omni=api_key_obj.is_omni)
|
||||
processed_files = await multimodal_service.process_files(files)
|
||||
logger.info(f"处理了 {len(processed_files)} 个文件")
|
||||
|
||||
@@ -245,10 +191,9 @@ class AppChatService:
|
||||
try:
|
||||
start_time = time.time()
|
||||
config_id = None
|
||||
yield f"event: start\ndata: {json.dumps({'conversation_id': str(conversation_id)}, ensure_ascii=False)}\n\n"
|
||||
|
||||
if variables is None:
|
||||
variables = {}
|
||||
|
||||
variables = self.agent_service.prepare_variables(variables, config.variables)
|
||||
# 获取模型配置ID
|
||||
model_config_id = config.default_model_config_id
|
||||
api_key_obj = ModelApiKeyService.get_available_api_key(self.db, model_config_id)
|
||||
@@ -266,73 +211,22 @@ class AppChatService:
|
||||
tools = []
|
||||
|
||||
# 获取工具服务
|
||||
tool_service = ToolService(self.db)
|
||||
tenant_id = ToolRepository.get_tenant_id_by_workspace_id(self.db, str(workspace_id))
|
||||
|
||||
if hasattr(config, 'tools') and config.tools and isinstance(config.tools, list):
|
||||
for tool_config in config.tools:
|
||||
if tool_config.get("enabled", False):
|
||||
# 根据工具名称查找工具实例
|
||||
tool_instance = tool_service._get_tool_instance(tool_config.get("tool_id", ""), tenant_id)
|
||||
if tool_instance:
|
||||
if tool_instance.name == "baidu_search_tool" and not web_search:
|
||||
continue
|
||||
# 转换为LangChain工具
|
||||
langchain_tool = tool_instance.to_langchain_tool(tool_config.get("operation", None))
|
||||
tools.append(langchain_tool)
|
||||
elif hasattr(config, 'tools') and config.tools and isinstance(config.tools, dict):
|
||||
web_tools = config.tools
|
||||
web_search_choice = web_tools.get("web_search", {})
|
||||
web_search_enable = web_search_choice.get("enabled", False)
|
||||
if web_search:
|
||||
if web_search_enable:
|
||||
search_tool = create_web_search_tool({})
|
||||
tools.append(search_tool)
|
||||
|
||||
logger.debug(
|
||||
"已添加网络搜索工具",
|
||||
extra={
|
||||
"tool_count": len(tools)
|
||||
}
|
||||
)
|
||||
|
||||
# 加载技能关联的工具
|
||||
if hasattr(config, 'skills') and config.skills:
|
||||
skills = config.skills
|
||||
skill_enable = skills.get("enabled", False)
|
||||
if skill_enable:
|
||||
middleware = AgentMiddleware(skills=skills)
|
||||
skill_tools, skill_configs, tool_to_skill_map = middleware.load_skill_tools(self.db, tenant_id)
|
||||
tools.extend(skill_tools)
|
||||
logger.debug(f"已加载 {len(skill_tools)} 个技能工具")
|
||||
|
||||
# 应用动态过滤
|
||||
if skill_configs:
|
||||
tools, activated_skill_ids = middleware.filter_tools(tools, message, skill_configs,
|
||||
tool_to_skill_map)
|
||||
logger.debug(f"过滤后剩余 {len(tools)} 个工具")
|
||||
active_prompts = AgentMiddleware.get_active_prompts(
|
||||
activated_skill_ids, skill_configs
|
||||
)
|
||||
system_prompt = f"{system_prompt}\n\n{active_prompts}"
|
||||
|
||||
# 添加知识库检索工具
|
||||
knowledge_retrieval = config.knowledge_retrieval
|
||||
if knowledge_retrieval:
|
||||
knowledge_bases = knowledge_retrieval.get("knowledge_bases", [])
|
||||
kb_ids = [kb.get("kb_id") for kb in knowledge_bases if kb.get("kb_id")]
|
||||
if kb_ids:
|
||||
kb_tool = create_knowledge_retrieval_tool(knowledge_retrieval, kb_ids, user_id)
|
||||
tools.append(kb_tool)
|
||||
tools.extend(self.agent_service.load_tools_config(config.tools, web_search, tenant_id))
|
||||
|
||||
skill_tools, skill_prompts = self.agent_service.load_skill_config(config.skills, message, tenant_id)
|
||||
tools.extend(skill_tools)
|
||||
if skill_prompts:
|
||||
system_prompt = f"{system_prompt}\n\n{skill_prompts}"
|
||||
tools.extend(self.agent_service.load_knowledge_retrieval_config(config.knowledge_retrieval, user_id))
|
||||
# 添加长期记忆工具
|
||||
memory_flag = False
|
||||
if memory:
|
||||
memory_config = config.memory
|
||||
if memory_config.get("enabled") and user_id:
|
||||
memory_flag = True
|
||||
memory_tool = create_long_term_memory_tool(memory_config, user_id)
|
||||
tools.append(memory_tool)
|
||||
memory_tools, memory_flag = self.agent_service.load_memory_config(
|
||||
config.memory, user_id, storage_type, user_rag_memory_id
|
||||
)
|
||||
tools.extend(memory_tools)
|
||||
|
||||
# 获取模型参数
|
||||
model_parameters = config.model_parameters
|
||||
@@ -343,6 +237,7 @@ class AppChatService:
|
||||
api_key=api_key_obj.api_key,
|
||||
provider=api_key_obj.provider,
|
||||
api_base=api_key_obj.api_base,
|
||||
is_omni=api_key_obj.is_omni,
|
||||
temperature=model_parameters.get("temperature", 0.7),
|
||||
max_tokens=model_parameters.get("max_tokens", 2000),
|
||||
system_prompt=system_prompt,
|
||||
@@ -366,13 +261,10 @@ class AppChatService:
|
||||
# 处理多模态文件
|
||||
processed_files = None
|
||||
if files:
|
||||
multimodal_service = MultimodalService(self.db)
|
||||
multimodal_service = MultimodalService(self.db, api_key_obj.provider, is_omni=api_key_obj.is_omni)
|
||||
processed_files = await multimodal_service.process_files(files)
|
||||
logger.info(f"处理了 {len(processed_files)} 个文件")
|
||||
|
||||
# 发送开始事件
|
||||
yield f"event: start\ndata: {json.dumps({'conversation_id': str(conversation_id)}, ensure_ascii=False)}\n\n"
|
||||
|
||||
# 流式调用 Agent(支持多模态)
|
||||
full_content = ""
|
||||
total_tokens = 0
|
||||
@@ -416,7 +308,7 @@ class AppChatService:
|
||||
ModelApiKeyService.record_api_key_usage(self.db, api_key_obj.id)
|
||||
|
||||
# 发送结束事件
|
||||
end_data = {"elapsed_time": elapsed_time, "message_length": len(full_content)}
|
||||
end_data = {"elapsed_time": elapsed_time, "message_length": len(full_content), "error": None}
|
||||
yield f"event: end\ndata: {json.dumps(end_data, ensure_ascii=False)}\n\n"
|
||||
|
||||
logger.info(
|
||||
@@ -435,7 +327,7 @@ class AppChatService:
|
||||
except Exception as e:
|
||||
logger.error(f"流式聊天失败: {str(e)}", exc_info=True)
|
||||
# 发送错误事件
|
||||
yield f"event: error\ndata: {json.dumps({'error': str(e)}, ensure_ascii=False)}\n\n"
|
||||
yield f"event: end\ndata: {json.dumps({'error': str(e)}, ensure_ascii=False)}\n\n"
|
||||
|
||||
async def multi_agent_chat(
|
||||
self,
|
||||
@@ -489,10 +381,10 @@ class AppChatService:
|
||||
"mode": result.get("mode"),
|
||||
"elapsed_time": result.get("elapsed_time"),
|
||||
"usage": result.get("usage", {
|
||||
"prompt_tokens": 0,
|
||||
"completion_tokens": 0,
|
||||
"total_tokens": 0
|
||||
})
|
||||
"prompt_tokens": 0,
|
||||
"completion_tokens": 0,
|
||||
"total_tokens": 0
|
||||
})
|
||||
}
|
||||
)
|
||||
|
||||
@@ -522,8 +414,6 @@ class AppChatService:
|
||||
"""多 Agent 聊天(流式)"""
|
||||
|
||||
start_time = time.time()
|
||||
actual_config_id = None
|
||||
config_id = actual_config_id
|
||||
|
||||
if variables is None:
|
||||
variables = {}
|
||||
@@ -629,7 +519,6 @@ class AppChatService:
|
||||
user_rag_memory_id: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""聊天(非流式)"""
|
||||
workflow_service = WorkflowService(self.db)
|
||||
payload = DraftRunRequest(
|
||||
message=message,
|
||||
variables=variables,
|
||||
@@ -637,7 +526,7 @@ class AppChatService:
|
||||
stream=True,
|
||||
user_id=user_id
|
||||
)
|
||||
return await workflow_service.run(
|
||||
return await self.workflow_service.run(
|
||||
app_id=app_id,
|
||||
payload=payload,
|
||||
config=config,
|
||||
@@ -664,7 +553,6 @@ class AppChatService:
|
||||
|
||||
) -> AsyncGenerator[dict, None]:
|
||||
"""聊天(流式)"""
|
||||
workflow_service = WorkflowService(self.db)
|
||||
payload = DraftRunRequest(
|
||||
message=message,
|
||||
variables=variables,
|
||||
@@ -673,7 +561,7 @@ class AppChatService:
|
||||
user_id=user_id,
|
||||
files=files
|
||||
)
|
||||
async for event in workflow_service.run_stream(
|
||||
async for event in self.workflow_service.run_stream(
|
||||
app_id=app_id,
|
||||
payload=payload,
|
||||
config=config,
|
||||
|
||||
@@ -232,7 +232,7 @@ class AppService:
|
||||
# 检查主 Agent 的模型配置
|
||||
multi_agent_config.default_model_config_id = master_agent_release.default_model_config_id
|
||||
|
||||
model_api_key = ModelApiKeyService.get_a_api_key(self.db, multi_agent_config.default_model_config_id)
|
||||
model_api_key = ModelApiKeyService.get_available_api_key(self.db, multi_agent_config.default_model_config_id)
|
||||
if not model_api_key:
|
||||
raise ResourceNotFoundException("模型配置", str(multi_agent_config.default_model_config_id))
|
||||
|
||||
@@ -1791,372 +1791,6 @@ class AppService:
|
||||
|
||||
return shares
|
||||
|
||||
# ==================== 试运行功能 ====================
|
||||
|
||||
async def draft_run(
|
||||
self,
|
||||
*,
|
||||
app_id: uuid.UUID,
|
||||
message: str,
|
||||
conversation_id: Optional[str] = None,
|
||||
user_id: Optional[str] = None,
|
||||
variables: Optional[Dict[str, Any]] = None,
|
||||
workspace_id: Optional[uuid.UUID] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""试运行 Agent(使用当前草稿配置)
|
||||
|
||||
Args:
|
||||
app_id: 应用ID
|
||||
message: 用户消息
|
||||
conversation_id: 会话ID(用于多轮对话)
|
||||
user_id: 用户ID(用于会话管理)
|
||||
variables: 自定义变量参数值
|
||||
workspace_id: 工作空间ID(用于权限验证)
|
||||
|
||||
Returns:
|
||||
Dict: 包含 AI 回复和元数据的字典
|
||||
|
||||
Raises:
|
||||
ResourceNotFoundException: 当应用不存在时
|
||||
BusinessException: 当应用类型不支持或配置缺失时
|
||||
"""
|
||||
from app.services.draft_run_service import DraftRunService
|
||||
|
||||
logger.info("试运行 Agent", extra={"app_id": str(app_id), "user_message": message[:50]})
|
||||
|
||||
# 1. 验证应用
|
||||
app = self._get_app_or_404(app_id)
|
||||
|
||||
if app.type != "agent":
|
||||
raise BusinessException("只有 Agent 类型应用支持试运行", BizCode.APP_TYPE_NOT_SUPPORTED)
|
||||
|
||||
# 只读操作,允许访问共享应用
|
||||
self._validate_app_accessible(app, workspace_id)
|
||||
|
||||
# 2. 获取 Agent 配置
|
||||
stmt = select(AgentConfig).where(AgentConfig.app_id == app_id)
|
||||
agent_cfg = self.db.scalars(stmt).first()
|
||||
|
||||
if not agent_cfg:
|
||||
raise BusinessException("Agent 配置不存在,无法试运行", BizCode.AGENT_CONFIG_MISSING)
|
||||
|
||||
# 3. 获取模型配置
|
||||
model_config = None
|
||||
if agent_cfg.default_model_config_id:
|
||||
from app.models import ModelConfig
|
||||
model_config = self.db.get(ModelConfig, agent_cfg.default_model_config_id)
|
||||
|
||||
if not model_config:
|
||||
raise BusinessException("模型配置不存在,无法试运行", BizCode.AGENT_CONFIG_MISSING)
|
||||
|
||||
# 4. 调用试运行服务
|
||||
logger.debug(
|
||||
"准备调用试运行服务",
|
||||
extra={
|
||||
"app_id": str(app_id),
|
||||
"model": model_config.name,
|
||||
"has_conversation_id": bool(conversation_id),
|
||||
"has_variables": bool(variables)
|
||||
}
|
||||
)
|
||||
|
||||
draft_service = DraftRunService(self.db)
|
||||
result = await draft_service.run(
|
||||
agent_config=agent_cfg,
|
||||
model_config=model_config,
|
||||
message=message,
|
||||
workspace_id=workspace_id,
|
||||
conversation_id=conversation_id,
|
||||
user_id=user_id,
|
||||
variables=variables
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
"试运行服务返回结果",
|
||||
extra={
|
||||
"result_type": str(type(result)),
|
||||
"result_keys": list(result.keys()) if isinstance(result, dict) else "not_dict",
|
||||
"has_message": "message" in result if isinstance(result, dict) else False,
|
||||
"has_conversation_id": "conversation_id" in result if isinstance(result, dict) else False
|
||||
}
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"试运行完成",
|
||||
extra={
|
||||
"app_id": str(app_id),
|
||||
"elapsed_time": result.get("elapsed_time"),
|
||||
"model": model_config.name
|
||||
}
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
async def draft_run_stream(
|
||||
self,
|
||||
*,
|
||||
app_id: uuid.UUID,
|
||||
message: str,
|
||||
conversation_id: Optional[str] = None,
|
||||
user_id: Optional[str] = None,
|
||||
variables: Optional[Dict[str, Any]] = None,
|
||||
workspace_id: Optional[uuid.UUID] = None
|
||||
):
|
||||
"""试运行 Agent(流式返回)
|
||||
|
||||
Args:
|
||||
app_id: 应用ID
|
||||
message: 用户消息
|
||||
conversation_id: 会话ID(用于多轮对话)
|
||||
user_id: 用户ID(用于会话管理)
|
||||
variables: 自定义变量参数值
|
||||
workspace_id: 工作空间ID(用于权限验证)
|
||||
|
||||
Yields:
|
||||
str: SSE 格式的事件数据
|
||||
|
||||
Raises:
|
||||
ResourceNotFoundException: 当应用不存在时
|
||||
BusinessException: 当应用类型不支持或配置缺失时
|
||||
"""
|
||||
from app.services.draft_run_service import DraftRunService
|
||||
|
||||
logger.info("流式试运行 Agent", extra={"app_id": str(app_id), "user_message": message[:50]})
|
||||
|
||||
# 1. 验证应用
|
||||
app = self._get_app_or_404(app_id)
|
||||
|
||||
if app.type != "agent":
|
||||
raise BusinessException("只有 Agent 类型应用支持试运行", BizCode.APP_TYPE_NOT_SUPPORTED)
|
||||
|
||||
# 只读操作,允许访问共享应用
|
||||
self._validate_app_accessible(app, workspace_id)
|
||||
|
||||
# 2. 获取 Agent 配置
|
||||
stmt = select(AgentConfig).where(AgentConfig.app_id == app_id)
|
||||
agent_cfg = self.db.scalars(stmt).first()
|
||||
|
||||
if not agent_cfg:
|
||||
raise BusinessException("Agent 配置不存在,无法试运行", BizCode.AGENT_CONFIG_MISSING)
|
||||
|
||||
# 3. 获取模型配置
|
||||
model_config = None
|
||||
if agent_cfg.default_model_config_id:
|
||||
from app.models import ModelConfig
|
||||
model_config = self.db.get(ModelConfig, agent_cfg.default_model_config_id)
|
||||
|
||||
if not model_config:
|
||||
raise BusinessException("模型配置不存在,无法试运行", BizCode.AGENT_CONFIG_MISSING)
|
||||
|
||||
# 4. 调用流式试运行服务
|
||||
draft_service = DraftRunService(self.db)
|
||||
async for event in draft_service.run_stream(
|
||||
agent_config=agent_cfg,
|
||||
model_config=model_config,
|
||||
message=message,
|
||||
workspace_id=workspace_id,
|
||||
conversation_id=conversation_id,
|
||||
user_id=user_id,
|
||||
variables=variables
|
||||
):
|
||||
yield event
|
||||
|
||||
# ==================== 多模型对比试运行 ====================
|
||||
|
||||
async def draft_run_compare(
|
||||
self,
|
||||
*,
|
||||
app_id: uuid.UUID,
|
||||
message: str,
|
||||
models: List[app_schema.ModelCompareItem],
|
||||
conversation_id: Optional[str] = None,
|
||||
user_id: Optional[str] = None,
|
||||
variables: Optional[Dict[str, Any]] = None,
|
||||
workspace_id: Optional[uuid.UUID] = None,
|
||||
parallel: bool = True,
|
||||
timeout: int = 60
|
||||
) -> Dict[str, Any]:
|
||||
"""多模型对比试运行
|
||||
|
||||
Args:
|
||||
app_id: 应用ID
|
||||
message: 用户消息
|
||||
models: 要对比的模型列表
|
||||
conversation_id: 会话ID
|
||||
user_id: 用户ID
|
||||
variables: 变量参数
|
||||
workspace_id: 工作空间ID
|
||||
parallel: 是否并行执行
|
||||
timeout: 超时时间(秒)
|
||||
|
||||
Returns:
|
||||
Dict: 对比结果
|
||||
"""
|
||||
from app.models import ModelConfig
|
||||
from app.services.draft_run_service import DraftRunService
|
||||
|
||||
logger.info(
|
||||
"多模型对比试运行",
|
||||
extra={
|
||||
"app_id": str(app_id),
|
||||
"model_count": len(models),
|
||||
"parallel": parallel
|
||||
}
|
||||
)
|
||||
|
||||
# 1. 验证应用
|
||||
app = self._get_app_or_404(app_id)
|
||||
if app.type != "agent":
|
||||
raise BusinessException("只有 Agent 类型应用支持试运行", BizCode.APP_TYPE_NOT_SUPPORTED)
|
||||
|
||||
# 只读操作,允许访问共享应用
|
||||
self._validate_app_accessible(app, workspace_id)
|
||||
|
||||
# 2. 获取 Agent 配置
|
||||
stmt = select(AgentConfig).where(AgentConfig.app_id == app_id)
|
||||
agent_cfg = self.db.scalars(stmt).first()
|
||||
if not agent_cfg:
|
||||
raise BusinessException("Agent 配置不存在", BizCode.AGENT_CONFIG_MISSING)
|
||||
|
||||
# 3. 准备所有模型配置
|
||||
model_configs = []
|
||||
for model_item in models:
|
||||
model_config = self.db.get(ModelConfig, model_item.model_config_id)
|
||||
if not model_config:
|
||||
raise ResourceNotFoundException("模型配置", str(model_item.model_config_id))
|
||||
|
||||
# 合并参数:agent配置参数 + 请求覆盖参数
|
||||
merged_parameters = {
|
||||
**(agent_cfg.model_parameters or {}),
|
||||
**(model_item.model_parameters or {})
|
||||
}
|
||||
|
||||
model_configs.append({
|
||||
"model_config": model_config,
|
||||
"parameters": merged_parameters,
|
||||
"label": model_item.label or model_config.name,
|
||||
"model_config_id": model_item.model_config_id
|
||||
})
|
||||
|
||||
# 4. 调用 DraftRunService 的对比方法
|
||||
draft_service = DraftRunService(self.db)
|
||||
result = await draft_service.run_compare(
|
||||
agent_config=agent_cfg,
|
||||
models=model_configs,
|
||||
message=message,
|
||||
workspace_id=workspace_id,
|
||||
conversation_id=conversation_id,
|
||||
user_id=user_id,
|
||||
variables=variables,
|
||||
parallel=parallel,
|
||||
timeout=timeout
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"多模型对比完成",
|
||||
extra={
|
||||
"app_id": str(app_id),
|
||||
"successful": result["successful_count"],
|
||||
"failed": result["failed_count"]
|
||||
}
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
async def draft_run_compare_stream(
|
||||
self,
|
||||
*,
|
||||
app_id: uuid.UUID,
|
||||
message: str,
|
||||
models: List[app_schema.ModelCompareItem],
|
||||
conversation_id: Optional[str] = None,
|
||||
user_id: Optional[str] = None,
|
||||
variables: Optional[Dict[str, Any]] = None,
|
||||
workspace_id: Optional[uuid.UUID] = None,
|
||||
parallel: bool = True,
|
||||
timeout: int = 60
|
||||
):
|
||||
"""多模型对比试运行(流式返回)
|
||||
|
||||
Args:
|
||||
app_id: 应用ID
|
||||
message: 用户消息
|
||||
models: 要对比的模型列表
|
||||
conversation_id: 会话ID
|
||||
user_id: 用户ID
|
||||
variables: 变量参数
|
||||
workspace_id: 工作空间ID
|
||||
timeout: 超时时间(秒)
|
||||
|
||||
Yields:
|
||||
str: SSE 格式的事件数据
|
||||
"""
|
||||
from app.models import ModelConfig
|
||||
from app.services.draft_run_service import DraftRunService
|
||||
|
||||
logger.info(
|
||||
"多模型对比流式试运行",
|
||||
extra={
|
||||
"app_id": str(app_id),
|
||||
"model_count": len(models)
|
||||
}
|
||||
)
|
||||
|
||||
# 1. 验证应用
|
||||
app = self._get_app_or_404(app_id)
|
||||
if app.type != "agent":
|
||||
raise BusinessException("只有 Agent 类型应用支持试运行", BizCode.APP_TYPE_NOT_SUPPORTED)
|
||||
|
||||
# 只读操作,允许访问共享应用
|
||||
self._validate_app_accessible(app, workspace_id)
|
||||
|
||||
# 2. 获取 Agent 配置
|
||||
stmt = select(AgentConfig).where(AgentConfig.app_id == app_id)
|
||||
agent_cfg = self.db.scalars(stmt).first()
|
||||
if not agent_cfg:
|
||||
raise BusinessException("Agent 配置不存在", BizCode.AGENT_CONFIG_MISSING)
|
||||
|
||||
# 3. 准备所有模型配置
|
||||
model_configs = []
|
||||
for model_item in models:
|
||||
model_config = self.db.get(ModelConfig, model_item.model_config_id)
|
||||
if not model_config:
|
||||
raise ResourceNotFoundException("模型配置", str(model_item.model_config_id))
|
||||
|
||||
# 合并参数:agent配置参数 + 请求覆盖参数
|
||||
merged_parameters = {
|
||||
**(agent_cfg.model_parameters or {}),
|
||||
**(model_item.model_parameters or {})
|
||||
}
|
||||
|
||||
model_configs.append({
|
||||
"model_config": model_config,
|
||||
"parameters": merged_parameters,
|
||||
"label": model_item.label or model_config.name,
|
||||
"model_config_id": model_item.model_config_id
|
||||
})
|
||||
|
||||
# 4. 调用 DraftRunService 的流式对比方法
|
||||
draft_service = DraftRunService(self.db)
|
||||
async for event in draft_service.run_compare_stream(
|
||||
agent_config=agent_cfg,
|
||||
models=model_configs,
|
||||
message=message,
|
||||
workspace_id=workspace_id,
|
||||
conversation_id=conversation_id,
|
||||
user_id=user_id,
|
||||
variables=variables,
|
||||
parallel=parallel,
|
||||
timeout=timeout
|
||||
):
|
||||
yield event
|
||||
|
||||
logger.info(
|
||||
"多模型对比流式完成",
|
||||
extra={"app_id": str(app_id)}
|
||||
)
|
||||
|
||||
|
||||
# ==================== 向后兼容的函数接口 ====================
|
||||
# 保留函数接口以兼容现有代码,但内部使用服务类
|
||||
|
||||
@@ -2278,53 +1912,6 @@ def get_apps_by_ids(
|
||||
return service.get_apps_by_ids(app_ids, workspace_id)
|
||||
|
||||
|
||||
# ==================== 向后兼容的函数接口 ====================
|
||||
|
||||
async def draft_run(
|
||||
db: Session,
|
||||
*,
|
||||
app_id: uuid.UUID,
|
||||
message: str,
|
||||
conversation_id: Optional[str] = None,
|
||||
user_id: Optional[str] = None,
|
||||
variables: Optional[Dict[str, Any]] = None,
|
||||
workspace_id: Optional[uuid.UUID] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""试运行 Agent(向后兼容接口)"""
|
||||
service = AppService(db)
|
||||
return await service.draft_run(
|
||||
app_id=app_id,
|
||||
message=message,
|
||||
conversation_id=conversation_id,
|
||||
user_id=user_id,
|
||||
variables=variables,
|
||||
workspace_id=workspace_id
|
||||
)
|
||||
|
||||
|
||||
async def draft_run_stream(
|
||||
db: Session,
|
||||
*,
|
||||
app_id: uuid.UUID,
|
||||
message: str,
|
||||
conversation_id: Optional[str] = None,
|
||||
user_id: Optional[str] = None,
|
||||
variables: Optional[Dict[str, Any]] = None,
|
||||
workspace_id: Optional[uuid.UUID] = None
|
||||
):
|
||||
"""试运行 Agent 流式返回(向后兼容接口)"""
|
||||
service = AppService(db)
|
||||
async for event in service.draft_run_stream(
|
||||
app_id=app_id,
|
||||
message=message,
|
||||
conversation_id=conversation_id,
|
||||
user_id=user_id,
|
||||
variables=variables,
|
||||
workspace_id=workspace_id
|
||||
):
|
||||
yield event
|
||||
|
||||
|
||||
# ==================== 依赖注入函数 ====================
|
||||
|
||||
def get_app_service(
|
||||
|
||||
101
api/app/services/audio_transcription_service.py
Normal file
101
api/app/services/audio_transcription_service.py
Normal file
@@ -0,0 +1,101 @@
|
||||
"""
|
||||
音频转文本服务
|
||||
|
||||
支持的服务商:
|
||||
- DashScope (阿里云通义千问)
|
||||
- OpenAI Whisper
|
||||
"""
|
||||
import httpx
|
||||
|
||||
from app.core.logging_config import get_business_logger
|
||||
|
||||
logger = get_business_logger()
|
||||
|
||||
|
||||
class AudioTranscriptionService:
|
||||
"""音频转文本服务"""
|
||||
|
||||
@staticmethod
|
||||
async def transcribe_dashscope(audio_url: str, api_key: str) -> str:
|
||||
"""
|
||||
使用阿里云通义千问语音识别服务转换音频为文本
|
||||
|
||||
Args:
|
||||
audio_url: 音频文件 URL
|
||||
api_key: DashScope API Key
|
||||
|
||||
Returns:
|
||||
str: 转录的文本
|
||||
"""
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=60.0) as client:
|
||||
response = await client.post(
|
||||
"https://dashscope.aliyuncs.com/api/v1/services/audio/asr/transcription",
|
||||
headers={
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"Content-Type": "application/json",
|
||||
"X-DashScope-Async": "enable",
|
||||
},
|
||||
json={
|
||||
"model": "paraformer-v2",
|
||||
"input": {
|
||||
"file_urls": [audio_url]
|
||||
},
|
||||
"parameters": {
|
||||
"language_hints": ["zh", "en", "ja", "yue", "ko", "de", "fr", "ru"]
|
||||
}
|
||||
}
|
||||
)
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
|
||||
if result.get("output", {}).get("results"):
|
||||
text = result["output"]["results"][0].get("transcription_text", "")
|
||||
logger.info(f"音频转文本成功: {len(text)} 字符")
|
||||
return text
|
||||
|
||||
return "[音频转文本失败]"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"DashScope 音频转文本失败: {e}")
|
||||
return f"[音频转文本失败: {str(e)}]"
|
||||
|
||||
@staticmethod
|
||||
async def transcribe_openai(audio_url: str, api_key: str) -> str:
|
||||
"""
|
||||
使用 OpenAI Whisper 转换音频为文本
|
||||
|
||||
Args:
|
||||
audio_url: 音频文件 URL
|
||||
api_key: OpenAI API Key
|
||||
|
||||
Returns:
|
||||
str: 转录的文本
|
||||
"""
|
||||
try:
|
||||
# 下载音频文件
|
||||
async with httpx.AsyncClient(timeout=60.0) as client:
|
||||
audio_response = await client.get(audio_url)
|
||||
audio_response.raise_for_status()
|
||||
audio_data = audio_response.content
|
||||
|
||||
# 调用 Whisper API
|
||||
files = {"file": ("audio.mp3", audio_data, "audio/mpeg")}
|
||||
data = {"model": "whisper-1"}
|
||||
|
||||
response = await client.post(
|
||||
"https://api.openai.com/v1/audio/transcriptions",
|
||||
headers={"Authorization": f"Bearer {api_key}"},
|
||||
files=files,
|
||||
data=data
|
||||
)
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
|
||||
text = result.get("text", "")
|
||||
logger.info(f"音频转文本成功: {len(text)} 字符")
|
||||
return text
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"OpenAI Whisper 音频转文本失败: {e}")
|
||||
return f"[音频转文本失败: {str(e)}]"
|
||||
@@ -445,6 +445,7 @@ class CollaborativeOrchestrator:
|
||||
"provider": api_key_config.provider,
|
||||
"api_key": api_key_config.api_key,
|
||||
"api_base": api_key_config.api_base,
|
||||
"is_omni": api_key_config.is_omni,
|
||||
"model_parameters": config_data.get("model_parameters", {}),
|
||||
"api_key_id": api_key_config.id
|
||||
}
|
||||
@@ -511,6 +512,7 @@ class CollaborativeOrchestrator:
|
||||
provider=agent_config["provider"],
|
||||
api_key=agent_config["api_key"],
|
||||
base_url=agent_config.get("api_base"),
|
||||
is_omni=agent_config.get("is_omni", False),
|
||||
extra_params=extra_params
|
||||
)
|
||||
|
||||
|
||||
@@ -17,6 +17,7 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from app.celery_app import celery_app
|
||||
from app.core.agent.agent_middleware import AgentMiddleware
|
||||
from app.core.agent.langchain_agent import LangChainAgent
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.logging_config import get_business_logger
|
||||
@@ -26,6 +27,7 @@ from app.repositories.tool_repository import ToolRepository
|
||||
from app.schemas.app_schema import FileInput
|
||||
from app.schemas.prompt_schema import PromptMessageRole, render_prompt_message
|
||||
from app.services import task_service
|
||||
from app.services.conversation_service import ConversationService
|
||||
from app.services.langchain_tool_server import Search
|
||||
from app.services.memory_agent_service import MemoryAgentService
|
||||
from app.services.model_parameter_merger import ModelParameterMerger
|
||||
@@ -52,8 +54,12 @@ class LongTermMemoryInput(BaseModel):
|
||||
description="经过优化重写的查询问题。请将用户的原始问题重写为更合适的检索形式,包含关键词,上下文和具体描述,注意错词检查并且改写")
|
||||
|
||||
|
||||
def create_long_term_memory_tool(memory_config: Dict[str, Any], end_user_id: str, storage_type: Optional[str] = None,
|
||||
user_rag_memory_id: Optional[str] = None):
|
||||
def create_long_term_memory_tool(
|
||||
memory_config: Dict[str, Any],
|
||||
end_user_id: str,
|
||||
storage_type: Optional[str] = None,
|
||||
user_rag_memory_id: Optional[str] = None
|
||||
):
|
||||
"""创建记忆工具,
|
||||
|
||||
|
||||
@@ -61,6 +67,7 @@ def create_long_term_memory_tool(memory_config: Dict[str, Any], end_user_id: str
|
||||
memory_config: 记忆配置
|
||||
end_user_id: 用户ID
|
||||
storage_type: 存储类型(可选)
|
||||
user_rag_memory_id: 用户RAG记忆ID(可选)
|
||||
|
||||
Returns:
|
||||
长期记忆工具
|
||||
@@ -188,7 +195,9 @@ def create_knowledge_retrieval_tool(kb_config, kb_ids, user_id):
|
||||
"""从知识库中检索相关信息。当用户的问题需要参考知识库、文档或历史记录时,使用此工具进行检索。
|
||||
|
||||
Args:
|
||||
query: 需要检索的问题或关键词
|
||||
kb_config: 知识库配置
|
||||
kb_ids: 知识库ID列表
|
||||
user_id: 用户ID
|
||||
|
||||
Returns:
|
||||
检索到的相关知识内容
|
||||
@@ -232,17 +241,141 @@ def create_knowledge_retrieval_tool(kb_config, kb_ids, user_id):
|
||||
return knowledge_retrieval_tool
|
||||
|
||||
|
||||
class DraftRunService:
|
||||
"""试运行服务类"""
|
||||
class AgentRunService:
|
||||
"""Agent运行服务类"""
|
||||
|
||||
def __init__(self, db: Session):
|
||||
"""初始化试运行服务
|
||||
"""Agent运行服务
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
"""
|
||||
self.db = db
|
||||
|
||||
@staticmethod
|
||||
def prepare_variables(
|
||||
input_vars: dict | None,
|
||||
variables_config: dict
|
||||
) -> dict:
|
||||
input_vars = input_vars or {}
|
||||
for variable in variables_config:
|
||||
if variable.get("required") and variable.get("name") not in input_vars:
|
||||
raise ValueError(f"The required parameter '{variable.get('name')}' was not provided")
|
||||
return input_vars
|
||||
|
||||
def load_tools_config(self, tools_config, web_search, tenant_id) -> list:
|
||||
"""加载工具配置"""
|
||||
if not tools_config:
|
||||
return []
|
||||
tools = []
|
||||
tool_service = ToolService(self.db)
|
||||
|
||||
if tools_config and isinstance(tools_config, list):
|
||||
for tool_config in tools_config:
|
||||
if tool_config.get("enabled", False):
|
||||
# 根据工具名称查找工具实例
|
||||
tool_instance = tool_service.get_tool_instance(tool_config.get("tool_id", ""), tenant_id)
|
||||
if tool_instance:
|
||||
if tool_instance.name == "baidu_search_tool" and not web_search:
|
||||
continue
|
||||
# 转换为LangChain工具
|
||||
langchain_tool = tool_instance.to_langchain_tool(tool_config.get("operation", None))
|
||||
tools.append(langchain_tool)
|
||||
elif tools_config and isinstance(tools_config, dict):
|
||||
web_search_choice = tools_config.get("web_search", {})
|
||||
web_search_enable = web_search_choice.get("enabled", False)
|
||||
if web_search and web_search_enable:
|
||||
search_tool = create_web_search_tool({})
|
||||
tools.append(search_tool)
|
||||
|
||||
logger.debug(
|
||||
"已添加网络搜索工具",
|
||||
extra={
|
||||
"tool_count": len(tools)
|
||||
}
|
||||
)
|
||||
return tools
|
||||
|
||||
def load_skill_config(
|
||||
self,
|
||||
skills_config: dict | None,
|
||||
message: str, tenant_id
|
||||
) -> tuple[list, str]:
|
||||
if not skills_config:
|
||||
return [], ""
|
||||
|
||||
tools = []
|
||||
skill_prompts = ""
|
||||
skill_enable = skills_config.get("enabled", False)
|
||||
if skill_enable:
|
||||
middleware = AgentMiddleware(skills=skills_config)
|
||||
skill_tools, skill_configs, tool_to_skill_map = middleware.load_skill_tools(self.db, tenant_id)
|
||||
tools.extend(skill_tools)
|
||||
logger.debug(f"已加载 {len(skill_tools)} 个技能工具")
|
||||
|
||||
if skill_configs:
|
||||
tools, activated_skill_ids = middleware.filter_tools(tools, message, skill_configs,
|
||||
tool_to_skill_map)
|
||||
logger.debug(f"过滤后剩余 {len(tools)} 个工具")
|
||||
skill_prompts = AgentMiddleware.get_active_prompts(
|
||||
activated_skill_ids, skill_configs
|
||||
)
|
||||
|
||||
return tools, skill_prompts
|
||||
|
||||
def load_knowledge_retrieval_config(
|
||||
self,
|
||||
knowledge_retrieval_config: dict | None,
|
||||
user_id
|
||||
) -> list:
|
||||
if not knowledge_retrieval_config:
|
||||
return []
|
||||
|
||||
tools = []
|
||||
knowledge_bases = knowledge_retrieval_config.get("knowledge_bases", [])
|
||||
kb_ids = bool(knowledge_bases and knowledge_bases[0].get("kb_id"))
|
||||
if kb_ids:
|
||||
# 创建知识库检索工具
|
||||
kb_tool = create_knowledge_retrieval_tool(knowledge_retrieval_config, kb_ids, user_id)
|
||||
tools.append(kb_tool)
|
||||
|
||||
logger.debug(
|
||||
"已添加知识库检索工具",
|
||||
extra={
|
||||
"kb_ids": kb_ids,
|
||||
"tool_count": len(tools)
|
||||
}
|
||||
)
|
||||
return tools
|
||||
|
||||
def load_memory_config(
|
||||
self,
|
||||
memory_config: dict | None,
|
||||
user_id,
|
||||
storage_type,
|
||||
user_rag_memory_id
|
||||
) -> tuple[list, bool]:
|
||||
"""加载长期记忆配置"""
|
||||
if not memory_config:
|
||||
return [], False
|
||||
|
||||
tools = []
|
||||
if memory_config.get("enabled"):
|
||||
if user_id:
|
||||
# 创建长期记忆工具
|
||||
memory_tool = create_long_term_memory_tool(memory_config, user_id, storage_type,
|
||||
user_rag_memory_id)
|
||||
tools.append(memory_tool)
|
||||
|
||||
logger.debug(
|
||||
"已添加长期记忆工具",
|
||||
extra={
|
||||
"user_id": user_id,
|
||||
"tool_count": len(tools)
|
||||
}
|
||||
)
|
||||
return tools, bool(memory_config.get("enabled"))
|
||||
|
||||
async def run(
|
||||
self,
|
||||
*,
|
||||
@@ -270,19 +403,21 @@ class DraftRunService:
|
||||
conversation_id: 会话ID(用于多轮对话)
|
||||
user_id: 用户ID
|
||||
variables: 自定义变量参数值
|
||||
storage_type: 存储类型(可选)
|
||||
user_rag_memory_id: 用户RAG记忆ID(可选)
|
||||
web_search: 是否启用网络搜索(默认True)
|
||||
memory: 是否启用长期记忆(默认True)
|
||||
sub_agent: 是否为子代理调用(默认False)
|
||||
files: 多模态文件列表(可选)
|
||||
|
||||
Returns:
|
||||
Dict: 包含 AI 回复和元数据的字典
|
||||
"""
|
||||
memory_flag = False
|
||||
|
||||
print('===========', storage_type)
|
||||
|
||||
print(user_id)
|
||||
if variables == None: variables = {}
|
||||
from app.core.agent.langchain_agent import LangChainAgent
|
||||
|
||||
start_time = time.time()
|
||||
tools_config: dict | list | None = agent_config.tools
|
||||
skills_config: dict | None = agent_config.skills
|
||||
knowledge_retrieval_config: dict | None = agent_config.knowledge_retrieval
|
||||
memory_config: dict | None = agent_config.memory
|
||||
|
||||
try:
|
||||
# 1. 获取 API Key 配置
|
||||
@@ -302,112 +437,40 @@ class DraftRunService:
|
||||
agent_config=agent_config
|
||||
)
|
||||
|
||||
items_params = variables
|
||||
if sub_agent:
|
||||
variables = self.prepare_variables(variables, agent_config.variables)
|
||||
else:
|
||||
# FIXME: subagent input valid
|
||||
variables = variables or {}
|
||||
|
||||
system_prompt = render_prompt_message(
|
||||
agent_config.system_prompt, # 修正拼写错误
|
||||
agent_config.system_prompt,
|
||||
PromptMessageRole.USER,
|
||||
items_params
|
||||
variables
|
||||
)
|
||||
|
||||
# 3. 处理系统提示词(支持变量替换)
|
||||
system_prompt = system_prompt.get_text_content() or "你是一个专业的AI助手"
|
||||
print('系统提示词:', system_prompt)
|
||||
|
||||
# 4. 准备工具列表
|
||||
tools = []
|
||||
|
||||
tool_service = ToolService(self.db)
|
||||
tenant_id = ToolRepository.get_tenant_id_by_workspace_id(self.db, str(workspace_id))
|
||||
|
||||
# 从配置中获取启用的工具
|
||||
if hasattr(agent_config, 'tools') and agent_config.tools and isinstance(agent_config.tools, list):
|
||||
if hasattr(agent_config, 'tools') and agent_config.tools:
|
||||
for tool_config in agent_config.tools:
|
||||
print("+" * 50)
|
||||
print(f"agent_config:{agent_config}")
|
||||
print(f"tool_config:{tool_config}")
|
||||
if tool_config.get("enabled", False):
|
||||
# 根据工具名称查找工具实例
|
||||
tool_instance = tool_service._get_tool_instance(tool_config.get("tool_id", ""), tenant_id)
|
||||
if tool_instance:
|
||||
if tool_instance.name == "baidu_search_tool" and not web_search:
|
||||
continue
|
||||
# 转换为LangChain工具
|
||||
langchain_tool = tool_instance.to_langchain_tool(tool_config.get("operation", None))
|
||||
tools.append(langchain_tool)
|
||||
elif hasattr(agent_config, 'tools') and agent_config.tools and isinstance(agent_config.tools, dict):
|
||||
web_tools = agent_config.tools
|
||||
web_search_choice = web_tools.get("web_search", {})
|
||||
web_search_enable = web_search_choice.get("enabled", False)
|
||||
if web_search:
|
||||
if web_search_enable:
|
||||
search_tool = create_web_search_tool({})
|
||||
tools.append(search_tool)
|
||||
|
||||
logger.debug(
|
||||
"已添加网络搜索工具",
|
||||
extra={
|
||||
"tool_count": len(tools)
|
||||
}
|
||||
)
|
||||
|
||||
# 加载技能关联的工具
|
||||
if hasattr(agent_config, 'skills') and agent_config.skills:
|
||||
skills = agent_config.skills
|
||||
skill_enable = skills.get("enabled", False)
|
||||
if skill_enable:
|
||||
middleware = AgentMiddleware(skills=skills)
|
||||
skill_tools, skill_configs, tool_to_skill_map = middleware.load_skill_tools(self.db, tenant_id)
|
||||
tools.extend(skill_tools)
|
||||
logger.debug(f"已加载 {len(skill_tools)} 个技能工具")
|
||||
|
||||
# 应用动态过滤
|
||||
if skill_configs:
|
||||
tools, activated_skill_ids = middleware.filter_tools(tools, message, skill_configs,
|
||||
tool_to_skill_map)
|
||||
logger.debug(f"过滤后剩余 {len(tools)} 个工具")
|
||||
active_prompts = AgentMiddleware.get_active_prompts(
|
||||
activated_skill_ids, skill_configs
|
||||
)
|
||||
system_prompt = f"{system_prompt}\n\n{active_prompts}"
|
||||
|
||||
# 添加知识库检索工具
|
||||
if agent_config.knowledge_retrieval:
|
||||
kb_config = agent_config.knowledge_retrieval
|
||||
knowledge_bases = kb_config.get("knowledge_bases", [])
|
||||
kb_ids = bool(knowledge_bases and knowledge_bases[0].get("kb_id"))
|
||||
if kb_ids:
|
||||
# 创建知识库检索工具
|
||||
kb_tool = create_knowledge_retrieval_tool(kb_config, kb_ids, user_id)
|
||||
tools.append(kb_tool)
|
||||
|
||||
logger.debug(
|
||||
"已添加知识库检索工具",
|
||||
extra={
|
||||
"kb_ids": kb_ids,
|
||||
"tool_count": len(tools)
|
||||
}
|
||||
)
|
||||
|
||||
tools.extend(self.load_tools_config(tools_config, web_search, tenant_id))
|
||||
skill_tools, skill_prompts = self.load_skill_config(skills_config, message, tenant_id)
|
||||
tools.extend(skill_tools)
|
||||
if skill_prompts:
|
||||
system_prompt = f"{system_prompt}\n\n{skill_prompts}"
|
||||
tools.extend(self.load_knowledge_retrieval_config(knowledge_retrieval_config, user_id))
|
||||
# 添加长期记忆工具
|
||||
memory_flag = False
|
||||
if memory:
|
||||
if agent_config.memory and agent_config.memory.get("enabled"):
|
||||
memory_flag = True
|
||||
|
||||
memory_config = agent_config.memory
|
||||
if user_id:
|
||||
# 创建长期记忆工具
|
||||
memory_tool = create_long_term_memory_tool(memory_config, user_id, storage_type,
|
||||
user_rag_memory_id)
|
||||
tools.append(memory_tool)
|
||||
|
||||
logger.debug(
|
||||
"已添加长期记忆工具",
|
||||
extra={
|
||||
"user_id": user_id,
|
||||
"tool_count": len(tools)
|
||||
}
|
||||
)
|
||||
memory_tools, memory_flag = self.load_memory_config(
|
||||
memory_config, user_id, storage_type, user_rag_memory_id
|
||||
)
|
||||
tools.extend(memory_tools)
|
||||
|
||||
# 4. 创建 LangChain Agent
|
||||
agent = LangChainAgent(
|
||||
@@ -415,6 +478,7 @@ class DraftRunService:
|
||||
api_key=api_key_config["api_key"],
|
||||
provider=api_key_config.get("provider", "openai"),
|
||||
api_base=api_key_config.get("api_base"),
|
||||
is_omni=api_key_config.get("is_omni", False),
|
||||
temperature=effective_params.get("temperature", 0.7),
|
||||
max_tokens=effective_params.get("max_tokens", 2000),
|
||||
system_prompt=system_prompt,
|
||||
@@ -431,7 +495,7 @@ class DraftRunService:
|
||||
|
||||
# 6. 加载历史消息
|
||||
history = []
|
||||
if agent_config.memory and agent_config.memory.get("enabled"):
|
||||
if memory_config and memory_config.get("enabled"):
|
||||
history = await self._load_conversation_history(
|
||||
conversation_id=conversation_id,
|
||||
max_history=agent_config.memory.get("max_history", 10)
|
||||
@@ -442,7 +506,7 @@ class DraftRunService:
|
||||
if files:
|
||||
# 获取 provider 信息
|
||||
provider = api_key_config.get("provider", "openai")
|
||||
multimodal_service = MultimodalService(self.db, provider=provider)
|
||||
multimodal_service = MultimodalService(self.db, provider=provider, is_omni=api_key_config.get("is_omni", False))
|
||||
processed_files = await multimodal_service.process_files(files)
|
||||
logger.info(f"处理了 {len(processed_files)} 个文件,provider={provider}")
|
||||
|
||||
@@ -481,7 +545,7 @@ class DraftRunService:
|
||||
ModelApiKeyService.record_api_key_usage(self.db, api_key_config.get("api_key_id"))
|
||||
|
||||
# 9. 保存会话消息
|
||||
if not sub_agent and agent_config.memory and agent_config.memory.get("enabled"):
|
||||
if not sub_agent and memory_config and memory_config.get("enabled"):
|
||||
await self._save_conversation_message(
|
||||
conversation_id=conversation_id,
|
||||
user_message=message,
|
||||
@@ -556,16 +620,21 @@ class DraftRunService:
|
||||
Yields:
|
||||
str: SSE 格式的事件数据
|
||||
"""
|
||||
memory_flag = False
|
||||
if variables == None: variables = {}
|
||||
|
||||
from app.core.agent.langchain_agent import LangChainAgent
|
||||
tools_config: dict | list | None = agent_config.tools
|
||||
skills_config: dict | None = agent_config.skills
|
||||
knowledge_retrieval_config: dict | None = agent_config.knowledge_retrieval
|
||||
memory_config: dict | None = agent_config.memory
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
# 1. 获取 API Key 配置
|
||||
api_key_config = await self._get_api_key(model_config.id)
|
||||
if not sub_agent:
|
||||
variables = self.prepare_variables(variables, agent_config.variables)
|
||||
else:
|
||||
# FIXME: subagent input valid
|
||||
variables = variables or {}
|
||||
|
||||
# 2. 合并模型参数
|
||||
effective_params = ModelParameterMerger.get_effective_parameters(
|
||||
@@ -587,95 +656,22 @@ class DraftRunService:
|
||||
# 4. 准备工具列表
|
||||
tools = []
|
||||
|
||||
tool_service = ToolService(self.db)
|
||||
tenant_id = ToolRepository.get_tenant_id_by_workspace_id(self.db, str(workspace_id))
|
||||
|
||||
# 从配置中获取启用的工具
|
||||
if hasattr(agent_config, 'tools') and agent_config.tools and isinstance(agent_config.tools, list):
|
||||
for tool_config in agent_config.tools:
|
||||
# print("+"*50)
|
||||
# print(f"agent_config:{agent_config}")
|
||||
# print(f"tool_config:{tool_config}")
|
||||
if tool_config.get("enabled", False):
|
||||
# 根据工具名称查找工具实例
|
||||
tool_instance = tool_service._get_tool_instance(tool_config.get("tool_id", ""), tenant_id)
|
||||
if tool_instance:
|
||||
if tool_instance.name == "baidu_search_tool" and not web_search:
|
||||
continue
|
||||
# 转换为LangChain工具
|
||||
langchain_tool = tool_instance.to_langchain_tool(tool_config.get("operation", None))
|
||||
tools.append(langchain_tool)
|
||||
elif hasattr(agent_config, 'tools') and agent_config.tools and isinstance(agent_config.tools, dict):
|
||||
web_tools = agent_config.tools
|
||||
web_search_choice = web_tools.get("web_search", {})
|
||||
web_search_enable = web_search_choice.get("enabled", False)
|
||||
if web_search:
|
||||
if web_search_enable:
|
||||
search_tool = create_web_search_tool({})
|
||||
tools.append(search_tool)
|
||||
tools.extend(self.load_tools_config(tools_config, web_search, tenant_id))
|
||||
skill_tools, skill_prompts = self.load_skill_config(skills_config, message, tenant_id)
|
||||
tools.extend(skill_tools)
|
||||
if skill_prompts:
|
||||
system_prompt = f"{system_prompt}\n\n{skill_prompts}"
|
||||
tools.extend(self.load_knowledge_retrieval_config(knowledge_retrieval_config, user_id))
|
||||
|
||||
logger.debug(
|
||||
"已添加网络搜索工具",
|
||||
extra={
|
||||
"tool_count": len(tools)
|
||||
}
|
||||
)
|
||||
|
||||
# 加载技能关联的工具
|
||||
if hasattr(agent_config, 'skills') and agent_config.skills:
|
||||
skills = agent_config.skills
|
||||
skill_enable = skills.get("enabled", False)
|
||||
if skill_enable:
|
||||
middleware = AgentMiddleware(skills=skills)
|
||||
skill_tools, skill_configs, tool_to_skill_map = middleware.load_skill_tools(self.db, tenant_id)
|
||||
tools.extend(skill_tools)
|
||||
logger.debug(f"已加载 {len(skill_tools)} 个技能工具")
|
||||
|
||||
# 应用动态过滤
|
||||
if skill_configs:
|
||||
tools, activated_skill_ids = middleware.filter_tools(tools, message, skill_configs,
|
||||
tool_to_skill_map)
|
||||
logger.debug(f"过滤后剩余 {len(tools)} 个工具")
|
||||
active_prompts = AgentMiddleware.get_active_prompts(
|
||||
activated_skill_ids, skill_configs
|
||||
)
|
||||
system_prompt = f"{system_prompt}\n\n{active_prompts}"
|
||||
|
||||
# 添加知识库检索工具
|
||||
if agent_config.knowledge_retrieval:
|
||||
kb_config = agent_config.knowledge_retrieval
|
||||
knowledge_bases = kb_config.get("knowledge_bases", [])
|
||||
kb_ids = bool(knowledge_bases and knowledge_bases[0].get("kb_id"))
|
||||
if kb_ids:
|
||||
# 创建知识库检索工具
|
||||
kb_tool = create_knowledge_retrieval_tool(kb_config, kb_ids, user_id)
|
||||
tools.append(kb_tool)
|
||||
|
||||
logger.debug(
|
||||
"已添加知识库检索工具",
|
||||
extra={
|
||||
"kb_ids": kb_ids,
|
||||
"tool_count": len(tools)
|
||||
}
|
||||
)
|
||||
# 添加长期记忆工具
|
||||
memory_flag = False
|
||||
if memory:
|
||||
if agent_config.memory and agent_config.memory.get("enabled"):
|
||||
memory_flag = True
|
||||
memory_config = agent_config.memory
|
||||
if user_id:
|
||||
# 创建长期记忆工具
|
||||
memory_tool = create_long_term_memory_tool(memory_config, user_id, storage_type,
|
||||
user_rag_memory_id)
|
||||
tools.append(memory_tool)
|
||||
|
||||
logger.debug(
|
||||
"已添加长期记忆工具",
|
||||
extra={
|
||||
"user_id": user_id,
|
||||
"tool_count": len(tools)
|
||||
}
|
||||
)
|
||||
memory_tools, memory_flag = self.load_memory_config(memory_config, user_id, storage_type,
|
||||
user_rag_memory_id)
|
||||
tools.extend(memory_tools)
|
||||
|
||||
# 4. 创建 LangChain Agent
|
||||
agent = LangChainAgent(
|
||||
@@ -683,6 +679,7 @@ class DraftRunService:
|
||||
api_key=api_key_config["api_key"],
|
||||
provider=api_key_config.get("provider", "openai"),
|
||||
api_base=api_key_config.get("api_base"),
|
||||
is_omni=api_key_config.get("is_omni", False),
|
||||
temperature=effective_params.get("temperature", 0.7),
|
||||
max_tokens=effective_params.get("max_tokens", 2000),
|
||||
system_prompt=system_prompt,
|
||||
@@ -700,10 +697,10 @@ class DraftRunService:
|
||||
|
||||
# 6. 加载历史消息
|
||||
history = []
|
||||
if agent_config.memory and agent_config.memory.get("enabled"):
|
||||
if memory_config and memory_config.get("enabled"):
|
||||
history = await self._load_conversation_history(
|
||||
conversation_id=conversation_id,
|
||||
max_history=agent_config.memory.get("max_history", 10)
|
||||
max_history=memory_config.get("max_history", 10)
|
||||
)
|
||||
|
||||
# 6. 处理多模态文件
|
||||
@@ -711,7 +708,7 @@ class DraftRunService:
|
||||
if files:
|
||||
# 获取 provider 信息
|
||||
provider = api_key_config.get("provider", "openai")
|
||||
multimodal_service = MultimodalService(self.db, provider=provider)
|
||||
multimodal_service = MultimodalService(self.db, provider=provider, is_omni=api_key_config.get("is_omni", False))
|
||||
processed_files = await multimodal_service.process_files(files)
|
||||
logger.info(f"处理了 {len(processed_files)} 个文件,provider={provider}")
|
||||
|
||||
@@ -761,7 +758,7 @@ class DraftRunService:
|
||||
})
|
||||
|
||||
# 10. 保存会话消息
|
||||
if not sub_agent and agent_config.memory and agent_config.memory.get("enabled"):
|
||||
if not sub_agent and memory_config and memory_config.get("enabled"):
|
||||
await self._save_conversation_message(
|
||||
conversation_id=conversation_id,
|
||||
user_message=message,
|
||||
@@ -809,7 +806,7 @@ class DraftRunService:
|
||||
"""
|
||||
return f"event: {event}\ndata: {json.dumps(data, ensure_ascii=False)}\n\n"
|
||||
|
||||
async def _get_api_key(self, model_config_id: uuid.UUID) -> Dict[str, str]:
|
||||
async def _get_api_key(self, model_config_id: uuid.UUID) -> Dict:
|
||||
"""获取模型的 API Key
|
||||
|
||||
Args:
|
||||
@@ -846,7 +843,8 @@ class DraftRunService:
|
||||
"provider": api_key.provider,
|
||||
"api_key": api_key.api_key,
|
||||
"api_base": api_key.api_base,
|
||||
"api_key_id": api_key.id
|
||||
"api_key_id": api_key.id,
|
||||
"is_omni": api_key.is_omni
|
||||
}
|
||||
|
||||
async def _ensure_conversation(
|
||||
@@ -966,7 +964,6 @@ class DraftRunService:
|
||||
List[Dict]: 历史消息列表
|
||||
"""
|
||||
try:
|
||||
from app.services.conversation_service import ConversationService
|
||||
|
||||
conversation_service = ConversationService(self.db)
|
||||
history = conversation_service.get_conversation_history(
|
||||
@@ -1486,6 +1483,15 @@ class DraftRunService:
|
||||
"conversation_id": returned_conversation_id,
|
||||
"content": chunk
|
||||
}))
|
||||
|
||||
if event_type == "error" and event_data:
|
||||
await event_queue.put(self._format_sse_event("model_error", {
|
||||
"model_index": idx,
|
||||
"model_config_id": model_config_id,
|
||||
"label": model_label,
|
||||
"conversation_id": returned_conversation_id,
|
||||
"error": event_data.get("error", "未知错误")
|
||||
}))
|
||||
except Exception as e:
|
||||
logger.warning(f"解析流式事件失败: {e}")
|
||||
finally:
|
||||
@@ -1670,41 +1676,3 @@ class DraftRunService:
|
||||
"total_time": sum(r.get("elapsed_time", 0) for r in results)
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
async def draft_run(
|
||||
db: Session,
|
||||
*,
|
||||
agent_config: AgentConfig,
|
||||
model_config: ModelConfig,
|
||||
message: str,
|
||||
user_id: Optional[str] = None,
|
||||
kb_ids: Optional[List[str]] = None,
|
||||
similarity_threshold: float = 0.7,
|
||||
top_k: int = 3
|
||||
) -> Dict[str, Any]:
|
||||
"""试运行 Agent(便捷函数)
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
agent_config: Agent 配置
|
||||
model_config: 模型配置
|
||||
message: 用户消息
|
||||
user_id: 用户ID
|
||||
kb_ids: 知识库ID列表
|
||||
similarity_threshold: 相似度阈值
|
||||
top_k: 检索返回的文档数量
|
||||
|
||||
Returns:
|
||||
Dict: 包含 AI 回复和元数据的字典
|
||||
"""
|
||||
service = DraftRunService(db)
|
||||
return await service.run(
|
||||
agent_config=agent_config,
|
||||
model_config=model_config,
|
||||
message=message,
|
||||
user_id=user_id,
|
||||
kb_ids=kb_ids,
|
||||
similarity_threshold=similarity_threshold,
|
||||
top_k=top_k
|
||||
)
|
||||
|
||||
@@ -843,32 +843,33 @@ class EmotionAnalyticsService:
|
||||
end_user_id: str,
|
||||
db: Session,
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""从 Redis 缓存获取个性化情绪建议
|
||||
"""从数据库获取个性化情绪建议
|
||||
|
||||
Args:
|
||||
end_user_id: 宿主ID(用户组ID)
|
||||
db: 数据库会话(保留参数以保持接口兼容性)
|
||||
db: 数据库会话
|
||||
|
||||
Returns:
|
||||
Dict: 缓存的建议数据,如果不存在或已过期返回 None
|
||||
Dict: 存储的建议数据,如果不存在返回 None
|
||||
"""
|
||||
try:
|
||||
from app.cache.memory.emotion_memory import EmotionMemoryCache
|
||||
from app.repositories.implicit_emotions_storage_repository import ImplicitEmotionsStorageRepository
|
||||
|
||||
logger.info(f"尝试从 Redis 缓存获取情绪建议: user={end_user_id}")
|
||||
logger.info(f"尝试从数据库获取情绪建议: user={end_user_id}")
|
||||
|
||||
# 从 Redis 获取缓存
|
||||
cached_data = await EmotionMemoryCache.get_emotion_suggestions(end_user_id)
|
||||
# 从数据库获取存储记录
|
||||
repo = ImplicitEmotionsStorageRepository(db)
|
||||
storage = repo.get_by_end_user_id(end_user_id)
|
||||
|
||||
if cached_data is None:
|
||||
logger.info(f"用户 {end_user_id} 的建议缓存不存在或已过期")
|
||||
if storage is None or storage.emotion_suggestions is None:
|
||||
logger.info(f"用户 {end_user_id} 的建议数据不存在")
|
||||
return None
|
||||
|
||||
logger.info(f"成功从 Redis 缓存获取建议: user={end_user_id}")
|
||||
return cached_data
|
||||
logger.info(f"成功从数据库获取建议: user={end_user_id}")
|
||||
return storage.emotion_suggestions
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"从 Redis 缓存获取建议失败: {str(e)}", exc_info=True)
|
||||
logger.error(f"从数据库获取建议失败: {str(e)}", exc_info=True)
|
||||
return None
|
||||
|
||||
async def save_suggestions_cache(
|
||||
@@ -876,36 +877,27 @@ class EmotionAnalyticsService:
|
||||
end_user_id: str,
|
||||
suggestions_data: Dict[str, Any],
|
||||
db: Session,
|
||||
expires_hours: int = 24
|
||||
expires_hours: int = 24 # 参数保留以保持接口兼容性
|
||||
) -> None:
|
||||
"""保存建议到 Redis 缓存
|
||||
"""保存建议到数据库
|
||||
|
||||
Args:
|
||||
end_user_id: 宿主ID(用户组ID)
|
||||
suggestions_data: 建议数据
|
||||
db: 数据库会话(保留参数以保持接口兼容性)
|
||||
expires_hours: 过期时间(小时),默认24小时
|
||||
db: 数据库会话
|
||||
expires_hours: 保留参数(兼容性)
|
||||
"""
|
||||
try:
|
||||
from app.cache.memory.emotion_memory import EmotionMemoryCache
|
||||
from app.repositories.implicit_emotions_storage_repository import ImplicitEmotionsStorageRepository
|
||||
|
||||
logger.info(f"保存建议到 Redis 缓存: user={end_user_id}, expires={expires_hours}小时")
|
||||
logger.info(f"保存建议到数据库: user={end_user_id}")
|
||||
|
||||
# 计算过期时间(秒)
|
||||
expire_seconds = expires_hours * 3600
|
||||
repo = ImplicitEmotionsStorageRepository(db)
|
||||
repo.update_emotion_suggestions(end_user_id, suggestions_data)
|
||||
db.commit()
|
||||
|
||||
# 保存到 Redis
|
||||
success = await EmotionMemoryCache.set_emotion_suggestions(
|
||||
user_id=end_user_id,
|
||||
suggestions_data=suggestions_data,
|
||||
expire=expire_seconds
|
||||
)
|
||||
|
||||
if success:
|
||||
logger.info(f"建议缓存保存成功: user={end_user_id}")
|
||||
else:
|
||||
logger.warning(f"建议缓存保存失败: user={end_user_id}")
|
||||
logger.info(f"建议保存成功: user={end_user_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"保存建议缓存失败: {str(e)}", exc_info=True)
|
||||
# 不抛出异常,缓存失败不应影响主流程
|
||||
db.rollback()
|
||||
logger.error(f"保存建议失败: {str(e)}", exc_info=True)
|
||||
@@ -544,6 +544,7 @@ def convert_multi_agent_config_to_handoffs(
|
||||
provider=model_api_key.provider,
|
||||
api_key=model_api_key.api_key,
|
||||
base_url=model_api_key.api_base,
|
||||
is_omni=model_api_key.is_omni,
|
||||
extra_params={
|
||||
"temperature": 0.7,
|
||||
"max_tokens": 2000,
|
||||
|
||||
@@ -422,32 +422,33 @@ class ImplicitMemoryService:
|
||||
end_user_id: str,
|
||||
db: Session
|
||||
) -> Optional[dict]:
|
||||
"""从 Redis 缓存获取完整用户画像
|
||||
"""从数据库获取完整用户画像
|
||||
|
||||
Args:
|
||||
end_user_id: 终端用户ID
|
||||
db: 数据库会话(保留参数以保持接口兼容性)
|
||||
db: 数据库会话
|
||||
|
||||
Returns:
|
||||
Dict: 缓存的画像数据,如果不存在或已过期返回 None
|
||||
Dict: 存储的画像数据,如果不存在返回 None
|
||||
"""
|
||||
try:
|
||||
from app.cache.memory.implicit_memory import ImplicitMemoryCache
|
||||
from app.repositories.implicit_emotions_storage_repository import ImplicitEmotionsStorageRepository
|
||||
|
||||
logger.info(f"尝试从 Redis 缓存获取用户画像: user={end_user_id}")
|
||||
logger.info(f"尝试从数据库获取用户画像: user={end_user_id}")
|
||||
|
||||
# 从 Redis 获取缓存
|
||||
cached_data = await ImplicitMemoryCache.get_user_profile(end_user_id)
|
||||
# 从数据库获取存储记录
|
||||
repo = ImplicitEmotionsStorageRepository(db)
|
||||
storage = repo.get_by_end_user_id(end_user_id)
|
||||
|
||||
if cached_data is None:
|
||||
logger.info(f"用户 {end_user_id} 的画像缓存不存在或已过期")
|
||||
if storage is None or storage.implicit_profile is None:
|
||||
logger.info(f"用户 {end_user_id} 的画像数据不存在")
|
||||
return None
|
||||
|
||||
logger.info(f"成功从 Redis 缓存获取用户画像: user={end_user_id}")
|
||||
return cached_data
|
||||
logger.info(f"成功从数据库获取用户画像: user={end_user_id}")
|
||||
return storage.implicit_profile
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"从 Redis 缓存获取用户画像失败: {str(e)}", exc_info=True)
|
||||
logger.error(f"从数据库获取用户画像失败: {str(e)}", exc_info=True)
|
||||
return None
|
||||
|
||||
async def save_profile_cache(
|
||||
@@ -455,36 +456,27 @@ class ImplicitMemoryService:
|
||||
end_user_id: str,
|
||||
profile_data: dict,
|
||||
db: Session,
|
||||
expires_hours: int = 168 # 默认7天
|
||||
expires_hours: int = 168 # 参数保留以保持接口兼容性
|
||||
) -> None:
|
||||
"""保存用户画像到 Redis 缓存
|
||||
"""保存用户画像到数据库
|
||||
|
||||
Args:
|
||||
end_user_id: 终端用户ID
|
||||
profile_data: 画像数据
|
||||
db: 数据库会话(保留参数以保持接口兼容性)
|
||||
expires_hours: 过期时间(小时),默认168小时(7天)
|
||||
db: 数据库会话
|
||||
expires_hours: 保留参数(兼容性)
|
||||
"""
|
||||
try:
|
||||
from app.cache.memory.implicit_memory import ImplicitMemoryCache
|
||||
from app.repositories.implicit_emotions_storage_repository import ImplicitEmotionsStorageRepository
|
||||
|
||||
logger.info(f"保存用户画像到 Redis 缓存: user={end_user_id}, expires={expires_hours}小时")
|
||||
logger.info(f"保存用户画像到数据库: user={end_user_id}")
|
||||
|
||||
# 计算过期时间(秒)
|
||||
expire_seconds = expires_hours * 3600
|
||||
repo = ImplicitEmotionsStorageRepository(db)
|
||||
repo.update_implicit_profile(end_user_id, profile_data)
|
||||
db.commit()
|
||||
|
||||
# 保存到 Redis
|
||||
success = await ImplicitMemoryCache.set_user_profile(
|
||||
user_id=end_user_id,
|
||||
profile_data=profile_data,
|
||||
expire=expire_seconds
|
||||
)
|
||||
|
||||
if success:
|
||||
logger.info(f"用户画像缓存保存成功: user={end_user_id}")
|
||||
else:
|
||||
logger.warning(f"用户画像缓存保存失败: user={end_user_id}")
|
||||
logger.info(f"用户画像保存成功: user={end_user_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"保存用户画像缓存失败: {str(e)}", exc_info=True)
|
||||
# 不抛出异常,缓存失败不应影响主流程
|
||||
db.rollback()
|
||||
logger.error(f"保存用户画像失败: {str(e)}", exc_info=True)
|
||||
|
||||
@@ -9,6 +9,8 @@ load_dotenv()
|
||||
|
||||
# 读取web_search环境变量
|
||||
web_search_value = os.getenv('web_search')
|
||||
|
||||
|
||||
def Search(query):
|
||||
url = "https://qianfan.baidubce.com/v2/ai_search/chat/completions"
|
||||
api_key = web_search_value
|
||||
@@ -18,23 +20,24 @@ def Search(query):
|
||||
"role": "user",
|
||||
"content": query
|
||||
}
|
||||
], #搜索输入
|
||||
"edition":"standard", #搜索版本。默认为standard。可选值:standard:完整版本。lite:标准版本,对召回规模和精排条数简化后的版本,时延表现更好,效果略弱于完整版。
|
||||
"search_source": "baidu_search_v2", #使用的搜索引擎版本
|
||||
"resource_type_filter": [{"type": "web","top_k": 20}], #支持设置网页、视频、图片、阿拉丁搜索模态,网页top_k最大取值为50,视频top_k最大为10,图片top_k最大为30,阿拉丁top_k最大为5
|
||||
], # 搜索输入
|
||||
"edition": "standard", # 搜索版本。默认为standard。可选值:standard:完整版本。lite:标准版本,对召回规模和精排条数简化后的版本,时延表现更好,效果略弱于完整版。
|
||||
"search_source": "baidu_search_v2", # 使用的搜索引擎版本
|
||||
"resource_type_filter": [{"type": "web", "top_k": 20}],
|
||||
# 支持设置网页、视频、图片、阿拉丁搜索模态,网页top_k最大取值为50,视频top_k最大为10,图片top_k最大为30,阿拉丁top_k最大为5
|
||||
"search_filter": {
|
||||
"range": {
|
||||
"page_time": {
|
||||
"gte": "now-1w/d", #时间查询参数,大于或等于
|
||||
"lt": "now/d", #时间查询参数,小于
|
||||
"gt": "", #时间查询参数,大于
|
||||
"lte": "" #时间查询参数,小于或等于
|
||||
"gte": "now-1w/d", # 时间查询参数,大于或等于
|
||||
"lt": "now/d", # 时间查询参数,小于
|
||||
"gt": "", # 时间查询参数,大于
|
||||
"lte": "" # 时间查询参数,小于或等于
|
||||
}
|
||||
}
|
||||
},
|
||||
"block_websites":["tieba.baidu.com"], #需要屏蔽的站点列表
|
||||
"search_recency_filter":"week", #根据网页发布时间进行筛选,可填值为:week,month,semiyear,year
|
||||
"enable_full_content":True #是否输出网页完整原文
|
||||
"block_websites": ["tieba.baidu.com"], # 需要屏蔽的站点列表
|
||||
"search_recency_filter": "week", # 根据网页发布时间进行筛选,可填值为:week,month,semiyear,year
|
||||
"enable_full_content": True # 是否输出网页完整原文
|
||||
}, ensure_ascii=False)
|
||||
headers = {
|
||||
'Content-Type': 'application/json',
|
||||
@@ -42,10 +45,10 @@ def Search(query):
|
||||
}
|
||||
|
||||
response = requests.request("POST", url, headers=headers, data=payload.encode("utf-8")).json()
|
||||
content=[]
|
||||
content = []
|
||||
for i in response['references']:
|
||||
title=i['title']
|
||||
snippet=i['snippet']
|
||||
content.append(title+';'+snippet)
|
||||
content='。'.join(content)
|
||||
return content
|
||||
title = i['title']
|
||||
snippet = i['snippet']
|
||||
content.append(title + ';' + snippet)
|
||||
content = '。'.join(content)
|
||||
return content
|
||||
|
||||
@@ -414,6 +414,7 @@ class LLMRouter:
|
||||
provider=api_key_config.provider,
|
||||
api_key=api_key_config.api_key,
|
||||
base_url=api_key_config.api_base,
|
||||
is_omni=api_key_config.is_omni,
|
||||
temperature=0.3,
|
||||
max_tokens=500
|
||||
)
|
||||
|
||||
@@ -392,6 +392,7 @@ class MasterAgentRouter:
|
||||
provider=api_key_config.provider,
|
||||
api_key=api_key_config.api_key,
|
||||
base_url=api_key_config.api_base,
|
||||
is_omni=api_key_config.is_omni,
|
||||
extra_params = extra_params
|
||||
)
|
||||
|
||||
|
||||
@@ -36,7 +36,7 @@ from app.core.memory.agent.utils.messages_tools import (
|
||||
)
|
||||
from app.core.memory.agent.utils.type_classifier import status_typle
|
||||
from app.core.memory.agent.utils.write_tools import write # 新增:直接导入 write 函数
|
||||
from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags
|
||||
from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags, get_interest_distribution
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.db import get_db_context
|
||||
from app.models.knowledge_model import Knowledge, KnowledgeType
|
||||
@@ -816,11 +816,10 @@ class MemoryAgentService:
|
||||
"""
|
||||
统计知识库类型分布,包含:
|
||||
1. PostgreSQL 中的知识库类型:General, Web, Third-party, Folder(根据 workspace_id 过滤)
|
||||
2. Neo4j 中的 Memory 类型(仅统计 Chunk 数量,根据 end_user_id/end_user_id 过滤)
|
||||
3. total: 所有类型的总和
|
||||
2. total: 所有类型的总和
|
||||
|
||||
参数:
|
||||
- end_user_id: 用户组ID(可选,未提供时 Memory 统计为 0)
|
||||
- end_user_id: 用户组ID(可选,保留参数以保持接口兼容性)
|
||||
- only_active: 是否仅统计有效记录
|
||||
- current_workspace_id: 当前工作空间ID(可选,未提供时知识库统计为 0)
|
||||
- db: 数据库会话
|
||||
@@ -831,7 +830,6 @@ class MemoryAgentService:
|
||||
"Web": count,
|
||||
"Third-party": count,
|
||||
"Folder": count,
|
||||
"Memory": chunk_count,
|
||||
"total": sum_of_all
|
||||
}
|
||||
"""
|
||||
@@ -878,51 +876,8 @@ class MemoryAgentService:
|
||||
logger.error(f"知识库类型统计失败: {e}")
|
||||
raise Exception(f"知识库类型统计失败: {e}")
|
||||
|
||||
# 2. 统计 Neo4j 中的 memory 总量(统计当前空间下所有宿主的 Chunk 总数)
|
||||
try:
|
||||
if current_workspace_id:
|
||||
# 获取当前空间下的所有宿主
|
||||
from app.repositories import app_repository, end_user_repository
|
||||
from app.schemas.app_schema import App as AppSchema
|
||||
from app.schemas.end_user_schema import EndUser as EndUserSchema
|
||||
|
||||
# 查询应用并转换为 Pydantic 模型
|
||||
apps_orm = app_repository.get_apps_by_workspace_id(db, current_workspace_id)
|
||||
apps = [AppSchema.model_validate(h) for h in apps_orm]
|
||||
app_ids = [app.id for app in apps]
|
||||
|
||||
# 获取所有宿主
|
||||
end_users = []
|
||||
for app_id in app_ids:
|
||||
end_user_orm_list = end_user_repository.get_end_users_by_app_id(db, app_id)
|
||||
end_users.extend(h for h in end_user_orm_list)
|
||||
|
||||
# 统计所有宿主的 Chunk 总数
|
||||
total_chunks = 0
|
||||
for end_user in end_users:
|
||||
end_user_id_str = str(end_user.id)
|
||||
memory_query = """
|
||||
MATCH (n:Chunk) WHERE n.end_user_id = $end_user_id RETURN count(n) AS Count
|
||||
"""
|
||||
neo4j_result = await _neo4j_connector.execute_query(
|
||||
memory_query,
|
||||
end_user_id=end_user_id_str,
|
||||
)
|
||||
chunk_count = neo4j_result[0]["Count"] if neo4j_result else 0
|
||||
total_chunks += chunk_count
|
||||
logger.debug(f"EndUser {end_user_id_str} Chunk数量: {chunk_count}")
|
||||
|
||||
result["Memory"] = total_chunks
|
||||
logger.info(f"Neo4j memory统计成功: 总Chunk数={total_chunks}, 宿主数={len(end_users)}")
|
||||
else:
|
||||
# 没有 workspace_id 时,返回 0
|
||||
result["Memory"] = 0
|
||||
logger.info("未提供 workspace_id,memory 统计为 0")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Neo4j memory统计失败: {e}", exc_info=True)
|
||||
# 如果 Neo4j 查询失败,memory 设为 0
|
||||
result["Memory"] = 0
|
||||
# 2. 统计 Neo4j 中的 memory 总量已移除
|
||||
# memory 字段不再返回
|
||||
|
||||
# 3. 计算知识库类型总和(不包括 memory)
|
||||
result["total"] = (
|
||||
@@ -935,36 +890,36 @@ class MemoryAgentService:
|
||||
return result
|
||||
|
||||
|
||||
async def get_hot_memory_tags_by_user(
|
||||
|
||||
async def get_interest_distribution_by_user(
|
||||
self,
|
||||
end_user_id: Optional[str] = None,
|
||||
limit: int = 20
|
||||
limit: int = 5,
|
||||
language: str = "zh"
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
获取指定用户的热门记忆标签
|
||||
获取指定用户的兴趣分布标签。
|
||||
|
||||
与热门标签不同,此接口专注于识别用户的兴趣活动(运动、爱好、学习等),
|
||||
过滤掉纯物品、工具、地点等不代表用户主动参与活动的名词。
|
||||
|
||||
参数:
|
||||
- end_user_id: 用户ID(可选),对应Neo4j中的end_user_id字段
|
||||
- end_user_id: 用户ID(必填)
|
||||
- limit: 返回标签数量限制
|
||||
- language: 输出语言("zh" 中文, "en" 英文)
|
||||
|
||||
返回格式:
|
||||
[
|
||||
{"name": "标签名", "frequency": 频次},
|
||||
{"name": "兴趣活动名", "frequency": 频次},
|
||||
...
|
||||
]
|
||||
|
||||
注意:标签语言由写入时的 X-Language-Type 决定,查询时不进行翻译
|
||||
"""
|
||||
try:
|
||||
# by_user=False 表示按 end_user_id 查询(在Neo4j中,end_user_id就是用户维度)
|
||||
tags = await get_hot_memory_tags(end_user_id, limit=limit, by_user=False)
|
||||
payload = []
|
||||
for tag, freq in tags:
|
||||
payload.append({"name": tag, "frequency": freq})
|
||||
return payload
|
||||
tags = await get_interest_distribution(end_user_id, limit=limit, by_user=False, language=language)
|
||||
return [{"name": tag, "frequency": freq} for tag, freq in tags]
|
||||
except Exception as e:
|
||||
logger.error(f"热门记忆标签查询失败: {e}")
|
||||
raise Exception(f"热门记忆标签查询失败: {e}")
|
||||
logger.error(f"兴趣分布标签查询失败: {e}")
|
||||
raise Exception(f"兴趣分布标签查询失败: {e}")
|
||||
|
||||
|
||||
async def get_user_profile(
|
||||
|
||||
@@ -140,9 +140,11 @@ class MemoryAPIService:
|
||||
|
||||
try:
|
||||
# Delegate to MemoryAgentService
|
||||
# Convert string message to list[dict] format expected by MemoryAgentService
|
||||
messages = message if isinstance(message, list) else [{"role": "user", "content": message}]
|
||||
result = await MemoryAgentService().write_memory(
|
||||
end_user_id=end_user_id,
|
||||
messages=message,
|
||||
messages=messages,
|
||||
config_id=config_id,
|
||||
db=self.db,
|
||||
storage_type=storage_type,
|
||||
@@ -151,9 +153,18 @@ class MemoryAPIService:
|
||||
|
||||
logger.info(f"Memory write successful for end_user: {end_user_id}")
|
||||
|
||||
# result may be a string "success" or a dict with a "status" key
|
||||
# Preserve the full dict so callers don't silently lose extra fields
|
||||
# (e.g. error codes, metadata) returned by MemoryAgentService.
|
||||
if isinstance(result, dict):
|
||||
return {
|
||||
**result,
|
||||
"status": result.get("status", "unknown"),
|
||||
"end_user_id": end_user_id,
|
||||
}
|
||||
return {
|
||||
"status": "success" if result == "success" else result,
|
||||
"end_user_id": end_user_id
|
||||
"status": result if isinstance(result, str) else "success",
|
||||
"end_user_id": end_user_id,
|
||||
}
|
||||
|
||||
except ConfigurationError as e:
|
||||
|
||||
@@ -390,19 +390,59 @@ def get_rag_total_kb(
|
||||
current_user: User
|
||||
) -> int:
|
||||
"""
|
||||
根据当前用户所在的workspace_id查询konwledges表所有不同id的数量
|
||||
根据当前用户所在的workspace_id查询konwledges表中排除用户知识库(permission_id!='Memory')的数量
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
business_logger.info(f"获取RAG总知识库数: workspace_id={workspace_id}, 操作者: {current_user.username}")
|
||||
business_logger.info(f"获取RAG总知识库数(排除用户知识库): workspace_id={workspace_id}, 操作者: {current_user.username}")
|
||||
|
||||
try:
|
||||
total_kb = knowledge_repository.get_total_kb_count_by_workspace(db, workspace_id)
|
||||
total_kb = knowledge_repository.get_non_user_kb_count_by_workspace(db, workspace_id)
|
||||
business_logger.info(f"成功获取RAG总知识库数: {total_kb}")
|
||||
return total_kb
|
||||
except Exception as e:
|
||||
business_logger.error(f"获取RAG总知识库数失败: workspace_id={workspace_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
def get_rag_user_kb_total_chunk(
|
||||
db: Session,
|
||||
current_user: User
|
||||
) -> int:
|
||||
"""
|
||||
根据当前用户所在的workspace_id,从documents表统计所有用户知识库的chunk总数。
|
||||
与 /end_users 接口保持同源:查询 file_name 匹配 end_user_id.txt 的文档 chunk_num 之和。
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
business_logger.info(f"获取用户知识库总chunk数(documents表): workspace_id={workspace_id}, 操作者: {current_user.username}")
|
||||
|
||||
try:
|
||||
from app.models.document_model import Document
|
||||
from app.models.end_user_model import EndUser
|
||||
from app.models.app_model import App
|
||||
from sqlalchemy import func
|
||||
|
||||
# 通过 App 关联取该 workspace 下所有 end_user_id
|
||||
end_user_ids = [
|
||||
str(eid) for (eid,) in db.query(EndUser.id)
|
||||
.join(App, EndUser.app_id == App.id)
|
||||
.filter(App.workspace_id == workspace_id)
|
||||
.all()
|
||||
]
|
||||
if not end_user_ids:
|
||||
return 0
|
||||
|
||||
file_names = [f"{uid}.txt" for uid in end_user_ids]
|
||||
result = db.query(func.sum(Document.chunk_num)).filter(
|
||||
Document.file_name.in_(file_names)
|
||||
).scalar()
|
||||
|
||||
total_chunk = int(result or 0)
|
||||
business_logger.info(f"成功获取用户知识库总chunk数: {total_chunk}")
|
||||
return total_chunk
|
||||
except Exception as e:
|
||||
business_logger.error(f"获取用户知识库总chunk数失败: workspace_id={workspace_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
def get_current_user_total_chunk(
|
||||
end_user_id: str,
|
||||
db: Session,
|
||||
|
||||
@@ -1,22 +1,37 @@
|
||||
from typing import Dict, List
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.db import get_db
|
||||
from app.repositories.memory_short_repository import LongTermMemoryRepository
|
||||
from app.repositories.memory_short_repository import ShortTermMemoryRepository
|
||||
|
||||
from app.repositories.memory_short_repository import (
|
||||
LongTermMemoryRepository,
|
||||
ShortTermMemoryRepository,
|
||||
)
|
||||
|
||||
api_logger = get_api_logger()
|
||||
db=next(get_db())
|
||||
|
||||
|
||||
class ShortService:
|
||||
def __init__(self, end_user_id):
|
||||
def __init__(self, end_user_id: str, db: Session) -> None:
|
||||
"""Service for short-term memory queries.
|
||||
|
||||
Args:
|
||||
end_user_id: The end user identifier to query memories for.
|
||||
db: SQLAlchemy database session (caller-managed lifecycle).
|
||||
"""
|
||||
self.short_repo = ShortTermMemoryRepository(db)
|
||||
self.end_user_id = end_user_id
|
||||
|
||||
def get_short_databasets(self):
|
||||
def get_short_databasets(self) -> List[Dict]:
|
||||
"""Retrieve the latest short-term memory entries for the user.
|
||||
|
||||
Returns:
|
||||
List[Dict]: List of memory dicts with retrieval, message, and answer keys.
|
||||
"""
|
||||
short_memories = self.short_repo.get_latest_by_user_id(self.end_user_id, 3)
|
||||
short_result = []
|
||||
for memory in short_memories:
|
||||
deep_expanded = {} # Create a new dictionary for each memory
|
||||
deep_expanded = {}
|
||||
messages = memory.messages
|
||||
aimessages = memory.aimessages
|
||||
retrieved_content = memory.retrieved_content or []
|
||||
@@ -27,23 +42,41 @@ class ShortService:
|
||||
for item in retrieved_content:
|
||||
if isinstance(item, dict):
|
||||
for key, values in item.items():
|
||||
retrieval_source.append({"query": key, "retrieval": values,"source":"上下文记忆"})
|
||||
retrieval_source.append({"query": key, "retrieval": values, "source": "上下文记忆"})
|
||||
|
||||
deep_expanded['retrieval'] = retrieval_source
|
||||
deep_expanded['message'] = messages # 修正拼写错误
|
||||
deep_expanded['message'] = messages
|
||||
deep_expanded['answer'] = aimessages
|
||||
short_result.append(deep_expanded)
|
||||
return short_result
|
||||
def get_short_count(self):
|
||||
|
||||
def get_short_count(self) -> int:
|
||||
"""Count total short-term memory entries for the user.
|
||||
|
||||
Returns:
|
||||
int: Number of short-term memory records.
|
||||
"""
|
||||
short_count = self.short_repo.count_by_user_id(self.end_user_id)
|
||||
return short_count
|
||||
|
||||
|
||||
class LongService:
|
||||
def __init__(self, end_user_id):
|
||||
def __init__(self, end_user_id: str, db: Session) -> None:
|
||||
"""Service for long-term memory queries.
|
||||
|
||||
Args:
|
||||
end_user_id: The end user identifier to query memories for.
|
||||
db: SQLAlchemy database session (caller-managed lifecycle).
|
||||
"""
|
||||
self.long_repo = LongTermMemoryRepository(db)
|
||||
self.end_user_id = end_user_id
|
||||
def get_long_databasets(self):
|
||||
# 获取长期记忆数据
|
||||
|
||||
def get_long_databasets(self) -> List[Dict]:
|
||||
"""Retrieve long-term memory retrieval data for the user.
|
||||
|
||||
Returns:
|
||||
List[Dict]: List of dicts with query and retrieval keys.
|
||||
"""
|
||||
long_memories = self.long_repo.get_by_user_id(self.end_user_id, 1)
|
||||
|
||||
long_result = []
|
||||
|
||||
@@ -90,7 +90,8 @@ class ModelConfigService:
|
||||
api_key: str,
|
||||
api_base: Optional[str] = None,
|
||||
model_type: str = "llm",
|
||||
test_message: str = "Hello"
|
||||
test_message: str = "Hello",
|
||||
is_omni: bool = False
|
||||
) -> Dict[str, Any]:
|
||||
"""验证模型配置是否有效
|
||||
|
||||
@@ -102,6 +103,7 @@ class ModelConfigService:
|
||||
api_base: API基础URL
|
||||
model_type: 模型类型 (llm/chat/embedding/rerank)
|
||||
test_message: 测试消息
|
||||
is_omni: 是否为Omni模型
|
||||
|
||||
Returns:
|
||||
Dict: 验证结果
|
||||
@@ -114,14 +116,27 @@ class ModelConfigService:
|
||||
try:
|
||||
start_time = time.time()
|
||||
|
||||
model_config = RedBearModelConfig(
|
||||
model_name=model_name,
|
||||
provider=provider,
|
||||
api_key=api_key,
|
||||
base_url=api_base,
|
||||
temperature=0.7,
|
||||
max_tokens=100
|
||||
)
|
||||
# dashscope 的 omni 模型需要使用 compatible-mode
|
||||
if provider.lower() == ModelProvider.DASHSCOPE and is_omni:
|
||||
if not api_base:
|
||||
api_base = "https://dashscope.aliyuncs.com/compatible-mode/v1"
|
||||
model_config = RedBearModelConfig(
|
||||
model_name=model_name,
|
||||
provider=ModelProvider.OPENAI,
|
||||
api_key=api_key,
|
||||
base_url=api_base,
|
||||
temperature=0.7,
|
||||
max_tokens=100
|
||||
)
|
||||
else:
|
||||
model_config = RedBearModelConfig(
|
||||
model_name=model_name,
|
||||
provider=provider,
|
||||
api_key=api_key,
|
||||
base_url=api_base,
|
||||
temperature=0.7,
|
||||
max_tokens=100
|
||||
)
|
||||
|
||||
# 根据模型类型选择不同的验证方式
|
||||
model_type_lower = model_type.lower()
|
||||
@@ -257,8 +272,9 @@ class ModelConfigService:
|
||||
provider=model_data.provider,
|
||||
api_key=api_key_data.api_key,
|
||||
api_base=api_key_data.api_base,
|
||||
model_type=model_data.type, # 传递模型类型
|
||||
test_message="Hello"
|
||||
model_type=model_data.type,
|
||||
test_message="Hello",
|
||||
is_omni=model_data.is_omni
|
||||
)
|
||||
if not validation_result["valid"]:
|
||||
raise BusinessException(
|
||||
@@ -279,6 +295,9 @@ class ModelConfigService:
|
||||
for api_key_data in api_key_datas:
|
||||
api_key_data.model_name = model_data.name
|
||||
api_key_data.provider = model_data.provider
|
||||
# 同步capability和is_omni
|
||||
api_key_data.capability = model_data.capability
|
||||
api_key_data.is_omni = model_data.is_omni
|
||||
api_key_create_schema = ModelApiKeyCreate(
|
||||
model_config_ids=[model.id],
|
||||
**api_key_data.model_dump()
|
||||
@@ -497,6 +516,8 @@ class ModelApiKeyService:
|
||||
existing_key.config = data.config
|
||||
existing_key.priority = data.priority
|
||||
existing_key.model_name = model_name
|
||||
existing_key.capability = data.capability
|
||||
existing_key.is_omni = data.is_omni
|
||||
|
||||
# 检查是否已关联该模型配置
|
||||
if model_config not in existing_key.model_configs:
|
||||
@@ -513,7 +534,8 @@ class ModelApiKeyService:
|
||||
api_key=data.api_key,
|
||||
api_base=data.api_base,
|
||||
model_type=model_config.type,
|
||||
test_message="Hello"
|
||||
test_message="Hello",
|
||||
is_omni=data.is_omni
|
||||
)
|
||||
if not validation_result["valid"]:
|
||||
# 记录验证失败的模型,但不抛出异常
|
||||
@@ -528,6 +550,8 @@ class ModelApiKeyService:
|
||||
provider=data.provider,
|
||||
api_key=data.api_key,
|
||||
api_base=data.api_base,
|
||||
capability=data.capability if data.capability is not None else model_config.capability,
|
||||
is_omni=data.is_omni if data.is_omni is not None else model_config.is_omni,
|
||||
config=data.config,
|
||||
is_active=data.is_active,
|
||||
priority=data.priority
|
||||
@@ -572,6 +596,8 @@ class ModelApiKeyService:
|
||||
existing_key.config = api_key_data.config
|
||||
existing_key.priority = api_key_data.priority
|
||||
existing_key.model_name = api_key_data.model_name
|
||||
existing_key.capability = api_key_data.capability
|
||||
existing_key.is_omni = api_key_data.is_omni
|
||||
|
||||
# 检查是否已关联该模型配置
|
||||
if model_config not in existing_key.model_configs:
|
||||
@@ -589,7 +615,8 @@ class ModelApiKeyService:
|
||||
api_key=api_key_data.api_key,
|
||||
api_base=api_key_data.api_base,
|
||||
model_type=model_config.type,
|
||||
test_message="Hello"
|
||||
test_message="Hello",
|
||||
is_omni=model_config.is_omni
|
||||
)
|
||||
if not validation_result["valid"]:
|
||||
raise BusinessException(
|
||||
@@ -620,7 +647,8 @@ class ModelApiKeyService:
|
||||
api_key=api_key_data.api_key or existing_api_key.api_key,
|
||||
api_base=api_key_data.api_base or existing_api_key.api_base,
|
||||
model_type=model_config.type,
|
||||
test_message="Hello"
|
||||
test_message="Hello",
|
||||
is_omni=model_config.is_omni
|
||||
)
|
||||
if not validation_result["valid"]:
|
||||
raise BusinessException(
|
||||
@@ -755,6 +783,8 @@ class ModelBaseService:
|
||||
"type": model_base.type,
|
||||
"logo": model_base.logo,
|
||||
"description": model_base.description,
|
||||
"capability": model_base.capability,
|
||||
"is_omni": model_base.is_omni,
|
||||
"is_composite": False
|
||||
}
|
||||
model_config = ModelConfigRepository.create(db, model_config_data)
|
||||
|
||||
@@ -123,11 +123,14 @@ class MultiAgentOrchestrator:
|
||||
user_id: 用户 ID
|
||||
variables: 变量参数
|
||||
use_llm_routing: 是否使用 LLM 路由
|
||||
web_search: 是否启用网络搜索
|
||||
memory: 是否启用记忆功能
|
||||
storage_type: 存储类型
|
||||
user_rag_memory_id: 用户 RAG 记忆 ID
|
||||
|
||||
Yields:
|
||||
SSE 格式的事件流
|
||||
"""
|
||||
import json
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
@@ -200,7 +203,8 @@ class MultiAgentOrchestrator:
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"多 Agent 任务执行失败(流式)",
|
||||
extra={"error": str(e), "mode": self._normalized_mode}
|
||||
extra={"error": str(e), "mode": self._normalized_mode},
|
||||
exc_info=True
|
||||
)
|
||||
# 发送错误事件
|
||||
yield self._format_sse_event("error", {
|
||||
@@ -1267,7 +1271,7 @@ class MultiAgentOrchestrator:
|
||||
Yields:
|
||||
SSE 格式的事件流
|
||||
"""
|
||||
from app.services.draft_run_service import DraftRunService
|
||||
from app.services.draft_run_service import AgentRunService
|
||||
|
||||
# 获取模型配置
|
||||
model_config = self.db.get(ModelConfig, agent_config.default_model_config_id)
|
||||
@@ -1278,7 +1282,7 @@ class MultiAgentOrchestrator:
|
||||
)
|
||||
|
||||
# 流式执行 Agent
|
||||
draft_service = DraftRunService(self.db)
|
||||
draft_service = AgentRunService(self.db)
|
||||
async for event in draft_service.run_stream(
|
||||
agent_config=agent_config,
|
||||
model_config=model_config,
|
||||
@@ -1320,7 +1324,7 @@ class MultiAgentOrchestrator:
|
||||
Returns:
|
||||
执行结果
|
||||
"""
|
||||
from app.services.draft_run_service import DraftRunService
|
||||
from app.services.draft_run_service import AgentRunService
|
||||
|
||||
# 获取模型配置
|
||||
model_config = self.db.get(ModelConfig, agent_config.default_model_config_id)
|
||||
@@ -1331,7 +1335,7 @@ class MultiAgentOrchestrator:
|
||||
)
|
||||
|
||||
# 执行 Agent
|
||||
draft_service = DraftRunService(self.db)
|
||||
draft_service = AgentRunService(self.db)
|
||||
result = await draft_service.run(
|
||||
agent_config=agent_config,
|
||||
model_config=model_config,
|
||||
@@ -1633,6 +1637,7 @@ class MultiAgentOrchestrator:
|
||||
self.memory = config_data.get("memory")
|
||||
self.variables = config_data.get("variables", [])
|
||||
self.tools = config_data.get("tools", {})
|
||||
self.skills = config_data.get("skills", {})
|
||||
self.default_model_config_id = release.default_model_config_id
|
||||
|
||||
return AgentConfigProxy(release, app, config_data)
|
||||
@@ -2593,6 +2598,7 @@ class MultiAgentOrchestrator:
|
||||
provider=api_key_config.provider,
|
||||
api_key=api_key_config.api_key,
|
||||
base_url=api_key_config.api_base,
|
||||
is_omni=api_key_config.is_omni,
|
||||
temperature=0.7, # 整合任务使用中等温度
|
||||
max_tokens=2000
|
||||
)
|
||||
@@ -2758,6 +2764,7 @@ class MultiAgentOrchestrator:
|
||||
provider=api_key_config.provider,
|
||||
api_key=api_key_config.api_key,
|
||||
base_url=api_key_config.api_base,
|
||||
is_omni=api_key_config.is_omni,
|
||||
temperature=0.7,
|
||||
max_tokens=2000,
|
||||
extra_params={"streaming": True} # 启用流式输出
|
||||
|
||||
@@ -267,7 +267,7 @@ class MultiAgentService:
|
||||
|
||||
# 2. 验证模型配置(如果提供了)
|
||||
if data.default_model_config_id:
|
||||
model_api_key = ModelApiKeyService.get_a_api_key(self.db, data.default_model_config_id)
|
||||
model_api_key = ModelApiKeyService.get_available_api_key(self.db, data.default_model_config_id)
|
||||
if not model_api_key:
|
||||
raise ResourceNotFoundException("模型配置", str(data.default_model_config_id))
|
||||
|
||||
|
||||
@@ -9,47 +9,100 @@
|
||||
- OpenAI: 支持 URL 和 base64 格式
|
||||
"""
|
||||
import uuid
|
||||
from typing import List, Dict, Any, Optional, Protocol
|
||||
import httpx
|
||||
import base64
|
||||
from typing import List, Dict, Any, Optional
|
||||
from abc import ABC, abstractmethod
|
||||
from sqlalchemy.orm import Session
|
||||
from docx import Document
|
||||
import io
|
||||
import PyPDF2
|
||||
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.error_codes import BizCode
|
||||
from app.schemas.app_schema import FileInput, FileType, TransferMethod
|
||||
from app.models.generic_file_model import GenericFile
|
||||
from app.models.file_metadata_model import FileMetadata
|
||||
from app.core.config import settings
|
||||
from app.services.audio_transcription_service import AudioTranscriptionService
|
||||
|
||||
logger = get_business_logger()
|
||||
|
||||
|
||||
class ImageFormatStrategy(Protocol):
|
||||
"""图片格式策略接口"""
|
||||
class MultimodalFormatStrategy(ABC):
|
||||
"""多模态格式策略基类"""
|
||||
|
||||
@abstractmethod
|
||||
async def format_image(self, url: str) -> Dict[str, Any]:
|
||||
"""格式化图片"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def format_document(self, file_name: str, text: str) -> Dict[str, Any]:
|
||||
"""格式化文档"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def format_audio(self, file_type: str, url: str) -> Dict[str, Any]:
|
||||
"""格式化音频"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def format_video(self, url: str) -> Dict[str, Any]:
|
||||
"""格式化视频"""
|
||||
pass
|
||||
|
||||
|
||||
class DashScopeFormatStrategy(MultimodalFormatStrategy):
|
||||
"""通义千问策略"""
|
||||
|
||||
async def format_image(self, url: str) -> Dict[str, Any]:
|
||||
"""将图片 URL 转换为特定 provider 的格式"""
|
||||
...
|
||||
|
||||
|
||||
class DashScopeImageStrategy:
|
||||
"""通义千问图片格式策略"""
|
||||
|
||||
async def format_image(self, url: str) -> Dict[str, Any]:
|
||||
"""通义千问格式: {"type": "image", "image": "url"}"""
|
||||
"""通义千问图片格式:{"type": "image", "image": "url"}"""
|
||||
return {
|
||||
"type": "image",
|
||||
"image": url
|
||||
}
|
||||
|
||||
async def format_document(self, file_name: str, text: str) -> Dict[str, Any]:
|
||||
"""通义千问文档格式"""
|
||||
return {
|
||||
"type": "text",
|
||||
"text": f"<document name=\"{file_name}\">\n{text}\n</document>"
|
||||
}
|
||||
|
||||
class BedrockImageStrategy:
|
||||
"""Bedrock/Anthropic 图片格式策略"""
|
||||
async def format_audio(self, file_type: str, url: str, transcription: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
通义千问音频格式
|
||||
- 原生支持: qwen-audio 系列
|
||||
- 其他模型: 需要转录为文本
|
||||
"""
|
||||
if transcription:
|
||||
return {
|
||||
"type": "text",
|
||||
"text": f"<audio url=\"{url}\">\n{transcription}\n</audio>"
|
||||
}
|
||||
# 通义千问音频格式:{"type": "audio", "audio": "url"}
|
||||
return {
|
||||
"type": "audio",
|
||||
"audio": url
|
||||
}
|
||||
|
||||
async def format_video(self, url: str) -> Dict[str, Any]:
|
||||
"""通义千问视频格式(qwen-vl 系列原生支持)"""
|
||||
return {
|
||||
"type": "video",
|
||||
"video": url
|
||||
}
|
||||
|
||||
|
||||
class BedrockFormatStrategy(MultimodalFormatStrategy):
|
||||
"""Bedrock/Anthropic 策略"""
|
||||
|
||||
async def format_image(self, url: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Bedrock/Anthropic 格式: base64 编码
|
||||
{"type": "image", "source": {"type": "base64", "media_type": "...", "data": "..."}}
|
||||
"""
|
||||
import httpx
|
||||
import base64
|
||||
from mimetypes import guess_type
|
||||
|
||||
logger.info(f"下载并编码图片: {url}")
|
||||
@@ -84,9 +137,46 @@ class BedrockImageStrategy:
|
||||
}
|
||||
}
|
||||
|
||||
async def format_document(self, file_name: str, text: str) -> Dict[str, Any]:
|
||||
"""Bedrock/Anthropic 文档格式(需要 base64 编码)"""
|
||||
# Bedrock 文档需要 base64 编码
|
||||
text_bytes = text.encode('utf-8')
|
||||
base64_text = base64.b64encode(text_bytes).decode('utf-8')
|
||||
|
||||
class OpenAIImageStrategy:
|
||||
"""OpenAI 图片格式策略"""
|
||||
return {
|
||||
"type": "document",
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": "text/plain",
|
||||
"data": base64_text
|
||||
}
|
||||
}
|
||||
|
||||
async def format_audio(self, file_type: str, url: str, transcription: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
Bedrock/Anthropic 音频格式
|
||||
不支持原生音频,必须转录为文本
|
||||
"""
|
||||
if transcription:
|
||||
return {
|
||||
"type": "text",
|
||||
"text": f"[音频转录]\n{transcription}"
|
||||
}
|
||||
return {
|
||||
"type": "text",
|
||||
"text": "[音频文件:Bedrock 不支持原生音频,请启用音频转文本功能]"
|
||||
}
|
||||
|
||||
async def format_video(self, url: str) -> Dict[str, Any]:
|
||||
"""Bedrock/Anthropic 视频格式"""
|
||||
return {
|
||||
"type": "text",
|
||||
"text": f"<video url=\"{url}\">\n[视频文件,当前 provider 暂不支持]\n</video>"
|
||||
}
|
||||
|
||||
|
||||
class OpenAIFormatStrategy(MultimodalFormatStrategy):
|
||||
"""OpenAI 策略"""
|
||||
|
||||
async def format_image(self, url: str) -> Dict[str, Any]:
|
||||
"""OpenAI 格式: {"type": "image_url", "image_url": {"url": "..."}}"""
|
||||
@@ -97,29 +187,97 @@ class OpenAIImageStrategy:
|
||||
}
|
||||
}
|
||||
|
||||
async def format_document(self, file_name: str, text: str) -> Dict[str, Any]:
|
||||
"""OpenAI 文档格式"""
|
||||
return {
|
||||
"type": "text",
|
||||
"text": f"<document name=\"{file_name}\">\n{text}\n</document>"
|
||||
}
|
||||
|
||||
async def format_audio(self, file_type: str, url: str, transcription: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
OpenAI 音频格式
|
||||
- gpt-4o-audio 系列支持原生音频(需要 base64 编码)
|
||||
- 其他模型使用转录文本
|
||||
"""
|
||||
if transcription:
|
||||
return {
|
||||
"type": "text",
|
||||
"text": f"<audio url=\"{url}\">\n{transcription}\n</audio>"
|
||||
}
|
||||
|
||||
# OpenAI 音频需要 base64 编码
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
response = await client.get(url)
|
||||
response.raise_for_status()
|
||||
audio_data = response.content
|
||||
base64_audio = base64.b64encode(audio_data).decode('utf-8')
|
||||
# 1. 优先从 file_type (MIME) 取扩展名
|
||||
file_ext = file_type.split('/')[-1] if file_type and '/' in file_type else None
|
||||
# 2. 从响应头 content-type 取
|
||||
if not file_ext:
|
||||
ct = response.headers.get("content-type", "")
|
||||
file_ext = ct.split('/')[-1].split(';')[0].strip() if '/' in ct else None
|
||||
# 3. 从 URL 路径取扩展名
|
||||
if not file_ext:
|
||||
file_ext = url.split('?')[0].rsplit('.', 1)[-1].lower() or None
|
||||
# 4. 默认 wav
|
||||
# supported_ext = {"wav", "mp3", "mp4", "ogg", "flac", "webm", "m4a", "wave", "x-m4a"}
|
||||
file_ext = "wav" if not file_ext else file_ext
|
||||
|
||||
return {
|
||||
"type": "input_audio",
|
||||
"input_audio": {
|
||||
"data": f"data:;base64,{base64_audio}",
|
||||
"format": file_ext
|
||||
}
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"下载音频失败: {e}")
|
||||
return {
|
||||
"type": "text",
|
||||
"text": f"[音频处理失败: {str(e)}]"
|
||||
}
|
||||
|
||||
async def format_video(self, url: str) -> Dict[str, Any]:
|
||||
"""OpenAI 视频格式"""
|
||||
return {
|
||||
"type": "video_url",
|
||||
"video_url": {
|
||||
"url": url
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
# Provider 到策略的映射
|
||||
PROVIDER_STRATEGIES = {
|
||||
"dashscope": DashScopeImageStrategy,
|
||||
"bedrock": BedrockImageStrategy,
|
||||
"anthropic": BedrockImageStrategy,
|
||||
"openai": OpenAIImageStrategy,
|
||||
"dashscope": DashScopeFormatStrategy,
|
||||
"bedrock": BedrockFormatStrategy,
|
||||
"anthropic": BedrockFormatStrategy,
|
||||
"openai": OpenAIFormatStrategy,
|
||||
}
|
||||
|
||||
|
||||
class MultimodalService:
|
||||
"""多模态文件处理服务"""
|
||||
|
||||
def __init__(self, db: Session, provider: str = "dashscope"):
|
||||
def __init__(self, db: Session, provider: str = "dashscope", api_key: Optional[str] = None, enable_audio_transcription: bool = False, is_omni: bool = False):
|
||||
"""
|
||||
初始化多模态服务
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
provider: 模型提供商(dashscope, bedrock, anthropic 等)
|
||||
provider: 模型提供商(dashscope, bedrock, anthropic, openai 等)
|
||||
api_key: API 密钥(用于音频转文本)
|
||||
enable_audio_transcription: 是否启用音频转文本
|
||||
is_omni: 是否为 Omni 模型(dashscope 的 omni 模型需要使用 OpenAI 兼容格式)
|
||||
"""
|
||||
self.db = db
|
||||
self.provider = provider.lower()
|
||||
self.api_key = api_key
|
||||
self.enable_audio_transcription = enable_audio_transcription
|
||||
self.is_omni = is_omni
|
||||
|
||||
async def process_files(
|
||||
self,
|
||||
@@ -137,20 +295,32 @@ class MultimodalService:
|
||||
if not files:
|
||||
return []
|
||||
|
||||
# 获取对应的策略
|
||||
# dashscope 的 omni 模型使用 OpenAI 兼容格式
|
||||
if self.provider == "dashscope" and self.is_omni:
|
||||
strategy_class = OpenAIFormatStrategy
|
||||
else:
|
||||
strategy_class = PROVIDER_STRATEGIES.get(self.provider)
|
||||
if not strategy_class:
|
||||
logger.warning(f"未找到 provider '{self.provider}' 的策略,使用默认策略")
|
||||
strategy_class = DashScopeFormatStrategy
|
||||
|
||||
strategy = strategy_class()
|
||||
|
||||
result = []
|
||||
for idx, file in enumerate(files):
|
||||
try:
|
||||
if file.type == FileType.IMAGE:
|
||||
content = await self._process_image(file)
|
||||
content = await self._process_image(file, strategy)
|
||||
result.append(content)
|
||||
elif file.type == FileType.DOCUMENT:
|
||||
content = await self._process_document(file)
|
||||
content = await self._process_document(file, strategy)
|
||||
result.append(content)
|
||||
elif file.type == FileType.AUDIO:
|
||||
content = await self._process_audio(file)
|
||||
content = await self._process_audio(file, strategy)
|
||||
result.append(content)
|
||||
elif file.type == FileType.VIDEO:
|
||||
content = await self._process_video(file)
|
||||
content = await self._process_video(file, strategy)
|
||||
result.append(content)
|
||||
else:
|
||||
logger.warning(f"不支持的文件类型: {file.type}")
|
||||
@@ -172,55 +342,29 @@ class MultimodalService:
|
||||
logger.info(f"成功处理 {len(result)}/{len(files)} 个文件,provider={self.provider}")
|
||||
return result
|
||||
|
||||
async def _process_image(self, file: FileInput) -> Dict[str, Any]:
|
||||
async def _process_image(self, file: FileInput, strategy) -> Dict[str, Any]:
|
||||
"""
|
||||
处理图片文件
|
||||
|
||||
Args:
|
||||
file: 图片文件输入
|
||||
strategy: 格式化策略
|
||||
|
||||
Returns:
|
||||
Dict: 根据 provider 返回不同格式
|
||||
- Anthropic/Bedrock: {"type": "image", "source": {"type": "base64", "media_type": "...", "data": "..."}}
|
||||
- 通义千问: {"type": "image", "image": "url"}
|
||||
Dict: 根据 provider 返回不同格式的图片内容
|
||||
"""
|
||||
url = await self.get_file_url(file)
|
||||
|
||||
logger.debug(f"处理图片: {url}, provider={self.provider}")
|
||||
|
||||
# 根据 provider 返回不同格式
|
||||
if self.provider in ["bedrock", "anthropic"]:
|
||||
# Anthropic/Bedrock 只支持 base64 格式,需要下载并转换
|
||||
try:
|
||||
logger.info(f"开始下载并编码图片: {url}")
|
||||
base64_data, media_type = await self._download_and_encode_image(url)
|
||||
result = {
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": media_type,
|
||||
"data": base64_data[:100] + "..." # 只记录前100个字符
|
||||
}
|
||||
}
|
||||
logger.info(f"图片编码完成: media_type={media_type}, data_length={len(base64_data)}")
|
||||
# 返回完整数据
|
||||
result["source"]["data"] = base64_data
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"下载并编码图片失败: {e}", exc_info=True)
|
||||
# 返回错误提示
|
||||
return {
|
||||
"type": "text",
|
||||
"text": f"[图片加载失败: {str(e)}]"
|
||||
}
|
||||
else:
|
||||
# 通义千问等其他格式支持 URL
|
||||
try:
|
||||
url = await self.get_file_url(file)
|
||||
return await strategy.format_image(url)
|
||||
except Exception as e:
|
||||
logger.error(f"处理图片失败: {e}", exc_info=True)
|
||||
return {
|
||||
"type": "image",
|
||||
"image": url
|
||||
"type": "text",
|
||||
"text": f"[图片处理失败: {str(e)}]"
|
||||
}
|
||||
|
||||
async def _download_and_encode_image(self, url: str) -> tuple[str, str]:
|
||||
@staticmethod
|
||||
async def _download_and_encode_image(url: str) -> tuple[str, str]:
|
||||
"""
|
||||
下载图片并转换为 base64
|
||||
|
||||
@@ -230,8 +374,6 @@ class MultimodalService:
|
||||
Returns:
|
||||
tuple: (base64_data, media_type)
|
||||
"""
|
||||
import httpx
|
||||
import base64
|
||||
from mimetypes import guess_type
|
||||
|
||||
# 下载图片
|
||||
@@ -258,15 +400,16 @@ class MultimodalService:
|
||||
|
||||
return base64_data, media_type
|
||||
|
||||
async def _process_document(self, file: FileInput) -> Dict[str, Any]:
|
||||
async def _process_document(self, file: FileInput, strategy) -> Dict[str, Any]:
|
||||
"""
|
||||
处理文档文件(PDF、Word 等)
|
||||
|
||||
Args:
|
||||
file: 文档文件输入
|
||||
strategy: 格式化策略
|
||||
|
||||
Returns:
|
||||
Dict: text 格式的内容(包含提取的文本)
|
||||
Dict: 根据 provider 返回不同格式的文档内容
|
||||
"""
|
||||
if file.transfer_method == TransferMethod.REMOTE_URL:
|
||||
# 远程文档暂不支持提取
|
||||
@@ -277,48 +420,68 @@ class MultimodalService:
|
||||
else:
|
||||
# 本地文件,提取文本内容
|
||||
text = await self._extract_document_text(file.upload_file_id)
|
||||
generic_file = self.db.query(GenericFile).filter(
|
||||
GenericFile.id == file.upload_file_id
|
||||
file_metadata = self.db.query(FileMetadata).filter(
|
||||
FileMetadata.id == file.upload_file_id
|
||||
).first()
|
||||
|
||||
file_name = generic_file.file_name if generic_file else "unknown"
|
||||
file_name = file_metadata.file_name if file_metadata else "unknown"
|
||||
|
||||
return {
|
||||
"type": "text",
|
||||
"text": f"<document name=\"{file_name}\">\n{text}\n</document>"
|
||||
}
|
||||
# 使用策略格式化文档
|
||||
return await strategy.format_document(file_name, text)
|
||||
|
||||
async def _process_audio(self, file: FileInput) -> Dict[str, Any]:
|
||||
async def _process_audio(self, file: FileInput, strategy) -> Dict[str, Any]:
|
||||
"""
|
||||
处理音频文件
|
||||
|
||||
Args:
|
||||
file: 音频文件输入
|
||||
strategy: 格式化策略
|
||||
|
||||
Returns:
|
||||
Dict: 音频内容(暂时返回占位符)
|
||||
Dict: 根据 provider 返回不同格式的音频内容
|
||||
"""
|
||||
# TODO: 实现音频转文字功能
|
||||
return {
|
||||
"type": "text",
|
||||
"text": "[音频文件,暂不支持处理]"
|
||||
}
|
||||
try:
|
||||
url = await self.get_file_url(file)
|
||||
|
||||
async def _process_video(self, file: FileInput) -> Dict[str, Any]:
|
||||
# 如果启用音频转文本且有 API Key
|
||||
transcription = None
|
||||
if self.enable_audio_transcription and self.api_key:
|
||||
logger.info(f"开始音频转文本: {url}")
|
||||
if self.provider == "dashscope":
|
||||
transcription = await AudioTranscriptionService.transcribe_dashscope(url, self.api_key)
|
||||
elif self.provider == "openai":
|
||||
transcription = await AudioTranscriptionService.transcribe_openai(url, self.api_key)
|
||||
else:
|
||||
logger.warning(f"Provider {self.provider} 不支持音频转文本")
|
||||
|
||||
return await strategy.format_audio(file.file_type, url, transcription)
|
||||
except Exception as e:
|
||||
logger.error(f"处理音频失败: {e}", exc_info=True)
|
||||
return {
|
||||
"type": "text",
|
||||
"text": f"[音频处理失败: {str(e)}]"
|
||||
}
|
||||
|
||||
async def _process_video(self, file: FileInput, strategy) -> Dict[str, Any]:
|
||||
"""
|
||||
处理视频文件
|
||||
|
||||
Args:
|
||||
file: 视频文件输入
|
||||
strategy: 格式化策略
|
||||
|
||||
Returns:
|
||||
Dict: 视频内容(暂时返回占位符)
|
||||
Dict: 根据 provider 返回不同格式的视频内容
|
||||
"""
|
||||
# TODO: 实现视频处理功能
|
||||
return {
|
||||
"type": "text",
|
||||
"text": "[视频文件,暂不支持处理]"
|
||||
}
|
||||
try:
|
||||
url = await self.get_file_url(file)
|
||||
return await strategy.format_video(url)
|
||||
except Exception as e:
|
||||
logger.error(f"处理视频失败: {e}", exc_info=True)
|
||||
return {
|
||||
"type": "text",
|
||||
"text": f"[视频处理失败: {str(e)}]"
|
||||
}
|
||||
|
||||
async def get_file_url(self, file: FileInput) -> str:
|
||||
"""
|
||||
@@ -336,26 +499,22 @@ class MultimodalService:
|
||||
if file.transfer_method == TransferMethod.REMOTE_URL:
|
||||
return file.url
|
||||
else:
|
||||
# 本地文件,通过 file_storage 系统获取永久访问 URL
|
||||
from app.models.file_metadata_model import FileMetadata
|
||||
from app.core.config import settings
|
||||
|
||||
file_id = file.upload_file_id
|
||||
print("="*50)
|
||||
print("file_id",file_id)
|
||||
|
||||
|
||||
# 查询 FileMetadata
|
||||
file_metadata = self.db.query(FileMetadata).filter(
|
||||
FileMetadata.id == file_id,
|
||||
FileMetadata.status == "completed"
|
||||
).first()
|
||||
|
||||
|
||||
if not file_metadata:
|
||||
raise BusinessException(
|
||||
f"文件不存在或已删除: {file_id}",
|
||||
BizCode.NOT_FOUND
|
||||
)
|
||||
|
||||
|
||||
# 返回永久URL
|
||||
server_url = settings.FILE_LOCAL_SERVER_URL
|
||||
return f"{server_url}/storage/permanent/{file_id}"
|
||||
@@ -370,58 +529,79 @@ class MultimodalService:
|
||||
Returns:
|
||||
str: 提取的文本内容
|
||||
"""
|
||||
generic_file = self.db.query(GenericFile).filter(
|
||||
GenericFile.id == file_id,
|
||||
GenericFile.status == "active"
|
||||
file_metadata = self.db.query(FileMetadata).filter(
|
||||
FileMetadata.id == file_id,
|
||||
FileMetadata.status == "completed"
|
||||
).first()
|
||||
|
||||
if not generic_file:
|
||||
if not file_metadata:
|
||||
raise BusinessException(
|
||||
f"文件不存在或已删除: {file_id}",
|
||||
BizCode.NOT_FOUND
|
||||
)
|
||||
|
||||
# TODO: 根据文件类型提取文本
|
||||
# - PDF: 使用 PyPDF2 或 pdfplumber
|
||||
# - Word: 使用 python-docx
|
||||
# - TXT/MD: 直接读取
|
||||
|
||||
file_ext = generic_file.file_ext.lower()
|
||||
file_ext = file_metadata.file_ext.lower()
|
||||
server_url = settings.FILE_LOCAL_SERVER_URL
|
||||
file_url = f"{server_url}/storage/permanent/{file_id}"
|
||||
|
||||
if file_ext in ['.txt', '.md', '.markdown']:
|
||||
return await self._read_text_file(generic_file.storage_path)
|
||||
return await self._read_text_file(file_url)
|
||||
elif file_ext == '.pdf':
|
||||
return await self._extract_pdf_text(generic_file.storage_path)
|
||||
return await self._extract_pdf_text(file_url)
|
||||
elif file_ext in ['.doc', '.docx']:
|
||||
return await self._extract_word_text(generic_file.storage_path)
|
||||
return await self._extract_word_text(file_url)
|
||||
else:
|
||||
return f"[不支持的文档格式: {file_ext}]"
|
||||
|
||||
async def _read_text_file(self, storage_path: str) -> str:
|
||||
@staticmethod
|
||||
async def _read_text_file(file_url: str) -> str:
|
||||
"""读取纯文本文件"""
|
||||
try:
|
||||
with open(storage_path, 'r', encoding='utf-8') as f:
|
||||
return f.read()
|
||||
# 下载文件
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
response = await client.get(file_url)
|
||||
response.raise_for_status()
|
||||
return response.text
|
||||
except Exception as e:
|
||||
logger.error(f"读取文本文件失败: {e}")
|
||||
return f"[文件读取失败: {str(e)}]"
|
||||
|
||||
async def _extract_pdf_text(self, storage_path: str) -> str:
|
||||
@staticmethod
|
||||
async def _extract_pdf_text(file_url: str) -> str:
|
||||
"""提取 PDF 文本"""
|
||||
try:
|
||||
# TODO: 实现 PDF 文本提取
|
||||
# import PyPDF2 或 pdfplumber
|
||||
return "[PDF 文本提取功能待实现]"
|
||||
# 下载 PDF 文件
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
response = await client.get(file_url)
|
||||
response.raise_for_status()
|
||||
pdf_data = response.content
|
||||
|
||||
# 使用 BytesIO 读取 PDF
|
||||
text_parts = []
|
||||
pdf_file = io.BytesIO(pdf_data)
|
||||
pdf_reader = PyPDF2.PdfReader(pdf_file)
|
||||
for page in pdf_reader.pages:
|
||||
text_parts.append(page.extract_text())
|
||||
return '\n'.join(text_parts)
|
||||
except Exception as e:
|
||||
logger.error(f"提取 PDF 文本失败: {e}")
|
||||
return f"[PDF 提取失败: {str(e)}]"
|
||||
|
||||
async def _extract_word_text(self, storage_path: str) -> str:
|
||||
@staticmethod
|
||||
async def _extract_word_text(file_url: str) -> str:
|
||||
"""提取 Word 文档文本"""
|
||||
try:
|
||||
# TODO: 实现 Word 文本提取
|
||||
# import docx
|
||||
return "[Word 文本提取功能待实现]"
|
||||
# 下载 Word 文件
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
response = await client.get(file_url)
|
||||
response.raise_for_status()
|
||||
word_data = response.content
|
||||
|
||||
# 使用 BytesIO 读取 Word 文档
|
||||
word_file = io.BytesIO(word_data)
|
||||
doc = Document(word_file)
|
||||
text_parts = [paragraph.text for paragraph in doc.paragraphs]
|
||||
return '\n'.join(text_parts)
|
||||
except Exception as e:
|
||||
logger.error(f"提取 Word 文本失败: {e}")
|
||||
return f"[Word 提取失败: {str(e)}]"
|
||||
|
||||
@@ -101,34 +101,141 @@ async def run_pilot_extraction(
|
||||
)
|
||||
|
||||
if progress_callback:
|
||||
await progress_callback("text_preprocessing", "开始预处理文本...")
|
||||
await progress_callback("text_preprocessing", "开始预处理文本(语义剪枝 + 语义分块)...")
|
||||
|
||||
# ========== 步骤 2.1: 语义剪枝 ==========
|
||||
pruned_dialogs = [dialog]
|
||||
deleted_messages = [] # 记录被删除的消息
|
||||
pruning_stats = None # 保存剪枝统计信息,用于最终汇总
|
||||
|
||||
if memory_config.pruning_enabled:
|
||||
try:
|
||||
from app.core.memory.storage_services.extraction_engine.data_preprocessing.data_pruning import (
|
||||
SemanticPruner,
|
||||
)
|
||||
from app.core.memory.models.config_models import PruningConfig
|
||||
|
||||
# 构建剪枝配置
|
||||
pruning_config_dict = {
|
||||
"pruning_switch": memory_config.pruning_enabled,
|
||||
"pruning_scene": memory_config.pruning_scene,
|
||||
"pruning_threshold": memory_config.pruning_threshold,
|
||||
"llm_model_id": str(memory_config.llm_model_id),
|
||||
}
|
||||
config = PruningConfig(**pruning_config_dict)
|
||||
|
||||
logger.info(f"[PILOT_RUN] 开始语义剪枝: scene={config.pruning_scene}, threshold={config.pruning_threshold}")
|
||||
|
||||
# 记录剪枝前的消息(用于对比)
|
||||
original_messages = [{"role": msg.role, "content": msg.msg} for msg in dialog.context.msgs]
|
||||
original_msg_count = len(original_messages)
|
||||
|
||||
# 执行剪枝
|
||||
pruner = SemanticPruner(config=config, llm_client=llm_client)
|
||||
pruned_dialogs = await pruner.prune_dataset([dialog])
|
||||
|
||||
# 计算剪枝结果并找出被删除的消息
|
||||
if pruned_dialogs and pruned_dialogs[0].context:
|
||||
remaining_messages = [{"role": msg.role, "content": msg.msg} for msg in pruned_dialogs[0].context.msgs]
|
||||
remaining_msg_count = len(remaining_messages)
|
||||
deleted_msg_count = original_msg_count - remaining_msg_count
|
||||
|
||||
# 找出被删除的消息(基于索引精确匹配)
|
||||
# 为剩余消息创建带索引的列表,用于精确追踪
|
||||
remaining_with_index = []
|
||||
remaining_idx = 0
|
||||
for orig_idx, orig_msg in enumerate(original_messages):
|
||||
if remaining_idx < len(remaining_messages) and \
|
||||
orig_msg["role"] == remaining_messages[remaining_idx]["role"] and \
|
||||
orig_msg["content"] == remaining_messages[remaining_idx]["content"]:
|
||||
remaining_with_index.append(orig_idx)
|
||||
remaining_idx += 1
|
||||
|
||||
# 找出未在保留列表中的消息索引
|
||||
deleted_messages = [
|
||||
{"index": idx, "role": msg["role"], "content": msg["content"]}
|
||||
for idx, msg in enumerate(original_messages)
|
||||
if idx not in remaining_with_index
|
||||
]
|
||||
|
||||
# 保存剪枝统计信息(用于最终汇总,只保留deleted_count)
|
||||
pruning_stats = {
|
||||
"enabled": True,
|
||||
"scene": config.pruning_scene,
|
||||
"threshold": config.pruning_threshold,
|
||||
"deleted_count": deleted_msg_count,
|
||||
}
|
||||
|
||||
# 输出剪枝结果(显示删除的消息详情)
|
||||
pruning_result = {
|
||||
"type": "pruning",
|
||||
"deleted_messages": deleted_messages,
|
||||
}
|
||||
|
||||
logger.info(
|
||||
f"[PILOT_RUN] 语义剪枝完成: 原始{original_msg_count}条 -> "
|
||||
f"保留{remaining_msg_count}条 (删除{deleted_msg_count}条)"
|
||||
)
|
||||
|
||||
if progress_callback:
|
||||
await progress_callback("text_preprocessing_result", "语义剪枝完成", pruning_result)
|
||||
else:
|
||||
logger.warning("[PILOT_RUN] 剪枝后对话为空,使用原始对话")
|
||||
pruned_dialogs = [dialog]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[PILOT_RUN] 语义剪枝失败,使用原始对话: {e}", exc_info=True)
|
||||
pruned_dialogs = [dialog]
|
||||
if progress_callback:
|
||||
error_result = {
|
||||
"type": "pruning",
|
||||
"error": str(e),
|
||||
"fallback": "使用原始对话"
|
||||
}
|
||||
await progress_callback("text_preprocessing_result", "语义剪枝失败", error_result)
|
||||
else:
|
||||
logger.info("[PILOT_RUN] 语义剪枝已关闭,跳过")
|
||||
pruning_stats = {
|
||||
"enabled": False,
|
||||
}
|
||||
|
||||
# ========== 步骤 2.2: 语义分块 ==========
|
||||
chunked_dialogs = await get_chunked_dialogs_from_preprocessed(
|
||||
data=[dialog],
|
||||
data=pruned_dialogs,
|
||||
chunker_strategy=memory_config.chunker_strategy,
|
||||
llm_client=llm_client,
|
||||
)
|
||||
logger.info(f"Processed dialogue text: {len(messages)} messages")
|
||||
|
||||
remaining_msg_count = len(pruned_dialogs[0].context.msgs) if pruned_dialogs and pruned_dialogs[0].context else 0
|
||||
logger.info(f"Processed dialogue text: {remaining_msg_count} messages after pruning")
|
||||
|
||||
# 进度回调:输出每个分块的结果
|
||||
if progress_callback:
|
||||
for dlg in chunked_dialogs:
|
||||
for i, chunk in enumerate(dlg.chunks):
|
||||
chunk_result = {
|
||||
"chunk_index": i + 1,
|
||||
"content": chunk.content[:200] + "..." if len(chunk.content) > 200 else chunk.content,
|
||||
"full_length": len(chunk.content),
|
||||
"dialog_id": dlg.id,
|
||||
"chunker_strategy": memory_config.chunker_strategy,
|
||||
}
|
||||
await progress_callback("text_preprocessing_result", f"分块 {i + 1} 处理完成", chunk_result)
|
||||
if hasattr(dlg, 'chunks') and dlg.chunks:
|
||||
for i, chunk in enumerate(dlg.chunks):
|
||||
chunk_result = {
|
||||
"type": "chunking",
|
||||
"chunk_index": i + 1,
|
||||
"content": chunk.content[:200] + "..." if len(chunk.content) > 200 else chunk.content,
|
||||
"full_length": len(chunk.content),
|
||||
"dialog_id": dlg.id,
|
||||
"chunker_strategy": memory_config.chunker_strategy,
|
||||
}
|
||||
await progress_callback("text_preprocessing_result", f"分块 {i + 1} 处理完成", chunk_result)
|
||||
|
||||
# 构建预处理完成总结(包含剪枝统计)
|
||||
preprocessing_summary = {
|
||||
"total_chunks": sum(len(dlg.chunks) for dlg in chunked_dialogs),
|
||||
"total_chunks": sum(len(dlg.chunks) for dlg in chunked_dialogs if hasattr(dlg, 'chunks') and dlg.chunks),
|
||||
"total_dialogs": len(chunked_dialogs),
|
||||
"chunker_strategy": memory_config.chunker_strategy,
|
||||
}
|
||||
await progress_callback("text_preprocessing_complete", "预处理文本完成", preprocessing_summary)
|
||||
|
||||
# 添加剪枝统计信息
|
||||
if pruning_stats:
|
||||
preprocessing_summary["pruning"] = pruning_stats
|
||||
|
||||
await progress_callback("text_preprocessing_complete", "预处理文本完成(剪枝 + 分块)", preprocessing_summary)
|
||||
|
||||
log_time("Data Loading & Chunking", time.time() - step_start, log_file)
|
||||
|
||||
|
||||
@@ -184,7 +184,8 @@ class PromptOptimizerService:
|
||||
model_name=api_config.model_name,
|
||||
provider=api_config.provider,
|
||||
api_key=api_config.api_key,
|
||||
base_url=api_config.api_base
|
||||
base_url=api_config.api_base,
|
||||
is_omni=api_config.is_omni
|
||||
), type=ModelType(model_config.type))
|
||||
try:
|
||||
prompt_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'prompt')
|
||||
|
||||
@@ -21,63 +21,64 @@ from app.repositories import knowledge_repository
|
||||
import json
|
||||
from app.services.task_service import get_task_memory_write_result
|
||||
from app.tasks import write_message_task
|
||||
|
||||
logger = get_business_logger()
|
||||
|
||||
|
||||
class SharedChatService:
|
||||
"""基于分享链接的聊天服务"""
|
||||
|
||||
|
||||
def __init__(self, db: Session):
|
||||
self.db = db
|
||||
self.conversation_service = ConversationService(db)
|
||||
self.share_service = ReleaseShareService(db)
|
||||
|
||||
def _get_release_by_share_token(
|
||||
self,
|
||||
share_token: str,
|
||||
password: Optional[str] = None
|
||||
|
||||
def get_release_by_share_token(
|
||||
self,
|
||||
share_token: str,
|
||||
password: Optional[str] = None
|
||||
) -> tuple[ReleaseShare, AppRelease]:
|
||||
"""通过 share_token 获取发布版本"""
|
||||
# 获取分享配置
|
||||
share = self.share_service.repo.get_by_share_token(share_token)
|
||||
if not share:
|
||||
raise ResourceNotFoundException("分享链接", share_token)
|
||||
|
||||
|
||||
# 验证分享是否启用
|
||||
if not share.is_enabled:
|
||||
raise BusinessException("该分享链接已被禁用", BizCode.SHARE_DISABLED)
|
||||
|
||||
|
||||
# 验证密码
|
||||
if share.require_password:
|
||||
if not password:
|
||||
raise BusinessException("需要提供访问密码", BizCode.PASSWORD_REQUIRED)
|
||||
|
||||
|
||||
if not self.share_service.verify_password(share_token, password):
|
||||
raise BusinessException("访问密码错误", BizCode.INVALID_PASSWORD)
|
||||
|
||||
|
||||
# 获取发布版本
|
||||
release = self.db.get(AppRelease, share.release_id)
|
||||
if not release:
|
||||
raise ResourceNotFoundException("发布版本", str(share.release_id))
|
||||
|
||||
|
||||
# 更新访问统计
|
||||
try:
|
||||
self.share_service.repo.increment_view_count(share.id)
|
||||
except Exception as e:
|
||||
logger.warning(f"更新访问统计失败: {str(e)}")
|
||||
|
||||
|
||||
return share, release
|
||||
|
||||
|
||||
def create_or_get_conversation(
|
||||
self,
|
||||
share_token: str,
|
||||
conversation_id: Optional[uuid.UUID] = None,
|
||||
user_id: Optional[str] = None,
|
||||
password: Optional[str] = None
|
||||
self,
|
||||
share_token: str,
|
||||
conversation_id: Optional[uuid.UUID] = None,
|
||||
user_id: Optional[str] = None,
|
||||
password: Optional[str] = None
|
||||
) -> Conversation:
|
||||
"""创建或获取会话"""
|
||||
share, release = self._get_release_by_share_token(share_token, password)
|
||||
|
||||
share, release = self.get_release_by_share_token(share_token, password)
|
||||
|
||||
# 如果提供了 conversation_id,尝试获取现有会话
|
||||
if conversation_id:
|
||||
try:
|
||||
@@ -85,18 +86,18 @@ class SharedChatService:
|
||||
conversation_id=conversation_id,
|
||||
workspace_id=release.app.workspace_id
|
||||
)
|
||||
|
||||
|
||||
# 验证会话是否属于该应用
|
||||
if conversation.app_id != release.app_id:
|
||||
raise BusinessException("会话不属于该应用", BizCode.INVALID_CONVERSATION)
|
||||
|
||||
|
||||
return conversation
|
||||
except ResourceNotFoundException:
|
||||
logger.warning(
|
||||
"会话不存在,将创建新会话",
|
||||
extra={"conversation_id": str(conversation_id)}
|
||||
)
|
||||
|
||||
|
||||
# 创建新会话(使用发布版本的配置)
|
||||
conversation = self.conversation_service.create_conversation(
|
||||
app_id=release.app_id,
|
||||
@@ -105,7 +106,7 @@ class SharedChatService:
|
||||
is_draft=False, # 分享链接使用发布版本
|
||||
config_snapshot=release.config
|
||||
)
|
||||
|
||||
|
||||
logger.info(
|
||||
"为分享链接创建新会话",
|
||||
extra={
|
||||
@@ -114,25 +115,25 @@ class SharedChatService:
|
||||
"release_id": str(release.id)
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
return conversation
|
||||
|
||||
|
||||
async def chat(
|
||||
self,
|
||||
share_token: str,
|
||||
message: str,
|
||||
conversation_id: Optional[uuid.UUID] = None,
|
||||
user_id: Optional[str] = None,
|
||||
variables: Optional[Dict[str, Any]] = None,
|
||||
password: Optional[str] = None,
|
||||
web_search: bool = False,
|
||||
memory: bool = True,
|
||||
self,
|
||||
share_token: str,
|
||||
message: str,
|
||||
conversation_id: Optional[uuid.UUID] = None,
|
||||
user_id: Optional[str] = None,
|
||||
variables: Optional[Dict[str, Any]] = None,
|
||||
password: Optional[str] = None,
|
||||
web_search: bool = False,
|
||||
memory: bool = True,
|
||||
storage_type: Optional[str] = None,
|
||||
user_rag_memory_id: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""聊天(非流式)"""
|
||||
actual_config_id = None
|
||||
config_id=actual_config_id
|
||||
config_id = actual_config_id
|
||||
from app.core.agent.langchain_agent import LangChainAgent
|
||||
from app.services.draft_run_service import create_knowledge_retrieval_tool, create_long_term_memory_tool
|
||||
from app.services.model_parameter_merger import ModelParameterMerger
|
||||
@@ -140,32 +141,30 @@ class SharedChatService:
|
||||
from sqlalchemy import select
|
||||
from app.models import ModelApiKey
|
||||
|
||||
|
||||
start_time = time.time()
|
||||
actual_config_id=None
|
||||
config_id=actual_config_id
|
||||
|
||||
actual_config_id = None
|
||||
config_id = actual_config_id
|
||||
|
||||
if variables is None:
|
||||
variables = {}
|
||||
|
||||
|
||||
# 获取发布版本和配置
|
||||
share, release = self._get_release_by_share_token(share_token, password)
|
||||
|
||||
share, release = self.get_release_by_share_token(share_token, password)
|
||||
|
||||
# 获取 Agent 配置
|
||||
config = release.config or {}
|
||||
|
||||
|
||||
# 获取模型配置ID
|
||||
model_config_id = release.default_model_config_id
|
||||
if not model_config_id:
|
||||
raise BusinessException("发布版本未配置模型", BizCode.AGENT_CONFIG_MISSING)
|
||||
|
||||
|
||||
# 获取模型配置
|
||||
from app.models import ModelConfig
|
||||
model_config = self.db.get(ModelConfig, model_config_id)
|
||||
if not model_config:
|
||||
raise ResourceNotFoundException("模型配置", str(model_config_id))
|
||||
|
||||
|
||||
# 获取 API Key
|
||||
# stmt = (
|
||||
# select(ModelApiKey).join(
|
||||
@@ -184,7 +183,7 @@ class SharedChatService:
|
||||
api_key_obj = ModelApiKeyService.get_available_api_key(self.db, model_config_id)
|
||||
if not api_key_obj:
|
||||
raise BusinessException("没有可用的 API Key", BizCode.AGENT_CONFIG_MISSING)
|
||||
|
||||
|
||||
# 获取或创建会话
|
||||
conversation = self.create_or_get_conversation(
|
||||
share_token=share_token,
|
||||
@@ -192,7 +191,7 @@ class SharedChatService:
|
||||
user_id=user_id,
|
||||
password=password
|
||||
)
|
||||
|
||||
|
||||
# 处理系统提示词(支持变量替换)
|
||||
system_prompt = config.get("system_prompt", "你是一个专业的AI助手")
|
||||
if variables:
|
||||
@@ -202,31 +201,31 @@ class SharedChatService:
|
||||
variables
|
||||
)
|
||||
system_prompt = system_prompt_rendered.get_text_content() or system_prompt
|
||||
|
||||
|
||||
# 准备工具列表
|
||||
tools = []
|
||||
|
||||
|
||||
# 添加知识库检索工具
|
||||
knowledge_retrieval = config.get("knowledge_retrieval")
|
||||
if knowledge_retrieval:
|
||||
knowledge_bases = knowledge_retrieval.get("knowledge_bases", [])
|
||||
kb_ids = [kb.get("kb_id") for kb in knowledge_bases if kb.get("kb_id")]
|
||||
if kb_ids:
|
||||
kb_tool = create_knowledge_retrieval_tool(knowledge_retrieval, kb_ids,user_id)
|
||||
kb_tool = create_knowledge_retrieval_tool(knowledge_retrieval, kb_ids, user_id)
|
||||
tools.append(kb_tool)
|
||||
|
||||
# 添加长期记忆工具
|
||||
memory_flag=False
|
||||
memory_flag = False
|
||||
if memory:
|
||||
memory_config = config.get("memory", {})
|
||||
if memory_config.get("enabled") and user_id:
|
||||
memory_flag=True
|
||||
memory_flag = True
|
||||
memory_tool = create_long_term_memory_tool(memory_config, user_id)
|
||||
tools.append(memory_tool)
|
||||
|
||||
web_tools=config.get("tools")
|
||||
web_tools = config.get("tools")
|
||||
web_search_choice = web_tools.get("web_search", {})
|
||||
web_search_enable = web_search_choice.get("enabled",False)
|
||||
web_search_enable = web_search_choice.get("enabled", False)
|
||||
if web_search:
|
||||
if web_search_enable:
|
||||
search_tool = create_web_search_tool({})
|
||||
@@ -238,26 +237,27 @@ class SharedChatService:
|
||||
"tool_count": len(tools)
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
# 获取模型参数
|
||||
model_parameters = config.get("model_parameters", {})
|
||||
|
||||
|
||||
# 创建 LangChain Agent
|
||||
agent = LangChainAgent(
|
||||
model_name=api_key_obj.model_name,
|
||||
api_key=api_key_obj.api_key,
|
||||
provider=api_key_obj.provider,
|
||||
api_base=api_key_obj.api_base,
|
||||
is_omni=api_key_obj.is_omni,
|
||||
temperature=model_parameters.get("temperature", 0.7),
|
||||
max_tokens=model_parameters.get("max_tokens", 2000),
|
||||
system_prompt=system_prompt,
|
||||
tools=tools,
|
||||
|
||||
)
|
||||
|
||||
|
||||
# 加载历史消息
|
||||
history = []
|
||||
memory_config={"enabled":True,'max_history':10}
|
||||
memory_config = {"enabled": True, 'max_history': 10}
|
||||
if memory_config.get("enabled"):
|
||||
messages = self.conversation_service.get_messages(
|
||||
conversation_id=conversation.id,
|
||||
@@ -267,7 +267,7 @@ class SharedChatService:
|
||||
{"role": msg.role, "content": msg.content}
|
||||
for msg in messages
|
||||
]
|
||||
|
||||
|
||||
# 调用 Agent
|
||||
result = await agent.chat(
|
||||
message=message,
|
||||
@@ -279,7 +279,7 @@ class SharedChatService:
|
||||
config_id=config_id,
|
||||
memory_flag=memory_flag
|
||||
)
|
||||
|
||||
|
||||
# 保存消息
|
||||
self.conversation_service.save_conversation_messages(
|
||||
conversation_id=conversation.id,
|
||||
@@ -298,7 +298,7 @@ class SharedChatService:
|
||||
# role="user",
|
||||
# content=message
|
||||
# )
|
||||
|
||||
|
||||
# self.conversation_service.add_message(
|
||||
# conversation_id=conversation.id,
|
||||
# role="assistant",
|
||||
@@ -308,12 +308,11 @@ class SharedChatService:
|
||||
# "usage": result.get("usage", {})
|
||||
# }
|
||||
# )
|
||||
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
ModelApiKeyService.record_api_key_usage(self.db, api_key_obj.id)
|
||||
|
||||
|
||||
return {
|
||||
"conversation_id": conversation.id,
|
||||
"message": result["content"],
|
||||
@@ -324,19 +323,19 @@ class SharedChatService:
|
||||
}),
|
||||
"elapsed_time": elapsed_time
|
||||
}
|
||||
|
||||
|
||||
async def chat_stream(
|
||||
self,
|
||||
share_token: str,
|
||||
message: str,
|
||||
conversation_id: Optional[uuid.UUID] = None,
|
||||
user_id: Optional[str] = None,
|
||||
variables: Optional[Dict[str, Any]] = None,
|
||||
password: Optional[str] = None,
|
||||
web_search: bool = False,
|
||||
memory: bool = True,
|
||||
storage_type:Optional[str] = None,
|
||||
user_rag_memory_id: Optional[str] = None,
|
||||
self,
|
||||
share_token: str,
|
||||
message: str,
|
||||
conversation_id: Optional[uuid.UUID] = None,
|
||||
user_id: Optional[str] = None,
|
||||
variables: Optional[Dict[str, Any]] = None,
|
||||
password: Optional[str] = None,
|
||||
web_search: bool = False,
|
||||
memory: bool = True,
|
||||
storage_type: Optional[str] = None,
|
||||
user_rag_memory_id: Optional[str] = None,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""聊天(流式)"""
|
||||
from app.core.agent.langchain_agent import LangChainAgent
|
||||
@@ -345,36 +344,35 @@ class SharedChatService:
|
||||
from sqlalchemy import select
|
||||
from app.models import ModelApiKey
|
||||
import json
|
||||
|
||||
start_time = time.time()
|
||||
actual_config_id=None
|
||||
config_id=actual_config_id
|
||||
|
||||
|
||||
start_time = time.time()
|
||||
actual_config_id = None
|
||||
config_id = actual_config_id
|
||||
|
||||
if variables is None:
|
||||
variables = {}
|
||||
# 兼容新旧字段名:使用 memory_config_id
|
||||
memory_config = {"enabled": memory, "memory_config_id": "17", "max_history": 10}
|
||||
|
||||
|
||||
try:
|
||||
# 获取发布版本和配置
|
||||
share, release = self._get_release_by_share_token(share_token, password)
|
||||
|
||||
share, release = self.get_release_by_share_token(share_token, password)
|
||||
|
||||
# 获取 Agent 配置
|
||||
config = release.config or {}
|
||||
agent_config_data = config.get("agent_config", {})
|
||||
|
||||
|
||||
# 获取模型配置ID
|
||||
model_config_id = release.default_model_config_id
|
||||
if not model_config_id:
|
||||
raise BusinessException("发布版本未配置模型", BizCode.AGENT_CONFIG_MISSING)
|
||||
|
||||
|
||||
# 获取模型配置
|
||||
from app.models import ModelConfig
|
||||
model_config = self.db.get(ModelConfig, model_config_id)
|
||||
if not model_config:
|
||||
raise ResourceNotFoundException("模型配置", str(model_config_id))
|
||||
|
||||
|
||||
# 获取 API Key
|
||||
# stmt = (
|
||||
# select(ModelApiKey).join(
|
||||
@@ -393,7 +391,7 @@ class SharedChatService:
|
||||
api_key_obj = ModelApiKeyService.get_available_api_key(self.db, model_config_id)
|
||||
if not api_key_obj:
|
||||
raise BusinessException("没有可用的 API Key", BizCode.AGENT_CONFIG_MISSING)
|
||||
|
||||
|
||||
# 获取或创建会话
|
||||
conversation = self.create_or_get_conversation(
|
||||
share_token=share_token,
|
||||
@@ -401,7 +399,7 @@ class SharedChatService:
|
||||
user_id=user_id,
|
||||
password=password
|
||||
)
|
||||
|
||||
|
||||
# 处理系统提示词(支持变量替换)
|
||||
system_prompt = config.get("system_prompt", "你是一个专业的AI助手")
|
||||
if variables:
|
||||
@@ -411,21 +409,21 @@ class SharedChatService:
|
||||
variables
|
||||
)
|
||||
system_prompt = system_prompt_rendered.get_text_content() or system_prompt
|
||||
|
||||
|
||||
# 准备工具列表
|
||||
tools = []
|
||||
|
||||
|
||||
# 添加知识库检索工具
|
||||
knowledge_retrieval = config.get("knowledge_retrieval")
|
||||
if knowledge_retrieval:
|
||||
knowledge_bases = knowledge_retrieval.get("knowledge_bases", [])
|
||||
kb_ids = [kb.get("kb_id") for kb in knowledge_bases if kb.get("kb_id")]
|
||||
if kb_ids:
|
||||
kb_tool = create_knowledge_retrieval_tool(knowledge_retrieval, kb_ids,user_id)
|
||||
kb_tool = create_knowledge_retrieval_tool(knowledge_retrieval, kb_ids, user_id)
|
||||
tools.append(kb_tool)
|
||||
|
||||
|
||||
# 添加长期记忆工具
|
||||
memory_flag=False
|
||||
memory_flag = False
|
||||
if memory:
|
||||
memory_config = config.get("memory", {})
|
||||
if memory_config.get("enabled") and user_id:
|
||||
@@ -450,20 +448,21 @@ class SharedChatService:
|
||||
|
||||
# 获取模型参数
|
||||
model_parameters = config.get("model_parameters", {})
|
||||
|
||||
|
||||
# 创建 LangChain Agent
|
||||
agent = LangChainAgent(
|
||||
model_name=api_key_obj.model_name,
|
||||
api_key=api_key_obj.api_key,
|
||||
provider=api_key_obj.provider,
|
||||
api_base=api_key_obj.api_base,
|
||||
is_omni=api_key_obj.is_omni,
|
||||
temperature=model_parameters.get("temperature", 0.7),
|
||||
max_tokens=model_parameters.get("max_tokens", 2000),
|
||||
system_prompt=system_prompt,
|
||||
tools=tools,
|
||||
streaming=True
|
||||
)
|
||||
|
||||
|
||||
# 加载历史消息
|
||||
history = []
|
||||
memory_config = {"enabled": True, 'max_history': 10}
|
||||
@@ -476,22 +475,22 @@ class SharedChatService:
|
||||
{"role": msg.role, "content": msg.content}
|
||||
for msg in messages
|
||||
]
|
||||
|
||||
|
||||
# 发送开始事件
|
||||
yield f"event: start\ndata: {json.dumps({'conversation_id': str(conversation.id)}, ensure_ascii=False)}\n\n"
|
||||
|
||||
|
||||
# 流式调用 Agent
|
||||
full_content = ""
|
||||
total_tokens = 0
|
||||
async for chunk in agent.chat_stream(
|
||||
message=message,
|
||||
history=history,
|
||||
context=None,
|
||||
end_user_id=user_id,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
config_id=config_id,
|
||||
memory_flag=memory_flag
|
||||
message=message,
|
||||
history=history,
|
||||
context=None,
|
||||
end_user_id=user_id,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
config_id=config_id,
|
||||
memory_flag=memory_flag
|
||||
):
|
||||
if isinstance(chunk, int):
|
||||
total_tokens = chunk
|
||||
@@ -499,16 +498,16 @@ class SharedChatService:
|
||||
full_content += chunk
|
||||
# 发送消息块事件
|
||||
yield f"event: message\ndata: {json.dumps({'content': chunk}, ensure_ascii=False)}\n\n"
|
||||
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
|
||||
# 保存消息
|
||||
self.conversation_service.add_message(
|
||||
conversation_id=conversation.id,
|
||||
role="user",
|
||||
content=message
|
||||
)
|
||||
|
||||
|
||||
self.conversation_service.add_message(
|
||||
conversation_id=conversation.id,
|
||||
role="assistant",
|
||||
@@ -524,7 +523,7 @@ class SharedChatService:
|
||||
# 发送结束事件
|
||||
end_data = {"elapsed_time": elapsed_time, "message_length": len(full_content)}
|
||||
yield f"event: end\ndata: {json.dumps(end_data, ensure_ascii=False)}\n\n"
|
||||
|
||||
|
||||
logger.info(
|
||||
"流式聊天完成",
|
||||
extra={
|
||||
@@ -533,7 +532,7 @@ class SharedChatService:
|
||||
"message_length": len(full_content)
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
except (GeneratorExit, asyncio.CancelledError):
|
||||
# 生成器被关闭或任务被取消,正常退出
|
||||
logger.debug("流式聊天被中断")
|
||||
@@ -542,39 +541,39 @@ class SharedChatService:
|
||||
logger.error(f"流式聊天失败: {str(e)}", exc_info=True)
|
||||
# 发送错误事件
|
||||
yield f"event: error\ndata: {json.dumps({'error': str(e)}, ensure_ascii=False)}\n\n"
|
||||
|
||||
|
||||
def get_conversation_messages(
|
||||
self,
|
||||
share_token: str,
|
||||
conversation_id: uuid.UUID,
|
||||
password: Optional[str] = None
|
||||
self,
|
||||
share_token: str,
|
||||
conversation_id: uuid.UUID,
|
||||
password: Optional[str] = None
|
||||
) -> Conversation:
|
||||
"""获取会话消息"""
|
||||
share, release = self._get_release_by_share_token(share_token, password)
|
||||
|
||||
share, release = self.get_release_by_share_token(share_token, password)
|
||||
|
||||
# 获取会话
|
||||
conversation = self.conversation_service.get_conversation(
|
||||
conversation_id=conversation_id,
|
||||
workspace_id=release.app.workspace_id
|
||||
)
|
||||
|
||||
|
||||
# 验证会话是否属于该应用
|
||||
if conversation.app_id != release.app_id:
|
||||
raise BusinessException("会话不属于该应用", BizCode.INVALID_CONVERSATION)
|
||||
|
||||
|
||||
return conversation
|
||||
|
||||
|
||||
def list_conversations(
|
||||
self,
|
||||
share_token: str,
|
||||
user_id: Optional[str] = None,
|
||||
password: Optional[str] = None,
|
||||
page: int = 1,
|
||||
pagesize: int = 20
|
||||
self,
|
||||
share_token: str,
|
||||
user_id: Optional[str] = None,
|
||||
password: Optional[str] = None,
|
||||
page: int = 1,
|
||||
pagesize: int = 20
|
||||
) -> tuple[list[Conversation], int]:
|
||||
"""列出会话"""
|
||||
share, release = self._get_release_by_share_token(share_token, password)
|
||||
|
||||
share, release = self.get_release_by_share_token(share_token, password)
|
||||
|
||||
conversations, total = self.conversation_service.list_conversations(
|
||||
app_id=release.app_id,
|
||||
workspace_id=release.app.workspace_id,
|
||||
@@ -583,19 +582,19 @@ class SharedChatService:
|
||||
page=page,
|
||||
pagesize=pagesize
|
||||
)
|
||||
|
||||
|
||||
return conversations, total
|
||||
|
||||
|
||||
async def multi_agent_chat(
|
||||
self,
|
||||
share_token: str,
|
||||
message: str,
|
||||
conversation_id: Optional[uuid.UUID] = None,
|
||||
user_id: Optional[str] = None,
|
||||
variables: Optional[Dict[str, Any]] = None,
|
||||
password: Optional[str] = None,
|
||||
web_search: bool = False,
|
||||
memory: bool = True,
|
||||
self,
|
||||
share_token: str,
|
||||
message: str,
|
||||
conversation_id: Optional[uuid.UUID] = None,
|
||||
user_id: Optional[str] = None,
|
||||
variables: Optional[Dict[str, Any]] = None,
|
||||
password: Optional[str] = None,
|
||||
web_search: bool = False,
|
||||
memory: bool = True,
|
||||
storage_type: Optional[str] = None,
|
||||
user_rag_memory_id: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
@@ -603,18 +602,16 @@ class SharedChatService:
|
||||
from app.services.multi_agent_service import MultiAgentService
|
||||
from app.models import MultiAgentConfig
|
||||
|
||||
|
||||
|
||||
start_time = time.time()
|
||||
actual_config_id=None
|
||||
config_id=actual_config_id
|
||||
|
||||
actual_config_id = None
|
||||
config_id = actual_config_id
|
||||
|
||||
if variables is None:
|
||||
variables = {}
|
||||
|
||||
|
||||
# 获取发布版本和配置
|
||||
share, release = self._get_release_by_share_token(share_token, password)
|
||||
|
||||
share, release = self.get_release_by_share_token(share_token, password)
|
||||
|
||||
# 获取或创建会话
|
||||
conversation = self.create_or_get_conversation(
|
||||
share_token=share_token,
|
||||
@@ -622,19 +619,19 @@ class SharedChatService:
|
||||
user_id=user_id,
|
||||
password=password
|
||||
)
|
||||
|
||||
|
||||
# 获取多 Agent 配置
|
||||
multi_agent_config = self.db.query(MultiAgentConfig).filter(
|
||||
MultiAgentConfig.app_id == release.app_id,
|
||||
MultiAgentConfig.is_active.is_(True)
|
||||
).first()
|
||||
|
||||
|
||||
if not multi_agent_config:
|
||||
raise BusinessException("多 Agent 配置不存在", BizCode.AGENT_CONFIG_MISSING)
|
||||
|
||||
|
||||
# 构建多 Agent 运行请求
|
||||
from app.schemas.multi_agent_schema import MultiAgentRunRequest
|
||||
|
||||
|
||||
multi_agent_request = MultiAgentRunRequest(
|
||||
message=message,
|
||||
conversation_id=conversation.id,
|
||||
@@ -644,23 +641,23 @@ class SharedChatService:
|
||||
web_search=web_search,
|
||||
memory=memory
|
||||
)
|
||||
|
||||
|
||||
# 使用多 Agent 服务执行
|
||||
multi_agent_service = MultiAgentService(self.db)
|
||||
result = await multi_agent_service.run(
|
||||
app_id=release.app_id,
|
||||
request=multi_agent_request
|
||||
)
|
||||
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
|
||||
# 保存消息
|
||||
self.conversation_service.add_message(
|
||||
conversation_id=conversation.id,
|
||||
role="user",
|
||||
content=message
|
||||
)
|
||||
|
||||
|
||||
self.conversation_service.add_message(
|
||||
conversation_id=conversation.id,
|
||||
role="assistant",
|
||||
@@ -672,8 +669,6 @@ class SharedChatService:
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
|
||||
return {
|
||||
"conversation_id": conversation.id,
|
||||
"message": result.get("message", ""),
|
||||
@@ -684,34 +679,33 @@ class SharedChatService:
|
||||
},
|
||||
"elapsed_time": elapsed_time
|
||||
}
|
||||
|
||||
|
||||
async def multi_agent_chat_stream(
|
||||
self,
|
||||
share_token: str,
|
||||
message: str,
|
||||
conversation_id: Optional[uuid.UUID] = None,
|
||||
user_id: Optional[str] = None,
|
||||
variables: Optional[Dict[str, Any]] = None,
|
||||
password: Optional[str] = None,
|
||||
web_search: bool = False,
|
||||
memory: bool = True,
|
||||
self,
|
||||
share_token: str,
|
||||
message: str,
|
||||
conversation_id: Optional[uuid.UUID] = None,
|
||||
user_id: Optional[str] = None,
|
||||
variables: Optional[Dict[str, Any]] = None,
|
||||
password: Optional[str] = None,
|
||||
web_search: bool = False,
|
||||
memory: bool = True,
|
||||
storage_type: Optional[str] = None,
|
||||
user_rag_memory_id:Optional[str] = None
|
||||
user_rag_memory_id: Optional[str] = None
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""多 Agent 聊天(流式)"""
|
||||
|
||||
|
||||
start_time = time.time()
|
||||
actual_config_id=None
|
||||
config_id=actual_config_id
|
||||
|
||||
actual_config_id = None
|
||||
config_id = actual_config_id
|
||||
|
||||
if variables is None:
|
||||
variables = {}
|
||||
|
||||
|
||||
try:
|
||||
# 获取发布版本和配置
|
||||
share, release = self._get_release_by_share_token(share_token, password)
|
||||
|
||||
share, release = self.get_release_by_share_token(share_token, password)
|
||||
|
||||
# 获取或创建会话
|
||||
conversation = self.create_or_get_conversation(
|
||||
share_token=share_token,
|
||||
@@ -719,28 +713,28 @@ class SharedChatService:
|
||||
user_id=user_id,
|
||||
password=password
|
||||
)
|
||||
|
||||
|
||||
# 获取多 Agent 配置
|
||||
multi_agent_config = self.db.query(MultiAgentConfig).filter(
|
||||
MultiAgentConfig.app_id == release.app_id,
|
||||
MultiAgentConfig.is_active.is_(True)
|
||||
).first()
|
||||
|
||||
|
||||
if not multi_agent_config:
|
||||
raise BusinessException("多 Agent 配置不存在", BizCode.AGENT_CONFIG_MISSING)
|
||||
|
||||
|
||||
# 获取 storage_type 和 user_rag_memory_id
|
||||
workspace_id = release.app.workspace_id
|
||||
storage_type = 'neo4j' # 默认值
|
||||
user_rag_memory_id = ''
|
||||
|
||||
|
||||
try:
|
||||
# 获取工作空间的存储类型(不需要用户权限检查,因为是公开分享)
|
||||
from app.models import Workspace
|
||||
workspace = self.db.get(Workspace, workspace_id)
|
||||
if workspace and workspace.storage_type:
|
||||
storage_type = workspace.storage_type
|
||||
|
||||
|
||||
# 获取 USER_RAG_MERORY 知识库 ID
|
||||
knowledge = knowledge_repository.get_knowledge_by_name(
|
||||
db=self.db,
|
||||
@@ -751,13 +745,13 @@ class SharedChatService:
|
||||
user_rag_memory_id = str(knowledge.id)
|
||||
except Exception as e:
|
||||
logger.warning(f"获取 storage_type 或 user_rag_memory_id 失败,使用默认值: {str(e)}")
|
||||
|
||||
|
||||
# 发送开始事件
|
||||
yield f"event: start\ndata: {json.dumps({'conversation_id': str(conversation.id)}, ensure_ascii=False)}\n\n"
|
||||
|
||||
|
||||
# 构建多 Agent 运行请求
|
||||
from app.schemas.multi_agent_schema import MultiAgentRunRequest
|
||||
|
||||
|
||||
multi_agent_request = MultiAgentRunRequest(
|
||||
message=message,
|
||||
conversation_id=conversation.id,
|
||||
@@ -767,20 +761,20 @@ class SharedChatService:
|
||||
web_search=web_search,
|
||||
memory=memory
|
||||
)
|
||||
|
||||
|
||||
# 使用多 Agent 服务流式执行
|
||||
multi_agent_service = MultiAgentService(self.db)
|
||||
full_content = ""
|
||||
|
||||
|
||||
async for event in multi_agent_service.run_stream(
|
||||
app_id=release.app_id,
|
||||
request=multi_agent_request,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id
|
||||
app_id=release.app_id,
|
||||
request=multi_agent_request,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id
|
||||
):
|
||||
# 直接转发事件
|
||||
yield event
|
||||
|
||||
|
||||
# 尝试提取内容(用于保存)
|
||||
if "data:" in event:
|
||||
try:
|
||||
@@ -790,16 +784,16 @@ class SharedChatService:
|
||||
full_content += data["content"]
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
|
||||
# 保存消息
|
||||
self.conversation_service.add_message(
|
||||
conversation_id=conversation.id,
|
||||
role="user",
|
||||
content=message
|
||||
)
|
||||
|
||||
|
||||
self.conversation_service.add_message(
|
||||
conversation_id=conversation.id,
|
||||
role="assistant",
|
||||
@@ -808,7 +802,7 @@ class SharedChatService:
|
||||
"elapsed_time": elapsed_time
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
logger.info(
|
||||
"多 Agent 流式聊天完成",
|
||||
extra={
|
||||
@@ -818,7 +812,6 @@ class SharedChatService:
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
except (GeneratorExit, asyncio.CancelledError):
|
||||
# 生成器被关闭或任务被取消,正常退出
|
||||
logger.debug("多 Agent 流式聊天被中断")
|
||||
|
||||
@@ -121,7 +121,7 @@ class SkillService:
|
||||
if skill and skill.is_active:
|
||||
# 加载技能关联的工具
|
||||
for tool_config in skill.tools:
|
||||
tool = tool_service._get_tool_instance(tool_config.get("tool_id", ""), tenant_id)
|
||||
tool = tool_service.get_tool_instance(tool_config.get("tool_id", ""), tenant_id)
|
||||
if tool:
|
||||
langchain_tool = tool.to_langchain_tool(tool_config.get("operation", None))
|
||||
tools.append(langchain_tool)
|
||||
|
||||
@@ -209,7 +209,7 @@ class ToolService:
|
||||
|
||||
try:
|
||||
# 获取工具实例
|
||||
tool = self._get_tool_instance(tool_id, tenant_id)
|
||||
tool = self.get_tool_instance(tool_id, tenant_id)
|
||||
if not tool:
|
||||
return ToolResult.error_result(
|
||||
error=f"工具不存在: {tool_id}",
|
||||
@@ -335,7 +335,7 @@ class ToolService:
|
||||
return []
|
||||
|
||||
# 获取工具实例
|
||||
tool_instance = self._get_tool_instance(str(config.id), config.tenant_id)
|
||||
tool_instance = self.get_tool_instance(str(config.id), config.tenant_id)
|
||||
if not tool_instance:
|
||||
return []
|
||||
|
||||
@@ -792,7 +792,7 @@ class ToolService:
|
||||
"""获取工具配置"""
|
||||
return self.tool_repo.find_by_id_and_tenant(self.db, uuid.UUID(tool_id), tenant_id)
|
||||
|
||||
def _get_tool_instance(self, tool_id: str, tenant_id: uuid.UUID) -> Optional[BaseTool]:
|
||||
def get_tool_instance(self, tool_id: str, tenant_id: uuid.UUID) -> Optional[BaseTool]:
|
||||
"""获取工具实例"""
|
||||
if tool_id in self._tool_cache:
|
||||
return self._tool_cache[tool_id]
|
||||
@@ -1416,7 +1416,7 @@ class ToolService:
|
||||
"""测试内置工具连接"""
|
||||
try:
|
||||
# 获取工具实例
|
||||
tool_instance = self._get_tool_instance(str(config.id), config.tenant_id)
|
||||
tool_instance = self.get_tool_instance(str(config.id), config.tenant_id)
|
||||
if not tool_instance:
|
||||
return {"success": False, "message": "无法创建工具实例"}
|
||||
|
||||
|
||||
@@ -10,6 +10,9 @@ from collections import Counter
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.logging_config import get_logger
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.db import get_db_context
|
||||
@@ -23,8 +26,6 @@ from app.services.memory_base_service import MemoryBaseService, MemoryTransServi
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
from app.services.memory_perceptual_service import MemoryPerceptualService
|
||||
from app.services.memory_short_service import ShortService
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -1035,9 +1036,10 @@ async def analytics_memory_insight_report(end_user_id: Optional[str] = None, lan
|
||||
"growth_trajectory": str # 成长轨迹
|
||||
}
|
||||
"""
|
||||
from app.core.memory.utils.prompt.prompt_utils import render_memory_insight_prompt
|
||||
from app.core.language_utils import validate_language
|
||||
import re
|
||||
|
||||
from app.core.language_utils import validate_language
|
||||
from app.core.memory.utils.prompt.prompt_utils import render_memory_insight_prompt
|
||||
|
||||
# 验证语言参数
|
||||
language = validate_language(language)
|
||||
@@ -1161,13 +1163,35 @@ async def analytics_user_summary(end_user_id: Optional[str] = None, language: st
|
||||
"one_sentence": str
|
||||
}
|
||||
"""
|
||||
from app.core.memory.utils.prompt.prompt_utils import render_user_summary_prompt
|
||||
from app.core.language_utils import validate_language
|
||||
import re
|
||||
|
||||
from app.core.language_utils import validate_language
|
||||
from app.core.memory.utils.prompt.prompt_utils import render_user_summary_prompt
|
||||
from app.db import get_db
|
||||
from app.repositories.end_user_repository import EndUserRepository
|
||||
|
||||
# 验证语言参数
|
||||
language = validate_language(language)
|
||||
|
||||
# 获取用户的 other_name 字段
|
||||
user_display_name = "该用户" if language == "zh" else "the user"
|
||||
if end_user_id:
|
||||
try:
|
||||
# 获取数据库会话并查询用户信息
|
||||
db = next(get_db())
|
||||
try:
|
||||
repo = EndUserRepository(db)
|
||||
end_user = repo.get_by_id(uuid.UUID(end_user_id))
|
||||
if end_user and end_user.other_name:
|
||||
user_display_name = end_user.other_name
|
||||
logger.info(f"使用 other_name 作为用户显示名称: {user_display_name}")
|
||||
else:
|
||||
logger.info(f"用户 {end_user_id} 的 other_name 为空,使用默认称呼: {user_display_name}")
|
||||
finally:
|
||||
db.close()
|
||||
except Exception as e:
|
||||
logger.warning(f"获取用户 other_name 失败,使用默认称呼: {str(e)}")
|
||||
|
||||
# 创建 UserSummaryHelper 实例
|
||||
user_summary_tool = UserSummaryHelper(end_user_id or os.getenv("SELECTED_end_user_id", "group_123"))
|
||||
|
||||
@@ -1184,7 +1208,8 @@ async def analytics_user_summary(end_user_id: Optional[str] = None, language: st
|
||||
user_id=user_summary_tool.user_id,
|
||||
entities=", ".join(entity_lines) if entity_lines else "(空)" if language == "zh" else "(empty)",
|
||||
statements=" | ".join(statement_samples) if statement_samples else "(空)" if language == "zh" else "(empty)",
|
||||
language=language
|
||||
language=language,
|
||||
user_display_name=user_display_name
|
||||
)
|
||||
|
||||
messages = [
|
||||
@@ -1435,7 +1460,7 @@ async def analytics_memory_types(
|
||||
short_term_count = 0
|
||||
if end_user_id:
|
||||
try:
|
||||
short_term_service = ShortService(end_user_id)
|
||||
short_term_service = ShortService(end_user_id, db)
|
||||
short_term_data = short_term_service.get_short_databasets()
|
||||
# 统计 short_term 数组的长度
|
||||
if short_term_data:
|
||||
@@ -1449,8 +1474,10 @@ async def analytics_memory_types(
|
||||
forgetting_threshold = 0.3 # 默认值
|
||||
if end_user_id:
|
||||
try:
|
||||
from app.core.memory.storage_services.forgetting_engine.config_utils import (
|
||||
load_actr_config_from_db,
|
||||
)
|
||||
from app.services.memory_agent_service import get_end_user_connected_config
|
||||
from app.core.memory.storage_services.forgetting_engine.config_utils import load_actr_config_from_db
|
||||
|
||||
# 获取用户关联的 config_id
|
||||
connected_config = get_end_user_connected_config(end_user_id, db)
|
||||
|
||||
@@ -13,7 +13,10 @@ from sqlalchemy.orm import Session
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.workflow.adapters.registry import PlatformAdapterRegistry
|
||||
from app.core.workflow.executor import execute_workflow, execute_workflow_stream
|
||||
from app.core.workflow.nodes.enums import NodeType
|
||||
from app.core.workflow.validator import validate_workflow_config
|
||||
from app.core.workflow.variable.base_variable import FileObject
|
||||
from app.db import get_db
|
||||
from app.models import App
|
||||
from app.models.workflow_model import WorkflowConfig, WorkflowExecution
|
||||
@@ -22,7 +25,7 @@ from app.repositories.workflow_repository import (
|
||||
WorkflowExecutionRepository,
|
||||
WorkflowNodeExecutionRepository
|
||||
)
|
||||
from app.schemas import DraftRunRequest
|
||||
from app.schemas import DraftRunRequest, FileInput
|
||||
from app.services.conversation_service import ConversationService
|
||||
from app.services.multi_agent_service import convert_uuids_to_str
|
||||
from app.services.multimodal_service import MultimodalService
|
||||
@@ -444,6 +447,94 @@ class WorkflowService:
|
||||
"success_rate": completed / total if total > 0 else 0
|
||||
}
|
||||
|
||||
async def _handle_file_input(self, files: list[FileInput]):
|
||||
if not files:
|
||||
return []
|
||||
|
||||
files_struct = []
|
||||
for file in files:
|
||||
files_struct.append(
|
||||
FileObject(
|
||||
type=file.type,
|
||||
url=await self.multimodal_service.get_file_url(file),
|
||||
transfer_method=file.transfer_method,
|
||||
file_id=str(file.upload_file_id),
|
||||
origin_file_type=file.file_type,
|
||||
is_file=True
|
||||
).model_dump()
|
||||
)
|
||||
return files_struct
|
||||
|
||||
@staticmethod
|
||||
def _map_public_event(event: dict) -> dict | None:
|
||||
"""
|
||||
Map internal workflow events to public-facing event formats.
|
||||
|
||||
Purpose:
|
||||
- Hide internal execution details
|
||||
- Expose a stable and simplified public event schema
|
||||
- Filter out non-public events
|
||||
- Maintain backward compatibility when possible
|
||||
|
||||
Args:
|
||||
event (dict): Internal event object, e.g.:
|
||||
{
|
||||
"event": "workflow_start",
|
||||
"data": {...}
|
||||
}
|
||||
|
||||
Returns:
|
||||
dict | None:
|
||||
- Returns the mapped public event
|
||||
- Returns None if the event should not be exposed
|
||||
"""
|
||||
event_type = event.get("event")
|
||||
payload = event.get("data")
|
||||
match event_type:
|
||||
case "workflow_start":
|
||||
return {
|
||||
"event": "start",
|
||||
"data": {
|
||||
"conversation_id": payload.get("conversation_id"),
|
||||
}
|
||||
}
|
||||
case "workflow_end":
|
||||
return {
|
||||
"event": "end",
|
||||
"data": {
|
||||
"elapsed_time": payload.get("elapsed_time"),
|
||||
"message_length": len(payload.get("output", "")),
|
||||
"error": payload.get("error", "")
|
||||
}
|
||||
}
|
||||
case "node_start" | "node_end" | "node_error" | "cycle_item":
|
||||
return None
|
||||
case _:
|
||||
return event
|
||||
|
||||
def _emit(self, public: bool, internal_event: dict):
|
||||
"""
|
||||
Unified event emission entry.
|
||||
|
||||
Args:
|
||||
public (bool):
|
||||
- True -> Emit mapped public event
|
||||
- False -> Emit raw internal event
|
||||
|
||||
internal_event (dict):
|
||||
The original internal event object
|
||||
|
||||
Returns:
|
||||
dict | None:
|
||||
- The mapped event
|
||||
- Or None if the event is filtered out
|
||||
"""
|
||||
if public:
|
||||
mapped = self._map_public_event(internal_event)
|
||||
else:
|
||||
mapped = internal_event
|
||||
return mapped
|
||||
|
||||
# ==================== 工作流执行 ====================
|
||||
|
||||
async def run(
|
||||
@@ -478,10 +569,11 @@ class WorkflowService:
|
||||
message=f"工作流配置不存在: app_id={app_id}"
|
||||
)
|
||||
|
||||
input_data = {"message": payload.message, "variables": payload.variables,
|
||||
"conversation_id": payload.conversation_id,
|
||||
"files": [file.model_dump(mode='json') for file in payload.files]
|
||||
}
|
||||
input_data = {
|
||||
"message": payload.message, "variables": payload.variables,
|
||||
"conversation_id": payload.conversation_id,
|
||||
"files": [file.model_dump(mode='json') for file in payload.files]
|
||||
}
|
||||
|
||||
# 转换 conversation_id 为 UUID
|
||||
conversation_id_uuid = uuid.UUID(payload.conversation_id) if payload.conversation_id else None
|
||||
@@ -505,22 +597,8 @@ class WorkflowService:
|
||||
"execution_config": config.execution_config
|
||||
}
|
||||
|
||||
# 4. 获取工作空间 ID(从 app 获取)
|
||||
|
||||
# 5. 执行工作流
|
||||
from app.core.workflow.executor import execute_workflow
|
||||
|
||||
try:
|
||||
files = []
|
||||
if payload.files:
|
||||
for file in payload.files:
|
||||
files.append(
|
||||
{
|
||||
"type": file.type,
|
||||
"url": await self.multimodal_service.get_file_url(file),
|
||||
"__file": True
|
||||
}
|
||||
)
|
||||
files = await self._handle_file_input(payload.files)
|
||||
input_data["files"] = files
|
||||
# 更新状态为运行中
|
||||
self.update_execution_status(execution.execution_id, "running")
|
||||
@@ -580,6 +658,7 @@ class WorkflowService:
|
||||
# "variables": result.get("variables"),
|
||||
# "messages": result.get("messages"),
|
||||
"output": result.get("output"), # 最终输出(字符串)
|
||||
"message": result.get("output"), # 最终输出(字符串)
|
||||
# "output_data": result.get("node_outputs", {}), # 所有节点输出(详细数据)
|
||||
"conversation_id": result.get("conversation_id"), # 所有节点输出(详细数据)payload., # 会话 ID
|
||||
"error_message": result.get("error"),
|
||||
@@ -599,41 +678,6 @@ class WorkflowService:
|
||||
message=f"工作流执行失败: {str(e)}"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _map_public_event(event: dict) -> dict | None:
|
||||
event_type = event.get("event")
|
||||
payload = event.get("data")
|
||||
match event_type:
|
||||
case "workflow_start":
|
||||
return {
|
||||
"event": "start",
|
||||
"data": {
|
||||
"conversation_id": payload.get("conversation_id"),
|
||||
}
|
||||
}
|
||||
case "workflow_end":
|
||||
return {
|
||||
"event": "end",
|
||||
"data": {
|
||||
"elapsed_time": payload.get("elapsed_time"),
|
||||
"message_length": len(payload.get("output", ""))
|
||||
}
|
||||
}
|
||||
case "node_start" | "node_end" | "node_error" | "cycle_item":
|
||||
return None
|
||||
case _:
|
||||
return event
|
||||
|
||||
def _emit(self, public: bool, internal_event: dict):
|
||||
"""
|
||||
decide
|
||||
"""
|
||||
if public:
|
||||
mapped = self._map_public_event(internal_event)
|
||||
else:
|
||||
mapped = internal_event
|
||||
return mapped
|
||||
|
||||
async def run_stream(
|
||||
self,
|
||||
app_id: uuid.UUID,
|
||||
@@ -668,10 +712,11 @@ class WorkflowService:
|
||||
message=f"工作流配置不存在: app_id={app_id}"
|
||||
)
|
||||
|
||||
input_data = {"message": payload.message, "variables": payload.variables,
|
||||
"conversation_id": payload.conversation_id,
|
||||
"files": [file.model_dump(mode='json') for file in payload.files]
|
||||
}
|
||||
input_data = {
|
||||
"message": payload.message, "variables": payload.variables,
|
||||
"conversation_id": payload.conversation_id,
|
||||
"files": [file.model_dump(mode='json') for file in payload.files]
|
||||
}
|
||||
|
||||
# 转换 conversation_id 为 UUID
|
||||
conversation_id_uuid = uuid.UUID(payload.conversation_id) if payload.conversation_id else None
|
||||
@@ -696,16 +741,7 @@ class WorkflowService:
|
||||
}
|
||||
|
||||
try:
|
||||
files = []
|
||||
if payload.files:
|
||||
for file in payload.files:
|
||||
files.append(
|
||||
{
|
||||
"type": file.type,
|
||||
"url": await self.multimodal_service.get_file_url(file),
|
||||
"__file": True
|
||||
}
|
||||
)
|
||||
files = await self._handle_file_input(payload.files)
|
||||
input_data["files"] = files
|
||||
self.update_execution_status(execution.execution_id, "running")
|
||||
executions = self.execution_repo.get_by_conversation_id(conversation_id=conversation_id_uuid)
|
||||
@@ -720,7 +756,6 @@ class WorkflowService:
|
||||
input_data["conv_messages"] = last_state.get("messages") or []
|
||||
break
|
||||
init_message_length = len(input_data.get("conv_messages", []))
|
||||
from app.core.workflow.executor import execute_workflow_stream
|
||||
|
||||
async for event in execute_workflow_stream(
|
||||
workflow_config=workflow_config_dict,
|
||||
@@ -778,36 +813,13 @@ class WorkflowService:
|
||||
}
|
||||
}
|
||||
|
||||
def _clean_event_for_json(self, event: dict[str, Any]) -> dict[str, Any]:
|
||||
"""清理事件数据,移除不可序列化的对象
|
||||
|
||||
Args:
|
||||
event: 原始事件数据
|
||||
|
||||
Returns:
|
||||
可序列化的事件数据
|
||||
"""
|
||||
from langchain_core.messages import BaseMessage
|
||||
|
||||
def clean_value(value):
|
||||
"""递归清理值"""
|
||||
if isinstance(value, BaseMessage):
|
||||
# 将 Message 对象转换为字典
|
||||
return {
|
||||
"type": value.__class__.__name__,
|
||||
"content": value.content,
|
||||
}
|
||||
elif isinstance(value, dict):
|
||||
return {k: clean_value(v) for k, v in value.items()}
|
||||
elif isinstance(value, list):
|
||||
return [clean_value(item) for item in value]
|
||||
elif isinstance(value, (str, int, float, bool, type(None))):
|
||||
return value
|
||||
else:
|
||||
# 其他不可序列化的对象转换为字符串
|
||||
return str(value)
|
||||
|
||||
return clean_value(event)
|
||||
@staticmethod
|
||||
def get_start_node_variables(config: dict) -> list:
|
||||
nodes = config.get("nodes", [])
|
||||
for node in nodes:
|
||||
if node.get("type") == NodeType.START:
|
||||
return node.get("config", {}).get("variables", [])
|
||||
raise BusinessException("workflow config error - start node not found")
|
||||
|
||||
|
||||
# ==================== 依赖注入函数 ====================
|
||||
|
||||
545
api/app/tasks.py
545
api/app/tasks.py
@@ -1,16 +1,16 @@
|
||||
import asyncio
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import time
|
||||
import uuid
|
||||
from uuid import UUID
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from datetime import datetime, timezone
|
||||
from math import ceil
|
||||
from pathlib import Path
|
||||
import shutil
|
||||
from typing import Any, Dict, List, Optional
|
||||
from uuid import UUID
|
||||
|
||||
import redis
|
||||
import requests
|
||||
@@ -38,7 +38,7 @@ from app.db import get_db, get_db_context
|
||||
from app.models.document_model import Document
|
||||
from app.models.file_model import File
|
||||
from app.models.knowledge_model import Knowledge
|
||||
from app.schemas import file_schema, document_schema
|
||||
from app.schemas import document_schema, file_schema
|
||||
from app.services.memory_agent_service import MemoryAgentService
|
||||
from app.utils.config_utils import resolve_config_id
|
||||
|
||||
@@ -67,8 +67,9 @@ def parse_document(file_path: str, document_id: uuid.UUID):
|
||||
Document parsing, vectorization, and storage
|
||||
"""
|
||||
# Force re-importing Trio in child processes (to avoid inheriting the state of the parent process)
|
||||
import trio
|
||||
import importlib
|
||||
|
||||
import trio
|
||||
importlib.reload(trio)
|
||||
db = next(get_db()) # Manually call the generator
|
||||
db_document = None
|
||||
@@ -256,7 +257,7 @@ def parse_document(file_path: str, document_id: uuid.UUID):
|
||||
progress_msg += f"{datetime.now().strftime('%H:%M:%S')} GraphRAG task result for task {task}:\n{result}\n"
|
||||
return result
|
||||
|
||||
try:
|
||||
def sync_task():
|
||||
trio.run(
|
||||
lambda: _run(
|
||||
row=task,
|
||||
@@ -271,6 +272,10 @@ def parse_document(file_path: str, document_id: uuid.UUID):
|
||||
with_community=with_community,
|
||||
)
|
||||
)
|
||||
try:
|
||||
with ThreadPoolExecutor(max_workers=1) as executor:
|
||||
future = executor.submit(sync_task)
|
||||
future.result() # Blocks until the task completes
|
||||
except Exception as e:
|
||||
progress_msg += f"{datetime.now().strftime('%H:%M:%S')} GraphRAG task failed for task {task}:\n{str(e)}\n"
|
||||
progress_msg += f"{datetime.now().strftime('%H:%M:%S')} Knowledge Graph done ({time.time() - start_time}s)"
|
||||
@@ -297,8 +302,9 @@ def build_graphrag_for_kb(kb_id: uuid.UUID):
|
||||
build knowledge graph
|
||||
"""
|
||||
# Force re-importing Trio in child processes (to avoid inheriting the state of the parent process)
|
||||
import trio
|
||||
import importlib
|
||||
|
||||
import trio
|
||||
importlib.reload(trio)
|
||||
db = next(get_db()) # Manually call the generator
|
||||
db_documents = None
|
||||
@@ -932,24 +938,18 @@ def read_message_task(self, end_user_id: str, message: str, history: List[Dict[s
|
||||
if actual_config_id is None:
|
||||
try:
|
||||
from app.services.memory_agent_service import get_end_user_connected_config
|
||||
db = next(get_db())
|
||||
try:
|
||||
with get_db_context() as db:
|
||||
connected_config = get_end_user_connected_config(end_user_id, db)
|
||||
actual_config_id = connected_config.get("memory_config_id")
|
||||
finally:
|
||||
db.close()
|
||||
except Exception:
|
||||
# Log but continue - will fail later with proper error
|
||||
pass
|
||||
|
||||
async def _run() -> str:
|
||||
db = next(get_db())
|
||||
try:
|
||||
with get_db_context() as db:
|
||||
service = MemoryAgentService()
|
||||
return await service.read_memory(end_user_id, message, history, search_switch, actual_config_id, db,
|
||||
storage_type, user_rag_memory_id)
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
try:
|
||||
# 使用 nest_asyncio 来避免事件循环冲突
|
||||
@@ -1049,19 +1049,15 @@ def write_message_task(self, end_user_id: str, message: list[dict], config_id: s
|
||||
if actual_config_id is None:
|
||||
try:
|
||||
from app.services.memory_agent_service import get_end_user_connected_config
|
||||
db = next(get_db())
|
||||
try:
|
||||
with get_db_context() as db:
|
||||
connected_config = get_end_user_connected_config(end_user_id, db)
|
||||
actual_config_id = connected_config.get("memory_config_id")
|
||||
finally:
|
||||
db.close()
|
||||
except Exception:
|
||||
# Log but continue - will fail later with proper error
|
||||
pass
|
||||
|
||||
async def _run() -> str:
|
||||
db = next(get_db())
|
||||
try:
|
||||
with get_db_context() as db:
|
||||
logger.info(
|
||||
f"[CELERY WRITE] Executing MemoryAgentService.write_memory with config_id={actual_config_id} (type: {type(actual_config_id).__name__}), language={language}")
|
||||
service = MemoryAgentService()
|
||||
@@ -1069,11 +1065,6 @@ def write_message_task(self, end_user_id: str, message: list[dict], config_id: s
|
||||
user_rag_memory_id, language)
|
||||
logger.info(f"[CELERY WRITE] Write completed successfully: {result}")
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"[CELERY WRITE] Write failed: {e}", exc_info=True)
|
||||
raise
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
try:
|
||||
# 使用 nest_asyncio 来避免事件循环冲突
|
||||
@@ -1304,6 +1295,203 @@ def write_total_memory_task(workspace_id: str) -> Dict[str, Any]:
|
||||
"workspace_id": workspace_id,
|
||||
"elapsed_time": elapsed_time,
|
||||
}
|
||||
@celery_app.task(
|
||||
name="app.tasks.write_all_workspaces_memory_task",
|
||||
bind=True,
|
||||
ignore_result=False,
|
||||
max_retries=3,
|
||||
acks_late=True,
|
||||
time_limit=3600,
|
||||
soft_time_limit=3300,
|
||||
)
|
||||
def write_all_workspaces_memory_task(self) -> Dict[str, Any]:
|
||||
"""定时任务:遍历所有工作空间,统计并写入记忆增量
|
||||
|
||||
此任务会:
|
||||
1. 查询所有活跃的工作空间
|
||||
2. 对每个工作空间统计记忆总量
|
||||
3. 将统计结果写入 memory_increments 表
|
||||
|
||||
Returns:
|
||||
包含任务执行结果的字典
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
async def _run() -> Dict[str, Any]:
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.models.app_model import App
|
||||
from app.models.end_user_model import EndUser
|
||||
from app.models.workspace_model import Workspace
|
||||
from app.repositories.memory_increment_repository import write_memory_increment
|
||||
from app.services.memory_storage_service import search_all
|
||||
|
||||
api_logger = get_api_logger()
|
||||
|
||||
with get_db_context() as db:
|
||||
try:
|
||||
# 获取所有活跃的工作空间
|
||||
workspaces = db.query(Workspace).filter(
|
||||
Workspace.is_active.is_(True)
|
||||
).all()
|
||||
|
||||
if not workspaces:
|
||||
api_logger.warning("没有找到活跃的工作空间")
|
||||
return {
|
||||
"status": "SUCCESS",
|
||||
"message": "没有找到活跃的工作空间",
|
||||
"workspace_count": 0,
|
||||
"workspace_results": []
|
||||
}
|
||||
|
||||
api_logger.info(f"开始统计 {len(workspaces)} 个工作空间的记忆增量")
|
||||
all_workspace_results = []
|
||||
|
||||
# 遍历每个工作空间
|
||||
for workspace in workspaces:
|
||||
workspace_id = workspace.id
|
||||
api_logger.info(f"开始处理工作空间: {workspace.name} (ID: {workspace_id})")
|
||||
|
||||
try:
|
||||
# 1. 查询当前workspace下的所有app(仅未删除的)
|
||||
apps = db.query(App).filter(
|
||||
App.workspace_id == workspace_id,
|
||||
App.is_active.is_(True)
|
||||
).all()
|
||||
|
||||
if not apps:
|
||||
# 如果没有app,总量为0
|
||||
memory_increment = write_memory_increment(
|
||||
db=db,
|
||||
workspace_id=workspace_id,
|
||||
total_num=0
|
||||
)
|
||||
all_workspace_results.append({
|
||||
"workspace_id": str(workspace_id),
|
||||
"workspace_name": workspace.name,
|
||||
"status": "SUCCESS",
|
||||
"total_num": 0,
|
||||
"end_user_count": 0,
|
||||
"memory_increment_id": str(memory_increment.id),
|
||||
"created_at": memory_increment.created_at.isoformat(),
|
||||
})
|
||||
api_logger.info(f"工作空间 {workspace.name} 没有应用,记录总量为0")
|
||||
continue
|
||||
|
||||
# 2. 查询所有app下的end_user_id(去重)
|
||||
app_ids = [app.id for app in apps]
|
||||
end_users = db.query(EndUser.id).filter(
|
||||
EndUser.app_id.in_(app_ids)
|
||||
).distinct().all()
|
||||
|
||||
# 3. 遍历所有end_user,查询每个宿主的记忆总量并累加
|
||||
total_num = 0
|
||||
end_user_details = []
|
||||
|
||||
for (end_user_id,) in end_users:
|
||||
try:
|
||||
# 调用 search_all 接口查询该宿主的总量
|
||||
result = await search_all(str(end_user_id))
|
||||
user_total = result.get("total", 0)
|
||||
total_num += user_total
|
||||
end_user_details.append({
|
||||
"end_user_id": str(end_user_id),
|
||||
"total": user_total
|
||||
})
|
||||
except Exception as e:
|
||||
# 记录单个用户查询失败,但继续处理其他用户
|
||||
api_logger.warning(f"查询用户 {end_user_id} 记忆失败: {str(e)}")
|
||||
end_user_details.append({
|
||||
"end_user_id": str(end_user_id),
|
||||
"total": 0,
|
||||
"error": str(e)
|
||||
})
|
||||
|
||||
# 4. 写入数据库
|
||||
memory_increment = write_memory_increment(
|
||||
db=db,
|
||||
workspace_id=workspace_id,
|
||||
total_num=total_num
|
||||
)
|
||||
|
||||
all_workspace_results.append({
|
||||
"workspace_id": str(workspace_id),
|
||||
"workspace_name": workspace.name,
|
||||
"status": "SUCCESS",
|
||||
"total_num": total_num,
|
||||
"end_user_count": len(end_users),
|
||||
"memory_increment_id": str(memory_increment.id),
|
||||
"created_at": memory_increment.created_at.isoformat(),
|
||||
})
|
||||
|
||||
api_logger.info(
|
||||
f"工作空间 {workspace.name} 统计完成: 总量={total_num}, 用户数={len(end_users)}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
db.rollback() # 回滚失败的事务,允许继续处理下一个工作空间
|
||||
api_logger.error(f"处理工作空间 {workspace.name} (ID: {workspace_id}) 失败: {str(e)}")
|
||||
all_workspace_results.append({
|
||||
"workspace_id": str(workspace_id),
|
||||
"workspace_name": workspace.name,
|
||||
"status": "FAILURE",
|
||||
"error": str(e),
|
||||
"total_num": 0,
|
||||
"end_user_count": 0,
|
||||
})
|
||||
|
||||
total_memory = sum(r.get("total_num", 0) for r in all_workspace_results)
|
||||
success_count = sum(1 for r in all_workspace_results if r.get("status") == "SUCCESS")
|
||||
|
||||
return {
|
||||
"status": "SUCCESS",
|
||||
"message": f"成功处理 {success_count}/{len(workspaces)} 个工作空间,总记忆量: {total_memory}",
|
||||
"workspace_count": len(workspaces),
|
||||
"success_count": success_count,
|
||||
"total_memory": total_memory,
|
||||
"workspace_results": all_workspace_results
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"记忆增量统计任务执行失败: {str(e)}")
|
||||
return {
|
||||
"status": "FAILURE",
|
||||
"error": str(e),
|
||||
"workspace_count": 0,
|
||||
"workspace_results": []
|
||||
}
|
||||
|
||||
try:
|
||||
# 使用 nest_asyncio 来避免事件循环冲突
|
||||
try:
|
||||
import nest_asyncio
|
||||
nest_asyncio.apply()
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# 尝试获取现有事件循环,如果不存在则创建新的
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
if loop.is_closed():
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
except RuntimeError:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
result = loop.run_until_complete(_run())
|
||||
elapsed_time = time.time() - start_time
|
||||
result["elapsed_time"] = elapsed_time
|
||||
result["task_id"] = self.request.id
|
||||
|
||||
return result
|
||||
except Exception as e:
|
||||
elapsed_time = time.time() - start_time
|
||||
return {
|
||||
"status": "FAILURE",
|
||||
"error": str(e),
|
||||
"elapsed_time": elapsed_time,
|
||||
"task_id": self.request.id
|
||||
}
|
||||
|
||||
|
||||
@celery_app.task(
|
||||
@@ -1924,4 +2112,307 @@ def run_forgetting_cycle_task(self, config_id: Optional[uuid.UUID] = None) -> Di
|
||||
# "config_id": config_id,
|
||||
# "elapsed_time": elapsed_time,
|
||||
# "task_id": self.request.id
|
||||
# }
|
||||
# }
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 隐性记忆和情绪数据更新定时任务
|
||||
# =============================================================================
|
||||
|
||||
@celery_app.task(
|
||||
name="app.tasks.update_implicit_emotions_storage",
|
||||
bind=True,
|
||||
ignore_result=True,
|
||||
max_retries=0,
|
||||
acks_late=False,
|
||||
time_limit=7200, # 2小时硬超时
|
||||
soft_time_limit=6900, # 1小时55分钟软超时
|
||||
)
|
||||
def update_implicit_emotions_storage(self) -> Dict[str, Any]:
|
||||
"""定时任务:更新所有用户的隐性记忆画像和情绪建议数据
|
||||
|
||||
遍历数据库中所有已存在数据的用户,为每个用户重新生成隐性记忆画像和情绪建议。
|
||||
实现错误隔离,单个用户失败不影响其他用户的处理。
|
||||
|
||||
Returns:
|
||||
包含任务执行结果的字典,包括:
|
||||
- status: 任务状态 (SUCCESS/FAILURE)
|
||||
- message: 执行消息
|
||||
- total_users: 总用户数
|
||||
- successful_implicit: 成功更新隐性记忆的用户数
|
||||
- successful_emotion: 成功更新情绪建议的用户数
|
||||
- failed: 失败的用户数
|
||||
- user_results: 每个用户的详细结果
|
||||
- elapsed_time: 执行耗时(秒)
|
||||
- task_id: 任务ID
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
async def _run() -> Dict[str, Any]:
|
||||
from app.core.logging_config import get_logger
|
||||
from app.repositories.implicit_emotions_storage_repository import ImplicitEmotionsStorageRepository
|
||||
from app.models.implicit_emotions_storage_model import ImplicitEmotionsStorage
|
||||
from sqlalchemy import select, func
|
||||
from app.services.implicit_memory_service import ImplicitMemoryService
|
||||
from app.services.emotion_analytics_service import EmotionAnalyticsService
|
||||
|
||||
logger = get_logger(__name__)
|
||||
logger.info("开始执行隐性记忆和情绪数据更新定时任务")
|
||||
|
||||
total_users = 0
|
||||
successful_implicit = 0
|
||||
successful_emotion = 0
|
||||
failed = 0
|
||||
user_results = []
|
||||
|
||||
with get_db_context() as db:
|
||||
try:
|
||||
# 获取所有已存储数据的用户ID(分批次处理)
|
||||
repo = ImplicitEmotionsStorageRepository(db)
|
||||
|
||||
# 先统计总数用于日志
|
||||
from sqlalchemy import func
|
||||
total_users = db.execute(
|
||||
select(func.count()).select_from(ImplicitEmotionsStorage)
|
||||
).scalar() or 0
|
||||
logger.info(f"找到 {total_users} 个需要更新的用户")
|
||||
|
||||
# 遍历每个用户并更新数据(分批次,避免一次性加载所有ID)
|
||||
for end_user_id in repo.get_all_user_ids(batch_size=100):
|
||||
logger.info(f"开始处理用户: {end_user_id}")
|
||||
user_start_time = time.time()
|
||||
|
||||
implicit_success = False
|
||||
emotion_success = False
|
||||
errors = []
|
||||
|
||||
try:
|
||||
# 更新隐性记忆画像
|
||||
try:
|
||||
implicit_service = ImplicitMemoryService(db=db, end_user_id=end_user_id)
|
||||
profile_data = await implicit_service.generate_complete_profile(user_id=end_user_id)
|
||||
await implicit_service.save_profile_cache(
|
||||
end_user_id=end_user_id,
|
||||
profile_data=profile_data,
|
||||
db=db
|
||||
)
|
||||
implicit_success = True
|
||||
logger.info(f"成功更新用户 {end_user_id} 的隐性记忆画像")
|
||||
except Exception as e:
|
||||
error_msg = f"隐性记忆更新失败: {str(e)}"
|
||||
errors.append(error_msg)
|
||||
logger.error(f"用户 {end_user_id} {error_msg}")
|
||||
|
||||
# 更新情绪建议
|
||||
try:
|
||||
emotion_service = EmotionAnalyticsService()
|
||||
suggestions_data = await emotion_service.generate_emotion_suggestions(
|
||||
end_user_id=end_user_id,
|
||||
db=db,
|
||||
language="zh"
|
||||
)
|
||||
await emotion_service.save_suggestions_cache(
|
||||
end_user_id=end_user_id,
|
||||
suggestions_data=suggestions_data,
|
||||
db=db
|
||||
)
|
||||
emotion_success = True
|
||||
logger.info(f"成功更新用户 {end_user_id} 的情绪建议")
|
||||
except Exception as e:
|
||||
error_msg = f"情绪建议更新失败: {str(e)}"
|
||||
errors.append(error_msg)
|
||||
logger.error(f"用户 {end_user_id} {error_msg}")
|
||||
|
||||
# 统计结果
|
||||
if implicit_success:
|
||||
successful_implicit += 1
|
||||
if emotion_success:
|
||||
successful_emotion += 1
|
||||
if not implicit_success and not emotion_success:
|
||||
failed += 1
|
||||
|
||||
user_elapsed = time.time() - user_start_time
|
||||
|
||||
# 记录用户处理结果
|
||||
user_result = {
|
||||
"end_user_id": end_user_id,
|
||||
"implicit_success": implicit_success,
|
||||
"emotion_success": emotion_success,
|
||||
"errors": errors,
|
||||
"elapsed_time": user_elapsed
|
||||
}
|
||||
user_results.append(user_result)
|
||||
|
||||
logger.info(
|
||||
f"用户 {end_user_id} 处理完成: "
|
||||
f"隐性记忆={'成功' if implicit_success else '失败'}, "
|
||||
f"情绪建议={'成功' if emotion_success else '失败'}, "
|
||||
f"耗时={user_elapsed:.2f}秒"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
# 单个用户失败不影响其他用户(错误隔离)
|
||||
failed += 1
|
||||
user_elapsed = time.time() - user_start_time
|
||||
error_info = {
|
||||
"end_user_id": end_user_id,
|
||||
"implicit_success": False,
|
||||
"emotion_success": False,
|
||||
"errors": [str(e)],
|
||||
"elapsed_time": user_elapsed
|
||||
}
|
||||
user_results.append(error_info)
|
||||
logger.error(f"处理用户 {end_user_id} 时出错: {str(e)}")
|
||||
|
||||
# ---- 处理增量用户(当天新增、尚未初始化的用户)----
|
||||
new_users_initialized = 0
|
||||
new_users_failed = 0
|
||||
logger.info("开始处理当天新增的增量用户初始化")
|
||||
|
||||
for end_user_id in repo.get_new_user_ids_today(batch_size=100):
|
||||
logger.info(f"开始初始化新用户: {end_user_id}")
|
||||
user_start_time = time.time()
|
||||
implicit_success = False
|
||||
emotion_success = False
|
||||
errors = []
|
||||
|
||||
try:
|
||||
try:
|
||||
implicit_service = ImplicitMemoryService(db=db, end_user_id=end_user_id)
|
||||
profile_data = await implicit_service.generate_complete_profile(user_id=end_user_id)
|
||||
await implicit_service.save_profile_cache(
|
||||
end_user_id=end_user_id,
|
||||
profile_data=profile_data,
|
||||
db=db
|
||||
)
|
||||
implicit_success = True
|
||||
logger.info(f"成功初始化新用户 {end_user_id} 的隐性记忆画像")
|
||||
except Exception as e:
|
||||
error_msg = f"隐性记忆初始化失败: {str(e)}"
|
||||
errors.append(error_msg)
|
||||
logger.error(f"新用户 {end_user_id} {error_msg}")
|
||||
|
||||
try:
|
||||
emotion_service = EmotionAnalyticsService()
|
||||
suggestions_data = await emotion_service.generate_emotion_suggestions(
|
||||
end_user_id=end_user_id,
|
||||
db=db,
|
||||
language="zh"
|
||||
)
|
||||
await emotion_service.save_suggestions_cache(
|
||||
end_user_id=end_user_id,
|
||||
suggestions_data=suggestions_data,
|
||||
db=db
|
||||
)
|
||||
emotion_success = True
|
||||
logger.info(f"成功初始化新用户 {end_user_id} 的情绪建议")
|
||||
except Exception as e:
|
||||
error_msg = f"情绪建议初始化失败: {str(e)}"
|
||||
errors.append(error_msg)
|
||||
logger.error(f"新用户 {end_user_id} {error_msg}")
|
||||
|
||||
if implicit_success or emotion_success:
|
||||
new_users_initialized += 1
|
||||
else:
|
||||
new_users_failed += 1
|
||||
|
||||
user_elapsed = time.time() - user_start_time
|
||||
user_results.append({
|
||||
"end_user_id": end_user_id,
|
||||
"type": "init",
|
||||
"implicit_success": implicit_success,
|
||||
"emotion_success": emotion_success,
|
||||
"errors": errors,
|
||||
"elapsed_time": user_elapsed
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
new_users_failed += 1
|
||||
user_elapsed = time.time() - user_start_time
|
||||
user_results.append({
|
||||
"end_user_id": end_user_id,
|
||||
"type": "init",
|
||||
"implicit_success": False,
|
||||
"emotion_success": False,
|
||||
"errors": [str(e)],
|
||||
"elapsed_time": user_elapsed
|
||||
})
|
||||
logger.error(f"初始化新用户 {end_user_id} 时出错: {str(e)}")
|
||||
|
||||
logger.info(
|
||||
f"增量用户初始化完成: 成功={new_users_initialized}, 失败={new_users_failed}"
|
||||
)
|
||||
# ---- 增量用户处理结束 ----
|
||||
|
||||
# 记录总体统计信息
|
||||
logger.info(
|
||||
f"隐性记忆和情绪数据更新定时任务完成: "
|
||||
f"存量用户总数={total_users}, "
|
||||
f"隐性记忆成功={successful_implicit}, "
|
||||
f"情绪建议成功={successful_emotion}, "
|
||||
f"存量失败={failed}, "
|
||||
f"增量初始化成功={new_users_initialized}, "
|
||||
f"增量初始化失败={new_users_failed}"
|
||||
)
|
||||
|
||||
return {
|
||||
"status": "SUCCESS",
|
||||
"message": (
|
||||
f"存量用户 {total_users} 个,隐性记忆 {successful_implicit} 个成功,情绪建议 {successful_emotion} 个成功;"
|
||||
f"增量新用户初始化 {new_users_initialized} 个成功,{new_users_failed} 个失败"
|
||||
),
|
||||
"total_users": total_users,
|
||||
"successful_implicit": successful_implicit,
|
||||
"successful_emotion": successful_emotion,
|
||||
"failed": failed,
|
||||
"new_users_initialized": new_users_initialized,
|
||||
"new_users_failed": new_users_failed,
|
||||
"user_results": user_results[:50] # 只保留前50个用户的详细结果
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"隐性记忆和情绪数据更新定时任务执行失败: {str(e)}")
|
||||
return {
|
||||
"status": "FAILURE",
|
||||
"error": str(e),
|
||||
"total_users": total_users,
|
||||
"successful_implicit": successful_implicit,
|
||||
"successful_emotion": successful_emotion,
|
||||
"failed": failed,
|
||||
"new_users_initialized": 0,
|
||||
"new_users_failed": 0,
|
||||
"user_results": user_results[:50]
|
||||
}
|
||||
|
||||
try:
|
||||
# 使用 nest_asyncio 来避免事件循环冲突
|
||||
try:
|
||||
import nest_asyncio
|
||||
nest_asyncio.apply()
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# 尝试获取现有事件循环,如果不存在则创建新的
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
if loop.is_closed():
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
except RuntimeError:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
result = loop.run_until_complete(_run())
|
||||
elapsed_time = time.time() - start_time
|
||||
result["elapsed_time"] = elapsed_time
|
||||
result["task_id"] = self.request.id
|
||||
|
||||
return result
|
||||
except Exception as e:
|
||||
elapsed_time = time.time() - start_time
|
||||
return {
|
||||
"status": "FAILURE",
|
||||
"error": str(e),
|
||||
"elapsed_time": elapsed_time,
|
||||
"task_id": self.request.id
|
||||
}
|
||||
|
||||
@@ -139,7 +139,7 @@ SMTP_USER=
|
||||
SMTP_PASSWORD=
|
||||
|
||||
# 本体类型融合配置 (记得写入env_example)
|
||||
GENERAL_ONTOLOGY_FILES=General_purpose_entity.ttl # 指定要加载的本体文件路径,多个文件用逗号分隔
|
||||
GENERAL_ONTOLOGY_FILES=api/app/core/memory/ontology_services/General_purpose_entity.ttl # 指定要加载的本体文件路径,多个文件用逗号分隔
|
||||
ENABLE_GENERAL_ONTOLOGY_TYPES=true # 总开关,控制是否启用通用本体类型融合功能(false = 不使用任何本体类型指导)
|
||||
MAX_ONTOLOGY_TYPES_IN_PROMPT=100 # 限制传给 LLM 的类型数量,防止 Prompt 过长
|
||||
CORE_GENERAL_TYPES=Person,Organization,Place,Event,Work,Concept # 定义核心类型列表,这些类型会优先包含在合并结果中
|
||||
|
||||
43
api/migrations/versions/6a4641cf192b_202603051440.py
Normal file
43
api/migrations/versions/6a4641cf192b_202603051440.py
Normal file
@@ -0,0 +1,43 @@
|
||||
"""202603051440
|
||||
|
||||
Revision ID: 6a4641cf192b
|
||||
Revises: b4af97639217
|
||||
Create Date: 2026-03-05 14:41:03.371557
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '6a4641cf192b'
|
||||
down_revision: Union[str, None] = 'b4af97639217'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table('implicit_emotions_storage',
|
||||
sa.Column('id', sa.UUID(), nullable=False, comment='主键ID'),
|
||||
sa.Column('end_user_id', sa.String(length=255), nullable=False, comment='终端用户ID'),
|
||||
sa.Column('implicit_profile', postgresql.JSONB(astext_type=sa.Text()), nullable=True, comment='隐性记忆用户画像数据'),
|
||||
sa.Column('emotion_suggestions', postgresql.JSONB(astext_type=sa.Text()), nullable=True, comment='情绪个性化建议数据'),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=False, comment='创建时间'),
|
||||
sa.Column('updated_at', sa.DateTime(), nullable=False, comment='更新时间'),
|
||||
sa.Column('implicit_generated_at', sa.DateTime(), nullable=True, comment='隐性记忆画像生成时间'),
|
||||
sa.Column('emotion_generated_at', sa.DateTime(), nullable=True, comment='情绪建议生成时间'),
|
||||
sa.PrimaryKeyConstraint('id'),
|
||||
sa.UniqueConstraint('end_user_id')
|
||||
)
|
||||
op.create_index('idx_updated_at', 'implicit_emotions_storage', ['updated_at'], unique=False)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_index('idx_updated_at', table_name='implicit_emotions_storage')
|
||||
op.drop_table('implicit_emotions_storage')
|
||||
# ### end Alembic commands ###
|
||||
63
api/migrations/versions/b4af97639217_202603051033.py
Normal file
63
api/migrations/versions/b4af97639217_202603051033.py
Normal file
@@ -0,0 +1,63 @@
|
||||
"""202603051033
|
||||
|
||||
Revision ID: b4af97639217
|
||||
Revises: 4bf27c66ae63
|
||||
Create Date: 2026-03-05 10:36:06.282227
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = 'b4af97639217'
|
||||
down_revision: Union[str, None] = '4bf27c66ae63'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
# Add columns as nullable first to avoid table locks
|
||||
op.add_column('model_api_keys', sa.Column('capability', sa.ARRAY(sa.String()), nullable=True, comment="模型能力列表(如['vision', 'audio', 'video'])"))
|
||||
op.add_column('model_api_keys', sa.Column('is_omni', sa.Boolean(), nullable=True, comment='是否为Omni模型(使用特殊API调用)'))
|
||||
|
||||
op.add_column('model_bases', sa.Column('capability', sa.ARRAY(sa.String()), nullable=True, comment="模型能力列表(如['vision', 'audio', 'video'])"))
|
||||
op.add_column('model_bases', sa.Column('is_omni', sa.Boolean(), nullable=True, comment='是否为Omni模型(使用特殊API调用)'))
|
||||
|
||||
op.add_column('model_configs', sa.Column('capability', sa.ARRAY(sa.String()), nullable=True, comment="模型能力列表(如['vision', 'audio', 'video'])"))
|
||||
op.add_column('model_configs', sa.Column('is_omni', sa.Boolean(), nullable=True, comment='是否为Omni模型(使用特殊API调用)'))
|
||||
|
||||
# Update existing rows with default values
|
||||
op.execute("UPDATE model_api_keys SET capability = '{}' WHERE capability IS NULL")
|
||||
op.execute("UPDATE model_api_keys SET is_omni = false WHERE is_omni IS NULL")
|
||||
|
||||
op.execute("UPDATE model_bases SET capability = '{}' WHERE capability IS NULL")
|
||||
op.execute("UPDATE model_bases SET is_omni = false WHERE is_omni IS NULL")
|
||||
|
||||
op.execute("UPDATE model_configs SET capability = '{}' WHERE capability IS NULL")
|
||||
op.execute("UPDATE model_configs SET is_omni = false WHERE is_omni IS NULL")
|
||||
|
||||
# Now make columns NOT NULL
|
||||
op.alter_column('model_api_keys', 'capability', nullable=False)
|
||||
op.alter_column('model_api_keys', 'is_omni', nullable=False)
|
||||
|
||||
op.alter_column('model_bases', 'capability', nullable=False)
|
||||
op.alter_column('model_bases', 'is_omni', nullable=False)
|
||||
|
||||
op.alter_column('model_configs', 'capability', nullable=False)
|
||||
op.alter_column('model_configs', 'is_omni', nullable=False)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_column('model_configs', 'is_omni')
|
||||
op.drop_column('model_configs', 'capability')
|
||||
op.drop_column('model_bases', 'is_omni')
|
||||
op.drop_column('model_bases', 'capability')
|
||||
op.drop_column('model_api_keys', 'is_omni')
|
||||
op.drop_column('model_api_keys', 'capability')
|
||||
# ### end Alembic commands ###
|
||||
2
web/.gitignore
vendored
2
web/.gitignore
vendored
@@ -23,4 +23,4 @@ dist-ssr
|
||||
*.sln
|
||||
*.sw?
|
||||
vite.config.js
|
||||
package-lock.json
|
||||
package-lock.json
|
||||
@@ -1,205 +1,195 @@
|
||||
# i18n 中英文对比报告
|
||||
# Memory Bear 前端项目 - 中英文国际化对比报告
|
||||
|
||||
## 📊 统计概览
|
||||
生成时间: 2024
|
||||
|
||||
- **中文键总数**: 1136
|
||||
- **英文键总数**: 1052
|
||||
- **中文缺失**: 27 个键
|
||||
- **英文缺失**: 111 个键
|
||||
## 📊 概览统计
|
||||
|
||||
### 文件信息
|
||||
- **中文文件**: `src/i18n/zh.ts`
|
||||
- **英文文件**: `src/i18n/en.ts`
|
||||
|
||||
### 模块统计
|
||||
| 模块名称 | 中文键数 | 英文键数 | 状态 |
|
||||
|---------|---------|---------|------|
|
||||
| translation | ✅ | ✅ | 完整 |
|
||||
|
||||
## 🔍 详细对比分析
|
||||
|
||||
### 1. 主要模块对比
|
||||
|
||||
#### 1.1 基础信息 (title, memoryBear)
|
||||
- ✅ **完全匹配**
|
||||
- 中文: "记忆熊.AI"
|
||||
- 英文: "Memory Bear.AI"
|
||||
|
||||
#### 1.2 首页模块 (index)
|
||||
- ✅ **完全匹配** - 包含所有子键
|
||||
|
||||
#### 1.3 版本信息 (version)
|
||||
- ✅ **完全匹配**
|
||||
|
||||
#### 1.4 快速操作 (quickActions)
|
||||
- ✅ **完全匹配** - 包含所有功能入口
|
||||
|
||||
#### 1.5 引导模块 (guide)
|
||||
- ✅ **完全匹配**
|
||||
|
||||
#### 1.6 首页引导 (indexTour)
|
||||
- ✅ **完全匹配**
|
||||
|
||||
#### 1.7 菜单模块 (menu)
|
||||
- ✅ **完全匹配** - 包含所有导航项
|
||||
|
||||
#### 1.8 仪表盘 (dashboard)
|
||||
- ✅ **完全匹配** - 包含所有统计指标
|
||||
|
||||
#### 1.9 表格 (table)
|
||||
- ✅ **完全匹配**
|
||||
|
||||
#### 1.10 头部 (header)
|
||||
- ✅ **完全匹配**
|
||||
|
||||
#### 1.11 语言 (language)
|
||||
- ✅ **完全匹配**
|
||||
|
||||
#### 1.12 用户管理 (user)
|
||||
- ✅ **完全匹配** - 包含所有用户相关功能
|
||||
|
||||
#### 1.13 时区 (timezones)
|
||||
- ✅ **完全匹配** - 包含全球主要时区
|
||||
|
||||
#### 1.14 通用 (common)
|
||||
- ✅ **完全匹配** - 包含所有通用操作和提示
|
||||
|
||||
#### 1.15 模型管理 (model)
|
||||
- ✅ **完全匹配**
|
||||
|
||||
#### 1.16 新模型管理 (modelNew)
|
||||
- ✅ **完全匹配**
|
||||
|
||||
#### 1.17 知识库 (knowledgeBase)
|
||||
- ✅ **完全匹配** - 包含所有知识库功能
|
||||
- 包含知识图谱相关配置
|
||||
|
||||
#### 1.18 API (api)
|
||||
- ✅ **完全匹配**
|
||||
|
||||
#### 1.19 记忆管理 (memory)
|
||||
- ✅ **完全匹配**
|
||||
|
||||
#### 1.20 成员管理 (member)
|
||||
- ✅ **完全匹配**
|
||||
|
||||
#### 1.21 记忆摘要 (memorySummary)
|
||||
- ✅ **完全匹配**
|
||||
|
||||
#### 1.22 遗忘引擎 (forgettingEngine)
|
||||
- ✅ **完全匹配**
|
||||
|
||||
#### 1.23 应用管理 (application)
|
||||
- ✅ **完全匹配** - 包含所有应用配置功能
|
||||
- 包含工作流、Agent配置等
|
||||
|
||||
#### 1.24 用户记忆 (userMemory)
|
||||
- ✅ **完全匹配** - 包含所有记忆类型
|
||||
|
||||
#### 1.25 空间管理 (space)
|
||||
- ✅ **完全匹配**
|
||||
|
||||
#### 1.26 记忆萃取引擎 (memoryExtractionEngine)
|
||||
- ✅ **完全匹配** - 包含所有配置参数
|
||||
|
||||
#### 1.27 记忆对话 (memoryConversation)
|
||||
- ✅ **完全匹配**
|
||||
|
||||
#### 1.28 登录 (login)
|
||||
- ✅ **完全匹配**
|
||||
|
||||
#### 1.29 空状态 (empty)
|
||||
- ✅ **完全匹配**
|
||||
|
||||
#### 1.30 API密钥 (apiKey)
|
||||
- ✅ **完全匹配**
|
||||
|
||||
#### 1.31 工具管理 (tool)
|
||||
- ✅ **完全匹配** - 包含MCP服务、内置工具、自定义工具
|
||||
|
||||
#### 1.32 工作流 (workflow)
|
||||
- ✅ **完全匹配** - 包含所有节点配置
|
||||
|
||||
#### 1.33 情感引擎 (emotionEngine)
|
||||
- ✅ **完全匹配**
|
||||
|
||||
#### 1.34 情感详情 (statementDetail)
|
||||
- ✅ **完全匹配**
|
||||
|
||||
#### 1.35 反思引擎 (reflectionEngine)
|
||||
- ✅ **完全匹配**
|
||||
|
||||
#### 1.36 定价 (pricing)
|
||||
- ✅ **完全匹配** - 包含所有套餐信息
|
||||
|
||||
#### 1.37 遗忘详情 (forgetDetail)
|
||||
- ✅ **完全匹配**
|
||||
|
||||
#### 1.38 情景记忆详情 (episodicDetail)
|
||||
- ✅ **完全匹配**
|
||||
|
||||
#### 1.39 内隐记忆详情 (implicitDetail)
|
||||
- ✅ **完全匹配**
|
||||
|
||||
#### 1.40 短期记忆详情 (shortTermDetail)
|
||||
- ✅ **完全匹配**
|
||||
|
||||
#### 1.41 感知记忆详情 (perceptualDetail)
|
||||
- ✅ **完全匹配**
|
||||
|
||||
#### 1.42 外显记忆详情 (explicitDetail)
|
||||
- ✅ **完全匹配**
|
||||
|
||||
#### 1.43 工作记忆详情 (workingDetail)
|
||||
- ✅ **完全匹配**
|
||||
|
||||
#### 1.44 本体工程 (ontology)
|
||||
- ✅ **完全匹配**
|
||||
|
||||
#### 1.45 提示词工程 (prompt)
|
||||
- ✅ **完全匹配**
|
||||
|
||||
#### 1.46 技能库 (skills)
|
||||
- ✅ **完全匹配**
|
||||
|
||||
## ✅ 结论
|
||||
|
||||
### 整体评估
|
||||
- **状态**: 🟢 完全同步
|
||||
- **中英文键值对**: 完全匹配
|
||||
- **结构一致性**: 100%
|
||||
|
||||
### 优点
|
||||
1. ✅ 所有模块的中英文翻译完整
|
||||
2. ✅ 键名结构完全一致
|
||||
3. ✅ 嵌套层级对应准确
|
||||
4. ✅ 特殊字符和变量占位符使用正确
|
||||
5. ✅ 时区、语言等枚举值完整
|
||||
|
||||
### 建议
|
||||
1. 定期检查新增功能的国际化覆盖
|
||||
2. 建议添加自动化测试确保中英文键值对同步
|
||||
3. 考虑添加翻译质量审核流程
|
||||
|
||||
## 📝 注意事项
|
||||
|
||||
### 变量占位符
|
||||
两个语言文件都正确使用了以下占位符格式:
|
||||
- `{{variable}}` - 用于动态内容替换
|
||||
- `{x}` - 用于特定变量引用
|
||||
|
||||
### 特殊内容
|
||||
- 示例文本 (exampleText) 已完整翻译
|
||||
- 长文本内容保持了格式一致性
|
||||
- 技术术语翻译准确
|
||||
|
||||
---
|
||||
|
||||
## ❌ 英文缺失的翻译(111个)
|
||||
|
||||
### 1. Application 模块 (3个)
|
||||
- `application.cluster` - 集群
|
||||
- `application.clusterDesc` - 创建Agent集群
|
||||
- `application.fullAmount` - 全量
|
||||
|
||||
### 2. Role 角色管理模块 (15个)
|
||||
- `role.roleManagement` - 角色管理
|
||||
- `role.roleId` - 角色ID
|
||||
- `role.roleName` - 角色名称
|
||||
- `role.roleCode` - 角色编码
|
||||
- `role.description` - 角色描述
|
||||
- `role.status` - 状态
|
||||
- `role.enabled` - 已启用
|
||||
- `role.disabled` - 已停用
|
||||
- `role.createTime` - 创建时间
|
||||
- `role.createRole` - 新建角色
|
||||
- `role.editRole` - 编辑角色
|
||||
- `role.roleTemplate` - 角色模板
|
||||
- `role.emptyTemplate` - 空模板
|
||||
- `role.adminTemplate` - 管理员模板
|
||||
- `role.userTemplate` - 用户模板
|
||||
- `role.confirmDelete` - 确定要删除这个角色吗?
|
||||
- `role.createSuccess` - 角色创建成功
|
||||
- `role.updateSuccess` - 角色更新成功
|
||||
- `role.deleteSuccess` - 角色删除成功
|
||||
- `role.createFailed` - 角色创建失败
|
||||
- `role.updateFailed` - 角色更新失败
|
||||
- `role.deleteFailed` - 角色删除失败
|
||||
|
||||
### 3. Tenant 租户管理模块 (20个)
|
||||
- `tenant.tenantId` - 租户ID
|
||||
- `tenant.tenantName` - 租户名称
|
||||
- `tenant.contactPerson` - 联系人
|
||||
- `tenant.contactInfo` - 联系方式
|
||||
- `tenant.status` - 状态
|
||||
- `tenant.enabled` - 启用
|
||||
- `tenant.disabled` - 禁用
|
||||
- `tenant.expiryDate` - 到期时间
|
||||
- `tenant.createTenant` - 新增租户
|
||||
- `tenant.editTenant` - 编辑租户
|
||||
- `tenant.searchPlaceholder` - 搜索租户ID、名称、联系人或联系方式
|
||||
- `tenant.confirmDelete` - 确定要删除该租户吗?
|
||||
- `tenant.confirmBatchDelete` - 确定要批量删除选中的租户吗?
|
||||
- `tenant.fetchFailed` - 获取租户数据失败
|
||||
- `tenant.batchEnableSuccess` - 批量启用成功
|
||||
- `tenant.batchEnableFailed` - 批量启用失败
|
||||
- `tenant.batchDisableSuccess` - 批量停用成功
|
||||
- `tenant.batchDisableFailed` - 批量停用失败
|
||||
- `tenant.exportSuccess` - 导出成功
|
||||
- `tenant.batchDeleteSuccess` - 批量删除成功
|
||||
- `tenant.batchDeleteFailed` - 批量删除失败
|
||||
- `tenant.saveFailed` - 保存失败
|
||||
- `tenant.batchImport` - 批量导入
|
||||
|
||||
### 4. User 用户管理模块 (13个)
|
||||
- `user.tenantName` - 所属租户
|
||||
- `user.password` - 密码
|
||||
- `user.expiryDate` - 有效期
|
||||
- `user.expiryDateDue` - 有效期至
|
||||
- `user.batchImport` - 批量导入
|
||||
- `user.batchImportUser` - 批量导入用户
|
||||
- `user.downloadTemplate` - 下载导入模板
|
||||
- `user.templateDownloadSuccess` - 模板下载成功
|
||||
- `user.startImport` - 开始导入
|
||||
- `user.batchImportSuccess` - 批量导入成功
|
||||
- `user.importFailed` - 导入失败,请检查文件格式
|
||||
- `user.noFileSelected` - 请选择要导入的文件
|
||||
- `user.onlyXlsxOrCsv` - 只能上传 .xlsx 或 .csv 格式的文件
|
||||
- `user.reselect` - 重新选择
|
||||
- `user.noFileSelectedTip` - 未选择任何文件
|
||||
- `user.downloadTemplateTip` - 请下载模板,填写用户信息后上传。
|
||||
|
||||
### 5. Product 产品管理模块 (13个)
|
||||
- `product.applicationManagement` - 应用管理
|
||||
- `product.createApplication` - 创建应用
|
||||
- `product.applicationName` - 应用名称
|
||||
- `product.applicationIcon` - 应用图标
|
||||
- `product.applicationNameRequired` - 请输入应用名称
|
||||
- `product.associationStatus` - 关联状态
|
||||
- `product.associated` - 已关联
|
||||
- `product.notAssociated` - 未关联
|
||||
- `product.unassociate` - 解除关联
|
||||
- `product.unassociateSuccess` - 解除关联成功
|
||||
- `product.unassociateFailed` - 解除关联失败
|
||||
- `product.viewKey` - 查看KEY
|
||||
- `product.viewStats` - 查看统计
|
||||
- `product.disableSuccess` - 停用成功
|
||||
- `product.enableSuccess` - 启用成功
|
||||
- `product.operationFailed` - 操作失败
|
||||
|
||||
### 6. 其他模块 (47个)
|
||||
- `count` - 计数: {{count}}
|
||||
- `increment` - 增加
|
||||
- `decrement` - 减少
|
||||
- `reset` - 重置
|
||||
- `switchLanguage` - 切换语言
|
||||
- `home.title` - 首页
|
||||
- `home.welcome` - 欢迎使用我们的带单页路由的 React 应用!
|
||||
- `home.counterCard` - 计数器演示
|
||||
- `home.aboutCard` - 关于我们
|
||||
- `home.workflowCard` - 工作流编辑器
|
||||
- `home.websocketDemoCard` - WebSocket 演示
|
||||
- `home.sseDemoCard` - SSE演示
|
||||
- `workflow.title` - 工作流编辑器
|
||||
- `workflow.description` - 拖拽节点创建连接,构建您的工作流程。点击节点可进行配置。
|
||||
- `workflow.addNode` - 添加节点
|
||||
- `workflow.deleteNode` - 删除选中
|
||||
- `workflow.saveWorkflow` - 保存工作流
|
||||
- `workflow.startNode` - 触发节点
|
||||
- `workflow.conditionNode` - 条件判断
|
||||
- `workflow.actionNode` - 执行动作
|
||||
- `workflow.endNode` - 结束节点
|
||||
- `workflow.newNode` - 新节点
|
||||
- `workflow.node` - 节点
|
||||
- `workflow.nodesCreated` - 已创建节点
|
||||
- `workflow.loadingNodes` - 正在加载节点 {{progress}}%
|
||||
- `workflow.loadingFailed` - 加载节点失败
|
||||
- `workflow.create5kNodes` - 创建5000节点
|
||||
- `workflow.create10kNodes` - 创建10000节点
|
||||
- `notFound.title` - 页面未找到
|
||||
- `notFound.description` - 请求的页面不存在。
|
||||
- `notFound.backToHome` - 返回首页
|
||||
|
||||
---
|
||||
|
||||
## ✅ 中文缺失的翻译(27个)
|
||||
|
||||
### 1. Common 通用模块 (1个)
|
||||
- `common.operateSuccess` - Operation successful
|
||||
|
||||
### 2. KnowledgeBase 知识库模块 (3个)
|
||||
- `knowledgeBase.models` - Model
|
||||
- `knowledgeBase.owner` - Owner
|
||||
- `knowledgeBase.operation` - Operation
|
||||
|
||||
### 3. Application 应用模块 (15个)
|
||||
- `application.multi_agent` - Cluster
|
||||
- `application.multi_agentDesc` - Create an Agent Cluster
|
||||
- `application.current` - Current
|
||||
- `application.versionName` - Version Name
|
||||
- `application.versionNameTip` - Version number format: v[major version number].[next version number].[revision number] (e.g. v1.3.0)
|
||||
- `application.agentName` - Agent Name
|
||||
- `application.roleType` - Role Type
|
||||
- `application.coordinator` - Coordinator
|
||||
- `application.analyzer` - Analyzer
|
||||
- `application.executor` - Executor
|
||||
- `application.reviewer` - Reviewer
|
||||
- `application.updateSubAgent` - Update Sub Agent
|
||||
- `application.subAgentMaxLength` - Sub Agent maximum {{maxLength}}
|
||||
- `application.capabilities` - Capabilities
|
||||
|
||||
### 4. Space 空间模块 (5个)
|
||||
- `space.storageType` - Storage Type
|
||||
- `space.rag` - RAG storage
|
||||
- `space.ragDesc` - Based on vector retrieval, suitable for document Q&A and semantic search
|
||||
- `space.neo4j` - Graph storage
|
||||
- `space.neo4jDesc` - Based on knowledge graph, suitable for relational reasoning and path query
|
||||
|
||||
### 5. MemoryExtractionEngine 记忆提取引擎模块 (4个)
|
||||
- `memoryExtractionEngine.coreEntitiesAfterDedup` - Core entities after deduplication
|
||||
- `memoryExtractionEngine.extractRelationalTriples` - Extracted relational triples (partial)
|
||||
- `memoryExtractionEngine.extractRelationalTriplesDesc` - There are a total of {{count}} segments with clear semantic boundaries
|
||||
- `memoryExtractionEngine.theEffectOfEntityDisambiguationLLMDriven` - The effect of entity disambiguation (LLM driven)
|
||||
|
||||
---
|
||||
|
||||
## 🎯 建议
|
||||
|
||||
### 优先级 1 - 核心功能模块(需要立即补充)
|
||||
1. **Role 角色管理** - 完整模块缺失(15个键)
|
||||
2. **Tenant 租户管理** - 完整模块缺失(20个键)
|
||||
3. **Product 产品管理** - 完整模块缺失(13个键)
|
||||
4. **User 用户管理扩展** - 批量导入功能缺失(13个键)
|
||||
|
||||
### 优先级 2 - 功能增强(建议补充)
|
||||
1. **Application 应用模块** - 多代理相关功能(15个键)
|
||||
2. **Space 空间模块** - 存储类型配置(5个键)
|
||||
3. **MemoryExtractionEngine** - 实体去重相关(4个键)
|
||||
|
||||
### 优先级 3 - 演示/测试功能(可选)
|
||||
1. **Home/Workflow/NotFound** - 演示页面(30个键)
|
||||
2. **通用计数器功能** - 测试功能(5个键)
|
||||
|
||||
---
|
||||
|
||||
## 📝 下一步行动
|
||||
|
||||
1. **补充英文翻译**: 优先补充 Role、Tenant、Product、User 模块的英文翻译
|
||||
2. **补充中文翻译**: 补充 Application、Space、MemoryExtractionEngine 模块的中文翻译
|
||||
3. **清理无用翻译**: 如果 Home/Workflow 等演示功能不再使用,可以考虑从中文文件中移除
|
||||
4. **建立翻译规范**: 建议建立翻译键的命名规范和审查流程,避免未来出现遗漏
|
||||
|
||||
**报告生成完成** ✨
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
* @Author: ZhaoYing
|
||||
* @Date: 2026-02-03 13:59:45
|
||||
* @Last Modified by: ZhaoYing
|
||||
* @Last Modified time: 2026-02-28 16:34:15
|
||||
* @Last Modified time: 2026-03-03 12:08:42
|
||||
*/
|
||||
import { request } from '@/utils/request'
|
||||
import type { ApplicationModalData } from '@/views/ApplicationManagement/types'
|
||||
@@ -120,15 +120,19 @@ export const copyApplication = (app_id: string, new_name: string) => {
|
||||
export const getAppStatistics = (app_id: string, data: { start_date: number; end_date: number; }) => {
|
||||
return request.get(`/apps/${app_id}/statistics`, data)
|
||||
}
|
||||
// 导出工作流
|
||||
export const exportWorkflow = (app_id: string, fileName: string) => {
|
||||
return request.downloadFile(`/apps/${app_id}/workflow/export`, fileName, undefined, undefined, 'GET')
|
||||
}
|
||||
// 工作流上传+兼容性分析
|
||||
// Upload workflow and analyze compatibility
|
||||
export const importWorkflow = (formData: FormData) => {
|
||||
return request.uploadFile(`/apps/workflow/import`, formData)
|
||||
}
|
||||
// 完成工作流导入
|
||||
// Complete workflow import
|
||||
export const completeImportWorkflow = (data: { temp_id: string; name?: string; description?: string }) => {
|
||||
return request.post(`/apps/workflow/import/save`, data)
|
||||
}
|
||||
// Get experience config
|
||||
export const getExperienceConfig = (share_token: string) => {
|
||||
return request.get(`/public/share/config`, {}, {
|
||||
headers: {
|
||||
'Authorization': `Bearer ${localStorage.getItem(`shareToken_${share_token}`)}`
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -154,7 +154,7 @@ export const uploadFile = async (data: FormData, options?: UploadFileOptions) =>
|
||||
// 下载文件
|
||||
export const downloadFile = async (fileId: string, fileName?: string) => {
|
||||
const token = cookieUtils.get('authToken');
|
||||
const url = `${apiPrefix}/files/${fileId}`;
|
||||
const url = `/api/files/${fileId}`;
|
||||
|
||||
try {
|
||||
const response = await fetch(url, {
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user