Merge branch 'develop' into feature/ontology_zy
This commit is contained in:
3
.gitignore
vendored
3
.gitignore
vendored
@@ -35,3 +35,6 @@ nltk_data/
|
||||
tika-server*.jar*
|
||||
cl100k_base.tiktoken
|
||||
libssl*.deb
|
||||
|
||||
sandbox/lib/seccomp_python/target
|
||||
sandbox/lib/seccomp_nodejs/target
|
||||
|
||||
0
api/app/__init__.py
Normal file
0
api/app/__init__.py
Normal file
@@ -872,3 +872,44 @@ async def update_workflow_config(
|
||||
workspace_id = current_user.current_workspace_id
|
||||
cfg = app_service.update_workflow_config(db, app_id=app_id, data=payload, workspace_id=workspace_id)
|
||||
return success(data=WorkflowConfigSchema.model_validate(cfg))
|
||||
|
||||
|
||||
@router.get("/{app_id}/statistics", summary="应用统计数据")
|
||||
@cur_workspace_access_guard()
|
||||
def get_app_statistics(
|
||||
app_id: uuid.UUID,
|
||||
start_date: int,
|
||||
end_date: int,
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
):
|
||||
"""获取应用统计数据
|
||||
|
||||
Args:
|
||||
app_id: 应用ID
|
||||
start_date: 开始时间戳(毫秒)
|
||||
end_date: 结束时间戳(毫秒)
|
||||
|
||||
Returns:
|
||||
- daily_conversations: 每日会话数统计
|
||||
- total_conversations: 总会话数
|
||||
- daily_new_users: 每日新增用户数
|
||||
- total_new_users: 总新增用户数
|
||||
- daily_api_calls: 每日API调用次数
|
||||
- total_api_calls: 总API调用次数
|
||||
- daily_tokens: 每日token消耗
|
||||
- total_tokens: 总token消耗
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
from app.services.app_statistics_service import AppStatisticsService
|
||||
stats_service = AppStatisticsService(db)
|
||||
|
||||
result = stats_service.get_app_statistics(
|
||||
app_id=app_id,
|
||||
workspace_id=workspace_id,
|
||||
start_date=start_date,
|
||||
end_date=end_date
|
||||
)
|
||||
|
||||
return success(data=result)
|
||||
|
||||
@@ -7,11 +7,13 @@ Routes:
|
||||
GET /memory/config/emotion - 获取情绪引擎配置
|
||||
POST /memory/config/emotion - 更新情绪引擎配置
|
||||
"""
|
||||
import uuid
|
||||
|
||||
from fastapi import APIRouter, Depends, Query, HTTPException, status
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Optional
|
||||
from typing import Optional, Union
|
||||
from sqlalchemy.orm import Session
|
||||
from uuid import UUID
|
||||
|
||||
from app.core.response_utils import success
|
||||
from app.dependencies import get_current_user
|
||||
@@ -20,6 +22,7 @@ from app.schemas.response_schema import ApiResponse
|
||||
from app.services.emotion_config_service import EmotionConfigService
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.db import get_db
|
||||
from app.utils.config_utils import resolve_config_id
|
||||
|
||||
# 获取API专用日志器
|
||||
api_logger = get_api_logger()
|
||||
@@ -32,11 +35,11 @@ router = APIRouter(
|
||||
|
||||
class EmotionConfigQuery(BaseModel):
|
||||
"""情绪配置查询请求模型"""
|
||||
config_id: int = Field(..., description="配置ID")
|
||||
config_id: UUID = Field(..., description="配置ID")
|
||||
|
||||
class EmotionConfigUpdate(BaseModel):
|
||||
"""情绪配置更新请求模型"""
|
||||
config_id: int = Field(..., description="配置ID")
|
||||
config_id: Union[uuid.UUID, int, str]= Field(..., description="配置ID")
|
||||
emotion_enabled: bool = Field(..., description="是否启用情绪提取")
|
||||
emotion_model_id: Optional[str] = Field(None, description="情绪分析专用模型ID")
|
||||
emotion_extract_keywords: bool = Field(..., description="是否提取情绪关键词")
|
||||
@@ -45,7 +48,7 @@ class EmotionConfigUpdate(BaseModel):
|
||||
|
||||
@router.get("/read_config", response_model=ApiResponse)
|
||||
def get_emotion_config(
|
||||
config_id: int = Query(..., description="配置ID"),
|
||||
config_id: UUID|int = Query(..., description="配置ID"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
@@ -78,7 +81,7 @@ def get_emotion_config(
|
||||
f"用户 {current_user.username} 请求获取情绪配置",
|
||||
extra={"config_id": config_id}
|
||||
)
|
||||
|
||||
config_id=resolve_config_id(config_id, db)
|
||||
# 初始化服务
|
||||
config_service = EmotionConfigService(db)
|
||||
|
||||
@@ -157,6 +160,7 @@ def update_emotion_config(
|
||||
}
|
||||
}
|
||||
"""
|
||||
config.config_id=resolve_config_id(config.config_id, db)
|
||||
try:
|
||||
api_logger.info(
|
||||
f"用户 {current_user.username} 请求更新情绪配置",
|
||||
|
||||
@@ -53,7 +53,7 @@ async def get_emotion_tags(
|
||||
api_logger.info(
|
||||
f"用户 {current_user.username} 请求获取情绪标签统计",
|
||||
extra={
|
||||
"group_id": request.group_id,
|
||||
"end_user_id": request.end_user_id,
|
||||
"emotion_type": request.emotion_type,
|
||||
"start_date": request.start_date,
|
||||
"end_date": request.end_date,
|
||||
@@ -63,7 +63,7 @@ async def get_emotion_tags(
|
||||
|
||||
# 调用服务层
|
||||
data = await emotion_service.get_emotion_tags(
|
||||
end_user_id=request.group_id,
|
||||
end_user_id=request.end_user_id,
|
||||
emotion_type=request.emotion_type,
|
||||
start_date=request.start_date,
|
||||
end_date=request.end_date,
|
||||
@@ -73,7 +73,7 @@ async def get_emotion_tags(
|
||||
api_logger.info(
|
||||
"情绪标签统计获取成功",
|
||||
extra={
|
||||
"group_id": request.group_id,
|
||||
"end_user_id": request.end_user_id,
|
||||
"total_count": data.get("total_count", 0),
|
||||
"tags_count": len(data.get("tags", []))
|
||||
}
|
||||
@@ -84,7 +84,7 @@ async def get_emotion_tags(
|
||||
except Exception as e:
|
||||
api_logger.error(
|
||||
f"获取情绪标签统计失败: {str(e)}",
|
||||
extra={"group_id": request.group_id},
|
||||
extra={"end_user_id": request.end_user_id},
|
||||
exc_info=True
|
||||
)
|
||||
raise HTTPException(
|
||||
@@ -105,7 +105,7 @@ async def get_emotion_wordcloud(
|
||||
api_logger.info(
|
||||
f"用户 {current_user.username} 请求获取情绪词云数据",
|
||||
extra={
|
||||
"group_id": request.group_id,
|
||||
"end_user_id": request.end_user_id,
|
||||
"emotion_type": request.emotion_type,
|
||||
"limit": request.limit
|
||||
}
|
||||
@@ -113,7 +113,7 @@ async def get_emotion_wordcloud(
|
||||
|
||||
# 调用服务层
|
||||
data = await emotion_service.get_emotion_wordcloud(
|
||||
end_user_id=request.group_id,
|
||||
end_user_id=request.end_user_id,
|
||||
emotion_type=request.emotion_type,
|
||||
limit=request.limit
|
||||
)
|
||||
@@ -121,7 +121,7 @@ async def get_emotion_wordcloud(
|
||||
api_logger.info(
|
||||
"情绪词云数据获取成功",
|
||||
extra={
|
||||
"group_id": request.group_id,
|
||||
"end_user_id": request.end_user_id,
|
||||
"total_keywords": data.get("total_keywords", 0)
|
||||
}
|
||||
)
|
||||
@@ -131,7 +131,7 @@ async def get_emotion_wordcloud(
|
||||
except Exception as e:
|
||||
api_logger.error(
|
||||
f"获取情绪词云数据失败: {str(e)}",
|
||||
extra={"group_id": request.group_id},
|
||||
extra={"end_user_id": request.end_user_id},
|
||||
exc_info=True
|
||||
)
|
||||
raise HTTPException(
|
||||
@@ -159,21 +159,21 @@ async def get_emotion_health(
|
||||
api_logger.info(
|
||||
f"用户 {current_user.username} 请求获取情绪健康指数",
|
||||
extra={
|
||||
"group_id": request.group_id,
|
||||
"end_user_id": request.end_user_id,
|
||||
"time_range": request.time_range
|
||||
}
|
||||
)
|
||||
|
||||
# 调用服务层
|
||||
data = await emotion_service.calculate_emotion_health_index(
|
||||
end_user_id=request.group_id,
|
||||
end_user_id=request.end_user_id,
|
||||
time_range=request.time_range
|
||||
)
|
||||
|
||||
api_logger.info(
|
||||
"情绪健康指数获取成功",
|
||||
extra={
|
||||
"group_id": request.group_id,
|
||||
"end_user_id": request.end_user_id,
|
||||
"health_score": data.get("health_score", 0),
|
||||
"level": data.get("level", "未知")
|
||||
}
|
||||
@@ -186,7 +186,7 @@ async def get_emotion_health(
|
||||
except Exception as e:
|
||||
api_logger.error(
|
||||
f"获取情绪健康指数失败: {str(e)}",
|
||||
extra={"group_id": request.group_id},
|
||||
extra={"end_user_id": request.end_user_id},
|
||||
exc_info=True
|
||||
)
|
||||
raise HTTPException(
|
||||
@@ -206,7 +206,7 @@ async def get_emotion_suggestions(
|
||||
"""获取个性化情绪建议(从缓存读取)
|
||||
|
||||
Args:
|
||||
request: 包含 group_id 和可选的 config_id
|
||||
request: 包含 end_user_id 和可选的 config_id
|
||||
db: 数据库会话
|
||||
current_user: 当前用户
|
||||
|
||||
@@ -217,22 +217,22 @@ async def get_emotion_suggestions(
|
||||
api_logger.info(
|
||||
f"用户 {current_user.username} 请求获取个性化情绪建议(缓存)",
|
||||
extra={
|
||||
"group_id": request.group_id,
|
||||
"end_user_id": request.end_user_id,
|
||||
"config_id": request.config_id
|
||||
}
|
||||
)
|
||||
|
||||
# 从缓存获取建议
|
||||
data = await emotion_service.get_cached_suggestions(
|
||||
end_user_id=request.group_id,
|
||||
end_user_id=request.end_user_id,
|
||||
db=db
|
||||
)
|
||||
|
||||
if data is None:
|
||||
# 缓存不存在或已过期
|
||||
api_logger.info(
|
||||
f"用户 {request.group_id} 的建议缓存不存在或已过期",
|
||||
extra={"group_id": request.group_id}
|
||||
f"用户 {request.end_user_id} 的建议缓存不存在或已过期",
|
||||
extra={"end_user_id": request.end_user_id}
|
||||
)
|
||||
return fail(
|
||||
BizCode.NOT_FOUND,
|
||||
@@ -243,7 +243,7 @@ async def get_emotion_suggestions(
|
||||
api_logger.info(
|
||||
"个性化建议获取成功(缓存)",
|
||||
extra={
|
||||
"group_id": request.group_id,
|
||||
"end_user_id": request.end_user_id,
|
||||
"suggestions_count": len(data.get("suggestions", []))
|
||||
}
|
||||
)
|
||||
@@ -253,7 +253,7 @@ async def get_emotion_suggestions(
|
||||
except Exception as e:
|
||||
api_logger.error(
|
||||
f"获取个性化建议失败: {str(e)}",
|
||||
extra={"group_id": request.group_id},
|
||||
extra={"end_user_id": request.end_user_id},
|
||||
exc_info=True
|
||||
)
|
||||
raise HTTPException(
|
||||
|
||||
@@ -310,7 +310,7 @@ async def get_file_url(
|
||||
try:
|
||||
if permanent:
|
||||
# Generate permanent URL (no expiration check)
|
||||
server_url = f"http://{settings.SERVER_IP}:8000/api"
|
||||
server_url = settings.FILE_LOCAL_SERVER_URL
|
||||
url = f"{server_url}/storage/permanent/{file_id}"
|
||||
return success(
|
||||
data={
|
||||
|
||||
@@ -122,10 +122,10 @@ def validate_confidence_threshold(threshold: float) -> None:
|
||||
raise ValueError("confidence_threshold must be between 0.0 and 1.0")
|
||||
|
||||
|
||||
@router.get("/preferences/{user_id}", response_model=ApiResponse)
|
||||
@router.get("/preferences/{end_user_id}", response_model=ApiResponse)
|
||||
@cur_workspace_access_guard()
|
||||
async def get_preference_tags(
|
||||
user_id: str,
|
||||
end_user_id: str,
|
||||
confidence_threshold: float = Query(0.5, ge=0.0, le=1.0, description="Minimum confidence threshold"),
|
||||
tag_category: Optional[str] = Query(None, description="Filter by tag category"),
|
||||
start_date: Optional[datetime] = Query(None, description="Filter start date"),
|
||||
@@ -137,7 +137,7 @@ async def get_preference_tags(
|
||||
Get user preference tags from cache.
|
||||
|
||||
Args:
|
||||
user_id: Target user ID
|
||||
end_user_id: Target end user ID
|
||||
confidence_threshold: Minimum confidence score (0.0-1.0)
|
||||
tag_category: Optional category filter
|
||||
start_date: Optional start date filter
|
||||
@@ -146,20 +146,20 @@ async def get_preference_tags(
|
||||
Returns:
|
||||
List of preference tags from cache
|
||||
"""
|
||||
api_logger.info(f"Preference tags requested for user: {user_id} (from cache)")
|
||||
api_logger.info(f"Preference tags requested for user: {end_user_id} (from cache)")
|
||||
|
||||
try:
|
||||
# Validate inputs
|
||||
validate_user_id(user_id)
|
||||
validate_user_id(end_user_id)
|
||||
|
||||
# Create service with user-specific config
|
||||
service = ImplicitMemoryService(db=db, end_user_id=user_id)
|
||||
service = ImplicitMemoryService(db=db, end_user_id=end_user_id)
|
||||
|
||||
# Get cached profile
|
||||
cached_profile = await service.get_cached_profile(end_user_id=user_id, db=db)
|
||||
cached_profile = await service.get_cached_profile(end_user_id=end_user_id, db=db)
|
||||
|
||||
if cached_profile is None:
|
||||
api_logger.info(f"用户 {user_id} 的画像缓存不存在或已过期")
|
||||
api_logger.info(f"用户 {end_user_id} 的画像缓存不存在或已过期")
|
||||
return fail(
|
||||
BizCode.NOT_FOUND,
|
||||
"画像缓存不存在或已过期,请右上角刷新生成新画像",
|
||||
@@ -192,17 +192,17 @@ async def get_preference_tags(
|
||||
|
||||
filtered_preferences.append(pref)
|
||||
|
||||
api_logger.info(f"Retrieved {len(filtered_preferences)} preference tags for user: {user_id} (from cache)")
|
||||
api_logger.info(f"Retrieved {len(filtered_preferences)} preference tags for user: {end_user_id} (from cache)")
|
||||
return success(data=filtered_preferences, msg="偏好标签获取成功(缓存)")
|
||||
|
||||
except Exception as e:
|
||||
return handle_implicit_memory_error(e, "偏好标签获取", user_id)
|
||||
return handle_implicit_memory_error(e, "偏好标签获取", end_user_id)
|
||||
|
||||
|
||||
@router.get("/portrait/{user_id}", response_model=ApiResponse)
|
||||
@router.get("/portrait/{end_user_id}", response_model=ApiResponse)
|
||||
@cur_workspace_access_guard()
|
||||
async def get_dimension_portrait(
|
||||
user_id: str,
|
||||
end_user_id: str,
|
||||
include_history: bool = Query(False, description="Include historical trends"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
@@ -211,26 +211,26 @@ async def get_dimension_portrait(
|
||||
Get user's four-dimension personality portrait from cache.
|
||||
|
||||
Args:
|
||||
user_id: Target user ID
|
||||
end_user_id: Target end user ID
|
||||
include_history: Whether to include historical trend data (ignored for cached data)
|
||||
|
||||
Returns:
|
||||
Four-dimension personality portrait from cache
|
||||
"""
|
||||
api_logger.info(f"Dimension portrait requested for user: {user_id} (from cache)")
|
||||
api_logger.info(f"Dimension portrait requested for user: {end_user_id} (from cache)")
|
||||
|
||||
try:
|
||||
# Validate inputs
|
||||
validate_user_id(user_id)
|
||||
validate_user_id(end_user_id)
|
||||
|
||||
# Create service with user-specific config
|
||||
service = ImplicitMemoryService(db=db, end_user_id=user_id)
|
||||
service = ImplicitMemoryService(db=db, end_user_id=end_user_id)
|
||||
|
||||
# Get cached profile
|
||||
cached_profile = await service.get_cached_profile(end_user_id=user_id, db=db)
|
||||
cached_profile = await service.get_cached_profile(end_user_id=end_user_id, db=db)
|
||||
|
||||
if cached_profile is None:
|
||||
api_logger.info(f"用户 {user_id} 的画像缓存不存在或已过期")
|
||||
api_logger.info(f"用户 {end_user_id} 的画像缓存不存在或已过期")
|
||||
return fail(
|
||||
BizCode.NOT_FOUND,
|
||||
"画像缓存不存在或已过期,请右上角刷新生成新画像",
|
||||
@@ -240,17 +240,17 @@ async def get_dimension_portrait(
|
||||
# Extract portrait from cache
|
||||
portrait = cached_profile.get("portrait", {})
|
||||
|
||||
api_logger.info(f"Dimension portrait retrieved for user: {user_id} (from cache)")
|
||||
api_logger.info(f"Dimension portrait retrieved for user: {end_user_id} (from cache)")
|
||||
return success(data=portrait, msg="四维画像获取成功(缓存)")
|
||||
|
||||
except Exception as e:
|
||||
return handle_implicit_memory_error(e, "四维画像获取", user_id)
|
||||
return handle_implicit_memory_error(e, "四维画像获取", end_user_id)
|
||||
|
||||
|
||||
@router.get("/interest-areas/{user_id}", response_model=ApiResponse)
|
||||
@router.get("/interest-areas/{end_user_id}", response_model=ApiResponse)
|
||||
@cur_workspace_access_guard()
|
||||
async def get_interest_area_distribution(
|
||||
user_id: str,
|
||||
end_user_id: str,
|
||||
include_trends: bool = Query(False, description="Include trend analysis"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
@@ -259,26 +259,26 @@ async def get_interest_area_distribution(
|
||||
Get user's interest area distribution from cache.
|
||||
|
||||
Args:
|
||||
user_id: Target user ID
|
||||
end_user_id: Target end user ID
|
||||
include_trends: Whether to include trend analysis data (ignored for cached data)
|
||||
|
||||
Returns:
|
||||
Interest area distribution from cache
|
||||
"""
|
||||
api_logger.info(f"Interest area distribution requested for user: {user_id} (from cache)")
|
||||
api_logger.info(f"Interest area distribution requested for user: {end_user_id} (from cache)")
|
||||
|
||||
try:
|
||||
# Validate inputs
|
||||
validate_user_id(user_id)
|
||||
validate_user_id(end_user_id)
|
||||
|
||||
# Create service with user-specific config
|
||||
service = ImplicitMemoryService(db=db, end_user_id=user_id)
|
||||
service = ImplicitMemoryService(db=db, end_user_id=end_user_id)
|
||||
|
||||
# Get cached profile
|
||||
cached_profile = await service.get_cached_profile(end_user_id=user_id, db=db)
|
||||
cached_profile = await service.get_cached_profile(end_user_id=end_user_id, db=db)
|
||||
|
||||
if cached_profile is None:
|
||||
api_logger.info(f"用户 {user_id} 的画像缓存不存在或已过期")
|
||||
api_logger.info(f"用户 {end_user_id} 的画像缓存不存在或已过期")
|
||||
return fail(
|
||||
BizCode.NOT_FOUND,
|
||||
"画像缓存不存在或已过期,请右上角刷新生成新画像",
|
||||
@@ -288,17 +288,17 @@ async def get_interest_area_distribution(
|
||||
# Extract interest areas from cache
|
||||
interest_areas = cached_profile.get("interest_areas", {})
|
||||
|
||||
api_logger.info(f"Interest area distribution retrieved for user: {user_id} (from cache)")
|
||||
api_logger.info(f"Interest area distribution retrieved for user: {end_user_id} (from cache)")
|
||||
return success(data=interest_areas, msg="兴趣领域分布获取成功(缓存)")
|
||||
|
||||
except Exception as e:
|
||||
return handle_implicit_memory_error(e, "兴趣领域分布获取", user_id)
|
||||
return handle_implicit_memory_error(e, "兴趣领域分布获取", end_user_id)
|
||||
|
||||
|
||||
@router.get("/habits/{user_id}", response_model=ApiResponse)
|
||||
@router.get("/habits/{end_user_id}", response_model=ApiResponse)
|
||||
@cur_workspace_access_guard()
|
||||
async def get_behavior_habits(
|
||||
user_id: str,
|
||||
end_user_id: str,
|
||||
confidence_level: Optional[str] = Query(None, regex="^(high|medium|low)$", description="Filter by confidence level"),
|
||||
frequency_pattern: Optional[str] = Query(None, regex="^(daily|weekly|monthly|seasonal|occasional|event_triggered)$", description="Filter by frequency pattern"),
|
||||
time_period: Optional[str] = Query(None, regex="^(current|past)$", description="Filter by time period"),
|
||||
@@ -309,7 +309,7 @@ async def get_behavior_habits(
|
||||
Get user's behavioral habits from cache.
|
||||
|
||||
Args:
|
||||
user_id: Target user ID
|
||||
end_user_id: Target end user ID
|
||||
confidence_level: Filter by confidence level (high, medium, low)
|
||||
frequency_pattern: Filter by frequency pattern (daily, weekly, monthly, seasonal, occasional, event_triggered)
|
||||
time_period: Filter by time period (current, past)
|
||||
@@ -317,20 +317,20 @@ async def get_behavior_habits(
|
||||
Returns:
|
||||
List of behavioral habits from cache
|
||||
"""
|
||||
api_logger.info(f"Behavior habits requested for user: {user_id} (from cache)")
|
||||
api_logger.info(f"Behavior habits requested for user: {end_user_id} (from cache)")
|
||||
|
||||
try:
|
||||
# Validate inputs
|
||||
validate_user_id(user_id)
|
||||
validate_user_id(end_user_id)
|
||||
|
||||
# Create service with user-specific config
|
||||
service = ImplicitMemoryService(db=db, end_user_id=user_id)
|
||||
service = ImplicitMemoryService(db=db, end_user_id=end_user_id)
|
||||
|
||||
# Get cached profile
|
||||
cached_profile = await service.get_cached_profile(end_user_id=user_id, db=db)
|
||||
cached_profile = await service.get_cached_profile(end_user_id=end_user_id, db=db)
|
||||
|
||||
if cached_profile is None:
|
||||
api_logger.info(f"用户 {user_id} 的画像缓存不存在或已过期")
|
||||
api_logger.info(f"用户 {end_user_id} 的画像缓存不存在或已过期")
|
||||
return fail(
|
||||
BizCode.NOT_FOUND,
|
||||
"画像缓存不存在或已过期,请右上角刷新生成新画像",
|
||||
@@ -368,11 +368,11 @@ async def get_behavior_habits(
|
||||
|
||||
filtered_habits.append(habit)
|
||||
|
||||
api_logger.info(f"Retrieved {len(filtered_habits)} behavior habits for user: {user_id} (from cache)")
|
||||
api_logger.info(f"Retrieved {len(filtered_habits)} behavior habits for user: {end_user_id} (from cache)")
|
||||
return success(data=filtered_habits, msg="行为习惯获取成功(缓存)")
|
||||
|
||||
except Exception as e:
|
||||
return handle_implicit_memory_error(e, "行为习惯获取", user_id)
|
||||
return handle_implicit_memory_error(e, "行为习惯获取", end_user_id)
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -125,7 +125,7 @@ async def write_server(
|
||||
Write service endpoint - processes write operations synchronously
|
||||
|
||||
Args:
|
||||
user_input: Write request containing message and group_id
|
||||
user_input: Write request containing message and end_user_id
|
||||
|
||||
Returns:
|
||||
Response with write operation status
|
||||
@@ -160,19 +160,18 @@ async def write_server(
|
||||
api_logger.warning("workspace_id 为空,无法使用 rag 存储,将使用 neo4j 存储")
|
||||
storage_type = 'neo4j'
|
||||
|
||||
api_logger.info(f"Write service requested for group {user_input.group_id}, storage_type: {storage_type}, user_rag_memory_id: {user_rag_memory_id}")
|
||||
api_logger.info(f"Write service requested for group {user_input.end_user_id}, storage_type: {storage_type}, user_rag_memory_id: {user_rag_memory_id}")
|
||||
try:
|
||||
# 获取标准化的消息列表
|
||||
messages_list = memory_agent_service.get_messages_list(user_input)
|
||||
|
||||
result = await memory_agent_service.write_memory(
|
||||
user_input.group_id,
|
||||
messages_list, # 传递结构化消息列表
|
||||
user_input.end_user_id,
|
||||
messages_list,
|
||||
config_id,
|
||||
db,
|
||||
storage_type,
|
||||
user_rag_memory_id
|
||||
)
|
||||
|
||||
return success(data=result, msg="写入成功")
|
||||
except BaseException as e:
|
||||
# Handle ExceptionGroup from TaskGroup (Python 3.11+) or BaseExceptionGroup
|
||||
@@ -196,7 +195,7 @@ async def write_server_async(
|
||||
Async write service endpoint - enqueues write processing to Celery
|
||||
|
||||
Args:
|
||||
user_input: Write request containing message and group_id
|
||||
user_input: Write request containing message and end_user_id
|
||||
|
||||
Returns:
|
||||
Task ID for tracking async operation
|
||||
@@ -226,10 +225,10 @@ async def write_server_async(
|
||||
try:
|
||||
# 获取标准化的消息列表
|
||||
messages_list = memory_agent_service.get_messages_list(user_input)
|
||||
|
||||
|
||||
task = celery_app.send_task(
|
||||
"app.core.memory.agent.write_message",
|
||||
args=[user_input.group_id, messages_list, config_id, storage_type, user_rag_memory_id]
|
||||
args=[user_input.end_user_id, messages_list, config_id, storage_type, user_rag_memory_id]
|
||||
)
|
||||
api_logger.info(f"Write task queued: {task.id}")
|
||||
|
||||
@@ -255,16 +254,14 @@ async def read_server(
|
||||
- "2": Direct answer based on context
|
||||
|
||||
Args:
|
||||
user_input: Read request with message, history, search_switch, and group_id
|
||||
user_input: Read request with message, history, search_switch, and end_user_id
|
||||
|
||||
Returns:
|
||||
Response with query answer
|
||||
"""
|
||||
config_id = user_input.config_id
|
||||
workspace_id = current_user.current_workspace_id
|
||||
api_logger.info(f"Read service: workspace_id={workspace_id}, config_id={config_id}")
|
||||
|
||||
# 获取 storage_type,如果为 None 则使用默认值
|
||||
storage_type = workspace_service.get_workspace_storage_type(
|
||||
db=db,
|
||||
workspace_id=workspace_id,
|
||||
@@ -279,12 +276,13 @@ async def read_server(
|
||||
name="USER_RAG_MERORY",
|
||||
workspace_id=workspace_id
|
||||
)
|
||||
if knowledge: user_rag_memory_id = str(knowledge.id)
|
||||
if knowledge:
|
||||
user_rag_memory_id = str(knowledge.id)
|
||||
|
||||
api_logger.info(f"Read service: group={user_input.group_id}, storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}, workspace_id={workspace_id}")
|
||||
api_logger.info(f"Read service: group={user_input.end_user_id}, storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}, workspace_id={workspace_id}")
|
||||
try:
|
||||
result = await memory_agent_service.read_memory(
|
||||
user_input.group_id,
|
||||
user_input.end_user_id,
|
||||
user_input.message,
|
||||
user_input.history,
|
||||
user_input.search_switch,
|
||||
@@ -295,17 +293,20 @@ async def read_server(
|
||||
)
|
||||
if str(user_input.search_switch) == "2":
|
||||
retrieve_info = result['answer']
|
||||
history = await SessionService(store).get_history(user_input.group_id, user_input.group_id, user_input.group_id)
|
||||
history = await SessionService(store).get_history(user_input.end_user_id, user_input.end_user_id, user_input.end_user_id)
|
||||
query = user_input.message
|
||||
|
||||
|
||||
# 调用 memory_agent_service 的方法生成最终答案
|
||||
result['answer'] = await memory_agent_service.generate_summary_from_retrieve(
|
||||
end_user_id=user_input.end_user_id,
|
||||
retrieve_info=retrieve_info,
|
||||
history=history,
|
||||
query=query,
|
||||
config_id=config_id,
|
||||
db=db
|
||||
)
|
||||
if "信息不足,无法回答" in result['answer']:
|
||||
result['answer']=retrieve_info
|
||||
return success(data=result, msg="回复对话消息成功")
|
||||
except BaseException as e:
|
||||
# Handle ExceptionGroup from TaskGroup (Python 3.11+) or BaseExceptionGroup
|
||||
@@ -403,7 +404,7 @@ async def read_server_async(
|
||||
try:
|
||||
task = celery_app.send_task(
|
||||
"app.core.memory.agent.read_message",
|
||||
args=[user_input.group_id, user_input.message, user_input.history, user_input.search_switch,
|
||||
args=[user_input.end_user_id, user_input.message, user_input.history, user_input.search_switch,
|
||||
config_id, storage_type, user_rag_memory_id]
|
||||
)
|
||||
api_logger.info(f"Read task queued: {task.id}")
|
||||
@@ -447,7 +448,7 @@ async def get_read_task_result(
|
||||
return success(
|
||||
data={
|
||||
"result": task_result.get("result"),
|
||||
"group_id": task_result.get("group_id"),
|
||||
"end_user_id": task_result.get("end_user_id"),
|
||||
"elapsed_time": task_result.get("elapsed_time"),
|
||||
"task_id": task_id
|
||||
},
|
||||
@@ -524,7 +525,7 @@ async def get_write_task_result(
|
||||
return success(
|
||||
data={
|
||||
"result": task_result.get("result"),
|
||||
"group_id": task_result.get("group_id"),
|
||||
"end_user_id": task_result.get("end_user_id"),
|
||||
"elapsed_time": task_result.get("elapsed_time"),
|
||||
"task_id": task_id
|
||||
},
|
||||
@@ -578,16 +579,16 @@ async def status_type(
|
||||
Determine the type of user message (read or write)
|
||||
|
||||
Args:
|
||||
user_input: Request containing user message and group_id
|
||||
user_input: Request containing user message and end_user_id
|
||||
|
||||
Returns:
|
||||
Type classification result
|
||||
"""
|
||||
api_logger.info(f"Status type check requested for group {user_input.group_id}")
|
||||
api_logger.info(f"Status type check requested for group {user_input.end_user_id}")
|
||||
try:
|
||||
# 获取标准化的消息列表
|
||||
messages_list = memory_agent_service.get_messages_list(user_input)
|
||||
|
||||
|
||||
# 将消息列表转换为字符串用于分类
|
||||
# 只取最后一条用户消息进行分类
|
||||
last_user_message = ""
|
||||
@@ -595,11 +596,11 @@ async def status_type(
|
||||
if msg.get('role') == 'user':
|
||||
last_user_message = msg.get('content', '')
|
||||
break
|
||||
|
||||
|
||||
if not last_user_message:
|
||||
# 如果没有用户消息,使用所有消息的内容
|
||||
last_user_message = " ".join([msg.get('content', '') for msg in messages_list])
|
||||
|
||||
|
||||
result = await memory_agent_service.classify_message_type(
|
||||
last_user_message,
|
||||
user_input.config_id,
|
||||
@@ -624,7 +625,7 @@ async def get_knowledge_type_stats_api(
|
||||
会对缺失类型补 0,返回字典形式。
|
||||
可选按状态过滤。
|
||||
- 知识库类型根据当前用户的 current_workspace_id 过滤
|
||||
- memory 是 Neo4j 中 Chunk 的数量,根据 end_user_id (group_id) 过滤
|
||||
- memory 是 Neo4j 中 Chunk 的数量,根据 end_user_id (end_user_id) 过滤
|
||||
- 如果用户没有当前工作空间或未提供 end_user_id,对应的统计返回 0
|
||||
"""
|
||||
api_logger.info(f"Knowledge type stats requested for workspace_id: {current_user.current_workspace_id}, end_user_id: {end_user_id}")
|
||||
@@ -697,7 +698,7 @@ async def get_user_profile_api(
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
获取工作空间下Popular Memory Tags,包含:
|
||||
获取用户详情,包含:
|
||||
- name: 用户名字(直接使用 end_user_id)
|
||||
- tags: 3个用户特征标签(从语句和实体中LLM总结)
|
||||
- hot_tags: 4个热门记忆标签
|
||||
|
||||
@@ -49,63 +49,134 @@ async def get_workspace_end_users(
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
获取工作空间的宿主列表
|
||||
获取工作空间的宿主列表(高性能优化版本 v2)
|
||||
|
||||
返回格式与原 memory_list 接口中的 end_users 字段相同,
|
||||
并包含每个用户的记忆配置信息(memory_config_id 和 memory_config_name)
|
||||
优化策略:
|
||||
1. 批量查询 end_users(一次查询而非循环)
|
||||
2. 并发查询所有用户的记忆数量(Neo4j)
|
||||
3. RAG 模式使用批量查询(一次 SQL)
|
||||
4. 只返回必要字段减少数据传输
|
||||
5. 添加短期缓存减少重复查询
|
||||
6. 并发执行配置查询和记忆数量查询
|
||||
|
||||
返回格式:
|
||||
{
|
||||
"end_user": {"id": "uuid", "other_name": "名称"},
|
||||
"memory_num": {"total": 数量},
|
||||
"memory_config": {"memory_config_id": "id", "memory_config_name": "名称"}
|
||||
}
|
||||
"""
|
||||
import asyncio
|
||||
import json
|
||||
from app.aioRedis import aio_redis_get, aio_redis_set
|
||||
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
# 尝试从缓存获取(30秒缓存)
|
||||
cache_key = f"end_users:workspace:{workspace_id}"
|
||||
try:
|
||||
cached_data = await aio_redis_get(cache_key)
|
||||
if cached_data:
|
||||
api_logger.info(f"从缓存获取宿主列表: workspace_id={workspace_id}")
|
||||
return success(data=json.loads(cached_data), msg="宿主列表获取成功")
|
||||
except Exception as e:
|
||||
api_logger.warning(f"Redis 缓存读取失败: {str(e)}")
|
||||
|
||||
# 获取当前空间类型
|
||||
current_workspace_type = memory_dashboard_service.get_current_workspace_type(db, workspace_id, current_user)
|
||||
api_logger.info(f"用户 {current_user.username} 请求获取工作空间 {workspace_id} 的宿主列表")
|
||||
|
||||
# 获取 end_users(已优化为批量查询)
|
||||
end_users = memory_dashboard_service.get_workspace_end_users(
|
||||
db=db,
|
||||
workspace_id=workspace_id,
|
||||
current_user=current_user
|
||||
)
|
||||
|
||||
# 批量获取所有用户的记忆配置信息(优化:一次查询而非 N 次)
|
||||
end_user_ids = [str(user.id) for user in end_users]
|
||||
memory_configs_map = {}
|
||||
if end_user_ids:
|
||||
if not end_users:
|
||||
api_logger.info("工作空间下没有宿主")
|
||||
# 缓存空结果,避免重复查询
|
||||
try:
|
||||
memory_configs_map = get_end_users_connected_configs_batch(end_user_ids, db)
|
||||
await aio_redis_set(cache_key, json.dumps([]), expire=30)
|
||||
except Exception as e:
|
||||
api_logger.warning(f"Redis 缓存写入失败: {str(e)}")
|
||||
return success(data=[], msg="宿主列表获取成功")
|
||||
|
||||
end_user_ids = [str(user.id) for user in end_users]
|
||||
|
||||
# 并发执行两个独立的查询任务
|
||||
async def get_memory_configs():
|
||||
"""获取记忆配置(在线程池中执行同步查询)"""
|
||||
try:
|
||||
return await asyncio.to_thread(
|
||||
get_end_users_connected_configs_batch,
|
||||
end_user_ids, db
|
||||
)
|
||||
except Exception as e:
|
||||
api_logger.error(f"批量获取记忆配置失败: {str(e)}")
|
||||
# 失败时使用空字典,不影响其他数据返回
|
||||
return {}
|
||||
|
||||
async def get_memory_nums():
|
||||
"""获取记忆数量"""
|
||||
if current_workspace_type == "rag":
|
||||
# RAG 模式:批量查询
|
||||
try:
|
||||
chunk_map = await asyncio.to_thread(
|
||||
memory_dashboard_service.get_users_total_chunk_batch,
|
||||
end_user_ids, db, current_user
|
||||
)
|
||||
return {uid: {"total": count} for uid, count in chunk_map.items()}
|
||||
except Exception as e:
|
||||
api_logger.error(f"批量获取 RAG chunk 数量失败: {str(e)}")
|
||||
return {uid: {"total": 0} for uid in end_user_ids}
|
||||
|
||||
elif current_workspace_type == "neo4j":
|
||||
# Neo4j 模式:并发查询(带并发限制)
|
||||
# 使用信号量限制并发数,避免大量用户时压垮 Neo4j
|
||||
MAX_CONCURRENT_QUERIES = 10
|
||||
semaphore = asyncio.Semaphore(MAX_CONCURRENT_QUERIES)
|
||||
|
||||
async def get_neo4j_memory_num(end_user_id: str):
|
||||
async with semaphore:
|
||||
try:
|
||||
return await memory_storage_service.search_all(end_user_id)
|
||||
except Exception as e:
|
||||
api_logger.error(f"获取用户 {end_user_id} Neo4j 记忆数量失败: {str(e)}")
|
||||
return {"total": 0}
|
||||
|
||||
memory_nums_list = await asyncio.gather(*[get_neo4j_memory_num(uid) for uid in end_user_ids])
|
||||
return {end_user_ids[i]: memory_nums_list[i] for i in range(len(end_user_ids))}
|
||||
|
||||
return {uid: {"total": 0} for uid in end_user_ids}
|
||||
|
||||
# 并发执行配置查询和记忆数量查询
|
||||
memory_configs_map, memory_nums_map = await asyncio.gather(
|
||||
get_memory_configs(),
|
||||
get_memory_nums()
|
||||
)
|
||||
|
||||
# 构建结果(优化:使用列表推导式)
|
||||
result = []
|
||||
for end_user in end_users:
|
||||
memory_num = {}
|
||||
if current_workspace_type == "neo4j":
|
||||
# EndUser 是 Pydantic 模型,直接访问属性而不是使用 .get()
|
||||
memory_num = await memory_storage_service.search_all(str(end_user.id))
|
||||
elif current_workspace_type == "rag":
|
||||
memory_num = {
|
||||
"total":memory_dashboard_service.get_current_user_total_chunk(str(end_user.id), db, current_user)
|
||||
}
|
||||
|
||||
# 从批量查询结果中获取配置信息
|
||||
user_id = str(end_user.id)
|
||||
memory_config_info = memory_configs_map.get(user_id, {
|
||||
"memory_config_id": None,
|
||||
"memory_config_name": None
|
||||
})
|
||||
|
||||
# 只保留需要的字段,移除 error 字段(如果有)
|
||||
memory_config = {
|
||||
"memory_config_id": memory_config_info.get("memory_config_id"),
|
||||
"memory_config_name": memory_config_info.get("memory_config_name")
|
||||
}
|
||||
|
||||
result.append(
|
||||
{
|
||||
'end_user': end_user,
|
||||
'memory_num': memory_num,
|
||||
'memory_config': memory_config
|
||||
config_info = memory_configs_map.get(user_id, {})
|
||||
result.append({
|
||||
'end_user': {
|
||||
'id': user_id,
|
||||
'other_name': end_user.other_name
|
||||
},
|
||||
'memory_num': memory_nums_map.get(user_id, {"total": 0}),
|
||||
'memory_config': {
|
||||
"memory_config_id": config_info.get("memory_config_id"),
|
||||
"memory_config_name": config_info.get("memory_config_name")
|
||||
}
|
||||
)
|
||||
|
||||
})
|
||||
|
||||
# 写入缓存(30秒过期)
|
||||
try:
|
||||
await aio_redis_set(cache_key, json.dumps(result), expire=30)
|
||||
except Exception as e:
|
||||
api_logger.warning(f"Redis 缓存写入失败: {str(e)}")
|
||||
|
||||
api_logger.info(f"成功获取 {len(end_users)} 个宿主记录")
|
||||
return success(data=result, msg="宿主列表获取成功")
|
||||
|
||||
|
||||
@@ -11,6 +11,7 @@
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
from sqlalchemy.orm import Session
|
||||
@@ -33,7 +34,7 @@ from app.schemas.memory_storage_schema import (
|
||||
)
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.services.memory_forget_service import MemoryForgetService
|
||||
|
||||
from app.utils.config_utils import resolve_config_id
|
||||
|
||||
# 获取API专用日志器
|
||||
api_logger = get_api_logger()
|
||||
@@ -83,7 +84,8 @@ async def trigger_forgetting_cycle(
|
||||
|
||||
connected_config = get_end_user_connected_config(end_user_id, db)
|
||||
config_id = connected_config.get("memory_config_id")
|
||||
|
||||
config_id = resolve_config_id((config_id), db)
|
||||
|
||||
if config_id is None:
|
||||
api_logger.warning(f"终端用户 {end_user_id} 未关联记忆配置")
|
||||
return fail(BizCode.INVALID_PARAMETER, f"终端用户 {end_user_id} 未关联记忆配置", "memory_config_id is None")
|
||||
@@ -106,7 +108,7 @@ async def trigger_forgetting_cycle(
|
||||
# 调用服务层执行遗忘周期
|
||||
report = await forget_service.trigger_forgetting_cycle(
|
||||
db=db,
|
||||
group_id=end_user_id, # 服务层方法的参数名是 group_id
|
||||
end_user_id=end_user_id, # 服务层方法的参数名是 end_user_id
|
||||
max_merge_batch_size=payload.max_merge_batch_size,
|
||||
min_days_since_access=payload.min_days_since_access,
|
||||
config_id=config_id
|
||||
@@ -128,7 +130,7 @@ async def trigger_forgetting_cycle(
|
||||
|
||||
@router.get("/read_config", response_model=ApiResponse)
|
||||
async def read_forgetting_config(
|
||||
config_id: int,
|
||||
config_id: UUID|int,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
@@ -157,6 +159,7 @@ async def read_forgetting_config(
|
||||
)
|
||||
|
||||
try:
|
||||
config_id=resolve_config_id(config_id, db)
|
||||
# 调用服务层读取配置
|
||||
config = forget_service.read_forgetting_config(db=db, config_id=config_id)
|
||||
|
||||
@@ -194,6 +197,8 @@ async def update_forgetting_config(
|
||||
ApiResponse: 包含更新结果的响应
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
payload.config_id=resolve_config_id((payload.config_id), db)
|
||||
|
||||
|
||||
# 检查用户是否已选择工作空间
|
||||
if workspace_id is None:
|
||||
@@ -236,7 +241,7 @@ async def update_forgetting_config(
|
||||
|
||||
@router.get("/stats", response_model=ApiResponse)
|
||||
async def get_forgetting_stats(
|
||||
group_id: Optional[str] = None,
|
||||
end_user_id: Optional[str] = None,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
@@ -246,7 +251,7 @@ async def get_forgetting_stats(
|
||||
返回知识层节点统计、激活值分布等信息。
|
||||
|
||||
Args:
|
||||
group_id: 组ID(即 end_user_id,可选)
|
||||
end_user_id: 组ID(即 end_user_id,可选)
|
||||
current_user: 当前用户
|
||||
db: 数据库会话
|
||||
|
||||
@@ -254,26 +259,25 @@ async def get_forgetting_stats(
|
||||
ApiResponse: 包含统计信息的响应
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
# 检查用户是否已选择工作空间
|
||||
if workspace_id is None:
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试获取遗忘引擎统计但未选择工作空间")
|
||||
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
||||
|
||||
# 如果提供了 group_id,通过它获取 config_id
|
||||
# 如果提供了 end_user_id,通过它获取 config_id
|
||||
config_id = None
|
||||
if group_id:
|
||||
if end_user_id:
|
||||
try:
|
||||
from app.services.memory_agent_service import get_end_user_connected_config
|
||||
|
||||
connected_config = get_end_user_connected_config(group_id, db)
|
||||
connected_config = get_end_user_connected_config(end_user_id, db)
|
||||
config_id = connected_config.get("memory_config_id")
|
||||
config_id = resolve_config_id(config_id, db)
|
||||
|
||||
if config_id is None:
|
||||
api_logger.warning(f"终端用户 {group_id} 未关联记忆配置")
|
||||
return fail(BizCode.INVALID_PARAMETER, f"终端用户 {group_id} 未关联记忆配置", "memory_config_id is None")
|
||||
api_logger.warning(f"终端用户 {end_user_id} 未关联记忆配置")
|
||||
return fail(BizCode.INVALID_PARAMETER, f"终端用户 {end_user_id} 未关联记忆配置", "memory_config_id is None")
|
||||
|
||||
api_logger.debug(f"通过 group_id={group_id} 获取到 config_id={config_id}")
|
||||
api_logger.debug(f"通过 end_user_id={end_user_id} 获取到 config_id={config_id}")
|
||||
except ValueError as e:
|
||||
api_logger.warning(f"获取终端用户配置失败: {str(e)}")
|
||||
return fail(BizCode.INVALID_PARAMETER, str(e), "ValueError")
|
||||
@@ -283,14 +287,14 @@ async def get_forgetting_stats(
|
||||
|
||||
api_logger.info(
|
||||
f"用户 {current_user.username} 在工作空间 {workspace_id} 请求获取遗忘引擎统计: "
|
||||
f"group_id={group_id}, config_id={config_id}"
|
||||
f"end_user_id={end_user_id}, config_id={config_id}"
|
||||
)
|
||||
|
||||
try:
|
||||
# 调用服务层获取统计信息
|
||||
stats = await forget_service.get_forgetting_stats(
|
||||
db=db,
|
||||
group_id=group_id,
|
||||
end_user_id=end_user_id,
|
||||
config_id=config_id
|
||||
)
|
||||
|
||||
@@ -324,7 +328,7 @@ async def get_forgetting_curve(
|
||||
ApiResponse: 包含遗忘曲线数据的响应
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
request.config_id = resolve_config_id((request.config_id), db)
|
||||
# 检查用户是否已选择工作空间
|
||||
if workspace_id is None:
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试获取遗忘曲线但未选择工作空间")
|
||||
|
||||
@@ -27,27 +27,27 @@ router = APIRouter(
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{group_id}/count", response_model=ApiResponse)
|
||||
@router.get("/{end_user_id}/count", response_model=ApiResponse)
|
||||
def get_memory_count(
|
||||
group_id: uuid.UUID,
|
||||
end_user_id: uuid.UUID,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""Retrieve perceptual memory statistics for a user group.
|
||||
|
||||
Args:
|
||||
group_id: ID of the user group (usually end_user_id in this context)
|
||||
end_user_id: ID of the user group (usually end_user_id in this context)
|
||||
current_user: Current authenticated user
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
ApiResponse: Response containing memory count statistics
|
||||
"""
|
||||
api_logger.info(f"Fetching perceptual memory statistics: user={current_user.username}, group_id={group_id}")
|
||||
api_logger.info(f"Fetching perceptual memory statistics: user={current_user.username}, end_user_id={end_user_id}")
|
||||
|
||||
try:
|
||||
service = MemoryPerceptualService(db)
|
||||
count_stats = service.get_memory_count(group_id)
|
||||
count_stats = service.get_memory_count(end_user_id)
|
||||
|
||||
api_logger.info(f"Memory statistics fetched successfully: total={count_stats.get('total', 0)}")
|
||||
|
||||
@@ -57,37 +57,37 @@ def get_memory_count(
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"Failed to fetch memory statistics: group_id={group_id}, error={str(e)}")
|
||||
api_logger.error(f"Failed to fetch memory statistics: end_user_id={end_user_id}, error={str(e)}")
|
||||
return fail(
|
||||
code=BizCode.INTERNAL_ERROR,
|
||||
msg="Failed to fetch memory statistics",
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{group_id}/last_visual", response_model=ApiResponse)
|
||||
@router.get("/{end_user_id}/last_visual", response_model=ApiResponse)
|
||||
def get_last_visual_memory(
|
||||
group_id: uuid.UUID,
|
||||
end_user_id: uuid.UUID,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""Retrieve the most recent VISION-type memory for a user.
|
||||
|
||||
Args:
|
||||
group_id: ID of the user group
|
||||
end_user_id: ID of the user group
|
||||
current_user: Current authenticated user
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
ApiResponse: Metadata of the latest visual memory
|
||||
"""
|
||||
api_logger.info(f"Fetching latest visual memory: user={current_user.username}, group_id={group_id}")
|
||||
api_logger.info(f"Fetching latest visual memory: user={current_user.username}, end_user_id={end_user_id}")
|
||||
|
||||
try:
|
||||
service = MemoryPerceptualService(db)
|
||||
visual_memory = service.get_latest_visual_memory(group_id)
|
||||
visual_memory = service.get_latest_visual_memory(end_user_id)
|
||||
|
||||
if visual_memory is None:
|
||||
api_logger.info(f"No visual memory found: group_id={group_id}")
|
||||
api_logger.info(f"No visual memory found: end_user_id={end_user_id}")
|
||||
return success(
|
||||
data=None,
|
||||
msg="No visual memory available"
|
||||
@@ -101,37 +101,37 @@ def get_last_visual_memory(
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"Failed to fetch latest visual memory: group_id={group_id}, error={str(e)}")
|
||||
api_logger.error(f"Failed to fetch latest visual memory: end_user_id={end_user_id}, error={str(e)}")
|
||||
return fail(
|
||||
code=BizCode.INTERNAL_ERROR,
|
||||
msg="Failed to fetch latest visual memory",
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{group_id}/last_listen", response_model=ApiResponse)
|
||||
@router.get("/{end_user_id}/last_listen", response_model=ApiResponse)
|
||||
def get_last_memory_listen(
|
||||
group_id: uuid.UUID,
|
||||
end_user_id: uuid.UUID,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""Retrieve the most recent AUDIO-type memory for a user.
|
||||
|
||||
Args:
|
||||
group_id: ID of the user group
|
||||
end_user_id: ID of the user group
|
||||
current_user: Current authenticated user
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
ApiResponse: Metadata of the latest audio memory
|
||||
"""
|
||||
api_logger.info(f"Fetching latest audio memory: user={current_user.username}, group_id={group_id}")
|
||||
api_logger.info(f"Fetching latest audio memory: user={current_user.username}, end_user_id={end_user_id}")
|
||||
|
||||
try:
|
||||
service = MemoryPerceptualService(db)
|
||||
audio_memory = service.get_latest_audio_memory(group_id)
|
||||
audio_memory = service.get_latest_audio_memory(end_user_id)
|
||||
|
||||
if audio_memory is None:
|
||||
api_logger.info(f"No audio memory found: group_id={group_id}")
|
||||
api_logger.info(f"No audio memory found: end_user_id={end_user_id}")
|
||||
return success(
|
||||
data=None,
|
||||
msg="No audio memory available"
|
||||
@@ -145,38 +145,38 @@ def get_last_memory_listen(
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"Failed to fetch latest audio memory: group_id={group_id}, error={str(e)}")
|
||||
api_logger.error(f"Failed to fetch latest audio memory: end_user_id={end_user_id}, error={str(e)}")
|
||||
return fail(
|
||||
code=BizCode.INTERNAL_ERROR,
|
||||
msg="Failed to fetch latest audio memory",
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{group_id}/last_text", response_model=ApiResponse)
|
||||
@router.get("/{end_user_id}/last_text", response_model=ApiResponse)
|
||||
def get_last_text_memory(
|
||||
group_id: uuid.UUID,
|
||||
end_user_id: uuid.UUID,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""Retrieve the most recent TEXT-type memory for a user.
|
||||
|
||||
Args:
|
||||
group_id: ID of the user group
|
||||
end_user_id: ID of the user group
|
||||
current_user: Current authenticated user
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
ApiResponse: Metadata of the latest text memory
|
||||
"""
|
||||
api_logger.info(f"Fetching latest text memory: user={current_user.username}, group_id={group_id}")
|
||||
api_logger.info(f"Fetching latest text memory: user={current_user.username}, end_user_id={end_user_id}")
|
||||
|
||||
try:
|
||||
# 调用服务层获取最近的文本记忆
|
||||
service = MemoryPerceptualService(db)
|
||||
text_memory = service.get_latest_text_memory(group_id)
|
||||
text_memory = service.get_latest_text_memory(end_user_id)
|
||||
|
||||
if text_memory is None:
|
||||
api_logger.info(f"No text memory found: group_id={group_id}")
|
||||
api_logger.info(f"No text memory found: end_user_id={end_user_id}")
|
||||
return success(
|
||||
data=None,
|
||||
msg="No text memory available"
|
||||
@@ -190,16 +190,16 @@ def get_last_text_memory(
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"Failed to fetch latest text memory: group_id={group_id}, error={str(e)}")
|
||||
api_logger.error(f"Failed to fetch latest text memory: end_user_id={end_user_id}, error={str(e)}")
|
||||
return fail(
|
||||
code=BizCode.INTERNAL_ERROR,
|
||||
msg="Failed to fetch latest text memory",
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{group_id}/timeline", response_model=ApiResponse)
|
||||
@router.get("/{end_user_id}/timeline", response_model=ApiResponse)
|
||||
def get_memory_time_line(
|
||||
group_id: uuid.UUID,
|
||||
end_user_id: uuid.UUID,
|
||||
perceptual_type: Optional[PerceptualType] = Query(None, description="感知类型过滤"),
|
||||
page: int = Query(1, ge=1, description="页码"),
|
||||
page_size: int = Query(10, ge=1, le=100, description="每页大小"),
|
||||
@@ -209,7 +209,7 @@ def get_memory_time_line(
|
||||
"""Retrieve a timeline of perceptual memories for a user group.
|
||||
|
||||
Args:
|
||||
group_id: ID of the user group
|
||||
end_user_id: ID of the user group
|
||||
perceptual_type: Optional filter for perceptual type
|
||||
page: Page number for pagination
|
||||
page_size: Number of items per page
|
||||
@@ -221,7 +221,7 @@ def get_memory_time_line(
|
||||
"""
|
||||
api_logger.info(
|
||||
f"Fetching perceptual memory timeline: user={current_user.username}, "
|
||||
f"group_id={group_id}, type={perceptual_type}, page={page}"
|
||||
f"end_user_id={end_user_id}, type={perceptual_type}, page={page}"
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -232,7 +232,7 @@ def get_memory_time_line(
|
||||
)
|
||||
|
||||
service = MemoryPerceptualService(db)
|
||||
timeline_data = service.get_time_line(group_id, query)
|
||||
timeline_data = service.get_time_line(end_user_id, query)
|
||||
|
||||
api_logger.info(
|
||||
f"Perceptual memory timeline retrieved successfully: total={timeline_data.total}, "
|
||||
@@ -246,7 +246,7 @@ def get_memory_time_line(
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(
|
||||
f"Failed to fetch perceptual memory timeline: group_id={group_id}, "
|
||||
f"Failed to fetch perceptual memory timeline: end_user_id={end_user_id}, "
|
||||
f"error={str(e)}"
|
||||
)
|
||||
return fail(
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import asyncio
|
||||
import time
|
||||
import uuid
|
||||
from uuid import UUID
|
||||
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.core.memory.storage_services.reflection_engine.self_reflexion import (
|
||||
@@ -11,7 +12,7 @@ from app.core.response_utils import success
|
||||
from app.db import get_db
|
||||
from app.dependencies import get_current_user
|
||||
from app.models.user_model import User
|
||||
from app.repositories.data_config_repository import DataConfigRepository
|
||||
from app.repositories.memory_config_repository import MemoryConfigRepository
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.schemas.memory_reflection_schemas import Memory_Reflection
|
||||
from app.services.memory_reflection_service import (
|
||||
@@ -24,6 +25,8 @@ from fastapi import APIRouter, Depends, HTTPException, status,Header
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.utils.config_utils import resolve_config_id
|
||||
|
||||
load_dotenv()
|
||||
api_logger = get_api_logger()
|
||||
|
||||
@@ -42,6 +45,7 @@ async def save_reflection_config(
|
||||
"""Save reflection configuration to data_comfig table"""
|
||||
try:
|
||||
config_id = request.config_id
|
||||
config_id = resolve_config_id(config_id, db)
|
||||
if not config_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
@@ -50,7 +54,7 @@ async def save_reflection_config(
|
||||
|
||||
api_logger.info(f"用户 {current_user.username} 保存反思配置,config_id: {config_id}")
|
||||
|
||||
data_config = DataConfigRepository.update_reflection_config(
|
||||
memory_config = MemoryConfigRepository.update_reflection_config(
|
||||
db,
|
||||
config_id=config_id,
|
||||
enable_self_reflexion=request.reflection_enabled,
|
||||
@@ -63,17 +67,17 @@ async def save_reflection_config(
|
||||
)
|
||||
|
||||
db.commit()
|
||||
db.refresh(data_config)
|
||||
db.refresh(memory_config)
|
||||
|
||||
reflection_result={
|
||||
"config_id": data_config.config_id,
|
||||
"enable_self_reflexion": data_config.enable_self_reflexion,
|
||||
"iteration_period": data_config.iteration_period,
|
||||
"reflexion_range": data_config.reflexion_range,
|
||||
"baseline": data_config.baseline,
|
||||
"reflection_model_id": data_config.reflection_model_id,
|
||||
"memory_verify": data_config.memory_verify,
|
||||
"quality_assessment": data_config.quality_assessment}
|
||||
"config_id": memory_config.config_id,
|
||||
"enable_self_reflexion": memory_config.enable_self_reflexion,
|
||||
"iteration_period": memory_config.iteration_period,
|
||||
"reflexion_range": memory_config.reflexion_range,
|
||||
"baseline": memory_config.baseline,
|
||||
"reflection_model_id": memory_config.reflection_model_id,
|
||||
"memory_verify": memory_config.memory_verify,
|
||||
"quality_assessment": memory_config.quality_assessment}
|
||||
|
||||
return success(data=reflection_result, msg="反思配置成功")
|
||||
|
||||
@@ -111,14 +115,14 @@ async def start_workspace_reflection(
|
||||
reflection_results = []
|
||||
|
||||
for data in result['apps_detailed_info']:
|
||||
if data['data_configs'] == []:
|
||||
if data['memory_configs'] == []:
|
||||
continue
|
||||
|
||||
releases = data['releases']
|
||||
data_configs = data['data_configs']
|
||||
memory_configs = data['memory_configs']
|
||||
end_users = data['end_users']
|
||||
|
||||
for base, config, user in zip(releases, data_configs, end_users):
|
||||
for base, config, user in zip(releases, memory_configs, end_users):
|
||||
# 安全地转换为整数,处理空字符串和None的情况
|
||||
print(base['config'])
|
||||
try:
|
||||
@@ -156,17 +160,20 @@ async def start_workspace_reflection(
|
||||
|
||||
@router.get("/reflection/configs")
|
||||
async def start_reflection_configs(
|
||||
config_id: int,
|
||||
config_id: uuid.UUID|int,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
"""通过config_id查询data_config表中的反思配置信息"""
|
||||
"""通过config_id查询memory_config表中的反思配置信息"""
|
||||
config_id = resolve_config_id(config_id, db)
|
||||
try:
|
||||
config_id=resolve_config_id(config_id,db)
|
||||
api_logger.info(f"用户 {current_user.username} 查询反思配置,config_id: {config_id}")
|
||||
result = DataConfigRepository.query_reflection_config_by_id(db, config_id)
|
||||
result = MemoryConfigRepository.query_reflection_config_by_id(db, config_id)
|
||||
memory_config_id = resolve_config_id(result.config_id, db)
|
||||
# 构建返回数据
|
||||
reflection_config = {
|
||||
"config_id": result.config_id,
|
||||
"config_id": memory_config_id,
|
||||
"reflection_enabled": result.enable_self_reflexion,
|
||||
"reflection_period_in_hours": result.iteration_period,
|
||||
"reflexion_range": result.reflexion_range,
|
||||
@@ -191,7 +198,7 @@ async def start_reflection_configs(
|
||||
|
||||
@router.get("/reflection/run")
|
||||
async def reflection_run(
|
||||
config_id: int,
|
||||
config_id: UUID|int,
|
||||
language_type: str = Header(default="zh", alias="X-Language-Type"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
@@ -199,9 +206,9 @@ async def reflection_run(
|
||||
"""Activate the reflection function for all matching applications in the workspace"""
|
||||
|
||||
api_logger.info(f"用户 {current_user.username} 查询反思配置,config_id: {config_id}")
|
||||
|
||||
# 使用DataConfigRepository查询反思配置
|
||||
result = DataConfigRepository.query_reflection_config_by_id(db, config_id)
|
||||
config_id = resolve_config_id(config_id, db)
|
||||
# 使用MemoryConfigRepository查询反思配置
|
||||
result = MemoryConfigRepository.query_reflection_config_by_id(db, config_id)
|
||||
if not result:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import os
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.logging_config import get_api_logger
|
||||
@@ -34,6 +35,8 @@ from fastapi import APIRouter, Depends
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.utils.config_utils import resolve_config_id
|
||||
|
||||
# Get API logger
|
||||
api_logger = get_api_logger()
|
||||
|
||||
@@ -140,7 +143,6 @@ def create_config(
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
# 检查用户是否已选择工作空间
|
||||
if workspace_id is None:
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试创建配置但未选择工作空间")
|
||||
@@ -160,12 +162,12 @@ def create_config(
|
||||
|
||||
@router.delete("/delete_config", response_model=ApiResponse) # 删除数据库中的内容(按配置名称)
|
||||
def delete_config(
|
||||
config_id: str,
|
||||
config_id: UUID|int,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
config_id=resolve_config_id(config_id, db)
|
||||
# 检查用户是否已选择工作空间
|
||||
if workspace_id is None:
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试删除配置但未选择工作空间")
|
||||
@@ -187,7 +189,7 @@ def update_config(
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
payload.config_id = resolve_config_id(payload.config_id, db)
|
||||
# 检查用户是否已选择工作空间
|
||||
if workspace_id is None:
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试更新配置但未选择工作空间")
|
||||
@@ -210,7 +212,7 @@ def update_config_extracted(
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
payload.config_id = resolve_config_id(payload.config_id, db)
|
||||
# 检查用户是否已选择工作空间
|
||||
if workspace_id is None:
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试更新提取配置但未选择工作空间")
|
||||
@@ -232,12 +234,12 @@ def update_config_extracted(
|
||||
|
||||
@router.get("/read_config_extracted", response_model=ApiResponse) # 通过查询参数读取某条配置(固定路径) 没有意义的话就删除
|
||||
def read_config_extracted(
|
||||
config_id: str,
|
||||
config_id: UUID | int,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
config_id = resolve_config_id(config_id, db)
|
||||
# 检查用户是否已选择工作空间
|
||||
if workspace_id is None:
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试读取提取配置但未选择工作空间")
|
||||
@@ -285,6 +287,7 @@ async def pilot_run(
|
||||
f"Pilot run requested: config_id={payload.config_id}, "
|
||||
f"dialogue_text_length={len(payload.dialogue_text)}"
|
||||
)
|
||||
payload.config_id = resolve_config_id(payload.config_id, db)
|
||||
svc = DataConfigService(db)
|
||||
return StreamingResponse(
|
||||
svc.pilot_run_stream(payload),
|
||||
@@ -420,15 +423,95 @@ async def get_hot_memory_tags_api(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> dict:
|
||||
api_logger.info(f"Hot memory tags requested for current_user: {current_user.id}")
|
||||
"""
|
||||
获取热门记忆标签(带Redis缓存)
|
||||
|
||||
缓存策略:
|
||||
- 缓存键:workspace_id + limit
|
||||
- 过期时间:5分钟(300秒)
|
||||
- 缓存命中:~50ms
|
||||
- 缓存未命中:~600-800ms(取决于LLM速度)
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
# 构建缓存键
|
||||
cache_key = f"hot_memory_tags:{workspace_id}:{limit}"
|
||||
|
||||
api_logger.info(f"Hot memory tags requested for workspace: {workspace_id}, limit: {limit}")
|
||||
|
||||
try:
|
||||
# 尝试从Redis缓存获取
|
||||
from app.aioRedis import aio_redis_get, aio_redis_set
|
||||
import json
|
||||
|
||||
cached_result = await aio_redis_get(cache_key)
|
||||
if cached_result:
|
||||
api_logger.info(f"Cache hit for key: {cache_key}")
|
||||
try:
|
||||
data = json.loads(cached_result)
|
||||
return success(data=data, msg="查询成功(缓存)")
|
||||
except json.JSONDecodeError:
|
||||
api_logger.warning(f"Failed to parse cached data, will refresh")
|
||||
|
||||
# 缓存未命中,执行查询
|
||||
api_logger.info(f"Cache miss for key: {cache_key}, executing query")
|
||||
result = await analytics_hot_memory_tags(db, current_user, limit)
|
||||
|
||||
# 写入缓存(过期时间:5分钟)
|
||||
# 注意:result是列表,需要转换为JSON字符串
|
||||
try:
|
||||
cache_data = json.dumps(result, ensure_ascii=False)
|
||||
await aio_redis_set(cache_key, cache_data, expire=300)
|
||||
api_logger.info(f"Cached result for key: {cache_key}")
|
||||
except Exception as cache_error:
|
||||
# 缓存写入失败不影响主流程
|
||||
api_logger.warning(f"Failed to cache result: {str(cache_error)}")
|
||||
|
||||
return success(data=result, msg="查询成功")
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"Hot memory tags failed: {str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "热门标签查询失败", str(e))
|
||||
|
||||
|
||||
@router.delete("/analytics/hot_memory_tags/cache", response_model=ApiResponse)
|
||||
async def clear_hot_memory_tags_cache(
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> dict:
|
||||
"""
|
||||
清除热门标签缓存
|
||||
|
||||
用于:
|
||||
- 手动刷新数据
|
||||
- 调试和测试
|
||||
- 数据更新后立即生效
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
api_logger.info(f"Clear hot memory tags cache requested for workspace: {workspace_id}")
|
||||
|
||||
try:
|
||||
from app.aioRedis import aio_redis_delete
|
||||
|
||||
# 清除所有limit的缓存(常见的limit值)
|
||||
cleared_count = 0
|
||||
for limit in [5, 10, 15, 20, 30, 50]:
|
||||
cache_key = f"hot_memory_tags:{workspace_id}:{limit}"
|
||||
result = await aio_redis_delete(cache_key)
|
||||
if result:
|
||||
cleared_count += 1
|
||||
api_logger.info(f"Cleared cache for key: {cache_key}")
|
||||
|
||||
return success(
|
||||
data={"cleared_count": cleared_count},
|
||||
msg=f"成功清除 {cleared_count} 个缓存"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"Clear cache failed: {str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "清除缓存失败", str(e))
|
||||
|
||||
|
||||
@router.get("/analytics/recent_activity_stats", response_model=ApiResponse)
|
||||
async def get_recent_activity_stats_api(
|
||||
current_user: User = Depends(get_current_user),
|
||||
|
||||
@@ -20,18 +20,18 @@ router = APIRouter(
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{group_id}/count", response_model=ApiResponse)
|
||||
@router.get("/{end_user_id}/count", response_model=ApiResponse)
|
||||
def get_memory_count(
|
||||
group_id: uuid.UUID,
|
||||
end_user_id: uuid.UUID,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
@router.get("/{group_id}/conversations", response_model=ApiResponse)
|
||||
@router.get("/{end_user_id}/conversations", response_model=ApiResponse)
|
||||
def get_conversations(
|
||||
group_id: uuid.UUID,
|
||||
end_user_id: uuid.UUID,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
@@ -39,7 +39,7 @@ def get_conversations(
|
||||
Retrieve all conversations for the current user in a specific group.
|
||||
|
||||
Args:
|
||||
group_id (UUID): The group identifier.
|
||||
end_user_id (UUID): The group identifier.
|
||||
current_user (User, optional): The authenticated user.
|
||||
db (Session, optional): SQLAlchemy session.
|
||||
|
||||
@@ -53,7 +53,7 @@ def get_conversations(
|
||||
"""
|
||||
conversation_service = ConversationService(db)
|
||||
conversations = conversation_service.get_user_conversations(
|
||||
group_id
|
||||
end_user_id
|
||||
)
|
||||
return success(data=[
|
||||
{
|
||||
@@ -63,7 +63,7 @@ def get_conversations(
|
||||
], msg="get conversations success")
|
||||
|
||||
|
||||
@router.get("/{group_id}/messages", response_model=ApiResponse)
|
||||
@router.get("/{end_user_id}/messages", response_model=ApiResponse)
|
||||
def get_messages(
|
||||
conversation_id: uuid.UUID,
|
||||
current_user: User = Depends(get_current_user),
|
||||
@@ -100,7 +100,7 @@ def get_messages(
|
||||
return success(data=messages, msg="get conversation history success")
|
||||
|
||||
|
||||
@router.get("/{group_id}/detail", response_model=ApiResponse)
|
||||
@router.get("/{end_user_id}/detail", response_model=ApiResponse)
|
||||
async def get_conversation_detail(
|
||||
conversation_id: uuid.UUID,
|
||||
current_user: User = Depends(get_current_user),
|
||||
|
||||
@@ -3,15 +3,17 @@ from sqlalchemy.orm import Session
|
||||
from typing import Optional
|
||||
import uuid
|
||||
|
||||
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.db import get_db
|
||||
from app.dependencies import get_current_user
|
||||
from app.models.models_model import ModelProvider, ModelType
|
||||
from app.models.models_model import ModelProvider, ModelType, LoadBalanceStrategy
|
||||
from app.models.user_model import User
|
||||
from app.repositories.model_repository import ModelConfigRepository
|
||||
from app.schemas import model_schema
|
||||
from app.core.response_utils import success
|
||||
from app.schemas.response_schema import ApiResponse, PageData
|
||||
from app.services.model_service import ModelConfigService, ModelApiKeyService
|
||||
from app.services.model_service import ModelConfigService, ModelApiKeyService, ModelBaseService
|
||||
from app.core.logging_config import get_api_logger
|
||||
|
||||
# 获取API专用日志器
|
||||
@@ -24,24 +26,83 @@ router = APIRouter(
|
||||
|
||||
@router.get("/type", response_model=ApiResponse)
|
||||
def get_model_types():
|
||||
|
||||
return success(msg="获取模型类型成功", data=list(ModelType))
|
||||
|
||||
|
||||
@router.get("/provider", response_model=ApiResponse)
|
||||
def get_model_providers():
|
||||
return success(msg="获取模型提供商成功", data=list(ModelProvider))
|
||||
providers = [p for p in ModelProvider if p != ModelProvider.COMPOSITE]
|
||||
return success(msg="获取模型提供商成功", data=providers)
|
||||
|
||||
@router.get("/strategy", response_model=ApiResponse)
|
||||
def get_model_strategies():
|
||||
return success(msg="获取模型策略成功", data=list(LoadBalanceStrategy))
|
||||
|
||||
|
||||
@router.get("", response_model=ApiResponse)
|
||||
def get_model_list(
|
||||
type: Optional[str] = Query(None, description="模型类型筛选(支持多个,如 ?type=LLM 或 ?type=LLM,EMBEDDING)"),
|
||||
provider: Optional[model_schema.ModelProvider] = Query(None, description="提供商筛选(基于API Key)"),
|
||||
type: Optional[list[str]] = Query(None, description="模型类型筛选(支持多个,如 ?type=LLM 或 ?type=LLM,EMBEDDING)"),
|
||||
provider: Optional[model_schema.ModelProvider] = Query(None, description="提供商筛选(基于API Key)"),
|
||||
is_active: Optional[bool] = Query(None, description="激活状态筛选"),
|
||||
is_public: Optional[bool] = Query(None, description="公开状态筛选"),
|
||||
search: Optional[str] = Query(None, description="搜索关键词"),
|
||||
page: int = Query(1, ge=1, description="页码"),
|
||||
pagesize: int = Query(10, ge=1, le=100, description="每页数量"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
获取模型配置列表
|
||||
|
||||
支持多个 type 参数:
|
||||
- 单个:?type=LLM
|
||||
- 多个(逗号分隔):?type=LLM,EMBEDDING
|
||||
- 多个(重复参数):?type=LLM&type=EMBEDDING
|
||||
"""
|
||||
api_logger.info(
|
||||
f"获取模型配置列表请求: type={type}, provider={provider}, page={page}, pagesize={pagesize}, tenant_id={current_user.tenant_id}")
|
||||
|
||||
try:
|
||||
# 解析 type 参数(支持逗号分隔)
|
||||
type_list = []
|
||||
if type is not None:
|
||||
flat_type = []
|
||||
for item in type:
|
||||
split_items = [t.strip() for t in item.split(',') if t.strip()]
|
||||
flat_type.extend(split_items)
|
||||
|
||||
unique_flat_type = list(dict.fromkeys(flat_type))
|
||||
type_list = [ModelType(t.lower()) for t in unique_flat_type]
|
||||
|
||||
api_logger.error(f"获取模型type_list: {type_list}")
|
||||
query = model_schema.ModelConfigQuery(
|
||||
type=type_list,
|
||||
provider=provider,
|
||||
is_active=is_active,
|
||||
is_public=is_public,
|
||||
search=search,
|
||||
page=page,
|
||||
pagesize=pagesize
|
||||
)
|
||||
|
||||
api_logger.debug(f"开始获取模型配置列表: {query.dict()}")
|
||||
result_orm = ModelConfigService.get_model_list(db=db, query=query, tenant_id=current_user.tenant_id)
|
||||
result = PageData.model_validate(result_orm)
|
||||
api_logger.info(f"模型配置列表获取成功: 总数={result.page.total}, 当前页={len(result.items)}")
|
||||
return success(data=result, msg="模型配置列表获取成功")
|
||||
except Exception as e:
|
||||
api_logger.error(f"获取模型配置列表失败: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
@router.get("/new", response_model=ApiResponse)
|
||||
def get_model_list_new(
|
||||
type: Optional[list[str]] = Query(None, description="模型类型筛选(支持多个,如 ?type=LLM 或 ?type=LLM,EMBEDDING)"),
|
||||
provider: Optional[model_schema.ModelProvider] = Query(None, description="提供商筛选(基于ModelConfig)"),
|
||||
is_active: Optional[bool] = Query(None, description="激活状态筛选"),
|
||||
is_public: Optional[bool] = Query(None, description="公开状态筛选"),
|
||||
search: Optional[str] = Query(None, description="搜索关键词"),
|
||||
page: int = Query(1, ge=1, description="页码"),
|
||||
pagesize: int = Query(10, ge=1, le=100, description="每页数量"),
|
||||
is_composite: Optional[bool] = Query(None, description="组合模型筛选"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
@@ -53,36 +114,127 @@ def get_model_list(
|
||||
- 多个(逗号分隔):?type=LLM,EMBEDDING
|
||||
- 多个(重复参数):?type=LLM&type=EMBEDDING
|
||||
"""
|
||||
api_logger.info(f"获取模型配置列表请求: type={type}, provider={provider}, page={page}, pagesize={pagesize}, tenant_id={current_user.tenant_id}")
|
||||
api_logger.info(f"获取模型配置列表请求: type={type}, provider={provider}, tenant_id={current_user.tenant_id}")
|
||||
|
||||
try:
|
||||
# 解析 type 参数(支持逗号分隔)
|
||||
type_list = None
|
||||
if type:
|
||||
type_values = [t.strip() for t in type.split(',')]
|
||||
type_list = [model_schema.ModelType(t.lower()) for t in type_values if t]
|
||||
type_list = []
|
||||
if type is not None:
|
||||
flat_type = []
|
||||
for item in type:
|
||||
split_items = [t.strip() for t in item.split(',') if t.strip()]
|
||||
flat_type.extend(split_items)
|
||||
|
||||
unique_flat_type = list(dict.fromkeys(flat_type))
|
||||
type_list = [ModelType(t.lower()) for t in unique_flat_type]
|
||||
|
||||
api_logger.error(f"获取模型type_list: {type_list}")
|
||||
query = model_schema.ModelConfigQuery(
|
||||
api_logger.info(f"获取模型type_list: {type_list}")
|
||||
query = model_schema.ModelConfigQueryNew(
|
||||
type=type_list,
|
||||
provider=provider,
|
||||
is_active=is_active,
|
||||
is_public=is_public,
|
||||
search=search,
|
||||
page=page,
|
||||
pagesize=pagesize
|
||||
is_composite=is_composite,
|
||||
search=search
|
||||
)
|
||||
|
||||
api_logger.debug(f"开始获取模型配置列表: {query.dict()}")
|
||||
result_orm = ModelConfigService.get_model_list(db=db, query=query, tenant_id=current_user.tenant_id)
|
||||
result = PageData.model_validate(result_orm)
|
||||
api_logger.info(f"模型配置列表获取成功: 总数={result.page.total}, 当前页={len(result.items)}")
|
||||
api_logger.debug(f"开始获取模型配置列表: {query.model_dump()}")
|
||||
result = ModelConfigService.get_model_list_new(db=db, query=query, tenant_id=current_user.tenant_id)
|
||||
api_logger.info(f"模型配置列表获取成功: 分组数={len(result)}, 总模型数={sum(len(item['models']) for item in result)}")
|
||||
return success(data=result, msg="模型配置列表获取成功")
|
||||
except Exception as e:
|
||||
api_logger.error(f"获取模型配置列表失败: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
@router.get("/model_plaza", response_model=ApiResponse)
|
||||
def get_model_plaza_list(
|
||||
type: Optional[ModelType] = Query(None, description="模型类型"),
|
||||
provider: Optional[ModelProvider] = Query(None, description="供应商"),
|
||||
is_official: Optional[bool] = Query(None, description="是否官方模型"),
|
||||
is_deprecated: Optional[bool] = Query(None, description="是否弃用"),
|
||||
search: Optional[str] = Query(None, description="搜索关键词"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""模型广场查询接口(按供应商分组)"""
|
||||
|
||||
query = model_schema.ModelBaseQuery(
|
||||
type=type,
|
||||
provider=provider,
|
||||
is_official=is_official,
|
||||
is_deprecated=is_deprecated,
|
||||
search=search
|
||||
)
|
||||
result = ModelBaseService.get_model_base_list(db=db, query=query, tenant_id=current_user.tenant_id)
|
||||
return success(data=result, msg="模型广场列表获取成功")
|
||||
|
||||
|
||||
@router.get("/model_plaza/{model_base_id}", response_model=ApiResponse)
|
||||
def get_model_base_by_id(
|
||||
model_base_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""获取基础模型详情"""
|
||||
|
||||
result = ModelBaseService.get_model_base_by_id(db=db, model_base_id=model_base_id)
|
||||
return success(data=model_schema.ModelBase.model_validate(result), msg="基础模型获取成功")
|
||||
|
||||
|
||||
@router.post("/model_plaza", response_model=ApiResponse)
|
||||
def create_model_base(
|
||||
data: model_schema.ModelBaseCreate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""创建基础模型"""
|
||||
|
||||
result = ModelBaseService.create_model_base(db=db, data=data)
|
||||
return success(data=model_schema.ModelBase.model_validate(result), msg="基础模型创建成功")
|
||||
|
||||
|
||||
@router.put("/model_plaza/{model_base_id}", response_model=ApiResponse)
|
||||
def update_model_base(
|
||||
model_base_id: uuid.UUID,
|
||||
data: model_schema.ModelBaseUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""更新基础模型"""
|
||||
|
||||
# 不允许更改type类型
|
||||
if data.type is not None or data.provider is not None:
|
||||
raise BusinessException("不允许更改模型类型和供应商", BizCode.INVALID_PARAMETER)
|
||||
|
||||
result = ModelBaseService.update_model_base(db=db, model_base_id=model_base_id, data=data)
|
||||
return success(data=model_schema.ModelBase.model_validate(result), msg="基础模型更新成功")
|
||||
|
||||
|
||||
@router.delete("/model_plaza/{model_base_id}", response_model=ApiResponse)
|
||||
def delete_model_base(
|
||||
model_base_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""删除基础模型"""
|
||||
|
||||
ModelBaseService.delete_model_base(db=db, model_base_id=model_base_id)
|
||||
return success(msg="基础模型删除成功")
|
||||
|
||||
|
||||
@router.post("/model_plaza/{model_base_id}/add", response_model=ApiResponse)
|
||||
def add_model_from_plaza(
|
||||
model_base_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""从模型广场添加模型到模型列表"""
|
||||
|
||||
result = ModelBaseService.add_model_from_plaza(db=db, model_base_id=model_base_id, tenant_id=current_user.tenant_id)
|
||||
return success(data=model_schema.ModelConfig.model_validate(result), msg="模型添加成功")
|
||||
|
||||
|
||||
@router.get("/{model_id}", response_model=ApiResponse)
|
||||
def get_model_by_id(
|
||||
model_id: uuid.UUID,
|
||||
@@ -138,6 +290,73 @@ async def create_model(
|
||||
raise
|
||||
|
||||
|
||||
@router.post("/composite", response_model=ApiResponse)
|
||||
async def create_composite_model(
|
||||
model_data: model_schema.CompositeModelCreate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
创建组合模型
|
||||
|
||||
- 绑定一个或多个现有的 API Key
|
||||
- 所有 API Key 必须来自非组合模型
|
||||
- 所有 API Key 关联的模型类型必须与组合模型类型一致
|
||||
"""
|
||||
api_logger.info(f"创建组合模型请求: {model_data.name}, 用户: {current_user.username}, tenant_id={current_user.tenant_id}")
|
||||
|
||||
try:
|
||||
result_orm = await ModelConfigService.create_composite_model(db=db, model_data=model_data, tenant_id=current_user.tenant_id)
|
||||
api_logger.info(f"组合模型创建成功: {result_orm.name} (ID: {result_orm.id})")
|
||||
|
||||
result = model_schema.ModelConfig.model_validate(result_orm)
|
||||
return success(data=result, msg="组合模型创建成功")
|
||||
except Exception as e:
|
||||
api_logger.error(f"创建组合模型失败: {model_data.name} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
@router.put("/composite/{model_id}", response_model=ApiResponse)
|
||||
async def update_composite_model(
|
||||
model_id: uuid.UUID,
|
||||
model_data: model_schema.CompositeModelCreate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""更新组合模型"""
|
||||
api_logger.info(f"更新组合模型请求: model_id={model_id}, 用户: {current_user.username}")
|
||||
|
||||
try:
|
||||
if model_data.type is not None:
|
||||
raise BusinessException("不允许更改模型类型和供应商", BizCode.INVALID_PARAMETER)
|
||||
result_orm = await ModelConfigService.update_composite_model(db=db, model_id=model_id, model_data=model_data, tenant_id=current_user.tenant_id)
|
||||
api_logger.info(f"组合模型更新成功: {result_orm.name} (ID: {model_id})")
|
||||
|
||||
result = model_schema.ModelConfig.model_validate(result_orm)
|
||||
return success(data=result, msg="组合模型更新成功")
|
||||
except Exception as e:
|
||||
api_logger.error(f"更新组合模型失败: model_id={model_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
@router.delete("/composite/{model_id}", response_model=ApiResponse)
|
||||
def delete_composite_model(
|
||||
model_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""删除组合模型"""
|
||||
api_logger.info(f"删除组合模型请求: model_id={model_id}, 用户: {current_user.username}")
|
||||
|
||||
try:
|
||||
ModelConfigService.delete_model(db=db, model_id=model_id, tenant_id=current_user.tenant_id)
|
||||
api_logger.info(f"组合模型删除成功: model_id={model_id}")
|
||||
return success(msg="组合模型删除成功")
|
||||
except Exception as e:
|
||||
api_logger.error(f"删除组合模型失败: model_id={model_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
@router.put("/{model_id}", response_model=ApiResponse)
|
||||
def update_model(
|
||||
model_id: uuid.UUID,
|
||||
@@ -214,6 +433,53 @@ def get_model_api_keys(
|
||||
raise
|
||||
|
||||
|
||||
@router.post("/provider/apikeys", response_model=ApiResponse)
|
||||
async def create_model_api_key_by_provider(
|
||||
api_key_data: model_schema.ModelApiKeyCreateByProvider,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
根据供应商为所有匹配的模型创建API Key
|
||||
"""
|
||||
api_logger.info(f"创建API Key请求: provider={api_key_data.provider}, 用户: {current_user.username}")
|
||||
|
||||
try:
|
||||
# 根据tenant_id和provider筛选model_config_id列表
|
||||
model_config_ids = api_key_data.model_config_ids
|
||||
if not model_config_ids:
|
||||
model_config_ids = ModelConfigRepository.get_model_config_ids_by_provider(
|
||||
db=db,
|
||||
tenant_id=current_user.tenant_id,
|
||||
provider=api_key_data.provider
|
||||
)
|
||||
|
||||
if not model_config_ids:
|
||||
raise BusinessException(f"未找到供应商 {api_key_data.provider} 的模型配置", BizCode.MODEL_NOT_FOUND)
|
||||
|
||||
# 构造schema并调用service
|
||||
create_data = model_schema.ModelApiKeyCreateByProvider(
|
||||
provider=api_key_data.provider,
|
||||
api_key=api_key_data.api_key,
|
||||
api_base=api_key_data.api_base,
|
||||
description=api_key_data.description,
|
||||
config=api_key_data.config,
|
||||
is_active=api_key_data.is_active,
|
||||
priority=api_key_data.priority,
|
||||
model_config_ids=model_config_ids
|
||||
)
|
||||
created_keys, failed_models = await ModelApiKeyService.create_api_key_by_provider(db=db, data=create_data)
|
||||
|
||||
api_logger.info(f"API Key创建成功: 关联{len(created_keys)}个模型")
|
||||
# result_list = [model_schema.ModelApiKey.model_validate(key) for key in created_keys]
|
||||
result = "API Key已存在" if len(created_keys) == 0 and len(failed_models) == 0 else \
|
||||
f"成功为 {len(created_keys)} 个模型创建API Key, 失败模型列表{failed_models}"
|
||||
return success(data=result, msg=f"成功为 {len(created_keys)} 个模型创建API Key")
|
||||
except Exception as e:
|
||||
api_logger.error(f"创建API Key失败: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
@router.post("/{model_id}/apikeys", response_model=ApiResponse, status_code=status.HTTP_201_CREATED)
|
||||
async def create_model_api_key(
|
||||
model_id: uuid.UUID,
|
||||
@@ -228,11 +494,12 @@ async def create_model_api_key(
|
||||
|
||||
try:
|
||||
# 设置模型配置ID
|
||||
api_key_data.model_config_id = model_id
|
||||
api_key_data.model_config_ids = [model_id]
|
||||
|
||||
api_logger.debug(f"开始创建模型API Key: {api_key_data.model_name}")
|
||||
result = await ModelApiKeyService.create_api_key(db=db, api_key_data=api_key_data)
|
||||
api_logger.info(f"模型API Key创建成功: {result.model_name} (ID: {result.id})")
|
||||
result_orm = await ModelApiKeyService.create_api_key(db=db, api_key_data=api_key_data)
|
||||
api_logger.info(f"模型API Key创建成功: {result_orm.model_name} (ID: {result_orm.id})")
|
||||
result = model_schema.ModelApiKey.model_validate(result_orm)
|
||||
return success(data=result, msg="模型API Key创建成功")
|
||||
except Exception as e:
|
||||
api_logger.error(f"创建模型API Key失败: {api_key_data.model_name} - {str(e)}")
|
||||
@@ -334,5 +601,3 @@ async def validate_model_config(
|
||||
return success(data=model_schema.ModelValidateResponse(**result), msg="验证完成")
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import uuid
|
||||
import json
|
||||
import uuid
|
||||
|
||||
from fastapi import APIRouter, Depends, Path
|
||||
from sqlalchemy.orm import Session
|
||||
@@ -8,9 +8,13 @@ from starlette.responses import StreamingResponse
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.core.response_utils import success
|
||||
from app.dependencies import get_current_user, get_db
|
||||
from app.models.prompt_optimizer_model import RoleType
|
||||
from app.schemas.prompt_optimizer_schema import PromptOptMessage, PromptOptModelSet, CreateSessionResponse, \
|
||||
OptimizePromptResponse, SessionHistoryResponse, SessionMessage
|
||||
from app.schemas.prompt_optimizer_schema import (
|
||||
PromptOptMessage,
|
||||
CreateSessionResponse,
|
||||
SessionHistoryResponse,
|
||||
SessionMessage,
|
||||
PromptSaveRequest
|
||||
)
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.services.prompt_optimizer_service import PromptOptimizerService
|
||||
|
||||
@@ -135,3 +139,109 @@ async def get_prompt_opt(
|
||||
"X-Accel-Buffering": "no"
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/releases",
|
||||
summary="Get prompt optimization",
|
||||
response_model=ApiResponse
|
||||
)
|
||||
def save_prompt(
|
||||
data: PromptSaveRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
Save a prompt release for the current tenant.
|
||||
|
||||
Args:
|
||||
data (PromptSaveRequest): Request body containing session_id, title, and prompt.
|
||||
db (Session): SQLAlchemy database session, injected via dependency.
|
||||
current_user: Currently authenticated user object, injected via dependency.
|
||||
|
||||
Returns:
|
||||
ApiResponse: Standard API response containing the saved prompt release info:
|
||||
- id: UUID of the prompt release
|
||||
- session_id: associated session
|
||||
- title: prompt title
|
||||
- prompt: prompt content
|
||||
- created_at: timestamp of creation
|
||||
|
||||
Raises:
|
||||
Any database or service exceptions are propagated to the global exception handler.
|
||||
"""
|
||||
service = PromptOptimizerService(db)
|
||||
prompt_info = service.save_prompt(
|
||||
tenant_id=current_user.tenant_id,
|
||||
session_id=data.session_id,
|
||||
title=data.title,
|
||||
prompt=data.prompt
|
||||
)
|
||||
return success(data=prompt_info)
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/releases/{prompt_id}",
|
||||
summary="Delete prompt (soft delete)",
|
||||
response_model=ApiResponse
|
||||
)
|
||||
def delete_prompt(
|
||||
prompt_id: uuid.UUID = Path(..., description="Prompt ID"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
Soft delete a prompt release.
|
||||
|
||||
Args:
|
||||
prompt_id
|
||||
db (Session): Database session
|
||||
current_user: Current logged-in user
|
||||
|
||||
Returns:
|
||||
ApiResponse: Success message confirming deletion
|
||||
"""
|
||||
service = PromptOptimizerService(db)
|
||||
service.delete_prompt(
|
||||
tenant_id=current_user.tenant_id,
|
||||
prompt_id=prompt_id
|
||||
)
|
||||
return success(msg="Prompt deleted successfully")
|
||||
|
||||
|
||||
@router.get(
|
||||
"/releases/list",
|
||||
summary="Get paginated list of released prompts with optional filter",
|
||||
response_model=ApiResponse
|
||||
)
|
||||
def get_release_list(
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
keyword: str | None = None,
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
Retrieve paginated list of released prompts for the current tenant.
|
||||
Optionally filter by keyword in title.
|
||||
|
||||
Args:
|
||||
page (int): Page number (starting from 1)
|
||||
page_size (int): Number of items per page (max 100)
|
||||
keyword (str | None): Optional keyword to filter prompt titles
|
||||
db (Session): Database session
|
||||
current_user: Current logged-in user
|
||||
|
||||
Returns:
|
||||
ApiResponse: Contains paginated list of prompt releases with metadata
|
||||
"""
|
||||
service = PromptOptimizerService(db)
|
||||
result = service.get_release_list(
|
||||
tenant_id=current_user.tenant_id,
|
||||
page=max(1, page),
|
||||
page_size=min(max(1, page_size), 100),
|
||||
filter_keyword=keyword
|
||||
)
|
||||
return success(data=result)
|
||||
|
||||
|
||||
|
||||
@@ -317,9 +317,12 @@ async def chat(
|
||||
appid = share.app_id
|
||||
"""获取存储类型和工作空间的ID"""
|
||||
|
||||
# 直接通过 SQLAlchemy 查询 app
|
||||
# 直接通过 SQLAlchemy 查询 app(仅查询未删除的应用)
|
||||
from app.models.app_model import App
|
||||
app = db.query(App).filter(App.id == appid).first()
|
||||
app = db.query(App).filter(
|
||||
App.id == appid,
|
||||
App.is_active.is_(True)
|
||||
).first()
|
||||
if not app:
|
||||
raise BusinessException("应用不存在", BizCode.APP_NOT_FOUND)
|
||||
|
||||
|
||||
@@ -235,11 +235,11 @@ async def chat(
|
||||
|
||||
message=payload.message,
|
||||
conversation_id=conversation.id, # 使用已创建的会话 ID
|
||||
user_id=new_end_user.id, # 转换为字符串
|
||||
user_id=end_user_id, # 转换为字符串
|
||||
variables=payload.variables,
|
||||
config=config,
|
||||
web_search=payload.web_search,
|
||||
memory=payload.memory,
|
||||
web_search=web_search,
|
||||
memory=memory,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
app_id=app.id,
|
||||
@@ -268,11 +268,11 @@ async def chat(
|
||||
|
||||
message=payload.message,
|
||||
conversation_id=conversation.id, # 使用已创建的会话 ID
|
||||
user_id=new_end_user.id, # 转换为字符串
|
||||
user_id=end_user_id, # 转换为字符串
|
||||
variables=payload.variables,
|
||||
config=config,
|
||||
web_search=payload.web_search,
|
||||
memory=payload.memory,
|
||||
web_search=web_search,
|
||||
memory=memory,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
app_id=app.id,
|
||||
|
||||
@@ -39,7 +39,7 @@ async def write_memory_api_service(
|
||||
|
||||
Stores memory content for the specified end user using the Memory API Service.
|
||||
"""
|
||||
logger.info(f"Memory write request - end_user_id: {payload.end_user_id}")
|
||||
logger.info(f"Memory write request - end_user_id: {payload.end_user_id}, tenant_id: {api_key_auth.tenant_id}")
|
||||
|
||||
memory_api_service = MemoryAPIService(db)
|
||||
|
||||
|
||||
@@ -135,27 +135,27 @@ async def generate_cache_api(
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试生成缓存但未选择工作空间")
|
||||
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
||||
|
||||
group_id = request.end_user_id
|
||||
end_user_id = request.end_user_id
|
||||
|
||||
api_logger.info(
|
||||
f"缓存生成请求: user={current_user.username}, workspace={workspace_id}, "
|
||||
f"end_user_id={group_id if group_id else '全部用户'}"
|
||||
f"end_user_id={end_user_id if end_user_id else '全部用户'}"
|
||||
)
|
||||
|
||||
try:
|
||||
if group_id:
|
||||
if end_user_id:
|
||||
# 为单个用户生成
|
||||
api_logger.info(f"开始为单个用户生成缓存: end_user_id={group_id}")
|
||||
api_logger.info(f"开始为单个用户生成缓存: end_user_id={end_user_id}")
|
||||
|
||||
# 生成记忆洞察
|
||||
insight_result = await user_memory_service.generate_and_cache_insight(db, group_id, workspace_id)
|
||||
insight_result = await user_memory_service.generate_and_cache_insight(db, end_user_id, workspace_id)
|
||||
|
||||
# 生成用户摘要
|
||||
summary_result = await user_memory_service.generate_and_cache_summary(db, group_id, workspace_id)
|
||||
summary_result = await user_memory_service.generate_and_cache_summary(db, end_user_id, workspace_id)
|
||||
|
||||
# 构建响应
|
||||
result = {
|
||||
"end_user_id": group_id,
|
||||
"end_user_id": end_user_id,
|
||||
"insight_success": insight_result["success"],
|
||||
"summary_success": summary_result["success"],
|
||||
"errors": []
|
||||
@@ -175,9 +175,9 @@ async def generate_cache_api(
|
||||
|
||||
# 记录结果
|
||||
if result["insight_success"] and result["summary_success"]:
|
||||
api_logger.info(f"成功为用户 {group_id} 生成缓存")
|
||||
api_logger.info(f"成功为用户 {end_user_id} 生成缓存")
|
||||
else:
|
||||
api_logger.warning(f"用户 {group_id} 的缓存生成部分失败: {result['errors']}")
|
||||
api_logger.warning(f"用户 {end_user_id} 的缓存生成部分失败: {result['errors']}")
|
||||
|
||||
return success(data=result, msg="生成完成")
|
||||
|
||||
|
||||
@@ -54,7 +54,7 @@ async def create_workflow_config(
|
||||
app = db.query(App).filter(
|
||||
App.id == app_id,
|
||||
App.workspace_id == current_user.current_workspace_id,
|
||||
App.is_active == True
|
||||
App.is_active.is_(True)
|
||||
).first()
|
||||
|
||||
if not app:
|
||||
@@ -214,7 +214,7 @@ async def delete_workflow_config(
|
||||
app = db.query(App).filter(
|
||||
App.id == app_id,
|
||||
App.workspace_id == current_user.current_workspace_id,
|
||||
App.is_active == True
|
||||
App.is_active.is_(True)
|
||||
).first()
|
||||
|
||||
if not app:
|
||||
@@ -259,7 +259,7 @@ async def validate_workflow_config(
|
||||
app = db.query(App).filter(
|
||||
App.id == app_id,
|
||||
App.workspace_id == current_user.current_workspace_id,
|
||||
App.is_active == True
|
||||
App.is_active.is_(True)
|
||||
).first()
|
||||
|
||||
if not app:
|
||||
@@ -329,7 +329,7 @@ async def get_workflow_executions(
|
||||
app = db.query(App).filter(
|
||||
App.id == app_id,
|
||||
App.workspace_id == current_user.current_workspace_id,
|
||||
App.is_active == True
|
||||
App.is_active.is_(True)
|
||||
).first()
|
||||
|
||||
if not app:
|
||||
@@ -389,7 +389,7 @@ async def get_workflow_execution(
|
||||
app = db.query(App).filter(
|
||||
App.id == execution.app_id,
|
||||
App.workspace_id == current_user.current_workspace_id,
|
||||
App.is_active == True
|
||||
App.is_active.is_(True)
|
||||
).first()
|
||||
|
||||
if not app:
|
||||
@@ -440,7 +440,7 @@ async def run_workflow(
|
||||
app = db.query(App).filter(
|
||||
App.id == app_id,
|
||||
App.workspace_id == current_user.current_workspace_id,
|
||||
App.is_active == True
|
||||
App.is_active.is_(True)
|
||||
).first()
|
||||
|
||||
if not app:
|
||||
@@ -578,7 +578,7 @@ async def cancel_workflow_execution(
|
||||
app = db.query(App).filter(
|
||||
App.id == execution.app_id,
|
||||
App.workspace_id == current_user.current_workspace_id,
|
||||
App.is_active == True
|
||||
App.is_active.is_(True)
|
||||
).first()
|
||||
|
||||
if not app:
|
||||
|
||||
@@ -28,6 +28,8 @@ from langchain.agents import create_agent
|
||||
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
|
||||
from langchain_core.tools import BaseTool
|
||||
|
||||
from app.utils.config_utils import resolve_config_id
|
||||
|
||||
logger = get_business_logger()
|
||||
|
||||
|
||||
@@ -155,13 +157,13 @@ class LangChainAgent:
|
||||
# userid=end_user_end,
|
||||
# messages=messages,
|
||||
# apply_id=end_user_end,
|
||||
# group_id=end_user_end,
|
||||
# end_user_id=end_user_end,
|
||||
# aimessages=aimessages
|
||||
# )
|
||||
# store.delete_duplicate_sessions()
|
||||
# # logger.info(f'Redis_Agent:{end_user_end};{session_id}')
|
||||
# return session_id
|
||||
|
||||
|
||||
# TODO 乐力齐 - 累积多组对话批量写入功能已禁用
|
||||
# async def term_memory_redis_read(self,end_user_end):
|
||||
# end_user_end = f"Term_{end_user_end}"
|
||||
@@ -175,11 +177,10 @@ class LangChainAgent:
|
||||
# messagss_list.append(f'用户:{query}。AI回复:{aimessages}')
|
||||
# retrieved_content.append({query: aimessages})
|
||||
# return messagss_list,retrieved_content
|
||||
|
||||
async def write(self, storage_type, end_user_id, user_message, ai_message, user_rag_memory_id, actual_end_user_id, actual_config_id):
|
||||
"""
|
||||
写入记忆(支持结构化消息)
|
||||
|
||||
|
||||
Args:
|
||||
storage_type: 存储类型 (neo4j/rag)
|
||||
end_user_id: 终端用户ID
|
||||
@@ -188,7 +189,7 @@ class LangChainAgent:
|
||||
user_rag_memory_id: RAG 记忆ID
|
||||
actual_end_user_id: 实际用户ID
|
||||
actual_config_id: 配置ID
|
||||
|
||||
|
||||
逻辑说明:
|
||||
- RAG 模式:组合 user_message 和 ai_message 为字符串格式,保持原有逻辑不变
|
||||
- Neo4j 模式:使用结构化消息列表
|
||||
@@ -196,48 +197,54 @@ class LangChainAgent:
|
||||
2. 如果只有 user_message:创建单条用户消息 [user](用于历史记忆场景)
|
||||
3. 每条消息会被转换为独立的 Chunk,保留 speaker 字段
|
||||
"""
|
||||
if storage_type == "rag":
|
||||
# RAG 模式:组合消息为字符串格式(保持原有逻辑)
|
||||
combined_message = f"user: {user_message}\nassistant: {ai_message}"
|
||||
await write_rag(end_user_id, combined_message, user_rag_memory_id)
|
||||
logger.info(f'RAG_Agent:{end_user_id};{user_rag_memory_id}')
|
||||
else:
|
||||
# Neo4j 模式:使用结构化消息列表
|
||||
structured_messages = []
|
||||
|
||||
# 始终添加用户消息(如果不为空)
|
||||
if user_message:
|
||||
structured_messages.append({"role": "user", "content": user_message})
|
||||
|
||||
# 只有当 AI 回复不为空时才添加 assistant 消息
|
||||
if ai_message:
|
||||
structured_messages.append({"role": "assistant", "content": ai_message})
|
||||
|
||||
# 如果没有消息,直接返回
|
||||
if not structured_messages:
|
||||
logger.warning(f"No messages to write for user {actual_end_user_id}")
|
||||
return
|
||||
|
||||
# 调用 Celery 任务,传递结构化消息列表
|
||||
# 数据流:
|
||||
# 1. structured_messages 传递给 write_message_task
|
||||
# 2. write_message_task 调用 memory_agent_service.write_memory
|
||||
# 3. write_memory 调用 write_tools.write,传递 messages 参数
|
||||
# 4. write_tools.write 调用 get_chunked_dialogs,传递 messages 参数
|
||||
# 5. get_chunked_dialogs 为每条消息创建独立的 Chunk,设置 speaker 字段
|
||||
# 6. 每个 Chunk 保存到 Neo4j,包含 speaker 字段
|
||||
logger.info(f"[WRITE] Submitting Celery task - user={actual_end_user_id}, messages={len(structured_messages)}, config={actual_config_id}")
|
||||
write_id = write_message_task.delay(
|
||||
actual_end_user_id, # group_id: 用户ID
|
||||
structured_messages, # message: 结构化消息列表 [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]
|
||||
actual_config_id, # config_id: 配置ID
|
||||
storage_type, # storage_type: "neo4j"
|
||||
user_rag_memory_id # user_rag_memory_id: RAG记忆ID(Neo4j模式下不使用)
|
||||
)
|
||||
logger.info(f"[WRITE] Celery task submitted - task_id={write_id}")
|
||||
write_status = get_task_memory_write_result(str(write_id))
|
||||
logger.info(f'[WRITE] Task result - user={actual_end_user_id}, status={write_status}')
|
||||
|
||||
db = next(get_db())
|
||||
try:
|
||||
actual_config_id=resolve_config_id(actual_config_id, db)
|
||||
|
||||
if storage_type == "rag":
|
||||
# RAG 模式:组合消息为字符串格式(保持原有逻辑)
|
||||
combined_message = f"user: {user_message}\nassistant: {ai_message}"
|
||||
await write_rag(end_user_id, combined_message, user_rag_memory_id)
|
||||
logger.info(f'RAG_Agent:{end_user_id};{user_rag_memory_id}')
|
||||
else:
|
||||
# Neo4j 模式:使用结构化消息列表
|
||||
structured_messages = []
|
||||
|
||||
# 始终添加用户消息(如果不为空)
|
||||
if user_message:
|
||||
structured_messages.append({"role": "user", "content": user_message})
|
||||
|
||||
# 只有当 AI 回复不为空时才添加 assistant 消息
|
||||
if ai_message:
|
||||
structured_messages.append({"role": "assistant", "content": ai_message})
|
||||
|
||||
# 如果没有消息,直接返回
|
||||
if not structured_messages:
|
||||
logger.warning(f"No messages to write for user {actual_end_user_id}")
|
||||
return
|
||||
|
||||
# 调用 Celery 任务,传递结构化消息列表
|
||||
# 数据流:
|
||||
# 1. structured_messages 传递给 write_message_task
|
||||
# 2. write_message_task 调用 memory_agent_service.write_memory
|
||||
# 3. write_memory 调用 write_tools.write,传递 messages 参数
|
||||
# 4. write_tools.write 调用 get_chunked_dialogs,传递 messages 参数
|
||||
# 5. get_chunked_dialogs 为每条消息创建独立的 Chunk,设置 speaker 字段
|
||||
# 6. 每个 Chunk 保存到 Neo4j,包含 speaker 字段
|
||||
logger.info(f"[WRITE] Submitting Celery task - user={actual_end_user_id}, messages={len(structured_messages)}, config={actual_config_id}")
|
||||
write_id = write_message_task.delay(
|
||||
actual_end_user_id, # end_user_id: 用户ID
|
||||
structured_messages, # message: 结构化消息列表 [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]
|
||||
actual_config_id, # config_id: 配置ID
|
||||
storage_type, # storage_type: "neo4j"
|
||||
user_rag_memory_id # user_rag_memory_id: RAG记忆ID(Neo4j模式下不使用)
|
||||
)
|
||||
logger.info(f"[WRITE] Celery task submitted - task_id={write_id}")
|
||||
write_status = get_task_memory_write_result(str(write_id))
|
||||
logger.info(f'[WRITE] Task result - user={actual_end_user_id}, status={write_status}')
|
||||
finally:
|
||||
db.close()
|
||||
async def chat(
|
||||
self,
|
||||
message: str,
|
||||
|
||||
@@ -9,6 +9,25 @@ load_dotenv()
|
||||
|
||||
|
||||
class Settings:
|
||||
# ========================================================================
|
||||
# Deployment Mode Configuration
|
||||
# ========================================================================
|
||||
# community: 社区版(开源,功能受限)
|
||||
# cloud: SaaS 云服务版(全功能,按量计费)
|
||||
# enterprise: 企业私有化版(License 控制)
|
||||
DEPLOYMENT_MODE: str = os.getenv("DEPLOYMENT_MODE", "community")
|
||||
|
||||
# License 配置(企业版)
|
||||
LICENSE_FILE: str = os.getenv("LICENSE_FILE", "/etc/app/license.json")
|
||||
LICENSE_SERVER_URL: str = os.getenv("LICENSE_SERVER_URL", "https://license.yourcompany.com")
|
||||
|
||||
# 计费服务配置(SaaS 版)
|
||||
BILLING_SERVICE_URL: str = os.getenv("BILLING_SERVICE_URL", "")
|
||||
|
||||
# 基础 URL(用于 SSO 回调等)
|
||||
BASE_URL: str = os.getenv("BASE_URL", "http://localhost:8000")
|
||||
FRONTEND_URL: str = os.getenv("FRONTEND_URL", "http://localhost:3000")
|
||||
|
||||
ENABLE_SINGLE_WORKSPACE: bool = os.getenv("ENABLE_SINGLE_WORKSPACE", "true").lower() == "true"
|
||||
# API Keys Configuration
|
||||
OPENAI_API_KEY: str = os.getenv("OPENAI_API_KEY", "")
|
||||
@@ -72,6 +91,10 @@ class Settings:
|
||||
|
||||
# Single Sign-On configuration
|
||||
ENABLE_SINGLE_SESSION: bool = os.getenv("ENABLE_SINGLE_SESSION", "false").lower() == "true"
|
||||
|
||||
# SSO 免登配置
|
||||
SSO_TOKEN_EXPIRE_SECONDS: int = int(os.getenv("SSO_TOKEN_EXPIRE_SECONDS", "300"))
|
||||
SSO_TRUSTED_SOURCES_CONFIG: str = os.getenv("SSO_TRUSTED_SOURCES_CONFIG", "{}")
|
||||
|
||||
# File Upload
|
||||
MAX_FILE_SIZE: int = int(os.getenv("MAX_FILE_SIZE", "52428800"))
|
||||
@@ -107,6 +130,7 @@ class Settings:
|
||||
|
||||
# Server Configuration
|
||||
SERVER_IP: str = os.getenv("SERVER_IP", "127.0.0.1")
|
||||
FILE_LOCAL_SERVER_URL : str = os.getenv("FILE_LOCAL_SERVER_URL", "http://localhost:8000/api")
|
||||
|
||||
# ========================================================================
|
||||
# Internal Configuration (not in .env, used by application code)
|
||||
@@ -184,7 +208,7 @@ class Settings:
|
||||
ENABLE_TOOL_MANAGEMENT: bool = os.getenv("ENABLE_TOOL_MANAGEMENT", "true").lower() == "true"
|
||||
|
||||
# official environment system version
|
||||
SYSTEM_VERSION: str = os.getenv("SYSTEM_VERSION", "v0.2.0")
|
||||
SYSTEM_VERSION: str = os.getenv("SYSTEM_VERSION", "v0.2.1")
|
||||
|
||||
# workflow config
|
||||
WORKFLOW_NODE_TIMEOUT: int = int(os.getenv("WORKFLOW_NODE_TIMEOUT", 600))
|
||||
|
||||
@@ -14,7 +14,7 @@ from app.core.memory.agent.utils.session_tools import SessionService
|
||||
from app.core.memory.agent.utils.template_tools import TemplateService
|
||||
from app.core.memory.agent.services.optimized_llm_service import LLMServiceMixin
|
||||
|
||||
template_root = os.path.join(PROJECT_ROOT_, 'agent', 'utils', 'prompt')
|
||||
template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt')
|
||||
db_session = next(get_db())
|
||||
logger = get_agent_logger(__name__)
|
||||
|
||||
@@ -35,10 +35,10 @@ async def Split_The_Problem(state: ReadState) -> ReadState:
|
||||
"""问题分解节点"""
|
||||
# 从状态中获取数据
|
||||
content = state.get('data', '')
|
||||
group_id = state.get('group_id', '')
|
||||
end_user_id = state.get('end_user_id', '')
|
||||
memory_config = state.get('memory_config', None)
|
||||
|
||||
history = await SessionService(store).get_history(group_id, group_id, group_id)
|
||||
history = await SessionService(store).get_history(end_user_id, end_user_id, end_user_id)
|
||||
|
||||
# 生成 JSON schema 以指导 LLM 输出正确格式
|
||||
json_schema = ProblemExtensionResponse.model_json_schema()
|
||||
@@ -140,7 +140,7 @@ async def Problem_Extension(state: ReadState) -> ReadState:
|
||||
start = time.time()
|
||||
content = state.get('data', '')
|
||||
data = state.get('spit_data', '')['context']
|
||||
group_id = state.get('group_id', '')
|
||||
end_user_id = state.get('end_user_id', '')
|
||||
storage_type = state.get('storage_type', '')
|
||||
user_rag_memory_id = state.get('user_rag_memory_id', '')
|
||||
memory_config = state.get('memory_config', None)
|
||||
@@ -156,7 +156,7 @@ async def Problem_Extension(state: ReadState) -> ReadState:
|
||||
databasets = {}
|
||||
data = []
|
||||
|
||||
history = await SessionService(store).get_history(group_id, group_id, group_id)
|
||||
history = await SessionService(store).get_history(end_user_id, end_user_id, end_user_id)
|
||||
|
||||
# 生成 JSON schema 以指导 LLM 输出正确格式
|
||||
json_schema = ProblemExtensionResponse.model_json_schema()
|
||||
|
||||
@@ -52,9 +52,9 @@ async def rag_config(state):
|
||||
return kb_config
|
||||
async def rag_knowledge(state,question):
|
||||
kb_config = await rag_config(state)
|
||||
group_id = state.get('group_id', '')
|
||||
end_user_id = state.get('end_user_id', '')
|
||||
user_rag_memory_id=state.get("user_rag_memory_id",'')
|
||||
retrieve_chunks_result = knowledge_retrieval(question, kb_config, [str(group_id)])
|
||||
retrieve_chunks_result = knowledge_retrieval(question, kb_config, [str(end_user_id)])
|
||||
try:
|
||||
retrieval_knowledge = [i.page_content for i in retrieve_chunks_result]
|
||||
clean_content = '\n\n'.join(retrieval_knowledge)
|
||||
@@ -159,7 +159,7 @@ async def retrieve_nodes(state: ReadState) -> ReadState:
|
||||
problem_extension=state.get('problem_extension', '')['context']
|
||||
storage_type=state.get('storage_type', '')
|
||||
user_rag_memory_id=state.get('user_rag_memory_id', '')
|
||||
group_id=state.get('group_id', '')
|
||||
end_user_id=state.get('end_user_id', '')
|
||||
memory_config = state.get('memory_config', None)
|
||||
original=state.get('data', '')
|
||||
problem_list=[]
|
||||
@@ -172,7 +172,7 @@ async def retrieve_nodes(state: ReadState) -> ReadState:
|
||||
try:
|
||||
# Prepare search parameters based on storage type
|
||||
search_params = {
|
||||
"group_id": group_id,
|
||||
"end_user_id": end_user_id,
|
||||
"question": question,
|
||||
"return_raw_results": True
|
||||
}
|
||||
@@ -263,13 +263,13 @@ async def retrieve_nodes(state: ReadState) -> ReadState:
|
||||
|
||||
|
||||
async def retrieve(state: ReadState) -> ReadState:
|
||||
# 从state中获取group_id
|
||||
# 从state中获取end_user_id
|
||||
import time
|
||||
start=time.time()
|
||||
problem_extension = state.get('problem_extension', '')['context']
|
||||
storage_type = state.get('storage_type', '')
|
||||
user_rag_memory_id = state.get('user_rag_memory_id', '')
|
||||
group_id = state.get('group_id', '')
|
||||
end_user_id = state.get('end_user_id', '')
|
||||
memory_config = state.get('memory_config', None)
|
||||
original = state.get('data', '')
|
||||
problem_list = []
|
||||
@@ -295,13 +295,13 @@ async def retrieve(state: ReadState) -> ReadState:
|
||||
temperature=0.2,
|
||||
)
|
||||
|
||||
time_retrieval_tool = create_time_retrieval_tool(group_id)
|
||||
search_params = { "group_id": group_id, "return_raw_results": True }
|
||||
time_retrieval_tool = create_time_retrieval_tool(end_user_id)
|
||||
search_params = { "end_user_id": end_user_id, "return_raw_results": True }
|
||||
hybrid_retrieval=create_hybrid_retrieval_tool_sync(memory_config, **search_params)
|
||||
agent = create_agent(
|
||||
llm,
|
||||
tools=[time_retrieval_tool,hybrid_retrieval],
|
||||
system_prompt=f"我是检索专家,可以根据适合的工具进行检索。当前使用的group_id是: {group_id}"
|
||||
system_prompt=f"我是检索专家,可以根据适合的工具进行检索。当前使用的end_user_id是: {end_user_id}"
|
||||
)
|
||||
|
||||
# 创建异步任务处理单个问题
|
||||
|
||||
@@ -19,7 +19,7 @@ from app.core.memory.agent.utils.session_tools import SessionService
|
||||
from app.core.memory.agent.utils.template_tools import TemplateService
|
||||
from app.db import get_db
|
||||
|
||||
template_root = os.path.join(PROJECT_ROOT_, 'agent', 'utils', 'prompt')
|
||||
template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt')
|
||||
logger = get_agent_logger(__name__)
|
||||
db_session = next(get_db())
|
||||
|
||||
@@ -34,8 +34,8 @@ class SummaryNodeService(LLMServiceMixin):
|
||||
summary_service = SummaryNodeService()
|
||||
|
||||
async def summary_history(state: ReadState) -> ReadState:
|
||||
group_id = state.get("group_id", '')
|
||||
history = await SessionService(store).get_history(group_id, group_id, group_id)
|
||||
end_user_id = state.get("end_user_id", '')
|
||||
history = await SessionService(store).get_history(end_user_id, end_user_id, end_user_id)
|
||||
return history
|
||||
|
||||
async def summary_llm(state: ReadState, history, retrieve_info, template_name, operation_name, response_model,search_mode) -> str:
|
||||
@@ -122,12 +122,12 @@ async def summary_llm(state: ReadState, history, retrieve_info, template_name, o
|
||||
|
||||
async def summary_redis_save(state: ReadState,aimessages) -> ReadState:
|
||||
data = state.get("data", '')
|
||||
group_id = state.get("group_id", '')
|
||||
end_user_id = state.get("end_user_id", '')
|
||||
await SessionService(store).save_session(
|
||||
user_id=group_id,
|
||||
user_id=end_user_id,
|
||||
query=data,
|
||||
apply_id=group_id,
|
||||
group_id=group_id,
|
||||
apply_id=end_user_id,
|
||||
end_user_id=end_user_id,
|
||||
ai_response=aimessages
|
||||
)
|
||||
await SessionService(store).cleanup_duplicates()
|
||||
@@ -175,11 +175,11 @@ async def Input_Summary(state: ReadState) -> ReadState:
|
||||
memory_config = state.get('memory_config', None)
|
||||
user_rag_memory_id=state.get("user_rag_memory_id",'')
|
||||
data=state.get("data", '')
|
||||
group_id=state.get("group_id", '')
|
||||
end_user_id=state.get("end_user_id", '')
|
||||
logger.info(f"Input_Summary: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}")
|
||||
history = await summary_history( state)
|
||||
search_params = {
|
||||
"group_id": group_id,
|
||||
"end_user_id": end_user_id,
|
||||
"question": data,
|
||||
"return_raw_results": True,
|
||||
"include": ["summaries"] # Only search summary nodes for faster performance
|
||||
@@ -236,7 +236,7 @@ async def Retrieve_Summary(state: ReadState)-> ReadState:
|
||||
retrieve_info_str='\n'.join(retrieve_info_str)
|
||||
|
||||
aimessages=await summary_llm(state,history,retrieve_info_str,
|
||||
'Retrieve_Summary_prompt.jinja2','retrieve_summary',RetrieveSummaryResponse,"1")
|
||||
'direct_summary_prompt.jinja2','retrieve_summary',RetrieveSummaryResponse,"1")
|
||||
if '信息不足,无法回答' not in str(aimessages) or str(aimessages) != "":
|
||||
await summary_redis_save(state, aimessages)
|
||||
if aimessages == '':
|
||||
@@ -276,7 +276,6 @@ async def Summary(state: ReadState)-> ReadState:
|
||||
aimessages=await summary_llm(state,history,data,
|
||||
'summary_prompt.jinja2','summary',SummaryResponse,0)
|
||||
|
||||
|
||||
if '信息不足,无法回答' not in str(aimessages) or str(aimessages) != "":
|
||||
await summary_redis_save(state, aimessages)
|
||||
if aimessages == '':
|
||||
@@ -295,9 +294,26 @@ async def Summary(state: ReadState)-> ReadState:
|
||||
async def Summary_fails(state: ReadState)-> ReadState:
|
||||
storage_type=state.get("storage_type", '')
|
||||
user_rag_memory_id=state.get("user_rag_memory_id", '')
|
||||
history = await summary_history(state)
|
||||
query = state.get("data", '')
|
||||
verify = state.get("verify", '')
|
||||
verify_expansion_issue = verify.get("verified_data", '')
|
||||
retrieve_info_str = ''
|
||||
for data in verify_expansion_issue:
|
||||
for key, value in data.items():
|
||||
if key == 'answer_small':
|
||||
for i in value:
|
||||
retrieve_info_str += i + '\n'
|
||||
data = {
|
||||
"query": query,
|
||||
"history": history,
|
||||
"retrieve_info": retrieve_info_str
|
||||
}
|
||||
aimessages = await summary_llm(state, history, data,
|
||||
'fail_summary_prompt.jinja2', 'summary', SummaryResponse, 0)
|
||||
result= {
|
||||
"status": "success",
|
||||
"summary_result": "没有相关数据",
|
||||
"summary_result": aimessages,
|
||||
"storage_type": storage_type,
|
||||
"user_rag_memory_id": user_rag_memory_id
|
||||
}
|
||||
|
||||
@@ -12,7 +12,7 @@ from app.core.memory.agent.utils.session_tools import SessionService
|
||||
from app.core.memory.agent.utils.template_tools import TemplateService
|
||||
from app.core.memory.agent.services.optimized_llm_service import LLMServiceMixin
|
||||
|
||||
template_root = os.path.join(PROJECT_ROOT_, 'agent', 'utils', 'prompt')
|
||||
template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt')
|
||||
db_session = next(get_db())
|
||||
logger = get_agent_logger(__name__)
|
||||
|
||||
@@ -62,12 +62,12 @@ async def Verify(state: ReadState):
|
||||
logger.info("=== Verify 节点开始执行 ===")
|
||||
try:
|
||||
content = state.get('data', '')
|
||||
group_id = state.get('group_id', '')
|
||||
end_user_id = state.get('end_user_id', '')
|
||||
memory_config = state.get('memory_config', None)
|
||||
|
||||
logger.info(f"Verify: content={content[:50] if content else 'empty'}..., group_id={group_id}")
|
||||
logger.info(f"Verify: content={content[:50] if content else 'empty'}..., end_user_id={end_user_id}")
|
||||
|
||||
history = await SessionService(store).get_history(group_id, group_id, group_id)
|
||||
history = await SessionService(store).get_history(end_user_id, end_user_id, end_user_id)
|
||||
logger.info(f"Verify: 获取历史记录完成,history length={len(history)}")
|
||||
|
||||
retrieve = state.get("retrieve", {})
|
||||
|
||||
@@ -1,23 +1,24 @@
|
||||
|
||||
from app.core.memory.agent.utils.llm_tools import WriteState
|
||||
from app.core.memory.agent.utils.llm_tools import WriteState
|
||||
from app.core.memory.agent.utils.write_tools import write
|
||||
from app.core.logging_config import get_agent_logger
|
||||
|
||||
logger = get_agent_logger(__name__)
|
||||
|
||||
|
||||
async def write_node(state: WriteState) -> WriteState:
|
||||
"""
|
||||
Write data to the database/file system.
|
||||
|
||||
Args:
|
||||
state: WriteState containing messages, group_id, and memory_config
|
||||
state: WriteState containing messages, end_user_id, and memory_config
|
||||
|
||||
Returns:
|
||||
dict: Contains 'write_result' with status and data fields
|
||||
"""
|
||||
messages = state.get('messages', [])
|
||||
group_id = state.get('group_id', '')
|
||||
end_user_id = state.get('end_user_id', '')
|
||||
memory_config = state.get('memory_config', '')
|
||||
|
||||
|
||||
# Convert LangChain messages to structured format expected by write()
|
||||
structured_messages = []
|
||||
for msg in messages:
|
||||
@@ -28,13 +29,11 @@ async def write_node(state: WriteState) -> WriteState:
|
||||
"role": role,
|
||||
"content": msg.content # content is now guaranteed to be a string
|
||||
})
|
||||
|
||||
|
||||
try:
|
||||
result = await write(
|
||||
messages=structured_messages,
|
||||
user_id=group_id,
|
||||
apply_id=group_id,
|
||||
group_id=group_id,
|
||||
end_user_id=end_user_id,
|
||||
memory_config=memory_config,
|
||||
)
|
||||
logger.info(f"Write completed successfully! Config: {memory_config.config_name}")
|
||||
|
||||
@@ -79,7 +79,7 @@ async def make_read_graph():
|
||||
async def main():
|
||||
"""主函数 - 运行工作流"""
|
||||
message = "昨天有什么好看的电影"
|
||||
group_id = '88a459f5_text09' # 组ID
|
||||
end_user_id = '88a459f5_text09' # 组ID
|
||||
storage_type = 'neo4j' # 存储类型
|
||||
search_switch = '1' # 搜索开关
|
||||
user_rag_memory_id = 'wwwwwwww' # 用户RAG记忆ID
|
||||
@@ -95,9 +95,9 @@ async def main():
|
||||
start=time.time()
|
||||
try:
|
||||
async with make_read_graph() as graph:
|
||||
config = {"configurable": {"thread_id": group_id}}
|
||||
config = {"configurable": {"thread_id": end_user_id}}
|
||||
# 初始状态 - 包含所有必要字段
|
||||
initial_state = {"messages": [HumanMessage(content=message)] ,"search_switch":search_switch,"group_id":group_id
|
||||
initial_state = {"messages": [HumanMessage(content=message)] ,"search_switch":search_switch,"end_user_id":end_user_id
|
||||
,"storage_type":storage_type,"user_rag_memory_id":user_rag_memory_id,"memory_config":memory_config}
|
||||
# 获取节点更新信息
|
||||
_intermediate_outputs = []
|
||||
|
||||
@@ -48,11 +48,11 @@ def extract_tool_message_content(response):
|
||||
class TimeRetrievalInput(BaseModel):
|
||||
"""时间检索工具的输入模式"""
|
||||
context: str = Field(description="用户输入的查询内容")
|
||||
group_id: str = Field(default="88a459f5_text09", description="组ID,用于过滤搜索结果")
|
||||
end_user_id: str = Field(default="88a459f5_text09", description="组ID,用于过滤搜索结果")
|
||||
|
||||
def create_time_retrieval_tool(group_id: str):
|
||||
def create_time_retrieval_tool(end_user_id: str):
|
||||
"""
|
||||
创建一个带有特定group_id的TimeRetrieval工具(同步版本),用于按时间范围搜索语句(Statements)
|
||||
创建一个带有特定end_user_id的TimeRetrieval工具(同步版本),用于按时间范围搜索语句(Statements)
|
||||
"""
|
||||
|
||||
def clean_temporal_result_fields(data):
|
||||
@@ -93,26 +93,26 @@ def create_time_retrieval_tool(group_id: str):
|
||||
return data
|
||||
|
||||
@tool
|
||||
def TimeRetrievalWithGroupId(context: str, start_date: str = None, end_date: str = None, group_id_param: str = None, clean_output: bool = True) -> str:
|
||||
def TimeRetrievalWithGroupId(context: str, start_date: str = None, end_date: str = None, end_user_id_param: str = None, clean_output: bool = True) -> str:
|
||||
"""
|
||||
优化的时间检索工具,只结合时间范围搜索(同步版本),自动过滤不需要的元数据字段
|
||||
显式接收参数:
|
||||
- context: 查询上下文内容
|
||||
- start_date: 开始时间(可选,格式:YYYY-MM-DD)
|
||||
- end_date: 结束时间(可选,格式:YYYY-MM-DD)
|
||||
- group_id_param: 组ID(可选,用于覆盖默认组ID)
|
||||
- end_user_id_param: 组ID(可选,用于覆盖默认组ID)
|
||||
- clean_output: 是否清理输出中的元数据字段
|
||||
-end_date 需要根据用户的描述获取结束的时间,输出格式用strftime("%Y-%m-%d")
|
||||
"""
|
||||
async def _async_search():
|
||||
# 使用传入的参数或默认值
|
||||
actual_group_id = group_id_param or group_id
|
||||
actual_end_user_id = end_user_id_param or end_user_id
|
||||
actual_end_date = end_date or datetime.now().strftime("%Y-%m-%d")
|
||||
actual_start_date = start_date or (datetime.now() - timedelta(days=7)).strftime("%Y-%m-%d")
|
||||
|
||||
# 基本时间搜索
|
||||
results = await search_by_temporal(
|
||||
group_id=actual_group_id,
|
||||
end_user_id=actual_end_user_id,
|
||||
start_date=actual_start_date,
|
||||
end_date=actual_end_date,
|
||||
limit=10
|
||||
@@ -147,7 +147,7 @@ def create_time_retrieval_tool(group_id: str):
|
||||
# 关键词时间搜索
|
||||
results = await search_by_keyword_temporal(
|
||||
query_text=context,
|
||||
group_id=group_id,
|
||||
end_user_id=end_user_id,
|
||||
start_date=actual_start_date,
|
||||
end_date=actual_end_date,
|
||||
limit=15
|
||||
@@ -172,7 +172,7 @@ def create_hybrid_retrieval_tool_async(memory_config, **search_params):
|
||||
|
||||
Args:
|
||||
memory_config: 内存配置对象
|
||||
**search_params: 搜索参数,包含group_id, limit, include等
|
||||
**search_params: 搜索参数,包含end_user_id, limit, include等
|
||||
"""
|
||||
|
||||
def clean_result_fields(data):
|
||||
@@ -211,7 +211,7 @@ def create_hybrid_retrieval_tool_async(memory_config, **search_params):
|
||||
context: str,
|
||||
search_type: str = "hybrid",
|
||||
limit: int = 10,
|
||||
group_id: str = None,
|
||||
end_user_id: str = None,
|
||||
rerank_alpha: float = 0.6,
|
||||
use_forgetting_rerank: bool = False,
|
||||
use_llm_rerank: bool = False,
|
||||
@@ -224,7 +224,7 @@ def create_hybrid_retrieval_tool_async(memory_config, **search_params):
|
||||
context: 查询内容
|
||||
search_type: 搜索类型 ('keyword', 'embedding', 'hybrid')
|
||||
limit: 结果数量限制
|
||||
group_id: 组ID,用于过滤搜索结果
|
||||
end_user_id: 组ID,用于过滤搜索结果
|
||||
rerank_alpha: 重排序权重参数
|
||||
use_forgetting_rerank: 是否使用遗忘重排序
|
||||
use_llm_rerank: 是否使用LLM重排序
|
||||
@@ -238,7 +238,7 @@ def create_hybrid_retrieval_tool_async(memory_config, **search_params):
|
||||
final_params = {
|
||||
"query_text": context,
|
||||
"search_type": search_type,
|
||||
"group_id": group_id or search_params.get("group_id"),
|
||||
"end_user_id": end_user_id or search_params.get("end_user_id"),
|
||||
"limit": limit or search_params.get("limit", 10),
|
||||
"include": search_params.get("include", ["summaries", "statements", "chunks", "entities"]),
|
||||
"output_path": None, # 不保存到文件
|
||||
@@ -291,7 +291,7 @@ def create_hybrid_retrieval_tool_sync(memory_config, **search_params):
|
||||
context: str,
|
||||
search_type: str = "hybrid",
|
||||
limit: int = 10,
|
||||
group_id: str = None,
|
||||
end_user_id: str = None,
|
||||
clean_output: bool = True
|
||||
) -> str:
|
||||
"""
|
||||
@@ -301,7 +301,7 @@ def create_hybrid_retrieval_tool_sync(memory_config, **search_params):
|
||||
context: 查询内容
|
||||
search_type: 搜索类型 ('keyword', 'embedding', 'hybrid')
|
||||
limit: 结果数量限制
|
||||
group_id: 组ID,用于过滤搜索结果
|
||||
end_user_id: 组ID,用于过滤搜索结果
|
||||
clean_output: 是否清理输出中的元数据字段
|
||||
"""
|
||||
async def _async_search():
|
||||
@@ -311,7 +311,7 @@ def create_hybrid_retrieval_tool_sync(memory_config, **search_params):
|
||||
"context": context,
|
||||
"search_type": search_type,
|
||||
"limit": limit,
|
||||
"group_id": group_id,
|
||||
"end_user_id": end_user_id,
|
||||
"clean_output": clean_output
|
||||
})
|
||||
|
||||
|
||||
@@ -14,6 +14,7 @@ from app.db import get_db
|
||||
from app.core.logging_config import get_agent_logger
|
||||
from app.core.memory.agent.utils.llm_tools import WriteState
|
||||
from app.core.memory.agent.langgraph_graph.nodes.write_nodes import write_node
|
||||
from app.core.memory.agent.langgraph_graph.nodes.data_nodes import content_input_write
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
|
||||
warnings.filterwarnings("ignore", category=RuntimeWarning)
|
||||
@@ -26,9 +27,21 @@ async def make_write_graph():
|
||||
"""
|
||||
Create a write graph workflow for memory operations.
|
||||
|
||||
The workflow directly processes messages from the initial state
|
||||
and saves them to Neo4j storage.
|
||||
Args:
|
||||
user_id: User identifier
|
||||
tools: MCP tools loaded from session
|
||||
apply_id: Application identifier
|
||||
end_user_id: Group identifier
|
||||
memory_config: MemoryConfig object containing all configuration
|
||||
"""
|
||||
# workflow = StateGraph(WriteState)
|
||||
# workflow.add_node("content_input", content_input_write)
|
||||
# workflow.add_node("save_neo4j", write_node)
|
||||
# workflow.add_edge(START, "content_input")
|
||||
# workflow.add_edge("content_input", "save_neo4j")
|
||||
# workflow.add_edge("save_neo4j", END)
|
||||
#
|
||||
# graph = workflow.compile()
|
||||
workflow = StateGraph(WriteState)
|
||||
workflow.add_node("save_neo4j", write_node)
|
||||
workflow.add_edge(START, "save_neo4j")
|
||||
@@ -42,7 +55,7 @@ async def make_write_graph():
|
||||
async def main():
|
||||
"""主函数 - 运行工作流"""
|
||||
message = "今天周一"
|
||||
group_id = 'new_2025test1103' # 组ID
|
||||
end_user_id = 'new_2025test1103' # 组ID
|
||||
|
||||
|
||||
# 获取数据库会话
|
||||
@@ -54,9 +67,9 @@ async def main():
|
||||
)
|
||||
try:
|
||||
async with make_write_graph() as graph:
|
||||
config = {"configurable": {"thread_id": group_id}}
|
||||
config = {"configurable": {"thread_id": end_user_id}}
|
||||
# 初始状态 - 包含所有必要字段
|
||||
initial_state = {"messages": [HumanMessage(content=message)], "group_id": group_id, "memory_config": memory_config}
|
||||
initial_state = {"messages": [HumanMessage(content=message)], "end_user_id": end_user_id, "memory_config": memory_config}
|
||||
|
||||
# 获取节点更新信息
|
||||
async for update_event in graph.astream(
|
||||
|
||||
@@ -24,7 +24,7 @@ class ParameterBuilder:
|
||||
tool_call_id: str,
|
||||
search_switch: str,
|
||||
apply_id: str,
|
||||
group_id: str,
|
||||
end_user_id: str,
|
||||
storage_type: Optional[str] = None,
|
||||
user_rag_memory_id: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
@@ -44,7 +44,7 @@ class ParameterBuilder:
|
||||
tool_call_id: Extracted tool call identifier
|
||||
search_switch: Search routing parameter
|
||||
apply_id: Application identifier
|
||||
group_id: Group identifier
|
||||
end_user_id: Group identifier
|
||||
storage_type: Storage type for the workspace (optional)
|
||||
user_rag_memory_id: User RAG memory ID for knowledge base retrieval (optional)
|
||||
|
||||
@@ -55,7 +55,7 @@ class ParameterBuilder:
|
||||
base_args = {
|
||||
"usermessages": tool_call_id,
|
||||
"apply_id": apply_id,
|
||||
"group_id": group_id
|
||||
"end_user_id": end_user_id
|
||||
}
|
||||
|
||||
# Always add storage_type and user_rag_memory_id (with defaults if None)
|
||||
|
||||
@@ -91,7 +91,7 @@ class SearchService:
|
||||
|
||||
async def execute_hybrid_search(
|
||||
self,
|
||||
group_id: str,
|
||||
end_user_id: str,
|
||||
question: str,
|
||||
limit: int = 5,
|
||||
search_type: str = "hybrid",
|
||||
@@ -105,7 +105,7 @@ class SearchService:
|
||||
Execute hybrid search and return clean content.
|
||||
|
||||
Args:
|
||||
group_id: Group identifier for filtering results
|
||||
end_user_id: Group identifier for filtering results
|
||||
question: Search query text
|
||||
limit: Maximum number of results to return (default: 5)
|
||||
search_type: Type of search - "hybrid", "keyword", or "embedding" (default: "hybrid")
|
||||
@@ -130,7 +130,7 @@ class SearchService:
|
||||
answer = await run_hybrid_search(
|
||||
query_text=cleaned_query,
|
||||
search_type=search_type,
|
||||
group_id=group_id,
|
||||
end_user_id=end_user_id,
|
||||
limit=limit,
|
||||
include=include,
|
||||
output_path=output_path,
|
||||
@@ -186,7 +186,7 @@ class SearchService:
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Search failed for query '{question}' in group '{group_id}': {e}",
|
||||
f"Search failed for query '{question}' in group '{end_user_id}': {e}",
|
||||
exc_info=True
|
||||
)
|
||||
# Return empty results on failure
|
||||
|
||||
@@ -59,7 +59,7 @@ class SessionService:
|
||||
self,
|
||||
user_id: str,
|
||||
apply_id: str,
|
||||
group_id: str
|
||||
end_user_id: str
|
||||
) -> List[dict]:
|
||||
"""
|
||||
Retrieve conversation history from Redis.
|
||||
@@ -67,20 +67,20 @@ class SessionService:
|
||||
Args:
|
||||
user_id: User identifier
|
||||
apply_id: Application identifier
|
||||
group_id: Group identifier
|
||||
end_user_id: Group identifier
|
||||
|
||||
Returns:
|
||||
List of conversation history items with Query and Answer keys
|
||||
Returns empty list if no history found or on error
|
||||
"""
|
||||
try:
|
||||
history = self.store.find_user_apply_group(user_id, apply_id, group_id)
|
||||
history = self.store.find_user_apply_group(user_id, apply_id, end_user_id)
|
||||
|
||||
# Validate history structure
|
||||
if not isinstance(history, list):
|
||||
logger.warning(
|
||||
f"Invalid history format for user {user_id}, "
|
||||
f"apply {apply_id}, group {group_id}: expected list, got {type(history)}"
|
||||
f"apply {apply_id}, group {end_user_id}: expected list, got {type(history)}"
|
||||
)
|
||||
return []
|
||||
|
||||
@@ -89,7 +89,7 @@ class SessionService:
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to retrieve history for user {user_id}, "
|
||||
f"apply {apply_id}, group {group_id}: {e}",
|
||||
f"apply {apply_id}, group {end_user_id}: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
# Return empty list on error to allow execution to continue
|
||||
@@ -100,7 +100,7 @@ class SessionService:
|
||||
user_id: str,
|
||||
query: str,
|
||||
apply_id: str,
|
||||
group_id: str,
|
||||
end_user_id: str,
|
||||
ai_response: str
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
@@ -110,7 +110,7 @@ class SessionService:
|
||||
user_id: User identifier
|
||||
query: User query/message
|
||||
apply_id: Application identifier
|
||||
group_id: Group identifier
|
||||
end_user_id: Group identifier
|
||||
ai_response: AI response/answer
|
||||
|
||||
Returns:
|
||||
@@ -131,7 +131,7 @@ class SessionService:
|
||||
userid=user_id,
|
||||
messages=query,
|
||||
apply_id=apply_id,
|
||||
group_id=group_id,
|
||||
end_user_id=end_user_id,
|
||||
aimessages=ai_response
|
||||
)
|
||||
|
||||
@@ -152,7 +152,7 @@ class SessionService:
|
||||
Duplicates are identified by matching:
|
||||
- sessionid
|
||||
- user_id (id field)
|
||||
- group_id
|
||||
- end_user_id
|
||||
- messages
|
||||
- aimessages
|
||||
|
||||
|
||||
@@ -9,9 +9,7 @@ from app.core.memory.models.message_models import DialogData, ConversationContex
|
||||
|
||||
async def get_chunked_dialogs(
|
||||
chunker_strategy: str = "RecursiveChunker",
|
||||
group_id: str = "group_1",
|
||||
user_id: str = "user1",
|
||||
apply_id: str = "applyid",
|
||||
end_user_id: str = "group_1",
|
||||
messages: list = None,
|
||||
ref_id: str = "wyl_20251027",
|
||||
config_id: str = None
|
||||
@@ -20,9 +18,7 @@ async def get_chunked_dialogs(
|
||||
|
||||
Args:
|
||||
chunker_strategy: The chunking strategy to use (default: RecursiveChunker)
|
||||
group_id: Group identifier
|
||||
user_id: User identifier
|
||||
apply_id: Application identifier
|
||||
end_user_id: Group identifier
|
||||
messages: Structured message list [{"role": "user", "content": "..."}, ...]
|
||||
ref_id: Reference identifier
|
||||
config_id: Configuration ID for processing
|
||||
@@ -32,42 +28,40 @@ async def get_chunked_dialogs(
|
||||
"""
|
||||
from app.core.logging_config import get_agent_logger
|
||||
logger = get_agent_logger(__name__)
|
||||
|
||||
|
||||
if not messages or not isinstance(messages, list) or len(messages) == 0:
|
||||
raise ValueError("messages parameter must be a non-empty list")
|
||||
|
||||
|
||||
conversation_messages = []
|
||||
|
||||
|
||||
for idx, msg in enumerate(messages):
|
||||
if not isinstance(msg, dict) or 'role' not in msg or 'content' not in msg:
|
||||
raise ValueError(f"Message {idx} format error: must contain 'role' and 'content' fields")
|
||||
|
||||
|
||||
role = msg['role']
|
||||
content = msg['content']
|
||||
|
||||
|
||||
if role not in ['user', 'assistant']:
|
||||
raise ValueError(f"Message {idx} role must be 'user' or 'assistant', got: {role}")
|
||||
|
||||
|
||||
if content.strip():
|
||||
conversation_messages.append(ConversationMessage(role=role, msg=content.strip()))
|
||||
|
||||
|
||||
if not conversation_messages:
|
||||
raise ValueError("Message list cannot be empty after filtering")
|
||||
|
||||
|
||||
conversation_context = ConversationContext(msgs=conversation_messages)
|
||||
dialog_data = DialogData(
|
||||
context=conversation_context,
|
||||
ref_id=ref_id,
|
||||
group_id=group_id,
|
||||
user_id=user_id,
|
||||
apply_id=apply_id,
|
||||
end_user_id=end_user_id,
|
||||
config_id=config_id
|
||||
)
|
||||
|
||||
|
||||
chunker = DialogueChunker(chunker_strategy)
|
||||
extracted_chunks = await chunker.process_dialogue(dialog_data)
|
||||
dialog_data.chunks = extracted_chunks
|
||||
|
||||
|
||||
logger.info(f"DialogData created with {len(extracted_chunks)} chunks")
|
||||
|
||||
return [dialog_data]
|
||||
|
||||
@@ -1,24 +1,23 @@
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Annotated, TypedDict
|
||||
|
||||
from langchain_core.messages import AnyMessage
|
||||
from langgraph.graph import add_messages
|
||||
|
||||
PROJECT_ROOT_ = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
PROJECT_ROOT_ = str(Path(__file__).resolve().parents[3])
|
||||
|
||||
class WriteState(TypedDict):
|
||||
'''
|
||||
Langgrapg Writing TypedDict
|
||||
'''
|
||||
messages: Annotated[list[AnyMessage], add_messages]
|
||||
user_id:str
|
||||
apply_id:str
|
||||
group_id:str
|
||||
end_user_id: str
|
||||
errors: list[dict] # Track errors: [{"tool": "tool_name", "error": "message"}]
|
||||
memory_config: object
|
||||
write_result: dict
|
||||
data:str
|
||||
data: str
|
||||
|
||||
class ReadState(TypedDict):
|
||||
"""
|
||||
@@ -28,7 +27,7 @@ class ReadState(TypedDict):
|
||||
messages: 消息列表,支持自动追加
|
||||
loop_count: 遍历次数
|
||||
search_switch: 搜索类型开关
|
||||
group_id: 组标识
|
||||
end_user_id: 组标识
|
||||
config_id: 配置ID,用于过滤结果
|
||||
data: 从content_input_node传递的内容数据
|
||||
spit_data: 从Split_The_Problem传递的分解结果
|
||||
@@ -39,7 +38,7 @@ class ReadState(TypedDict):
|
||||
messages: Annotated[list[AnyMessage], add_messages] # 消息追加模式
|
||||
loop_count: int
|
||||
search_switch: str
|
||||
group_id: str
|
||||
end_user_id: str
|
||||
config_id: str
|
||||
data: str # 新增字段用于传递内容
|
||||
spit_data: dict # 新增字段用于传递问题分解结果
|
||||
|
||||
@@ -0,0 +1,61 @@
|
||||
# 角色
|
||||
你是一个智能问答助手,基于检索信息和历史对话回答用户问题。
|
||||
# 任务
|
||||
根据提供的上下文信息回答用户的问题。
|
||||
# 输入信息
|
||||
- 历史对话:{{history}}
|
||||
- 检索信息:{{retrieve_info}}
|
||||
# 用户问题
|
||||
{{query}}
|
||||
# 回答指南
|
||||
## 1. 仔细阅读检索信息
|
||||
- 答案可能直接或间接地出现在检索信息中
|
||||
- 如果检索信息中提到"小曼会使用Python",说明用户名是"小曼"
|
||||
- 第三人称描述的偏好、行为通常指用户本人
|
||||
|
||||
## 2. 判断信息相关性
|
||||
**情况A:信息匹配问题**
|
||||
- 直接回答,像自然对话一样
|
||||
- 例:检索到"小曼会使用Python" → 问"我叫什么" → 答"你叫小曼"
|
||||
|
||||
**情况B:信息部分相关**
|
||||
- 先回答已知部分,再自然地询问更多信息
|
||||
- 例:检索到"用户去过上海的面包店" → 问"我吃过哪家面包" → 答"我记得你去过上海的面包店,但具体是哪家我不太清楚,是哪家呢?"
|
||||
|
||||
**情况C:信息完全不相关**
|
||||
- 自然地表达不知道,但可以提及检索到的相关信息,让对话更连贯
|
||||
- 使用友好的表达:
|
||||
- "你好像没和我说过...,但是我知道你[检索到的相关信息]"
|
||||
- "关于这个我不太清楚,不过我记得你[检索到的相关信息],能告诉我更多吗?"
|
||||
- "我不记得你提到过...,但你[检索到的相关信息]"
|
||||
- 即使检索信息不直接回答问题,也可以自然地融入对话中
|
||||
- 避免僵硬的"信息不足,无法回答"
|
||||
## 3. 回答要求
|
||||
- 像人类对话一样自然流畅
|
||||
- 不要提及"检索信息"、"搜索结果"、"根据资料"等技术术语
|
||||
- 不要解释推理过程或引用信息来源
|
||||
- 保持友好、乐于助人的语气
|
||||
- 使用与问题相同的语言回答
|
||||
# 关键示例
|
||||
**示例1 - 直接匹配:**
|
||||
- 检索信息:"小曼会使用Python..."
|
||||
- 问题:"我叫什么"
|
||||
- ✓ 正确:"你叫小曼"
|
||||
- ✗ 错误:"你没有告诉我你的名字"
|
||||
**示例2 - 间接匹配:**
|
||||
- 检索信息:"用户很喜欢吃星巴克的甜品"
|
||||
- 问题:"我喜欢什么"
|
||||
- ✓ 正确:"你很喜欢吃星巴克的甜品"
|
||||
- ✗ 错误:"信息不足"
|
||||
**示例3 - 信息不匹配(推荐做法):**
|
||||
- 检索信息:"用户只喝拿铁咖啡,认为美式咖啡太苦"
|
||||
- 问题:"我吃过哪家面包"
|
||||
- ✓ 最佳:"你好像没和我说过吃过哪家面包,但是我知道你喜欢喝拿铁,能跟我分享一下吗?"
|
||||
- ✓ 可以:"你好像没和我说过吃过哪家面包,能跟我分享一下吗?"
|
||||
- ✗ 错误:"用户只喝拿铁咖啡,认为美式咖啡太苦。"(答非所问)
|
||||
- ✗ 错误:"信息不足,无法回答。"(太僵硬)
|
||||
# 重要提醒
|
||||
- 检索信息中描述用户行为/偏好时提到的名字,就是用户的名字
|
||||
- 信息不匹配时,不要强行回答无关内容,但可以自然地提及检索到的信息,让对话更有温度
|
||||
- 用对话式语言表达"不知道",而非机械模板
|
||||
- 检索信息代表你对用户的了解,即使不直接回答问题,也能体现你对用户的记忆
|
||||
@@ -0,0 +1,43 @@
|
||||
{# 角色定义 #}
|
||||
你是专业的问题解答专家+引导学者
|
||||
|
||||
{# 输入数据展示 #}
|
||||
{% if data %}
|
||||
## 输入数据
|
||||
上下文信息:
|
||||
{% for item in data.history %}
|
||||
- {{ item }}
|
||||
{% endfor %}
|
||||
检索到的所有信息:
|
||||
{% for item in data.retrieve_info %}
|
||||
- {{ item }}
|
||||
{% endfor %}
|
||||
{% endif %}
|
||||
|
||||
## User Query
|
||||
{{ query }}
|
||||
|
||||
{# 问题回答标准 #}
|
||||
## 问题回答核心标准
|
||||
根据上下文信息(history)和检索到的所有信息(retrieve_info)准确回答用户的问题(query)。
|
||||
注意,仔细阅读检索信息,答案可能直接或间接地出现在检索信息中或者历史上下文消息中,同时需要 判断信息相关性
|
||||
**情况A:信息匹配问题**
|
||||
- 直接回答,像自然对话一样
|
||||
- 例:检索到"小曼会使用Python" → 问"我叫什么" → 答"你叫小曼"
|
||||
|
||||
**情况B:信息部分相关**
|
||||
- 先回答已知部分,再自然地询问更多信息
|
||||
- 例:检索到"用户去过上海的面包店" → 问"我吃过哪家面包" → 答"我记得你去过上海的面包店,但具体是哪家我不太清楚,是哪家呢?"
|
||||
|
||||
**情况C:信息完全不相关**
|
||||
- 自然地表达不知道,但可以提及检索到的相关信息,让对话更连贯
|
||||
- 使用友好的表达:
|
||||
- "你好像没和我说过...,但是我知道你[检索到的相关信息]"
|
||||
- "关于这个我不太清楚,不过我记得你[检索到的相关信息],能告诉我更多吗?"
|
||||
- "我不记得你提到过...,但你[检索到的相关信息]"
|
||||
- 即使检索信息不直接回答问题,也可以自然地融入对话中
|
||||
- 避免僵硬的"信息不足,无法回答"
|
||||
|
||||
{# 重要提醒 #}
|
||||
当检索以及上下文的历史信息都无法回答的时候,可引导对方进行提问/回答,或者进行其他引导
|
||||
当检索或者上下文中出现了,相似的问题,可以委婉,提醒对方,我记得刚刚提过这个问题,但是我自己不记得了,能在描述一次吗~以此为例
|
||||
@@ -28,7 +28,7 @@ class RedisSessionStore:
|
||||
return text
|
||||
|
||||
# 修改后的 save_session 方法
|
||||
def save_session(self, userid, messages, aimessages, apply_id, group_id):
|
||||
def save_session(self, userid, messages, aimessages, apply_id, end_user_id):
|
||||
"""
|
||||
写入一条会话数据,返回 session_id
|
||||
优化版本:确保写入时间不超过1秒
|
||||
@@ -46,7 +46,7 @@ class RedisSessionStore:
|
||||
"id": self.uudi,
|
||||
"sessionid": userid,
|
||||
"apply_id": apply_id,
|
||||
"group_id": group_id,
|
||||
"end_user_id": end_user_id,
|
||||
"messages": messages,
|
||||
"aimessages": aimessages,
|
||||
"starttime": starttime
|
||||
@@ -67,7 +67,7 @@ class RedisSessionStore:
|
||||
def save_sessions_batch(self, sessions_data):
|
||||
"""
|
||||
批量写入多条会话数据,返回 session_id 列表
|
||||
sessions_data: list of dict, 每个 dict 包含 userid, messages, aimessages, apply_id, group_id
|
||||
sessions_data: list of dict, 每个 dict 包含 userid, messages, aimessages, apply_id, end_user_id
|
||||
优化版本:批量操作,大幅提升性能
|
||||
"""
|
||||
try:
|
||||
@@ -83,7 +83,7 @@ class RedisSessionStore:
|
||||
"id": self.uudi,
|
||||
"sessionid": session.get('userid'),
|
||||
"apply_id": session.get('apply_id'),
|
||||
"group_id": session.get('group_id'),
|
||||
"end_user_id": session.get('end_user_id'),
|
||||
"messages": session.get('messages'),
|
||||
"aimessages": session.get('aimessages'),
|
||||
"starttime": starttime
|
||||
@@ -108,9 +108,9 @@ class RedisSessionStore:
|
||||
data = self.r.hgetall(key)
|
||||
return data if data else None
|
||||
|
||||
def get_session_apply_group(self, sessionid, apply_id, group_id):
|
||||
def get_session_apply_group(self, sessionid, apply_id, end_user_id):
|
||||
"""
|
||||
根据 sessionid、apply_id 和 group_id 三个条件查询会话数据
|
||||
根据 sessionid、apply_id 和 end_user_id 三个条件查询会话数据
|
||||
"""
|
||||
result_items = []
|
||||
|
||||
@@ -124,7 +124,7 @@ class RedisSessionStore:
|
||||
# 检查三个条件是否都匹配
|
||||
if (data.get('sessionid') == sessionid and
|
||||
data.get('apply_id') == apply_id and
|
||||
data.get('group_id') == group_id):
|
||||
data.get('end_user_id') == end_user_id):
|
||||
result_items.append(data)
|
||||
|
||||
return result_items
|
||||
@@ -172,7 +172,7 @@ class RedisSessionStore:
|
||||
def delete_duplicate_sessions(self):
|
||||
"""
|
||||
删除重复会话数据,条件:
|
||||
"sessionid"、"user_id"、"group_id"、"messages"、"aimessages" 五个字段都相同的只保留一个,其他删除
|
||||
"sessionid"、"user_id"、"end_user_id"、"messages"、"aimessages" 五个字段都相同的只保留一个,其他删除
|
||||
优化版本:使用 pipeline 批量操作,确保在1秒内完成
|
||||
"""
|
||||
import time
|
||||
@@ -202,12 +202,12 @@ class RedisSessionStore:
|
||||
# 获取五个字段的值
|
||||
sessionid = data.get('sessionid', '')
|
||||
user_id = data.get('id', '')
|
||||
group_id = data.get('group_id', '')
|
||||
end_user_id = data.get('end_user_id', '')
|
||||
messages = data.get('messages', '')
|
||||
aimessages = data.get('aimessages', '')
|
||||
|
||||
# 用五元组作为唯一标识
|
||||
identifier = (sessionid, user_id, group_id, messages, aimessages)
|
||||
identifier = (sessionid, user_id, end_user_id, messages, aimessages)
|
||||
|
||||
if identifier in seen:
|
||||
# 重复,标记为待删除
|
||||
@@ -248,9 +248,9 @@ class RedisSessionStore:
|
||||
result_items = []
|
||||
return (result_items)
|
||||
|
||||
def find_user_apply_group(self, sessionid, apply_id, group_id):
|
||||
def find_user_apply_group(self, sessionid, apply_id, end_user_id):
|
||||
"""
|
||||
根据 sessionid、apply_id 和 group_id 三个条件查询会话数据,返回最新的6条
|
||||
根据 sessionid、apply_id 和 end_user_id 三个条件查询会话数据,返回最新的6条
|
||||
"""
|
||||
import time
|
||||
start_time = time.time()
|
||||
@@ -276,7 +276,7 @@ class RedisSessionStore:
|
||||
# 检查是否符合三个条件
|
||||
|
||||
if (data.get('apply_id') == apply_id and
|
||||
data.get('group_id') == group_id):
|
||||
data.get('end_user_id') == end_user_id):
|
||||
# 支持模糊匹配 sessionid 或者完全匹配
|
||||
if sessionid in data.get('sessionid', '') or data.get('sessionid') == sessionid:
|
||||
matched_items.append({
|
||||
|
||||
@@ -59,7 +59,7 @@ class SessionService:
|
||||
self,
|
||||
user_id: str,
|
||||
apply_id: str,
|
||||
group_id: str
|
||||
end_user_id: str
|
||||
) -> List[dict]:
|
||||
"""
|
||||
Retrieve conversation history from Redis.
|
||||
@@ -67,20 +67,20 @@ class SessionService:
|
||||
Args:
|
||||
user_id: User identifier
|
||||
apply_id: Application identifier
|
||||
group_id: Group identifier
|
||||
end_user_id: Group identifier
|
||||
|
||||
Returns:
|
||||
List of conversation history items with Query and Answer keys
|
||||
Returns empty list if no history found or on error
|
||||
"""
|
||||
try:
|
||||
history = self.store.find_user_apply_group(user_id, apply_id, group_id)
|
||||
history = self.store.find_user_apply_group(user_id, apply_id, end_user_id)
|
||||
|
||||
# Validate history structure
|
||||
if not isinstance(history, list):
|
||||
logger.warning(
|
||||
f"Invalid history format for user {user_id}, "
|
||||
f"apply {apply_id}, group {group_id}: expected list, got {type(history)}"
|
||||
f"apply {apply_id}, group {end_user_id}: expected list, got {type(history)}"
|
||||
)
|
||||
return []
|
||||
|
||||
@@ -89,7 +89,7 @@ class SessionService:
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to retrieve history for user {user_id}, "
|
||||
f"apply {apply_id}, group {group_id}: {e}",
|
||||
f"apply {apply_id}, group {end_user_id}: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
# Return empty list on error to allow execution to continue
|
||||
@@ -100,7 +100,7 @@ class SessionService:
|
||||
user_id: str,
|
||||
query: str,
|
||||
apply_id: str,
|
||||
group_id: str,
|
||||
end_user_id: str,
|
||||
ai_response: str
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
@@ -110,7 +110,7 @@ class SessionService:
|
||||
user_id: User identifier
|
||||
query: User query/message
|
||||
apply_id: Application identifier
|
||||
group_id: Group identifier
|
||||
end_user_id: Group identifier
|
||||
ai_response: AI response/answer
|
||||
|
||||
Returns:
|
||||
@@ -131,7 +131,7 @@ class SessionService:
|
||||
userid=user_id,
|
||||
messages=query,
|
||||
apply_id=apply_id,
|
||||
group_id=group_id,
|
||||
end_user_id=end_user_id,
|
||||
aimessages=ai_response
|
||||
)
|
||||
|
||||
@@ -152,7 +152,7 @@ class SessionService:
|
||||
Duplicates are identified by matching:
|
||||
- sessionid
|
||||
- user_id (id field)
|
||||
- group_id
|
||||
- end_user_id
|
||||
- messages
|
||||
- aimessages
|
||||
|
||||
|
||||
@@ -29,20 +29,18 @@ logger = get_agent_logger(__name__)
|
||||
|
||||
|
||||
async def write(
|
||||
user_id: str,
|
||||
apply_id: str,
|
||||
group_id: str,
|
||||
end_user_id: str,
|
||||
memory_config: MemoryConfig,
|
||||
messages: list,
|
||||
ref_id: str = "wyl20251027",
|
||||
) -> None:
|
||||
"""
|
||||
Execute the complete knowledge extraction pipeline.
|
||||
|
||||
|
||||
Args:
|
||||
user_id: User identifier
|
||||
apply_id: Application identifier
|
||||
group_id: Group identifier
|
||||
end_user_id: Group identifier
|
||||
memory_config: MemoryConfig object containing all configuration
|
||||
messages: Structured message list [{"role": "user", "content": "..."}, ...]
|
||||
ref_id: Reference ID, defaults to "wyl20251027"
|
||||
@@ -51,14 +49,14 @@ async def write(
|
||||
embedding_model_id = str(memory_config.embedding_model_id)
|
||||
chunker_strategy = memory_config.chunker_strategy
|
||||
config_id = str(memory_config.config_id)
|
||||
|
||||
|
||||
logger.info("=== MemSci Knowledge Extraction Pipeline ===")
|
||||
logger.info(f"Config: {memory_config.config_name} (ID: {config_id})")
|
||||
logger.info(f"Workspace: {memory_config.workspace_name}")
|
||||
logger.info(f"LLM model: {memory_config.llm_model_name}")
|
||||
logger.info(f"Embedding model: {memory_config.embedding_model_name}")
|
||||
logger.info(f"Chunker strategy: {chunker_strategy}")
|
||||
logger.info(f"Group ID: {group_id}")
|
||||
logger.info(f"end_user_id ID: {end_user_id}")
|
||||
|
||||
# Construct clients from memory_config using factory pattern with db session
|
||||
with get_db_context() as db:
|
||||
@@ -83,9 +81,7 @@ async def write(
|
||||
step_start = time.time()
|
||||
chunked_dialogs = await get_chunked_dialogs(
|
||||
chunker_strategy=chunker_strategy,
|
||||
group_id=group_id,
|
||||
user_id=user_id,
|
||||
apply_id=apply_id,
|
||||
end_user_id=end_user_id,
|
||||
messages=messages,
|
||||
ref_id=ref_id,
|
||||
config_id=config_id,
|
||||
|
||||
@@ -139,7 +139,8 @@ def parse_api_docs(file_path: str) -> Dict[str, Any]:
|
||||
|
||||
|
||||
def get_default_docs_path() -> str:
|
||||
project_root = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
|
||||
from pathlib import Path
|
||||
project_root = str(Path(__file__).resolve().parents[2])
|
||||
return os.path.join(project_root, "src", "analytics", "API接口.md")
|
||||
|
||||
|
||||
|
||||
@@ -16,13 +16,13 @@ class FilteredTags(BaseModel):
|
||||
"""用于接收LLM筛选后的核心标签列表的模型。"""
|
||||
meaningful_tags: List[str] = Field(..., description="从原始列表中筛选出的具有核心代表意义的名词列表。")
|
||||
|
||||
async def filter_tags_with_llm(tags: List[str], group_id: str) -> List[str]:
|
||||
async def filter_tags_with_llm(tags: List[str], end_user_id: str) -> List[str]:
|
||||
"""
|
||||
使用LLM筛选标签列表,仅保留具有代表性的核心名词。
|
||||
|
||||
Args:
|
||||
tags: 原始标签列表
|
||||
group_id: 用户组ID,用于获取配置
|
||||
end_user_id: 用户组ID,用于获取配置
|
||||
|
||||
Returns:
|
||||
筛选后的标签列表
|
||||
@@ -37,12 +37,12 @@ async def filter_tags_with_llm(tags: List[str], group_id: str) -> List[str]:
|
||||
get_end_user_connected_config,
|
||||
)
|
||||
|
||||
connected_config = get_end_user_connected_config(group_id, db)
|
||||
connected_config = get_end_user_connected_config(end_user_id, db)
|
||||
config_id = connected_config.get("memory_config_id")
|
||||
|
||||
if not config_id:
|
||||
raise ValueError(
|
||||
f"No memory_config_id found for group_id: {group_id}. "
|
||||
f"No memory_config_id found for end_user_id: {end_user_id}. "
|
||||
"Please ensure the user has a valid memory configuration."
|
||||
)
|
||||
|
||||
@@ -87,7 +87,7 @@ async def filter_tags_with_llm(tags: List[str], group_id: str) -> List[str]:
|
||||
|
||||
async def get_raw_tags_from_db(
|
||||
connector: Neo4jConnector,
|
||||
group_id: str,
|
||||
end_user_id: str,
|
||||
limit: int,
|
||||
by_user: bool = False
|
||||
) -> List[Tuple[str, int]]:
|
||||
@@ -99,9 +99,9 @@ async def get_raw_tags_from_db(
|
||||
|
||||
Args:
|
||||
connector: Neo4j连接器实例
|
||||
group_id: 如果by_user=False,则为group_id;如果by_user=True,则为user_id
|
||||
end_user_id: 如果by_user=False,则为end_user_id;如果by_user=True,则为user_id
|
||||
limit: 返回的标签数量限制
|
||||
by_user: 是否按user_id查询(默认False,按group_id查询)
|
||||
by_user: 是否按user_id查询(默认False,按end_user_id查询)
|
||||
|
||||
Returns:
|
||||
List[Tuple[str, int]]: 标签名称和频率的元组列表
|
||||
@@ -119,7 +119,7 @@ async def get_raw_tags_from_db(
|
||||
else:
|
||||
query = (
|
||||
"MATCH (e:ExtractedEntity) "
|
||||
"WHERE e.group_id = $id AND e.entity_type <> '人物' AND e.name IS NOT NULL AND NOT e.name IN $names_to_exclude "
|
||||
"WHERE e.end_user_id = $id AND e.entity_type <> '人物' AND e.name IS NOT NULL AND NOT e.name IN $names_to_exclude "
|
||||
"RETURN e.name AS name, count(e) AS frequency "
|
||||
"ORDER BY frequency DESC "
|
||||
"LIMIT $limit"
|
||||
@@ -128,44 +128,44 @@ async def get_raw_tags_from_db(
|
||||
# 使用项目的Neo4jConnector执行查询
|
||||
results = await connector.execute_query(
|
||||
query,
|
||||
id=group_id,
|
||||
id=end_user_id,
|
||||
limit=limit,
|
||||
names_to_exclude=names_to_exclude
|
||||
)
|
||||
|
||||
return [(record["name"], record["frequency"]) for record in results]
|
||||
|
||||
async def get_hot_memory_tags(group_id: str, limit: int = 40, by_user: bool = False) -> List[Tuple[str, int]]:
|
||||
async def get_hot_memory_tags(end_user_id: str, limit: int = 40, by_user: bool = False) -> List[Tuple[str, int]]:
|
||||
"""
|
||||
获取原始标签,然后使用LLM进行筛选,返回最终的热门标签列表。
|
||||
查询更多的标签(limit=40)给LLM提供更丰富的上下文进行筛选。
|
||||
|
||||
Args:
|
||||
group_id: 必需参数。如果by_user=False,则为group_id;如果by_user=True,则为user_id
|
||||
end_user_id: 必需参数。如果by_user=False,则为end_user_id;如果by_user=True,则为user_id
|
||||
limit: 返回的标签数量限制
|
||||
by_user: 是否按user_id查询(默认False,按group_id查询)
|
||||
by_user: 是否按user_id查询(默认False,按end_user_id查询)
|
||||
|
||||
Raises:
|
||||
ValueError: 如果group_id未提供或为空
|
||||
ValueError: 如果end_user_id未提供或为空
|
||||
"""
|
||||
# 验证group_id必须提供且不为空
|
||||
if not group_id or not group_id.strip():
|
||||
# 验证end_user_id必须提供且不为空
|
||||
if not end_user_id or not end_user_id.strip():
|
||||
raise ValueError(
|
||||
"group_id is required. Please provide a valid group_id or user_id."
|
||||
"end_user_id is required. Please provide a valid end_user_id or user_id."
|
||||
)
|
||||
|
||||
# 使用项目的Neo4jConnector
|
||||
connector = Neo4jConnector()
|
||||
try:
|
||||
# 1. 从数据库获取原始排名靠前的标签
|
||||
raw_tags_with_freq = await get_raw_tags_from_db(connector, group_id, limit, by_user=by_user)
|
||||
raw_tags_with_freq = await get_raw_tags_from_db(connector, end_user_id, limit, by_user=by_user)
|
||||
if not raw_tags_with_freq:
|
||||
return []
|
||||
|
||||
raw_tag_names = [tag for tag, freq in raw_tags_with_freq]
|
||||
|
||||
# 2. 初始化LLM客户端并使用LLM筛选出有意义的标签
|
||||
meaningful_tag_names = await filter_tags_with_llm(raw_tag_names, group_id)
|
||||
meaningful_tag_names = await filter_tags_with_llm(raw_tag_names, end_user_id)
|
||||
|
||||
# 3. 根据LLM的筛选结果,构建最终的标签列表(保留原始频率和顺序)
|
||||
final_tags = []
|
||||
|
||||
@@ -75,8 +75,8 @@ class MemoryDataSource:
|
||||
start_date = time_range.start_date if time_range else None
|
||||
end_date = time_range.end_date if time_range else None
|
||||
|
||||
summary_dicts = await self.memory_summary_repo.find_by_group_id(
|
||||
group_id=user_id,
|
||||
summary_dicts = await self.memory_summary_repo.find_by_end_user_id(
|
||||
end_user_id=user_id,
|
||||
limit=limit,
|
||||
start_date=start_date,
|
||||
end_date=end_date
|
||||
|
||||
@@ -2,13 +2,16 @@ import os
|
||||
import re
|
||||
import glob
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Tuple
|
||||
|
||||
try:
|
||||
from app.core.memory.utils.config.definitions import PROJECT_ROOT
|
||||
except Exception:
|
||||
# Fallback: derive project root from this file location
|
||||
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
# 当前文件在 api/app/core/memory/analytics/recent_activity_stats.py
|
||||
# 需要向上 5 级到达 api/ 目录
|
||||
PROJECT_ROOT = str(Path(__file__).resolve().parents[4])
|
||||
|
||||
|
||||
def _get_latest_prompt_log_path() -> str | None:
|
||||
@@ -67,44 +70,43 @@ def parse_stats_from_log(log_path: str) -> dict:
|
||||
triplet_relations_count = 0
|
||||
temporal_count = 0
|
||||
|
||||
# Patterns
|
||||
# 正则表达式模式 - 匹配当前日志格式
|
||||
pat_chunk_render = re.compile(r"===\s*RENDERED\s*STATEMENT\s*EXTRACTION\s*PROMPT\s*===")
|
||||
pat_triplet_start = re.compile(r"\[Triplet\].*statements_to_process\s*=\s*(\d+)")
|
||||
pat_triplet_done = re.compile(
|
||||
r"\[Triplet\].*completed,\s*total_triplets\s*=\s*(\d+),\s*total_entities\s*=\s*(\d+)"
|
||||
pat_triplet_started = re.compile(r"\[Triplet\]\s+Started\s+-\s+statement_id=")
|
||||
pat_triplet_completed = re.compile(
|
||||
r"\[Triplet\]\s+Completed\s+-\s+statement_id=[^,]+,\s+triplets=(\d+),\s+entities=(\d+)"
|
||||
)
|
||||
pat_temporal_done = re.compile(
|
||||
r"\[Temporal\].*completed,\s*extracted_valid_ranges\s*=\s*(\d+)"
|
||||
pat_temporal_completed = re.compile(
|
||||
r"\[Temporal\]\s+Completed\s+-\s+statement_id=[^,]+,\s+valid_ranges=(\d+)"
|
||||
)
|
||||
|
||||
with open(log_path, "r", encoding="utf-8", errors="ignore") as f:
|
||||
for line in f:
|
||||
# Chunk prompts count (each chunk triggers one statement-extraction prompt render)
|
||||
# 文本块数量(每个块触发一次陈述提取提示)
|
||||
if pat_chunk_render.search(line):
|
||||
chunk_count += 1
|
||||
continue
|
||||
|
||||
m1 = pat_triplet_start.search(line)
|
||||
if m1:
|
||||
# 陈述数量(每个 Triplet Started 代表一个陈述被处理)
|
||||
if pat_triplet_started.search(line):
|
||||
statements_count += 1
|
||||
continue
|
||||
|
||||
# 三元组完成:[Triplet] Completed - statement_id=xxx, triplets=X, entities=Y
|
||||
m_triplet = pat_triplet_completed.search(line)
|
||||
if m_triplet:
|
||||
try:
|
||||
statements_count += int(m1.group(1))
|
||||
triplet_relations_count += int(m_triplet.group(1))
|
||||
triplet_entities_count += int(m_triplet.group(2))
|
||||
except Exception:
|
||||
pass
|
||||
continue
|
||||
|
||||
m2 = pat_triplet_done.search(line)
|
||||
if m2:
|
||||
# 时间信息完成:[Temporal] Completed - statement_id=xxx, valid_ranges=X
|
||||
m_temporal = pat_temporal_completed.search(line)
|
||||
if m_temporal:
|
||||
try:
|
||||
triplet_relations_count += int(m2.group(1))
|
||||
triplet_entities_count += int(m2.group(2))
|
||||
except Exception:
|
||||
pass
|
||||
continue
|
||||
|
||||
m3 = pat_temporal_done.search(line)
|
||||
if m3:
|
||||
try:
|
||||
temporal_count += int(m3.group(1))
|
||||
temporal_count += int(m_temporal.group(1))
|
||||
except Exception:
|
||||
pass
|
||||
continue
|
||||
@@ -120,15 +122,20 @@ def parse_stats_from_log(log_path: str) -> dict:
|
||||
|
||||
|
||||
def get_recent_activity_stats() -> Tuple[dict, str]:
|
||||
"""Get aggregated stats from all prompt logs in logs/.
|
||||
"""Get stats from the latest prompt log file only.
|
||||
|
||||
Returns (stats_dict, message).
|
||||
"""
|
||||
all_logs = _get_all_prompt_logs()
|
||||
# Fallback to recursive search if none found in logs/
|
||||
if not all_logs:
|
||||
# 获取最新的日志文件
|
||||
latest_log = _get_latest_prompt_log_path()
|
||||
|
||||
# 如果没有找到,尝试递归搜索
|
||||
if not latest_log:
|
||||
all_logs = _get_any_logs_recursive()
|
||||
if not all_logs:
|
||||
if all_logs:
|
||||
latest_log = all_logs[-1] # 取最新的
|
||||
|
||||
if not latest_log:
|
||||
return (
|
||||
{
|
||||
"chunk_count": 0,
|
||||
@@ -141,24 +148,13 @@ def get_recent_activity_stats() -> Tuple[dict, str]:
|
||||
"未找到日志文件,请确认已运行过提取流程。",
|
||||
)
|
||||
|
||||
agg = {
|
||||
"chunk_count": 0,
|
||||
"statements_count": 0,
|
||||
"triplet_entities_count": 0,
|
||||
"triplet_relations_count": 0,
|
||||
"temporal_count": 0,
|
||||
}
|
||||
for path in all_logs:
|
||||
s = parse_stats_from_log(path)
|
||||
agg["chunk_count"] += s.get("chunk_count", 0)
|
||||
agg["statements_count"] += s.get("statements_count", 0)
|
||||
agg["triplet_entities_count"] += s.get("triplet_entities_count", 0)
|
||||
agg["triplet_relations_count"] += s.get("triplet_relations_count", 0)
|
||||
agg["temporal_count"] += s.get("temporal_count", 0)
|
||||
|
||||
# Attach a summary of files combined
|
||||
agg["log_path"] = f"{len(all_logs)} 个日志文件,最新:{all_logs[-1]}"
|
||||
return agg, "成功汇总 logs 目录中所有提示日志。"
|
||||
# 只解析最新的日志文件
|
||||
stats = parse_stats_from_log(latest_log)
|
||||
|
||||
# 添加日志文件路径信息
|
||||
stats["log_path"] = f"最新:{latest_log}"
|
||||
|
||||
return stats, "成功读取最近一次记忆活动统计。"
|
||||
|
||||
|
||||
def _format_summary(stats: dict) -> str:
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
"""Evaluation package with dataset-specific pipelines and a unified runner."""
|
||||
@@ -1,30 +0,0 @@
|
||||
⏬数据集下载地址:
|
||||
Locomo10.json:https://github.com/snap-research/locomo/tree/main/data
|
||||
LongMemEval_oracle.json:https://huggingface.co/datasets/xiaowu0162/longmemeval-cleaned
|
||||
msc_self_instruct.jsonl:https://huggingface.co/datasets/MemGPT/MSC-Self-Instruct
|
||||
上方数据集下载好后全部放入app/core/memory/data文件夹中
|
||||
|
||||
全流程基准测试运行:
|
||||
locomo:
|
||||
python -m app.core.memory.evaluation.run_eval --dataset locomo --sample-size 1 --reset-group --group-id yyw1 --search-type hybrid --search-limit 8 --context-char-budget 12000 --llm-max-tokens 32
|
||||
LongMemEval:
|
||||
python -m app.core.memory.evaluation.run_eval --dataset longmemeval --sample-size 10 --start-index 0 --group-id longmemeval_zh_bak_2 --search-limit 8 --context-char-budget 4000 --search-type hybrid --max-contexts-per-item 2 --reset-group
|
||||
memsciqa:
|
||||
python -m app.core.memory.evaluation.run_eval --dataset memsciqa --sample-size 10 --reset-group --group-id group_memsci
|
||||
|
||||
单独检索评估运行命令:
|
||||
python -m app.core.memory.evaluation.locomo.locomo_test
|
||||
python -m app.core.memory.evaluation.longmemeval.test_eval
|
||||
python -m app.core.memory.evaluation.memsciqa.memsciqa-test
|
||||
需要先在项目中修改需要检测评估的group_id。
|
||||
|
||||
参数及解释:
|
||||
● --dataset longmemeval - 指定数据集
|
||||
● --sample-size 10 - 评估10个样本
|
||||
● --start-index 0 - 从第0个样本开始
|
||||
● --group-id longmemeval_zh_bak_2 - 使用指定的组ID
|
||||
● --search-limit 8 - 检索限制8条
|
||||
● --context-char-budget 4000 - 上下文字符预算4000
|
||||
● --search-type hybrid - 使用混合检索
|
||||
● --max-contexts-per-item 2 - 每个样本最多摄入2个上下文
|
||||
● --reset-group - 运行前清空组数据
|
||||
@@ -1,100 +0,0 @@
|
||||
import math
|
||||
import re
|
||||
from typing import List, Dict
|
||||
|
||||
|
||||
def _normalize(text: str) -> List[str]:
|
||||
"""Lowercase, strip punctuation, and split into tokens."""
|
||||
text = text.lower().strip()
|
||||
# Python's re doesn't support \p classes; use a simple non-word filter
|
||||
text = re.sub(r"[^\w\s]", " ", text)
|
||||
tokens = [t for t in text.split() if t]
|
||||
return tokens
|
||||
|
||||
|
||||
def exact_match(pred: str, ref: str) -> float:
|
||||
return float(_normalize(pred) == _normalize(ref))
|
||||
|
||||
|
||||
def jaccard(pred: str, ref: str) -> float:
|
||||
p = set(_normalize(pred))
|
||||
r = set(_normalize(ref))
|
||||
if not p and not r:
|
||||
return 1.0
|
||||
if not p or not r:
|
||||
return 0.0
|
||||
return len(p & r) / len(p | r)
|
||||
|
||||
|
||||
def f1_score(pred: str, ref: str) -> float:
|
||||
p_tokens = _normalize(pred)
|
||||
r_tokens = _normalize(ref)
|
||||
if not p_tokens and not r_tokens:
|
||||
return 1.0
|
||||
if not p_tokens or not r_tokens:
|
||||
return 0.0
|
||||
p_set = set(p_tokens)
|
||||
r_set = set(r_tokens)
|
||||
tp = len(p_set & r_set)
|
||||
precision = tp / len(p_set) if p_set else 0.0
|
||||
recall = tp / len(r_set) if r_set else 0.0
|
||||
if precision + recall == 0:
|
||||
return 0.0
|
||||
return 2 * precision * recall / (precision + recall)
|
||||
|
||||
|
||||
def bleu1(pred: str, ref: str) -> float:
|
||||
"""Unigram BLEU (BLEU-1) with clipping and brevity penalty."""
|
||||
p_tokens = _normalize(pred)
|
||||
r_tokens = _normalize(ref)
|
||||
if not p_tokens:
|
||||
return 0.0
|
||||
# Clipped count
|
||||
r_counts: Dict[str, int] = {}
|
||||
for t in r_tokens:
|
||||
r_counts[t] = r_counts.get(t, 0) + 1
|
||||
clipped = 0
|
||||
p_counts: Dict[str, int] = {}
|
||||
for t in p_tokens:
|
||||
p_counts[t] = p_counts.get(t, 0) + 1
|
||||
for t, c in p_counts.items():
|
||||
clipped += min(c, r_counts.get(t, 0))
|
||||
precision = clipped / max(len(p_tokens), 1)
|
||||
# Brevity penalty
|
||||
ref_len = len(r_tokens)
|
||||
pred_len = len(p_tokens)
|
||||
if pred_len > ref_len or pred_len == 0:
|
||||
bp = 1.0
|
||||
else:
|
||||
bp = math.exp(1 - ref_len / max(pred_len, 1))
|
||||
return bp * precision
|
||||
|
||||
|
||||
def percentile(values: List[float], p: float) -> float:
|
||||
if not values:
|
||||
return 0.0
|
||||
vals = sorted(values)
|
||||
k = (len(vals) - 1) * p
|
||||
f = math.floor(k)
|
||||
c = math.ceil(k)
|
||||
if f == c:
|
||||
return vals[int(k)]
|
||||
return vals[f] + (k - f) * (vals[c] - vals[f])
|
||||
|
||||
|
||||
def latency_stats(latencies_ms: List[float]) -> Dict[str, float]:
|
||||
"""Return basic latency stats: mean, p50, p95, iqr (p75-p25)."""
|
||||
if not latencies_ms:
|
||||
return {"mean": 0.0, "p50": 0.0, "p95": 0.0, "iqr": 0.0}
|
||||
p25 = percentile(latencies_ms, 0.25)
|
||||
p50 = percentile(latencies_ms, 0.50)
|
||||
p75 = percentile(latencies_ms, 0.75)
|
||||
p95 = percentile(latencies_ms, 0.95)
|
||||
mean = sum(latencies_ms) / max(len(latencies_ms), 1)
|
||||
return {"mean": mean, "p50": p50, "p95": p95, "iqr": p75 - p25}
|
||||
|
||||
|
||||
def avg_context_tokens(contexts: List[str]) -> float:
|
||||
if not contexts:
|
||||
return 0.0
|
||||
return sum(len(_normalize(c)) for c in contexts) / len(contexts)
|
||||
@@ -1,60 +0,0 @@
|
||||
"""
|
||||
Dialogue search queries for evaluation purposes.
|
||||
This file contains Cypher queries for searching dialogues, entities, and chunks.
|
||||
Placed in evaluation directory to avoid circular imports with src modules.
|
||||
"""
|
||||
|
||||
# Entity search queries
|
||||
SEARCH_ENTITIES_BY_NAME = """
|
||||
MATCH (e:Entity)
|
||||
WHERE e.name = $name
|
||||
RETURN e
|
||||
"""
|
||||
|
||||
SEARCH_ENTITIES_BY_NAME_FALLBACK = """
|
||||
MATCH (e:Entity)
|
||||
WHERE e.name CONTAINS $name
|
||||
RETURN e
|
||||
"""
|
||||
|
||||
# Chunk search queries
|
||||
SEARCH_CHUNKS_BY_CONTENT = """
|
||||
MATCH (c:Chunk)
|
||||
WHERE c.content CONTAINS $content
|
||||
RETURN c
|
||||
"""
|
||||
|
||||
# Dialogue search queries
|
||||
SEARCH_DIALOGUE_BY_DIALOG_ID = """
|
||||
MATCH (d:Dialogue)
|
||||
WHERE d.dialog_id = $dialog_id
|
||||
RETURN d
|
||||
"""
|
||||
|
||||
SEARCH_DIALOGUES_BY_CONTENT = """
|
||||
MATCH (d:Dialogue)
|
||||
WHERE d.content CONTAINS $q
|
||||
RETURN d
|
||||
"""
|
||||
|
||||
DIALOGUE_EMBEDDING_SEARCH = """
|
||||
WITH $embedding AS q
|
||||
MATCH (d:Dialogue)
|
||||
WHERE d.dialog_embedding IS NOT NULL
|
||||
AND ($group_id IS NULL OR d.group_id = $group_id)
|
||||
WITH d, q, d.dialog_embedding AS v
|
||||
WITH d,
|
||||
reduce(dot = 0.0, i IN range(0, size(q)-1) | dot + toFloat(q[i]) * toFloat(v[i])) AS dot,
|
||||
sqrt(reduce(qs = 0.0, i IN range(0, size(q)-1) | qs + toFloat(q[i]) * toFloat(q[i]))) AS qnorm,
|
||||
sqrt(reduce(vs = 0.0, i IN range(0, size(v)-1) | vs + toFloat(v[i]) * toFloat(v[i]))) AS vnorm
|
||||
WITH d, CASE WHEN qnorm = 0 OR vnorm = 0 THEN 0.0 ELSE dot / (qnorm * vnorm) END AS score
|
||||
WHERE score > $threshold
|
||||
RETURN d.id AS dialog_id,
|
||||
d.group_id AS group_id,
|
||||
d.content AS content,
|
||||
d.created_at AS created_at,
|
||||
d.expired_at AS expired_at,
|
||||
score
|
||||
ORDER BY score DESC
|
||||
LIMIT $limit
|
||||
"""
|
||||
@@ -1,341 +0,0 @@
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from app.core.memory.llm_tools.openai_client import LLMClient
|
||||
from app.core.memory.models.message_models import (
|
||||
ConversationContext,
|
||||
ConversationMessage,
|
||||
DialogData,
|
||||
)
|
||||
|
||||
# 使用新的模块化架构
|
||||
from app.core.memory.storage_services.extraction_engine.extraction_orchestrator import (
|
||||
ExtractionOrchestrator,
|
||||
)
|
||||
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.chunk_extraction import (
|
||||
DialogueChunker,
|
||||
)
|
||||
from app.core.memory.utils.config.definitions import (
|
||||
SELECTED_CHUNKER_STRATEGY,
|
||||
SELECTED_EMBEDDING_ID,
|
||||
)
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.db import get_db_context
|
||||
|
||||
# Import from database module
|
||||
from app.repositories.neo4j.graph_saver import save_dialog_and_statements_to_neo4j
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
|
||||
# Cypher queries for evaluation
|
||||
# Note: Entity, chunk, and dialogue search queries have been moved to evaluation/dialogue_queries.py
|
||||
|
||||
|
||||
async def ingest_contexts_via_full_pipeline(
|
||||
contexts: List[str],
|
||||
group_id: str,
|
||||
chunker_strategy: str | None = None,
|
||||
embedding_name: str | None = None,
|
||||
save_chunk_output: bool = False,
|
||||
save_chunk_output_path: str | None = None,
|
||||
) -> bool:
|
||||
"""DEPRECATED: 此函数使用旧的流水线架构,建议使用新的 ExtractionOrchestrator
|
||||
|
||||
Run the full extraction pipeline on provided dialogue contexts and save to Neo4j.
|
||||
This function mirrors the steps in main(), but starts from raw text contexts.
|
||||
Args:
|
||||
contexts: List of dialogue texts, each containing lines like "role: message".
|
||||
group_id: Group ID to assign to generated DialogData and graph nodes.
|
||||
chunker_strategy: Optional chunker strategy; defaults to SELECTED_CHUNKER_STRATEGY.
|
||||
embedding_name: Optional embedding model ID; defaults to SELECTED_EMBEDDING_ID.
|
||||
save_chunk_output: If True, write chunked DialogData list to a JSON file for debugging.
|
||||
save_chunk_output_path: Optional output path; defaults to src/chunker_test_output.txt.
|
||||
Returns:
|
||||
True if data saved successfully, False otherwise.
|
||||
"""
|
||||
chunker_strategy = chunker_strategy or SELECTED_CHUNKER_STRATEGY
|
||||
embedding_name = embedding_name or SELECTED_EMBEDDING_ID
|
||||
|
||||
# Initialize llm client with graceful fallback
|
||||
llm_client = None
|
||||
llm_available = True
|
||||
try:
|
||||
from app.core.memory.utils.config import definitions as config_defs
|
||||
with get_db_context() as db:
|
||||
factory = MemoryClientFactory(db)
|
||||
llm_client = factory.get_llm_client(config_defs.SELECTED_LLM_ID)
|
||||
except Exception as e:
|
||||
print(f"[Ingestion] LLM client unavailable, will skip LLM-dependent steps: {e}")
|
||||
llm_available = False
|
||||
|
||||
# Step A: Build DialogData list from contexts with robust parsing
|
||||
chunker = DialogueChunker(chunker_strategy)
|
||||
dialog_data_list: List[DialogData] = []
|
||||
|
||||
for idx, ctx in enumerate(contexts):
|
||||
messages: List[ConversationMessage] = []
|
||||
|
||||
# Improved parsing: capture multi-line message blocks, normalize roles
|
||||
pattern = r"^\s*(用户|AI|assistant|user)\s*[::]\s*(.+?)(?=\n\s*(?:用户|AI|assistant|user)\s*[::]|\Z)"
|
||||
matches = list(re.finditer(pattern, ctx, flags=re.MULTILINE | re.DOTALL))
|
||||
|
||||
if matches:
|
||||
for m in matches:
|
||||
raw_role = m.group(1).strip()
|
||||
content = m.group(2).strip()
|
||||
norm_role = "AI" if raw_role.lower() in ("ai", "assistant") else "用户"
|
||||
messages.append(ConversationMessage(role=norm_role, msg=content))
|
||||
else:
|
||||
# Fallback: line-by-line parsing
|
||||
for raw in ctx.split("\n"):
|
||||
line = raw.strip()
|
||||
if not line:
|
||||
continue
|
||||
m = re.match(r'^\s*([^::]+)\s*[::]\s*(.+)$', line)
|
||||
if m:
|
||||
role = m.group(1).strip()
|
||||
msg = m.group(2).strip()
|
||||
norm_role = "AI" if role.lower() in ("ai", "assistant") else "用户"
|
||||
messages.append(ConversationMessage(role=norm_role, msg=msg))
|
||||
else:
|
||||
# Final fallback: treat as user message
|
||||
default_role = "AI" if re.match(r'^\s*(assistant|AI)\b', line, flags=re.IGNORECASE) else "用户"
|
||||
messages.append(ConversationMessage(role=default_role, msg=line))
|
||||
|
||||
context_model = ConversationContext(msgs=messages)
|
||||
dialog = DialogData(
|
||||
context=context_model,
|
||||
ref_id=f"pipeline_item_{idx}",
|
||||
group_id=group_id,
|
||||
user_id="default_user",
|
||||
apply_id="default_application",
|
||||
)
|
||||
# Generate chunks
|
||||
dialog.chunks = await chunker.process_dialogue(dialog)
|
||||
dialog_data_list.append(dialog)
|
||||
|
||||
if not dialog_data_list:
|
||||
print("No dialogs to process for ingestion.")
|
||||
return False
|
||||
|
||||
# Optionally save chunking outputs for debugging
|
||||
if save_chunk_output:
|
||||
try:
|
||||
def _serialize_datetime(obj):
|
||||
if isinstance(obj, datetime):
|
||||
return obj.isoformat()
|
||||
raise TypeError(f"Object of type {obj.__class__.__name__} is not JSON serializable")
|
||||
|
||||
from app.core.config import settings
|
||||
settings.ensure_memory_output_dir()
|
||||
default_path = settings.get_memory_output_path("chunker_test_output.txt")
|
||||
out_path = save_chunk_output_path or default_path
|
||||
|
||||
combined_output = [dd.model_dump() for dd in dialog_data_list]
|
||||
with open(out_path, "w", encoding="utf-8") as f:
|
||||
json.dump(combined_output, f, ensure_ascii=False, indent=4, default=_serialize_datetime)
|
||||
print(f"Saved chunking results to: {out_path}")
|
||||
except Exception as e:
|
||||
print(f"Failed to save chunking results: {e}")
|
||||
|
||||
# Step B-G: 使用新的 ExtractionOrchestrator 执行完整的提取流水线
|
||||
if not llm_available:
|
||||
print("[Ingestion] Skipping extraction pipeline (no LLM).")
|
||||
return False
|
||||
|
||||
# 初始化 embedder 客户端
|
||||
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
|
||||
try:
|
||||
with get_db_context() as db:
|
||||
embedder_config_dict = MemoryConfigService(db).get_embedder_config(embedding_name or SELECTED_EMBEDDING_ID)
|
||||
embedder_config = RedBearModelConfig(**embedder_config_dict)
|
||||
embedder_client = OpenAIEmbedderClient(embedder_config)
|
||||
except Exception as e:
|
||||
print(f"[Ingestion] Failed to initialize embedder client: {e}")
|
||||
print("[Ingestion] Skipping extraction pipeline (embedder initialization failed).")
|
||||
return False
|
||||
|
||||
connector = Neo4jConnector()
|
||||
|
||||
# 初始化并运行 ExtractionOrchestrator
|
||||
from app.core.memory.utils.config.config_utils import get_pipeline_config
|
||||
config = get_pipeline_config()
|
||||
|
||||
orchestrator = ExtractionOrchestrator(
|
||||
llm_client=llm_client,
|
||||
embedder_client=embedder_client,
|
||||
connector=connector,
|
||||
config=config,
|
||||
)
|
||||
|
||||
# 创建一个包装的 orchestrator 来修复时间提取器的输出
|
||||
# 保存原始的 _assign_extracted_data 方法
|
||||
original_assign = orchestrator._assign_extracted_data
|
||||
|
||||
def clean_temporal_value(value):
|
||||
"""清理 temporal_validity 字段的值,将无效值转换为 None"""
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, str):
|
||||
# 处理字符串形式的 'null', 'None', 空字符串等
|
||||
if value.lower() in ('null', 'none', '') or value.strip() == '':
|
||||
return None
|
||||
return value
|
||||
|
||||
async def patched_assign_extracted_data(*args, **kwargs):
|
||||
"""包装方法:在赋值后清理 temporal_validity 中的无效字符串"""
|
||||
result = await original_assign(*args, **kwargs)
|
||||
|
||||
# 清理返回的 dialog_data_list 中的 temporal_validity
|
||||
for dialog in result:
|
||||
if hasattr(dialog, 'chunks') and dialog.chunks:
|
||||
for chunk in dialog.chunks:
|
||||
if hasattr(chunk, 'statements') and chunk.statements:
|
||||
for statement in chunk.statements:
|
||||
if hasattr(statement, 'temporal_validity') and statement.temporal_validity:
|
||||
tv = statement.temporal_validity
|
||||
# 清理 valid_at 和 invalid_at
|
||||
if hasattr(tv, 'valid_at'):
|
||||
tv.valid_at = clean_temporal_value(tv.valid_at)
|
||||
if hasattr(tv, 'invalid_at'):
|
||||
tv.invalid_at = clean_temporal_value(tv.invalid_at)
|
||||
return result
|
||||
|
||||
# 替换方法
|
||||
orchestrator._assign_extracted_data = patched_assign_extracted_data
|
||||
|
||||
# 同时包装 _create_nodes_and_edges 方法,在创建节点前再次清理
|
||||
original_create = orchestrator._create_nodes_and_edges
|
||||
|
||||
async def patched_create_nodes_and_edges(dialog_data_list_arg):
|
||||
"""包装方法:在创建节点前再次清理 temporal_validity"""
|
||||
# 最后一次清理,确保万无一失
|
||||
for dialog in dialog_data_list_arg:
|
||||
if hasattr(dialog, 'chunks') and dialog.chunks:
|
||||
for chunk in dialog.chunks:
|
||||
if hasattr(chunk, 'statements') and chunk.statements:
|
||||
for statement in chunk.statements:
|
||||
if hasattr(statement, 'temporal_validity') and statement.temporal_validity:
|
||||
tv = statement.temporal_validity
|
||||
if hasattr(tv, 'valid_at'):
|
||||
tv.valid_at = clean_temporal_value(tv.valid_at)
|
||||
if hasattr(tv, 'invalid_at'):
|
||||
tv.invalid_at = clean_temporal_value(tv.invalid_at)
|
||||
|
||||
return await original_create(dialog_data_list_arg)
|
||||
|
||||
orchestrator._create_nodes_and_edges = patched_create_nodes_and_edges
|
||||
|
||||
# 运行完整的提取流水线
|
||||
# orchestrator.run 返回 7 个元素的元组
|
||||
result = await orchestrator.run(dialog_data_list, is_pilot_run=False)
|
||||
(
|
||||
dialogue_nodes,
|
||||
chunk_nodes,
|
||||
statement_nodes,
|
||||
entity_nodes,
|
||||
statement_chunk_edges,
|
||||
statement_entity_edges,
|
||||
entity_entity_edges,
|
||||
) = result
|
||||
|
||||
# statement_chunk_edges 已经由 orchestrator 创建,无需重复创建
|
||||
|
||||
# Step G: 生成记忆摘要
|
||||
print("[Ingestion] Generating memory summaries...")
|
||||
try:
|
||||
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.memory_summary import (
|
||||
memory_summary_generation,
|
||||
)
|
||||
from app.repositories.neo4j.add_edges import add_memory_summary_statement_edges
|
||||
from app.repositories.neo4j.add_nodes import add_memory_summary_nodes
|
||||
|
||||
summaries = await memory_summary_generation(
|
||||
chunked_dialogs=dialog_data_list,
|
||||
llm_client=llm_client,
|
||||
embedder_client=embedder_client
|
||||
)
|
||||
print(f"[Ingestion] Generated {len(summaries)} memory summaries")
|
||||
except Exception as e:
|
||||
print(f"[Ingestion] Warning: Failed to generate memory summaries: {e}")
|
||||
summaries = []
|
||||
|
||||
# Step H: Save to Neo4j
|
||||
try:
|
||||
success = await save_dialog_and_statements_to_neo4j(
|
||||
dialogue_nodes=dialogue_nodes,
|
||||
chunk_nodes=chunk_nodes,
|
||||
statement_nodes=statement_nodes,
|
||||
entity_nodes=entity_nodes,
|
||||
entity_edges=entity_entity_edges,
|
||||
statement_chunk_edges=statement_chunk_edges,
|
||||
statement_entity_edges=statement_entity_edges,
|
||||
connector=connector
|
||||
)
|
||||
|
||||
# Save memory summaries separately
|
||||
if summaries:
|
||||
try:
|
||||
await add_memory_summary_nodes(summaries, connector)
|
||||
await add_memory_summary_statement_edges(summaries, connector)
|
||||
print(f"Successfully saved {len(summaries)} memory summary nodes to Neo4j")
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to save summary nodes: {e}")
|
||||
|
||||
await connector.close()
|
||||
if success:
|
||||
print("Successfully saved extracted data to Neo4j!")
|
||||
else:
|
||||
print("Failed to save data to Neo4j")
|
||||
return success
|
||||
except Exception as e:
|
||||
print(f"Failed to save data to Neo4j: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def handle_context_processing(args):
|
||||
"""Handle context-based processing from command line arguments."""
|
||||
contexts = []
|
||||
|
||||
if args.contexts:
|
||||
contexts.extend(args.contexts)
|
||||
|
||||
if args.context_file:
|
||||
try:
|
||||
with open(args.context_file, 'r', encoding='utf-8') as f:
|
||||
contexts.extend(line.strip() for line in f if line.strip())
|
||||
except Exception as e:
|
||||
print(f"Error reading context file: {e}")
|
||||
return False
|
||||
|
||||
if not contexts:
|
||||
print("No contexts provided for processing.")
|
||||
return False
|
||||
|
||||
return await main_from_contexts(contexts, args.context_group_id)
|
||||
|
||||
|
||||
async def main_from_contexts(contexts: List[str], group_id: str):
|
||||
"""Run the pipeline from provided dialogue contexts instead of test data."""
|
||||
print("=== Running pipeline from provided contexts ===")
|
||||
|
||||
success = await ingest_contexts_via_full_pipeline(
|
||||
contexts=contexts,
|
||||
group_id=group_id,
|
||||
chunker_strategy=SELECTED_CHUNKER_STRATEGY,
|
||||
embedding_name=SELECTED_EMBEDDING_ID,
|
||||
save_chunk_output=True
|
||||
)
|
||||
|
||||
if success:
|
||||
print("Successfully processed and saved contexts to Neo4j!")
|
||||
else:
|
||||
print("Failed to process contexts.")
|
||||
|
||||
return success
|
||||
@@ -1,575 +0,0 @@
|
||||
"""
|
||||
LoCoMo Benchmark Script
|
||||
|
||||
This module provides the main entry point for running LoCoMo benchmark evaluations.
|
||||
It orchestrates data loading, ingestion, retrieval, LLM inference, and metric calculation
|
||||
in a clean, maintainable way.
|
||||
|
||||
Usage:
|
||||
python locomo_benchmark.py --sample_size 20 --search_type hybrid
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
try:
|
||||
from dotenv import load_dotenv
|
||||
except ImportError:
|
||||
def load_dotenv():
|
||||
pass
|
||||
|
||||
from app.core.memory.evaluation.common.metrics import (
|
||||
avg_context_tokens,
|
||||
bleu1,
|
||||
f1_score,
|
||||
jaccard,
|
||||
latency_stats,
|
||||
)
|
||||
from app.core.memory.evaluation.locomo.locomo_metrics import (
|
||||
get_category_name,
|
||||
locomo_f1_score,
|
||||
locomo_multi_f1,
|
||||
)
|
||||
from app.core.memory.evaluation.locomo.locomo_utils import (
|
||||
extract_conversations,
|
||||
ingest_conversations_if_needed,
|
||||
load_locomo_data,
|
||||
resolve_temporal_references,
|
||||
retrieve_relevant_information,
|
||||
select_and_format_information,
|
||||
)
|
||||
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
|
||||
from app.core.memory.utils.definitions import (
|
||||
PROJECT_ROOT,
|
||||
SELECTED_EMBEDDING_ID,
|
||||
SELECTED_GROUP_ID,
|
||||
SELECTED_LLM_ID,
|
||||
)
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
from app.db import get_db_context
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
|
||||
|
||||
async def run_locomo_benchmark(
|
||||
sample_size: int = 20,
|
||||
group_id: Optional[str] = None,
|
||||
search_type: str = "hybrid",
|
||||
search_limit: int = 12,
|
||||
context_char_budget: int = 8000,
|
||||
reset_group: bool = False,
|
||||
skip_ingest: bool = False,
|
||||
output_dir: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Run LoCoMo benchmark evaluation.
|
||||
|
||||
This function orchestrates the complete evaluation pipeline:
|
||||
1. Load LoCoMo dataset (only QA pairs from first conversation)
|
||||
2. Check/ingest conversations into database (only first conversation, unless skip_ingest=True)
|
||||
3. For each question:
|
||||
- Retrieve relevant information
|
||||
- Generate answer using LLM
|
||||
- Calculate metrics
|
||||
4. Aggregate results and save to file
|
||||
|
||||
Note: By default, only the first conversation is ingested into the database,
|
||||
and only QA pairs from that conversation are evaluated. This ensures that
|
||||
all questions have corresponding memory in the database for retrieval.
|
||||
|
||||
Args:
|
||||
sample_size: Number of QA pairs to evaluate (from first conversation)
|
||||
group_id: Database group ID for retrieval (uses default if None)
|
||||
search_type: "keyword", "embedding", or "hybrid"
|
||||
search_limit: Max documents to retrieve per query
|
||||
context_char_budget: Max characters for context
|
||||
reset_group: Whether to clear and re-ingest data (not implemented)
|
||||
skip_ingest: If True, skip data ingestion and use existing data in Neo4j
|
||||
output_dir: Directory to save results (uses default if None)
|
||||
|
||||
Returns:
|
||||
Dictionary with evaluation results including metrics, timing, and samples
|
||||
"""
|
||||
# Use default group_id if not provided
|
||||
group_id = group_id or SELECTED_GROUP_ID
|
||||
|
||||
# Determine data path
|
||||
data_path = os.path.join(PROJECT_ROOT, "data", "locomo10.json")
|
||||
if not os.path.exists(data_path):
|
||||
# Fallback to current directory
|
||||
data_path = os.path.join(os.getcwd(), "data", "locomo10.json")
|
||||
|
||||
print(f"\n{'='*60}")
|
||||
print("🚀 Starting LoCoMo Benchmark Evaluation")
|
||||
print(f"{'='*60}")
|
||||
print("📊 Configuration:")
|
||||
print(f" Sample size: {sample_size}")
|
||||
print(f" Group ID: {group_id}")
|
||||
print(f" Search type: {search_type}")
|
||||
print(f" Search limit: {search_limit}")
|
||||
print(f" Context budget: {context_char_budget} chars")
|
||||
print(f" Data path: {data_path}")
|
||||
print(f"{'='*60}\n")
|
||||
|
||||
# Step 1: Load LoCoMo data
|
||||
print("📂 Loading LoCoMo dataset...")
|
||||
try:
|
||||
# Only load QA pairs from the first conversation (index 0)
|
||||
# since we only ingest the first conversation into the database
|
||||
qa_items = load_locomo_data(data_path, sample_size, conversation_index=0)
|
||||
print(f"✅ Loaded {len(qa_items)} QA pairs from conversation 0\n")
|
||||
except Exception as e:
|
||||
print(f"❌ Failed to load data: {e}")
|
||||
return {
|
||||
"error": f"Data loading failed: {e}",
|
||||
"timestamp": datetime.now().isoformat()
|
||||
}
|
||||
|
||||
# Step 2: Extract conversations and ingest if needed
|
||||
if skip_ingest:
|
||||
print("⏭️ Skipping data ingestion (using existing data in Neo4j)")
|
||||
print(f" Group ID: {group_id}\n")
|
||||
else:
|
||||
print("💾 Checking database ingestion...")
|
||||
try:
|
||||
conversations = extract_conversations(data_path, max_dialogues=1)
|
||||
print(f"📝 Extracted {len(conversations)} conversations")
|
||||
|
||||
# Always ingest for now (ingestion check not implemented)
|
||||
print(f"🔄 Ingesting conversations into group '{group_id}'...")
|
||||
success = await ingest_conversations_if_needed(
|
||||
conversations=conversations,
|
||||
group_id=group_id,
|
||||
reset=reset_group
|
||||
)
|
||||
|
||||
if success:
|
||||
print("✅ Ingestion completed successfully\n")
|
||||
else:
|
||||
print("⚠️ Ingestion may have failed, continuing anyway\n")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Ingestion failed: {e}")
|
||||
print("⚠️ Continuing with evaluation (database may be empty)\n")
|
||||
|
||||
# Step 3: Initialize clients
|
||||
print("🔧 Initializing clients...")
|
||||
connector = Neo4jConnector()
|
||||
|
||||
# Initialize LLM client with database context
|
||||
with get_db_context() as db:
|
||||
factory = MemoryClientFactory(db)
|
||||
llm_client = factory.get_llm_client(SELECTED_LLM_ID)
|
||||
|
||||
# Initialize embedder
|
||||
with get_db_context() as db:
|
||||
config_service = MemoryConfigService(db)
|
||||
cfg_dict = config_service.get_embedder_config(SELECTED_EMBEDDING_ID)
|
||||
embedder = OpenAIEmbedderClient(
|
||||
model_config=RedBearModelConfig.model_validate(cfg_dict)
|
||||
)
|
||||
print("✅ Clients initialized\n")
|
||||
|
||||
# Step 4: Process questions
|
||||
print(f"🔍 Processing {len(qa_items)} questions...")
|
||||
print(f"{'='*60}\n")
|
||||
|
||||
# Tracking variables
|
||||
latencies_search: List[float] = []
|
||||
latencies_llm: List[float] = []
|
||||
context_counts: List[int] = []
|
||||
context_chars: List[int] = []
|
||||
context_tokens: List[int] = []
|
||||
|
||||
# Metric lists
|
||||
f1_scores: List[float] = []
|
||||
bleu1_scores: List[float] = []
|
||||
jaccard_scores: List[float] = []
|
||||
locomo_f1_scores: List[float] = []
|
||||
|
||||
# Per-category tracking
|
||||
category_counts: Dict[str, int] = {}
|
||||
category_f1: Dict[str, List[float]] = {}
|
||||
category_bleu1: Dict[str, List[float]] = {}
|
||||
category_jaccard: Dict[str, List[float]] = {}
|
||||
category_locomo_f1: Dict[str, List[float]] = {}
|
||||
|
||||
# Detailed samples
|
||||
samples: List[Dict[str, Any]] = []
|
||||
|
||||
# Fixed anchor date for temporal resolution
|
||||
anchor_date = datetime(2023, 5, 8)
|
||||
|
||||
try:
|
||||
for idx, item in enumerate(qa_items, 1):
|
||||
question = item.get("question", "")
|
||||
ground_truth = item.get("answer", "")
|
||||
category = get_category_name(item)
|
||||
|
||||
# Ensure ground truth is a string
|
||||
ground_truth_str = str(ground_truth) if ground_truth is not None else ""
|
||||
|
||||
print(f"[{idx}/{len(qa_items)}] Category: {category}")
|
||||
print(f"❓ Question: {question}")
|
||||
print(f"✅ Ground Truth: {ground_truth_str}")
|
||||
|
||||
# Step 4a: Retrieve relevant information
|
||||
t_search_start = time.time()
|
||||
try:
|
||||
retrieved_info = await retrieve_relevant_information(
|
||||
question=question,
|
||||
group_id=group_id,
|
||||
search_type=search_type,
|
||||
search_limit=search_limit,
|
||||
connector=connector,
|
||||
embedder=embedder
|
||||
)
|
||||
t_search_end = time.time()
|
||||
search_latency = (t_search_end - t_search_start) * 1000
|
||||
latencies_search.append(search_latency)
|
||||
|
||||
print(f"🔍 Retrieved {len(retrieved_info)} documents ({search_latency:.1f}ms)")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Retrieval failed: {e}")
|
||||
retrieved_info = []
|
||||
search_latency = 0.0
|
||||
latencies_search.append(search_latency)
|
||||
|
||||
# Step 4b: Select and format context
|
||||
context_text = select_and_format_information(
|
||||
retrieved_info=retrieved_info,
|
||||
question=question,
|
||||
max_chars=context_char_budget
|
||||
)
|
||||
|
||||
# Resolve temporal references
|
||||
context_text = resolve_temporal_references(context_text, anchor_date)
|
||||
|
||||
# Add reference date to context
|
||||
if context_text:
|
||||
context_text = f"Reference date: {anchor_date.date().isoformat()}\n\n{context_text}"
|
||||
else:
|
||||
context_text = "No relevant context found."
|
||||
|
||||
# Track context statistics
|
||||
context_counts.append(len(retrieved_info))
|
||||
context_chars.append(len(context_text))
|
||||
context_tokens.append(len(context_text.split()))
|
||||
|
||||
print(f"📝 Context: {len(context_text)} chars, {len(retrieved_info)} docs")
|
||||
|
||||
# Step 4c: Generate answer with LLM
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": (
|
||||
"You are a precise QA assistant. Answer following these rules:\n"
|
||||
"1) Extract the EXACT information mentioned in the context\n"
|
||||
"2) For time questions: calculate actual dates from relative times\n"
|
||||
"3) Return ONLY the answer text in simplest form\n"
|
||||
"4) For dates, use format 'DD Month YYYY' (e.g., '7 May 2023')\n"
|
||||
"5) If no clear answer found, respond with 'Unknown'"
|
||||
)
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"Question: {question}\n\nContext:\n{context_text}"
|
||||
}
|
||||
]
|
||||
|
||||
t_llm_start = time.time()
|
||||
try:
|
||||
response = await llm_client.chat(messages=messages)
|
||||
t_llm_end = time.time()
|
||||
llm_latency = (t_llm_end - t_llm_start) * 1000
|
||||
latencies_llm.append(llm_latency)
|
||||
|
||||
# Extract prediction from response
|
||||
if hasattr(response, 'content'):
|
||||
prediction = response.content.strip()
|
||||
elif isinstance(response, dict):
|
||||
prediction = response["choices"][0]["message"]["content"].strip()
|
||||
else:
|
||||
prediction = "Unknown"
|
||||
|
||||
print(f"🤖 Prediction: {prediction} ({llm_latency:.1f}ms)")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ LLM failed: {e}")
|
||||
prediction = "Unknown"
|
||||
llm_latency = 0.0
|
||||
latencies_llm.append(llm_latency)
|
||||
|
||||
# Step 4d: Calculate metrics
|
||||
f1_val = f1_score(prediction, ground_truth_str)
|
||||
bleu1_val = bleu1(prediction, ground_truth_str)
|
||||
jaccard_val = jaccard(prediction, ground_truth_str)
|
||||
|
||||
# LoCoMo-specific F1: use multi-answer for category 1 (Multi-Hop)
|
||||
if item.get("category") == 1:
|
||||
locomo_f1_val = locomo_multi_f1(prediction, ground_truth_str)
|
||||
else:
|
||||
locomo_f1_val = locomo_f1_score(prediction, ground_truth_str)
|
||||
|
||||
# Accumulate metrics
|
||||
f1_scores.append(f1_val)
|
||||
bleu1_scores.append(bleu1_val)
|
||||
jaccard_scores.append(jaccard_val)
|
||||
locomo_f1_scores.append(locomo_f1_val)
|
||||
|
||||
# Track by category
|
||||
category_counts[category] = category_counts.get(category, 0) + 1
|
||||
category_f1.setdefault(category, []).append(f1_val)
|
||||
category_bleu1.setdefault(category, []).append(bleu1_val)
|
||||
category_jaccard.setdefault(category, []).append(jaccard_val)
|
||||
category_locomo_f1.setdefault(category, []).append(locomo_f1_val)
|
||||
|
||||
print(f"📊 Metrics - F1: {f1_val:.3f}, BLEU-1: {bleu1_val:.3f}, "
|
||||
f"Jaccard: {jaccard_val:.3f}, LoCoMo F1: {locomo_f1_val:.3f}")
|
||||
print()
|
||||
|
||||
# Save sample details
|
||||
samples.append({
|
||||
"question": question,
|
||||
"ground_truth": ground_truth_str,
|
||||
"prediction": prediction,
|
||||
"category": category,
|
||||
"metrics": {
|
||||
"f1": f1_val,
|
||||
"bleu1": bleu1_val,
|
||||
"jaccard": jaccard_val,
|
||||
"locomo_f1": locomo_f1_val
|
||||
},
|
||||
"retrieval": {
|
||||
"num_docs": len(retrieved_info),
|
||||
"context_length": len(context_text)
|
||||
},
|
||||
"timing": {
|
||||
"search_ms": search_latency,
|
||||
"llm_ms": llm_latency
|
||||
}
|
||||
})
|
||||
|
||||
finally:
|
||||
# Close connector
|
||||
await connector.close()
|
||||
|
||||
# Step 5: Aggregate results
|
||||
print(f"\n{'='*60}")
|
||||
print("📊 Aggregating Results")
|
||||
print(f"{'='*60}\n")
|
||||
|
||||
# Overall metrics
|
||||
overall_metrics = {
|
||||
"f1": sum(f1_scores) / max(len(f1_scores), 1) if f1_scores else 0.0,
|
||||
"bleu1": sum(bleu1_scores) / max(len(bleu1_scores), 1) if bleu1_scores else 0.0,
|
||||
"jaccard": sum(jaccard_scores) / max(len(jaccard_scores), 1) if jaccard_scores else 0.0,
|
||||
"locomo_f1": sum(locomo_f1_scores) / max(len(locomo_f1_scores), 1) if locomo_f1_scores else 0.0
|
||||
}
|
||||
|
||||
# Per-category metrics
|
||||
by_category: Dict[str, Dict[str, Any]] = {}
|
||||
for cat in category_counts:
|
||||
f1_list = category_f1.get(cat, [])
|
||||
b1_list = category_bleu1.get(cat, [])
|
||||
j_list = category_jaccard.get(cat, [])
|
||||
lf_list = category_locomo_f1.get(cat, [])
|
||||
|
||||
by_category[cat] = {
|
||||
"count": category_counts[cat],
|
||||
"f1": sum(f1_list) / max(len(f1_list), 1) if f1_list else 0.0,
|
||||
"bleu1": sum(b1_list) / max(len(b1_list), 1) if b1_list else 0.0,
|
||||
"jaccard": sum(j_list) / max(len(j_list), 1) if j_list else 0.0,
|
||||
"locomo_f1": sum(lf_list) / max(len(lf_list), 1) if lf_list else 0.0
|
||||
}
|
||||
|
||||
# Latency statistics
|
||||
latency = {
|
||||
"search": latency_stats(latencies_search),
|
||||
"llm": latency_stats(latencies_llm)
|
||||
}
|
||||
|
||||
# Context statistics
|
||||
context_stats = {
|
||||
"avg_retrieved_docs": sum(context_counts) / max(len(context_counts), 1) if context_counts else 0.0,
|
||||
"avg_context_chars": sum(context_chars) / max(len(context_chars), 1) if context_chars else 0.0,
|
||||
"avg_context_tokens": sum(context_tokens) / max(len(context_tokens), 1) if context_tokens else 0.0
|
||||
}
|
||||
|
||||
# Build result dictionary
|
||||
result = {
|
||||
"dataset": "locomo",
|
||||
"sample_size": len(qa_items),
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"params": {
|
||||
"group_id": group_id,
|
||||
"search_type": search_type,
|
||||
"search_limit": search_limit,
|
||||
"context_char_budget": context_char_budget,
|
||||
"llm_id": SELECTED_LLM_ID,
|
||||
"embedding_id": SELECTED_EMBEDDING_ID
|
||||
},
|
||||
"overall_metrics": overall_metrics,
|
||||
"by_category": by_category,
|
||||
"latency": latency,
|
||||
"context_stats": context_stats,
|
||||
"samples": samples
|
||||
}
|
||||
|
||||
# Step 6: Save results
|
||||
if output_dir is None:
|
||||
output_dir = os.path.join(
|
||||
os.path.dirname(__file__),
|
||||
"results"
|
||||
)
|
||||
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
# Generate timestamped filename
|
||||
timestamp_str = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
output_path = os.path.join(output_dir, f"locomo_{timestamp_str}.json")
|
||||
|
||||
try:
|
||||
with open(output_path, "w", encoding="utf-8") as f:
|
||||
json.dump(result, f, ensure_ascii=False, indent=2)
|
||||
print(f"✅ Results saved to: {output_path}\n")
|
||||
except Exception as e:
|
||||
print(f"❌ Failed to save results: {e}")
|
||||
print("📊 Printing results to console instead:\n")
|
||||
print(json.dumps(result, ensure_ascii=False, indent=2))
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def main():
|
||||
"""
|
||||
Parse command-line arguments and run benchmark.
|
||||
|
||||
This function provides a CLI interface for running LoCoMo benchmarks
|
||||
with configurable parameters.
|
||||
"""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Run LoCoMo benchmark evaluation",
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--sample_size",
|
||||
type=int,
|
||||
default=20,
|
||||
help="Number of QA pairs to evaluate"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--group_id",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Database group ID for retrieval (uses default if not specified)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--search_type",
|
||||
type=str,
|
||||
default="hybrid",
|
||||
choices=["keyword", "embedding", "hybrid"],
|
||||
help="Search strategy to use"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--search_limit",
|
||||
type=int,
|
||||
default=12,
|
||||
help="Maximum number of documents to retrieve per query"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--context_char_budget",
|
||||
type=int,
|
||||
default=8000,
|
||||
help="Maximum characters for context"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--reset_group",
|
||||
action="store_true",
|
||||
help="Clear and re-ingest data (not implemented)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--skip_ingest",
|
||||
action="store_true",
|
||||
help="Skip data ingestion and use existing data in Neo4j"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Directory to save results (uses default if not specified)"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Load environment variables
|
||||
load_dotenv()
|
||||
|
||||
# Run benchmark
|
||||
result = asyncio.run(run_locomo_benchmark(
|
||||
sample_size=args.sample_size,
|
||||
group_id=args.group_id,
|
||||
search_type=args.search_type,
|
||||
search_limit=args.search_limit,
|
||||
context_char_budget=args.context_char_budget,
|
||||
reset_group=args.reset_group,
|
||||
skip_ingest=args.skip_ingest,
|
||||
output_dir=args.output_dir
|
||||
))
|
||||
|
||||
# Print summary
|
||||
print(f"\n{'='*60}")
|
||||
|
||||
# Check if there was an error
|
||||
if 'error' in result:
|
||||
print("❌ Benchmark Failed!")
|
||||
print(f"{'='*60}")
|
||||
print(f"Error: {result['error']}")
|
||||
return
|
||||
|
||||
print("🎉 Benchmark Complete!")
|
||||
print(f"{'='*60}")
|
||||
print("📊 Final Results:")
|
||||
print(f" Sample size: {result.get('sample_size', 0)}")
|
||||
print(f" F1: {result['overall_metrics']['f1']:.3f}")
|
||||
print(f" BLEU-1: {result['overall_metrics']['bleu1']:.3f}")
|
||||
print(f" Jaccard: {result['overall_metrics']['jaccard']:.3f}")
|
||||
print(f" LoCoMo F1: {result['overall_metrics']['locomo_f1']:.3f}")
|
||||
|
||||
if result.get('context_stats'):
|
||||
print("\n📈 Context Statistics:")
|
||||
print(f" Avg retrieved docs: {result['context_stats']['avg_retrieved_docs']:.1f}")
|
||||
print(f" Avg context chars: {result['context_stats']['avg_context_chars']:.0f}")
|
||||
print(f" Avg context tokens: {result['context_stats']['avg_context_tokens']:.0f}")
|
||||
|
||||
if result.get('latency'):
|
||||
print("\n⏱️ Latency Statistics:")
|
||||
print(f" Search - Mean: {result['latency']['search']['mean']:.1f}ms, "
|
||||
f"P50: {result['latency']['search']['p50']:.1f}ms, "
|
||||
f"P95: {result['latency']['search']['p95']:.1f}ms")
|
||||
print(f" LLM - Mean: {result['latency']['llm']['mean']:.1f}ms, "
|
||||
f"P50: {result['latency']['llm']['p50']:.1f}ms, "
|
||||
f"P95: {result['latency']['llm']['p95']:.1f}ms")
|
||||
|
||||
if result.get('by_category'):
|
||||
print("\n📂 Results by Category:")
|
||||
for cat, metrics in result['by_category'].items():
|
||||
print(f" {cat}:")
|
||||
print(f" Count: {metrics['count']}")
|
||||
print(f" F1: {metrics['f1']:.3f}")
|
||||
print(f" LoCoMo F1: {metrics['locomo_f1']:.3f}")
|
||||
print(f" Jaccard: {metrics['jaccard']:.3f}")
|
||||
|
||||
print(f"\n{'='*60}\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,225 +0,0 @@
|
||||
"""
|
||||
LoCoMo-specific metric calculations.
|
||||
|
||||
This module provides clean, simplified implementations of metrics used for
|
||||
LoCoMo benchmark evaluation, including text normalization and F1 score variants.
|
||||
"""
|
||||
|
||||
import re
|
||||
from typing import Dict, Any
|
||||
|
||||
|
||||
def normalize_text(text: str) -> str:
|
||||
"""
|
||||
Normalize text for LoCoMo evaluation.
|
||||
|
||||
Normalization steps:
|
||||
- Convert to lowercase
|
||||
- Remove commas
|
||||
- Remove stop words (a, an, the, and)
|
||||
- Remove punctuation
|
||||
- Normalize whitespace
|
||||
|
||||
Args:
|
||||
text: Input text to normalize
|
||||
|
||||
Returns:
|
||||
Normalized text string with consistent formatting
|
||||
|
||||
Examples:
|
||||
>>> normalize_text("The cat, and the dog")
|
||||
'cat dog'
|
||||
>>> normalize_text("Hello, World!")
|
||||
'hello world'
|
||||
"""
|
||||
# Ensure input is a string
|
||||
text = str(text) if text is not None else ""
|
||||
|
||||
# Convert to lowercase
|
||||
text = text.lower()
|
||||
|
||||
# Remove commas
|
||||
text = re.sub(r"[\,]", " ", text)
|
||||
|
||||
# Remove stop words
|
||||
text = re.sub(r"\b(a|an|the|and)\b", " ", text)
|
||||
|
||||
# Remove punctuation (keep only word characters and whitespace)
|
||||
text = re.sub(r"[^\w\s]", " ", text)
|
||||
|
||||
# Normalize whitespace (collapse multiple spaces to single space)
|
||||
text = " ".join(text.split())
|
||||
|
||||
return text
|
||||
|
||||
|
||||
def locomo_f1_score(prediction: str, ground_truth: str) -> float:
|
||||
"""
|
||||
Calculate LoCoMo F1 score for single-answer questions.
|
||||
|
||||
Uses token-level precision and recall based on normalized text.
|
||||
Treats tokens as sets (no duplicate counting).
|
||||
|
||||
Args:
|
||||
prediction: Model's predicted answer
|
||||
ground_truth: Correct answer
|
||||
|
||||
Returns:
|
||||
F1 score between 0.0 and 1.0
|
||||
|
||||
Examples:
|
||||
>>> locomo_f1_score("Paris", "Paris")
|
||||
1.0
|
||||
>>> locomo_f1_score("The cat", "cat")
|
||||
1.0
|
||||
>>> locomo_f1_score("dog", "cat")
|
||||
0.0
|
||||
"""
|
||||
# Ensure inputs are strings
|
||||
pred_str = str(prediction) if prediction is not None else ""
|
||||
truth_str = str(ground_truth) if ground_truth is not None else ""
|
||||
|
||||
# Normalize and tokenize
|
||||
pred_tokens = normalize_text(pred_str).split()
|
||||
truth_tokens = normalize_text(truth_str).split()
|
||||
|
||||
# Handle empty cases
|
||||
if not pred_tokens or not truth_tokens:
|
||||
return 0.0
|
||||
|
||||
# Convert to sets for comparison
|
||||
pred_set = set(pred_tokens)
|
||||
truth_set = set(truth_tokens)
|
||||
|
||||
# Calculate true positives (intersection)
|
||||
true_positives = len(pred_set & truth_set)
|
||||
|
||||
# Calculate precision and recall
|
||||
precision = true_positives / len(pred_set) if pred_set else 0.0
|
||||
recall = true_positives / len(truth_set) if truth_set else 0.0
|
||||
|
||||
# Calculate F1 score
|
||||
if precision + recall == 0:
|
||||
return 0.0
|
||||
|
||||
f1 = 2 * precision * recall / (precision + recall)
|
||||
return f1
|
||||
|
||||
|
||||
def locomo_multi_f1(prediction: str, ground_truth: str) -> float:
|
||||
"""
|
||||
Calculate LoCoMo F1 score for multi-answer questions.
|
||||
|
||||
Handles comma-separated answers by:
|
||||
1. Splitting both prediction and ground truth by commas
|
||||
2. For each ground truth answer, finding the best matching prediction
|
||||
3. Averaging the F1 scores across all ground truth answers
|
||||
|
||||
Args:
|
||||
prediction: Model's predicted answer (may contain multiple comma-separated answers)
|
||||
ground_truth: Correct answer (may contain multiple comma-separated answers)
|
||||
|
||||
Returns:
|
||||
Average F1 score across all ground truth answers (0.0 to 1.0)
|
||||
|
||||
Examples:
|
||||
>>> locomo_multi_f1("Paris, London", "Paris, London")
|
||||
1.0
|
||||
>>> locomo_multi_f1("Paris", "Paris, London")
|
||||
0.5
|
||||
>>> locomo_multi_f1("Paris, Berlin", "Paris, London")
|
||||
0.5
|
||||
"""
|
||||
# Ensure inputs are strings
|
||||
pred_str = str(prediction) if prediction is not None else ""
|
||||
truth_str = str(ground_truth) if ground_truth is not None else ""
|
||||
|
||||
# Split by commas and strip whitespace
|
||||
predictions = [p.strip() for p in pred_str.split(',') if p.strip()]
|
||||
ground_truths = [g.strip() for g in truth_str.split(',') if g.strip()]
|
||||
|
||||
# Handle empty cases
|
||||
if not predictions or not ground_truths:
|
||||
return 0.0
|
||||
|
||||
# For each ground truth, find the best matching prediction
|
||||
f1_scores = []
|
||||
for gt in ground_truths:
|
||||
# Calculate F1 with each prediction and take the maximum
|
||||
best_f1 = max(locomo_f1_score(pred, gt) for pred in predictions)
|
||||
f1_scores.append(best_f1)
|
||||
|
||||
# Return average F1 across all ground truths
|
||||
return sum(f1_scores) / len(f1_scores)
|
||||
|
||||
|
||||
def get_category_name(item: Dict[str, Any]) -> str:
|
||||
"""
|
||||
Extract and normalize category name from QA item.
|
||||
|
||||
Handles both numeric categories (1-4) and string categories with various formats.
|
||||
Supports multiple field names: "cat", "category", "type".
|
||||
|
||||
Category mapping:
|
||||
- 1 or "multi-hop" -> "Multi-Hop"
|
||||
- 2 or "temporal" -> "Temporal"
|
||||
- 3 or "open domain" -> "Open Domain"
|
||||
- 4 or "single-hop" -> "Single-Hop"
|
||||
|
||||
Args:
|
||||
item: QA item dictionary containing category information
|
||||
|
||||
Returns:
|
||||
Standardized category name or "unknown" if not found
|
||||
|
||||
Examples:
|
||||
>>> get_category_name({"category": 1})
|
||||
'Multi-Hop'
|
||||
>>> get_category_name({"cat": "temporal"})
|
||||
'Temporal'
|
||||
>>> get_category_name({"type": "Single-Hop"})
|
||||
'Single-Hop'
|
||||
"""
|
||||
# Numeric category mapping
|
||||
CATEGORY_MAP = {
|
||||
1: "Multi-Hop",
|
||||
2: "Temporal",
|
||||
3: "Open Domain",
|
||||
4: "Single-Hop",
|
||||
}
|
||||
|
||||
# String category aliases (case-insensitive)
|
||||
TYPE_ALIASES = {
|
||||
"single-hop": "Single-Hop",
|
||||
"singlehop": "Single-Hop",
|
||||
"single hop": "Single-Hop",
|
||||
"multi-hop": "Multi-Hop",
|
||||
"multihop": "Multi-Hop",
|
||||
"multi hop": "Multi-Hop",
|
||||
"open domain": "Open Domain",
|
||||
"opendomain": "Open Domain",
|
||||
"temporal": "Temporal",
|
||||
}
|
||||
|
||||
# Try "cat" field first (string category)
|
||||
cat = item.get("cat")
|
||||
if isinstance(cat, str) and cat.strip():
|
||||
name = cat.strip()
|
||||
lower = name.lower()
|
||||
return TYPE_ALIASES.get(lower, name)
|
||||
|
||||
# Try "category" field (can be int or string)
|
||||
cat_num = item.get("category")
|
||||
if isinstance(cat_num, int):
|
||||
return CATEGORY_MAP.get(cat_num, "unknown")
|
||||
elif isinstance(cat_num, str) and cat_num.strip():
|
||||
lower = cat_num.strip().lower()
|
||||
return TYPE_ALIASES.get(lower, cat_num.strip())
|
||||
|
||||
# Try "type" field as fallback
|
||||
cat_type = item.get("type")
|
||||
if isinstance(cat_type, str) and cat_type.strip():
|
||||
lower = cat_type.strip().lower()
|
||||
return TYPE_ALIASES.get(lower, cat_type.strip())
|
||||
|
||||
return "unknown"
|
||||
@@ -1,810 +0,0 @@
|
||||
# file name: check_neo4j_connection_fixed.py
|
||||
import asyncio
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import time
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# 1
|
||||
# 添加项目根目录到路径
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
project_root = os.path.dirname(current_dir)
|
||||
if project_root not in sys.path:
|
||||
sys.path.insert(0, project_root)
|
||||
# 关键:将 src 目录置于最前,确保从当前仓库加载模块
|
||||
src_dir = os.path.join(project_root, "src")
|
||||
if src_dir not in sys.path:
|
||||
sys.path.insert(0, src_dir)
|
||||
|
||||
load_dotenv()
|
||||
|
||||
# 首先定义 _loc_normalize 函数,因为其他函数依赖它
|
||||
def _loc_normalize(text: str) -> str:
|
||||
text = str(text) if text is not None else ""
|
||||
text = text.lower()
|
||||
text = re.sub(r"[\,]", " ", text)
|
||||
text = re.sub(r"\b(a|an|the|and)\b", " ", text)
|
||||
text = re.sub(r"[^\w\s]", " ", text)
|
||||
text = " ".join(text.split())
|
||||
return text
|
||||
|
||||
# 尝试从 metrics.py 导入基础指标
|
||||
try:
|
||||
from common.metrics import bleu1, f1_score, jaccard
|
||||
print("✅ 从 metrics.py 导入基础指标成功")
|
||||
except ImportError as e:
|
||||
print(f"❌ 从 metrics.py 导入失败: {e}")
|
||||
# 回退到本地实现
|
||||
def f1_score(pred: str, ref: str) -> float:
|
||||
pred_str = str(pred) if pred is not None else ""
|
||||
ref_str = str(ref) if ref is not None else ""
|
||||
|
||||
p_tokens = _loc_normalize(pred_str).split()
|
||||
r_tokens = _loc_normalize(ref_str).split()
|
||||
if not p_tokens and not r_tokens:
|
||||
return 1.0
|
||||
if not p_tokens or not r_tokens:
|
||||
return 0.0
|
||||
p_set = set(p_tokens)
|
||||
r_set = set(r_tokens)
|
||||
tp = len(p_set & r_set)
|
||||
precision = tp / len(p_set) if p_set else 0.0
|
||||
recall = tp / len(r_set) if r_set else 0.0
|
||||
if precision + recall == 0:
|
||||
return 0.0
|
||||
return 2 * precision * recall / (precision + recall)
|
||||
|
||||
def bleu1(pred: str, ref: str) -> float:
|
||||
pred_str = str(pred) if pred is not None else ""
|
||||
ref_str = str(ref) if ref is not None else ""
|
||||
|
||||
p_tokens = _loc_normalize(pred_str).split()
|
||||
r_tokens = _loc_normalize(ref_str).split()
|
||||
if not p_tokens:
|
||||
return 0.0
|
||||
|
||||
r_counts = {}
|
||||
for t in r_tokens:
|
||||
r_counts[t] = r_counts.get(t, 0) + 1
|
||||
|
||||
clipped = 0
|
||||
p_counts = {}
|
||||
for t in p_tokens:
|
||||
p_counts[t] = p_counts.get(t, 0) + 1
|
||||
|
||||
for t, c in p_counts.items():
|
||||
clipped += min(c, r_counts.get(t, 0))
|
||||
|
||||
precision = clipped / max(len(p_tokens), 1)
|
||||
ref_len = len(r_tokens)
|
||||
pred_len = len(p_tokens)
|
||||
|
||||
if pred_len > ref_len or pred_len == 0:
|
||||
bp = 1.0
|
||||
else:
|
||||
bp = math.exp(1 - ref_len / max(pred_len, 1))
|
||||
|
||||
return bp * precision
|
||||
|
||||
def jaccard(pred: str, ref: str) -> float:
|
||||
pred_str = str(pred) if pred is not None else ""
|
||||
ref_str = str(ref) if ref is not None else ""
|
||||
|
||||
p = set(_loc_normalize(pred_str).split())
|
||||
r = set(_loc_normalize(ref_str).split())
|
||||
if not p and not r:
|
||||
return 1.0
|
||||
if not p or not r:
|
||||
return 0.0
|
||||
return len(p & r) / len(p | r)
|
||||
|
||||
# 尝试从 qwen_search_eval.py 导入 LoCoMo 特定指标
|
||||
try:
|
||||
# 添加 evaluation 目录路径
|
||||
evaluation_dir = os.path.join(project_root, "evaluation")
|
||||
if evaluation_dir not in sys.path:
|
||||
sys.path.insert(0, evaluation_dir)
|
||||
|
||||
# 尝试从不同位置导入
|
||||
try:
|
||||
from locomo.qwen_search_eval import (
|
||||
_resolve_relative_times,
|
||||
loc_f1_score,
|
||||
loc_multi_f1,
|
||||
)
|
||||
print("✅ 从 locomo.qwen_search_eval 导入 LoCoMo 特定指标成功")
|
||||
except ImportError:
|
||||
from qwen_search_eval import _resolve_relative_times, loc_f1_score, loc_multi_f1
|
||||
print("✅ 从 qwen_search_eval 导入 LoCoMo 特定指标成功")
|
||||
|
||||
except ImportError as e:
|
||||
print(f"❌ 从 qwen_search_eval.py 导入失败: {e}")
|
||||
# 回退到本地实现 LoCoMo 特定函数
|
||||
def _resolve_relative_times(text: str, anchor: datetime) -> str:
|
||||
t = str(text) if text is not None else ""
|
||||
t = re.sub(r"\btoday\b", anchor.date().isoformat(), t, flags=re.IGNORECASE)
|
||||
t = re.sub(r"\byesterday\b", (anchor - timedelta(days=1)).date().isoformat(), t, flags=re.IGNORECASE)
|
||||
t = re.sub(r"\btomorrow\b", (anchor + timedelta(days=1)).date().isoformat(), t, flags=re.IGNORECASE)
|
||||
|
||||
def _ago_repl(m: re.Match[str]) -> str:
|
||||
n = int(m.group(1))
|
||||
return (anchor - timedelta(days=n)).date().isoformat()
|
||||
def _in_repl(m: re.Match[str]) -> str:
|
||||
n = int(m.group(1))
|
||||
return (anchor + timedelta(days=n)).date().isoformat()
|
||||
|
||||
t = re.sub(r"\b(\d+)\s+days\s+ago\b", _ago_repl, t, flags=re.IGNORECASE)
|
||||
t = re.sub(r"\bin\s+(\d+)\s+days\b", _in_repl, t, flags=re.IGNORECASE)
|
||||
t = re.sub(r"\blast\s+week\b", (anchor - timedelta(days=7)).date().isoformat(), t, flags=re.IGNORECASE)
|
||||
t = re.sub(r"\bnext\s+week\b", (anchor + timedelta(days=7)).date().isoformat(), t, flags=re.IGNORECASE)
|
||||
return t
|
||||
|
||||
def loc_f1_score(prediction: str, ground_truth: str) -> float:
|
||||
p_tokens = _loc_normalize(prediction).split()
|
||||
g_tokens = _loc_normalize(ground_truth).split()
|
||||
if not p_tokens or not g_tokens:
|
||||
return 0.0
|
||||
p = set(p_tokens)
|
||||
g = set(g_tokens)
|
||||
tp = len(p & g)
|
||||
precision = tp / len(p) if p else 0.0
|
||||
recall = tp / len(g) if g else 0.0
|
||||
return (2 * precision * recall / (precision + recall)) if (precision + recall) > 0 else 0.0
|
||||
|
||||
def loc_multi_f1(prediction: str, ground_truth: str) -> float:
|
||||
predictions = [p.strip() for p in str(prediction).split(',') if p.strip()]
|
||||
ground_truths = [g.strip() for g in str(ground_truth).split(',') if g.strip()]
|
||||
if not predictions or not ground_truths:
|
||||
return 0.0
|
||||
def _f1(a: str, b: str) -> float:
|
||||
return loc_f1_score(a, b)
|
||||
vals = []
|
||||
for gt in ground_truths:
|
||||
vals.append(max(_f1(pred, gt) for pred in predictions))
|
||||
return sum(vals) / len(vals)
|
||||
|
||||
|
||||
def smart_context_selection(contexts: List[str], question: str, max_chars: int = 8000) -> str:
|
||||
"""基于问题关键词智能选择上下文"""
|
||||
if not contexts:
|
||||
return ""
|
||||
|
||||
# 提取问题关键词(只保留有意义的词)
|
||||
question_lower = question.lower()
|
||||
stop_words = {'what', 'when', 'where', 'who', 'why', 'how', 'did', 'do', 'does', 'is', 'are', 'was', 'were', 'the', 'a', 'an', 'and', 'or', 'but'}
|
||||
question_words = set(re.findall(r'\b\w+\b', question_lower))
|
||||
question_words = {word for word in question_words if word not in stop_words and len(word) > 2}
|
||||
|
||||
print(f"🔍 问题关键词: {question_words}")
|
||||
|
||||
# 给每个上下文打分
|
||||
scored_contexts = []
|
||||
for i, context in enumerate(contexts):
|
||||
context_lower = context.lower()
|
||||
score = 0
|
||||
|
||||
# 关键词匹配得分
|
||||
keyword_matches = 0
|
||||
for word in question_words:
|
||||
if word in context_lower:
|
||||
keyword_matches += 1
|
||||
# 关键词出现次数越多,得分越高
|
||||
score += context_lower.count(word) * 2
|
||||
|
||||
# 上下文长度得分(适中的长度更好)
|
||||
context_len = len(context)
|
||||
if 100 < context_len < 2000: # 理想长度范围
|
||||
score += 5
|
||||
elif context_len >= 2000: # 太长可能包含无关信息
|
||||
score += 2
|
||||
|
||||
# 如果是前几个上下文,给予额外分数(通常相关性更高)
|
||||
if i < 3:
|
||||
score += 3
|
||||
|
||||
scored_contexts.append((score, context, keyword_matches))
|
||||
|
||||
# 按得分排序
|
||||
scored_contexts.sort(key=lambda x: x[0], reverse=True)
|
||||
|
||||
# 选择高得分的上下文,直到达到字符限制
|
||||
selected = []
|
||||
total_chars = 0
|
||||
selected_count = 0
|
||||
|
||||
print("📊 上下文相关性分析:")
|
||||
for score, context, matches in scored_contexts[:5]: # 只显示前5个
|
||||
print(f" - 得分: {score}, 关键词匹配: {matches}, 长度: {len(context)}")
|
||||
|
||||
for score, context, matches in scored_contexts:
|
||||
if total_chars + len(context) <= max_chars:
|
||||
selected.append(context)
|
||||
total_chars += len(context)
|
||||
selected_count += 1
|
||||
else:
|
||||
# 如果这个上下文得分很高但放不下,尝试截取
|
||||
if score > 10 and total_chars < max_chars - 500:
|
||||
remaining = max_chars - total_chars
|
||||
# 找到包含关键词的部分
|
||||
lines = context.split('\n')
|
||||
relevant_lines = []
|
||||
current_chars = 0
|
||||
|
||||
for line in lines:
|
||||
line_lower = line.lower()
|
||||
line_relevance = any(word in line_lower for word in question_words)
|
||||
|
||||
if line_relevance and current_chars < remaining - 100:
|
||||
relevant_lines.append(line)
|
||||
current_chars += len(line)
|
||||
|
||||
if relevant_lines:
|
||||
truncated = '\n'.join(relevant_lines)
|
||||
if len(truncated) > 100: # 确保有足够内容
|
||||
selected.append(truncated + "\n[相关内容截断...]")
|
||||
total_chars += len(truncated)
|
||||
selected_count += 1
|
||||
break # 不再尝试添加更多上下文
|
||||
|
||||
result = "\n\n".join(selected)
|
||||
print(f"✅ 智能选择: {selected_count}个上下文, 总长度: {total_chars}字符")
|
||||
return result
|
||||
|
||||
|
||||
def get_dynamic_search_params(question: str, question_index: int, total_questions: int):
|
||||
"""根据问题复杂度和进度动态调整检索参数"""
|
||||
|
||||
# 分析问题复杂度
|
||||
word_count = len(question.split())
|
||||
has_temporal = any(word in question.lower() for word in ['when', 'date', 'time', 'ago'])
|
||||
has_multi_hop = any(word in question.lower() for word in ['and', 'both', 'also', 'while'])
|
||||
|
||||
# 根据进度调整 - 后期问题可能需要更精确的检索
|
||||
progress_factor = question_index / total_questions
|
||||
|
||||
base_limit = 12
|
||||
if has_temporal and has_multi_hop:
|
||||
base_limit = 20
|
||||
elif word_count > 8:
|
||||
base_limit = 16
|
||||
|
||||
# 随着测试进行,逐渐收紧检索范围
|
||||
adjusted_limit = max(8, int(base_limit * (1 - progress_factor * 0.3)))
|
||||
|
||||
# 动态调整最大字符数
|
||||
max_chars = 8000 + 4000 * (1 - progress_factor)
|
||||
|
||||
return {
|
||||
"limit": adjusted_limit,
|
||||
"max_chars": int(max_chars)
|
||||
}
|
||||
|
||||
|
||||
class EnhancedEvaluationMonitor:
|
||||
def __init__(self, reset_interval=5, performance_threshold=0.6):
|
||||
self.question_count = 0
|
||||
self.reset_interval = reset_interval
|
||||
self.performance_threshold = performance_threshold
|
||||
self.consecutive_low_scores = 0
|
||||
self.performance_history = []
|
||||
self.recent_f1_scores = []
|
||||
|
||||
def should_reset_connections(self, current_f1=None):
|
||||
"""基于计数和性能双重判断"""
|
||||
# 定期重置
|
||||
if self.question_count % self.reset_interval == 0:
|
||||
return True
|
||||
|
||||
# 性能驱动的重置
|
||||
if current_f1 is not None and current_f1 < self.performance_threshold:
|
||||
self.consecutive_low_scores += 1
|
||||
if self.consecutive_low_scores >= 2: # 连续2个低分就重置
|
||||
print("🚨 连续低分,触发紧急重置")
|
||||
self.consecutive_low_scores = 0
|
||||
return True
|
||||
else:
|
||||
self.consecutive_low_scores = 0
|
||||
|
||||
return False
|
||||
|
||||
def record_performance(self, question_index, metrics, context_length, retrieved_docs):
|
||||
"""记录性能指标,检测衰减"""
|
||||
self.performance_history.append({
|
||||
'index': question_index,
|
||||
'metrics': metrics,
|
||||
'context_length': context_length,
|
||||
'retrieved_docs': retrieved_docs,
|
||||
'timestamp': time.time()
|
||||
})
|
||||
|
||||
# 记录最近的F1分数
|
||||
self.recent_f1_scores.append(metrics['f1'])
|
||||
if len(self.recent_f1_scores) > 5:
|
||||
self.recent_f1_scores.pop(0)
|
||||
|
||||
def get_recent_performance(self):
|
||||
"""获取近期平均性能"""
|
||||
if not self.recent_f1_scores:
|
||||
return 0.5
|
||||
return sum(self.recent_f1_scores) / len(self.recent_f1_scores)
|
||||
|
||||
def get_performance_trend(self):
|
||||
"""分析性能趋势"""
|
||||
if len(self.performance_history) < 2:
|
||||
return "stable"
|
||||
|
||||
recent_metrics = [item['metrics']['f1'] for item in self.performance_history[-5:]]
|
||||
earlier_metrics = [item['metrics']['f1'] for item in self.performance_history[-10:-5]]
|
||||
|
||||
if len(recent_metrics) < 2 or len(earlier_metrics) < 2:
|
||||
return "stable"
|
||||
|
||||
recent_avg = sum(recent_metrics) / len(recent_metrics)
|
||||
earlier_avg = sum(earlier_metrics) / len(earlier_metrics)
|
||||
|
||||
if recent_avg < earlier_avg * 0.8:
|
||||
return "degrading"
|
||||
elif recent_avg > earlier_avg * 1.1:
|
||||
return "improving"
|
||||
else:
|
||||
return "stable"
|
||||
|
||||
|
||||
def get_enhanced_search_params(question: str, question_index: int, total_questions: int, recent_performance: float):
|
||||
"""基于问题复杂度和近期性能动态调整检索参数"""
|
||||
|
||||
# 基础参数
|
||||
base_params = get_dynamic_search_params(question, question_index, total_questions)
|
||||
|
||||
# 性能自适应调整
|
||||
if recent_performance < 0.5: # 近期表现差
|
||||
# 增加检索范围,尝试获取更多上下文
|
||||
base_params["limit"] = min(base_params["limit"] + 5, 25)
|
||||
base_params["max_chars"] = min(base_params["max_chars"] + 2000, 12000)
|
||||
print(f"📈 性能自适应:增加检索范围 (limit={base_params['limit']}, max_chars={base_params['max_chars']})")
|
||||
|
||||
elif recent_performance > 0.8: # 近期表现好
|
||||
# 收紧检索,提高精度
|
||||
base_params["limit"] = max(base_params["limit"] - 2, 8)
|
||||
base_params["max_chars"] = max(base_params["max_chars"] - 1000, 6000)
|
||||
print(f"🎯 性能自适应:提高检索精度 (limit={base_params['limit']}, max_chars={base_params['max_chars']})")
|
||||
|
||||
# 中间阶段特殊处理
|
||||
mid_sequence_factor = abs(question_index / total_questions - 0.5)
|
||||
if mid_sequence_factor < 0.2: # 在中间30%的问题
|
||||
print("🎯 中间阶段:使用更精确的检索策略")
|
||||
base_params["limit"] = max(base_params["limit"] - 2, 10) # 减少数量,提高质量
|
||||
base_params["max_chars"] = max(base_params["max_chars"] - 1000, 7000)
|
||||
|
||||
return base_params
|
||||
|
||||
|
||||
def enhanced_context_selection(contexts: List[str], question: str, question_index: int, total_questions: int, max_chars: int = 8000) -> str:
|
||||
"""考虑问题序列位置的智能选择"""
|
||||
|
||||
if not contexts:
|
||||
return ""
|
||||
|
||||
# 在序列中间阶段使用更严格的筛选
|
||||
mid_sequence_factor = abs(question_index / total_questions - 0.5) # 距离中心的距离
|
||||
|
||||
if mid_sequence_factor < 0.2: # 在中间30%的问题
|
||||
print("🎯 中间阶段:使用严格上下文筛选")
|
||||
|
||||
# 提取问题关键词
|
||||
question_lower = question.lower()
|
||||
stop_words = {'what', 'when', 'where', 'who', 'why', 'how', 'did', 'do', 'does', 'is', 'are', 'was', 'were', 'the', 'a', 'an', 'and', 'or', 'but'}
|
||||
question_words = set(re.findall(r'\b\w+\b', question_lower))
|
||||
question_words = {word for word in question_words if word not in stop_words and len(word) > 2}
|
||||
|
||||
# 只保留高度相关的上下文
|
||||
filtered_contexts = []
|
||||
for context in contexts:
|
||||
context_lower = context.lower()
|
||||
relevance_score = sum(3 if word in context_lower else 0 for word in question_words)
|
||||
|
||||
# 额外加分给包含数字、日期的上下文(对事实性问题更重要)
|
||||
if any(char.isdigit() for char in context):
|
||||
relevance_score += 2
|
||||
|
||||
# 提高阈值:只有得分>=3的上下文才保留
|
||||
if relevance_score >= 3:
|
||||
filtered_contexts.append(context)
|
||||
else:
|
||||
print(f" - 过滤低分上下文: 得分={relevance_score}")
|
||||
|
||||
contexts = filtered_contexts
|
||||
print(f"🔍 严格筛选后保留 {len(contexts)} 个上下文")
|
||||
|
||||
# 使用原有的智能选择逻辑
|
||||
return smart_context_selection(contexts, question, max_chars)
|
||||
|
||||
|
||||
async def run_enhanced_evaluation():
|
||||
"""使用增强方法进行完整评估 - 解决中间性能衰减问题"""
|
||||
try:
|
||||
from dotenv import load_dotenv
|
||||
except Exception:
|
||||
def load_dotenv():
|
||||
return None
|
||||
|
||||
# 修正导入路径:使用 app.core.memory.src 前缀
|
||||
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
|
||||
from app.core.memory.utils.config.definitions import (
|
||||
SELECTED_EMBEDDING_ID,
|
||||
SELECTED_LLM_ID,
|
||||
)
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
from app.db import get_db_context
|
||||
from app.repositories.neo4j.graph_search import search_graph_by_embedding
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
|
||||
# 加载数据
|
||||
# 获取项目根目录
|
||||
current_file = os.path.abspath(__file__)
|
||||
evaluation_dir = os.path.dirname(os.path.dirname(current_file)) # evaluation目录
|
||||
memory_dir = os.path.dirname(evaluation_dir) # memory目录
|
||||
data_path = os.path.join(memory_dir, "data", "locomo10.json")
|
||||
with open(data_path, "r", encoding="utf-8") as f:
|
||||
raw = json.load(f)
|
||||
|
||||
qa_items = []
|
||||
if isinstance(raw, list):
|
||||
for entry in raw:
|
||||
qa_items.extend(entry.get("qa", []))
|
||||
else:
|
||||
qa_items.extend(raw.get("qa", []))
|
||||
|
||||
items = qa_items[:20] # 测试多少个问题
|
||||
|
||||
# 初始化增强监控器
|
||||
monitor = EnhancedEvaluationMonitor(reset_interval=5, performance_threshold=0.6)
|
||||
|
||||
with get_db_context() as db:
|
||||
factory = MemoryClientFactory(db)
|
||||
llm = factory.get_llm_client(SELECTED_LLM_ID)
|
||||
|
||||
# 初始化embedder
|
||||
with get_db_context() as db:
|
||||
config_service = MemoryConfigService(db)
|
||||
cfg_dict = config_service.get_embedder_config(SELECTED_EMBEDDING_ID)
|
||||
embedder = OpenAIEmbedderClient(
|
||||
model_config=RedBearModelConfig.model_validate(cfg_dict)
|
||||
)
|
||||
|
||||
# 初始化连接器
|
||||
connector = Neo4jConnector()
|
||||
|
||||
# 初始化结果字典
|
||||
results = {
|
||||
"questions": [],
|
||||
"overall_metrics": {"f1": 0.0, "b1": 0.0, "j": 0.0, "loc_f1": 0.0},
|
||||
"category_metrics": {},
|
||||
"retrieval_stats": {"total_questions": len(items), "avg_context_length": 0, "avg_retrieved_docs": 0},
|
||||
"performance_trend": "stable",
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"enhanced_strategy": True
|
||||
}
|
||||
|
||||
total_f1 = 0.0
|
||||
total_bleu1 = 0.0
|
||||
total_jaccard = 0.0
|
||||
total_loc_f1 = 0.0
|
||||
total_context_length = 0
|
||||
total_retrieved_docs = 0
|
||||
category_stats = {}
|
||||
|
||||
try:
|
||||
for i, item in enumerate(items):
|
||||
monitor.question_count += 1
|
||||
|
||||
# 获取近期性能用于重置判断
|
||||
recent_performance = monitor.get_recent_performance()
|
||||
|
||||
# 增强的重置判断
|
||||
should_reset = monitor.should_reset_connections(current_f1=recent_performance)
|
||||
if should_reset and i > 0:
|
||||
print(f"🔄 重置Neo4j连接 (问题 {i+1}/{len(items)}, 近期性能: {recent_performance:.3f})...")
|
||||
await connector.close()
|
||||
connector = Neo4jConnector() # 创建新连接
|
||||
print("✅ 连接重置完成")
|
||||
|
||||
q = item.get("question", "")
|
||||
ref = item.get("answer", "")
|
||||
ref_str = str(ref) if ref is not None else ""
|
||||
|
||||
print(f"\n🔍 [{i+1}/{len(items)}] 问题: {q}")
|
||||
print(f"✅ 真实答案: {ref_str}")
|
||||
|
||||
# 分类别统计
|
||||
category = "Unknown"
|
||||
if item.get("category") == 1:
|
||||
category = "Multi-Hop"
|
||||
elif item.get("category") == 2:
|
||||
category = "Temporal"
|
||||
elif item.get("category") == 3:
|
||||
category = "Open Domain"
|
||||
elif item.get("category") == 4:
|
||||
category = "Single-Hop"
|
||||
|
||||
# 增强的检索参数
|
||||
search_params = get_enhanced_search_params(q, i, len(items), recent_performance)
|
||||
search_limit = search_params["limit"]
|
||||
max_chars = search_params["max_chars"]
|
||||
|
||||
print(f"🏷️ 类别: {category}, 检索参数: limit={search_limit}, max_chars={max_chars}")
|
||||
|
||||
# 使用项目标准的混合检索方法
|
||||
t0 = time.time()
|
||||
contexts_all = []
|
||||
|
||||
try:
|
||||
# 使用统一的搜索服务
|
||||
from app.core.memory.storage_services.search import run_hybrid_search
|
||||
|
||||
print("🔀 使用混合搜索服务...")
|
||||
|
||||
search_results = await run_hybrid_search(
|
||||
query_text=q,
|
||||
search_type="hybrid",
|
||||
group_id="locomo_sk",
|
||||
limit=20,
|
||||
include=["statements", "chunks", "entities", "summaries"],
|
||||
alpha=0.6, # BM25权重
|
||||
embedding_id=SELECTED_EMBEDDING_ID
|
||||
)
|
||||
|
||||
# 处理搜索结果 - 新的搜索服务返回统一的结构
|
||||
chunks = search_results.get("chunks", [])
|
||||
statements = search_results.get("statements", [])
|
||||
entities = search_results.get("entities", [])
|
||||
summaries = search_results.get("summaries", [])
|
||||
|
||||
print(f"✅ 混合检索成功: {len(chunks)} chunks, {len(statements)} 条陈述, {len(entities)} 个实体, {len(summaries)} 个摘要")
|
||||
|
||||
# 构建上下文:优先使用 chunks、statements 和 summaries
|
||||
for c in chunks:
|
||||
content = str(c.get("content", "")).strip()
|
||||
if content:
|
||||
contexts_all.append(content)
|
||||
|
||||
for s in statements:
|
||||
stmt_text = str(s.get("statement", "")).strip()
|
||||
if stmt_text:
|
||||
contexts_all.append(stmt_text)
|
||||
|
||||
for sm in summaries:
|
||||
summary_text = str(sm.get("summary", "")).strip()
|
||||
if summary_text:
|
||||
contexts_all.append(summary_text)
|
||||
|
||||
# 实体摘要:最多加入前3个高分实体,避免噪声
|
||||
scored = [e for e in entities if e.get("score") is not None]
|
||||
top_entities = sorted(scored, key=lambda x: x.get("score", 0), reverse=True)[:3] if scored else entities[:3]
|
||||
if top_entities:
|
||||
summary_lines = []
|
||||
for e in top_entities:
|
||||
name = str(e.get("name", "")).strip()
|
||||
etype = str(e.get("entity_type", "")).strip()
|
||||
score = e.get("score")
|
||||
if name:
|
||||
meta = []
|
||||
if etype:
|
||||
meta.append(f"type={etype}")
|
||||
if isinstance(score, (int, float)):
|
||||
meta.append(f"score={score:.3f}")
|
||||
summary_lines.append(f"EntitySummary: {name}{(' [' + ' '.join(meta) + ']') if meta else ''}")
|
||||
if summary_lines:
|
||||
contexts_all.append("\n".join(summary_lines))
|
||||
|
||||
print(f"📊 有效上下文数量: {len(contexts_all)}")
|
||||
except Exception as e:
|
||||
print(f"❌ 检索失败: {e}")
|
||||
contexts_all = []
|
||||
|
||||
t1 = time.time()
|
||||
search_time = (t1 - t0) * 1000
|
||||
|
||||
# 增强的上下文选择
|
||||
context_text = ""
|
||||
if contexts_all:
|
||||
# 使用增强的上下文选择
|
||||
context_text = enhanced_context_selection(contexts_all, q, i, len(items), max_chars=max_chars)
|
||||
|
||||
# 如果智能选择后仍然过长,进行最终保护性截断
|
||||
if len(context_text) > max_chars:
|
||||
print(f"⚠️ 智能选择后仍然过长 ({len(context_text)}字符),进行最终截断")
|
||||
context_text = context_text[:max_chars] + "\n\n[最终截断...]"
|
||||
|
||||
# 时间解析
|
||||
anchor_date = datetime(2023, 5, 8) # 使用固定日期确保一致性
|
||||
context_text = _resolve_relative_times(context_text, anchor_date)
|
||||
|
||||
context_text = f"Reference date: {anchor_date.date().isoformat()}\n\n" + context_text
|
||||
|
||||
print(f"📝 最终上下文长度: {len(context_text)} 字符")
|
||||
|
||||
# 显示不同上下文的预览(不只是第一条)
|
||||
print("🔍 上下文预览:")
|
||||
for j, context in enumerate(contexts_all[:3]): # 显示前3个上下文
|
||||
preview = context[:150].replace('\n', ' ')
|
||||
print(f" 上下文{j+1}: {preview}...")
|
||||
|
||||
# 🔍 调试:检查答案是否在上下文中
|
||||
if ref_str and ref_str.strip():
|
||||
answer_found = any(ref_str.lower() in ctx.lower() for ctx in contexts_all)
|
||||
print(f"🔍 调试:答案 '{ref_str}' 是否在检索到的上下文中? {'✅ 是' if answer_found else '❌ 否'}")
|
||||
|
||||
else:
|
||||
print("❌ 没有检索到有效上下文")
|
||||
context_text = "No relevant context found."
|
||||
|
||||
# LLM 回答
|
||||
messages = [
|
||||
{"role": "system", "content": (
|
||||
"You are a precise QA assistant. Answer following these rules:\n"
|
||||
"1) Extract the EXACT information mentioned in the context\n"
|
||||
"2) For time questions: calculate actual dates from relative times\n"
|
||||
"3) Return ONLY the answer text in simplest form\n"
|
||||
"4) For dates, use format 'DD Month YYYY' (e.g., '7 May 2023')\n"
|
||||
"5) If no clear answer found, respond with 'Unknown'"
|
||||
)},
|
||||
{"role": "user", "content": f"Question: {q}\n\nContext:\n{context_text}"},
|
||||
]
|
||||
|
||||
t2 = time.time()
|
||||
try:
|
||||
# 使用异步调用
|
||||
resp = await llm.chat(messages=messages)
|
||||
# 兼容不同的响应格式
|
||||
pred = resp.content.strip() if hasattr(resp, 'content') else (resp["choices"][0]["message"]["content"].strip() if isinstance(resp, dict) else "Unknown")
|
||||
except Exception as e:
|
||||
print(f"❌ LLM 生成失败: {e}")
|
||||
pred = "Unknown"
|
||||
t3 = time.time()
|
||||
llm_time = (t3 - t2) * 1000
|
||||
|
||||
# 计算指标 - 使用导入的指标函数
|
||||
f1_val = f1_score(pred, ref_str)
|
||||
bleu1_val = bleu1(pred, ref_str)
|
||||
jaccard_val = jaccard(pred, ref_str)
|
||||
loc_f1_val = loc_f1_score(pred, ref_str)
|
||||
|
||||
print(f"🤖 LLM 回答: {pred}")
|
||||
print(f"📈 指标 - F1: {f1_val:.3f}, BLEU-1: {bleu1_val:.3f}, Jaccard: {jaccard_val:.3f}, LoCoMo F1: {loc_f1_val:.3f}")
|
||||
print(f"⏱️ 时间 - 检索: {search_time:.1f}ms, LLM: {llm_time:.1f}ms")
|
||||
|
||||
# 更新统计
|
||||
total_f1 += f1_val
|
||||
total_bleu1 += bleu1_val
|
||||
total_jaccard += jaccard_val
|
||||
total_loc_f1 += loc_f1_val
|
||||
total_context_length += len(context_text)
|
||||
total_retrieved_docs += len(contexts_all)
|
||||
|
||||
if category not in category_stats:
|
||||
category_stats[category] = {"count": 0, "f1_sum": 0.0, "b1_sum": 0.0, "j_sum": 0.0, "loc_f1_sum": 0.0}
|
||||
|
||||
category_stats[category]["count"] += 1
|
||||
category_stats[category]["f1_sum"] += f1_val
|
||||
category_stats[category]["b1_sum"] += bleu1_val
|
||||
category_stats[category]["j_sum"] += jaccard_val
|
||||
category_stats[category]["loc_f1_sum"] += loc_f1_val
|
||||
|
||||
# 记录性能指标
|
||||
metrics = {"f1": f1_val, "bleu1": bleu1_val, "jaccard": jaccard_val, "loc_f1": loc_f1_val}
|
||||
monitor.record_performance(i, metrics, len(context_text), len(contexts_all))
|
||||
|
||||
# 保存结果
|
||||
question_result = {
|
||||
"question": q,
|
||||
"ground_truth": ref_str,
|
||||
"prediction": pred,
|
||||
"category": category,
|
||||
"metrics": metrics,
|
||||
"retrieval": {
|
||||
"retrieved_documents": len(contexts_all),
|
||||
"context_length": len(context_text),
|
||||
"search_limit": search_limit,
|
||||
"max_chars": max_chars,
|
||||
"recent_performance": recent_performance
|
||||
},
|
||||
"timing": {
|
||||
"search_ms": search_time,
|
||||
"llm_ms": llm_time
|
||||
}
|
||||
}
|
||||
|
||||
results["questions"].append(question_result)
|
||||
|
||||
print("="*60)
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ 评估过程中发生错误: {e}")
|
||||
# 即使出错,也返回已有的结果
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
finally:
|
||||
await connector.close()
|
||||
|
||||
# 计算总体指标
|
||||
n = len(items)
|
||||
if n > 0:
|
||||
results["overall_metrics"] = {
|
||||
"f1": total_f1 / n,
|
||||
"b1": total_bleu1 / n,
|
||||
"j": total_jaccard / n,
|
||||
"loc_f1": total_loc_f1 / n
|
||||
}
|
||||
|
||||
for category, stats in category_stats.items():
|
||||
count = stats["count"]
|
||||
results["category_metrics"][category] = {
|
||||
"count": count,
|
||||
"f1": stats["f1_sum"] / count,
|
||||
"bleu1": stats["b1_sum"] / count,
|
||||
"jaccard": stats["j_sum"] / count,
|
||||
"loc_f1": stats["loc_f1_sum"] / count
|
||||
}
|
||||
|
||||
results["retrieval_stats"]["avg_context_length"] = total_context_length / n
|
||||
results["retrieval_stats"]["avg_retrieved_docs"] = total_retrieved_docs / n
|
||||
|
||||
# 分析性能趋势
|
||||
results["performance_trend"] = monitor.get_performance_trend()
|
||||
results["reset_interval"] = monitor.reset_interval
|
||||
results["total_questions_processed"] = monitor.question_count
|
||||
|
||||
return results
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("🚀 运行增强版完整评估(解决中间性能衰减问题)...")
|
||||
print("📋 增强特性:")
|
||||
print(" - 双重重置策略:定期重置 + 性能驱动重置")
|
||||
print(" - 动态检索参数:基于近期性能自适应调整")
|
||||
print(" - 中间阶段严格筛选:提高上下文质量要求")
|
||||
print(" - 连续性能监控:实时检测性能衰减")
|
||||
|
||||
result = asyncio.run(run_enhanced_evaluation())
|
||||
|
||||
print("\n📊 最终评估结果:")
|
||||
print("总体指标:")
|
||||
print(f" F1: {result['overall_metrics']['f1']:.4f}")
|
||||
print(f" BLEU-1: {result['overall_metrics']['b1']:.4f}")
|
||||
print(f" Jaccard: {result['overall_metrics']['j']:.4f}")
|
||||
print(f" LoCoMo F1: {result['overall_metrics']['loc_f1']:.4f}")
|
||||
|
||||
print("\n分类别指标:")
|
||||
for category, metrics in result['category_metrics'].items():
|
||||
print(f" {category}: F1={metrics['f1']:.4f}, BLEU-1={metrics['bleu1']:.4f}, Jaccard={metrics['jaccard']:.4f}, LoCoMo F1={metrics['loc_f1']:.4f} (样本数: {metrics['count']})")
|
||||
|
||||
print("\n检索统计:")
|
||||
stats = result['retrieval_stats']
|
||||
print(f" 平均上下文长度: {stats['avg_context_length']:.0f} 字符")
|
||||
print(f" 平均检索文档数: {stats['avg_retrieved_docs']:.1f}")
|
||||
|
||||
print(f"\n性能趋势: {result['performance_trend']}")
|
||||
print(f"重置间隔: 每{result['reset_interval']}个问题")
|
||||
print(f"处理问题总数: {result['total_questions_processed']}")
|
||||
print(f"增强策略: {'启用' if result.get('enhanced_strategy', False) else '未启用'}")
|
||||
|
||||
|
||||
# 保存结果到指定目录
|
||||
# 使用代码文件所在目录的绝对路径
|
||||
current_file_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
output_dir = os.path.join(current_file_dir, "results")
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
output_file = os.path.join(output_dir, "enhanced_evaluation_results.json")
|
||||
with open(output_file, "w", encoding="utf-8") as f:
|
||||
json.dump(result, f, ensure_ascii=False, indent=2)
|
||||
print(f"\n详细结果已保存到: {output_file}")
|
||||
@@ -1,626 +0,0 @@
|
||||
"""
|
||||
LoCoMo Utilities Module
|
||||
|
||||
This module provides helper functions for the LoCoMo benchmark evaluation:
|
||||
- Data loading from JSON files
|
||||
- Conversation extraction for ingestion
|
||||
- Temporal reference resolution
|
||||
- Context selection and formatting
|
||||
- Retrieval wrapper functions
|
||||
- Ingestion wrapper functions
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
import re
|
||||
from datetime import datetime, timedelta
|
||||
from typing import List, Dict, Any, Optional
|
||||
|
||||
from app.core.memory.utils.definitions import PROJECT_ROOT
|
||||
from app.core.memory.evaluation.extraction_utils import ingest_contexts_via_full_pipeline
|
||||
|
||||
|
||||
def load_locomo_data(
|
||||
data_path: str,
|
||||
sample_size: int,
|
||||
conversation_index: int = 0
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Load LoCoMo dataset from JSON file.
|
||||
|
||||
The LoCoMo dataset structure is a list of conversation objects, where each
|
||||
object contains a "qa" list of question-answer pairs.
|
||||
|
||||
Args:
|
||||
data_path: Path to locomo10.json file
|
||||
sample_size: Number of QA pairs to load (limits total QA items returned)
|
||||
conversation_index: Which conversation to load QA pairs from (default: 0 for first)
|
||||
|
||||
Returns:
|
||||
List of QA item dictionaries, each containing:
|
||||
- question: str
|
||||
- answer: str
|
||||
- category: int (1-4)
|
||||
- evidence: List[str]
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If data_path does not exist
|
||||
json.JSONDecodeError: If file is not valid JSON
|
||||
IndexError: If conversation_index is out of range
|
||||
"""
|
||||
if not os.path.exists(data_path):
|
||||
raise FileNotFoundError(f"LoCoMo data file not found: {data_path}")
|
||||
|
||||
with open(data_path, "r", encoding="utf-8") as f:
|
||||
raw = json.load(f)
|
||||
|
||||
# LoCoMo data structure: list of objects, each with a "qa" list
|
||||
qa_items: List[Dict[str, Any]] = []
|
||||
|
||||
if isinstance(raw, list):
|
||||
# Only load QA pairs from the specified conversation
|
||||
if conversation_index < len(raw):
|
||||
entry = raw[conversation_index]
|
||||
if isinstance(entry, dict) and "qa" in entry:
|
||||
qa_items.extend(entry.get("qa", []))
|
||||
else:
|
||||
raise IndexError(
|
||||
f"Conversation index {conversation_index} out of range. "
|
||||
f"Dataset has {len(raw)} conversations."
|
||||
)
|
||||
else:
|
||||
# Fallback: single object with qa list
|
||||
if conversation_index == 0:
|
||||
qa_items.extend(raw.get("qa", []))
|
||||
else:
|
||||
raise IndexError(
|
||||
f"Conversation index {conversation_index} out of range. "
|
||||
f"Dataset has only 1 conversation."
|
||||
)
|
||||
|
||||
# Return only the requested sample size
|
||||
return qa_items[:sample_size]
|
||||
|
||||
|
||||
def extract_conversations(data_path: str, max_dialogues: int = 1) -> List[str]:
|
||||
"""
|
||||
Extract conversation texts from LoCoMo data for ingestion.
|
||||
|
||||
This function extracts the raw conversation dialogues from the LoCoMo dataset
|
||||
so they can be ingested into the memory system. Each conversation is formatted
|
||||
as a multi-line string with "role: message" format.
|
||||
|
||||
Args:
|
||||
data_path: Path to locomo10.json file
|
||||
max_dialogues: Maximum number of dialogues to extract (default: 1)
|
||||
|
||||
Returns:
|
||||
List of conversation strings formatted for ingestion.
|
||||
Each string contains multiple lines in format "role: message"
|
||||
|
||||
Example output:
|
||||
[
|
||||
"User: I went to the store yesterday.\\nAI: What did you buy?\\n...",
|
||||
"User: I love hiking.\\nAI: Where do you like to hike?\\n..."
|
||||
]
|
||||
"""
|
||||
if not os.path.exists(data_path):
|
||||
raise FileNotFoundError(f"LoCoMo data file not found: {data_path}")
|
||||
|
||||
with open(data_path, "r", encoding="utf-8") as f:
|
||||
raw = json.load(f)
|
||||
|
||||
# Ensure we have a list of entries
|
||||
entries = raw if isinstance(raw, list) else [raw]
|
||||
|
||||
contents: List[str] = []
|
||||
|
||||
for i, entry in enumerate(entries[:max_dialogues]):
|
||||
if not isinstance(entry, dict):
|
||||
continue
|
||||
|
||||
conv = entry.get("conversation", {})
|
||||
|
||||
if not isinstance(conv, dict):
|
||||
continue
|
||||
|
||||
lines: List[str] = []
|
||||
|
||||
# Collect all session_* messages
|
||||
for key, val in sorted(conv.items()):
|
||||
if isinstance(val, list) and key.startswith("session_"):
|
||||
for msg in val:
|
||||
if not isinstance(msg, dict):
|
||||
continue
|
||||
|
||||
role = msg.get("speaker") or "User"
|
||||
text = msg.get("text") or ""
|
||||
text = str(text).strip()
|
||||
|
||||
if not text:
|
||||
continue
|
||||
|
||||
lines.append(f"{role}: {text}")
|
||||
|
||||
if lines:
|
||||
contents.append("\n".join(lines))
|
||||
|
||||
return contents
|
||||
|
||||
|
||||
def resolve_temporal_references(text: str, anchor_date: datetime) -> str:
|
||||
"""
|
||||
Resolve relative temporal references to absolute dates.
|
||||
|
||||
This function converts relative time expressions (like "today", "yesterday",
|
||||
"3 days ago") into absolute ISO date strings based on an anchor date.
|
||||
|
||||
Supported patterns:
|
||||
- today, yesterday, tomorrow
|
||||
- X days ago, in X days
|
||||
- last week, next week
|
||||
|
||||
Args:
|
||||
text: Text containing temporal references
|
||||
anchor_date: Reference date for resolution (datetime object)
|
||||
|
||||
Returns:
|
||||
Text with temporal references replaced by ISO dates (YYYY-MM-DD format)
|
||||
|
||||
Example:
|
||||
>>> anchor = datetime(2023, 5, 8)
|
||||
>>> resolve_temporal_references("I saw him yesterday", anchor)
|
||||
"I saw him 2023-05-07"
|
||||
"""
|
||||
# Ensure input is a string
|
||||
t = str(text) if text is not None else ""
|
||||
|
||||
# today / yesterday / tomorrow
|
||||
t = re.sub(
|
||||
r"\btoday\b",
|
||||
anchor_date.date().isoformat(),
|
||||
t,
|
||||
flags=re.IGNORECASE
|
||||
)
|
||||
t = re.sub(
|
||||
r"\byesterday\b",
|
||||
(anchor_date - timedelta(days=1)).date().isoformat(),
|
||||
t,
|
||||
flags=re.IGNORECASE
|
||||
)
|
||||
t = re.sub(
|
||||
r"\btomorrow\b",
|
||||
(anchor_date + timedelta(days=1)).date().isoformat(),
|
||||
t,
|
||||
flags=re.IGNORECASE
|
||||
)
|
||||
|
||||
# X days ago
|
||||
def _ago_repl(m: re.Match[str]) -> str:
|
||||
n = int(m.group(1))
|
||||
return (anchor_date - timedelta(days=n)).date().isoformat()
|
||||
|
||||
# in X days
|
||||
def _in_repl(m: re.Match[str]) -> str:
|
||||
n = int(m.group(1))
|
||||
return (anchor_date + timedelta(days=n)).date().isoformat()
|
||||
|
||||
t = re.sub(
|
||||
r"\b(\d+)\s+days?\s+ago\b",
|
||||
_ago_repl,
|
||||
t,
|
||||
flags=re.IGNORECASE
|
||||
)
|
||||
t = re.sub(
|
||||
r"\bin\s+(\d+)\s+days?\b",
|
||||
_in_repl,
|
||||
t,
|
||||
flags=re.IGNORECASE
|
||||
)
|
||||
|
||||
# last week / next week (approximate as 7 days)
|
||||
t = re.sub(
|
||||
r"\blast\s+week\b",
|
||||
(anchor_date - timedelta(days=7)).date().isoformat(),
|
||||
t,
|
||||
flags=re.IGNORECASE
|
||||
)
|
||||
t = re.sub(
|
||||
r"\bnext\s+week\b",
|
||||
(anchor_date + timedelta(days=7)).date().isoformat(),
|
||||
t,
|
||||
flags=re.IGNORECASE
|
||||
)
|
||||
|
||||
return t
|
||||
|
||||
|
||||
def select_and_format_information(
|
||||
retrieved_info: List[str],
|
||||
question: str,
|
||||
max_chars: int = 8000
|
||||
) -> str:
|
||||
"""
|
||||
Intelligently select and format most relevant retrieved information for LLM prompt.
|
||||
|
||||
This function scores each piece of retrieved information based on keyword matching
|
||||
with the question, then selects the highest-scoring pieces up to the character limit.
|
||||
|
||||
Scoring criteria:
|
||||
- Keyword matches (higher weight for multiple occurrences)
|
||||
- Context length (moderate length preferred)
|
||||
- Position (earlier contexts get bonus points)
|
||||
|
||||
Args:
|
||||
retrieved_info: List of retrieved information strings (chunks, statements, entities)
|
||||
question: Question being answered
|
||||
max_chars: Maximum total characters to include in final prompt
|
||||
|
||||
Returns:
|
||||
Formatted string combining the most relevant information for LLM prompt.
|
||||
Contexts are separated by double newlines.
|
||||
|
||||
Example:
|
||||
>>> contexts = ["Alice went to Paris", "Bob likes pizza", "Alice visited the Eiffel Tower"]
|
||||
>>> question = "Where did Alice go?"
|
||||
>>> select_and_format_information(contexts, question, max_chars=100)
|
||||
"Alice went to Paris\\n\\nAlice visited the Eiffel Tower"
|
||||
"""
|
||||
if not retrieved_info:
|
||||
return ""
|
||||
|
||||
# Extract question keywords (filter out stop words and short words)
|
||||
question_lower = question.lower()
|
||||
stop_words = {
|
||||
'what', 'when', 'where', 'who', 'why', 'how',
|
||||
'did', 'do', 'does', 'is', 'are', 'was', 'were',
|
||||
'the', 'a', 'an', 'and', 'or', 'but', 'in', 'on', 'at'
|
||||
}
|
||||
question_words = set(re.findall(r'\b\w+\b', question_lower))
|
||||
question_words = {
|
||||
word for word in question_words
|
||||
if word not in stop_words and len(word) > 2
|
||||
}
|
||||
|
||||
# Score each context
|
||||
scored_contexts = []
|
||||
for i, context in enumerate(retrieved_info):
|
||||
context_lower = context.lower()
|
||||
score = 0
|
||||
|
||||
# Keyword matching score
|
||||
keyword_matches = 0
|
||||
for word in question_words:
|
||||
if word in context_lower:
|
||||
keyword_matches += 1
|
||||
# Multiple occurrences increase score
|
||||
score += context_lower.count(word) * 2
|
||||
|
||||
# Length score (prefer moderate length)
|
||||
context_len = len(context)
|
||||
if 100 < context_len < 2000:
|
||||
score += 5
|
||||
elif context_len >= 2000:
|
||||
score += 2
|
||||
|
||||
# Position bonus (earlier contexts often more relevant)
|
||||
if i < 3:
|
||||
score += 3
|
||||
|
||||
scored_contexts.append((score, context, keyword_matches))
|
||||
|
||||
# Sort by score (descending)
|
||||
scored_contexts.sort(key=lambda x: x[0], reverse=True)
|
||||
|
||||
# Select contexts up to character limit
|
||||
selected = []
|
||||
total_chars = 0
|
||||
|
||||
for score, context, matches in scored_contexts:
|
||||
if total_chars + len(context) <= max_chars:
|
||||
selected.append(context)
|
||||
total_chars += len(context)
|
||||
else:
|
||||
# Try to include high-scoring context by truncating
|
||||
if score > 10 and total_chars < max_chars - 500:
|
||||
remaining = max_chars - total_chars
|
||||
# Find lines with keywords
|
||||
lines = context.split('\n')
|
||||
relevant_lines = []
|
||||
current_chars = 0
|
||||
|
||||
for line in lines:
|
||||
line_lower = line.lower()
|
||||
line_relevance = any(word in line_lower for word in question_words)
|
||||
|
||||
if line_relevance and current_chars < remaining - 100:
|
||||
relevant_lines.append(line)
|
||||
current_chars += len(line)
|
||||
|
||||
if relevant_lines and len('\n'.join(relevant_lines)) > 100:
|
||||
truncated = '\n'.join(relevant_lines)
|
||||
selected.append(truncated + "\n[Content truncated...]")
|
||||
total_chars += len(truncated)
|
||||
break
|
||||
|
||||
return "\n\n".join(selected)
|
||||
|
||||
|
||||
async def retrieve_relevant_information(
|
||||
question: str,
|
||||
group_id: str,
|
||||
search_type: str,
|
||||
search_limit: int,
|
||||
connector: Any,
|
||||
embedder: Any
|
||||
) -> List[str]:
|
||||
"""
|
||||
Retrieve relevant information from memory graph for a question.
|
||||
|
||||
This function searches the Neo4j memory graph (populated during ingestion) and
|
||||
returns relevant chunks, statements, and entity information that might help
|
||||
answer the question.
|
||||
|
||||
The function supports three search types:
|
||||
- "keyword": Full-text search using Cypher queries
|
||||
- "embedding": Vector similarity search using embeddings
|
||||
- "hybrid": Combination of keyword and embedding search with reranking
|
||||
|
||||
Args:
|
||||
question: Question to search for
|
||||
group_id: Database group ID (identifies which conversation memory to search)
|
||||
search_type: "keyword", "embedding", or "hybrid"
|
||||
search_limit: Max memory pieces to retrieve
|
||||
connector: Neo4j connector instance
|
||||
embedder: Embedder client instance
|
||||
|
||||
Returns:
|
||||
List of text strings (chunks, statements, entity summaries) from memory graph.
|
||||
Each string represents a piece of retrieved information.
|
||||
|
||||
Raises:
|
||||
Exception: If search fails (caught and returns empty list)
|
||||
"""
|
||||
from app.repositories.neo4j.graph_search import (
|
||||
search_graph,
|
||||
search_graph_by_embedding
|
||||
)
|
||||
from app.core.memory.storage_services.search import run_hybrid_search
|
||||
|
||||
contexts_all: List[str] = []
|
||||
|
||||
try:
|
||||
if search_type == "embedding":
|
||||
# Embedding-based search
|
||||
search_results = await search_graph_by_embedding(
|
||||
connector=connector,
|
||||
embedder_client=embedder,
|
||||
query_text=question,
|
||||
group_id=group_id,
|
||||
limit=search_limit,
|
||||
include=["chunks", "statements", "entities", "summaries"],
|
||||
)
|
||||
|
||||
chunks = search_results.get("chunks", [])
|
||||
statements = search_results.get("statements", [])
|
||||
entities = search_results.get("entities", [])
|
||||
summaries = search_results.get("summaries", [])
|
||||
|
||||
# Build context from chunks
|
||||
for c in chunks:
|
||||
content = str(c.get("content", "")).strip()
|
||||
if content:
|
||||
contexts_all.append(content)
|
||||
|
||||
# Add statements
|
||||
for s in statements:
|
||||
stmt_text = str(s.get("statement", "")).strip()
|
||||
if stmt_text:
|
||||
contexts_all.append(stmt_text)
|
||||
|
||||
# Add summaries
|
||||
for sm in summaries:
|
||||
summary_text = str(sm.get("summary", "")).strip()
|
||||
if summary_text:
|
||||
contexts_all.append(summary_text)
|
||||
|
||||
# Add top entities (limit to 3 to avoid noise)
|
||||
if entities:
|
||||
scored = [e for e in entities if e.get("score") is not None]
|
||||
top_entities = (
|
||||
sorted(scored, key=lambda x: x.get("score", 0), reverse=True)[:3]
|
||||
if scored else entities[:3]
|
||||
)
|
||||
if top_entities:
|
||||
summary_lines = []
|
||||
for e in top_entities:
|
||||
name = str(e.get("name", "")).strip()
|
||||
etype = str(e.get("entity_type", "")).strip()
|
||||
score = e.get("score")
|
||||
if name:
|
||||
meta = []
|
||||
if etype:
|
||||
meta.append(f"type={etype}")
|
||||
if isinstance(score, (int, float)):
|
||||
meta.append(f"score={score:.3f}")
|
||||
summary_lines.append(
|
||||
f"EntitySummary: {name}"
|
||||
f"{(' [' + '; '.join(meta) + ']') if meta else ''}"
|
||||
)
|
||||
if summary_lines:
|
||||
contexts_all.append("\n".join(summary_lines))
|
||||
|
||||
elif search_type == "keyword":
|
||||
# Keyword-based search
|
||||
search_results = await search_graph(
|
||||
connector=connector,
|
||||
q=question,
|
||||
group_id=group_id,
|
||||
limit=search_limit
|
||||
)
|
||||
|
||||
dialogs = search_results.get("dialogues", [])
|
||||
statements = search_results.get("statements", [])
|
||||
entities = search_results.get("entities", [])
|
||||
|
||||
# Build context from dialogues
|
||||
for d in dialogs:
|
||||
content = str(d.get("content", "")).strip()
|
||||
if content:
|
||||
contexts_all.append(content)
|
||||
|
||||
# Add statements
|
||||
for s in statements:
|
||||
stmt_text = str(s.get("statement", "")).strip()
|
||||
if stmt_text:
|
||||
contexts_all.append(stmt_text)
|
||||
|
||||
# Add entity names
|
||||
if entities:
|
||||
entity_names = [
|
||||
str(e.get("name", "")).strip()
|
||||
for e in entities[:5]
|
||||
if e.get("name")
|
||||
]
|
||||
if entity_names:
|
||||
contexts_all.append(f"EntitySummary: {', '.join(entity_names)}")
|
||||
|
||||
else: # hybrid
|
||||
# Hybrid search with fallback to embedding
|
||||
try:
|
||||
search_results = await run_hybrid_search(
|
||||
query_text=question,
|
||||
search_type=search_type,
|
||||
group_id=group_id,
|
||||
limit=search_limit,
|
||||
include=["chunks", "statements", "entities", "summaries"],
|
||||
output_path=None,
|
||||
)
|
||||
|
||||
# Handle flat structure (new API format)
|
||||
if search_results and isinstance(search_results, dict):
|
||||
chunks = search_results.get("chunks", [])
|
||||
statements = search_results.get("statements", [])
|
||||
entities = search_results.get("entities", [])
|
||||
summaries = search_results.get("summaries", [])
|
||||
|
||||
# Check if we got results
|
||||
if not (chunks or statements or entities or summaries):
|
||||
# Try nested structure (backward compatibility)
|
||||
reranked = search_results.get("reranked_results", {})
|
||||
if reranked and isinstance(reranked, dict):
|
||||
chunks = reranked.get("chunks", [])
|
||||
statements = reranked.get("statements", [])
|
||||
entities = reranked.get("entities", [])
|
||||
summaries = reranked.get("summaries", [])
|
||||
else:
|
||||
raise ValueError("Hybrid search returned empty results")
|
||||
else:
|
||||
raise ValueError("Hybrid search returned empty results")
|
||||
|
||||
except Exception as e:
|
||||
# Fallback to embedding search
|
||||
search_results = await search_graph_by_embedding(
|
||||
connector=connector,
|
||||
embedder_client=embedder,
|
||||
query_text=question,
|
||||
group_id=group_id,
|
||||
limit=search_limit,
|
||||
include=["chunks", "statements", "entities", "summaries"],
|
||||
)
|
||||
chunks = search_results.get("chunks", [])
|
||||
statements = search_results.get("statements", [])
|
||||
entities = search_results.get("entities", [])
|
||||
summaries = search_results.get("summaries", [])
|
||||
|
||||
# Build context (same for both hybrid and fallback)
|
||||
for c in chunks:
|
||||
content = str(c.get("content", "")).strip()
|
||||
if content:
|
||||
contexts_all.append(content)
|
||||
|
||||
for s in statements:
|
||||
stmt_text = str(s.get("statement", "")).strip()
|
||||
if stmt_text:
|
||||
contexts_all.append(stmt_text)
|
||||
|
||||
for sm in summaries:
|
||||
summary_text = str(sm.get("summary", "")).strip()
|
||||
if summary_text:
|
||||
contexts_all.append(summary_text)
|
||||
|
||||
# Add top entities
|
||||
if entities:
|
||||
scored = [e for e in entities if e.get("score") is not None]
|
||||
top_entities = (
|
||||
sorted(scored, key=lambda x: x.get("score", 0), reverse=True)[:3]
|
||||
if scored else entities[:3]
|
||||
)
|
||||
if top_entities:
|
||||
summary_lines = []
|
||||
for e in top_entities:
|
||||
name = str(e.get("name", "")).strip()
|
||||
etype = str(e.get("entity_type", "")).strip()
|
||||
score = e.get("score")
|
||||
if name:
|
||||
meta = []
|
||||
if etype:
|
||||
meta.append(f"type={etype}")
|
||||
if isinstance(score, (int, float)):
|
||||
meta.append(f"score={score:.3f}")
|
||||
summary_lines.append(
|
||||
f"EntitySummary: {name}"
|
||||
f"{(' [' + '; '.join(meta) + ']') if meta else ''}"
|
||||
)
|
||||
if summary_lines:
|
||||
contexts_all.append("\n".join(summary_lines))
|
||||
|
||||
except Exception as e:
|
||||
# Return empty list on error
|
||||
contexts_all = []
|
||||
|
||||
return contexts_all
|
||||
|
||||
|
||||
async def ingest_conversations_if_needed(
|
||||
conversations: List[str],
|
||||
group_id: str,
|
||||
reset: bool = False
|
||||
) -> bool:
|
||||
"""
|
||||
Wrapper for conversation ingestion using external extraction pipeline.
|
||||
|
||||
This function populates the Neo4j database with processed conversation data
|
||||
(chunks, statements, entities) so that the retrieval system has memory to search.
|
||||
|
||||
The ingestion process:
|
||||
1. Parses conversation text into dialogue messages
|
||||
2. Chunks the dialogues into semantic units
|
||||
3. Extracts statements and entities using LLM
|
||||
4. Generates embeddings for all content
|
||||
5. Stores everything in Neo4j graph database
|
||||
|
||||
Args:
|
||||
conversations: List of raw conversation texts from LoCoMo dataset
|
||||
Example: ["User: I went to Paris. AI: When was that?", ...]
|
||||
group_id: Target group ID for database storage
|
||||
reset: Whether to clear existing data first (not implemented in wrapper)
|
||||
|
||||
Returns:
|
||||
True if successful, False otherwise
|
||||
|
||||
Note:
|
||||
The external function uses "contexts" to mean "conversation texts".
|
||||
This runs the full extraction pipeline: chunking → entity extraction →
|
||||
statement extraction → embedding → Neo4j storage.
|
||||
"""
|
||||
try:
|
||||
success = await ingest_contexts_via_full_pipeline(
|
||||
contexts=conversations,
|
||||
group_id=group_id,
|
||||
save_chunk_output=True
|
||||
)
|
||||
return success
|
||||
except Exception as e:
|
||||
print(f"[Ingestion] Failed to ingest conversations: {e}")
|
||||
return False
|
||||
@@ -1,878 +0,0 @@
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import statistics
|
||||
import time
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any, Dict, List
|
||||
|
||||
try:
|
||||
from dotenv import load_dotenv
|
||||
except Exception:
|
||||
def load_dotenv():
|
||||
return None
|
||||
|
||||
import re
|
||||
|
||||
from app.core.memory.evaluation.common.metrics import (
|
||||
avg_context_tokens,
|
||||
bleu1,
|
||||
jaccard,
|
||||
latency_stats,
|
||||
)
|
||||
from app.core.memory.evaluation.common.metrics import f1_score as common_f1
|
||||
from app.core.memory.evaluation.extraction_utils import (
|
||||
ingest_contexts_via_full_pipeline,
|
||||
)
|
||||
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
|
||||
from app.core.memory.storage_services.search import run_hybrid_search
|
||||
from app.core.memory.utils.config.definitions import (
|
||||
PROJECT_ROOT,
|
||||
SELECTED_EMBEDDING_ID,
|
||||
SELECTED_GROUP_ID,
|
||||
SELECTED_LLM_ID,
|
||||
)
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
from app.db import get_db_context
|
||||
from app.repositories.neo4j.graph_search import search_graph, search_graph_by_embedding
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
|
||||
|
||||
# 参考 evaluation/locomo/evaluation.py 的 F1 计算逻辑(移除外部依赖,内联实现)
|
||||
def _loc_normalize(text: str) -> str:
|
||||
import re
|
||||
# 确保输入是字符串
|
||||
text = str(text) if text is not None else ""
|
||||
text = text.lower()
|
||||
text = re.sub(r"[\,]", " ", text) # 去掉逗号
|
||||
text = re.sub(r"\b(a|an|the|and)\b", " ", text)
|
||||
text = re.sub(r"[^\w\s]", " ", text)
|
||||
text = " ".join(text.split())
|
||||
return text
|
||||
|
||||
# 追加:相对时间归一化为绝对日期(有限支持:today/yesterday/tomorrow/X days ago/in X days/last week/next week)
|
||||
def _resolve_relative_times(text: str, anchor: datetime) -> str:
|
||||
import re
|
||||
# 确保输入是字符串
|
||||
t = str(text) if text is not None else ""
|
||||
# today / yesterday / tomorrow
|
||||
t = re.sub(r"\btoday\b", anchor.date().isoformat(), t, flags=re.IGNORECASE)
|
||||
t = re.sub(r"\byesterday\b", (anchor - timedelta(days=1)).date().isoformat(), t, flags=re.IGNORECASE)
|
||||
t = re.sub(r"\btomorrow\b", (anchor + timedelta(days=1)).date().isoformat(), t, flags=re.IGNORECASE)
|
||||
# X days ago / in X days
|
||||
def _ago_repl(m: re.Match[str]) -> str:
|
||||
n = int(m.group(1))
|
||||
return (anchor - timedelta(days=n)).date().isoformat()
|
||||
def _in_repl(m: re.Match[str]) -> str:
|
||||
n = int(m.group(1))
|
||||
return (anchor + timedelta(days=n)).date().isoformat()
|
||||
t = re.sub(r"\b(\d+)\s+days\s+ago\b", _ago_repl, t, flags=re.IGNORECASE)
|
||||
t = re.sub(r"\bin\s+(\d+)\s+days\b", _in_repl, t, flags=re.IGNORECASE)
|
||||
# last week / next week(以7天近似)
|
||||
t = re.sub(r"\blast\s+week\b", (anchor - timedelta(days=7)).date().isoformat(), t, flags=re.IGNORECASE)
|
||||
t = re.sub(r"\bnext\s+week\b", (anchor + timedelta(days=7)).date().isoformat(), t, flags=re.IGNORECASE)
|
||||
return t
|
||||
|
||||
def loc_f1_score(prediction: str, ground_truth: str) -> float:
|
||||
# 单答案 F1:按词集合计算(近似原始实现,去除词干依赖)
|
||||
# 确保输入是字符串
|
||||
pred_str = str(prediction) if prediction is not None else ""
|
||||
truth_str = str(ground_truth) if ground_truth is not None else ""
|
||||
|
||||
p_tokens = _loc_normalize(pred_str).split()
|
||||
g_tokens = _loc_normalize(truth_str).split()
|
||||
if not p_tokens or not g_tokens:
|
||||
return 0.0
|
||||
p = set(p_tokens)
|
||||
g = set(g_tokens)
|
||||
tp = len(p & g)
|
||||
precision = tp / len(p) if p else 0.0
|
||||
recall = tp / len(g) if g else 0.0
|
||||
return (2 * precision * recall / (precision + recall)) if (precision + recall) > 0 else 0.0
|
||||
|
||||
def loc_multi_f1(prediction: str, ground_truth: str) -> float:
|
||||
# 多答案 F1:prediction 与 ground_truth 以逗号分隔,逐一匹配取最大,再对多个 GT 取平均
|
||||
# 确保输入是字符串
|
||||
pred_str = str(prediction) if prediction is not None else ""
|
||||
truth_str = str(ground_truth) if ground_truth is not None else ""
|
||||
|
||||
predictions = [p.strip() for p in str(pred_str).split(',') if p.strip()]
|
||||
ground_truths = [g.strip() for g in str(truth_str).split(',') if g.strip()]
|
||||
if not predictions or not ground_truths:
|
||||
return 0.0
|
||||
def _f1(a: str, b: str) -> float:
|
||||
return loc_f1_score(a, b)
|
||||
vals = []
|
||||
for gt in ground_truths:
|
||||
vals.append(max(_f1(pred, gt) for pred in predictions))
|
||||
return sum(vals) / len(vals)
|
||||
|
||||
# 标准化 LoCoMo 类别名:支持数字 category 与字符串 cat/type
|
||||
CATEGORY_MAP_NUM_TO_NAME = {
|
||||
4: "Single-Hop",
|
||||
1: "Multi-Hop",
|
||||
3: "Open Domain",
|
||||
2: "Temporal",
|
||||
}
|
||||
|
||||
_TYPE_ALIASES = {
|
||||
"single-hop": "Single-Hop",
|
||||
"singlehop": "Single-Hop",
|
||||
"single hop": "Single-Hop",
|
||||
"multi-hop": "Multi-Hop",
|
||||
"multihop": "Multi-Hop",
|
||||
"multi hop": "Multi-Hop",
|
||||
"open domain": "Open Domain",
|
||||
"opendomain": "Open Domain",
|
||||
"temporal": "Temporal",
|
||||
}
|
||||
|
||||
def get_category_label(item: Dict[str, Any]) -> str:
|
||||
# 1) 直接用字符串 cat
|
||||
cat = item.get("cat")
|
||||
if isinstance(cat, str) and cat.strip():
|
||||
name = cat.strip()
|
||||
lower = name.lower()
|
||||
return _TYPE_ALIASES.get(lower, name)
|
||||
# 2) 数字 category 转名称
|
||||
cat_num = item.get("category")
|
||||
if isinstance(cat_num, int):
|
||||
return CATEGORY_MAP_NUM_TO_NAME.get(cat_num, "unknown")
|
||||
# 3) 备用 type 字段
|
||||
t = item.get("type")
|
||||
if isinstance(t, str) and t.strip():
|
||||
lower = t.strip().lower()
|
||||
return _TYPE_ALIASES.get(lower, t.strip())
|
||||
return "unknown"
|
||||
|
||||
|
||||
def smart_context_selection(contexts: List[str], question: str, max_chars: int = 12000) -> str:
|
||||
"""基于问题关键词智能选择上下文"""
|
||||
if not contexts:
|
||||
return ""
|
||||
|
||||
# 提取问题关键词(只保留有意义的词)
|
||||
question_lower = question.lower()
|
||||
stop_words = {'what', 'when', 'where', 'who', 'why', 'how', 'did', 'do', 'does', 'is', 'are', 'was', 'were', 'the', 'a', 'an', 'and', 'or', 'but'}
|
||||
question_words = set(re.findall(r'\b\w+\b', question_lower))
|
||||
question_words = {word for word in question_words if word not in stop_words and len(word) > 2}
|
||||
|
||||
print(f"🔍 问题关键词: {question_words}")
|
||||
|
||||
# 给每个上下文打分
|
||||
scored_contexts = []
|
||||
for i, context in enumerate(contexts):
|
||||
context_lower = context.lower()
|
||||
score = 0
|
||||
|
||||
# 关键词匹配得分
|
||||
keyword_matches = 0
|
||||
for word in question_words:
|
||||
if word in context_lower:
|
||||
keyword_matches += 1
|
||||
# 关键词出现次数越多,得分越高
|
||||
score += context_lower.count(word) * 2
|
||||
|
||||
# 上下文长度得分(适中的长度更好)
|
||||
context_len = len(context)
|
||||
if 100 < context_len < 2000: # 理想长度范围
|
||||
score += 5
|
||||
elif context_len >= 2000: # 太长可能包含无关信息
|
||||
score += 2
|
||||
|
||||
# 如果是前几个上下文,给予额外分数(通常相关性更高)
|
||||
if i < 3:
|
||||
score += 3
|
||||
|
||||
scored_contexts.append((score, context, keyword_matches))
|
||||
|
||||
# 按得分排序
|
||||
scored_contexts.sort(key=lambda x: x[0], reverse=True)
|
||||
|
||||
# 选择高得分的上下文,直到达到字符限制
|
||||
selected = []
|
||||
total_chars = 0
|
||||
selected_count = 0
|
||||
|
||||
print("📊 上下文相关性分析:")
|
||||
for score, context, matches in scored_contexts[:5]: # 只显示前5个
|
||||
print(f" - 得分: {score}, 关键词匹配: {matches}, 长度: {len(context)}")
|
||||
|
||||
for score, context, matches in scored_contexts:
|
||||
if total_chars + len(context) <= max_chars:
|
||||
selected.append(context)
|
||||
total_chars += len(context)
|
||||
selected_count += 1
|
||||
else:
|
||||
# 如果这个上下文得分很高但放不下,尝试截取
|
||||
if score > 10 and total_chars < max_chars - 500:
|
||||
remaining = max_chars - total_chars
|
||||
# 找到包含关键词的部分
|
||||
lines = context.split('\n')
|
||||
relevant_lines = []
|
||||
current_chars = 0
|
||||
|
||||
for line in lines:
|
||||
line_lower = line.lower()
|
||||
line_relevance = any(word in line_lower for word in question_words)
|
||||
|
||||
if line_relevance and current_chars < remaining - 100:
|
||||
relevant_lines.append(line)
|
||||
current_chars += len(line)
|
||||
|
||||
if relevant_lines:
|
||||
truncated = '\n'.join(relevant_lines)
|
||||
if len(truncated) > 100: # 确保有足够内容
|
||||
selected.append(truncated + "\n[相关内容截断...]")
|
||||
total_chars += len(truncated)
|
||||
selected_count += 1
|
||||
break # 不再尝试添加更多上下文
|
||||
|
||||
result = "\n\n".join(selected)
|
||||
print(f"✅ 智能选择: {selected_count}个上下文, 总长度: {total_chars}字符")
|
||||
return result
|
||||
|
||||
|
||||
def get_search_params_by_category(category: str):
|
||||
"""根据问题类别调整检索参数"""
|
||||
params_map = {
|
||||
"Multi-Hop": {"limit": 20, "max_chars": 15000},
|
||||
"Temporal": {"limit": 16, "max_chars": 10000},
|
||||
"Open Domain": {"limit": 24, "max_chars": 18000},
|
||||
"Single-Hop": {"limit": 12, "max_chars": 8000},
|
||||
}
|
||||
return params_map.get(category, {"limit": 16, "max_chars": 12000})
|
||||
|
||||
|
||||
async def run_locomo_eval(
|
||||
sample_size: int = 1,
|
||||
group_id: str | None = None,
|
||||
search_limit: int = 8,
|
||||
context_char_budget: int = 4000, # 保持默认值不变
|
||||
llm_temperature: float = 0.0,
|
||||
llm_max_tokens: int = 32,
|
||||
search_type: str = "hybrid", # 保持默认值不变
|
||||
output_path: str | None = None,
|
||||
skip_ingest_if_exists: bool = True,
|
||||
llm_timeout: float = 10.0,
|
||||
llm_max_retries: int = 1
|
||||
) -> Dict[str, Any]:
|
||||
|
||||
# 函数内部使用三路检索逻辑,但保持参数签名不变
|
||||
group_id = group_id or SELECTED_GROUP_ID
|
||||
data_path = os.path.join(PROJECT_ROOT, "data", "locomo10.json")
|
||||
if not os.path.exists(data_path):
|
||||
data_path = os.path.join(os.getcwd(), "data", "locomo10.json")
|
||||
with open(data_path, "r", encoding="utf-8") as f:
|
||||
raw = json.load(f)
|
||||
# LoCoMo 数据结构:顶层为若干对象,每个对象下有 qa 列表
|
||||
qa_items: List[Dict[str, Any]] = []
|
||||
if isinstance(raw, list):
|
||||
for entry in raw:
|
||||
qa_items.extend(entry.get("qa", []))
|
||||
else:
|
||||
qa_items.extend(raw.get("qa", []))
|
||||
items: List[Dict[str, Any]] = qa_items[:sample_size]
|
||||
|
||||
# === 保持原来的数据摄入逻辑 ===
|
||||
entries = raw if isinstance(raw, list) else [raw]
|
||||
|
||||
# 只摄入前1条对话(保持原样)
|
||||
max_dialogues_to_ingest = 1
|
||||
contents: List[str] = []
|
||||
print(f"📊 找到 {len(entries)} 个对话对象,只摄入前 {max_dialogues_to_ingest} 条")
|
||||
|
||||
for i, entry in enumerate(entries[:max_dialogues_to_ingest]):
|
||||
if not isinstance(entry, dict):
|
||||
continue
|
||||
|
||||
conv = entry.get("conversation", {})
|
||||
sample_id = entry.get("sample_id", f"unknown_{i}")
|
||||
|
||||
print(f"🔍 处理对话 {i+1}: {sample_id}")
|
||||
|
||||
lines: List[str] = []
|
||||
if isinstance(conv, dict):
|
||||
# 收集所有 session_* 的消息
|
||||
session_count = 0
|
||||
for key, val in conv.items():
|
||||
if isinstance(val, list) and key.startswith("session_"):
|
||||
session_count += 1
|
||||
for msg in val:
|
||||
role = msg.get("speaker") or "用户"
|
||||
text = msg.get("text") or ""
|
||||
text = str(text).strip()
|
||||
if not text:
|
||||
continue
|
||||
lines.append(f"{role}: {text}")
|
||||
|
||||
print(f" - 包含 {session_count} 个session, {len(lines)} 条消息")
|
||||
|
||||
if not lines:
|
||||
print(f"⚠️ 警告: 对话 {sample_id} 没有对话内容,跳过摄入")
|
||||
continue
|
||||
|
||||
contents.append("\n".join(lines))
|
||||
|
||||
print(f"📥 总共摄入 {len(contents)} 个对话的conversation内容")
|
||||
|
||||
# 选择要评测的QA对(从所有对话中选取)
|
||||
indexed_items: List[tuple[int, Dict[str, Any]]] = []
|
||||
if isinstance(raw, list):
|
||||
for e_idx, entry in enumerate(raw):
|
||||
for qa in entry.get("qa", []):
|
||||
indexed_items.append((e_idx, qa))
|
||||
else:
|
||||
for qa in raw.get("qa", []):
|
||||
indexed_items.append((0, qa))
|
||||
|
||||
# 这里使用sample_size来限制评测的QA数量
|
||||
selected = indexed_items[:sample_size]
|
||||
items: List[Dict[str, Any]] = [qa for _, qa in selected]
|
||||
|
||||
print(f"🎯 将评测 {len(items)} 个QA对,数据库中只包含 {len(contents)} 个对话")
|
||||
# === 修改结束 ===
|
||||
|
||||
connector = Neo4jConnector()
|
||||
|
||||
# 关键修复:强制重新摄入纯净的对话数据
|
||||
print("🔄 强制重新摄入纯净的对话数据...")
|
||||
await ingest_contexts_via_full_pipeline(contents, group_id, save_chunk_output=True)
|
||||
|
||||
# 使用异步LLM客户端
|
||||
with get_db_context() as db:
|
||||
factory = MemoryClientFactory(db)
|
||||
llm_client = factory.get_llm_client(SELECTED_LLM_ID)
|
||||
# 初始化embedder用于直接调用
|
||||
with get_db_context() as db:
|
||||
config_service = MemoryConfigService(db)
|
||||
cfg_dict = config_service.get_embedder_config(SELECTED_EMBEDDING_ID)
|
||||
embedder = OpenAIEmbedderClient(
|
||||
model_config=RedBearModelConfig.model_validate(cfg_dict)
|
||||
)
|
||||
|
||||
# connector initialized above
|
||||
latencies_llm: List[float] = []
|
||||
latencies_search: List[float] = []
|
||||
# 上下文诊断收集
|
||||
per_query_context_counts: List[int] = []
|
||||
per_query_context_avg_tokens: List[float] = []
|
||||
per_query_context_chars: List[int] = []
|
||||
per_query_context_tokens_total: List[int] = []
|
||||
# 详细样本调试信息
|
||||
samples: List[Dict[str, Any]] = []
|
||||
# 通用指标
|
||||
f1s: List[float] = []
|
||||
b1s: List[float] = []
|
||||
jss: List[float] = []
|
||||
# 参考 LoCoMo 评测的类别专用 F1(multi-hop 使用多答案 F1)
|
||||
loc_f1s: List[float] = []
|
||||
# Per-category aggregation
|
||||
cat_counts: Dict[str, int] = {}
|
||||
cat_f1s: Dict[str, List[float]] = {}
|
||||
cat_b1s: Dict[str, List[float]] = {}
|
||||
cat_jss: Dict[str, List[float]] = {}
|
||||
cat_loc_f1s: Dict[str, List[float]] = {}
|
||||
try:
|
||||
for item in items:
|
||||
q = item.get("question", "")
|
||||
ref = item.get("answer", "")
|
||||
# 确保答案是字符串
|
||||
ref_str = str(ref) if ref is not None else ""
|
||||
cat = get_category_label(item)
|
||||
|
||||
print(f"\n=== 处理问题: {q} ===")
|
||||
|
||||
# 根据类别调整检索参数
|
||||
search_params = get_search_params_by_category(cat)
|
||||
adjusted_limit = search_params["limit"]
|
||||
max_chars = search_params["max_chars"]
|
||||
|
||||
print(f"🏷️ 类别: {cat}, 检索参数: limit={adjusted_limit}, max_chars={max_chars}")
|
||||
|
||||
# 改进的检索逻辑:使用三路检索(statements, dialogues, entities)
|
||||
t0 = time.time()
|
||||
contexts_all: List[str] = []
|
||||
search_results = None # 保存完整的检索结果
|
||||
|
||||
try:
|
||||
if search_type == "embedding":
|
||||
# 直接调用嵌入检索,包含三路数据
|
||||
search_results = await search_graph_by_embedding(
|
||||
connector=connector,
|
||||
embedder_client=embedder,
|
||||
query_text=q,
|
||||
group_id=group_id,
|
||||
limit=adjusted_limit,
|
||||
include=["chunks", "statements", "entities", "summaries"], # 修复:使用正确的类型
|
||||
)
|
||||
chunks = search_results.get("chunks", [])
|
||||
statements = search_results.get("statements", [])
|
||||
entities = search_results.get("entities", [])
|
||||
summaries = search_results.get("summaries", [])
|
||||
|
||||
print(f"✅ 嵌入检索成功: {len(chunks)} chunks, {len(statements)} 条陈述, {len(entities)} 个实体, {len(summaries)} 个摘要")
|
||||
|
||||
# 构建上下文:优先使用 chunks、statements 和 summaries
|
||||
for c in chunks:
|
||||
content = str(c.get("content", "")).strip()
|
||||
if content:
|
||||
contexts_all.append(content)
|
||||
|
||||
for s in statements:
|
||||
stmt_text = str(s.get("statement", "")).strip()
|
||||
if stmt_text:
|
||||
contexts_all.append(stmt_text)
|
||||
|
||||
for sm in summaries:
|
||||
summary_text = str(sm.get("summary", "")).strip()
|
||||
if summary_text:
|
||||
contexts_all.append(summary_text)
|
||||
|
||||
# 实体摘要:最多加入前3个高分实体,避免噪声
|
||||
scored = [e for e in entities if e.get("score") is not None]
|
||||
top_entities = sorted(scored, key=lambda x: x.get("score", 0), reverse=True)[:3] if scored else entities[:3]
|
||||
if top_entities:
|
||||
summary_lines = []
|
||||
for e in top_entities:
|
||||
name = str(e.get("name", "")).strip()
|
||||
etype = str(e.get("entity_type", "")).strip()
|
||||
score = e.get("score")
|
||||
if name:
|
||||
meta = []
|
||||
if etype:
|
||||
meta.append(f"type={etype}")
|
||||
if isinstance(score, (int, float)):
|
||||
meta.append(f"score={score:.3f}")
|
||||
summary_lines.append(f"EntitySummary: {name}{(' [' + '; '.join(meta) + ']') if meta else ''}")
|
||||
if summary_lines:
|
||||
contexts_all.append("\n".join(summary_lines))
|
||||
|
||||
elif search_type == "keyword":
|
||||
# 直接调用关键词检索
|
||||
search_results = await search_graph(
|
||||
connector=connector,
|
||||
q=q,
|
||||
group_id=group_id,
|
||||
limit=adjusted_limit
|
||||
)
|
||||
dialogs = search_results.get("dialogues", [])
|
||||
statements = search_results.get("statements", [])
|
||||
entities = search_results.get("entities", [])
|
||||
print(f"🔤 关键词检索找到 {len(dialogs)} 条对话, {len(statements)} 条陈述, {len(entities)} 个实体")
|
||||
|
||||
# 构建上下文
|
||||
for d in dialogs:
|
||||
content = str(d.get("content", "")).strip()
|
||||
if content:
|
||||
contexts_all.append(content)
|
||||
for s in statements:
|
||||
stmt_text = str(s.get("statement", "")).strip()
|
||||
if stmt_text:
|
||||
contexts_all.append(stmt_text)
|
||||
# 实体处理(关键词检索的实体可能没有分数)
|
||||
if entities:
|
||||
entity_names = [str(e.get("name", "")).strip() for e in entities[:5] if e.get("name")]
|
||||
if entity_names:
|
||||
contexts_all.append(f"EntitySummary: {', '.join(entity_names)}")
|
||||
|
||||
else: # hybrid
|
||||
# 🎯 关键修复:混合检索使用更严格的回退机制
|
||||
print("🔀 使用混合检索(带回退机制)...")
|
||||
try:
|
||||
search_results = await run_hybrid_search(
|
||||
query_text=q,
|
||||
search_type=search_type,
|
||||
group_id=group_id,
|
||||
limit=adjusted_limit,
|
||||
include=["chunks", "statements", "entities", "summaries"],
|
||||
output_path=None,
|
||||
)
|
||||
|
||||
# 🎯 关键修复:正确处理混合检索的扁平结构
|
||||
# 新的API返回扁平结构,直接从顶层获取结果
|
||||
if search_results and isinstance(search_results, dict):
|
||||
# 新API返回扁平结构:直接从顶层获取
|
||||
chunks = search_results.get("chunks", [])
|
||||
statements = search_results.get("statements", [])
|
||||
entities = search_results.get("entities", [])
|
||||
summaries = search_results.get("summaries", [])
|
||||
|
||||
# 检查是否有有效结果
|
||||
if chunks or statements or entities or summaries:
|
||||
print(f"✅ 混合检索成功: {len(chunks)} chunks, {len(statements)} 陈述, {len(entities)} 实体, {len(summaries)} 摘要")
|
||||
else:
|
||||
# 如果顶层没有结果,尝试旧的嵌套结构(向后兼容)
|
||||
reranked = search_results.get("reranked_results", {})
|
||||
if reranked and isinstance(reranked, dict):
|
||||
chunks = reranked.get("chunks", [])
|
||||
statements = reranked.get("statements", [])
|
||||
entities = reranked.get("entities", [])
|
||||
summaries = reranked.get("summaries", [])
|
||||
print(f"✅ 混合检索成功(使用旧格式reranked结果): {len(chunks)} chunks, {len(statements)} 陈述")
|
||||
else:
|
||||
raise ValueError("混合检索返回空结果")
|
||||
else:
|
||||
raise ValueError("混合检索返回空结果")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ 混合检索失败: {e},回退到嵌入检索")
|
||||
search_results = await search_graph_by_embedding(
|
||||
connector=connector,
|
||||
embedder_client=embedder,
|
||||
query_text=q,
|
||||
group_id=group_id,
|
||||
limit=adjusted_limit,
|
||||
include=["chunks", "statements", "entities", "summaries"],
|
||||
)
|
||||
chunks = search_results.get("chunks", [])
|
||||
statements = search_results.get("statements", [])
|
||||
entities = search_results.get("entities", [])
|
||||
summaries = search_results.get("summaries", [])
|
||||
print(f"✅ 回退嵌入检索成功: {len(chunks)} chunks, {len(statements)} 陈述")
|
||||
|
||||
# 🎯 统一处理:构建上下文(所有检索类型共用)
|
||||
for c in chunks:
|
||||
content = str(c.get("content", "")).strip()
|
||||
if content:
|
||||
contexts_all.append(content)
|
||||
|
||||
for s in statements:
|
||||
stmt_text = str(s.get("statement", "")).strip()
|
||||
if stmt_text:
|
||||
contexts_all.append(stmt_text)
|
||||
|
||||
for sm in summaries:
|
||||
summary_text = str(sm.get("summary", "")).strip()
|
||||
if summary_text:
|
||||
contexts_all.append(summary_text)
|
||||
|
||||
# 实体摘要:最多加入前3个高分实体
|
||||
if entities:
|
||||
scored = [e for e in entities if e.get("score") is not None]
|
||||
top_entities = sorted(scored, key=lambda x: x.get("score", 0), reverse=True)[:3] if scored else entities[:3]
|
||||
if top_entities:
|
||||
summary_lines = []
|
||||
for e in top_entities:
|
||||
name = str(e.get("name", "")).strip()
|
||||
etype = str(e.get("entity_type", "")).strip()
|
||||
score = e.get("score")
|
||||
if name:
|
||||
meta = []
|
||||
if etype:
|
||||
meta.append(f"type={etype}")
|
||||
if isinstance(score, (int, float)):
|
||||
meta.append(f"score={score:.3f}")
|
||||
summary_lines.append(f"EntitySummary: {name}{(' [' + '; '.join(meta) + ']') if meta else ''}")
|
||||
if summary_lines:
|
||||
contexts_all.append("\n".join(summary_lines))
|
||||
|
||||
# 关键修复:过滤掉包含当前问题答案的上下文
|
||||
filtered_contexts = []
|
||||
for context in contexts_all:
|
||||
content = str(context)
|
||||
# 排除包含当前问题标准答案的上下文
|
||||
if ref_str and ref_str.strip() and ref_str.strip() in content:
|
||||
print("🚫 过滤掉包含标准答案的上下文")
|
||||
continue
|
||||
filtered_contexts.append(context)
|
||||
|
||||
print(f"📊 过滤后保留 {len(filtered_contexts)} 个上下文 (原 {len(contexts_all)} 个)")
|
||||
contexts_all = filtered_contexts
|
||||
|
||||
# 输出完整的检索结果信息
|
||||
print("🔍 检索结果详情:")
|
||||
if search_results:
|
||||
output_data = {
|
||||
"statements": [
|
||||
{
|
||||
"statement": s.get("statement", "")[:200] + "..." if len(s.get("statement", "")) > 200 else s.get("statement", ""),
|
||||
"score": s.get("score", 0.0)
|
||||
}
|
||||
for s in (statements[:2] if 'statements' in locals() else [])
|
||||
],
|
||||
"dialogues": [
|
||||
{
|
||||
"uuid": d.get("uuid", ""),
|
||||
"group_id": d.get("group_id", ""),
|
||||
"content": d.get("content", "")[:200] + "..." if len(d.get("content", "")) > 200 else d.get("content", ""),
|
||||
"score": d.get("score", 0.0)
|
||||
}
|
||||
for d in (dialogs[:2] if 'dialogs' in locals() else [])
|
||||
],
|
||||
"entities": [
|
||||
{
|
||||
"name": e.get("name", ""),
|
||||
"entity_type": e.get("entity_type", ""),
|
||||
"score": e.get("score", 0.0)
|
||||
}
|
||||
for e in (entities[:2] if 'entities' in locals() else [])
|
||||
]
|
||||
}
|
||||
print(json.dumps(output_data, ensure_ascii=False, indent=2))
|
||||
else:
|
||||
print(" 无检索结果")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ {search_type}检索失败: {e}")
|
||||
contexts_all = []
|
||||
search_results = None
|
||||
|
||||
t1 = time.time()
|
||||
latencies_search.append((t1 - t0) * 1000)
|
||||
|
||||
# 使用智能上下文选择
|
||||
context_text = ""
|
||||
if contexts_all:
|
||||
context_text = smart_context_selection(contexts_all, q, max_chars=max_chars)
|
||||
|
||||
# 如果智能选择后仍然过长,进行最终保护性截断
|
||||
if len(context_text) > max_chars:
|
||||
print(f"⚠️ 智能选择后仍然过长 ({len(context_text)}字符),进行最终截断")
|
||||
context_text = context_text[:max_chars] + "\n\n[最终截断...]"
|
||||
|
||||
# 时间解析
|
||||
anchor_date = datetime(2023, 5, 8) # 使用固定日期确保一致性
|
||||
context_text = _resolve_relative_times(context_text, anchor_date)
|
||||
|
||||
context_text = f"Reference date: {anchor_date.date().isoformat()}\n\n" + context_text
|
||||
|
||||
print(f"📝 最终上下文长度: {len(context_text)} 字符")
|
||||
|
||||
# 显示不同上下文的预览
|
||||
print("🔍 上下文预览:")
|
||||
for j, context in enumerate(contexts_all[:3]): # 显示前3个上下文
|
||||
preview = context[:150].replace('\n', ' ')
|
||||
print(f" 上下文{j+1}: {preview}...")
|
||||
|
||||
else:
|
||||
print("❌ 没有检索到有效上下文")
|
||||
context_text = "No relevant context found."
|
||||
|
||||
# 记录上下文诊断信息
|
||||
per_query_context_counts.append(len(contexts_all))
|
||||
per_query_context_avg_tokens.append(avg_context_tokens([context_text]))
|
||||
per_query_context_chars.append(len(context_text))
|
||||
per_query_context_tokens_total.append(len(_loc_normalize(context_text).split()))
|
||||
|
||||
# LLM 提示词
|
||||
messages = [
|
||||
{"role": "system", "content": (
|
||||
"You are a precise QA assistant. Answer following these rules:\n"
|
||||
"1) Extract the EXACT information mentioned in the context\n"
|
||||
"2) For time questions: calculate actual dates from relative times\n"
|
||||
"3) Return ONLY the answer text in simplest form\n"
|
||||
"4) For dates, use format 'DD Month YYYY' (e.g., '7 May 2023')\n"
|
||||
"5) If no clear answer found, respond with 'Unknown'"
|
||||
)},
|
||||
{"role": "user", "content": f"Question: {q}\n\nContext:\n{context_text}"},
|
||||
]
|
||||
|
||||
t2 = time.time()
|
||||
# 使用异步调用
|
||||
resp = await llm_client.chat(messages=messages)
|
||||
t3 = time.time()
|
||||
latencies_llm.append((t3 - t2) * 1000)
|
||||
|
||||
# 兼容不同的响应格式
|
||||
pred = resp.content.strip() if hasattr(resp, 'content') else (resp["choices"][0]["message"]["content"].strip() if isinstance(resp, dict) else "Unknown")
|
||||
|
||||
# 计算指标(确保使用字符串)
|
||||
f1_val = common_f1(str(pred), ref_str)
|
||||
b1_val = bleu1(str(pred), ref_str)
|
||||
j_val = jaccard(str(pred), ref_str)
|
||||
|
||||
f1s.append(f1_val)
|
||||
b1s.append(b1_val)
|
||||
jss.append(j_val)
|
||||
|
||||
# Accumulate by category
|
||||
cat_counts[cat] = cat_counts.get(cat, 0) + 1
|
||||
cat_f1s.setdefault(cat, []).append(f1_val)
|
||||
cat_b1s.setdefault(cat, []).append(b1_val)
|
||||
cat_jss.setdefault(cat, []).append(j_val)
|
||||
|
||||
# LoCoMo 专用 F1:multi-hop(1) 使用多答案 F1,其它(2/3/4)使用单答案 F1
|
||||
if item.get("category") in [2, 3, 4]:
|
||||
loc_val = loc_f1_score(str(pred), ref_str)
|
||||
elif item.get("category") in [1]:
|
||||
loc_val = loc_multi_f1(str(pred), ref_str)
|
||||
else:
|
||||
loc_val = loc_f1_score(str(pred), ref_str)
|
||||
loc_f1s.append(loc_val)
|
||||
cat_loc_f1s.setdefault(cat, []).append(loc_val)
|
||||
|
||||
# 保存完整的检索结果信息
|
||||
samples.append({
|
||||
"question": q,
|
||||
"answer": ref_str,
|
||||
"category": cat,
|
||||
"prediction": pred,
|
||||
"metrics": {
|
||||
"f1": f1_val,
|
||||
"b1": b1_val,
|
||||
"j": j_val,
|
||||
"loc_f1": loc_val
|
||||
},
|
||||
"retrieval": {
|
||||
"retrieved_documents": len(contexts_all),
|
||||
"context_length": len(context_text),
|
||||
"search_limit": adjusted_limit,
|
||||
"max_chars": max_chars
|
||||
},
|
||||
"timing": {
|
||||
"search_ms": (t1 - t0) * 1000,
|
||||
"llm_ms": (t3 - t2) * 1000
|
||||
}
|
||||
})
|
||||
|
||||
print(f"🤖 LLM 回答: {pred}")
|
||||
print(f"✅ 正确答案: {ref_str}")
|
||||
print(f"📈 当前指标 - F1: {f1_val:.3f}, BLEU-1: {b1_val:.3f}, Jaccard: {j_val:.3f}, LoCoMo F1: {loc_val:.3f}")
|
||||
|
||||
# Compute per-category averages and dispersion (std, iqr)
|
||||
def _percentile(sorted_vals: List[float], p: float) -> float:
|
||||
if not sorted_vals:
|
||||
return 0.0
|
||||
if len(sorted_vals) == 1:
|
||||
return sorted_vals[0]
|
||||
k = (len(sorted_vals) - 1) * p
|
||||
f = int(k)
|
||||
c = f + 1 if f + 1 < len(sorted_vals) else f
|
||||
if f == c:
|
||||
return sorted_vals[f]
|
||||
return sorted_vals[f] + (sorted_vals[c] - sorted_vals[f]) * (k - f)
|
||||
|
||||
by_category: Dict[str, Dict[str, float | int]] = {}
|
||||
for c in cat_counts:
|
||||
f_list = cat_f1s.get(c, [])
|
||||
b_list = cat_b1s.get(c, [])
|
||||
j_list = cat_jss.get(c, [])
|
||||
lf_list = cat_loc_f1s.get(c, [])
|
||||
j_sorted = sorted(j_list)
|
||||
j_std = statistics.stdev(j_list) if len(j_list) > 1 else 0.0
|
||||
j_q75 = _percentile(j_sorted, 0.75)
|
||||
j_q25 = _percentile(j_sorted, 0.25)
|
||||
by_category[c] = {
|
||||
"count": cat_counts[c],
|
||||
"f1": (sum(f_list) / max(len(f_list), 1)) if f_list else 0.0,
|
||||
"b1": (sum(b_list) / max(len(b_list), 1)) if b_list else 0.0,
|
||||
"j": (sum(j_list) / max(len(j_list), 1)) if j_list else 0.0,
|
||||
"j_std": j_std,
|
||||
"j_iqr": (j_q75 - j_q25) if j_list else 0.0,
|
||||
# 参考 LoCoMo 评测的类别专用 F1
|
||||
"loc_f1": (sum(lf_list) / max(len(lf_list), 1)) if lf_list else 0.0,
|
||||
}
|
||||
|
||||
# 累加命中(cum accuracy by category):与 evaluation_stats.py 输出形式相仿
|
||||
cum_accuracy_by_category = {c: sum(cat_loc_f1s.get(c, [])) for c in cat_counts}
|
||||
|
||||
result = {
|
||||
"dataset": "locomo",
|
||||
"items": len(items),
|
||||
"metrics": {
|
||||
"f1": sum(f1s) / max(len(f1s), 1),
|
||||
"b1": sum(b1s) / max(len(b1s), 1),
|
||||
"j": sum(jss) / max(len(jss), 1),
|
||||
# LoCoMo 类别专用 F1 的总体
|
||||
"loc_f1": sum(loc_f1s) / max(len(loc_f1s), 1),
|
||||
},
|
||||
"by_category": by_category,
|
||||
"category_counts": cat_counts,
|
||||
"cum_accuracy_by_category": cum_accuracy_by_category,
|
||||
"context": {
|
||||
"avg_tokens": (sum(per_query_context_avg_tokens) / max(len(per_query_context_avg_tokens), 1)) if per_query_context_avg_tokens else 0.0,
|
||||
"avg_chars": (sum(per_query_context_chars) / max(len(per_query_context_chars), 1)) if per_query_context_chars else 0.0,
|
||||
"count_avg": (sum(per_query_context_counts) / max(len(per_query_context_counts), 1)) if per_query_context_counts else 0.0,
|
||||
"avg_memory_tokens": (sum(per_query_context_tokens_total) / max(len(per_query_context_tokens_total), 1)) if per_query_context_tokens_total else 0.0,
|
||||
},
|
||||
"latency": {
|
||||
"search": latency_stats(latencies_search),
|
||||
"llm": latency_stats(latencies_llm),
|
||||
},
|
||||
"samples": samples,
|
||||
"params": {
|
||||
"group_id": group_id,
|
||||
"search_limit": search_limit,
|
||||
"context_char_budget": context_char_budget,
|
||||
"search_type": search_type,
|
||||
"llm_id": SELECTED_LLM_ID,
|
||||
"retrieval_embedding_id": SELECTED_EMBEDDING_ID,
|
||||
"skip_ingest_if_exists": skip_ingest_if_exists,
|
||||
"llm_timeout": llm_timeout,
|
||||
"llm_max_retries": llm_max_retries,
|
||||
"llm_temperature": llm_temperature,
|
||||
"llm_max_tokens": llm_max_tokens
|
||||
},
|
||||
"timestamp": datetime.now().isoformat()
|
||||
}
|
||||
if output_path:
|
||||
try:
|
||||
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
||||
with open(output_path, "w", encoding="utf-8") as f:
|
||||
json.dump(result, f, ensure_ascii=False, indent=2)
|
||||
print(f"✅ 结果已保存到: {output_path}")
|
||||
except Exception as e:
|
||||
print(f"❌ 保存结果失败: {e}")
|
||||
return result
|
||||
finally:
|
||||
await connector.close()
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Run LoCoMo evaluation with Qwen search")
|
||||
parser.add_argument("--sample_size", type=int, default=1, help="Number of samples to evaluate")
|
||||
parser.add_argument("--group_id", type=str, default=None, help="Group ID for retrieval")
|
||||
parser.add_argument("--search_limit", type=int, default=8, help="Search limit per query")
|
||||
parser.add_argument("--context_char_budget", type=int, default=12000, help="Max characters for context")
|
||||
parser.add_argument("--llm_temperature", type=float, default=0.0, help="LLM temperature")
|
||||
parser.add_argument("--llm_max_tokens", type=int, default=32, help="LLM max tokens")
|
||||
parser.add_argument("--search_type", type=str, default="embedding", choices=["keyword", "embedding", "hybrid"], help="Search type")
|
||||
parser.add_argument("--output_path", type=str, default=None, help="Output path for results")
|
||||
parser.add_argument("--skip_ingest_if_exists", action="store_true", help="Skip ingest if group exists")
|
||||
parser.add_argument("--llm_timeout", type=float, default=10.0, help="LLM timeout in seconds")
|
||||
parser.add_argument("--llm_max_retries", type=int, default=1, help="LLM max retries")
|
||||
args = parser.parse_args()
|
||||
|
||||
load_dotenv()
|
||||
|
||||
result = asyncio.run(run_locomo_eval(
|
||||
sample_size=args.sample_size,
|
||||
group_id=args.group_id,
|
||||
search_limit=args.search_limit,
|
||||
context_char_budget=args.context_char_budget,
|
||||
llm_temperature=args.llm_temperature,
|
||||
llm_max_tokens=args.llm_max_tokens,
|
||||
search_type=args.search_type,
|
||||
output_path=args.output_path,
|
||||
skip_ingest_if_exists=args.skip_ingest_if_exists,
|
||||
llm_timeout=args.llm_timeout,
|
||||
llm_max_retries=args.llm_max_retries
|
||||
))
|
||||
|
||||
print("\n" + "="*50)
|
||||
print("📊 最终评测结果:")
|
||||
print(f" 样本数量: {result['items']}")
|
||||
print(f" F1: {result['metrics']['f1']:.3f}")
|
||||
print(f" BLEU-1: {result['metrics']['b1']:.3f}")
|
||||
print(f" Jaccard: {result['metrics']['j']:.3f}")
|
||||
print(f" LoCoMo F1: {result['metrics']['loc_f1']:.3f}")
|
||||
print(f" 平均上下文长度: {result['context']['avg_chars']:.0f} 字符")
|
||||
print(f" 平均检索延迟: {result['latency']['search']['mean']:.1f}ms")
|
||||
print(f" 平均LLM延迟: {result['latency']['llm']['mean']:.1f}ms")
|
||||
|
||||
if result['by_category']:
|
||||
print("\n📈 按类别细分:")
|
||||
for cat, metrics in result['by_category'].items():
|
||||
print(f" {cat}:")
|
||||
print(f" 样本数: {metrics['count']}")
|
||||
print(f" F1: {metrics['f1']:.3f}")
|
||||
print(f" LoCoMo F1: {metrics['loc_f1']:.3f}")
|
||||
print(f" Jaccard: {metrics['j']:.3f} (±{metrics['j_std']:.3f}, IQR={metrics['j_iqr']:.3f})")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -1,324 +0,0 @@
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Any, Dict, List
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.schemas.memory_config_schema import MemoryConfig
|
||||
|
||||
try:
|
||||
from dotenv import load_dotenv
|
||||
except Exception:
|
||||
def load_dotenv():
|
||||
return None
|
||||
|
||||
from app.core.memory.evaluation.common.metrics import (
|
||||
avg_context_tokens,
|
||||
exact_match,
|
||||
latency_stats,
|
||||
)
|
||||
from app.core.memory.evaluation.extraction_utils import (
|
||||
ingest_contexts_via_full_pipeline,
|
||||
)
|
||||
from app.core.memory.storage_services.search import run_hybrid_search
|
||||
from app.core.memory.utils.config.definitions import (
|
||||
PROJECT_ROOT,
|
||||
SELECTED_EMBEDDING_ID,
|
||||
SELECTED_GROUP_ID,
|
||||
SELECTED_LLM_ID,
|
||||
)
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.db import get_db_context
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
|
||||
|
||||
def smart_context_selection(contexts: List[str], question: str, max_chars: int = 4000) -> str:
|
||||
"""基于问题关键词对上下文进行评分选择,并在预算内拼接文本。"""
|
||||
if not contexts:
|
||||
return ""
|
||||
import re
|
||||
# 提取问题关键词(移除停用词)
|
||||
question_lower = (question or "").lower()
|
||||
stop_words = {
|
||||
'what','when','where','who','why','how','did','do','does','is','are','was','were',
|
||||
'the','a','an','and','or','but'
|
||||
}
|
||||
question_words = set(re.findall(r"\b\w+\b", question_lower))
|
||||
question_words = {w for w in question_words if w not in stop_words and len(w) > 2}
|
||||
|
||||
# 评分
|
||||
scored = []
|
||||
for i, ctx in enumerate(contexts):
|
||||
ctx_lower = (ctx or "").lower()
|
||||
score = 0
|
||||
matches = 0
|
||||
for w in question_words:
|
||||
if w in ctx_lower:
|
||||
matches += 1
|
||||
score += ctx_lower.count(w) * 2
|
||||
length = len(ctx)
|
||||
if 100 < length < 2000:
|
||||
score += 5
|
||||
elif length >= 2000:
|
||||
score += 2
|
||||
if i < 3:
|
||||
score += 3
|
||||
scored.append((score, ctx, matches))
|
||||
|
||||
scored.sort(key=lambda x: x[0], reverse=True)
|
||||
|
||||
# 选择直到达到字符限制,必要时截断包含关键词的段落
|
||||
selected: List[str] = []
|
||||
total = 0
|
||||
for score, ctx, _ in scored:
|
||||
if total + len(ctx) <= max_chars:
|
||||
selected.append(ctx)
|
||||
total += len(ctx)
|
||||
else:
|
||||
if score > 10 and total < max_chars - 200:
|
||||
remaining = max_chars - total
|
||||
lines = ctx.split('\n')
|
||||
rel_lines: List[str] = []
|
||||
cur = 0
|
||||
for line in lines:
|
||||
l = line.lower()
|
||||
if any(w in l for w in question_words) and cur < remaining - 50:
|
||||
rel_lines.append(line)
|
||||
cur += len(line)
|
||||
if rel_lines:
|
||||
truncated = '\n'.join(rel_lines)
|
||||
if len(truncated) > 50:
|
||||
selected.append(truncated + "\n[相关内容截断...]")
|
||||
total += len(truncated)
|
||||
break
|
||||
return "\n\n".join(selected)
|
||||
|
||||
|
||||
def build_context_from_dialog(dialog_obj: Dict[str, Any]) -> str:
|
||||
"""Compose a text context from `dialog` list in msc_self_instruct item."""
|
||||
parts: List[str] = []
|
||||
for turn in dialog_obj.get("dialog", []):
|
||||
speaker = turn.get("speaker", "")
|
||||
text = turn.get("text", "")
|
||||
if text:
|
||||
parts.append(f"{speaker}: {text}")
|
||||
return "\n".join(parts)
|
||||
|
||||
|
||||
def _combine_dialogues_for_hybrid(results: Dict[str, Any]) -> List[Dict[str, Any]]:
|
||||
"""Combine dialogues from embedding and keyword searches (embedding first)."""
|
||||
if results is None:
|
||||
return []
|
||||
emb = []
|
||||
kw = []
|
||||
if isinstance(results.get("embedding_search"), dict):
|
||||
emb = results.get("embedding_search", {}).get("dialogues", []) or []
|
||||
elif isinstance(results.get("dialogues"), list):
|
||||
emb = results.get("dialogues", []) or []
|
||||
if isinstance(results.get("keyword_search"), dict):
|
||||
kw = results.get("keyword_search", {}).get("dialogues", []) or []
|
||||
seen = set()
|
||||
merged: List[Dict[str, Any]] = []
|
||||
for d in emb:
|
||||
k = (str(d.get("uuid", "")), str(d.get("content", "")))
|
||||
if k not in seen:
|
||||
merged.append(d)
|
||||
seen.add(k)
|
||||
for d in kw:
|
||||
k = (str(d.get("uuid", "")), str(d.get("content", "")))
|
||||
if k not in seen:
|
||||
merged.append(d)
|
||||
seen.add(k)
|
||||
return merged
|
||||
|
||||
|
||||
async def run_memsciqa_eval(sample_size: int = 1, group_id: str | None = None, search_limit: int = 8, context_char_budget: int = 4000, llm_temperature: float = 0.0, llm_max_tokens: int = 64, search_type: str = "hybrid", memory_config: "MemoryConfig" = None) -> Dict[str, Any]:
|
||||
group_id = group_id or SELECTED_GROUP_ID
|
||||
# Load data
|
||||
data_path = os.path.join(PROJECT_ROOT, "data", "msc_self_instruct.jsonl")
|
||||
if not os.path.exists(data_path):
|
||||
data_path = os.path.join(os.getcwd(), "data", "msc_self_instruct.jsonl")
|
||||
with open(data_path, "r", encoding="utf-8") as f:
|
||||
lines = f.readlines()
|
||||
items: List[Dict[str, Any]] = [json.loads(l) for l in lines[:sample_size]]
|
||||
# 改为:每条样本仅摄入一个上下文(完整对话转录),避免多上下文摄入
|
||||
# 说明:memsciqa 数据集的每个样本天然只有一个对话,保持按样本一上下文的策略
|
||||
contexts: List[str] = [build_context_from_dialog(item) for item in items]
|
||||
await ingest_contexts_via_full_pipeline(contexts, group_id)
|
||||
|
||||
# LLM client (使用异步调用)
|
||||
with get_db_context() as db:
|
||||
factory = MemoryClientFactory(db)
|
||||
llm_client = factory.get_llm_client(SELECTED_LLM_ID)
|
||||
|
||||
# Evaluate each item
|
||||
connector = Neo4jConnector()
|
||||
latencies_llm: List[float] = []
|
||||
latencies_search: List[float] = []
|
||||
contexts_used: List[str] = []
|
||||
correct_flags: List[float] = []
|
||||
f1s: List[float] = []
|
||||
b1s: List[float] = []
|
||||
jss: List[float] = []
|
||||
try:
|
||||
for item in items:
|
||||
question = item.get("self_instruct", {}).get("B", "") or item.get("question", "")
|
||||
reference = item.get("self_instruct", {}).get("A", "") or item.get("answer", "")
|
||||
# 检索:对齐 locomo 的三路检索(dialogues/statements/entities)
|
||||
t0 = time.time()
|
||||
try:
|
||||
results = await run_hybrid_search(
|
||||
query_text=question,
|
||||
search_type=search_type,
|
||||
group_id=group_id,
|
||||
limit=search_limit,
|
||||
include=["dialogues", "statements", "entities"],
|
||||
output_path=None,
|
||||
memory_config=memory_config,
|
||||
)
|
||||
except Exception:
|
||||
results = None
|
||||
t1 = time.time()
|
||||
latencies_search.append((t1 - t0) * 1000)
|
||||
|
||||
# 构建上下文:包含对话、陈述和实体摘要,并智能选择
|
||||
contexts_all: List[str] = []
|
||||
if results:
|
||||
if search_type == "hybrid":
|
||||
emb = results.get("embedding_search", {}) if isinstance(results.get("embedding_search"), dict) else {}
|
||||
kw = results.get("keyword_search", {}) if isinstance(results.get("keyword_search"), dict) else {}
|
||||
emb_dialogs = emb.get("dialogues", [])
|
||||
emb_statements = emb.get("statements", [])
|
||||
emb_entities = emb.get("entities", [])
|
||||
kw_dialogs = kw.get("dialogues", [])
|
||||
kw_statements = kw.get("statements", [])
|
||||
kw_entities = kw.get("entities", [])
|
||||
all_dialogs = emb_dialogs + kw_dialogs
|
||||
all_statements = emb_statements + kw_statements
|
||||
all_entities = emb_entities + kw_entities
|
||||
|
||||
# 简单去重与限制
|
||||
seen_texts = set()
|
||||
for d in all_dialogs:
|
||||
text = str(d.get("content", "")).strip()
|
||||
if text and text not in seen_texts:
|
||||
contexts_all.append(text)
|
||||
seen_texts.add(text)
|
||||
if len(contexts_all) >= search_limit:
|
||||
break
|
||||
for s in all_statements:
|
||||
text = str(s.get("statement", "")).strip()
|
||||
if text and text not in seen_texts:
|
||||
contexts_all.append(text)
|
||||
seen_texts.add(text)
|
||||
if len(contexts_all) >= search_limit:
|
||||
break
|
||||
# 实体摘要(最多3个)
|
||||
names = []
|
||||
merged_entities = all_entities[:]
|
||||
for e in merged_entities:
|
||||
name = str(e.get("name", "")).strip()
|
||||
if name and name not in names:
|
||||
names.append(name)
|
||||
if len(names) >= 3:
|
||||
break
|
||||
if names:
|
||||
contexts_all.append("EntitySummary: " + ", ".join(names))
|
||||
else:
|
||||
dialogs = results.get("dialogues", [])
|
||||
statements = results.get("statements", [])
|
||||
entities = results.get("entities", [])
|
||||
for d in dialogs:
|
||||
text = str(d.get("content", "")).strip()
|
||||
if text:
|
||||
contexts_all.append(text)
|
||||
for s in statements:
|
||||
text = str(s.get("statement", "")).strip()
|
||||
if text:
|
||||
contexts_all.append(text)
|
||||
names = [str(e.get("name", "")).strip() for e in entities[:3] if e.get("name")]
|
||||
if names:
|
||||
contexts_all.append("EntitySummary: " + ", ".join(names))
|
||||
|
||||
# 智能选择并截断到预算
|
||||
context_text = smart_context_selection(contexts_all, question, max_chars=context_char_budget) if contexts_all else ""
|
||||
if not context_text:
|
||||
context_text = "No relevant context found."
|
||||
contexts_used.append(context_text[:200])
|
||||
|
||||
# Call LLM (使用异步调用)
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a QA assistant. Answer in English. Strictly follow: 1) If the context contains the answer, copy the shortest exact span from the context as the answer; 2) If the answer cannot be determined from the context, respond with 'Unknown'; 3) Return ONLY the answer text, no explanations."},
|
||||
{"role": "user", "content": f"Question: {question}\n\nContext:\n{context_text}"},
|
||||
]
|
||||
t2 = time.time()
|
||||
resp = await llm_client.chat(messages=messages)
|
||||
t3 = time.time()
|
||||
latencies_llm.append((t3 - t2) * 1000)
|
||||
pred = resp.content.strip() if hasattr(resp, 'content') else (resp["choices"][0]["message"]["content"].strip() if isinstance(resp, dict) else str(resp).strip())
|
||||
# Metrics: F1, BLEU-1, Jaccard; keep exact match for reference
|
||||
correct_flags.append(exact_match(pred, reference))
|
||||
from app.core.memory.evaluation.common.metrics import (
|
||||
bleu1,
|
||||
f1_score,
|
||||
jaccard,
|
||||
)
|
||||
f1s.append(f1_score(str(pred), str(reference)))
|
||||
b1s.append(bleu1(str(pred), str(reference)))
|
||||
jss.append(jaccard(str(pred), str(reference)))
|
||||
|
||||
# Aggregate metrics
|
||||
acc = sum(correct_flags) / max(len(correct_flags), 1)
|
||||
ctx_avg_tokens = avg_context_tokens(contexts_used)
|
||||
result = {
|
||||
"dataset": "memsciqa",
|
||||
"items": len(items),
|
||||
"metrics": {
|
||||
"accuracy": acc,
|
||||
# Placeholders for extensibility
|
||||
"f1": (sum(f1s) / max(len(f1s), 1)) if f1s else 0.0,
|
||||
"bleu1": (sum(b1s) / max(len(b1s), 1)) if b1s else 0.0,
|
||||
"jaccard": (sum(jss) / max(len(jss), 1)) if jss else 0.0,
|
||||
},
|
||||
"latency": {
|
||||
"search": latency_stats(latencies_search),
|
||||
"llm": latency_stats(latencies_llm),
|
||||
},
|
||||
"avg_context_tokens": ctx_avg_tokens,
|
||||
}
|
||||
return result
|
||||
finally:
|
||||
await connector.close()
|
||||
|
||||
|
||||
def main():
|
||||
load_dotenv()
|
||||
parser = argparse.ArgumentParser(description="Evaluate DMR (memsciqa) with graph search and Qwen")
|
||||
parser.add_argument("--sample-size", type=int, default=1, help="评测样本数量")
|
||||
parser.add_argument("--group-id", type=str, default=None, help="可选 group_id,默认取 runtime.json")
|
||||
parser.add_argument("--search-limit", type=int, default=8, help="每类检索最大返回数")
|
||||
parser.add_argument("--context-char-budget", type=int, default=4000, help="上下文字符预算")
|
||||
parser.add_argument("--llm-temperature", type=float, default=0.0, help="LLM 温度")
|
||||
parser.add_argument("--llm-max-tokens", type=int, default=64, help="LLM 最大生成长度")
|
||||
parser.add_argument("--search-type", type=str, choices=["keyword","embedding","hybrid"], default="hybrid", help="检索类型")
|
||||
args = parser.parse_args()
|
||||
|
||||
result = asyncio.run(
|
||||
run_memsciqa_eval(
|
||||
sample_size=args.sample_size,
|
||||
group_id=args.group_id,
|
||||
search_limit=args.search_limit,
|
||||
context_char_budget=args.context_char_budget,
|
||||
llm_temperature=args.llm_temperature,
|
||||
llm_max_tokens=args.llm_max_tokens,
|
||||
search_type=args.search_type,
|
||||
)
|
||||
)
|
||||
print(json.dumps(result, ensure_ascii=False, indent=2))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,576 +0,0 @@
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List
|
||||
|
||||
try:
|
||||
from dotenv import load_dotenv
|
||||
except Exception:
|
||||
def load_dotenv():
|
||||
return None
|
||||
|
||||
# 路径与模块导入保持与现有评估脚本一致
|
||||
import sys
|
||||
|
||||
_THIS_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||
_PROJECT_ROOT = os.path.dirname(os.path.dirname(_THIS_DIR))
|
||||
_SRC_DIR = os.path.join(_PROJECT_ROOT, "src")
|
||||
for _p in (_SRC_DIR, _PROJECT_ROOT):
|
||||
if _p not in sys.path:
|
||||
sys.path.insert(0, _p)
|
||||
|
||||
# 对齐 locomo_test 的检索逻辑:直接使用 graph_search 与 Neo4jConnector/Embedder1
|
||||
from app.core.memory.evaluation.common.metrics import (
|
||||
avg_context_tokens,
|
||||
exact_match,
|
||||
latency_stats,
|
||||
)
|
||||
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
|
||||
from app.core.memory.utils.config.definitions import (
|
||||
PROJECT_ROOT,
|
||||
SELECTED_EMBEDDING_ID,
|
||||
SELECTED_GROUP_ID,
|
||||
SELECTED_LLM_ID,
|
||||
)
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
from app.db import get_db_context
|
||||
from app.repositories.neo4j.graph_search import search_graph, search_graph_by_embedding
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
|
||||
try:
|
||||
from app.core.memory.evaluation.common.metrics import bleu1, f1_score, jaccard
|
||||
except Exception:
|
||||
# 兜底:简单实现(必要时)
|
||||
def f1_score(pred: str, ref: str) -> float:
|
||||
ps = pred.lower().split()
|
||||
rs = ref.lower().split()
|
||||
if not ps or not rs:
|
||||
return 0.0
|
||||
tp = len(set(ps) & set(rs))
|
||||
if tp == 0:
|
||||
return 0.0
|
||||
precision = tp / len(ps)
|
||||
recall = tp / len(rs)
|
||||
if precision + recall == 0:
|
||||
return 0.0
|
||||
return 2 * precision * recall / (precision + recall)
|
||||
|
||||
def bleu1(pred: str, ref: str) -> float:
|
||||
ps = pred.lower().split()
|
||||
rs = ref.lower().split()
|
||||
if not ps or not rs:
|
||||
return 0.0
|
||||
overlap = len([w for w in ps if w in rs])
|
||||
return overlap / max(len(ps), 1)
|
||||
|
||||
def jaccard(pred: str, ref: str) -> float:
|
||||
ps = set(pred.lower().split())
|
||||
rs = set(ref.lower().split())
|
||||
union = len(ps | rs)
|
||||
if union == 0:
|
||||
return 0.0
|
||||
return len(ps & rs) / union
|
||||
|
||||
|
||||
def smart_context_selection(contexts: List[str], question: str, max_chars: int = 4000) -> str:
|
||||
"""基于问题关键词对上下文进行评分选择,并在预算内拼接文本。
|
||||
|
||||
参考 evaluation/memsciqa/evaluate_qa.py 的实现,避免路径导入带来的不稳定。
|
||||
"""
|
||||
if not contexts:
|
||||
return ""
|
||||
question_lower = (question or "").lower()
|
||||
stop_words = {
|
||||
'what','when','where','who','why','how','did','do','does','is','are','was','were',
|
||||
'the','a','an','and','or','but'
|
||||
}
|
||||
question_words = set(re.findall(r"\b\w+\b", question_lower))
|
||||
question_words = {w for w in question_words if w not in stop_words and len(w) > 2}
|
||||
|
||||
scored = []
|
||||
for i, ctx in enumerate(contexts):
|
||||
ctx_lower = (ctx or "").lower()
|
||||
score = 0
|
||||
matches = 0
|
||||
for w in question_words:
|
||||
if w in ctx_lower:
|
||||
matches += 1
|
||||
score += ctx_lower.count(w) * 2
|
||||
length = len(ctx)
|
||||
if 100 < length < 2000:
|
||||
score += 5
|
||||
elif length >= 2000:
|
||||
score += 2
|
||||
if i < 3:
|
||||
score += 3
|
||||
scored.append((score, ctx, matches))
|
||||
|
||||
scored.sort(key=lambda x: x[0], reverse=True)
|
||||
|
||||
selected: List[str] = []
|
||||
total = 0
|
||||
for score, ctx, _ in scored:
|
||||
if total + len(ctx) <= max_chars:
|
||||
selected.append(ctx)
|
||||
total += len(ctx)
|
||||
else:
|
||||
if score > 10 and total < max_chars - 200:
|
||||
remaining = max_chars - total
|
||||
lines = ctx.split('\n')
|
||||
rel_lines: List[str] = []
|
||||
cur = 0
|
||||
for line in lines:
|
||||
l = line.lower()
|
||||
if any(w in l for w in question_words) and cur < remaining - 50:
|
||||
rel_lines.append(line)
|
||||
cur += len(line)
|
||||
if rel_lines:
|
||||
truncated = '\n'.join(rel_lines)
|
||||
if len(truncated) > 50:
|
||||
selected.append(truncated + "\n[相关内容截断...]")
|
||||
total += len(truncated)
|
||||
break
|
||||
return "\n\n".join(selected)
|
||||
|
||||
|
||||
def extract_question_keywords(question: str, max_keywords: int = 8) -> List[str]:
|
||||
"""提取问题中的关键词(简单英文分词,去停用词,长度>=3)。"""
|
||||
ql = (question or "").lower()
|
||||
stop_words = {
|
||||
'what','when','where','who','why','how','did','do','does','is','are','was','were',
|
||||
'the','a','an','and','or','but','of','to','in','on','for','with','from','that','this'
|
||||
}
|
||||
words = re.findall(r"\b[\w-]+\b", ql)
|
||||
kws = [w for w in words if w not in stop_words and len(w) >= 3]
|
||||
# 去重保序
|
||||
seen = set()
|
||||
uniq = []
|
||||
for w in kws:
|
||||
if w not in seen:
|
||||
uniq.append(w)
|
||||
seen.add(w)
|
||||
if len(uniq) >= max_keywords:
|
||||
break
|
||||
return uniq
|
||||
|
||||
|
||||
def analyze_contexts_simple(contexts: List[str], keywords: List[str], top_n: int = 5) -> List[Dict[str, int | float]]:
|
||||
"""对上下文进行简单相关性打分,仅用于控制台可视化。
|
||||
|
||||
评分: score = match_count*200 + min(len(text), 100000)/100
|
||||
"""
|
||||
results = []
|
||||
for ctx in contexts:
|
||||
tl = (ctx or "").lower()
|
||||
match_count = sum(1 for k in keywords if k in tl)
|
||||
length = len(ctx)
|
||||
score = match_count * 200 + min(length, 100000) / 100.0
|
||||
results.append({"score": float(f"{score:.0f}"), "match": match_count, "length": length})
|
||||
results.sort(key=lambda x: (x["score"], x["match"], x["length"]), reverse=True)
|
||||
return results[:max(top_n, 0)]
|
||||
|
||||
|
||||
# 纯测试脚本不进行摄入;若需摄入请使用 evaluate_qa.py
|
||||
|
||||
|
||||
def load_dataset_memsciqa(data_path: str) -> List[Dict[str, Any]]:
|
||||
if not os.path.exists(data_path):
|
||||
raise FileNotFoundError(f"未找到数据集: {data_path}")
|
||||
items: List[Dict[str, Any]] = []
|
||||
with open(data_path, "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
items.append(json.loads(line))
|
||||
except Exception:
|
||||
# 跳过坏行但不中断
|
||||
continue
|
||||
return items
|
||||
|
||||
|
||||
async def run_memsciqa_test(
|
||||
sample_size: int = 3,
|
||||
group_id: str | None = None,
|
||||
search_limit: int = 8,
|
||||
context_char_budget: int = 4000,
|
||||
llm_temperature: float = 0.0,
|
||||
llm_max_tokens: int = 64,
|
||||
search_type: str = "embedding",
|
||||
data_path: str | None = None,
|
||||
start_index: int = 0,
|
||||
verbose: bool = True,
|
||||
) -> Dict[str, Any]:
|
||||
"""memsciqa 增强测试脚本:结合 evaluate_qa 的三路检索与智能上下文选择。
|
||||
|
||||
- 支持从指定索引开始与评估全部样本(sample_size<=0)
|
||||
- 支持在摄入前重置组(清空图)与跳过摄入
|
||||
- 支持 keyword / embedding / hybrid 三种检索
|
||||
"""
|
||||
|
||||
# 默认使用指定的 memsci 组 ID
|
||||
group_id = group_id or "group_memsci"
|
||||
|
||||
# 数据路径解析(项目根与当前工作目录兜底)
|
||||
if not data_path:
|
||||
proj_path = os.path.join(PROJECT_ROOT, "data", "msc_self_instruct.jsonl")
|
||||
cwd_path = os.path.join(os.getcwd(), "data", "msc_self_instruct.jsonl")
|
||||
if os.path.exists(proj_path):
|
||||
data_path = proj_path
|
||||
elif os.path.exists(cwd_path):
|
||||
data_path = cwd_path
|
||||
else:
|
||||
raise FileNotFoundError("未找到数据集: data/msc_self_instruct.jsonl,请确保其存在于项目根目录或当前工作目录的 data 目录下。")
|
||||
|
||||
# 加载数据
|
||||
all_items = load_dataset_memsciqa(data_path)
|
||||
if sample_size is None or sample_size <= 0:
|
||||
items = all_items[start_index:]
|
||||
else:
|
||||
items = all_items[start_index:start_index + sample_size]
|
||||
|
||||
# 初始化 LLM(纯测试:不进行摄入)
|
||||
with get_db_context() as db:
|
||||
factory = MemoryClientFactory(db)
|
||||
llm = factory.get_llm_client(SELECTED_LLM_ID)
|
||||
|
||||
# 初始化 Neo4j 连接与向量检索 Embedder(对齐 locomo_test)
|
||||
connector = Neo4jConnector()
|
||||
embedder = None
|
||||
if search_type in ("embedding", "hybrid"):
|
||||
with get_db_context() as db:
|
||||
config_service = MemoryConfigService(db)
|
||||
cfg_dict = config_service.get_embedder_config(SELECTED_EMBEDDING_ID)
|
||||
embedder = OpenAIEmbedderClient(
|
||||
model_config=RedBearModelConfig.model_validate(cfg_dict)
|
||||
)
|
||||
|
||||
# 评估循环
|
||||
latencies_llm: List[float] = []
|
||||
latencies_search: List[float] = []
|
||||
# 存储完整上下文文本用于统计
|
||||
contexts_used: List[str] = []
|
||||
per_query_context_chars: List[int] = []
|
||||
per_query_context_counts: List[int] = []
|
||||
correct_flags: List[float] = []
|
||||
f1s: List[float] = []
|
||||
b1s: List[float] = []
|
||||
jss: List[float] = []
|
||||
samples: List[Dict[str, Any]] = []
|
||||
|
||||
total_items = len(items)
|
||||
for idx, item in enumerate(items):
|
||||
if verbose:
|
||||
print(f"\n🧪 评估样本: {idx+1}/{total_items}")
|
||||
question = item.get("self_instruct", {}).get("B", "") or item.get("question", "")
|
||||
reference = item.get("self_instruct", {}).get("A", "") or item.get("answer", "")
|
||||
|
||||
# 三路检索:chunks/statements/entities/summaries(对齐 qwen_search_eval.py)
|
||||
t0 = time.time()
|
||||
results = None
|
||||
try:
|
||||
if search_type in ("embedding", "hybrid"):
|
||||
# 使用嵌入检索(与 qwen_search_eval 对齐)
|
||||
results = await search_graph_by_embedding(
|
||||
connector=connector,
|
||||
embedder_client=embedder,
|
||||
query_text=question,
|
||||
group_id=group_id,
|
||||
limit=search_limit,
|
||||
include=["chunks", "statements", "entities", "summaries"], # 使用 chunks 而不是 dialogues
|
||||
)
|
||||
elif search_type == "keyword":
|
||||
# 关键词检索(直接调用 graph_search)
|
||||
results = await search_graph(
|
||||
connector=connector,
|
||||
q=question,
|
||||
group_id=group_id,
|
||||
limit=search_limit,
|
||||
include=["chunks", "statements", "entities", "summaries"], # 使用 chunks 而不是 dialogues
|
||||
)
|
||||
except Exception:
|
||||
results = None
|
||||
t1 = time.time()
|
||||
search_ms = (t1 - t0) * 1000
|
||||
latencies_search.append(search_ms)
|
||||
|
||||
# 构建上下文:包含 chunks、陈述、摘要和实体(对齐 qwen_search_eval.py)
|
||||
contexts_all: List[str] = []
|
||||
retrieved_counts: Dict[str, int] = {}
|
||||
if results:
|
||||
chunks = results.get("chunks", [])
|
||||
statements = results.get("statements", [])
|
||||
entities = results.get("entities", [])
|
||||
summaries = results.get("summaries", [])
|
||||
retrieved_counts = {
|
||||
"chunks": len(chunks),
|
||||
"statements": len(statements),
|
||||
"entities": len(entities),
|
||||
"summaries": len(summaries),
|
||||
}
|
||||
# 优先使用 chunks
|
||||
for c in chunks:
|
||||
text = str(c.get("content", "")).strip()
|
||||
if text:
|
||||
contexts_all.append(text)
|
||||
# 然后是 statements
|
||||
for s in statements:
|
||||
text = str(s.get("statement", "")).strip()
|
||||
if text:
|
||||
contexts_all.append(text)
|
||||
# 然后是 summaries
|
||||
for sm in summaries:
|
||||
text = str(sm.get("summary", "")).strip()
|
||||
if text:
|
||||
contexts_all.append(text)
|
||||
# 实体摘要:最多加入前3个高分实体(对齐 qwen_search_eval.py)
|
||||
scored = [e for e in entities if e.get("score") is not None]
|
||||
top_entities = sorted(scored, key=lambda x: x.get("score", 0), reverse=True)[:3] if scored else entities[:3]
|
||||
if top_entities:
|
||||
summary_lines = []
|
||||
for e in top_entities:
|
||||
name = str(e.get("name", "")).strip()
|
||||
etype = str(e.get("entity_type", "")).strip()
|
||||
score = e.get("score")
|
||||
if name:
|
||||
meta = []
|
||||
if etype:
|
||||
meta.append(f"type={etype}")
|
||||
if isinstance(score, (int, float)):
|
||||
meta.append(f"score={score:.3f}")
|
||||
summary_lines.append(f"EntitySummary: {name}{(' [' + '; '.join(meta) + ']') if meta else ''}")
|
||||
if summary_lines:
|
||||
contexts_all.append("\n".join(summary_lines))
|
||||
|
||||
if verbose:
|
||||
if retrieved_counts:
|
||||
print(f"✅ 检索成功: {retrieved_counts.get('chunks',0)} chunks, {retrieved_counts.get('statements',0)} 条陈述, {retrieved_counts.get('entities',0)} 个实体, {retrieved_counts.get('summaries',0)} 个摘要")
|
||||
print(f"📊 有效上下文数量: {len(contexts_all)}")
|
||||
q_keywords = extract_question_keywords(question, max_keywords=8)
|
||||
if q_keywords:
|
||||
print(f"🔍 问题关键词: {set(q_keywords)}")
|
||||
if contexts_all:
|
||||
analysis = analyze_contexts_simple(contexts_all, q_keywords, top_n=5)
|
||||
if analysis:
|
||||
print("📊 上下文相关性分析:")
|
||||
for a in analysis:
|
||||
print(f" - 得分: {int(a['score'])}, 关键词匹配: {a['match']}, 长度: {a['length']}")
|
||||
# 打印检索到的上下文预览,便于定位为何为 Unknown
|
||||
print("🔎 上下文预览(最多前10条,每条截断展示):")
|
||||
for i, ctx in enumerate(contexts_all[:10]):
|
||||
preview = str(ctx).replace("\n", " ")
|
||||
if len(preview) > 300:
|
||||
preview = preview[:300] + "..."
|
||||
print(f" [{i+1}] 长度: {len(ctx)} | 片段: {preview}")
|
||||
# 标注参考答案是否出现在任一上下文中
|
||||
ref_lower = (str(reference) or "").lower()
|
||||
if ref_lower:
|
||||
hits = []
|
||||
for i, ctx in enumerate(contexts_all):
|
||||
if ref_lower in str(ctx).lower():
|
||||
hits.append(i+1)
|
||||
print(f"🔗 参考答案命中上下文条数: {len(hits)}" + (f" | 命中索引: {hits}" if hits else ""))
|
||||
|
||||
context_text = smart_context_selection(contexts_all, question, max_chars=context_char_budget) if contexts_all else ""
|
||||
if not context_text:
|
||||
context_text = "No relevant context found."
|
||||
contexts_used.append(context_text)
|
||||
per_query_context_chars.append(len(context_text))
|
||||
per_query_context_counts.append(len(contexts_all))
|
||||
|
||||
if verbose:
|
||||
selected_count = (context_text.count("\n\n") + 1) if context_text else 0
|
||||
print(f"✅ 智能选择: {selected_count}个上下文, 总长度: {len(context_text)}字符")
|
||||
# 展示拼接后的上下文片段,便于核查是否包含答案
|
||||
concat_preview = context_text.replace("\n", " ")
|
||||
if len(concat_preview) > 600:
|
||||
concat_preview = concat_preview[:600] + "..."
|
||||
print(f"🧵 拼接上下文预览: {concat_preview}")
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": (
|
||||
"You are a QA assistant. Answer in English. Follow these guidelines:\n"
|
||||
"1) If the context contains information to answer the question, provide a concise answer based on the context;\n"
|
||||
"2) If the context does not contain enough information to answer the question, respond with 'Unknown';\n"
|
||||
"3) Keep your answer brief and to the point;\n"
|
||||
"4) Do not add explanations or additional text beyond the answer."
|
||||
),
|
||||
},
|
||||
{"role": "user", "content": f"Question: {question}\n\nContext:\n{context_text}"},
|
||||
]
|
||||
|
||||
t2 = time.time()
|
||||
try:
|
||||
# 使用异步调用
|
||||
resp = await llm.chat(messages=messages)
|
||||
# 更健壮的响应解析,处理不同的LLM响应格式
|
||||
if hasattr(resp, 'content'):
|
||||
pred = resp.content.strip()
|
||||
elif isinstance(resp, dict) and "choices" in resp and len(resp["choices"]) > 0:
|
||||
pred = resp["choices"][0]["message"]["content"].strip()
|
||||
elif isinstance(resp, dict) and "content" in resp:
|
||||
pred = resp["content"].strip()
|
||||
elif isinstance(resp, str):
|
||||
pred = resp.strip()
|
||||
else:
|
||||
pred = "Unknown"
|
||||
print(f"⚠️ LLM响应格式异常: {type(resp)} - {resp}")
|
||||
|
||||
# 检查预测是否为"Unknown"或空,如果是则检查上下文是否真的没有答案
|
||||
if pred.lower() in ["unknown", ""]:
|
||||
# 如果参考答案在上下文中存在,但LLM返回Unknown,可能是提示词问题
|
||||
ref_lower = (str(reference) or "").lower()
|
||||
if ref_lower and any(ref_lower in ctx.lower() for ctx in contexts_all):
|
||||
print("⚠️ 参考答案在上下文中存在但LLM返回Unknown,检查提示词")
|
||||
except Exception as e:
|
||||
# 更详细的错误处理
|
||||
pred = "Unknown"
|
||||
print(f"⚠️ LLM调用异常: {e}")
|
||||
t3 = time.time()
|
||||
llm_ms = (t3 - t2) * 1000
|
||||
latencies_llm.append(llm_ms)
|
||||
|
||||
exact = exact_match(pred, reference)
|
||||
correct_flags.append(exact)
|
||||
f1_val = f1_score(str(pred), str(reference))
|
||||
b1_val = bleu1(str(pred), str(reference))
|
||||
j_val = jaccard(str(pred), str(reference))
|
||||
f1s.append(f1_val)
|
||||
b1s.append(b1_val)
|
||||
jss.append(j_val)
|
||||
|
||||
if verbose:
|
||||
print(f"🤖 LLM 回答: {pred}")
|
||||
print(f"✅ 正确答案: {reference}")
|
||||
print(f"📈 当前指标 - F1: {f1_val:.3f}, BLEU-1: {b1_val:.3f}, Jaccard: {j_val:.3f}")
|
||||
print(f"⏱️ 延迟 - 检索: {search_ms:.0f}ms, LLM: {llm_ms:.0f}ms")
|
||||
|
||||
# 对齐 locomo/qwen_search_eval.py 的样本输出结构
|
||||
samples.append({
|
||||
"question": str(question),
|
||||
"answer": str(reference),
|
||||
"prediction": str(pred),
|
||||
"metrics": {
|
||||
"f1": f1_val,
|
||||
"b1": b1_val,
|
||||
"j": j_val
|
||||
},
|
||||
"retrieval": {
|
||||
"retrieved_documents": len(contexts_all),
|
||||
"context_length": len(context_text),
|
||||
"search_limit": search_limit,
|
||||
"max_chars": context_char_budget
|
||||
},
|
||||
"timing": {
|
||||
"search_ms": search_ms,
|
||||
"llm_ms": llm_ms
|
||||
}
|
||||
})
|
||||
|
||||
# 计算总体指标与聚合
|
||||
acc = sum(correct_flags) / max(len(correct_flags), 1)
|
||||
ctx_avg_tokens = avg_context_tokens(contexts_used)
|
||||
result = {
|
||||
"dataset": "memsciqa",
|
||||
"items": len(items),
|
||||
"metrics": {
|
||||
"f1": (sum(f1s) / max(len(f1s), 1)) if f1s else 0.0,
|
||||
"b1": (sum(b1s) / max(len(b1s), 1)) if b1s else 0.0,
|
||||
"j": (sum(jss) / max(len(jss), 1)) if jss else 0.0,
|
||||
},
|
||||
"context": {
|
||||
"avg_tokens": ctx_avg_tokens,
|
||||
"avg_chars": (sum(per_query_context_chars) / max(len(per_query_context_chars), 1)) if per_query_context_chars else 0.0,
|
||||
"count_avg": (sum(per_query_context_counts) / max(len(per_query_context_counts), 1)) if per_query_context_counts else 0.0,
|
||||
"avg_memory_tokens": 0.0
|
||||
},
|
||||
"latency": {
|
||||
"search": latency_stats(latencies_search),
|
||||
"llm": latency_stats(latencies_llm),
|
||||
},
|
||||
"samples": samples,
|
||||
"params": {
|
||||
"group_id": group_id,
|
||||
"search_limit": search_limit,
|
||||
"context_char_budget": context_char_budget,
|
||||
"llm_temperature": llm_temperature,
|
||||
"llm_max_tokens": llm_max_tokens,
|
||||
"search_type": search_type,
|
||||
"start_index": start_index,
|
||||
"llm_id": SELECTED_LLM_ID,
|
||||
"retrieval_embedding_id": SELECTED_EMBEDDING_ID
|
||||
},
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
}
|
||||
try:
|
||||
await connector.close()
|
||||
except Exception:
|
||||
pass
|
||||
return result
|
||||
|
||||
|
||||
def main():
|
||||
load_dotenv()
|
||||
parser = argparse.ArgumentParser(description="memsciqa 测试脚本(三路检索 + 智能上下文选择)")
|
||||
parser.add_argument("--sample-size", type=int, default=30, help="样本数量(<=0 表示全部)")
|
||||
parser.add_argument("--all", action="store_true", help="评估全部样本(覆盖 --sample-size)")
|
||||
parser.add_argument("--start-index", type=int, default=0, help="起始样本索引")
|
||||
parser.add_argument("--group-id", type=str, default="group_memsci", help="图数据库 Group ID(默认 group_memsci)")
|
||||
parser.add_argument("--search-limit", type=int, default=8, help="检索条数上限")
|
||||
parser.add_argument("--context-char-budget", type=int, default=4000, help="上下文字符预算")
|
||||
parser.add_argument("--llm-temperature", type=float, default=0.0, help="LLM 温度")
|
||||
parser.add_argument("--llm-max-tokens", type=int, default=64, help="LLM 最大输出 token")
|
||||
parser.add_argument("--search-type", type=str, default="embedding", choices=["embedding","keyword","hybrid"], help="检索类型(hybrid 等同于 embedding)")
|
||||
parser.add_argument("--data-path", type=str, default=None, help="数据集路径(默认 data/msc_self_instruct.jsonl)")
|
||||
parser.add_argument("--output", type=str, default=None, help="将评估结果保存到指定文件路径(JSON)")
|
||||
parser.add_argument("--verbose", action="store_true", default=True, help="打印过程日志(默认开启)")
|
||||
parser.add_argument("--quiet", action="store_true", help="关闭过程日志")
|
||||
args = parser.parse_args()
|
||||
|
||||
sample_size = 0 if args.all else args.sample_size
|
||||
|
||||
verbose_flag = False if args.quiet else args.verbose
|
||||
result = asyncio.run(
|
||||
run_memsciqa_test(
|
||||
sample_size=sample_size,
|
||||
group_id=args.group_id,
|
||||
search_limit=args.search_limit,
|
||||
context_char_budget=args.context_char_budget,
|
||||
llm_temperature=args.llm_temperature,
|
||||
llm_max_tokens=args.llm_max_tokens,
|
||||
search_type=args.search_type,
|
||||
data_path=args.data_path,
|
||||
start_index=args.start_index,
|
||||
verbose=verbose_flag,
|
||||
)
|
||||
)
|
||||
|
||||
print(json.dumps(result, ensure_ascii=False, indent=2))
|
||||
|
||||
# 结果保存
|
||||
out_path = args.output
|
||||
if not out_path:
|
||||
eval_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
dataset_results_dir = os.path.join(eval_dir, "results")
|
||||
ts = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
out_path = os.path.join(dataset_results_dir, f"memsciqa_{result['params']['search_type']}_{ts}.json")
|
||||
try:
|
||||
os.makedirs(os.path.dirname(out_path), exist_ok=True)
|
||||
with open(out_path, "w", encoding="utf-8") as f:
|
||||
json.dump(result, f, ensure_ascii=False, indent=2)
|
||||
print(f"\n💾 结果已保存: {out_path}")
|
||||
except Exception as e:
|
||||
print(f"⚠️ 结果保存失败: {e}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,150 +0,0 @@
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from typing import Any, Dict
|
||||
|
||||
# Add src directory to Python path for proper imports when running from evaluation directory
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), 'src'))
|
||||
|
||||
try:
|
||||
from dotenv import load_dotenv
|
||||
except Exception:
|
||||
def load_dotenv():
|
||||
return None
|
||||
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.core.memory.utils.config.definitions import SELECTED_GROUP_ID, PROJECT_ROOT
|
||||
|
||||
from app.core.memory.evaluation.memsciqa.evaluate_qa import run_memsciqa_eval
|
||||
from app.core.memory.evaluation.longmemeval.qwen_search_eval import run_longmemeval_test
|
||||
from app.core.memory.evaluation.locomo.qwen_search_eval import run_locomo_eval
|
||||
|
||||
|
||||
async def run(
|
||||
dataset: str,
|
||||
sample_size: int,
|
||||
reset_group: bool,
|
||||
group_id: str | None,
|
||||
judge_model: str | None = None,
|
||||
search_limit: int | None = None,
|
||||
context_char_budget: int | None = None,
|
||||
llm_temperature: float | None = None,
|
||||
llm_max_tokens: int | None = None,
|
||||
search_type: str | None = None,
|
||||
start_index: int | None = None,
|
||||
max_contexts_per_item: int | None = None,
|
||||
) -> Dict[str, Any]:
|
||||
# 恢复原始风格:统一入口做路由,并沿用各数据集既有默认
|
||||
group_id = group_id or SELECTED_GROUP_ID
|
||||
|
||||
if reset_group:
|
||||
connector = Neo4jConnector()
|
||||
try:
|
||||
await connector.delete_group(group_id)
|
||||
finally:
|
||||
await connector.close()
|
||||
|
||||
if dataset == "locomo":
|
||||
kwargs: Dict[str, Any] = {"sample_size": sample_size, "group_id": group_id}
|
||||
if search_limit is not None:
|
||||
kwargs["search_limit"] = search_limit
|
||||
if context_char_budget is not None:
|
||||
kwargs["context_char_budget"] = context_char_budget
|
||||
if llm_temperature is not None:
|
||||
kwargs["llm_temperature"] = llm_temperature
|
||||
if llm_max_tokens is not None:
|
||||
kwargs["llm_max_tokens"] = llm_max_tokens
|
||||
if search_type is not None:
|
||||
kwargs["search_type"] = search_type
|
||||
return await run_locomo_eval(**kwargs)
|
||||
|
||||
if dataset == "memsciqa":
|
||||
kwargs: Dict[str, Any] = {"sample_size": sample_size, "group_id": group_id}
|
||||
if search_limit is not None:
|
||||
kwargs["search_limit"] = search_limit
|
||||
if context_char_budget is not None:
|
||||
kwargs["context_char_budget"] = context_char_budget
|
||||
if llm_temperature is not None:
|
||||
kwargs["llm_temperature"] = llm_temperature
|
||||
if llm_max_tokens is not None:
|
||||
kwargs["llm_max_tokens"] = llm_max_tokens
|
||||
if search_type is not None:
|
||||
kwargs["search_type"] = search_type
|
||||
return await run_memsciqa_eval(**kwargs)
|
||||
|
||||
if dataset == "longmemeval":
|
||||
kwargs: Dict[str, Any] = {"sample_size": sample_size, "group_id": group_id}
|
||||
if search_limit is not None:
|
||||
kwargs["search_limit"] = search_limit
|
||||
if context_char_budget is not None:
|
||||
kwargs["context_char_budget"] = context_char_budget
|
||||
if llm_temperature is not None:
|
||||
kwargs["llm_temperature"] = llm_temperature
|
||||
if llm_max_tokens is not None:
|
||||
kwargs["llm_max_tokens"] = llm_max_tokens
|
||||
if search_type is not None:
|
||||
kwargs["search_type"] = search_type
|
||||
if start_index is not None:
|
||||
kwargs["start_index"] = start_index
|
||||
if max_contexts_per_item is not None:
|
||||
kwargs["max_contexts_per_item"] = max_contexts_per_item
|
||||
return await run_longmemeval_test(**kwargs)
|
||||
raise ValueError(f"未知数据集: {dataset}")
|
||||
|
||||
|
||||
def main():
|
||||
load_dotenv()
|
||||
parser = argparse.ArgumentParser(description="统一评估入口:memsciqa / longmemeval / locomo")
|
||||
parser.add_argument("--dataset", choices=["memsciqa", "longmemeval", "locomo"], required=True)
|
||||
parser.add_argument("--sample-size", type=int, default=1, help="先用一条数据跑通")
|
||||
parser.add_argument("--reset-group", action="store_true", help="运行前清空当前 group_id 的图数据")
|
||||
parser.add_argument("--group-id", type=str, default=None, help="可选 group_id,默认取 runtime.json")
|
||||
parser.add_argument("--judge-model", type=str, default=None, help="可选:longmemeval 判别式评测模型名")
|
||||
parser.add_argument("--search-limit", type=int, default=None, help="检索返回的对话节点数量上限(不提供则使用各脚本默认)")
|
||||
parser.add_argument("--context-char-budget", type=int, default=None, help="上下文字符预算(不提供则使用各脚本默认)")
|
||||
parser.add_argument("--llm-temperature", type=float, default=None, help="生成温度(不提供则使用各脚本默认)")
|
||||
parser.add_argument("--llm-max-tokens", type=int, default=None, help="最大生成 tokens(不提供则使用各脚本默认)")
|
||||
parser.add_argument("--search-type", type=str, default=None, choices=["keyword", "embedding", "hybrid"], help="检索类型(不提供则使用各脚本默认)")
|
||||
# 仅透传到 longmemeval;其他数据集忽略
|
||||
parser.add_argument("--start-index", type=int, default=None, help="仅 longmemeval:起始样本索引(不提供则用脚本默认)")
|
||||
parser.add_argument("--max-contexts-per-item", type=int, default=None, help="仅 longmemeval:每条样本摄入的上下文数量上限(不提供则用脚本默认)")
|
||||
parser.add_argument("--output", type=str, default=None, help="可选:将评估结果保存到指定文件路径(JSON);不提供时默认保存到 evaluation/<dataset>/results 目录")
|
||||
args = parser.parse_args()
|
||||
|
||||
result = asyncio.run(run(
|
||||
args.dataset,
|
||||
args.sample_size,
|
||||
args.reset_group,
|
||||
args.group_id,
|
||||
args.judge_model,
|
||||
args.search_limit,
|
||||
args.context_char_budget,
|
||||
args.llm_temperature,
|
||||
args.llm_max_tokens,
|
||||
args.search_type,
|
||||
args.start_index,
|
||||
args.max_contexts_per_item,
|
||||
))
|
||||
print(json.dumps(result, ensure_ascii=False, indent=2))
|
||||
|
||||
# 结果输出逻辑保持不变
|
||||
if args.output:
|
||||
out_path = args.output
|
||||
else:
|
||||
eval_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
dataset_results_dir = os.path.join(eval_dir, args.dataset, "results")
|
||||
out_filename = f"{args.dataset}_{args.sample_size}.json"
|
||||
out_path = os.path.join(dataset_results_dir, out_filename)
|
||||
|
||||
out_dir = os.path.dirname(out_path)
|
||||
if out_dir and not os.path.exists(out_dir):
|
||||
os.makedirs(out_dir, exist_ok=True)
|
||||
with open(out_path, "w", encoding="utf-8") as f:
|
||||
json.dump(result, f, ensure_ascii=False, indent=2)
|
||||
print(f"\n结果已保存到: {out_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -187,11 +187,11 @@ class ChunkerClient:
|
||||
async def generate_chunks(self, dialogue: DialogData):
|
||||
"""
|
||||
Generate chunks following 1 Message = 1 Chunk strategy.
|
||||
|
||||
|
||||
Each message creates one chunk, directly inheriting role information.
|
||||
If a message is too long, it will be split into multiple sub-chunks,
|
||||
each maintaining the same speaker.
|
||||
|
||||
|
||||
Raises:
|
||||
ValueError: If dialogue has no messages or chunking fails
|
||||
"""
|
||||
@@ -201,9 +201,9 @@ class ChunkerClient:
|
||||
f"Dialogue {dialogue.ref_id} has no messages. "
|
||||
f"Cannot generate chunks from empty dialogue."
|
||||
)
|
||||
|
||||
|
||||
dialogue.chunks = []
|
||||
|
||||
|
||||
# 按消息分块:每个消息创建一个或多个 chunk,直接继承角色
|
||||
for msg_idx, msg in enumerate(dialogue.context.msgs):
|
||||
# Validate message has required attributes
|
||||
@@ -212,13 +212,13 @@ class ChunkerClient:
|
||||
f"Message {msg_idx} in dialogue {dialogue.ref_id} "
|
||||
f"missing 'role' or 'msg' attribute"
|
||||
)
|
||||
|
||||
|
||||
msg_content = msg.msg.strip()
|
||||
|
||||
|
||||
# Skip empty messages
|
||||
if not msg_content:
|
||||
continue
|
||||
|
||||
|
||||
# 如果消息太长,可以进一步分块
|
||||
if len(msg_content) > self.chunk_size:
|
||||
# 对单个消息的内容进行分块
|
||||
@@ -228,14 +228,14 @@ class ChunkerClient:
|
||||
raise ValueError(
|
||||
f"Failed to chunk long message {msg_idx} in dialogue {dialogue.ref_id}: {e}"
|
||||
)
|
||||
|
||||
|
||||
for idx, sub_chunk in enumerate(sub_chunks):
|
||||
sub_chunk_text = sub_chunk.text if hasattr(sub_chunk, 'text') else str(sub_chunk)
|
||||
sub_chunk_text = sub_chunk_text.strip()
|
||||
|
||||
|
||||
if len(sub_chunk_text) < (self.min_characters_per_chunk or 50):
|
||||
continue
|
||||
|
||||
|
||||
chunk = Chunk(
|
||||
content=f"{msg.role}: {sub_chunk_text}",
|
||||
speaker=msg.role, # 直接继承角色
|
||||
@@ -260,7 +260,7 @@ class ChunkerClient:
|
||||
},
|
||||
)
|
||||
dialogue.chunks.append(chunk)
|
||||
|
||||
|
||||
# Validate we generated at least one chunk
|
||||
if not dialogue.chunks:
|
||||
raise ValueError(
|
||||
@@ -268,7 +268,7 @@ class ChunkerClient:
|
||||
f"All messages were either empty or too short. "
|
||||
f"Messages count: {len(dialogue.context.msgs)}"
|
||||
)
|
||||
|
||||
|
||||
return dialogue
|
||||
|
||||
def evaluate_chunking(self, dialogue: DialogData) -> dict:
|
||||
|
||||
@@ -72,7 +72,7 @@ class TemporalSearchParams(BaseModel):
|
||||
"""Parameters for temporal search queries in the knowledge graph.
|
||||
|
||||
Attributes:
|
||||
group_id: Group ID to filter search results (default: 'test')
|
||||
end_user_id: Group ID to filter search results (default: 'test')
|
||||
apply_id: Application ID to filter search results
|
||||
user_id: User ID to filter search results
|
||||
start_date: Start date for temporal filtering (format: 'YYYY-MM-DD')
|
||||
@@ -81,7 +81,7 @@ class TemporalSearchParams(BaseModel):
|
||||
invalid_date: Date when memory should be invalid (format: 'YYYY-MM-DD')
|
||||
limit: Maximum number of results to return (default: 3)
|
||||
"""
|
||||
group_id: Optional[str] = Field("test", description="The group ID to filter the search.")
|
||||
end_user_id: Optional[str] = Field("test", description="The group ID to filter the search.")
|
||||
apply_id: Optional[str] = Field(None, description="The apply ID to filter the search.")
|
||||
user_id: Optional[str] = Field(None, description="The user ID to filter the search.")
|
||||
start_date: Optional[str] = Field(None, description="The start date for the search.")
|
||||
|
||||
@@ -103,9 +103,7 @@ class Edge(BaseModel):
|
||||
id: Unique identifier for the edge
|
||||
source: ID of the source node
|
||||
target: ID of the target node
|
||||
group_id: Group ID for multi-tenancy
|
||||
user_id: User ID for user-specific data
|
||||
apply_id: Application ID for application-specific data
|
||||
end_user_id: End user ID for multi-tenancy
|
||||
run_id: Unique identifier for the pipeline run that created this edge
|
||||
created_at: Timestamp when the edge was created (system perspective)
|
||||
expired_at: Optional timestamp when the edge expires (system perspective)
|
||||
@@ -113,9 +111,7 @@ class Edge(BaseModel):
|
||||
id: str = Field(default_factory=lambda: uuid4().hex, description="A unique identifier for the edge.")
|
||||
source: str = Field(..., description="The ID of the source node.")
|
||||
target: str = Field(..., description="The ID of the target node.")
|
||||
group_id: str = Field(..., description="The group ID of the edge.")
|
||||
user_id: str = Field(..., description="The user ID of the edge.")
|
||||
apply_id: str = Field(..., description="The apply ID of the edge.")
|
||||
end_user_id: str = Field(..., description="The end user ID of the edge.")
|
||||
run_id: str = Field(default_factory=lambda: uuid4().hex, description="Unique identifier for this pipeline run.")
|
||||
created_at: datetime = Field(..., description="The valid time of the edge from system perspective.")
|
||||
expired_at: Optional[datetime] = Field(None, description="The expired time of the edge from system perspective.")
|
||||
@@ -185,18 +181,14 @@ class Node(BaseModel):
|
||||
Attributes:
|
||||
id: Unique identifier for the node
|
||||
name: Name of the node
|
||||
group_id: Group ID for multi-tenancy
|
||||
user_id: User ID for user-specific data
|
||||
apply_id: Application ID for application-specific data
|
||||
end_user_id: End user ID for multi-tenancy
|
||||
run_id: Unique identifier for the pipeline run that created this node
|
||||
created_at: Timestamp when the node was created (system perspective)
|
||||
expired_at: Optional timestamp when the node expires (system perspective)
|
||||
"""
|
||||
id: str = Field(..., description="The unique identifier for the node.")
|
||||
name: str = Field(..., description="The name of the node.")
|
||||
group_id: str = Field(..., description="The group ID of the node.")
|
||||
user_id: str = Field(..., description="The user ID of the edge.")
|
||||
apply_id: str = Field(..., description="The apply ID of the edge.")
|
||||
end_user_id: str = Field(..., description="The end user ID of the node.")
|
||||
run_id: str = Field(default_factory=lambda: uuid4().hex, description="Unique identifier for this pipeline run.")
|
||||
created_at: datetime = Field(..., description="The valid time of the node from system perspective.")
|
||||
expired_at: Optional[datetime] = Field(None, description="The expired time of the node from system perspective.")
|
||||
|
||||
@@ -55,7 +55,7 @@ class Statement(BaseModel):
|
||||
Attributes:
|
||||
id: Unique identifier for the statement
|
||||
chunk_id: ID of the parent chunk this statement belongs to
|
||||
group_id: Optional group ID for multi-tenancy
|
||||
end_user_id: Optional group ID for multi-tenancy
|
||||
statement: The actual statement text content
|
||||
speaker: Optional speaker identifier ('用户' for user, 'AI' for AI responses)
|
||||
statement_embedding: Optional embedding vector for the statement
|
||||
@@ -73,7 +73,7 @@ class Statement(BaseModel):
|
||||
"""
|
||||
id: str = Field(default_factory=lambda: uuid4().hex, description="A unique identifier for the statement.")
|
||||
chunk_id: str = Field(..., description="ID of the parent chunk this statement belongs to.")
|
||||
group_id: Optional[str] = Field(None, description="ID of the group this statement belongs to.")
|
||||
end_user_id: Optional[str] = Field(None, description="ID of the group this statement belongs to.")
|
||||
statement: str = Field(..., description="The text content of the statement.")
|
||||
speaker: Optional[str] = Field(None, description="Speaker identifier: 'user' for user messages, 'assistant' for AI responses")
|
||||
statement_embedding: Optional[List[float]] = Field(None, description="The embedding vector of the statement.")
|
||||
@@ -159,9 +159,7 @@ class DialogData(BaseModel):
|
||||
context: Full conversation context
|
||||
dialog_embedding: Optional embedding vector for the entire dialog
|
||||
ref_id: Reference ID linking to external dialog system
|
||||
group_id: Group ID for multi-tenancy
|
||||
user_id: User ID for user-specific data
|
||||
apply_id: Application ID for application-specific data
|
||||
end_user_id: End user ID for multi-tenancy
|
||||
created_at: Timestamp when the dialog was created
|
||||
expired_at: Timestamp when the dialog expires (default: far future)
|
||||
metadata: Additional metadata as key-value pairs
|
||||
@@ -175,9 +173,7 @@ class DialogData(BaseModel):
|
||||
context: ConversationContext = Field(..., description="The full conversation context as a single string.")
|
||||
dialog_embedding: Optional[List[float]] = Field(None, description="The embedding vector of the dialog.")
|
||||
ref_id: str = Field(..., description="Refer to external dialog id. This is used to link to the original dialog.")
|
||||
group_id: str = Field(default=..., description="Group ID of dialogue data")
|
||||
user_id: str = Field(..., description="USER ID of dialogue data")
|
||||
apply_id: str = Field(..., description="APPLY ID of dialogue data")
|
||||
end_user_id: str = Field(default=..., description="End user ID of dialogue data")
|
||||
run_id: str = Field(default_factory=lambda: uuid4().hex, description="Unique identifier for this pipeline run.")
|
||||
created_at: datetime = Field(default_factory=datetime.now, description="The timestamp when the dialog was created.")
|
||||
expired_at: datetime = Field(default_factory=lambda: datetime(9999, 12, 31), description="The timestamp when the dialog expires.")
|
||||
@@ -250,11 +246,11 @@ class DialogData(BaseModel):
|
||||
return []
|
||||
|
||||
def assign_group_id_to_statements(self) -> None:
|
||||
"""Assign this dialog's group_id to all statements in all chunks.
|
||||
"""Assign this dialog's end_user_id to all statements in all chunks.
|
||||
|
||||
This method updates statements that don't have a group_id set.
|
||||
This method updates statements that don't have a end_user_id set.
|
||||
"""
|
||||
for chunk in self.chunks:
|
||||
for statement in chunk.statements:
|
||||
if statement.group_id is None:
|
||||
statement.group_id = self.group_id
|
||||
if statement.end_user_id is None:
|
||||
statement.end_user_id = self.end_user_id
|
||||
|
||||
@@ -6,6 +6,7 @@ import os
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||
from uuid import UUID
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.schemas.memory_config_schema import MemoryConfig
|
||||
@@ -396,13 +397,13 @@ def rerank_with_activation(
|
||||
return reranked
|
||||
|
||||
|
||||
def log_search_query(query_text: str, search_type: str, group_id: str | None, limit: int, include: List[str], log_file: str = None):
|
||||
def log_search_query(query_text: str, search_type: str, end_user_id: str | None, limit: int, include: List[str], log_file: str = None):
|
||||
"""Log search query information using the logger.
|
||||
|
||||
Args:
|
||||
query_text: The search query text
|
||||
search_type: Type of search (keyword, embedding, hybrid)
|
||||
group_id: Group identifier for filtering
|
||||
end_user_id: Group identifier for filtering
|
||||
limit: Maximum number of results
|
||||
include: List of result types to include
|
||||
log_file: Deprecated parameter, kept for backward compatibility
|
||||
@@ -413,7 +414,7 @@ def log_search_query(query_text: str, search_type: str, group_id: str | None, li
|
||||
# Log using the standard logger
|
||||
logger.info(
|
||||
f"Search query: query='{cleaned_query}', type={search_type}, "
|
||||
f"group_id={group_id}, limit={limit}, include={include}"
|
||||
f"end_user_id={end_user_id}, limit={limit}, include={include}"
|
||||
)
|
||||
|
||||
|
||||
@@ -672,7 +673,7 @@ def apply_reranker_placeholder(
|
||||
async def run_hybrid_search(
|
||||
query_text: str,
|
||||
search_type: str,
|
||||
group_id: str | None,
|
||||
end_user_id: str | None,
|
||||
limit: int,
|
||||
include: List[str],
|
||||
output_path: str | None,
|
||||
@@ -715,7 +716,7 @@ async def run_hybrid_search(
|
||||
}
|
||||
|
||||
# Log the search query
|
||||
log_search_query(query_text, search_type, group_id, limit, include)
|
||||
log_search_query(query_text, search_type, end_user_id, limit, include)
|
||||
|
||||
connector = Neo4jConnector()
|
||||
results = {}
|
||||
@@ -732,7 +733,7 @@ async def run_hybrid_search(
|
||||
search_graph(
|
||||
connector=connector,
|
||||
q=query_text,
|
||||
group_id=group_id,
|
||||
end_user_id=end_user_id,
|
||||
limit=limit,
|
||||
include=include
|
||||
)
|
||||
@@ -769,7 +770,7 @@ async def run_hybrid_search(
|
||||
connector=connector,
|
||||
embedder_client=embedder,
|
||||
query_text=query_text,
|
||||
group_id=group_id,
|
||||
end_user_id=end_user_id,
|
||||
limit=limit,
|
||||
include=include,
|
||||
)
|
||||
@@ -916,9 +917,7 @@ async def run_hybrid_search(
|
||||
|
||||
|
||||
async def search_by_temporal(
|
||||
group_id: Optional[str] = "test",
|
||||
apply_id: Optional[str] = None,
|
||||
user_id: Optional[str] = None,
|
||||
end_user_id: Optional[str] = "test",
|
||||
start_date: Optional[str] = None,
|
||||
end_date: Optional[str] = None,
|
||||
valid_date: Optional[str] = None,
|
||||
@@ -929,7 +928,7 @@ async def search_by_temporal(
|
||||
Temporal search across Statements.
|
||||
|
||||
- Matches statements created between start_date and end_date
|
||||
- Optionally filters by group_id
|
||||
- Optionally filters by end_user_id
|
||||
- Returns up to 'limit' statements
|
||||
"""
|
||||
connector = Neo4jConnector()
|
||||
@@ -939,9 +938,7 @@ async def search_by_temporal(
|
||||
end_date = normalize_date_safe(end_date)
|
||||
|
||||
params = TemporalSearchParams.model_validate({
|
||||
"group_id": group_id,
|
||||
"apply_id": apply_id,
|
||||
"user_id": user_id,
|
||||
"end_user_id": end_user_id,
|
||||
"start_date": start_date,
|
||||
"end_date": end_date,
|
||||
"valid_date": valid_date,
|
||||
@@ -950,9 +947,7 @@ async def search_by_temporal(
|
||||
})
|
||||
statements = await search_graph_by_temporal(
|
||||
connector=connector,
|
||||
group_id=params.group_id,
|
||||
apply_id=params.apply_id,
|
||||
user_id=params.user_id,
|
||||
end_user_id=params.end_user_id,
|
||||
start_date=params.start_date,
|
||||
end_date=params.end_date,
|
||||
valid_date=params.valid_date,
|
||||
@@ -964,9 +959,7 @@ async def search_by_temporal(
|
||||
|
||||
async def search_by_keyword_temporal(
|
||||
query_text: str,
|
||||
group_id: Optional[str] = "test",
|
||||
apply_id: Optional[str] = None,
|
||||
user_id: Optional[str] = None,
|
||||
end_user_id: Optional[str] = "test",
|
||||
start_date: Optional[str] = None,
|
||||
end_date: Optional[str] = None,
|
||||
valid_date: Optional[str] = None,
|
||||
@@ -987,9 +980,7 @@ async def search_by_keyword_temporal(
|
||||
invalid_date = normalize_date_safe(invalid_date)
|
||||
|
||||
params = TemporalSearchParams.model_validate({
|
||||
"group_id": group_id,
|
||||
"apply_id": apply_id,
|
||||
"user_id": user_id,
|
||||
"end_user_id": end_user_id,
|
||||
"start_date": start_date,
|
||||
"end_date": end_date,
|
||||
"valid_date": valid_date,
|
||||
@@ -999,9 +990,7 @@ async def search_by_keyword_temporal(
|
||||
statements = await search_graph_by_keyword_temporal(
|
||||
connector=connector,
|
||||
query_text=query_text,
|
||||
group_id=params.group_id,
|
||||
apply_id=params.apply_id,
|
||||
user_id=params.user_id,
|
||||
end_user_id=params.end_user_id,
|
||||
start_date=params.start_date,
|
||||
end_date=params.end_date,
|
||||
valid_date=params.valid_date,
|
||||
@@ -1013,7 +1002,7 @@ async def search_by_keyword_temporal(
|
||||
|
||||
async def search_chunk_by_chunk_id(
|
||||
chunk_id: str,
|
||||
group_id: Optional[str] = "test",
|
||||
end_user_id: Optional[str] = "test",
|
||||
limit: int = 1,
|
||||
):
|
||||
"""
|
||||
@@ -1023,7 +1012,7 @@ async def search_chunk_by_chunk_id(
|
||||
chunks = await search_graph_by_chunk_id(
|
||||
connector=connector,
|
||||
chunk_id=chunk_id,
|
||||
group_id=group_id,
|
||||
end_user_id=end_user_id,
|
||||
limit=limit
|
||||
)
|
||||
return {"chunks": chunks}
|
||||
|
||||
@@ -555,8 +555,8 @@ class DataPreprocessor:
|
||||
dialog_id = item.get('dialog_id', item.get('ref_id', item.get('id', f'dialog_{i}')))
|
||||
|
||||
|
||||
# 获取group_id,如果不存在则生成默认值
|
||||
group_id = item.get('group_id', f'group_default_{i}')
|
||||
# 获取end_user_id,如果不存在则生成默认值
|
||||
end_user_id = item.get('end_user_id', f'group_default_{i}')
|
||||
user_id = item.get('user_id', f'user_default_{i}')
|
||||
apply_id = item.get('apply_id', f'apply_default_{i}')
|
||||
|
||||
@@ -574,7 +574,7 @@ class DataPreprocessor:
|
||||
dialog_data = DialogData(
|
||||
context=context,
|
||||
ref_id=dialog_id,
|
||||
group_id=group_id,
|
||||
end_user_id=end_user_id,
|
||||
user_id=user_id,
|
||||
apply_id=apply_id,
|
||||
metadata=metadata
|
||||
@@ -644,7 +644,7 @@ class DataPreprocessor:
|
||||
|
||||
context = ConversationContext(msgs=messages)
|
||||
dialog_id = item.get('dialog_id', item.get('ref_id', item.get('id', f'dialog_{i}')))
|
||||
group_id = item.get('group_id', f'group_default_{i}')
|
||||
end_user_id = item.get('end_user_id', f'group_default_{i}')
|
||||
user_id = item.get('user_id', f'user_default_{i}')
|
||||
apply_id = item.get('apply_id', f'apply_default_{i}')
|
||||
|
||||
@@ -657,7 +657,7 @@ class DataPreprocessor:
|
||||
dialog_data = DialogData(
|
||||
context=context,
|
||||
ref_id=dialog_id,
|
||||
group_id=group_id,
|
||||
end_user_id=end_user_id,
|
||||
user_id=user_id,
|
||||
apply_id=apply_id,
|
||||
metadata=metadata
|
||||
|
||||
@@ -199,7 +199,7 @@ def accurate_match(
|
||||
entity_nodes: List[ExtractedEntityNode]
|
||||
) -> Tuple[List[ExtractedEntityNode], Dict[str, str], Dict[str, Dict]]:
|
||||
"""
|
||||
精确匹配:按 (group_id, name, entity_type) 合并实体并建立重定向与合并记录。
|
||||
精确匹配:按 (end_user_id, name, entity_type) 合并实体并建立重定向与合并记录。
|
||||
返回: (deduped_entities, id_redirect, exact_merge_map)
|
||||
"""
|
||||
exact_merge_map: Dict[str, Dict] = {}
|
||||
@@ -210,8 +210,8 @@ def accurate_match(
|
||||
for ent in entity_nodes:
|
||||
name_norm = (getattr(ent, "name", "") or "").strip()
|
||||
type_norm = (getattr(ent, "entity_type", "") or "").strip()
|
||||
key = f"{getattr(ent, 'group_id', None)}|{name_norm}|{type_norm}"
|
||||
# 为避免跨业务组误并,明确以 group_id 为范围边界
|
||||
key = f"{getattr(ent, 'end_user_id', None)}|{name_norm}|{type_norm}"
|
||||
# 为避免跨业务组误并,明确以 end_user_id 为范围边界
|
||||
if key not in canonical_map:
|
||||
canonical_map[key] = ent
|
||||
id_redirect[ent.id] = ent.id
|
||||
@@ -223,11 +223,11 @@ def accurate_match(
|
||||
id_redirect[ent.id] = canonical.id
|
||||
# 记录精确匹配的合并项(使用规范化键,避免外层变量误用)
|
||||
try:
|
||||
k = f"{canonical.group_id}|{(canonical.name or '').strip()}|{(canonical.entity_type or '').strip()}"
|
||||
k = f"{canonical.end_user_id}|{(canonical.name or '').strip()}|{(canonical.entity_type or '').strip()}"
|
||||
if k not in exact_merge_map:
|
||||
exact_merge_map[k] = {
|
||||
"canonical_id": canonical.id,
|
||||
"group_id": canonical.group_id,
|
||||
"end_user_id": canonical.end_user_id,
|
||||
"name": canonical.name,
|
||||
"entity_type": canonical.entity_type,
|
||||
"merged_ids": set(),
|
||||
@@ -596,7 +596,7 @@ def fuzzy_match(
|
||||
b = deduped_entities[j]
|
||||
|
||||
# 跳过不同业务组的实体
|
||||
if getattr(a, "group_id", None) != getattr(b, "group_id", None):
|
||||
if getattr(a, "end_user_id", None) != getattr(b, "end_user_id", None):
|
||||
j += 1
|
||||
continue
|
||||
|
||||
@@ -671,7 +671,7 @@ def fuzzy_match(
|
||||
merge_reason = "[别名匹配]" if alias_match_merge else "[模糊]"
|
||||
merge_reason = "[别名匹配]" if alias_match_merge else "[模糊]"
|
||||
fuzzy_merge_records.append(
|
||||
f"{merge_reason} 规范实体 {a.id} ({a.group_id}|{a.name}|{a.entity_type}) <- 合并实体 {b.id} ({b.group_id}|{b.name}|{b.entity_type}) | "
|
||||
f"{merge_reason} 规范实体 {a.id} ({a.end_user_id}|{a.name}|{a.entity_type}) <- 合并实体 {b.id} ({b.end_user_id}|{b.name}|{b.entity_type}) | "
|
||||
f"s_name={s_name:.3f}, s_type={s_type:.3f}, overall={overall:.3f}, exact_alias={has_exact_match}"
|
||||
)
|
||||
except Exception:
|
||||
@@ -779,7 +779,7 @@ async def LLM_decision( # 决策中包含去重和消歧的功能
|
||||
# 记录 LLM 融合日志
|
||||
try:
|
||||
llm_records.append(
|
||||
f"[LLM融合] 规范实体 {a.id} ({a.group_id}|{a.name}|{a.entity_type}) <- 合并实体 {b.id} ({b.group_id}|{b.name}|{b.entity_type})"
|
||||
f"[LLM融合] 规范实体 {a.id} ({a.end_user_id}|{a.name}|{a.entity_type}) <- 合并实体 {b.id} ({b.end_user_id}|{b.name}|{b.entity_type})"
|
||||
)
|
||||
# 详细的“同类名称相似”记录改由 LLM 去重模块统一生成以携带 conf/reason
|
||||
except Exception:
|
||||
@@ -847,7 +847,7 @@ async def LLM_disamb_decision(
|
||||
id_redirect[k] = a.id
|
||||
try:
|
||||
disamb_records.append(
|
||||
f"[DISAMB合并应用] 规范实体 {a.id} ({a.group_id}|{a.name}|{a.entity_type}) <- 合并实体 {b.id} ({b.group_id}|{b.name}|{b.entity_type})"
|
||||
f"[DISAMB合并应用] 规范实体 {a.id} ({a.end_user_id}|{a.name}|{a.entity_type}) <- 合并实体 {b.id} ({b.end_user_id}|{b.name}|{b.entity_type})"
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@@ -174,7 +174,7 @@ async def _judge_pair(
|
||||
pass
|
||||
# 3. 构建LLM判断的“上下文信息”(规则层计算的所有特征) 判断上下文特征有助于实体消歧首先判断的类型关系
|
||||
ctx = {
|
||||
"same_group": getattr(a, "group_id", None) == getattr(b, "group_id", None),
|
||||
"same_group": getattr(a, "end_user_id", None) == getattr(b, "end_user_id", None),
|
||||
"type_ok": _simple_type_ok(getattr(a, "entity_type", None), getattr(b, "entity_type", None)),
|
||||
"type_similarity": _type_similarity(getattr(a, "entity_type", None), getattr(b, "entity_type", None)),
|
||||
"name_text_sim": name_text_sim,
|
||||
@@ -235,7 +235,7 @@ async def _judge_pair_disamb(
|
||||
except Exception:
|
||||
pass
|
||||
ctx = {
|
||||
"same_group": getattr(a, "group_id", None) == getattr(b, "group_id", None),
|
||||
"same_group": getattr(a, "end_user_id", None) == getattr(b, "end_user_id", None),
|
||||
"type_ok": _simple_type_ok(getattr(a, "entity_type", None), getattr(b, "entity_type", None)),
|
||||
"name_text_sim": name_text_sim,
|
||||
"name_embed_sim": name_embed_sim,
|
||||
@@ -317,8 +317,8 @@ async def llm_dedup_entities( # 保留对偶判断作为子流程,是为了
|
||||
a = entity_nodes[i]
|
||||
for j in range(i + 1, len(entity_nodes)):
|
||||
b = entity_nodes[j]
|
||||
# 规则1:必须属于同一组(group_id相同,不同组的实体不重复)
|
||||
if getattr(a, "group_id", None) != getattr(b, "group_id", None):
|
||||
# 规则1:必须属于同一组(end_user_id相同,不同组的实体不重复)
|
||||
if getattr(a, "end_user_id", None) != getattr(b, "end_user_id", None):
|
||||
continue
|
||||
# 规则2:类型必须兼容(调用_simple_type_ok判断)
|
||||
if not _simple_type_ok(getattr(a, "entity_type", None), getattr(b, "entity_type", None)):
|
||||
@@ -474,7 +474,7 @@ async def llm_dedup_entities_iterative_blocks( # 迭代分块并发 LLM 去重
|
||||
- max_rounds: upper bound for iterative passes (default 3)
|
||||
- auto_merge_threshold: decision confidence for auto-merge when no co-occurrence (default 0.90)
|
||||
- co_ctx_threshold: lower threshold when co-occurrence is detected (default 0.83)
|
||||
- shuffle_each_round: whether to shuffle entities within group_id each round to vary block composition
|
||||
- shuffle_each_round: whether to shuffle entities within end_user_id each round to vary block composition
|
||||
|
||||
Returns:
|
||||
- global_redirect: dict losing_id -> canonical_id accumulated across rounds
|
||||
@@ -509,7 +509,7 @@ async def llm_dedup_entities_iterative_blocks( # 迭代分块并发 LLM 去重
|
||||
|
||||
def _partition_blocks(nodes: List[ExtractedEntityNode]) -> List[List[ExtractedEntityNode]]:
|
||||
"""
|
||||
按 group_id 分块,避免跨组实体在同一块,减少无效候选对
|
||||
按 end_user_id 分块,避免跨组实体在同一块,减少无效候选对
|
||||
|
||||
Args:
|
||||
nodes: 实体节点列表
|
||||
@@ -519,7 +519,7 @@ async def llm_dedup_entities_iterative_blocks( # 迭代分块并发 LLM 去重
|
||||
"""
|
||||
groups: Dict[str, List[ExtractedEntityNode]] = {}
|
||||
for e in nodes:
|
||||
gid = getattr(e, "group_id", None)
|
||||
gid = getattr(e, "end_user_id", None)
|
||||
groups.setdefault(str(gid), []).append(e)
|
||||
blocks: List[List[ExtractedEntityNode]] = []
|
||||
for gid, arr in groups.items():
|
||||
@@ -559,7 +559,7 @@ async def llm_dedup_entities_iterative_blocks( # 迭代分块并发 LLM 去重
|
||||
# Collapse nodes to canonical reps before each round to avoid redundant comparisons
|
||||
# 步骤1:折叠实体(合并已确定的重复实体,减少后续计算量)
|
||||
current_nodes = _collapse_nodes(current_nodes)
|
||||
# 步骤2:分块(按group_id分块,避免跨组处理)
|
||||
# 步骤2:分块(按end_user_id分块,避免跨组处理)
|
||||
blocks = _partition_blocks(current_nodes)
|
||||
if not blocks: # 无块可处理(实体已全部折叠),退出循环
|
||||
break
|
||||
@@ -645,7 +645,7 @@ async def llm_disambiguate_pairs_iterative(
|
||||
a = entity_nodes[i]
|
||||
b = entity_nodes[j]
|
||||
# 必须同组
|
||||
if getattr(a, "group_id", None) != getattr(b, "group_id", None):
|
||||
if getattr(a, "end_user_id", None) != getattr(b, "end_user_id", None):
|
||||
continue
|
||||
ta = getattr(a, "entity_type", None)
|
||||
tb = getattr(b, "entity_type", None)
|
||||
|
||||
@@ -61,7 +61,7 @@ def _row_to_entity(row: Dict[str, Any]) -> ExtractedEntityNode:
|
||||
return ExtractedEntityNode(
|
||||
id=row.get("id"),
|
||||
name=row.get("name") or "",
|
||||
group_id=row.get("group_id") or "",
|
||||
end_user_id=row.get("end_user_id") or "",
|
||||
user_id=row.get("user_id") or "",
|
||||
apply_id=row.get("apply_id") or "",
|
||||
created_at=_parse_dt(row.get("created_at")),
|
||||
@@ -79,7 +79,7 @@ def _row_to_entity(row: Dict[str, Any]) -> ExtractedEntityNode:
|
||||
|
||||
async def second_layer_dedup_and_merge_with_neo4j( # 二层去重的核心逻辑,与 Neo4j 中同组实体联合去重
|
||||
connector: Neo4jConnector,
|
||||
group_id: str, # 用于定位neo4j中同一组的实体,确保只在同组内去重
|
||||
end_user_id: str, # 用于定位neo4j中同一组的实体,确保只在同组内去重
|
||||
entity_nodes: List[ExtractedEntityNode], # 输入的实体节点列表,包含待去重的实体
|
||||
statement_entity_edges: List[StatementEntityEdge], # 输入的语句实体边列表,用于处理实体之间的关系
|
||||
entity_entity_edges: List[EntityEntityEdge], # 输入的实体实体边列表,用于处理实体之间的关系
|
||||
@@ -88,7 +88,7 @@ async def second_layer_dedup_and_merge_with_neo4j( # 二层去重的核心逻辑
|
||||
) -> Tuple[List[ExtractedEntityNode], List[StatementEntityEdge], List[EntityEntityEdge]]:
|
||||
"""
|
||||
第二层去重消歧:
|
||||
- 以第一层结果为索引,检索相同 group_id 下的 DB 候选实体
|
||||
- 以第一层结果为索引,检索相同 end_user_id 下的 DB 候选实体
|
||||
- 将 DB 候选与当前实体集合联合,按既有精确/模糊/LLM 决策进行融合
|
||||
- 返回融合后的实体与重定向后的边(边已指向规范 ID,优先 DB ID)
|
||||
"""
|
||||
@@ -102,7 +102,7 @@ async def second_layer_dedup_and_merge_with_neo4j( # 二层去重的核心逻辑
|
||||
|
||||
]
|
||||
candidates_map = await get_dedup_candidates_for_entities( # 从 Neo4j 中查询候选实体,并将结果赋值给candidates_map(等待异步操作完成)。
|
||||
connector=connector, group_id=group_id,
|
||||
connector=connector, end_user_id=end_user_id,
|
||||
entities=incoming_rows, # 传入参数:第一层实体的核心信息(作为查询索引)
|
||||
use_contains_fallback=True # 传入参数:启用 “包含关系” 作为匹配失败的降级策略(若精确匹配无结果,用包含关系召回候选),与src\database\cypher_queries.py的307产生联动
|
||||
)
|
||||
|
||||
@@ -57,11 +57,11 @@ async def dedup_layers_and_merge_and_return(
|
||||
if pipeline_config is None:
|
||||
raise ValueError("pipeline_config is required for dedup_layers_and_merge_and_return")
|
||||
|
||||
# 先探测 group_id,决定报告写入策略
|
||||
group_id: Optional[str] = None
|
||||
# 先探测 end_user_id,决定报告写入策略
|
||||
end_user_id: Optional[str] = None
|
||||
for dd in dialog_data_list:
|
||||
group_id = getattr(dd, "group_id", None)
|
||||
if group_id:
|
||||
end_user_id = getattr(dd, "end_user_id", None)
|
||||
if end_user_id:
|
||||
break
|
||||
|
||||
# 第一层去重消歧
|
||||
@@ -82,11 +82,11 @@ async def dedup_layers_and_merge_and_return(
|
||||
|
||||
# 第二层去重消歧:与 Neo4j 中同组实体联合融合
|
||||
try:
|
||||
if group_id:
|
||||
if end_user_id:
|
||||
if connector:
|
||||
fused_entity_nodes, fused_statement_entity_edges, fused_entity_entity_edges = await second_layer_dedup_and_merge_with_neo4j(
|
||||
connector=connector,
|
||||
group_id=group_id,
|
||||
end_user_id=end_user_id,
|
||||
entity_nodes=dedup_entity_nodes,
|
||||
statement_entity_edges=dedup_statement_entity_edges,
|
||||
entity_entity_edges=dedup_entity_entity_edges,
|
||||
@@ -96,7 +96,7 @@ async def dedup_layers_and_merge_and_return(
|
||||
else:
|
||||
print("Skip second-layer dedup: missing connector")
|
||||
else:
|
||||
print("Skip second-layer dedup: missing group_id")
|
||||
print("Skip second-layer dedup: missing end_user_id")
|
||||
except Exception as e:
|
||||
print(f"Second-layer dedup failed: {e}")
|
||||
|
||||
|
||||
@@ -287,7 +287,7 @@ class ExtractionOrchestrator:
|
||||
for d_idx, dialog in enumerate(dialog_data_list):
|
||||
dialogue_content = dialog.content if self.config.statement_extraction.include_dialogue_context else None
|
||||
for c_idx, chunk in enumerate(dialog.chunks):
|
||||
all_chunks.append((chunk, dialog.group_id, dialogue_content))
|
||||
all_chunks.append((chunk, dialog.end_user_id, dialogue_content))
|
||||
chunk_metadata.append((d_idx, c_idx))
|
||||
|
||||
logger.info(f"收集到 {len(all_chunks)} 个分块,开始全局并行提取")
|
||||
@@ -299,9 +299,9 @@ class ExtractionOrchestrator:
|
||||
# 全局并行处理所有分块
|
||||
async def extract_for_chunk(chunk_data, chunk_index):
|
||||
nonlocal completed_chunks
|
||||
chunk, group_id, dialogue_content = chunk_data
|
||||
chunk, end_user_id, dialogue_content = chunk_data
|
||||
try:
|
||||
statements = await self.statement_extractor._extract_statements(chunk, group_id, dialogue_content)
|
||||
statements = await self.statement_extractor._extract_statements(chunk, end_user_id, dialogue_content)
|
||||
|
||||
# 流式输出:每提取完一个分块的陈述句,立即发送进度
|
||||
# 注意:只在试运行模式下发送陈述句详情,正式模式不发送
|
||||
@@ -569,32 +569,32 @@ class ExtractionOrchestrator:
|
||||
if dialog_data_list and hasattr(dialog_data_list[0], 'config_id'):
|
||||
config_id = dialog_data_list[0].config_id
|
||||
|
||||
# 加载DataConfig
|
||||
data_config = None
|
||||
# 加载MemoryConfig
|
||||
memory_config = None
|
||||
if config_id:
|
||||
try:
|
||||
from app.db import SessionLocal
|
||||
from app.repositories.data_config_repository import DataConfigRepository
|
||||
from app.repositories.memory_config_repository import MemoryConfigRepository
|
||||
|
||||
db = SessionLocal()
|
||||
try:
|
||||
data_config = DataConfigRepository.get_by_id(db, config_id)
|
||||
memory_config = MemoryConfigRepository.get_by_id(db, config_id)
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
if data_config and not data_config.emotion_enabled:
|
||||
if memory_config and not memory_config.emotion_enabled:
|
||||
logger.info("情绪提取已在配置中禁用,跳过情绪提取")
|
||||
return [{} for _ in dialog_data_list]
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"加载DataConfig失败: {e},将跳过情绪提取")
|
||||
logger.warning(f"加载MemoryConfig失败: {e},将跳过情绪提取")
|
||||
return [{} for _ in dialog_data_list]
|
||||
else:
|
||||
logger.info("未找到config_id,跳过情绪提取")
|
||||
return [{} for _ in dialog_data_list]
|
||||
|
||||
# 如果配置未启用情绪提取,直接返回空映射
|
||||
if not data_config or not data_config.emotion_enabled:
|
||||
if not memory_config or not memory_config.emotion_enabled:
|
||||
logger.info("情绪提取未启用,跳过")
|
||||
return [{} for _ in dialog_data_list]
|
||||
|
||||
@@ -608,7 +608,7 @@ class ExtractionOrchestrator:
|
||||
total_statements += 1
|
||||
# 只处理用户的陈述句 (role 为 "user")
|
||||
if hasattr(statement, 'speaker') and statement.speaker == "user":
|
||||
all_statements.append((statement, data_config))
|
||||
all_statements.append((statement, memory_config))
|
||||
statement_metadata.append((d_idx, statement.id))
|
||||
filtered_statements += 1
|
||||
|
||||
@@ -617,7 +617,7 @@ class ExtractionOrchestrator:
|
||||
# 初始化情绪提取服务
|
||||
from app.services.emotion_extraction_service import EmotionExtractionService
|
||||
emotion_service = EmotionExtractionService(
|
||||
llm_id=data_config.emotion_model_id if data_config.emotion_model_id else None
|
||||
llm_id=memory_config.emotion_model_id if memory_config.emotion_model_id else None
|
||||
)
|
||||
|
||||
# 全局并行处理所有陈述句
|
||||
@@ -992,9 +992,7 @@ class ExtractionOrchestrator:
|
||||
id=dialog_data.id,
|
||||
name=f"Dialog_{dialog_data.id}", # 添加必需的 name 字段
|
||||
ref_id=dialog_data.ref_id,
|
||||
group_id=dialog_data.group_id,
|
||||
user_id=dialog_data.user_id,
|
||||
apply_id=dialog_data.apply_id,
|
||||
end_user_id=dialog_data.end_user_id,
|
||||
run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id
|
||||
content=dialog_data.context.content if dialog_data.context else "",
|
||||
dialog_embedding=dialog_data.dialog_embedding if hasattr(dialog_data, 'dialog_embedding') else None,
|
||||
@@ -1012,9 +1010,7 @@ class ExtractionOrchestrator:
|
||||
id=chunk.id,
|
||||
name=f"Chunk_{chunk.id}", # 添加必需的 name 字段
|
||||
dialog_id=dialog_data.id,
|
||||
group_id=dialog_data.group_id,
|
||||
user_id=dialog_data.user_id,
|
||||
apply_id=dialog_data.apply_id,
|
||||
end_user_id=dialog_data.end_user_id,
|
||||
run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id
|
||||
content=chunk.content,
|
||||
chunk_embedding=chunk.chunk_embedding,
|
||||
@@ -1035,9 +1031,7 @@ class ExtractionOrchestrator:
|
||||
stmt_type=getattr(statement, 'stmt_type', 'general'), # 添加必需的 stmt_type 字段
|
||||
temporal_info=getattr(statement, 'temporal_info', TemporalInfo.ATEMPORAL), # 添加必需的 temporal_info 字段
|
||||
connect_strength=statement.connect_strength if statement.connect_strength is not None else 'Strong', # 添加必需的 connect_strength 字段
|
||||
group_id=dialog_data.group_id,
|
||||
user_id=dialog_data.user_id,
|
||||
apply_id=dialog_data.apply_id,
|
||||
end_user_id=dialog_data.end_user_id,
|
||||
run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id
|
||||
statement=statement.statement,
|
||||
speaker=getattr(statement, 'speaker', None), # 添加 speaker 字段
|
||||
@@ -1060,9 +1054,7 @@ class ExtractionOrchestrator:
|
||||
statement_chunk_edge = StatementChunkEdge(
|
||||
source=statement.id,
|
||||
target=chunk.id,
|
||||
group_id=dialog_data.group_id,
|
||||
user_id=dialog_data.user_id,
|
||||
apply_id=dialog_data.apply_id,
|
||||
end_user_id=dialog_data.end_user_id,
|
||||
run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id
|
||||
created_at=dialog_data.created_at,
|
||||
)
|
||||
@@ -1072,13 +1064,16 @@ class ExtractionOrchestrator:
|
||||
if statement.triplet_extraction_info:
|
||||
triplet_info = statement.triplet_extraction_info
|
||||
|
||||
# 创建实体索引到ID的映射
|
||||
# 创建实体索引到ID的映射(支持多种索引方式)
|
||||
entity_idx_to_id = {}
|
||||
|
||||
# 创建实体节点
|
||||
for entity_idx, entity in enumerate(triplet_info.entities):
|
||||
# 映射实体索引到实体ID
|
||||
# 映射实体索引到实体ID(使用多个键以提高容错性)
|
||||
# 1. 使用实体自己的 entity_idx
|
||||
entity_idx_to_id[entity.entity_idx] = entity.id
|
||||
# 2. 使用枚举索引(从0开始)
|
||||
entity_idx_to_id[entity_idx] = entity.id
|
||||
|
||||
if entity.id not in entity_id_set:
|
||||
entity_connect_strength = getattr(entity, 'connect_strength', 'Strong')
|
||||
@@ -1095,9 +1090,7 @@ class ExtractionOrchestrator:
|
||||
aliases=getattr(entity, 'aliases', []) or [], # 传递从三元组提取阶段获取的aliases
|
||||
name_embedding=getattr(entity, 'name_embedding', None),
|
||||
is_explicit_memory=getattr(entity, 'is_explicit_memory', False), # 新增:传递语义记忆标记
|
||||
group_id=dialog_data.group_id,
|
||||
user_id=dialog_data.user_id,
|
||||
apply_id=dialog_data.apply_id,
|
||||
end_user_id=dialog_data.end_user_id,
|
||||
run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id
|
||||
created_at=dialog_data.created_at,
|
||||
expired_at=dialog_data.expired_at,
|
||||
@@ -1112,9 +1105,7 @@ class ExtractionOrchestrator:
|
||||
source=statement.id,
|
||||
target=entity.id,
|
||||
connect_strength=entity_connect_strength if entity_connect_strength is not None else 'Strong',
|
||||
group_id=dialog_data.group_id,
|
||||
user_id=dialog_data.user_id,
|
||||
apply_id=dialog_data.apply_id,
|
||||
end_user_id=dialog_data.end_user_id,
|
||||
run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id
|
||||
created_at=dialog_data.created_at,
|
||||
)
|
||||
@@ -1134,9 +1125,7 @@ class ExtractionOrchestrator:
|
||||
relation_type=triplet.predicate,
|
||||
statement=statement.statement,
|
||||
source_statement_id=statement.id,
|
||||
group_id=dialog_data.group_id,
|
||||
user_id=dialog_data.user_id,
|
||||
apply_id=dialog_data.apply_id,
|
||||
end_user_id=dialog_data.end_user_id,
|
||||
run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id
|
||||
created_at=dialog_data.created_at,
|
||||
expired_at=dialog_data.expired_at,
|
||||
@@ -1163,9 +1152,18 @@ class ExtractionOrchestrator:
|
||||
relationship_result
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"跳过三元组 - 无法找到实体ID: subject_id={triplet.subject_id}, "
|
||||
f"object_id={triplet.object_id}, statement_id={statement.id}"
|
||||
# 改进的警告信息,包含更多调试信息
|
||||
missing_subject = "subject" if not subject_entity_id else ""
|
||||
missing_object = "object" if not object_entity_id else ""
|
||||
missing_both = " and " if (not subject_entity_id and not object_entity_id) else ""
|
||||
|
||||
logger.debug(
|
||||
f"跳过三元组 - 无法找到{missing_subject}{missing_both}{missing_object}实体ID: "
|
||||
f"subject_id={triplet.subject_id} ({triplet.subject_name}), "
|
||||
f"object_id={triplet.object_id} ({triplet.object_name}), "
|
||||
f"predicate={triplet.predicate}, "
|
||||
f"statement_id={statement.id}, "
|
||||
f"available_indices={sorted(entity_idx_to_id.keys())}"
|
||||
)
|
||||
|
||||
logger.info(
|
||||
@@ -1763,14 +1761,14 @@ class ExtractionOrchestrator:
|
||||
|
||||
async def get_chunked_dialogs(
|
||||
chunker_strategy: str = "RecursiveChunker",
|
||||
group_id: str = "group_1",
|
||||
end_user_id: str = "group_1",
|
||||
indices: Optional[List[int]] = None,
|
||||
) -> List[DialogData]:
|
||||
"""从测试数据生成分块对话
|
||||
|
||||
Args:
|
||||
chunker_strategy: 分块策略(默认: RecursiveChunker)
|
||||
group_id: 组ID
|
||||
end_user_id: 组ID
|
||||
indices: 要处理的数据索引列表(可选)
|
||||
|
||||
Returns:
|
||||
@@ -1834,7 +1832,7 @@ async def get_chunked_dialogs(
|
||||
dialog_data = DialogData(
|
||||
context=conversation_context,
|
||||
ref_id=data['id'],
|
||||
group_id=group_id,
|
||||
end_user_id=end_user_id,
|
||||
metadata=dialog_metadata,
|
||||
)
|
||||
|
||||
@@ -1936,7 +1934,7 @@ async def get_chunked_dialogs_from_preprocessed(
|
||||
|
||||
async def get_chunked_dialogs_with_preprocessing(
|
||||
chunker_strategy: str = "RecursiveChunker",
|
||||
group_id: str = "default",
|
||||
end_user_id: str = "default",
|
||||
user_id: str = "default",
|
||||
apply_id: str = "default",
|
||||
indices: Optional[List[int]] = None,
|
||||
@@ -1948,7 +1946,7 @@ async def get_chunked_dialogs_with_preprocessing(
|
||||
|
||||
Args:
|
||||
chunker_strategy: 分块策略
|
||||
group_id: 组ID
|
||||
end_user_id: 组ID
|
||||
user_id: 用户ID
|
||||
apply_id: 应用ID
|
||||
indices: 要处理的数据索引列表
|
||||
@@ -1976,11 +1974,9 @@ async def get_chunked_dialogs_with_preprocessing(
|
||||
indices=indices,
|
||||
)
|
||||
|
||||
# 设置 group_id, user_id, apply_id
|
||||
# 设置 end_user_id
|
||||
for dd in preprocessed_data:
|
||||
dd.group_id = group_id
|
||||
dd.user_id = user_id
|
||||
dd.apply_id = apply_id
|
||||
dd.end_user_id = end_user_id
|
||||
|
||||
# 步骤2: 语义剪枝
|
||||
try:
|
||||
|
||||
@@ -193,9 +193,9 @@ async def _process_chunk_summary(
|
||||
node = MemorySummaryNode(
|
||||
id=uuid4().hex,
|
||||
name=title if title else f"MemorySummaryChunk_{chunk.id}",
|
||||
group_id=dialog.group_id,
|
||||
user_id=dialog.user_id,
|
||||
apply_id=dialog.apply_id,
|
||||
end_user_id=dialog.end_user_id,
|
||||
user_id=dialog.end_user_id,
|
||||
apply_id=dialog.end_user_id,
|
||||
run_id=dialog.run_id, # 使用 dialog 的 run_id
|
||||
created_at=datetime.now(),
|
||||
expired_at=datetime(9999, 12, 31),
|
||||
|
||||
@@ -82,12 +82,12 @@ class StatementExtractor:
|
||||
logger.warning(f"Chunk {getattr(chunk, 'id', 'unknown')} has no speaker field or is empty")
|
||||
return None
|
||||
|
||||
async def _extract_statements(self, chunk, group_id: Optional[str] = None, dialogue_content: str = None) -> List[Statement]:
|
||||
async def _extract_statements(self, chunk, end_user_id: Optional[str] = None, dialogue_content: str = None) -> List[Statement]:
|
||||
"""Process a single chunk and return extracted statements
|
||||
|
||||
Args:
|
||||
chunk: Chunk object to process
|
||||
group_id: Group ID to assign to all statements in this chunk
|
||||
end_user_id: Group ID to assign to all statements in this chunk
|
||||
dialogue_content: Full dialogue content to provide as context
|
||||
|
||||
Returns:
|
||||
@@ -158,7 +158,7 @@ class StatementExtractor:
|
||||
temporal_info=temporal_type,
|
||||
relevence_info=relevence_info,
|
||||
chunk_id=chunk.id,
|
||||
group_id=group_id,
|
||||
end_user_id=end_user_id,
|
||||
speaker=chunk_speaker,
|
||||
)
|
||||
|
||||
@@ -184,10 +184,10 @@ class StatementExtractor:
|
||||
|
||||
logger.info(f"Processing {len(chunks_to_process)} chunks for statement extraction")
|
||||
|
||||
# Process all chunks concurrently, passing the group_id and dialogue content from dialog_data
|
||||
# Process all chunks concurrently, passing the end_user_id and dialogue content from dialog_data
|
||||
dialogue_content = dialog_data.content if self.config.include_dialogue_context else None
|
||||
results = await asyncio.gather(
|
||||
*[self._extract_statements(chunk, dialog_data.group_id, dialogue_content) for chunk in chunks_to_process],
|
||||
*[self._extract_statements(chunk, dialog_data.end_user_id, dialogue_content) for chunk in chunks_to_process],
|
||||
return_exceptions=True
|
||||
)
|
||||
|
||||
@@ -225,7 +225,7 @@ class StatementExtractor:
|
||||
for i, statement in enumerate(statements, 1):
|
||||
f.write(f"Statement {i}:\n")
|
||||
f.write(f"Id: {statement.id}\n")
|
||||
f.write(f"Group Id: {statement.group_id}\n")
|
||||
f.write(f"Group Id: {statement.end_user_id}\n")
|
||||
f.write(f"Content: {statement.statement}\n")
|
||||
f.write(f"Type: {statement.stmt_type.value}\n")
|
||||
f.write(f"Temporal Info: {statement.temporal_info.value}\n")
|
||||
@@ -298,7 +298,7 @@ class StatementExtractor:
|
||||
|
||||
dialog_sections.append({
|
||||
"dialog_id": dialog.ref_id,
|
||||
"group_id": dialog.group_id,
|
||||
"end_user_id": dialog.end_user_id,
|
||||
"content": dialog.content if getattr(dialog, "content", None) else "",
|
||||
"strong": strong_relations,
|
||||
"weak": weak_relations,
|
||||
@@ -312,7 +312,7 @@ class StatementExtractor:
|
||||
for idx, section in enumerate(dialog_sections, 1):
|
||||
f.write(f"Dialog {idx}:\n")
|
||||
f.write(f"Dialog ID: {section.get('dialog_id', '')}\n")
|
||||
f.write(f"Group ID: {section.get('group_id', '')}\n")
|
||||
f.write(f"Group ID: {section.get('end_user_id', '')}\n")
|
||||
f.write("Content:\n")
|
||||
f.write(f"{section.get('content', '')}\n")
|
||||
f.write("-" * 40 + "\n\n")
|
||||
|
||||
@@ -132,7 +132,7 @@ class TemporalExtractor:
|
||||
prompt_logger.info("")
|
||||
prompt_logger.info("=== TEMPORAL EXTRACTION RESULTS ===")
|
||||
prompt_logger.info(
|
||||
f"[Temporal] Dialog ref_id={getattr(dialog_data, 'ref_id', None)}, group_id={getattr(dialog_data, 'group_id', None)}"
|
||||
f"[Temporal] Dialog ref_id={getattr(dialog_data, 'ref_id', None)}, end_user_id={getattr(dialog_data, 'end_user_id', None)}"
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@@ -116,7 +116,7 @@ class TripletExtractor:
|
||||
logger.info(f"Processing {len(all_statements)} statements for triplet extraction...")
|
||||
try:
|
||||
prompt_logger.info(
|
||||
f"[Triplet] Dialog ref_id={getattr(dialog_data, 'ref_id', None)}, group_id={getattr(dialog_data, 'group_id', None)}, statements_to_process={len(all_statements)}"
|
||||
f"[Triplet] Dialog ref_id={getattr(dialog_data, 'ref_id', None)}, end_user_id={getattr(dialog_data, 'end_user_id', None)}, statements_to_process={len(all_statements)}"
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@@ -75,7 +75,7 @@ class AccessHistoryManager:
|
||||
self,
|
||||
node_id: str,
|
||||
node_label: str,
|
||||
group_id: Optional[str] = None,
|
||||
end_user_id: Optional[str] = None,
|
||||
current_time: Optional[datetime] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
@@ -91,7 +91,7 @@ class AccessHistoryManager:
|
||||
Args:
|
||||
node_id: 节点ID
|
||||
node_label: 节点标签(Statement, ExtractedEntity, MemorySummary)
|
||||
group_id: 组ID(可选,用于过滤)
|
||||
end_user_id: 组ID(可选,用于过滤)
|
||||
current_time: 当前时间(可选,默认使用系统时间)
|
||||
|
||||
Returns:
|
||||
@@ -123,7 +123,7 @@ class AccessHistoryManager:
|
||||
for attempt in range(self.max_retries):
|
||||
try:
|
||||
# 步骤1:读取当前节点状态
|
||||
node_data = await self._fetch_node(node_id, node_label, group_id)
|
||||
node_data = await self._fetch_node(node_id, node_label, end_user_id)
|
||||
|
||||
if not node_data:
|
||||
raise ValueError(
|
||||
@@ -142,7 +142,7 @@ class AccessHistoryManager:
|
||||
node_id=node_id,
|
||||
node_label=node_label,
|
||||
update_data=update_data,
|
||||
group_id=group_id
|
||||
end_user_id=end_user_id
|
||||
)
|
||||
|
||||
logger.info(
|
||||
@@ -172,7 +172,7 @@ class AccessHistoryManager:
|
||||
self,
|
||||
node_ids: List[str],
|
||||
node_label: str,
|
||||
group_id: Optional[str] = None,
|
||||
end_user_id: Optional[str] = None,
|
||||
current_time: Optional[datetime] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
@@ -184,7 +184,7 @@ class AccessHistoryManager:
|
||||
Args:
|
||||
node_ids: 节点ID列表
|
||||
node_label: 节点标签(所有节点必须是同一类型)
|
||||
group_id: 组ID(可选)
|
||||
end_user_id: 组ID(可选)
|
||||
current_time: 当前时间(可选)
|
||||
|
||||
Returns:
|
||||
@@ -202,7 +202,7 @@ class AccessHistoryManager:
|
||||
task = self.record_access(
|
||||
node_id=node_id,
|
||||
node_label=node_label,
|
||||
group_id=group_id,
|
||||
end_user_id=end_user_id,
|
||||
current_time=current_time
|
||||
)
|
||||
tasks.append(task)
|
||||
@@ -235,7 +235,7 @@ class AccessHistoryManager:
|
||||
self,
|
||||
node_id: str,
|
||||
node_label: str,
|
||||
group_id: Optional[str] = None
|
||||
end_user_id: Optional[str] = None
|
||||
) -> Tuple[ConsistencyCheckResult, Optional[str]]:
|
||||
"""
|
||||
检查节点数据的一致性
|
||||
@@ -249,14 +249,14 @@ class AccessHistoryManager:
|
||||
Args:
|
||||
node_id: 节点ID
|
||||
node_label: 节点标签
|
||||
group_id: 组ID(可选)
|
||||
end_user_id: 组ID(可选)
|
||||
|
||||
Returns:
|
||||
Tuple[ConsistencyCheckResult, Optional[str]]:
|
||||
- 一致性检查结果枚举
|
||||
- 错误描述(如果不一致)
|
||||
"""
|
||||
node_data = await self._fetch_node(node_id, node_label, group_id)
|
||||
node_data = await self._fetch_node(node_id, node_label, end_user_id)
|
||||
|
||||
if not node_data:
|
||||
return ConsistencyCheckResult.CONSISTENT, None
|
||||
@@ -305,7 +305,7 @@ class AccessHistoryManager:
|
||||
async def check_batch_consistency(
|
||||
self,
|
||||
node_label: str,
|
||||
group_id: Optional[str] = None,
|
||||
end_user_id: Optional[str] = None,
|
||||
limit: int = 1000
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
@@ -313,7 +313,7 @@ class AccessHistoryManager:
|
||||
|
||||
Args:
|
||||
node_label: 节点标签
|
||||
group_id: 组ID(可选)
|
||||
end_user_id: 组ID(可选)
|
||||
limit: 检查的最大节点数
|
||||
|
||||
Returns:
|
||||
@@ -329,16 +329,16 @@ class AccessHistoryManager:
|
||||
MATCH (n:{node_label})
|
||||
WHERE n.access_history IS NOT NULL
|
||||
"""
|
||||
if group_id:
|
||||
query += " AND n.group_id = $group_id"
|
||||
if end_user_id:
|
||||
query += " AND n.end_user_id = $end_user_id"
|
||||
query += """
|
||||
RETURN n.id as id
|
||||
LIMIT $limit
|
||||
"""
|
||||
|
||||
params = {"limit": limit}
|
||||
if group_id:
|
||||
params["group_id"] = group_id
|
||||
if end_user_id:
|
||||
params["end_user_id"] = end_user_id
|
||||
|
||||
results = await self.connector.execute_query(query, **params)
|
||||
node_ids = [r['id'] for r in results]
|
||||
@@ -351,7 +351,7 @@ class AccessHistoryManager:
|
||||
result, message = await self.check_consistency(
|
||||
node_id=node_id,
|
||||
node_label=node_label,
|
||||
group_id=group_id
|
||||
end_user_id=end_user_id
|
||||
)
|
||||
|
||||
if result == ConsistencyCheckResult.CONSISTENT:
|
||||
@@ -387,7 +387,7 @@ class AccessHistoryManager:
|
||||
self,
|
||||
node_id: str,
|
||||
node_label: str,
|
||||
group_id: Optional[str] = None
|
||||
end_user_id: Optional[str] = None
|
||||
) -> bool:
|
||||
"""
|
||||
自动修复节点的数据不一致问题
|
||||
@@ -401,7 +401,7 @@ class AccessHistoryManager:
|
||||
Args:
|
||||
node_id: 节点ID
|
||||
node_label: 节点标签
|
||||
group_id: 组ID(可选)
|
||||
end_user_id: 组ID(可选)
|
||||
|
||||
Returns:
|
||||
bool: 修复成功返回True,否则返回False
|
||||
@@ -411,7 +411,7 @@ class AccessHistoryManager:
|
||||
result, message = await self.check_consistency(
|
||||
node_id=node_id,
|
||||
node_label=node_label,
|
||||
group_id=group_id
|
||||
end_user_id=end_user_id
|
||||
)
|
||||
|
||||
if result == ConsistencyCheckResult.CONSISTENT:
|
||||
@@ -419,7 +419,7 @@ class AccessHistoryManager:
|
||||
return True
|
||||
|
||||
# 获取节点数据
|
||||
node_data = await self._fetch_node(node_id, node_label, group_id)
|
||||
node_data = await self._fetch_node(node_id, node_label, end_user_id)
|
||||
if not node_data:
|
||||
logger.error(f"节点不存在,无法修复: {node_label}[{node_id}]")
|
||||
return False
|
||||
@@ -457,8 +457,8 @@ class AccessHistoryManager:
|
||||
query = f"""
|
||||
MATCH (n:{node_label} {{id: $node_id}})
|
||||
"""
|
||||
if group_id:
|
||||
query += " WHERE n.group_id = $group_id"
|
||||
if end_user_id:
|
||||
query += " WHERE n.end_user_id = $end_user_id"
|
||||
query += """
|
||||
SET n += $repair_data
|
||||
RETURN n
|
||||
@@ -468,8 +468,8 @@ class AccessHistoryManager:
|
||||
'node_id': node_id,
|
||||
'repair_data': repair_data
|
||||
}
|
||||
if group_id:
|
||||
params['group_id'] = group_id
|
||||
if end_user_id:
|
||||
params['end_user_id'] = end_user_id
|
||||
|
||||
await self.connector.execute_query(query, **params)
|
||||
|
||||
@@ -491,7 +491,7 @@ class AccessHistoryManager:
|
||||
self,
|
||||
node_id: str,
|
||||
node_label: str,
|
||||
group_id: Optional[str] = None
|
||||
end_user_id: Optional[str] = None
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
获取节点数据
|
||||
@@ -499,7 +499,7 @@ class AccessHistoryManager:
|
||||
Args:
|
||||
node_id: 节点ID
|
||||
node_label: 节点标签
|
||||
group_id: 组ID(可选)
|
||||
end_user_id: 组ID(可选)
|
||||
|
||||
Returns:
|
||||
Optional[Dict[str, Any]]: 节点数据,如果不存在返回None
|
||||
@@ -507,8 +507,8 @@ class AccessHistoryManager:
|
||||
query = f"""
|
||||
MATCH (n:{node_label} {{id: $node_id}})
|
||||
"""
|
||||
if group_id:
|
||||
query += " WHERE n.group_id = $group_id"
|
||||
if end_user_id:
|
||||
query += " WHERE n.end_user_id = $end_user_id"
|
||||
query += """
|
||||
RETURN n.id as id,
|
||||
n.importance_score as importance_score,
|
||||
@@ -519,8 +519,8 @@ class AccessHistoryManager:
|
||||
"""
|
||||
|
||||
params = {'node_id': node_id}
|
||||
if group_id:
|
||||
params['group_id'] = group_id
|
||||
if end_user_id:
|
||||
params['end_user_id'] = end_user_id
|
||||
|
||||
results = await self.connector.execute_query(query, **params)
|
||||
|
||||
@@ -585,7 +585,7 @@ class AccessHistoryManager:
|
||||
node_id: str,
|
||||
node_label: str,
|
||||
update_data: Dict[str, Any],
|
||||
group_id: Optional[str] = None
|
||||
end_user_id: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
原子性更新节点(使用乐观锁)
|
||||
@@ -597,7 +597,7 @@ class AccessHistoryManager:
|
||||
node_id: 节点ID
|
||||
node_label: 节点标签
|
||||
update_data: 更新数据
|
||||
group_id: 组ID(可选)
|
||||
end_user_id: 组ID(可选)
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 更新后的节点数据
|
||||
@@ -606,13 +606,13 @@ class AccessHistoryManager:
|
||||
RuntimeError: 如果更新失败或发生版本冲突
|
||||
"""
|
||||
# 定义事务函数
|
||||
async def update_transaction(tx, node_id, node_label, update_data, group_id):
|
||||
async def update_transaction(tx, node_id, node_label, update_data, end_user_id):
|
||||
# 步骤1:读取当前节点并获取版本号
|
||||
read_query = f"""
|
||||
MATCH (n:{node_label} {{id: $node_id}})
|
||||
"""
|
||||
if group_id:
|
||||
read_query += " WHERE n.group_id = $group_id"
|
||||
if end_user_id:
|
||||
read_query += " WHERE n.end_user_id = $end_user_id"
|
||||
read_query += """
|
||||
RETURN n.id as id,
|
||||
n.version as version,
|
||||
@@ -624,8 +624,8 @@ class AccessHistoryManager:
|
||||
"""
|
||||
|
||||
read_params = {'node_id': node_id}
|
||||
if group_id:
|
||||
read_params['group_id'] = group_id
|
||||
if end_user_id:
|
||||
read_params['end_user_id'] = end_user_id
|
||||
|
||||
read_result = await tx.run(read_query, **read_params)
|
||||
current_node = await read_result.single()
|
||||
@@ -656,8 +656,8 @@ class AccessHistoryManager:
|
||||
|
||||
# 构建 WHERE 子句
|
||||
where_conditions = []
|
||||
if group_id:
|
||||
where_conditions.append("n.group_id = $group_id")
|
||||
if end_user_id:
|
||||
where_conditions.append("n.end_user_id = $end_user_id")
|
||||
|
||||
# 添加版本检查
|
||||
if current_version > 0:
|
||||
@@ -695,8 +695,8 @@ class AccessHistoryManager:
|
||||
'last_access_time': update_data['last_access_time'],
|
||||
'access_count': update_data['access_count']
|
||||
}
|
||||
if group_id:
|
||||
update_params['group_id'] = group_id
|
||||
if end_user_id:
|
||||
update_params['end_user_id'] = end_user_id
|
||||
|
||||
update_result = await tx.run(update_query, **update_params)
|
||||
updated_node = await update_result.single()
|
||||
@@ -720,7 +720,7 @@ class AccessHistoryManager:
|
||||
node_id=node_id,
|
||||
node_label=node_label,
|
||||
update_data=update_data,
|
||||
group_id=group_id
|
||||
end_user_id=end_user_id
|
||||
)
|
||||
return result
|
||||
except Exception as e:
|
||||
|
||||
@@ -11,9 +11,10 @@ Functions:
|
||||
|
||||
import logging
|
||||
from typing import Optional, Dict, Any
|
||||
from uuid import UUID
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.repositories.data_config_repository import DataConfigRepository
|
||||
from app.repositories.memory_config_repository import MemoryConfigRepository
|
||||
from app.core.memory.storage_services.forgetting_engine.actr_calculator import ACTRCalculator
|
||||
|
||||
|
||||
@@ -61,12 +62,12 @@ def calculate_forgetting_rate(lambda_time: float, lambda_mem: float) -> float:
|
||||
|
||||
def load_actr_config_from_db(
|
||||
db: Session,
|
||||
config_id: Optional[int] = None
|
||||
config_id: Optional[UUID] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
从数据库加载 ACT-R 配置参数
|
||||
|
||||
从 PostgreSQL 的 data_config 表读取配置参数,
|
||||
从 PostgreSQL 的 memory_config 表读取配置参数,
|
||||
并计算派生参数(如 forgetting_rate)。
|
||||
|
||||
Args:
|
||||
@@ -99,7 +100,7 @@ def load_actr_config_from_db(
|
||||
|
||||
# 从数据库加载配置
|
||||
try:
|
||||
repository = DataConfigRepository()
|
||||
repository = MemoryConfigRepository()
|
||||
db_config = repository.get_by_id(db, config_id)
|
||||
|
||||
if db_config is None:
|
||||
@@ -150,7 +151,7 @@ def load_actr_config_from_db(
|
||||
|
||||
def create_actr_calculator_from_config(
|
||||
db: Session,
|
||||
config_id: Optional[int] = None
|
||||
config_id: Optional[UUID] = None
|
||||
) -> ACTRCalculator:
|
||||
"""
|
||||
从数据库配置创建 ACTRCalculator 实例
|
||||
@@ -168,11 +169,6 @@ def create_actr_calculator_from_config(
|
||||
ValueError: 如果指定的 config_id 不存在
|
||||
|
||||
Examples:
|
||||
>>> from sqlalchemy.orm import Session
|
||||
>>> db = Session()
|
||||
>>> calculator = create_actr_calculator_from_config(db, config_id=1)
|
||||
>>> # 使用计算器
|
||||
>>> activation = calculator.calculate_memory_activation(...)
|
||||
"""
|
||||
# 加载配置
|
||||
config = load_actr_config_from_db(db, config_id)
|
||||
|
||||
@@ -16,6 +16,7 @@ Classes:
|
||||
|
||||
import logging
|
||||
from typing import Dict, Any, Optional
|
||||
from uuid import UUID
|
||||
from datetime import datetime
|
||||
|
||||
from app.core.memory.storage_services.forgetting_engine.forgetting_strategy import ForgettingStrategy
|
||||
@@ -66,10 +67,10 @@ class ForgettingScheduler:
|
||||
|
||||
async def run_forgetting_cycle(
|
||||
self,
|
||||
group_id: Optional[str] = None,
|
||||
end_user_id: Optional[str] = None,
|
||||
max_merge_batch_size: int = 100,
|
||||
min_days_since_access: int = 30,
|
||||
config_id: Optional[int] = None,
|
||||
config_id: Optional[UUID] = None,
|
||||
db = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
@@ -77,7 +78,7 @@ class ForgettingScheduler:
|
||||
|
||||
|
||||
Args:
|
||||
group_id: 组 ID(可选,用于过滤特定组的节点)
|
||||
end_user_id: 组 ID(可选,用于过滤特定组的节点)
|
||||
max_merge_batch_size: 单次最大融合节点对数(默认 100)
|
||||
min_days_since_access: 最小未访问天数(默认 30 天)
|
||||
config_id: 配置ID(可选,用于获取 llm_id)
|
||||
@@ -107,19 +108,19 @@ class ForgettingScheduler:
|
||||
start_time_iso = start_time.isoformat()
|
||||
|
||||
logger.info(
|
||||
f"开始遗忘周期: group_id={group_id}, "
|
||||
f"开始遗忘周期: end_user_id={end_user_id}, "
|
||||
f"max_batch={max_merge_batch_size}, "
|
||||
f"min_days={min_days_since_access}"
|
||||
)
|
||||
|
||||
try:
|
||||
# 步骤1:统计遗忘前的节点数量
|
||||
nodes_before = await self._count_knowledge_nodes(group_id)
|
||||
nodes_before = await self._count_knowledge_nodes(end_user_id)
|
||||
logger.info(f"遗忘前节点总数: {nodes_before}")
|
||||
|
||||
# 步骤2:识别可遗忘的节点对
|
||||
forgettable_pairs = await self.forgetting_strategy.find_forgettable_nodes(
|
||||
group_id=group_id,
|
||||
end_user_id=end_user_id,
|
||||
min_days_since_access=min_days_since_access
|
||||
)
|
||||
|
||||
@@ -213,7 +214,7 @@ class ForgettingScheduler:
|
||||
'statement_text': pair['statement_text'],
|
||||
'statement_activation': pair['statement_activation'],
|
||||
'statement_importance': pair['statement_importance'],
|
||||
'group_id': group_id
|
||||
'end_user_id': end_user_id
|
||||
}
|
||||
|
||||
entity_node = {
|
||||
@@ -222,7 +223,7 @@ class ForgettingScheduler:
|
||||
'entity_type': pair['entity_type'],
|
||||
'entity_activation': pair['entity_activation'],
|
||||
'entity_importance': pair['entity_importance'],
|
||||
'group_id': group_id
|
||||
'end_user_id': end_user_id
|
||||
}
|
||||
|
||||
# 融合节点
|
||||
@@ -262,7 +263,7 @@ class ForgettingScheduler:
|
||||
continue
|
||||
|
||||
# 步骤6:统计遗忘后的节点数量
|
||||
nodes_after = await self._count_knowledge_nodes(group_id)
|
||||
nodes_after = await self._count_knowledge_nodes(end_user_id)
|
||||
logger.info(f"遗忘后节点总数: {nodes_after}")
|
||||
|
||||
# 步骤7:生成遗忘报告
|
||||
@@ -315,7 +316,7 @@ class ForgettingScheduler:
|
||||
|
||||
async def _count_knowledge_nodes(
|
||||
self,
|
||||
group_id: Optional[str] = None
|
||||
end_user_id: Optional[str] = None
|
||||
) -> int:
|
||||
"""
|
||||
统计知识层节点总数
|
||||
@@ -323,7 +324,7 @@ class ForgettingScheduler:
|
||||
统计 Statement、ExtractedEntity 和 MemorySummary 节点的总数。
|
||||
|
||||
Args:
|
||||
group_id: 组 ID(可选,用于过滤特定组的节点)
|
||||
end_user_id: 组 ID(可选,用于过滤特定组的节点)
|
||||
|
||||
Returns:
|
||||
int: 知识层节点总数
|
||||
@@ -333,16 +334,16 @@ class ForgettingScheduler:
|
||||
WHERE (n:Statement OR n:ExtractedEntity OR n:MemorySummary)
|
||||
"""
|
||||
|
||||
if group_id:
|
||||
query += " AND n.group_id = $group_id"
|
||||
if end_user_id:
|
||||
query += " AND n.end_user_id = $end_user_id"
|
||||
|
||||
query += """
|
||||
RETURN count(n) as total
|
||||
"""
|
||||
|
||||
params = {}
|
||||
if group_id:
|
||||
params['group_id'] = group_id
|
||||
if end_user_id:
|
||||
params['end_user_id'] = end_user_id
|
||||
|
||||
results = await self.connector.execute_query(query, **params)
|
||||
|
||||
|
||||
@@ -13,6 +13,7 @@ Classes:
|
||||
|
||||
import logging
|
||||
from typing import List, Dict, Any, Optional
|
||||
from uuid import UUID
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
@@ -90,7 +91,7 @@ class ForgettingStrategy:
|
||||
|
||||
async def find_forgettable_nodes(
|
||||
self,
|
||||
group_id: Optional[str] = None,
|
||||
end_user_id: Optional[str] = None,
|
||||
min_days_since_access: int = 30
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
@@ -102,7 +103,7 @@ class ForgettingStrategy:
|
||||
3. Statement 和 Entity 之间存在关系边
|
||||
|
||||
Args:
|
||||
group_id: 组 ID(可选,用于过滤特定组的节点)
|
||||
end_user_id: 组 ID(可选,用于过滤特定组的节点)
|
||||
min_days_since_access: 最小未访问天数(默认 30 天)
|
||||
|
||||
Returns:
|
||||
@@ -136,8 +137,8 @@ class ForgettingStrategy:
|
||||
AND (e.entity_type IS NULL OR e.entity_type <> 'Person')
|
||||
"""
|
||||
|
||||
if group_id:
|
||||
query += " AND s.group_id = $group_id AND e.group_id = $group_id"
|
||||
if end_user_id:
|
||||
query += " AND s.end_user_id = $end_user_id AND e.end_user_id = $end_user_id"
|
||||
|
||||
query += """
|
||||
RETURN s.id as statement_id,
|
||||
@@ -159,8 +160,8 @@ class ForgettingStrategy:
|
||||
'threshold': self.forgetting_threshold,
|
||||
'cutoff_time': cutoff_time_iso
|
||||
}
|
||||
if group_id:
|
||||
params['group_id'] = group_id
|
||||
if end_user_id:
|
||||
params['end_user_id'] = end_user_id
|
||||
|
||||
results = await self.connector.execute_query(query, **params)
|
||||
|
||||
@@ -176,7 +177,7 @@ class ForgettingStrategy:
|
||||
self,
|
||||
statement_node: Dict[str, Any],
|
||||
entity_node: Dict[str, Any],
|
||||
config_id: Optional[int] = None,
|
||||
config_id: Optional[UUID] = None,
|
||||
db = None
|
||||
) -> str:
|
||||
"""
|
||||
@@ -247,8 +248,8 @@ class ForgettingStrategy:
|
||||
entity_activation = entity_node['entity_activation']
|
||||
entity_importance = entity_node['entity_importance']
|
||||
|
||||
# 获取 group_id(从 statement 或 entity 节点)
|
||||
group_id = statement_node.get('group_id') or entity_node.get('group_id')
|
||||
# 获取 end_user_id(从 statement 或 entity 节点)
|
||||
end_user_id = statement_node.get('end_user_id') or entity_node.get('end_user_id')
|
||||
|
||||
# 生成摘要内容
|
||||
summary_text = await self._generate_summary(
|
||||
@@ -325,7 +326,7 @@ class ForgettingStrategy:
|
||||
last_access_time: $current_time,
|
||||
access_count: 1,
|
||||
version: 1,
|
||||
group_id: $group_id,
|
||||
end_user_id: $end_user_id,
|
||||
created_at: datetime($current_time),
|
||||
merged_at: datetime($current_time)
|
||||
})
|
||||
@@ -423,7 +424,7 @@ class ForgettingStrategy:
|
||||
'inherited_activation': inherited_activation,
|
||||
'inherited_importance': inherited_importance,
|
||||
'current_time': current_time_iso,
|
||||
'group_id': group_id
|
||||
'end_user_id': end_user_id
|
||||
}
|
||||
|
||||
try:
|
||||
@@ -462,7 +463,7 @@ class ForgettingStrategy:
|
||||
statement_text: str,
|
||||
entity_name: str,
|
||||
entity_type: str,
|
||||
config_id: Optional[int] = None,
|
||||
config_id: Optional[UUID] = None,
|
||||
db = None
|
||||
) -> str:
|
||||
"""
|
||||
@@ -527,7 +528,7 @@ class ForgettingStrategy:
|
||||
statement_text, entity_name, entity_type
|
||||
)
|
||||
|
||||
async def _get_llm_client(self, db, config_id: int):
|
||||
async def _get_llm_client(self, db, config_id: UUID):
|
||||
"""
|
||||
从数据库获取 LLM 客户端
|
||||
|
||||
@@ -539,11 +540,11 @@ class ForgettingStrategy:
|
||||
LLM 客户端实例,如果无法获取则返回 None
|
||||
"""
|
||||
try:
|
||||
from app.repositories.data_config_repository import DataConfigRepository
|
||||
from app.repositories.memory_config_repository import MemoryConfigRepository
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
|
||||
# 从数据库读取配置
|
||||
repository = DataConfigRepository()
|
||||
repository = MemoryConfigRepository()
|
||||
db_config = repository.get_by_id(db, config_id)
|
||||
|
||||
if db_config is None or db_config.llm_id is None:
|
||||
|
||||
@@ -37,7 +37,7 @@ __all__ = [
|
||||
async def run_hybrid_search(
|
||||
query_text: str,
|
||||
search_type: str = "hybrid",
|
||||
group_id: str | None = None,
|
||||
end_user_id: str | None = None,
|
||||
apply_id: str | None = None,
|
||||
user_id: str | None = None,
|
||||
limit: int = 50,
|
||||
@@ -54,7 +54,7 @@ async def run_hybrid_search(
|
||||
Args:
|
||||
query_text: 查询文本
|
||||
search_type: 搜索类型("hybrid", "keyword", "semantic")
|
||||
group_id: 组ID过滤
|
||||
end_user_id: 组ID过滤
|
||||
apply_id: 应用ID过滤
|
||||
user_id: 用户ID过滤
|
||||
limit: 每个类别的最大结果数
|
||||
@@ -104,7 +104,7 @@ async def run_hybrid_search(
|
||||
# 执行搜索
|
||||
result = await strategy.search(
|
||||
query_text=query_text,
|
||||
group_id=group_id,
|
||||
end_user_id=end_user_id,
|
||||
limit=limit,
|
||||
include=include,
|
||||
alpha=alpha,
|
||||
|
||||
@@ -77,7 +77,7 @@
|
||||
# async def search(
|
||||
# self,
|
||||
# query_text: str,
|
||||
# group_id: Optional[str] = None,
|
||||
# end_user_id: Optional[str] = None,
|
||||
# limit: int = 50,
|
||||
# include: Optional[List[str]] = None,
|
||||
# **kwargs
|
||||
@@ -86,7 +86,7 @@
|
||||
|
||||
# Args:
|
||||
# query_text: 查询文本
|
||||
# group_id: 可选的组ID过滤
|
||||
# end_user_id: 可选的组ID过滤
|
||||
# limit: 每个类别的最大结果数
|
||||
# include: 要包含的搜索类别列表
|
||||
# **kwargs: 其他搜索参数(如alpha, use_forgetting_curve)
|
||||
@@ -94,7 +94,7 @@
|
||||
# Returns:
|
||||
# SearchResult: 搜索结果对象
|
||||
# """
|
||||
# logger.info(f"执行混合搜索: query='{query_text}', group_id={group_id}, limit={limit}")
|
||||
# logger.info(f"执行混合搜索: query='{query_text}', end_user_id={end_user_id}, limit={limit}")
|
||||
|
||||
# # 从kwargs中获取参数
|
||||
# alpha = kwargs.get("alpha", self.alpha)
|
||||
@@ -107,14 +107,14 @@
|
||||
# # 并行执行关键词搜索和语义搜索
|
||||
# keyword_result = await self.keyword_strategy.search(
|
||||
# query_text=query_text,
|
||||
# group_id=group_id,
|
||||
# end_user_id=end_user_id,
|
||||
# limit=limit,
|
||||
# include=include_list
|
||||
# )
|
||||
|
||||
# semantic_result = await self.semantic_strategy.search(
|
||||
# query_text=query_text,
|
||||
# group_id=group_id,
|
||||
# end_user_id=end_user_id,
|
||||
# limit=limit,
|
||||
# include=include_list
|
||||
# )
|
||||
@@ -139,7 +139,7 @@
|
||||
# metadata = self._create_metadata(
|
||||
# query_text=query_text,
|
||||
# search_type="hybrid",
|
||||
# group_id=group_id,
|
||||
# end_user_id=end_user_id,
|
||||
# limit=limit,
|
||||
# include=include_list,
|
||||
# alpha=alpha,
|
||||
@@ -165,7 +165,7 @@
|
||||
# metadata=self._create_metadata(
|
||||
# query_text=query_text,
|
||||
# search_type="hybrid",
|
||||
# group_id=group_id,
|
||||
# end_user_id=end_user_id,
|
||||
# limit=limit,
|
||||
# error=str(e)
|
||||
# )
|
||||
|
||||
@@ -44,7 +44,7 @@ class KeywordSearchStrategy(SearchStrategy):
|
||||
async def search(
|
||||
self,
|
||||
query_text: str,
|
||||
group_id: Optional[str] = None,
|
||||
end_user_id: Optional[str] = None,
|
||||
limit: int = 50,
|
||||
include: Optional[List[str]] = None,
|
||||
**kwargs
|
||||
@@ -53,7 +53,7 @@ class KeywordSearchStrategy(SearchStrategy):
|
||||
|
||||
Args:
|
||||
query_text: 查询文本
|
||||
group_id: 可选的组ID过滤
|
||||
end_user_id: 可选的组ID过滤
|
||||
limit: 每个类别的最大结果数
|
||||
include: 要包含的搜索类别列表
|
||||
**kwargs: 其他搜索参数
|
||||
@@ -61,7 +61,7 @@ class KeywordSearchStrategy(SearchStrategy):
|
||||
Returns:
|
||||
SearchResult: 搜索结果对象
|
||||
"""
|
||||
logger.info(f"执行关键词搜索: query='{query_text}', group_id={group_id}, limit={limit}")
|
||||
logger.info(f"执行关键词搜索: query='{query_text}', end_user_id={end_user_id}, limit={limit}")
|
||||
|
||||
# 获取有效的搜索类别
|
||||
include_list = self._get_include_list(include)
|
||||
@@ -75,7 +75,7 @@ class KeywordSearchStrategy(SearchStrategy):
|
||||
results_dict = await search_graph(
|
||||
connector=self.connector,
|
||||
q=query_text,
|
||||
group_id=group_id,
|
||||
end_user_id=end_user_id,
|
||||
limit=limit,
|
||||
include=include_list
|
||||
)
|
||||
@@ -84,7 +84,7 @@ class KeywordSearchStrategy(SearchStrategy):
|
||||
metadata = self._create_metadata(
|
||||
query_text=query_text,
|
||||
search_type="keyword",
|
||||
group_id=group_id,
|
||||
end_user_id=end_user_id,
|
||||
limit=limit,
|
||||
include=include_list
|
||||
)
|
||||
@@ -115,7 +115,7 @@ class KeywordSearchStrategy(SearchStrategy):
|
||||
metadata=self._create_metadata(
|
||||
query_text=query_text,
|
||||
search_type="keyword",
|
||||
group_id=group_id,
|
||||
end_user_id=end_user_id,
|
||||
limit=limit,
|
||||
error=str(e)
|
||||
)
|
||||
|
||||
@@ -58,7 +58,7 @@ class SearchStrategy(ABC):
|
||||
async def search(
|
||||
self,
|
||||
query_text: str,
|
||||
group_id: Optional[str] = None,
|
||||
end_user_id: Optional[str] = None,
|
||||
limit: int = 50,
|
||||
include: Optional[List[str]] = None,
|
||||
**kwargs
|
||||
@@ -67,7 +67,7 @@ class SearchStrategy(ABC):
|
||||
|
||||
Args:
|
||||
query_text: 查询文本
|
||||
group_id: 可选的组ID过滤
|
||||
end_user_id: 可选的组ID过滤
|
||||
limit: 每个类别的最大结果数
|
||||
include: 要包含的搜索类别列表(statements, chunks, entities, summaries)
|
||||
**kwargs: 其他搜索参数
|
||||
@@ -81,7 +81,7 @@ class SearchStrategy(ABC):
|
||||
self,
|
||||
query_text: str,
|
||||
search_type: str,
|
||||
group_id: Optional[str] = None,
|
||||
end_user_id: Optional[str] = None,
|
||||
limit: int = 50,
|
||||
**kwargs
|
||||
) -> Dict[str, Any]:
|
||||
@@ -90,7 +90,7 @@ class SearchStrategy(ABC):
|
||||
Args:
|
||||
query_text: 查询文本
|
||||
search_type: 搜索类型
|
||||
group_id: 组ID
|
||||
end_user_id: 组ID
|
||||
limit: 结果限制
|
||||
**kwargs: 其他元数据
|
||||
|
||||
@@ -100,7 +100,7 @@ class SearchStrategy(ABC):
|
||||
metadata = {
|
||||
"query": query_text,
|
||||
"search_type": search_type,
|
||||
"group_id": group_id,
|
||||
"end_user_id": end_user_id,
|
||||
"limit": limit,
|
||||
"timestamp": datetime.now().isoformat()
|
||||
}
|
||||
|
||||
@@ -85,7 +85,7 @@ class SemanticSearchStrategy(SearchStrategy):
|
||||
async def search(
|
||||
self,
|
||||
query_text: str,
|
||||
group_id: Optional[str] = None,
|
||||
end_user_id: Optional[str] = None,
|
||||
limit: int = 50,
|
||||
include: Optional[List[str]] = None,
|
||||
**kwargs
|
||||
@@ -94,7 +94,7 @@ class SemanticSearchStrategy(SearchStrategy):
|
||||
|
||||
Args:
|
||||
query_text: 查询文本
|
||||
group_id: 可选的组ID过滤
|
||||
end_user_id: 可选的组ID过滤
|
||||
limit: 每个类别的最大结果数
|
||||
include: 要包含的搜索类别列表
|
||||
**kwargs: 其他搜索参数
|
||||
@@ -102,7 +102,7 @@ class SemanticSearchStrategy(SearchStrategy):
|
||||
Returns:
|
||||
SearchResult: 搜索结果对象
|
||||
"""
|
||||
logger.info(f"执行语义搜索: query='{query_text}', group_id={group_id}, limit={limit}")
|
||||
logger.info(f"执行语义搜索: query='{query_text}', end_user_id={end_user_id}, limit={limit}")
|
||||
|
||||
# 获取有效的搜索类别
|
||||
include_list = self._get_include_list(include)
|
||||
@@ -119,7 +119,7 @@ class SemanticSearchStrategy(SearchStrategy):
|
||||
connector=self.connector,
|
||||
embedder_client=self.embedder_client,
|
||||
query_text=query_text,
|
||||
group_id=group_id,
|
||||
end_user_id=end_user_id,
|
||||
limit=limit,
|
||||
include=include_list
|
||||
)
|
||||
@@ -128,7 +128,7 @@ class SemanticSearchStrategy(SearchStrategy):
|
||||
metadata = self._create_metadata(
|
||||
query_text=query_text,
|
||||
search_type="semantic",
|
||||
group_id=group_id,
|
||||
end_user_id=end_user_id,
|
||||
limit=limit,
|
||||
include=include_list
|
||||
)
|
||||
@@ -159,7 +159,7 @@ class SemanticSearchStrategy(SearchStrategy):
|
||||
metadata=self._create_metadata(
|
||||
query_text=query_text,
|
||||
search_type="semantic",
|
||||
group_id=group_id,
|
||||
end_user_id=end_user_id,
|
||||
limit=limit,
|
||||
error=str(e)
|
||||
)
|
||||
|
||||
@@ -23,7 +23,7 @@ async def _load_(data: List[Any]) -> List[Dict]:
|
||||
target_keys = [
|
||||
"id",
|
||||
"statement",
|
||||
"group_id",
|
||||
"end_user_id",
|
||||
"chunk_id",
|
||||
"created_at",
|
||||
"expired_at",
|
||||
@@ -75,7 +75,7 @@ async def get_data(result):
|
||||
"""
|
||||
EXCLUDE_FIELDS = {
|
||||
"user_id",
|
||||
"group_id",
|
||||
"end_user_id",
|
||||
"entity_type",
|
||||
"connect_strength",
|
||||
"relationship_type",
|
||||
|
||||
@@ -62,7 +62,7 @@ class ConfigAuditLogger:
|
||||
self,
|
||||
config_id: str,
|
||||
user_id: Optional[str] = None,
|
||||
group_id: Optional[str] = None,
|
||||
end_user_id: Optional[str] = None,
|
||||
success: bool = True,
|
||||
details: Optional[Dict[str, Any]] = None
|
||||
):
|
||||
@@ -72,14 +72,14 @@ class ConfigAuditLogger:
|
||||
Args:
|
||||
config_id: 配置 ID
|
||||
user_id: 用户 ID(可选)
|
||||
group_id: 组 ID(可选)
|
||||
end_user_id: 组 ID(可选)
|
||||
success: 是否成功
|
||||
details: 详细信息(可选)
|
||||
"""
|
||||
result = "SUCCESS" if success else "FAILED"
|
||||
msg = (
|
||||
f"CONFIG_LOAD config_id={config_id} "
|
||||
f"user={user_id or 'N/A'} group={group_id or 'N/A'} "
|
||||
f"user={user_id or 'N/A'} group={end_user_id or 'N/A'} "
|
||||
f"result={result}"
|
||||
)
|
||||
if details:
|
||||
@@ -121,7 +121,7 @@ class ConfigAuditLogger:
|
||||
self,
|
||||
operation: str,
|
||||
config_id: str,
|
||||
group_id: str,
|
||||
end_user_id: str,
|
||||
success: bool = True,
|
||||
duration: Optional[float] = None,
|
||||
error: Optional[str] = None,
|
||||
@@ -133,7 +133,7 @@ class ConfigAuditLogger:
|
||||
Args:
|
||||
operation: 操作类型(WRITE, READ 等)
|
||||
config_id: 配置 ID
|
||||
group_id: 组 ID
|
||||
end_user_id: 组 ID
|
||||
success: 是否成功
|
||||
duration: 操作耗时(秒)
|
||||
error: 错误信息(可选)
|
||||
@@ -142,7 +142,7 @@ class ConfigAuditLogger:
|
||||
result = "SUCCESS" if success else "FAILED"
|
||||
msg = (
|
||||
f"{operation.upper()} config_id={config_id} "
|
||||
f"group={group_id} result={result}"
|
||||
f"group={end_user_id} result={result}"
|
||||
)
|
||||
if duration is not None:
|
||||
msg += f" duration={duration:.2f}s"
|
||||
|
||||
1
api/app/core/models/scripts/__init__.py
Normal file
1
api/app/core/models/scripts/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""模型配置脚本模块"""
|
||||
174
api/app/core/models/scripts/bedrock_models.yaml
Normal file
174
api/app/core/models/scripts/bedrock_models.yaml
Normal file
@@ -0,0 +1,174 @@
|
||||
provider: bedrock
|
||||
enabled: true
|
||||
models:
|
||||
- name: ai21
|
||||
type: llm
|
||||
provider: bedrock
|
||||
description: AI21 Labs大语言模型,completion生成模式,256000上下文窗口
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
tags:
|
||||
- 大语言模型
|
||||
logo: bedrock
|
||||
- name: amazon nova
|
||||
type: llm
|
||||
provider: bedrock
|
||||
description: Amazon Nova大语言模型,支持智能体思考、工具调用、流式工具调用、视觉能力,300000上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
tags:
|
||||
- 大语言模型
|
||||
- agent-thought
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
- vision
|
||||
logo: bedrock
|
||||
- name: anthropic claude
|
||||
type: llm
|
||||
provider: bedrock
|
||||
description: Anthropic Claude大语言模型,支持智能体思考、视觉能力、工具调用、流式工具调用、文档处理,200000上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
tags:
|
||||
- 大语言模型
|
||||
- agent-thought
|
||||
- vision
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
- document
|
||||
logo: bedrock
|
||||
- name: cohere
|
||||
type: llm
|
||||
provider: bedrock
|
||||
description: Cohere大语言模型,支持智能体思考、工具调用、流式工具调用,128000上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
tags:
|
||||
- 大语言模型
|
||||
- agent-thought
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
logo: bedrock
|
||||
- name: deepseek
|
||||
type: llm
|
||||
provider: bedrock
|
||||
description: DeepSeek大语言模型,支持智能体思考、视觉能力、工具调用、流式工具调用,32768上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
tags:
|
||||
- 大语言模型
|
||||
- agent-thought
|
||||
- vision
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
logo: bedrock
|
||||
- name: meta
|
||||
type: llm
|
||||
provider: bedrock
|
||||
description: Meta Llama大语言模型,支持智能体思考、工具调用,128000上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
tags:
|
||||
- 大语言模型
|
||||
- agent-thought
|
||||
- tool-call
|
||||
logo: bedrock
|
||||
- name: mistral
|
||||
type: llm
|
||||
provider: bedrock
|
||||
description: Mistral AI大语言模型,支持智能体思考、工具调用,32000上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
tags:
|
||||
- 大语言模型
|
||||
- agent-thought
|
||||
- tool-call
|
||||
logo: bedrock
|
||||
- name: openai
|
||||
type: llm
|
||||
provider: bedrock
|
||||
description: OpenAI大语言模型,支持智能体思考、工具调用、流式工具调用,32768上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
tags:
|
||||
- 大语言模型
|
||||
- agent-thought
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
logo: bedrock
|
||||
- name: qwen
|
||||
type: llm
|
||||
provider: bedrock
|
||||
description: Qwen大语言模型,支持智能体思考、工具调用、流式工具调用,32768上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
tags:
|
||||
- 大语言模型
|
||||
- agent-thought
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
logo: bedrock
|
||||
- name: amazon.rerank-v1:0
|
||||
type: rerank
|
||||
provider: bedrock
|
||||
description: amazon.rerank-v1:0重排序模型,5120上下文窗口
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
tags:
|
||||
- 重排序模型
|
||||
logo: bedrock
|
||||
- name: cohere.rerank-v3-5:0
|
||||
type: rerank
|
||||
provider: bedrock
|
||||
description: cohere.rerank-v3-5:0重排序模型,5120上下文窗口
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
tags:
|
||||
- 重排序模型
|
||||
logo: bedrock
|
||||
- name: amazon.nova-2-multimodal-embeddings-v1:0
|
||||
type: embedding
|
||||
provider: bedrock
|
||||
description: amazon.nova-2-multimodal-embeddings-v1:0文本嵌入模型,支持视觉能力,8192上下文窗口
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
tags:
|
||||
- 文本嵌入模型
|
||||
- vision
|
||||
logo: bedrock
|
||||
- name: amazon.titan-embed-text-v1
|
||||
type: embedding
|
||||
provider: bedrock
|
||||
description: amazon.titan-embed-text-v1文本嵌入模型,8192上下文窗口
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
tags:
|
||||
- 文本嵌入模型
|
||||
logo: bedrock
|
||||
- name: amazon.titan-embed-text-v2:0
|
||||
type: embedding
|
||||
provider: bedrock
|
||||
description: amazon.titan-embed-text-v2:0文本嵌入模型,8192上下文窗口
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
tags:
|
||||
- 文本嵌入模型
|
||||
logo: bedrock
|
||||
- name: cohere.embed-english-v3
|
||||
type: embedding
|
||||
provider: bedrock
|
||||
description: Cohere Embed 3 English文本嵌入模型,512上下文窗口
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
tags:
|
||||
- 文本嵌入模型
|
||||
logo: bedrock
|
||||
- name: cohere.embed-multilingual-v3
|
||||
type: embedding
|
||||
provider: bedrock
|
||||
description: Cohere Embed 3 Multilingual文本嵌入模型,512上下文窗口
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
tags:
|
||||
- 文本嵌入模型
|
||||
logo: bedrock
|
||||
820
api/app/core/models/scripts/dashscope_models.yaml
Normal file
820
api/app/core/models/scripts/dashscope_models.yaml
Normal file
@@ -0,0 +1,820 @@
|
||||
provider: dashscope
|
||||
enabled: true
|
||||
models:
|
||||
- name: deepseek-r1-distill-qwen-14b
|
||||
type: llm
|
||||
provider: dashscope
|
||||
description: DeepSeek-R1-Distill-Qwen-14B大语言模型,支持智能体思考,32000上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
tags:
|
||||
- 大语言模型
|
||||
- agent-thought
|
||||
logo: dashscope
|
||||
- name: deepseek-r1-distill-qwen-32b
|
||||
type: llm
|
||||
provider: dashscope
|
||||
description: DeepSeek-R1-Distill-Qwen-32B大语言模型,支持智能体思考,32000上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
tags:
|
||||
- 大语言模型
|
||||
- agent-thought
|
||||
logo: dashscope
|
||||
- name: deepseek-r1
|
||||
type: llm
|
||||
provider: dashscope
|
||||
description: DeepSeek-R1大语言模型,支持智能体思考,131072超大上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
tags:
|
||||
- 大语言模型
|
||||
- agent-thought
|
||||
logo: dashscope
|
||||
- name: deepseek-v3.1
|
||||
type: llm
|
||||
provider: dashscope
|
||||
description: DeepSeek-V3.1大语言模型,支持智能体思考,131072超大上下文窗口,对话模式,支持丰富生成参数调节
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
tags:
|
||||
- 大语言模型
|
||||
- agent-thought
|
||||
logo: dashscope
|
||||
- name: deepseek-v3.2-exp
|
||||
type: llm
|
||||
provider: dashscope
|
||||
description: DeepSeek-V3.2-exp实验版大语言模型,支持智能体思考,131072超大上下文窗口,对话模式,支持丰富生成参数调节
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
tags:
|
||||
- 大语言模型
|
||||
- agent-thought
|
||||
logo: dashscope
|
||||
- name: deepseek-v3.2
|
||||
type: llm
|
||||
provider: dashscope
|
||||
description: DeepSeek-V3.2大语言模型,支持智能体思考,131072超大上下文窗口,对话模式,支持丰富生成参数调节
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
tags:
|
||||
- 大语言模型
|
||||
- agent-thought
|
||||
logo: dashscope
|
||||
- name: deepseek-v3
|
||||
type: llm
|
||||
provider: dashscope
|
||||
description: DeepSeek-V3大语言模型,支持智能体思考,64000上下文窗口,对话模式,支持文本与JSON格式输出
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
tags:
|
||||
- 大语言模型
|
||||
- agent-thought
|
||||
logo: dashscope
|
||||
- name: farui-plus
|
||||
type: llm
|
||||
provider: dashscope
|
||||
description: farui-plus大语言模型,支持多工具调用、智能体思考、流式工具调用,12288上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
tags:
|
||||
- 大语言模型
|
||||
- multi-tool-call
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
logo: dashscope
|
||||
- name: glm-4.7
|
||||
type: llm
|
||||
provider: dashscope
|
||||
description: GLM-4.7大语言模型,支持多工具调用、智能体思考、流式工具调用,202752超大上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
tags:
|
||||
- 大语言模型
|
||||
- multi-tool-call
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
logo: dashscope
|
||||
- name: qvq-max-latest
|
||||
type: llm
|
||||
provider: dashscope
|
||||
description: qvq-max-latest大语言模型,支持视觉、智能体思考、流式工具调用,131072上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
tags:
|
||||
- 大语言模型
|
||||
- vision
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
logo: dashscope
|
||||
- name: qvq-max
|
||||
type: llm
|
||||
provider: dashscope
|
||||
description: qvq-max大语言模型,支持视觉、智能体思考、流式工具调用,131072上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
tags:
|
||||
- 大语言模型
|
||||
- vision
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
logo: dashscope
|
||||
- name: qwen-coder-turbo-0919
|
||||
type: llm
|
||||
provider: dashscope
|
||||
description: qwen-coder-turbo-0919代码专用大语言模型,支持智能体思考,131072上下文窗口,对话模式,已废弃
|
||||
is_deprecated: true
|
||||
is_official: true
|
||||
tags:
|
||||
- 大语言模型
|
||||
- 代码模型
|
||||
- agent-thought
|
||||
logo: dashscope
|
||||
- name: qwen-max-latest
|
||||
type: llm
|
||||
provider: dashscope
|
||||
description: qwen-max-latest大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式,支持联网搜索
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
tags:
|
||||
- 大语言模型
|
||||
- multi-tool-call
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
logo: dashscope
|
||||
- name: qwen-max-longcontext
|
||||
type: llm
|
||||
provider: dashscope
|
||||
description: qwen-max-longcontext长上下文大语言模型,支持多工具调用、智能体思考、流式工具调用,32000上下文窗口,对话模式,已废弃
|
||||
is_deprecated: true
|
||||
is_official: true
|
||||
tags:
|
||||
- 大语言模型
|
||||
- multi-tool-call
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
logo: dashscope
|
||||
- name: qwen-max
|
||||
type: llm
|
||||
provider: dashscope
|
||||
description: qwen-max大语言模型,支持多工具调用、智能体思考、流式工具调用,32768上下文窗口,对话模式,支持联网搜索
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
tags:
|
||||
- 大语言模型
|
||||
- multi-tool-call
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
logo: dashscope
|
||||
- name: qwen-mt-plus
|
||||
type: llm
|
||||
provider: dashscope
|
||||
description: qwen-mt-plus多语言翻译大语言模型,支持智能体思考,16384上下文窗口,对话模式,支持多语种互译与领域翻译适配
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
tags:
|
||||
- 大语言模型
|
||||
- 翻译模型
|
||||
- agent-thought
|
||||
logo: dashscope
|
||||
- name: qwen-mt-turbo
|
||||
type: llm
|
||||
provider: dashscope
|
||||
description: qwen-mt-turbo轻量化多语言翻译大语言模型,支持智能体思考,16384上下文窗口,对话模式,支持多语种互译与领域翻译适配
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
tags:
|
||||
- 大语言模型
|
||||
- 翻译模型
|
||||
- agent-thought
|
||||
logo: dashscope
|
||||
- name: qwen-plus-0112
|
||||
type: llm
|
||||
provider: dashscope
|
||||
description: qwen-plus-0112大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式,支持联网搜索,已废弃
|
||||
is_deprecated: true
|
||||
is_official: true
|
||||
tags:
|
||||
- 大语言模型
|
||||
- multi-tool-call
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
logo: dashscope
|
||||
- name: qwen-plus-0125
|
||||
type: llm
|
||||
provider: dashscope
|
||||
description: qwen-plus-0125大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式,支持联网搜索,已废弃
|
||||
is_deprecated: true
|
||||
is_official: true
|
||||
tags:
|
||||
- 大语言模型
|
||||
- multi-tool-call
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
logo: dashscope
|
||||
- name: qwen-plus-0723
|
||||
type: llm
|
||||
provider: dashscope
|
||||
description: qwen-plus-0723大语言模型,支持多工具调用、智能体思考、流式工具调用,32000上下文窗口,对话模式,支持联网搜索,已废弃
|
||||
is_deprecated: true
|
||||
is_official: true
|
||||
tags:
|
||||
- 大语言模型
|
||||
- multi-tool-call
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
logo: dashscope
|
||||
- name: qwen-plus-0806
|
||||
type: llm
|
||||
provider: dashscope
|
||||
description: qwen-plus-0806大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式,支持联网搜索,已废弃
|
||||
is_deprecated: true
|
||||
is_official: true
|
||||
tags:
|
||||
- 大语言模型
|
||||
- multi-tool-call
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
logo: dashscope
|
||||
- name: qwen-plus-0919
|
||||
type: llm
|
||||
provider: dashscope
|
||||
description: qwen-plus-0919大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式,支持联网搜索,已废弃
|
||||
is_deprecated: true
|
||||
is_official: true
|
||||
tags:
|
||||
- 大语言模型
|
||||
- multi-tool-call
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
logo: dashscope
|
||||
- name: qwen-plus-1125
|
||||
type: llm
|
||||
provider: dashscope
|
||||
description: qwen-plus-1125大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式,支持联网搜索,已废弃
|
||||
is_deprecated: true
|
||||
is_official: true
|
||||
tags:
|
||||
- 大语言模型
|
||||
- multi-tool-call
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
logo: dashscope
|
||||
- name: qwen-plus-1127
|
||||
type: llm
|
||||
provider: dashscope
|
||||
description: qwen-plus-1127大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式,支持联网搜索,已废弃
|
||||
is_deprecated: true
|
||||
is_official: true
|
||||
tags:
|
||||
- 大语言模型
|
||||
- multi-tool-call
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
logo: dashscope
|
||||
- name: qwen-plus-1220
|
||||
type: llm
|
||||
provider: dashscope
|
||||
description: qwen-plus-1220大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式,已废弃
|
||||
is_deprecated: true
|
||||
is_official: true
|
||||
tags:
|
||||
- 大语言模型
|
||||
- multi-tool-call
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
logo: dashscope
|
||||
- name: qwen-vl-max
|
||||
type: llm
|
||||
provider: dashscope
|
||||
description: qwen-vl-max多模态大模型,支持视觉理解、智能体思考、视频理解,131072上下文窗口,对话模式,未废弃
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
tags:
|
||||
- 大语言模型
|
||||
- 多模态模型
|
||||
- vision
|
||||
- agent-thought
|
||||
- video
|
||||
logo: dashscope
|
||||
- name: qwen-vl-plus-0809
|
||||
type: llm
|
||||
provider: dashscope
|
||||
description: qwen-vl-plus-0809多模态大模型,支持视觉理解、智能体思考、视频理解,32768上下文窗口,对话模式,已废弃
|
||||
is_deprecated: true
|
||||
is_official: true
|
||||
tags:
|
||||
- 大语言模型
|
||||
- 多模态模型
|
||||
- vision
|
||||
- agent-thought
|
||||
- video
|
||||
logo: dashscope
|
||||
- name: qwen-vl-plus-2025-01-02
|
||||
type: llm
|
||||
provider: dashscope
|
||||
description: qwen-vl-plus-2025-01-02多模态大模型,支持视觉理解、智能体思考、视频理解,32768上下文窗口,对话模式,未废弃
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
tags:
|
||||
- 大语言模型
|
||||
- 多模态模型
|
||||
- vision
|
||||
- agent-thought
|
||||
- video
|
||||
logo: dashscope
|
||||
- name: qwen-vl-plus-2025-01-25
|
||||
type: llm
|
||||
provider: dashscope
|
||||
description: qwen-vl-plus-2025-01-25多模态大模型,支持视觉理解、智能体思考、视频理解,131072上下文窗口,对话模式,未废弃
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
tags:
|
||||
- 大语言模型
|
||||
- 多模态模型
|
||||
- vision
|
||||
- agent-thought
|
||||
- video
|
||||
logo: dashscope
|
||||
- name: qwen-vl-plus-latest
|
||||
type: llm
|
||||
provider: dashscope
|
||||
description: qwen-vl-plus-latest多模态大模型,支持视觉理解、智能体思考、视频理解,131072上下文窗口,对话模式,未废弃
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
tags:
|
||||
- 大语言模型
|
||||
- 多模态模型
|
||||
- vision
|
||||
- agent-thought
|
||||
- video
|
||||
logo: dashscope
|
||||
- name: qwen-vl-plus
|
||||
type: llm
|
||||
provider: dashscope
|
||||
description: qwen-vl-plus多模态大模型,支持视觉理解、智能体思考、视频理解,131072上下文窗口,对话模式,未废弃
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
tags:
|
||||
- 大语言模型
|
||||
- 多模态模型
|
||||
- vision
|
||||
- agent-thought
|
||||
- video
|
||||
logo: dashscope
|
||||
- name: qwen2.5-0.5b-instruct
|
||||
type: llm
|
||||
provider: dashscope
|
||||
description: qwen2.5-0.5b-instruct大语言模型,支持多工具调用、智能体思考、流式工具调用,32768上下文窗口,对话模式,未废弃
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
tags:
|
||||
- 大语言模型
|
||||
- multi-tool-call
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
logo: dashscope
|
||||
- name: qwen3-14b
|
||||
type: llm
|
||||
provider: dashscope
|
||||
description: qwen3-14b大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
tags:
|
||||
- 大语言模型
|
||||
- multi-tool-call
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
logo: dashscope
|
||||
- name: qwen3-235b-a22b-instruct-2507
|
||||
type: llm
|
||||
provider: dashscope
|
||||
description: qwen3-235b-a22b-instruct-2507大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
tags:
|
||||
- 大语言模型
|
||||
- multi-tool-call
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
logo: dashscope
|
||||
- name: qwen3-235b-a22b-thinking-2507
|
||||
type: llm
|
||||
provider: dashscope
|
||||
description: qwen3-235b-a22b-thinking-2507大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
tags:
|
||||
- 大语言模型
|
||||
- multi-tool-call
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
logo: dashscope
|
||||
- name: qwen3-235b-a22b
|
||||
type: llm
|
||||
provider: dashscope
|
||||
description: qwen3-235b-a22b大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
tags:
|
||||
- 大语言模型
|
||||
- multi-tool-call
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
logo: dashscope
|
||||
- name: qwen3-30b-a3b-instruct-2507
|
||||
type: llm
|
||||
provider: dashscope
|
||||
description: qwen3-30b-a3b-instruct-2507大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
tags:
|
||||
- 大语言模型
|
||||
- multi-tool-call
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
logo: dashscope
|
||||
- name: qwen3-30b-a3b
|
||||
type: llm
|
||||
provider: dashscope
|
||||
description: qwen3-30b-a3b大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
tags:
|
||||
- 大语言模型
|
||||
- multi-tool-call
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
logo: dashscope
|
||||
- name: qwen3-32b
|
||||
type: llm
|
||||
provider: dashscope
|
||||
description: qwen3-32b大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
tags:
|
||||
- 大语言模型
|
||||
- multi-tool-call
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
logo: dashscope
|
||||
- name: qwen3-4b
|
||||
type: llm
|
||||
provider: dashscope
|
||||
description: qwen3-4b大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
tags:
|
||||
- 大语言模型
|
||||
- multi-tool-call
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
logo: dashscope
|
||||
- name: qwen3-8b
|
||||
type: llm
|
||||
provider: dashscope
|
||||
description: qwen3-8b大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
tags:
|
||||
- 大语言模型
|
||||
- multi-tool-call
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
logo: dashscope
|
||||
- name: qwen3-coder-30b-a3b-instruct
|
||||
type: llm
|
||||
provider: dashscope
|
||||
description: qwen3-coder-30b-a3b-instruct大语言模型,支持智能体思考,262144上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
tags:
|
||||
- 大语言模型
|
||||
- 代码模型
|
||||
- agent-thought
|
||||
logo: dashscope
|
||||
- name: qwen3-coder-480b-a35b-instruct
|
||||
type: llm
|
||||
provider: dashscope
|
||||
description: qwen3-coder-480b-a35b-instruct大语言模型,支持智能体思考,262144上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
tags:
|
||||
- 大语言模型
|
||||
- 代码模型
|
||||
- agent-thought
|
||||
logo: dashscope
|
||||
- name: qwen3-coder-plus-2025-09-23
|
||||
type: llm
|
||||
provider: dashscope
|
||||
description: qwen3-coder-plus-2025-09-23大语言模型,支持智能体思考,1000000上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
tags:
|
||||
- 大语言模型
|
||||
- 代码模型
|
||||
- agent-thought
|
||||
logo: dashscope
|
||||
- name: qwen3-coder-plus
|
||||
type: llm
|
||||
provider: dashscope
|
||||
description: qwen3-coder-plus大语言模型,支持智能体思考,1000000上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
tags:
|
||||
- 大语言模型
|
||||
- 代码模型
|
||||
- agent-thought
|
||||
logo: dashscope
|
||||
- name: qwen3-max-2025-09-23
|
||||
type: llm
|
||||
provider: dashscope
|
||||
description: qwen3-max-2025-09-23大语言模型,支持多工具调用、智能体思考、流式工具调用,262144上下文窗口,对话模式,支持联网搜索
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
tags:
|
||||
- 大语言模型
|
||||
- multi-tool-call
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
- 联网搜索
|
||||
logo: dashscope
|
||||
- name: qwen3-max-2026-01-23
|
||||
type: llm
|
||||
provider: dashscope
|
||||
description: qwen3-max-2026-01-23大语言模型,支持多工具调用、智能体思考、流式工具调用,262144上下文窗口,对话模式,支持联网搜索
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
tags:
|
||||
- 大语言模型
|
||||
- multi-tool-call
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
- 联网搜索
|
||||
logo: dashscope
|
||||
- name: qwen3-max-preview
|
||||
type: llm
|
||||
provider: dashscope
|
||||
description: qwen3-max-preview大语言模型,支持多工具调用、智能体思考、流式工具调用,262144上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
tags:
|
||||
- 大语言模型
|
||||
- multi-tool-call
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
logo: dashscope
|
||||
- name: qwen3-max
|
||||
type: llm
|
||||
provider: dashscope
|
||||
description: qwen3-max大语言模型,支持多工具调用、智能体思考、流式工具调用,262144上下文窗口,对话模式,支持联网搜索
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
tags:
|
||||
- 大语言模型
|
||||
- multi-tool-call
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
- 联网搜索
|
||||
logo: dashscope
|
||||
- name: qwen3-next-80b-a3b-instruct
|
||||
type: llm
|
||||
provider: dashscope
|
||||
description: qwen3-next-80b-a3b-instruct大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
tags:
|
||||
- 大语言模型
|
||||
- multi-tool-call
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
logo: dashscope
|
||||
- name: qwen3-next-80b-a3b-thinking
|
||||
type: llm
|
||||
provider: dashscope
|
||||
description: qwen3-next-80b-a3b-thinking大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
tags:
|
||||
- 大语言模型
|
||||
- multi-tool-call
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
logo: dashscope
|
||||
- name: qwen3-omni-flash-2025-12-01
|
||||
type: llm
|
||||
provider: dashscope
|
||||
description: qwen3-omni-flash-2025-12-01多模态大语言模型,支持视觉、智能体思考、视频、音频能力,65536上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
tags:
|
||||
- 大语言模型
|
||||
- 多模态模型
|
||||
- vision
|
||||
- agent-thought
|
||||
- video
|
||||
- audio
|
||||
logo: dashscope
|
||||
- name: qwen3-vl-235b-a22b-instruct
|
||||
type: llm
|
||||
provider: dashscope
|
||||
description: qwen3-vl-235b-a22b-instruct多模态大语言模型,支持多工具调用、智能体思考、流式工具调用、视觉、视频能力,131072上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
tags:
|
||||
- 大语言模型
|
||||
- 多模态模型
|
||||
- multi-tool-call
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
- vision
|
||||
- video
|
||||
logo: dashscope
|
||||
- name: qwen3-vl-235b-a22b-thinking
|
||||
type: llm
|
||||
provider: dashscope
|
||||
description: qwen3-vl-235b-a22b-thinking多模态大语言模型,支持多工具调用、智能体思考、流式工具调用、视觉、视频能力,131072上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
tags:
|
||||
- 大语言模型
|
||||
- 多模态模型
|
||||
- multi-tool-call
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
- vision
|
||||
- video
|
||||
logo: dashscope
|
||||
- name: qwen3-vl-30b-a3b-instruct
|
||||
type: llm
|
||||
provider: dashscope
|
||||
description: qwen3-vl-30b-a3b-instruct多模态大语言模型,支持多工具调用、智能体思考、流式工具调用、视觉、视频能力,131072上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
tags:
|
||||
- 大语言模型
|
||||
- 多模态模型
|
||||
- multi-tool-call
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
- vision
|
||||
- video
|
||||
logo: dashscope
|
||||
- name: qwen3-vl-30b-a3b-thinking
|
||||
type: llm
|
||||
provider: dashscope
|
||||
description: qwen3-vl-30b-a3b-thinking多模态大语言模型,支持多工具调用、智能体思考、流式工具调用、视觉、视频能力,131072上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
tags:
|
||||
- 大语言模型
|
||||
- 多模态模型
|
||||
- multi-tool-call
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
- vision
|
||||
- video
|
||||
logo: dashscope
|
||||
- name: qwen3-vl-flash
|
||||
type: llm
|
||||
provider: dashscope
|
||||
description: qwen3-vl-flash多模态大语言模型,支持多工具调用、智能体思考、流式工具调用、视觉、视频能力,131072上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
tags:
|
||||
- 大语言模型
|
||||
- 多模态模型
|
||||
- multi-tool-call
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
- vision
|
||||
- video
|
||||
logo: dashscope
|
||||
- name: qwen3-vl-plus-2025-09-23
|
||||
type: llm
|
||||
provider: dashscope
|
||||
description: qwen3-vl-plus-2025-09-23多模态大语言模型,支持视觉、智能体思考、视频能力,262144上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
tags:
|
||||
- 大语言模型
|
||||
- 多模态模型
|
||||
- vision
|
||||
- agent-thought
|
||||
- video
|
||||
logo: dashscope
|
||||
- name: qwen3-vl-plus
|
||||
type: llm
|
||||
provider: dashscope
|
||||
description: qwen3-vl-plus多模态大语言模型,支持视觉、智能体思考、视频能力,262144上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
tags:
|
||||
- 大语言模型
|
||||
- 多模态模型
|
||||
- vision
|
||||
- agent-thought
|
||||
- video
|
||||
logo: dashscope
|
||||
- name: qwq-32b
|
||||
type: llm
|
||||
provider: dashscope
|
||||
description: qwq-32b大语言模型,支持智能体思考、流式工具调用,131072上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
tags:
|
||||
- 大语言模型
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
logo: dashscope
|
||||
- name: qwq-plus-0305
|
||||
type: llm
|
||||
provider: dashscope
|
||||
description: qwq-plus-0305大语言模型,支持智能体思考、流式工具调用,131072上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
tags:
|
||||
- 大语言模型
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
logo: dashscope
|
||||
- name: qwq-plus
|
||||
type: llm
|
||||
provider: dashscope
|
||||
description: qwq-plus大语言模型,支持智能体思考、流式工具调用,131072上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
tags:
|
||||
- 大语言模型
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
logo: dashscope
|
||||
- name: gte-rerank-v2
|
||||
type: rerank
|
||||
provider: dashscope
|
||||
description: gte-rerank-v2重排序模型,4000上下文窗口
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
tags:
|
||||
- 重排序模型
|
||||
logo: dashscope
|
||||
- name: gte-rerank
|
||||
type: rerank
|
||||
provider: dashscope
|
||||
description: gte-rerank重排序模型,4000上下文窗口
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
tags:
|
||||
- 重排序模型
|
||||
logo: dashscope
|
||||
- name: multimodal-embedding-v1
|
||||
type: embedding
|
||||
provider: dashscope
|
||||
description: multimodal-embedding-v1多模态嵌入模型,支持视觉能力,8192上下文窗口,最大分块数10
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
tags:
|
||||
- 嵌入模型
|
||||
- 多模态模型
|
||||
- vision
|
||||
logo: dashscope
|
||||
- name: text-embedding-v1
|
||||
type: embedding
|
||||
provider: dashscope
|
||||
description: text-embedding-v1文本嵌入模型,2048上下文窗口,最大分块数25
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
tags:
|
||||
- 嵌入模型
|
||||
- 文本嵌入
|
||||
logo: dashscope
|
||||
- name: text-embedding-v2
|
||||
type: embedding
|
||||
provider: dashscope
|
||||
description: text-embedding-v2文本嵌入模型,2048上下文窗口,最大分块数25
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
tags:
|
||||
- 嵌入模型
|
||||
- 文本嵌入
|
||||
logo: dashscope
|
||||
- name: text-embedding-v3
|
||||
type: embedding
|
||||
provider: dashscope
|
||||
description: text-embedding-v3文本嵌入模型,8192上下文窗口,最大分块数10
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
tags:
|
||||
- 嵌入模型
|
||||
- 文本嵌入
|
||||
logo: dashscope
|
||||
- name: text-embedding-v4
|
||||
type: embedding
|
||||
provider: dashscope
|
||||
description: text-embedding-v4文本嵌入模型,8192上下文窗口,最大分块数10
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
tags:
|
||||
- 嵌入模型
|
||||
- 文本嵌入
|
||||
logo: dashscope
|
||||
143
api/app/core/models/scripts/loader.py
Normal file
143
api/app/core/models/scripts/loader.py
Normal file
@@ -0,0 +1,143 @@
|
||||
"""模型配置加载器 - 用于将预定义模型批量导入到数据库"""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Callable
|
||||
|
||||
import yaml
|
||||
from sqlalchemy.orm import Session
|
||||
from app.models.models_model import ModelBase, ModelProvider
|
||||
|
||||
|
||||
def _load_yaml_config(provider: ModelProvider) -> list[dict]:
|
||||
"""从YAML文件加载指定供应商的模型配置"""
|
||||
config_dir = Path(__file__).parent
|
||||
config_file = config_dir / f"{provider.value}_models.yaml"
|
||||
|
||||
if not config_file.exists():
|
||||
return []
|
||||
|
||||
with open(config_file, 'r', encoding='utf-8') as f:
|
||||
data = yaml.safe_load(f)
|
||||
|
||||
# 检查是否需要加载(默认为 true)
|
||||
if not data.get('enabled', True):
|
||||
return []
|
||||
|
||||
return data.get('models', [])
|
||||
|
||||
|
||||
def _disable_yaml_config(provider: ModelProvider) -> None:
|
||||
"""将YAML文件的enabled标志设置为false"""
|
||||
config_dir = Path(__file__).parent
|
||||
config_file = config_dir / f"{provider.value}_models.yaml"
|
||||
|
||||
if not config_file.exists():
|
||||
return
|
||||
|
||||
with open(config_file, 'r', encoding='utf-8') as f:
|
||||
data = yaml.safe_load(f)
|
||||
|
||||
data['enabled'] = False
|
||||
|
||||
with open(config_file, 'w', encoding='utf-8') as f:
|
||||
yaml.dump(data, f, allow_unicode=True, sort_keys=False)
|
||||
|
||||
|
||||
def load_models(db: Session, providers: list[str] = None, silent: bool = False) -> dict:
|
||||
"""
|
||||
加载模型配置到数据库
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
providers: 要加载的供应商列表,None表示加载所有
|
||||
silent: 是否静默模式(不输出详细日志)
|
||||
|
||||
Returns:
|
||||
dict: 加载结果统计 {"success": int, "skipped": int, "failed": int}
|
||||
"""
|
||||
result = {"success": 0, "skipped": 0, "failed": 0}
|
||||
|
||||
# 确定要加载的供应商
|
||||
if providers:
|
||||
target_providers = [ModelProvider(p) if isinstance(p, str) else p for p in providers]
|
||||
else:
|
||||
target_providers = [p for p in ModelProvider if p != ModelProvider.COMPOSITE]
|
||||
|
||||
for provider in target_providers:
|
||||
# 从YAML文件加载模型配置
|
||||
models = _load_yaml_config(provider)
|
||||
|
||||
if not models:
|
||||
if not silent:
|
||||
print(f"警告: 供应商 '{provider.value}' 暂无预定义模型")
|
||||
continue
|
||||
|
||||
if not silent:
|
||||
print(f"\n正在加载 {provider.value} 的 {len(models)} 个模型...")
|
||||
|
||||
# provider_success = 0
|
||||
for model_data in models:
|
||||
try:
|
||||
# 检查模型是否已存在
|
||||
existing = db.query(ModelBase).filter(
|
||||
ModelBase.name == model_data["name"],
|
||||
ModelBase.provider == model_data["provider"]
|
||||
).first()
|
||||
|
||||
if existing:
|
||||
# 更新现有模型配置
|
||||
for key, value in model_data.items():
|
||||
setattr(existing, key, value)
|
||||
db.commit()
|
||||
if not silent:
|
||||
print(f"更新成功: {model_data['name']}")
|
||||
result["success"] += 1
|
||||
# provider_success += 1
|
||||
else:
|
||||
# 创建新模型
|
||||
model = ModelBase(**model_data)
|
||||
db.add(model)
|
||||
db.commit()
|
||||
if not silent:
|
||||
print(f"添加成功: {model_data['name']}")
|
||||
result["success"] += 1
|
||||
# provider_success += 1
|
||||
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
if not silent:
|
||||
print(f"添加失败: {model_data['name']} - {str(e)}")
|
||||
result["failed"] += 1
|
||||
|
||||
# 如果该供应商的模型全部加载成功,将enabled设置为false
|
||||
# if provider_success == len(models):
|
||||
_disable_yaml_config(provider)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def load_models_by_provider(db: Session, provider: str) -> dict:
|
||||
"""
|
||||
加载指定供应商的模型配置
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
provider: 供应商名称(字符串或ModelProvider枚举)
|
||||
|
||||
Returns:
|
||||
dict: 加载结果统计
|
||||
"""
|
||||
provider_enum = ModelProvider(provider) if isinstance(provider, str) else provider
|
||||
return load_models(db, providers=[provider_enum])
|
||||
|
||||
|
||||
def get_available_providers() -> list[Callable[[], str]]:
|
||||
"""获取所有可用的供应商列表(从ModelProvider枚举获取,排除COMPOSITE)"""
|
||||
return [p.value for p in ModelProvider if p != ModelProvider.COMPOSITE]
|
||||
|
||||
|
||||
def get_models_by_provider(provider: str) -> list[dict]:
|
||||
"""获取指定供应商的模型配置列表"""
|
||||
provider_enum = ModelProvider(provider) if isinstance(provider, str) else provider
|
||||
return _load_yaml_config(provider_enum)
|
||||
294
api/app/core/models/scripts/openai_models.yaml
Normal file
294
api/app/core/models/scripts/openai_models.yaml
Normal file
@@ -0,0 +1,294 @@
|
||||
provider: openai
|
||||
enabled: true
|
||||
models:
|
||||
- name: chatgpt-4o-latest
|
||||
type: llm
|
||||
provider: openai
|
||||
description: chatgpt-4o-latest大语言模型,支持多工具调用、智能体思考、流式工具调用、视觉能力,128000上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
tags:
|
||||
- 大语言模型
|
||||
- multi-tool-call
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
- vision
|
||||
logo: openai
|
||||
- name: gpt-3.5-turbo-0125
|
||||
type: llm
|
||||
provider: openai
|
||||
description: gpt-3.5-turbo-0125大语言模型,支持多工具调用、智能体思考、流式工具调用,16385上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
tags:
|
||||
- 大语言模型
|
||||
- multi-tool-call
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
logo: openai
|
||||
- name: gpt-3.5-turbo-1106
|
||||
type: llm
|
||||
provider: openai
|
||||
description: gpt-3.5-turbo-1106大语言模型,支持多工具调用、智能体思考、流式工具调用,16385上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
tags:
|
||||
- 大语言模型
|
||||
- multi-tool-call
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
logo: openai
|
||||
- name: gpt-3.5-turbo-16k
|
||||
type: llm
|
||||
provider: openai
|
||||
description: gpt-3.5-turbo-16k大语言模型,支持多工具调用、智能体思考、流式工具调用,16385上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
tags:
|
||||
- 大语言模型
|
||||
- multi-tool-call
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
logo: openai
|
||||
- name: gpt-3.5-turbo-instruct
|
||||
type: llm
|
||||
provider: openai
|
||||
description: gpt-3.5-turbo-instruct大语言模型,4096上下文窗口,文本补全模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
tags:
|
||||
- 大语言模型
|
||||
logo: openai
|
||||
- name: gpt-3.5-turbo
|
||||
type: llm
|
||||
provider: openai
|
||||
description: gpt-3.5-turbo大语言模型,支持多工具调用、智能体思考、流式工具调用,16385上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
tags:
|
||||
- 大语言模型
|
||||
- multi-tool-call
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
logo: openai
|
||||
- name: gpt-4-0125-preview
|
||||
type: llm
|
||||
provider: openai
|
||||
description: gpt-4-0125-preview大语言模型,支持多工具调用、智能体思考、流式工具调用,128000上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
tags:
|
||||
- 大语言模型
|
||||
- multi-tool-call
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
logo: openai
|
||||
- name: gpt-4-1106-preview
|
||||
type: llm
|
||||
provider: openai
|
||||
description: gpt-4-1106-preview大语言模型,支持多工具调用、智能体思考、流式工具调用,128000上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
tags:
|
||||
- 大语言模型
|
||||
- multi-tool-call
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
logo: openai
|
||||
- name: gpt-4-turbo-2024-04-09
|
||||
type: llm
|
||||
provider: openai
|
||||
description: gpt-4-turbo-2024-04-09大语言模型,支持多工具调用、智能体思考、流式工具调用、视觉能力,128000上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
tags:
|
||||
- 大语言模型
|
||||
- multi-tool-call
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
- vision
|
||||
logo: openai
|
||||
- name: gpt-4-turbo-preview
|
||||
type: llm
|
||||
provider: openai
|
||||
description: gpt-4-turbo-preview大语言模型,支持多工具调用、智能体思考、流式工具调用,128000上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
tags:
|
||||
- 大语言模型
|
||||
- multi-tool-call
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
logo: openai
|
||||
- name: gpt-4-turbo
|
||||
type: llm
|
||||
provider: openai
|
||||
description: gpt-4-turbo大语言模型,支持多工具调用、智能体思考、流式工具调用、视觉能力,128000上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
tags:
|
||||
- 大语言模型
|
||||
- multi-tool-call
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
- vision
|
||||
logo: openai
|
||||
- name: o1-preview
|
||||
type: llm
|
||||
provider: openai
|
||||
description: o1-preview大语言模型,支持智能体思考,128000上下文窗口,对话模式,已废弃
|
||||
is_deprecated: true
|
||||
is_official: true
|
||||
tags:
|
||||
- 大语言模型
|
||||
- agent-thought
|
||||
logo: openai
|
||||
- name: o1
|
||||
type: llm
|
||||
provider: openai
|
||||
description: o1大语言模型,支持多工具调用、智能体思考、流式工具调用、视觉能力、结构化输出,200000上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
tags:
|
||||
- 大语言模型
|
||||
- multi-tool-call
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
- vision
|
||||
- structured-output
|
||||
logo: openai
|
||||
- name: o3-2025-04-16
|
||||
type: llm
|
||||
provider: openai
|
||||
description: o3-2025-04-16大语言模型,支持智能体思考、工具调用、视觉能力、流式工具调用、结构化输出,200000上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
tags:
|
||||
- 大语言模型
|
||||
- agent-thought
|
||||
- tool-call
|
||||
- vision
|
||||
- stream-tool-call
|
||||
- structured-output
|
||||
logo: openai
|
||||
- name: o3-mini-2025-01-31
|
||||
type: llm
|
||||
provider: openai
|
||||
description: o3-mini-2025-01-31大语言模型,支持智能体思考、工具调用、流式工具调用、结构化输出,200000上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
tags:
|
||||
- 大语言模型
|
||||
- agent-thought
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
- structured-output
|
||||
logo: openai
|
||||
- name: o3-mini
|
||||
type: llm
|
||||
provider: openai
|
||||
description: o3-mini大语言模型,支持智能体思考、工具调用、流式工具调用、结构化输出,200000上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
tags:
|
||||
- 大语言模型
|
||||
- agent-thought
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
- structured-output
|
||||
logo: openai
|
||||
- name: o3-pro-2025-06-10
|
||||
type: llm
|
||||
provider: openai
|
||||
description: o3-pro-2025-06-10大语言模型,支持智能体思考、工具调用、视觉能力、结构化输出,200000上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
tags:
|
||||
- 大语言模型
|
||||
- agent-thought
|
||||
- tool-call
|
||||
- vision
|
||||
- structured-output
|
||||
logo: openai
|
||||
- name: o3-pro
|
||||
type: llm
|
||||
provider: openai
|
||||
description: o3-pro大语言模型,支持智能体思考、工具调用、视觉能力、结构化输出,200000上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
tags:
|
||||
- 大语言模型
|
||||
- agent-thought
|
||||
- tool-call
|
||||
- vision
|
||||
- structured-output
|
||||
logo: openai
|
||||
- name: o3
|
||||
type: llm
|
||||
provider: openai
|
||||
description: o3大语言模型,支持智能体思考、视觉能力、工具调用、流式工具调用、结构化输出,200000上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
tags:
|
||||
- 大语言模型
|
||||
- agent-thought
|
||||
- vision
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
- structured-output
|
||||
logo: openai
|
||||
- name: o4-mini-2025-04-16
|
||||
type: llm
|
||||
provider: openai
|
||||
description: o4-mini-2025-04-16大语言模型,支持智能体思考、工具调用、视觉能力、流式工具调用、结构化输出,200000上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
tags:
|
||||
- 大语言模型
|
||||
- agent-thought
|
||||
- tool-call
|
||||
- vision
|
||||
- stream-tool-call
|
||||
- structured-output
|
||||
logo: openai
|
||||
- name: o4-mini
|
||||
type: llm
|
||||
provider: openai
|
||||
description: o4-mini大语言模型,支持智能体思考、工具调用、视觉能力、流式工具调用、结构化输出,200000上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
tags:
|
||||
- 大语言模型
|
||||
- agent-thought
|
||||
- tool-call
|
||||
- vision
|
||||
- stream-tool-call
|
||||
- structured-output
|
||||
logo: openai
|
||||
- name: text-embedding-3-large
|
||||
type: embedding
|
||||
provider: openai
|
||||
description: text-embedding-3-large文本向量模型,8191上下文窗口,最大分块数32
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
tags:
|
||||
- 文本向量模型
|
||||
logo: openai
|
||||
- name: text-embedding-3-small
|
||||
type: embedding
|
||||
provider: openai
|
||||
description: text-embedding-3-small文本向量模型,8191上下文窗口,最大分块数32
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
tags:
|
||||
- 文本向量模型
|
||||
logo: openai
|
||||
- name: text-embedding-ada-002
|
||||
type: embedding
|
||||
provider: openai
|
||||
description: text-embedding-ada-002文本向量模型,8097上下文窗口,最大分块数32
|
||||
is_deprecated: false
|
||||
is_official: true
|
||||
tags:
|
||||
- 文本向量模型
|
||||
logo: openai
|
||||
@@ -1,165 +0,0 @@
|
||||
import copy
|
||||
import re
|
||||
from io import BytesIO
|
||||
from PIL import Image
|
||||
|
||||
from app.core.rag.nlp import tokenize, is_english
|
||||
from app.core.rag.nlp import rag_tokenizer
|
||||
from app.core.rag.deepdoc.parser import PdfParser, PlainParser
|
||||
from app.core.rag.deepdoc.parser.ppt_parser import RAGPptParser as PptParser
|
||||
from PyPDF2 import PdfReader as pdf2_read
|
||||
from app.core.rag.app.naive import by_plaintext, PARSERS
|
||||
|
||||
class Ppt(PptParser):
|
||||
def __call__(self, fnm, from_page, to_page, callback=None):
|
||||
txts = super().__call__(fnm, from_page, to_page)
|
||||
|
||||
callback(0.5, "Text extraction finished.")
|
||||
import aspose.slides as slides
|
||||
import aspose.pydrawing as drawing
|
||||
imgs = []
|
||||
with slides.Presentation(BytesIO(fnm)) as presentation:
|
||||
for i, slide in enumerate(presentation.slides[from_page: to_page]):
|
||||
try:
|
||||
with BytesIO() as buffered:
|
||||
slide.get_thumbnail(
|
||||
0.1, 0.1).save(
|
||||
buffered, drawing.imaging.ImageFormat.jpeg)
|
||||
buffered.seek(0)
|
||||
imgs.append(Image.open(buffered).copy())
|
||||
except RuntimeError as e:
|
||||
raise RuntimeError(f'ppt parse error at page {i+1}, original error: {str(e)}') from e
|
||||
assert len(imgs) == len(
|
||||
txts), "Slides text and image do not match: {} vs. {}".format(len(imgs), len(txts))
|
||||
callback(0.9, "Image extraction finished")
|
||||
self.is_english = is_english(txts)
|
||||
return [(txts[i], imgs[i]) for i in range(len(txts))]
|
||||
|
||||
class Pdf(PdfParser):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def __garbage(self, txt):
|
||||
txt = txt.lower().strip()
|
||||
if re.match(r"[0-9\.,%/-]+$", txt):
|
||||
return True
|
||||
if len(txt) < 3:
|
||||
return True
|
||||
return False
|
||||
|
||||
def __call__(self, filename, binary=None, from_page=0,
|
||||
to_page=100000, zoomin=3, callback=None):
|
||||
from timeit import default_timer as timer
|
||||
start = timer()
|
||||
callback(msg="OCR started")
|
||||
self.__images__(filename if not binary else binary,
|
||||
zoomin, from_page, to_page, callback)
|
||||
callback(msg="Page {}~{}: OCR finished ({:.2f}s)".format(from_page, min(to_page, self.total_page), timer() - start))
|
||||
assert len(self.boxes) == len(self.page_images), "{} vs. {}".format(
|
||||
len(self.boxes), len(self.page_images))
|
||||
res = []
|
||||
for i in range(len(self.boxes)):
|
||||
lines = "\n".join([b["text"] for b in self.boxes[i]
|
||||
if not self.__garbage(b["text"])])
|
||||
res.append((lines, self.page_images[i]))
|
||||
callback(0.9, "Page {}~{}: Parsing finished".format(
|
||||
from_page, min(to_page, self.total_page)))
|
||||
return res, []
|
||||
|
||||
|
||||
class PlainPdf(PlainParser):
|
||||
def __call__(self, filename, binary=None, from_page=0,
|
||||
to_page=100000, callback=None, **kwargs):
|
||||
self.pdf = pdf2_read(filename if not binary else BytesIO(binary))
|
||||
page_txt = []
|
||||
for page in self.pdf.pages[from_page: to_page]:
|
||||
page_txt.append(page.extract_text())
|
||||
callback(0.9, "Parsing finished")
|
||||
return [(txt, None) for txt in page_txt], []
|
||||
|
||||
|
||||
def chunk(filename, binary=None, from_page=0, to_page=100000,
|
||||
lang="Chinese", callback=None, vision_model=None, parser_config=None, **kwargs):
|
||||
"""
|
||||
The supported file formats are pdf, pptx.
|
||||
Every page will be treated as a chunk. And the thumbnail of every page will be stored.
|
||||
PPT file will be parsed by using this method automatically, setting-up for every PPT file is not necessary.
|
||||
"""
|
||||
if parser_config is None:
|
||||
parser_config = {}
|
||||
eng = lang.lower() == "english"
|
||||
doc = {
|
||||
"docnm_kwd": filename,
|
||||
"title_tks": rag_tokenizer.tokenize(re.sub(r"\.[a-zA-Z]+$", "", filename))
|
||||
}
|
||||
doc["title_sm_tks"] = rag_tokenizer.fine_grained_tokenize(doc["title_tks"])
|
||||
res = []
|
||||
if re.search(r"\.pptx?$", filename, re.IGNORECASE):
|
||||
if not binary:
|
||||
with open(filename, "rb") as f:
|
||||
binary = f.read()
|
||||
ppt_parser = Ppt()
|
||||
for pn, (txt, img) in enumerate(ppt_parser(
|
||||
filename if not binary else binary, from_page, 1000000, callback)):
|
||||
d = copy.deepcopy(doc)
|
||||
pn += from_page
|
||||
d["image"] = img
|
||||
d["doc_type_kwd"] = "image"
|
||||
d["page_num_int"] = [pn + 1]
|
||||
d["top_int"] = [0]
|
||||
d["position_int"] = [(pn + 1, 0, img.size[0], 0, img.size[1])]
|
||||
tokenize(d, txt, eng)
|
||||
res.append(d)
|
||||
return res
|
||||
elif re.search(r"\.pdf$", filename, re.IGNORECASE):
|
||||
layout_recognizer = parser_config.get("layout_recognize", "DeepDOC")
|
||||
|
||||
if isinstance(layout_recognizer, bool):
|
||||
layout_recognizer = "DeepDOC" if layout_recognizer else "Plain Text"
|
||||
|
||||
name = layout_recognizer.strip().lower()
|
||||
parser = PARSERS.get(name, by_plaintext)
|
||||
callback(0.1, "Start to parse.")
|
||||
|
||||
sections, _, _ = parser(
|
||||
filename=filename,
|
||||
binary=binary,
|
||||
from_page=from_page,
|
||||
to_page=to_page,
|
||||
lang=lang,
|
||||
callback=callback,
|
||||
vision_model=vision_model,
|
||||
pdf_cls=Pdf,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
if not sections:
|
||||
return []
|
||||
|
||||
if name in ["tcadp", "docling", "mineru"]:
|
||||
parser_config["chunk_token_num"] = 0
|
||||
|
||||
callback(0.8, "Finish parsing.")
|
||||
|
||||
for pn, (txt, img) in enumerate(sections):
|
||||
d = copy.deepcopy(doc)
|
||||
pn += from_page
|
||||
if img:
|
||||
d["image"] = img
|
||||
d["page_num_int"] = [pn + 1]
|
||||
d["top_int"] = [0]
|
||||
d["position_int"] = [(pn + 1, 0, img.size[0] if img else 0, 0, img.size[1] if img else 0)]
|
||||
tokenize(d, txt, eng)
|
||||
res.append(d)
|
||||
return res
|
||||
|
||||
raise NotImplementedError(
|
||||
"file type not supported yet(pptx, pdf supported)")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
|
||||
def dummy(a, b):
|
||||
pass
|
||||
chunk(sys.argv[1], callback=dummy)
|
||||
@@ -4,7 +4,7 @@ from enum import StrEnum, auto
|
||||
class Field(StrEnum):
|
||||
CONTENT_KEY = "page_content"
|
||||
METADATA_KEY = "metadata"
|
||||
GROUP_KEY = "group_id"
|
||||
GROUP_KEY = "end_user_id"
|
||||
VECTOR = auto()
|
||||
# Sparse Vector aims to support full text search
|
||||
SPARSE_VECTOR = auto()
|
||||
|
||||
@@ -36,7 +36,7 @@ def generate_signed_url(
|
||||
"""
|
||||
if base_url is None:
|
||||
# Use SERVER_IP or default to localhost
|
||||
server_url = f"http://{settings.SERVER_IP}:8000/api"
|
||||
server_url = settings.FILE_LOCAL_SERVER_URL
|
||||
base_url = server_url
|
||||
|
||||
# Calculate expiration timestamp
|
||||
|
||||
@@ -16,7 +16,7 @@ class BaiduSearchTool(BuiltinTool):
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "百度搜索 - 搜索引擎服务:网页搜索、新闻搜索、图片搜索、实时结果"
|
||||
return "百度搜索 - 搜索引擎服务:网页搜索、新闻搜索、图片搜索、视频搜索"
|
||||
|
||||
def get_required_config_parameters(self) -> List[str]:
|
||||
return ["api_key"]
|
||||
@@ -33,7 +33,7 @@ class BaiduSearchTool(BuiltinTool):
|
||||
ToolParameter(
|
||||
name="search_type",
|
||||
type=ParameterType.STRING,
|
||||
description="搜索类型",
|
||||
description="搜索类型, web: 网页搜索;news:新闻搜索;image:图片搜索;video视频搜索",
|
||||
required=False,
|
||||
default="web",
|
||||
enum=["web", "news", "image", "video"]
|
||||
|
||||
@@ -26,7 +26,7 @@ logger = get_config_logger()
|
||||
|
||||
|
||||
def _parse_model_id(model_id: Union[str, UUID, None], model_type: str,
|
||||
config_id: Optional[int] = None, workspace_id: Optional[UUID] = None) -> Optional[UUID]:
|
||||
config_id: Optional[UUID] = None, workspace_id: Optional[UUID] = None) -> Optional[UUID]:
|
||||
"""Parse model ID from string or UUID."""
|
||||
if model_id is None:
|
||||
return None
|
||||
@@ -59,7 +59,7 @@ def validate_model_exists_and_active(
|
||||
model_type: str,
|
||||
db: Session,
|
||||
tenant_id: Optional[UUID] = None,
|
||||
config_id: Optional[int] = None,
|
||||
config_id: Optional[UUID] = None,
|
||||
workspace_id: Optional[UUID] = None
|
||||
) -> tuple[str, bool]:
|
||||
"""Validate that a model exists and is active.
|
||||
@@ -166,7 +166,7 @@ def validate_and_resolve_model_id(
|
||||
db: Session,
|
||||
tenant_id: Optional[UUID] = None,
|
||||
required: bool = False,
|
||||
config_id: Optional[int] = None,
|
||||
config_id: Optional[UUID] = None,
|
||||
workspace_id: Optional[UUID] = None
|
||||
) -> tuple[Optional[UUID], Optional[str]]:
|
||||
"""Validate and resolve a model ID, checking existence and active status.
|
||||
@@ -204,7 +204,7 @@ def validate_and_resolve_model_id(
|
||||
|
||||
|
||||
def validate_embedding_model(
|
||||
config_id: int,
|
||||
config_id: UUID,
|
||||
embedding_id: Union[str, UUID, None],
|
||||
db: Session,
|
||||
tenant_id: Optional[UUID] = None,
|
||||
@@ -256,7 +256,7 @@ def validate_embedding_model(
|
||||
|
||||
|
||||
def validate_llm_model(
|
||||
config_id: int,
|
||||
config_id: UUID,
|
||||
llm_id: Union[str, UUID, None],
|
||||
db: Session,
|
||||
tenant_id: Optional[UUID] = None,
|
||||
|
||||
@@ -11,17 +11,12 @@ from typing import Any
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.graph.state import CompiledStateGraph
|
||||
|
||||
from app.core.workflow.graph_builder import GraphBuilder
|
||||
from app.core.workflow.expression_evaluator import evaluate_expression
|
||||
from app.core.workflow.graph_builder import GraphBuilder, StreamOutputConfig
|
||||
from app.core.workflow.nodes import WorkflowState
|
||||
from app.core.workflow.nodes.base_config import VariableType
|
||||
from app.core.workflow.nodes.enums import NodeType
|
||||
|
||||
# from app.core.tools.registry import ToolRegistry
|
||||
# from app.core.tools.executor import ToolExecutor
|
||||
# from app.core.tools.langchain_adapter import LangchainAdapter
|
||||
# TOOL_MANAGEMENT_AVAILABLE = True
|
||||
# from app.db import get_db
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -55,6 +50,8 @@ class WorkflowExecutor:
|
||||
self.execution_config = workflow_config.get("execution_config", {})
|
||||
|
||||
self.start_node_id = None
|
||||
self.end_outputs: dict[str, StreamOutputConfig] = {}
|
||||
self.activate_end: str | None = None
|
||||
|
||||
self.checkpoint_config = RunnableConfig(
|
||||
configurable={
|
||||
@@ -127,7 +124,6 @@ class WorkflowExecutor:
|
||||
"user_id": self.user_id,
|
||||
"error": None,
|
||||
"error_node": None,
|
||||
"streaming_buffer": {}, # 流式缓冲区
|
||||
"cycle_nodes": [
|
||||
node.get("id")
|
||||
for node in self.workflow_config.get("nodes")
|
||||
@@ -139,9 +135,8 @@ class WorkflowExecutor:
|
||||
}
|
||||
}
|
||||
|
||||
def _build_final_output(self, result, elapsed_time):
|
||||
def _build_final_output(self, result, elapsed_time, final_output):
|
||||
node_outputs = result.get("node_outputs", {})
|
||||
final_output = self._extract_final_output(node_outputs)
|
||||
token_usage = self._aggregate_token_usage(node_outputs)
|
||||
conversation_id = None
|
||||
for node_id, node_output in node_outputs.items():
|
||||
@@ -161,6 +156,146 @@ class WorkflowExecutor:
|
||||
"error": result.get("error"),
|
||||
}
|
||||
|
||||
def _update_scope_activate(self, scope, status=None):
|
||||
"""
|
||||
Update the activation state of all End nodes based on a completed scope (node or variable).
|
||||
|
||||
Iterates over all End nodes in `self.end_outputs` and calls
|
||||
`update_activate` on each, which may:
|
||||
- Activate variable segments that depend on the completed node/scope.
|
||||
- Activate the entire End node output if all control conditions are met.
|
||||
|
||||
If any End node becomes active and `self.activate_end` is not yet set,
|
||||
this node will be marked as the currently active End node.
|
||||
|
||||
Args:
|
||||
scope (str): The node ID or scope that has completed execution.
|
||||
status (str | None): Optional status of the node (used for branch/control nodes).
|
||||
"""
|
||||
for node in self.end_outputs.keys():
|
||||
self.end_outputs[node].update_activate(scope, status)
|
||||
if self.end_outputs[node].activate and self.activate_end is None:
|
||||
self.activate_end = node
|
||||
|
||||
def _update_stream_output_status(self, activate, data):
|
||||
"""
|
||||
Update the stream output state of End nodes based on workflow state updates.
|
||||
|
||||
This method checks which nodes/scopes are activated and propagates
|
||||
activation to End nodes accordingly.
|
||||
|
||||
Args:
|
||||
activate (dict): Mapping of node_id -> bool indicating which nodes/scopes are activated.
|
||||
data (dict): Mapping of node_id -> node runtime data, including outputs.
|
||||
|
||||
Behavior:
|
||||
For each node in `data`:
|
||||
1. If the node is activated (`activate[node_id]` is True),
|
||||
retrieve its output status from `runtime_vars`.
|
||||
2. Call `_update_scope_activate` to propagate the activation
|
||||
to all relevant End nodes and update `self.activate_end`.
|
||||
"""
|
||||
for node_id in data.keys():
|
||||
if activate.get(node_id):
|
||||
node_output_status = (
|
||||
data[node_id]
|
||||
.get('runtime_vars', {})
|
||||
.get(node_id)
|
||||
.get("output")
|
||||
)
|
||||
self._update_scope_activate(node_id, status=node_output_status)
|
||||
|
||||
async def _emit_active_chunks(
|
||||
self,
|
||||
node_outputs: dict,
|
||||
variables: dict,
|
||||
force=False
|
||||
):
|
||||
"""
|
||||
Process and yield all currently active output segments for the currently active End node.
|
||||
|
||||
This method handles stream-mode output for an End node by iterating through its output segments
|
||||
(`OutputContent`). Only segments marked as active (`activate=True`) are processed, unless
|
||||
`force=True`, which allows all segments to be processed regardless of their activation state.
|
||||
|
||||
Behavior:
|
||||
1. Iterates from the current `cursor` position to the end of the outputs list.
|
||||
2. For each segment:
|
||||
- If the segment is literal text (`is_variable=False`), append it directly.
|
||||
- If the segment is a variable (`is_variable=True`), evaluate it using
|
||||
`evaluate_expression` with the given `node_outputs` and `variables`,
|
||||
then transform the result with `_trans_output_string`.
|
||||
3. Yield a stream event of type "message" containing the processed chunk.
|
||||
4. Move the `cursor` forward after processing each segment.
|
||||
5. When all segments have been processed, remove this End node from `end_outputs`
|
||||
and reset `activate_end` to None.
|
||||
|
||||
Args:
|
||||
node_outputs (dict): Current runtime node outputs, used for variable evaluation.
|
||||
variables (dict): Current runtime variables, used for variable evaluation.
|
||||
force (bool, default=False): If True, process segments even if `activate=False`.
|
||||
|
||||
Yields:
|
||||
dict: A stream event of type "message" containing the processed chunk.
|
||||
|
||||
Notes:
|
||||
- Segments that fail evaluation (ValueError) are skipped with a warning logged.
|
||||
- This method only processes the currently active End node (`self.activate_end`).
|
||||
- Use `force=True` for final emission regardless of activation state.
|
||||
"""
|
||||
|
||||
end_info = self.end_outputs[self.activate_end]
|
||||
|
||||
while end_info.cursor < len(end_info.outputs):
|
||||
final_chunk = ''
|
||||
current_segment = end_info.outputs[end_info.cursor]
|
||||
|
||||
if not current_segment.activate and not force:
|
||||
# Stop processing until this segment becomes active
|
||||
break
|
||||
|
||||
# Literal segment
|
||||
if not current_segment.is_variable:
|
||||
final_chunk += current_segment.literal
|
||||
else:
|
||||
# Variable segment: evaluate and transform
|
||||
try:
|
||||
chunk = evaluate_expression(
|
||||
current_segment.literal,
|
||||
variables=variables,
|
||||
node_outputs=node_outputs
|
||||
)
|
||||
chunk = self._trans_output_string(chunk)
|
||||
final_chunk += chunk
|
||||
except ValueError:
|
||||
# Log failed evaluation but continue streaming
|
||||
logger.warning(f"[STREAM] Failed to evaluate segment: {current_segment.literal}")
|
||||
|
||||
if final_chunk:
|
||||
yield {
|
||||
"event": "message",
|
||||
"data": {
|
||||
"chunk": final_chunk
|
||||
}
|
||||
}
|
||||
|
||||
# Advance cursor after processing
|
||||
end_info.cursor += 1
|
||||
|
||||
# Remove End node from active tracking if all segments have been processed
|
||||
if end_info.cursor >= len(end_info.outputs):
|
||||
self.end_outputs.pop(self.activate_end)
|
||||
self.activate_end = None
|
||||
|
||||
@staticmethod
|
||||
def _trans_output_string(content):
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
elif isinstance(content, list):
|
||||
return "\n".join(content)
|
||||
else:
|
||||
return str(content)
|
||||
|
||||
def build_graph(self, stream=False) -> CompiledStateGraph:
|
||||
"""构建 LangGraph
|
||||
|
||||
@@ -173,6 +308,7 @@ class WorkflowExecutor:
|
||||
stream=stream,
|
||||
)
|
||||
self.start_node_id = builder.start_node_id
|
||||
self.end_outputs = builder.end_node_map
|
||||
graph = builder.build()
|
||||
logger.info(f"工作流图构建完成: execution_id={self.execution_id}")
|
||||
|
||||
@@ -205,14 +341,28 @@ class WorkflowExecutor:
|
||||
try:
|
||||
|
||||
result = await graph.ainvoke(initial_state, config=self.checkpoint_config)
|
||||
|
||||
full_content = ''
|
||||
for end_id in self.end_outputs.keys():
|
||||
full_content += result.get('runtime_vars', {}).get(end_id, {}).get('output', '')
|
||||
result["messages"].extend(
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": input_data.get("message", '')
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": full_content
|
||||
}
|
||||
]
|
||||
)
|
||||
# 计算耗时
|
||||
end_time = datetime.datetime.now()
|
||||
elapsed_time = (end_time - start_time).total_seconds()
|
||||
|
||||
logger.info(f"工作流执行完成: execution_id={self.execution_id}, elapsed_time={elapsed_time:.2f}s")
|
||||
|
||||
return self._build_final_output(result, elapsed_time)
|
||||
return self._build_final_output(result, elapsed_time, full_content)
|
||||
|
||||
except Exception as e:
|
||||
# 计算耗时(即使失败也记录)
|
||||
@@ -261,7 +411,7 @@ class WorkflowExecutor:
|
||||
"data": {
|
||||
"execution_id": self.execution_id,
|
||||
"workspace_id": self.workspace_id,
|
||||
"timestamp": start_time.isoformat()
|
||||
"timestamp": int(start_time.timestamp() * 1000)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -273,7 +423,8 @@ class WorkflowExecutor:
|
||||
# 3. Execute workflow
|
||||
try:
|
||||
chunk_count = 0
|
||||
|
||||
full_content = ''
|
||||
self._update_scope_activate("sys")
|
||||
async for event in graph.astream(
|
||||
initial_state,
|
||||
stream_mode=["updates", "debug", "custom"], # Use updates + debug + custom mode
|
||||
@@ -293,20 +444,42 @@ class WorkflowExecutor:
|
||||
# Handle custom streaming events (chunks from nodes via stream writer)
|
||||
chunk_count += 1
|
||||
event_type = data.get("type", "node_chunk") # "message" or "node_chunk"
|
||||
logger.info(f"[CUSTOM] ✅ 收到 {event_type} #{chunk_count} from {data.get('node_id')}"
|
||||
f"- execution_id: {self.execution_id}")
|
||||
yield {
|
||||
"event": event_type, # "message" or "node_chunk"
|
||||
"data": {
|
||||
"node_id": data.get("node_id"),
|
||||
"chunk": data.get("chunk"),
|
||||
"full_content": data.get("full_content"),
|
||||
"chunk_index": data.get("chunk_index"),
|
||||
"is_prefix": data.get("is_prefix"),
|
||||
"is_suffix": data.get("is_suffix"),
|
||||
"conversation_id": input_data.get("conversation_id"),
|
||||
if event_type == "node_chunk":
|
||||
node_id = data.get("node_id")
|
||||
if self.activate_end:
|
||||
end_info = self.end_outputs.get(self.activate_end)
|
||||
if not end_info or end_info.cursor >= len(end_info.outputs):
|
||||
continue
|
||||
current_output = end_info.outputs[end_info.cursor]
|
||||
if current_output.is_variable and current_output.depends_on_scope(node_id):
|
||||
if data.get("done"):
|
||||
end_info.cursor += 1
|
||||
if end_info.cursor >= len(end_info.outputs):
|
||||
self.end_outputs.pop(self.activate_end)
|
||||
self.activate_end = None
|
||||
else:
|
||||
full_content += data.get("chunk")
|
||||
yield {
|
||||
"event": "message",
|
||||
"data": {
|
||||
"chunk": data.get("chunk")
|
||||
}
|
||||
}
|
||||
logger.info(f"[CUSTOM] ✅ 收到 {event_type} #{chunk_count} from {data.get('node_id')}"
|
||||
f"- execution_id: {self.execution_id}")
|
||||
|
||||
elif event_type == "node_error":
|
||||
yield {
|
||||
"event": event_type, # "message" or "node_chunk"
|
||||
"data": {
|
||||
"node_id": data.get("node_id"),
|
||||
"status": "failed",
|
||||
"input": data.get("input_data"),
|
||||
"elapsed_time": data.get("elapsed_time"),
|
||||
"output": None,
|
||||
"error": data.get("error")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
elif mode == "debug":
|
||||
# Handle debug information (node execution status)
|
||||
@@ -325,14 +498,15 @@ class WorkflowExecutor:
|
||||
conversation_id = input_data.get("conversation_id")
|
||||
logger.info(f"[NODE-START] Node starts execution: {node_name} "
|
||||
f"- execution_id: {self.execution_id}")
|
||||
|
||||
yield {
|
||||
"event": "node_start",
|
||||
"data": {
|
||||
"node_id": node_name,
|
||||
"conversation_id": conversation_id,
|
||||
"execution_id": self.execution_id,
|
||||
"timestamp": data.get("timestamp"),
|
||||
"timestamp": int(datetime.datetime.fromisoformat(
|
||||
data.get("timestamp")
|
||||
).timestamp() * 1000),
|
||||
}
|
||||
}
|
||||
elif event_type == "task_result":
|
||||
@@ -351,21 +525,82 @@ class WorkflowExecutor:
|
||||
"node_id": node_name,
|
||||
"conversation_id": conversation_id,
|
||||
"execution_id": self.execution_id,
|
||||
"timestamp": data.get("timestamp"),
|
||||
"state": result.get("node_outputs", {}).get(node_name),
|
||||
"timestamp": int(datetime.datetime.fromisoformat(
|
||||
data.get("timestamp")
|
||||
).timestamp() * 1000),
|
||||
"input": result.get("node_outputs", {}).get(node_name, {}).get("input"),
|
||||
"output": result.get("node_outputs", {}).get(node_name, {}).get("output"),
|
||||
"elapsed_time": result.get("node_outputs", {}).get(node_name, {}).get("elapsed_time"),
|
||||
}
|
||||
}
|
||||
|
||||
elif mode == "updates":
|
||||
# Handle state updates - store final state
|
||||
# TODO:流式输出点
|
||||
state = graph.get_state(config=self.checkpoint_config).values
|
||||
node_outputs = state.get("runtime_vars", {})
|
||||
variables = state.get("variables", {})
|
||||
activate = state.get("activate", {})
|
||||
for _, node_data in data.items():
|
||||
node_outputs |= node_data.get("runtime_vars", {})
|
||||
variables |= node_data.get("variables", {})
|
||||
|
||||
self._update_stream_output_status(activate, data)
|
||||
wait = False
|
||||
while self.activate_end and not wait:
|
||||
async for msg_event in self._emit_active_chunks(
|
||||
node_outputs=node_outputs,
|
||||
variables=variables
|
||||
):
|
||||
full_content += msg_event["data"]['chunk']
|
||||
yield msg_event
|
||||
|
||||
if self.activate_end:
|
||||
wait = True
|
||||
else:
|
||||
self._update_stream_output_status(activate, data)
|
||||
|
||||
logger.debug(f"[UPDATES] 收到 state 更新 from {list(data.keys())} "
|
||||
f"- execution_id: {self.execution_id}")
|
||||
|
||||
result = graph.get_state(self.checkpoint_config).values
|
||||
node_outputs = result.get("runtime_vars", {})
|
||||
variables = result.get("variables", {})
|
||||
self.end_outputs = {
|
||||
node_id: node_info
|
||||
for node_id, node_info in self.end_outputs.items()
|
||||
if node_info.activate
|
||||
}
|
||||
|
||||
if self.end_outputs or self.activate_end:
|
||||
while self.activate_end:
|
||||
async for msg_event in self._emit_active_chunks(
|
||||
node_outputs=node_outputs,
|
||||
variables=variables,
|
||||
force=True
|
||||
):
|
||||
full_content += msg_event["data"]['chunk']
|
||||
yield msg_event
|
||||
|
||||
if not self.activate_end and self.end_outputs:
|
||||
self.activate_end = list(self.end_outputs.keys())[0]
|
||||
|
||||
# 计算耗时
|
||||
end_time = datetime.datetime.now()
|
||||
elapsed_time = (end_time - start_time).total_seconds()
|
||||
result = graph.get_state(self.checkpoint_config).values
|
||||
logger.info(result)
|
||||
result["messages"].extend(
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": input_data.get("message", '')
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": full_content
|
||||
}
|
||||
]
|
||||
)
|
||||
logger.info(
|
||||
f"Workflow execution completed (streaming), "
|
||||
f"total chunks: {chunk_count}, elapsed: {elapsed_time:.2f}s, execution_id: {self.execution_id}"
|
||||
@@ -374,7 +609,7 @@ class WorkflowExecutor:
|
||||
# 发送 workflow_end 事件
|
||||
yield {
|
||||
"event": "workflow_end",
|
||||
"data": self._build_final_output(result, elapsed_time)
|
||||
"data": self._build_final_output(result, elapsed_time, full_content)
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
@@ -396,31 +631,6 @@ class WorkflowExecutor:
|
||||
}
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _extract_final_output(node_outputs: dict[str, Any]) -> str | None:
|
||||
"""从节点输出中提取最终输出
|
||||
|
||||
优先级:
|
||||
1. 最后一个执行的非 start/end 节点的 output
|
||||
2. 如果没有节点输出,返回 None
|
||||
|
||||
Args:
|
||||
node_outputs: 所有节点的输出
|
||||
|
||||
Returns:
|
||||
最终输出字符串或 None
|
||||
"""
|
||||
if not node_outputs:
|
||||
return None
|
||||
|
||||
# 获取最后一个节点的输出
|
||||
last_node_output = list(node_outputs.values())[-1] if node_outputs else None
|
||||
|
||||
if last_node_output and isinstance(last_node_output, dict):
|
||||
return last_node_output.get("output")
|
||||
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _aggregate_token_usage(node_outputs: dict[str, Any]) -> dict[str, int] | None:
|
||||
"""聚合所有节点的 token 使用情况
|
||||
@@ -511,178 +721,3 @@ async def execute_workflow_stream(
|
||||
)
|
||||
async for event in executor.execute_stream(input_data):
|
||||
yield event
|
||||
|
||||
# ==================== 工具管理系统集成 ====================
|
||||
|
||||
# def get_workflow_tools(workspace_id: str, user_id: str) -> list:
|
||||
# """获取工作流可用的工具列表
|
||||
#
|
||||
# Args:
|
||||
# workspace_id: 工作空间ID
|
||||
# user_id: 用户ID
|
||||
#
|
||||
# Returns:
|
||||
# 可用工具列表
|
||||
# """
|
||||
# if not TOOL_MANAGEMENT_AVAILABLE:
|
||||
# logger.warning("工具管理系统不可用")
|
||||
# return []
|
||||
#
|
||||
# try:
|
||||
# db = next(get_db())
|
||||
#
|
||||
# # 创建工具注册表
|
||||
# registry = ToolRegistry(db)
|
||||
#
|
||||
# # 注册内置工具类
|
||||
# from app.core.tools.builtin import (
|
||||
# DateTimeTool, JsonTool, BaiduSearchTool, MinerUTool, TextInTool
|
||||
# )
|
||||
# registry.register_tool_class(DateTimeTool)
|
||||
# registry.register_tool_class(JsonTool)
|
||||
# registry.register_tool_class(BaiduSearchTool)
|
||||
# registry.register_tool_class(MinerUTool)
|
||||
# registry.register_tool_class(TextInTool)
|
||||
#
|
||||
# # 获取活跃的工具
|
||||
# import uuid
|
||||
# tools = registry.list_tools(workspace_id=uuid.UUID(workspace_id))
|
||||
# active_tools = [tool for tool in tools if tool.status.value == "active"]
|
||||
#
|
||||
# # 转换为Langchain工具
|
||||
# langchain_tools = []
|
||||
# for tool_info in active_tools:
|
||||
# try:
|
||||
# tool_instance = registry.get_tool(tool_info.id)
|
||||
# if tool_instance:
|
||||
# langchain_tool = LangchainAdapter.convert_tool(tool_instance)
|
||||
# langchain_tools.append(langchain_tool)
|
||||
# except Exception as e:
|
||||
# logger.error(f"转换工具失败: {tool_info.name}, 错误: {e}")
|
||||
#
|
||||
# logger.info(f"为工作流获取了 {len(langchain_tools)} 个工具")
|
||||
# return langchain_tools
|
||||
#
|
||||
# except Exception as e:
|
||||
# logger.error(f"获取工作流工具失败: {e}")
|
||||
# return []
|
||||
#
|
||||
#
|
||||
# class ToolWorkflowNode:
|
||||
# """工具工作流节点 - 在工作流中执行工具"""
|
||||
#
|
||||
# def __init__(self, node_config: dict, workflow_config: dict):
|
||||
# """初始化工具节点
|
||||
#
|
||||
# Args:
|
||||
# node_config: 节点配置
|
||||
# workflow_config: 工作流配置
|
||||
# """
|
||||
# self.node_config = node_config
|
||||
# self.workflow_config = workflow_config
|
||||
# self.tool_id = node_config.get("tool_id")
|
||||
# self.tool_parameters = node_config.get("parameters", {})
|
||||
#
|
||||
# async def run(self, state: WorkflowState) -> WorkflowState:
|
||||
# """执行工具节点"""
|
||||
# if not TOOL_MANAGEMENT_AVAILABLE:
|
||||
# logger.error("工具管理系统不可用")
|
||||
# state["error"] = "工具管理系统不可用"
|
||||
# return state
|
||||
#
|
||||
# try:
|
||||
# from sqlalchemy.orm import Session
|
||||
# db = next(get_db())
|
||||
#
|
||||
# # 创建工具执行器
|
||||
# registry = ToolRegistry(db)
|
||||
# executor = ToolExecutor(db, registry)
|
||||
#
|
||||
# # 准备参数(支持变量替换)
|
||||
# parameters = self._prepare_parameters(state)
|
||||
#
|
||||
# # 执行工具
|
||||
# result = await executor.execute_tool(
|
||||
# tool_id=self.tool_id,
|
||||
# parameters=parameters,
|
||||
# user_id=uuid.UUID(state["user_id"]),
|
||||
# workspace_id=uuid.UUID(state["workspace_id"])
|
||||
# )
|
||||
#
|
||||
# # 更新状态
|
||||
# node_id = self.node_config.get("id")
|
||||
# if result.success:
|
||||
# state["node_outputs"][node_id] = {
|
||||
# "type": "tool",
|
||||
# "tool_id": self.tool_id,
|
||||
# "output": result.data,
|
||||
# "execution_time": result.execution_time,
|
||||
# "token_usage": result.token_usage
|
||||
# }
|
||||
#
|
||||
# # 更新运行时变量
|
||||
# if isinstance(result.data, dict):
|
||||
# for key, value in result.data.items():
|
||||
# state["runtime_vars"][f"{node_id}.{key}"] = value
|
||||
# else:
|
||||
# state["runtime_vars"][f"{node_id}.result"] = result.data
|
||||
# else:
|
||||
# state["error"] = result.error
|
||||
# state["error_node"] = node_id
|
||||
# state["node_outputs"][node_id] = {
|
||||
# "type": "tool",
|
||||
# "tool_id": self.tool_id,
|
||||
# "error": result.error,
|
||||
# "execution_time": result.execution_time
|
||||
# }
|
||||
#
|
||||
# return state
|
||||
#
|
||||
# except Exception as e:
|
||||
# logger.error(f"工具节点执行失败: {e}")
|
||||
# state["error"] = str(e)
|
||||
# state["error_node"] = self.node_config.get("id")
|
||||
# return state
|
||||
#
|
||||
# def _prepare_parameters(self, state: WorkflowState) -> dict:
|
||||
# """准备工具参数(支持变量替换)"""
|
||||
# parameters = {}
|
||||
#
|
||||
# for key, value in self.tool_parameters.items():
|
||||
# if isinstance(value, str) and value.startswith("${") and value.endswith("}"):
|
||||
# # 变量替换
|
||||
# var_path = value[2:-1]
|
||||
#
|
||||
# # 支持多层级变量访问,如 ${sys.message} 或 ${node1.result}
|
||||
# if "." in var_path:
|
||||
# parts = var_path.split(".")
|
||||
# current = state.get("variables", {})
|
||||
#
|
||||
# for part in parts:
|
||||
# if isinstance(current, dict) and part in current:
|
||||
# current = current[part]
|
||||
# else:
|
||||
# # 尝试从运行时变量获取
|
||||
# runtime_key = ".".join(parts)
|
||||
# current = state.get("runtime_vars", {}).get(runtime_key, value)
|
||||
# break
|
||||
#
|
||||
# parameters[key] = current
|
||||
# else:
|
||||
# # 简单变量
|
||||
# variables = state.get("variables", {})
|
||||
# parameters[key] = variables.get(var_path, value)
|
||||
# else:
|
||||
# parameters[key] = value
|
||||
#
|
||||
# return parameters
|
||||
#
|
||||
#
|
||||
# # 注册工具节点到NodeFactory(如果存在)
|
||||
# try:
|
||||
# from app.core.workflow.nodes import NodeFactory
|
||||
# if hasattr(NodeFactory, 'register_node_type'):
|
||||
# NodeFactory.register_node_type("tool", ToolWorkflowNode)
|
||||
# logger.info("工具节点已注册到工作流系统")
|
||||
# except Exception as e:
|
||||
# logger.warning(f"注册工具节点失败: {e}")
|
||||
|
||||
@@ -1,12 +1,15 @@
|
||||
import logging
|
||||
import re
|
||||
import uuid
|
||||
from collections import defaultdict
|
||||
from functools import lru_cache
|
||||
from typing import Any
|
||||
|
||||
from langgraph.checkpoint.memory import InMemorySaver
|
||||
from langgraph.graph import START, END
|
||||
from langgraph.graph.state import CompiledStateGraph, StateGraph
|
||||
from langgraph.types import Send
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.core.workflow.expression_evaluator import evaluate_condition
|
||||
from app.core.workflow.nodes import WorkflowState, NodeFactory
|
||||
@@ -15,6 +18,149 @@ from app.core.workflow.nodes.enums import NodeType, BRANCH_NODES
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OutputContent(BaseModel):
|
||||
"""
|
||||
Represents a single output segment of an End node.
|
||||
|
||||
An output segment can be either:
|
||||
- literal text (static string)
|
||||
- a variable placeholder (e.g. {{ node.field }})
|
||||
|
||||
Each segment has its own activation state, which is especially
|
||||
important in stream mode.
|
||||
"""
|
||||
|
||||
literal: str = Field(
|
||||
...,
|
||||
description="Raw output content. Can be literal text or a variable placeholder."
|
||||
)
|
||||
|
||||
activate: bool = Field(
|
||||
...,
|
||||
description=(
|
||||
"Whether this output segment is currently active.\n"
|
||||
"- True: allowed to be emitted/output\n"
|
||||
"- False: blocked until activated by branch control"
|
||||
)
|
||||
)
|
||||
|
||||
is_variable: bool = Field(
|
||||
...,
|
||||
description=(
|
||||
"Whether this segment represents a variable placeholder.\n"
|
||||
"True -> variable (e.g. {{ node.field }})\n"
|
||||
"False -> literal text"
|
||||
)
|
||||
)
|
||||
|
||||
def depends_on_scope(self, scope: str) -> bool:
|
||||
"""
|
||||
Check if this segment depends on a given scope.
|
||||
|
||||
Args:
|
||||
scope (str): Node ID or special variable prefix (e.g., "sys").
|
||||
|
||||
Returns:
|
||||
bool: True if this segment references the given scope.
|
||||
"""
|
||||
pattern = rf"\{{\{{\s*{re.escape(scope)}\.[a-zA-Z0-9_]+\s*\}}\}}"
|
||||
return bool(re.search(pattern, self.literal))
|
||||
|
||||
|
||||
class StreamOutputConfig(BaseModel):
|
||||
"""
|
||||
Streaming output configuration for an End node.
|
||||
|
||||
This configuration describes how the End node output behaves in streaming mode,
|
||||
including:
|
||||
- whether output emission is globally activated
|
||||
- which upstream branch/control nodes gate the activation
|
||||
- how each parsed output segment is streamed and activated
|
||||
"""
|
||||
|
||||
activate: bool = Field(
|
||||
...,
|
||||
description=(
|
||||
"Global activation flag for the End node output.\n"
|
||||
"When False, output segments should not be emitted even if available.\n"
|
||||
"This flag typically becomes True once required control branch conditions "
|
||||
"are satisfied."
|
||||
)
|
||||
)
|
||||
|
||||
control_nodes: dict[str, str] = Field(
|
||||
...,
|
||||
description=(
|
||||
"Control branch conditions for this End node output.\n"
|
||||
"Mapping of `branch_node_id -> expected_branch_label`.\n"
|
||||
"The End node output becomes globally active when a controlling branch node "
|
||||
"reports a matching completion status."
|
||||
)
|
||||
)
|
||||
|
||||
outputs: list[OutputContent] = Field(
|
||||
...,
|
||||
description=(
|
||||
"Ordered list of output segments parsed from the output template.\n"
|
||||
"Each segment represents either a literal text block or a variable placeholder "
|
||||
"that may be activated independently."
|
||||
)
|
||||
)
|
||||
|
||||
cursor: int = Field(
|
||||
...,
|
||||
description=(
|
||||
"Streaming cursor index.\n"
|
||||
"Indicates the next output segment index to be emitted.\n"
|
||||
"Segments with index < cursor are considered already streamed."
|
||||
)
|
||||
)
|
||||
|
||||
def update_activate(self, scope: str, status=None):
|
||||
"""
|
||||
Update streaming activation state based on an upstream node or special variable.
|
||||
|
||||
Args:
|
||||
scope (str):
|
||||
Identifier of the completed upstream entity.
|
||||
- If a control branch node, it should match a key in `control_nodes`.
|
||||
- If a variable placeholder (e.g., "sys.xxx"), it may appear in output segments.
|
||||
status (optional):
|
||||
Completion status of the control branch node.
|
||||
Required when `scope` refers to a control node.
|
||||
|
||||
Behavior:
|
||||
1. Control branch nodes:
|
||||
- If `scope` matches a key in `control_nodes` and `status` matches the expected
|
||||
branch label, the End node output becomes globally active (`activate = True`).
|
||||
|
||||
2. Variable output segments:
|
||||
- For each segment that is a variable (`is_variable=True`):
|
||||
- If the segment literal references `scope`, mark the segment as active.
|
||||
- This applies both to regular node variables (e.g., "node_id.field")
|
||||
and special system variables (e.g., "sys.xxx").
|
||||
|
||||
Notes:
|
||||
- This method does not emit output or advance the streaming cursor.
|
||||
- It only updates activation flags based on upstream events or special variables.
|
||||
"""
|
||||
|
||||
# Case 1: resolve control branch dependency
|
||||
if scope in self.control_nodes.keys():
|
||||
if status is None:
|
||||
raise RuntimeError("[Stream Output] Control node activation status not provided")
|
||||
if status == self.control_nodes[scope]:
|
||||
self.activate = True
|
||||
|
||||
# Case 2: activate variable segments related to this node
|
||||
for i in range(len(self.outputs)):
|
||||
if (
|
||||
self.outputs[i].is_variable
|
||||
and self.outputs[i].depends_on_scope(scope)
|
||||
):
|
||||
self.outputs[i].activate = True
|
||||
|
||||
|
||||
class GraphBuilder:
|
||||
def __init__(
|
||||
self,
|
||||
@@ -29,10 +175,16 @@ class GraphBuilder:
|
||||
|
||||
self.start_node_id = None
|
||||
self.end_node_ids = []
|
||||
self.node_map = {node["id"]: node for node in self.nodes}
|
||||
self.end_node_map: dict[str, StreamOutputConfig] = {}
|
||||
self._find_upstream_branch_node = lru_cache(
|
||||
maxsize=len(self.nodes) * 2
|
||||
)(self._find_upstream_branch_node)
|
||||
|
||||
self.graph = StateGraph(WorkflowState)
|
||||
self.add_nodes()
|
||||
self.add_edges()
|
||||
self._analyze_end_node_output()
|
||||
# EDGES MUST BE ADDED AFTER NODES ARE ADDED.
|
||||
|
||||
@property
|
||||
@@ -43,79 +195,207 @@ class GraphBuilder:
|
||||
def edges(self) -> list[dict[str, Any]]:
|
||||
return self.workflow_config.get("edges", [])
|
||||
|
||||
def _analyze_end_node_prefixes(self) -> tuple[dict[str, str], set[str]]:
|
||||
"""
|
||||
Analyze the prefix configuration for End nodes.
|
||||
def get_node_type(self, node_id: str) -> str:
|
||||
"""Retrieve the type of node given its ID.
|
||||
|
||||
This function scans each End node's output template, identifies
|
||||
references to its direct upstream nodes, and extracts the prefix
|
||||
string appearing before the first reference.
|
||||
Args:
|
||||
node_id (str): The unique identifier of the node.
|
||||
|
||||
Returns:
|
||||
tuple:
|
||||
- dict[str, str]: Mapping from upstream node ID to its End node prefix
|
||||
- set[str]: Set of node IDs that are directly adjacent to End nodes and referenced
|
||||
str: The type of the node.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If no node with the given `node_id` exists.
|
||||
"""
|
||||
import re
|
||||
try:
|
||||
return self.node_map[node_id]["type"]
|
||||
except KeyError:
|
||||
raise RuntimeError(f"Node not found: Id={node_id}")
|
||||
|
||||
prefixes = {}
|
||||
adjacent_and_referenced = set() # Record nodes directly adjacent to End and referenced
|
||||
def _find_upstream_branch_node(self, target_node: str) -> tuple[bool, tuple[tuple[str, str]]]:
|
||||
"""
|
||||
Recursively find all upstream branch (control) nodes that influence the execution
|
||||
of the given target node.
|
||||
|
||||
# 找到所有 End 节点
|
||||
This method walks upstream along the workflow graph starting from `target_node`.
|
||||
It distinguishes between:
|
||||
- branch nodes (node types listed in `BRANCH_NODES`)
|
||||
- non-branch nodes (ordinary processing nodes)
|
||||
|
||||
Traversal rules:
|
||||
1. For each immediate upstream node:
|
||||
- If it is a branch node, it is recorded as an affecting control node.
|
||||
- If it is a non-branch node, the traversal continues recursively upstream.
|
||||
2. If ANY upstream path reaches a START / CYCLE_START node without encountering
|
||||
a branch node, the traversal is considered invalid:
|
||||
- `has_branch` will be False
|
||||
- no branch nodes are returned.
|
||||
3. Only when ALL upstream non-branch paths eventually lead to at least one
|
||||
branch node will `has_branch` be True.
|
||||
|
||||
Special case:
|
||||
- If `target_node` has no upstream nodes AND its type is START or CYCLE_START,
|
||||
it is considered directly reachable from the workflow entry, and therefore
|
||||
has no controlling branch nodes.
|
||||
|
||||
Args:
|
||||
target_node (str):
|
||||
The identifier of the node whose upstream control branches
|
||||
are to be resolved.
|
||||
|
||||
Returns:
|
||||
tuple[bool, tuple[tuple[str, str]]]:
|
||||
- has_branch (bool):
|
||||
True if every upstream path from `target_node` encounters
|
||||
at least one branch node.
|
||||
False if any path reaches a start node without a branch.
|
||||
- branch_nodes (tuple[tuple[str, str]]):
|
||||
A deduplicated tuple of `(branch_node_id, branch_label)` pairs
|
||||
representing all branch nodes that can influence `target_node`.
|
||||
Returns an empty tuple if `has_branch` is False.
|
||||
"""
|
||||
source_nodes = [
|
||||
{
|
||||
"id": edge.get("source"),
|
||||
"branch": edge.get("label")
|
||||
}
|
||||
for edge in self.edges
|
||||
if edge.get("target") == target_node
|
||||
]
|
||||
if not source_nodes and self.get_node_type(target_node) in [NodeType.START, NodeType.CYCLE_START]:
|
||||
return False, tuple()
|
||||
|
||||
branch_nodes = []
|
||||
non_branch_nodes = []
|
||||
|
||||
for node_info in source_nodes:
|
||||
if self.get_node_type(node_info["id"]) in BRANCH_NODES:
|
||||
branch_nodes.append(
|
||||
(node_info["id"], node_info["branch"])
|
||||
)
|
||||
else:
|
||||
non_branch_nodes.append(node_info["id"])
|
||||
|
||||
has_branch = True
|
||||
for node_id in non_branch_nodes:
|
||||
node_has_branch, nodes = self._find_upstream_branch_node(node_id)
|
||||
has_branch = has_branch and node_has_branch
|
||||
if not has_branch:
|
||||
break
|
||||
branch_nodes.extend(nodes)
|
||||
if not has_branch:
|
||||
branch_nodes = []
|
||||
|
||||
return has_branch, tuple(set(branch_nodes))
|
||||
|
||||
def _analyze_end_node_output(self):
|
||||
"""
|
||||
Analyze output templates of all End nodes and generate StreamOutputConfig.
|
||||
|
||||
This method is responsible for parsing the `output` field of End nodes,
|
||||
splitting literal text and variable placeholders (e.g. {{ node.field }}),
|
||||
and determining whether each output segment should be activated immediately
|
||||
or controlled by upstream branch nodes.
|
||||
|
||||
In stream mode:
|
||||
- If the End node is controlled by any upstream branch node, the output
|
||||
will be initially inactive and controlled by those branch nodes.
|
||||
- Otherwise, the output is activated immediately.
|
||||
|
||||
In non-stream mode:
|
||||
- All outputs are activated by default.
|
||||
"""
|
||||
|
||||
# Collect all End nodes in the workflow
|
||||
end_nodes = [node for node in self.nodes if node.get("type") == "end"]
|
||||
logger.info(f"[Prefix Analysis] Found {len(end_nodes)} End nodes")
|
||||
|
||||
# Iterate through each End node to analyze its output
|
||||
for end_node in end_nodes:
|
||||
end_node_id = end_node.get("id")
|
||||
output_template = end_node.get("config", {}).get("output")
|
||||
config = end_node.get("config", {})
|
||||
output = config.get("output")
|
||||
|
||||
logger.info(f"[Prefix Analysis] End node {end_node_id} template: {output_template}")
|
||||
|
||||
if not output_template:
|
||||
# Skip End nodes without output configuration
|
||||
if not output:
|
||||
continue
|
||||
|
||||
# Find all node references in the template
|
||||
# Matches {{node_id.xxx}} or {{ node_id.xxx }} format (allowing spaces)
|
||||
pattern = r'\{\{\s*([a-zA-Z0-9_-]+)\.[a-zA-Z0-9_]+\s*\}\}'
|
||||
matches = list(re.finditer(pattern, output_template))
|
||||
# Regex to split output into:
|
||||
# - variable placeholders: {{ ... }}
|
||||
# - normal literal text
|
||||
#
|
||||
# Example:
|
||||
# "Hello {{user.name}}!" ->
|
||||
# ["Hello ", "{{user.name}}", "!"]
|
||||
pattern = r'\{\{.*?\}\}|[^{}]+'
|
||||
|
||||
logger.info(f"[Prefix Analysis] 模板中找到 {len(matches)} 个节点引用")
|
||||
# Strict variable format: {{ node_id.field_name }}
|
||||
variable_pattern_string = r'\{\{\s*[a-zA-Z0-9_]+\.[a-zA-Z0-9_]+\s*\}\}'
|
||||
variable_pattern = re.compile(variable_pattern_string)
|
||||
|
||||
# Identify all direct upstream nodes connected to the End node
|
||||
direct_upstream_nodes = []
|
||||
for edge in self.edges:
|
||||
if edge.get("target") == end_node_id:
|
||||
source_node_id = edge.get("source")
|
||||
direct_upstream_nodes.append(source_node_id)
|
||||
# Split output into ordered segments
|
||||
output_template = list(re.findall(pattern, output))
|
||||
|
||||
logger.info(f"[Prefix Analysis] Direct upstream nodes of End node: {direct_upstream_nodes}")
|
||||
# Determine whether each segment is literal text
|
||||
# True -> literal (can be directly output)
|
||||
# False -> variable placeholder (needs runtime value)
|
||||
output_flag = [
|
||||
not bool(variable_pattern.match(item))
|
||||
for item in output_template
|
||||
]
|
||||
|
||||
# 找到第一个直接上游节点的引用
|
||||
for match in matches:
|
||||
referenced_node_id = match.group(1)
|
||||
logger.info(f"[Prefix Analysis] Checking reference: {referenced_node_id}")
|
||||
# Stream mode: output activation depends on upstream branch nodes
|
||||
if self.stream:
|
||||
# Find upstream branch nodes that can control this End node
|
||||
has_branch, control_nodes = self._find_upstream_branch_node(end_node_id)
|
||||
|
||||
if referenced_node_id in direct_upstream_nodes:
|
||||
# 这是直接上游节点的引用,提取前缀
|
||||
prefix = output_template[:match.start()]
|
||||
# Build StreamOutputConfig for this End node
|
||||
self.end_node_map[end_node_id] = StreamOutputConfig(
|
||||
# If there is no upstream branch, output is active immediately
|
||||
activate=not has_branch,
|
||||
|
||||
logger.info(f"[Prefix Analysis] "
|
||||
f"✅ Found reference to direct upstream node {referenced_node_id}, prefix: '{prefix}'")
|
||||
# Branch nodes that control activation of this End node
|
||||
control_nodes=dict(control_nodes),
|
||||
|
||||
# 标记这个节点为"相邻且被引用"
|
||||
adjacent_and_referenced.add(referenced_node_id)
|
||||
# Convert output segments into OutputContent objects
|
||||
outputs=list(
|
||||
[
|
||||
OutputContent(
|
||||
literal=output_string,
|
||||
# Literal text can be activated immediately unless blocked by branch
|
||||
activate=activate,
|
||||
# Variable segments are marked explicitly
|
||||
is_variable=not activate
|
||||
)
|
||||
for output_string, activate in zip(output_template, output_flag)
|
||||
]
|
||||
),
|
||||
# Cursor for streaming output (initially 0)
|
||||
cursor=0
|
||||
)
|
||||
logger.info(f"[Stream Analysis] end_id: {end_node_id}, "
|
||||
f"activate: {not has_branch}, "
|
||||
f"control_nodes: {control_nodes},"
|
||||
f"output: {output_template},"
|
||||
f"output_activate: {output_flag}")
|
||||
|
||||
if prefix:
|
||||
prefixes[referenced_node_id] = prefix
|
||||
logger.info(f"[Prefix Analysis] "
|
||||
f"✅ Assign prefix for node {referenced_node_id}: '{prefix[:50]}...'")
|
||||
|
||||
# 只处理第一个直接上游节点的引用
|
||||
break
|
||||
|
||||
logger.info(f"[Prefix Analysis] Final prefixes: {prefixes}")
|
||||
logger.info(f"[Prefix Analysis] Nodes adjacent to End and referenced: {adjacent_and_referenced}")
|
||||
return prefixes, adjacent_and_referenced
|
||||
# Non-stream mode: all outputs are activated by default
|
||||
else:
|
||||
self.end_node_map[end_node_id] = StreamOutputConfig(
|
||||
activate=True,
|
||||
control_nodes={},
|
||||
outputs=list(
|
||||
[
|
||||
OutputContent(
|
||||
literal=output_string,
|
||||
activate=True,
|
||||
is_variable=not activate
|
||||
)
|
||||
for output_string, activate in zip(output_template, output_flag)
|
||||
]
|
||||
),
|
||||
cursor=0
|
||||
)
|
||||
|
||||
def add_nodes(self):
|
||||
"""Add all nodes from the workflow configuration to the state graph.
|
||||
@@ -135,9 +415,6 @@ class GraphBuilder:
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
# Analyze End node prefixes if in stream mode
|
||||
end_prefixes, adjacent_and_referenced = self._analyze_end_node_prefixes() if self.stream else ({}, set())
|
||||
|
||||
for node in self.nodes:
|
||||
node_type = node.get("type")
|
||||
node_id = node.get("id")
|
||||
@@ -171,17 +448,6 @@ class GraphBuilder:
|
||||
related_edge[idx]['condition'] = f"node.{node_id}.output == '{related_edge[idx]['label']}'"
|
||||
|
||||
if node_instance:
|
||||
# Inject End node prefix configuration if in stream mode
|
||||
if self.stream and node_id in end_prefixes:
|
||||
node_instance._end_node_prefix = end_prefixes[node_id]
|
||||
logger.info(f"Injected End prefix for node {node_id}")
|
||||
|
||||
# Mark nodes as adjacent and referenced to End node in stream mode
|
||||
if self.stream:
|
||||
node_instance._is_adjacent_to_end = node_id in adjacent_and_referenced
|
||||
if node_id in adjacent_and_referenced:
|
||||
logger.info(f"Node {node_id} marked as adjacent and referenced to End node")
|
||||
|
||||
# Wrap node's run method to avoid closure issues
|
||||
if self.stream:
|
||||
# Stream mode: create an async generator function
|
||||
@@ -261,6 +527,7 @@ class GraphBuilder:
|
||||
for source_node, branches in conditional_edges.items():
|
||||
def make_router(src, branch_list):
|
||||
"""reate a router function for each source node that routes to a NOP node for later merging."""
|
||||
|
||||
def make_branch_node(node_name, targets):
|
||||
def node(s):
|
||||
# NOTE: NOP NODE MUST NOT MODIFY STATE
|
||||
|
||||
@@ -67,10 +67,6 @@ class WorkflowState(TypedDict):
|
||||
error: str | None
|
||||
error_node: str | None
|
||||
|
||||
# Streaming buffer (stores real-time streaming output of nodes)
|
||||
# Format: {node_id: {"chunks": [...], "full_content": "..."}}
|
||||
streaming_buffer: Annotated[dict[str, Any], lambda x, y: {**x, **y}]
|
||||
|
||||
# node activate status
|
||||
activate: Annotated[dict[str, bool], merge_activate_state]
|
||||
|
||||
@@ -300,7 +296,7 @@ class BaseNode(ABC):
|
||||
"""
|
||||
if not self.check_activate(state):
|
||||
yield self.trans_activate(state)
|
||||
logger.info(f"跳过节点{self.node_id}")
|
||||
logger.info(f"jump node: {self.node_id}")
|
||||
return
|
||||
|
||||
import time
|
||||
@@ -313,19 +309,6 @@ class BaseNode(ABC):
|
||||
# Get LangGraph's stream writer for sending custom data
|
||||
writer = get_stream_writer()
|
||||
|
||||
# Check if this is an End node
|
||||
# End nodes CAN send chunks (for suffix), but only after LLM content
|
||||
is_end_node = self.node_type == "end"
|
||||
|
||||
# Check if this node is adjacent to End node (for message type)
|
||||
is_adjacent_to_end = getattr(self, '_is_adjacent_to_end', False)
|
||||
|
||||
# Determine chunk type: "message" for End and adjacent nodes, "node_chunk" for others
|
||||
chunk_type = "message" if (is_end_node or is_adjacent_to_end) else "node_chunk"
|
||||
|
||||
logger.debug(
|
||||
f"节点 {self.node_id} chunk 类型: {chunk_type} (is_end={is_end_node}, adjacent={is_adjacent_to_end})")
|
||||
|
||||
# Accumulate complete result (for final wrapping)
|
||||
chunks = []
|
||||
final_result = None
|
||||
@@ -340,66 +323,25 @@ class BaseNode(ABC):
|
||||
raise TimeoutError()
|
||||
|
||||
# Check if it's a completion marker
|
||||
if isinstance(item, dict) and item.get("__final__"):
|
||||
if item.get("__final__"):
|
||||
final_result = item["result"]
|
||||
elif isinstance(item, str):
|
||||
# String is a chunk
|
||||
else:
|
||||
chunk_count += 1
|
||||
chunks.append(item)
|
||||
full_content = "".join(chunks)
|
||||
content = str(item.get("chunk"))
|
||||
done = item.get("done", False)
|
||||
chunks.append(content)
|
||||
|
||||
# Send chunks for all nodes (including End nodes for suffix)
|
||||
logger.debug(f"节点 {self.node_id} 发送 chunk #{chunk_count}: {item[:50]}...")
|
||||
logger.debug(f"节点 {self.node_id} 发送 chunk #{chunk_count}: {content[:50]}...")
|
||||
|
||||
# 1. Send via stream writer (for real-time client updates)
|
||||
writer({
|
||||
"type": chunk_type, # "message" or "node_chunk"
|
||||
"type": "node_chunk",
|
||||
"node_id": self.node_id,
|
||||
"chunk": item,
|
||||
"full_content": full_content,
|
||||
"chunk_index": chunk_count
|
||||
"chunk": content,
|
||||
"done": done
|
||||
})
|
||||
|
||||
# 2. Update streaming buffer in state (for downstream nodes)
|
||||
# Only non-End nodes need streaming buffer
|
||||
if not is_end_node:
|
||||
yield {
|
||||
"streaming_buffer": {
|
||||
self.node_id: {
|
||||
"full_content": full_content,
|
||||
"chunk_count": chunk_count,
|
||||
"is_complete": False
|
||||
}
|
||||
}
|
||||
}
|
||||
else:
|
||||
# Other types are also treated as chunks
|
||||
chunk_count += 1
|
||||
chunk_str = str(item)
|
||||
chunks.append(chunk_str)
|
||||
full_content = "".join(chunks)
|
||||
|
||||
# Send chunks for all nodes
|
||||
writer({
|
||||
"type": chunk_type, # "message" or "node_chunk"
|
||||
"node_id": self.node_id,
|
||||
"chunk": chunk_str,
|
||||
"full_content": full_content,
|
||||
"chunk_index": chunk_count
|
||||
})
|
||||
|
||||
# Only non-End nodes need streaming buffer
|
||||
if not is_end_node:
|
||||
yield {
|
||||
"streaming_buffer": {
|
||||
self.node_id: {
|
||||
"full_content": full_content,
|
||||
"chunk_count": chunk_count,
|
||||
"is_complete": False
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
logger.info(f"节点 {self.node_id} 流式执行完成,耗时: {elapsed_time:.2f}s, chunks: {chunk_count}")
|
||||
@@ -426,16 +368,6 @@ class BaseNode(ABC):
|
||||
"looping": state["looping"]
|
||||
}
|
||||
|
||||
# Add streaming buffer for non-End nodes
|
||||
if not is_end_node:
|
||||
state_update["streaming_buffer"] = {
|
||||
self.node_id: {
|
||||
"full_content": "".join(chunks),
|
||||
"chunk_count": chunk_count,
|
||||
"is_complete": True # Mark as complete
|
||||
}
|
||||
}
|
||||
|
||||
# Finally yield state update
|
||||
# LangGraph will merge this into state
|
||||
yield state_update | self.trans_activate(state)
|
||||
@@ -544,6 +476,11 @@ class BaseNode(ABC):
|
||||
"error_node": self.node_id
|
||||
}
|
||||
else:
|
||||
writer = get_stream_writer()
|
||||
writer({
|
||||
"type": "node_error",
|
||||
**node_output
|
||||
})
|
||||
# 无错误边:抛出异常停止工作流
|
||||
logger.error(f"节点 {self.node_id} 执行失败,停止工作流: {error_message}")
|
||||
raise Exception(f"节点 {self.node_id} 执行失败: {error_message}")
|
||||
|
||||
@@ -0,0 +1,3 @@
|
||||
from app.core.workflow.nodes.code.node import CodeNode
|
||||
|
||||
__all__ = ["CodeNode"]
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user