Merge branch 'develop' into feature/knowledgeBase_yjp
This commit is contained in:
25
api/app/base/type.py
Normal file
25
api/app/base/type.py
Normal file
@@ -0,0 +1,25 @@
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import TypeDecorator, JSON
|
||||
|
||||
|
||||
class PydanticType(TypeDecorator):
|
||||
impl = JSON
|
||||
|
||||
def __init__(self, pydantic_model: type[BaseModel]):
|
||||
super().__init__()
|
||||
self.model = pydantic_model
|
||||
|
||||
def process_bind_param(self, value, dialect):
|
||||
# 入库:Model -> dict
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, self.model):
|
||||
return value.dict()
|
||||
return value # 已经是 dict 也放行
|
||||
|
||||
def process_result_value(self, value, dialect):
|
||||
# 出库:dict -> Model
|
||||
if value is None:
|
||||
return None
|
||||
# return self.model.parse_obj(value) # pydantic v1
|
||||
return self.model.model_validate(value) # pydantic v2
|
||||
@@ -4,39 +4,45 @@
|
||||
认证方式: JWT Token
|
||||
"""
|
||||
from fastapi import APIRouter
|
||||
|
||||
from . import (
|
||||
model_controller,
|
||||
task_controller,
|
||||
test_controller,
|
||||
user_controller,
|
||||
auth_controller,
|
||||
workspace_controller,
|
||||
setup_controller,
|
||||
file_controller,
|
||||
document_controller,
|
||||
knowledge_controller,
|
||||
chunk_controller,
|
||||
knowledgeshare_controller,
|
||||
api_key_controller,
|
||||
app_controller,
|
||||
upload_controller,
|
||||
auth_controller,
|
||||
chunk_controller,
|
||||
document_controller,
|
||||
emotion_config_controller,
|
||||
emotion_controller,
|
||||
file_controller,
|
||||
home_page_controller,
|
||||
implicit_memory_controller,
|
||||
knowledge_controller,
|
||||
knowledgeshare_controller,
|
||||
memory_agent_controller,
|
||||
memory_dashboard_controller,
|
||||
memory_storage_controller,
|
||||
memory_dashboard_controller,
|
||||
memory_forget_controller,
|
||||
memory_reflection_controller,
|
||||
api_key_controller,
|
||||
release_share_controller,
|
||||
public_share_controller,
|
||||
memory_short_term_controller,
|
||||
memory_storage_controller,
|
||||
model_controller,
|
||||
multi_agent_controller,
|
||||
workflow_controller,
|
||||
emotion_controller,
|
||||
emotion_config_controller,
|
||||
prompt_optimizer_controller,
|
||||
public_share_controller,
|
||||
release_share_controller,
|
||||
setup_controller,
|
||||
task_controller,
|
||||
test_controller,
|
||||
tool_controller,
|
||||
upload_controller,
|
||||
user_controller,
|
||||
user_memory_controllers,
|
||||
workflow_controller,
|
||||
workspace_controller,
|
||||
memory_forget_controller,
|
||||
home_page_controller,
|
||||
memory_perceptual_controller,
|
||||
memory_working_controller,
|
||||
)
|
||||
from . import user_memory_controllers
|
||||
|
||||
# 创建管理端 API 路由器
|
||||
manager_router = APIRouter()
|
||||
@@ -71,8 +77,12 @@ manager_router.include_router(emotion_controller.router)
|
||||
manager_router.include_router(emotion_config_controller.router)
|
||||
manager_router.include_router(prompt_optimizer_controller.router)
|
||||
manager_router.include_router(memory_reflection_controller.router)
|
||||
manager_router.include_router(memory_short_term_controller.router)
|
||||
manager_router.include_router(tool_controller.router)
|
||||
manager_router.include_router(memory_forget_controller.router)
|
||||
manager_router.include_router(home_page_controller.router)
|
||||
manager_router.include_router(implicit_memory_controller.router)
|
||||
manager_router.include_router(memory_perceptual_controller.router)
|
||||
manager_router.include_router(memory_working_controller.router)
|
||||
|
||||
__all__ = ["manager_router"]
|
||||
|
||||
@@ -48,6 +48,7 @@ def list_apps(
|
||||
include_shared: bool = True,
|
||||
page: int = 1,
|
||||
pagesize: int = 10,
|
||||
ids: Optional[str] = None,
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
):
|
||||
@@ -55,8 +56,19 @@ def list_apps(
|
||||
|
||||
- 默认包含本工作空间的应用和分享给本工作空间的应用
|
||||
- 设置 include_shared=false 可以只查看本工作空间的应用
|
||||
- 当提供 ids 参数时,按逗号分割获取指定应用,不分页
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
service = app_service.AppService(db)
|
||||
|
||||
# 当 ids 存在且不为 None 时,根据 ids 获取应用
|
||||
if ids is not None:
|
||||
app_ids = [id.strip() for id in ids.split(',') if id.strip()]
|
||||
items_orm = app_service.get_apps_by_ids(db, app_ids, workspace_id)
|
||||
items = [service._convert_to_schema(app, workspace_id) for app in items_orm]
|
||||
return success(data=items)
|
||||
|
||||
# 正常分页查询
|
||||
items_orm, total = app_service.list_apps(
|
||||
db,
|
||||
workspace_id=workspace_id,
|
||||
@@ -69,8 +81,6 @@ def list_apps(
|
||||
pagesize=pagesize,
|
||||
)
|
||||
|
||||
# 使用 AppService 的转换方法来设置 is_shared 字段
|
||||
service = app_service.AppService(db)
|
||||
items = [service._convert_to_schema(app, workspace_id) for app in items_orm]
|
||||
meta = PageMeta(page=page, pagesize=pagesize, total=total, hasnext=(page * pagesize) < total)
|
||||
return success(data=PageData(page=meta, items=items))
|
||||
@@ -506,7 +516,7 @@ async def draft_run(
|
||||
multi_agent_request = MultiAgentRunRequest(
|
||||
message=payload.message,
|
||||
conversation_id=payload.conversation_id,
|
||||
user_id=payload.user_id,
|
||||
user_id=payload.user_id or str(current_user.id),
|
||||
variables=payload.variables or {},
|
||||
use_llm_routing=True # 默认启用 LLM 路由
|
||||
)
|
||||
@@ -728,9 +738,23 @@ async def draft_run_compare(
|
||||
from app.core.exceptions import ResourceNotFoundException
|
||||
raise ResourceNotFoundException("模型配置", str(model_item.model_config_id))
|
||||
|
||||
# 获取 agent_cfg.model_parameters,如果是 ModelParameters 对象则转为字典
|
||||
agent_model_params = agent_cfg.model_parameters
|
||||
if hasattr(agent_model_params, 'model_dump'):
|
||||
agent_model_params = agent_model_params.model_dump()
|
||||
elif not isinstance(agent_model_params, dict):
|
||||
agent_model_params = {}
|
||||
|
||||
# 获取 model_item.model_parameters,如果是 ModelParameters 对象则转为字典
|
||||
item_model_params = model_item.model_parameters
|
||||
if hasattr(item_model_params, 'model_dump'):
|
||||
item_model_params = item_model_params.model_dump()
|
||||
elif not isinstance(item_model_params, dict):
|
||||
item_model_params = {}
|
||||
|
||||
merged_parameters = {
|
||||
**(agent_cfg.model_parameters or {}),
|
||||
**(model_item.model_parameters or {})
|
||||
**(agent_model_params or {}),
|
||||
**(item_model_params or {})
|
||||
}
|
||||
|
||||
model_configs.append({
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from fastapi import APIRouter, Depends
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.response_utils import success
|
||||
from app.db import get_db
|
||||
from app.dependencies import get_current_user
|
||||
@@ -26,4 +27,12 @@ def get_workspace_list(
|
||||
):
|
||||
"""获取工作空间列表"""
|
||||
workspace_list = HomePageService.get_workspace_list(db, current_user.tenant_id)
|
||||
return success(data=workspace_list, msg="工作空间列表获取成功")
|
||||
return success(data=workspace_list, msg="工作空间列表获取成功")
|
||||
|
||||
@router.get("/version", response_model=ApiResponse)
|
||||
def get_system_version():
|
||||
"""获取系统版本号+说明"""
|
||||
return success(data={
|
||||
"version": settings.SYSTEM_VERSION,
|
||||
"introduction": settings.SYSTEM_INTRODUCTION
|
||||
}, msg="系统版本获取成功")
|
||||
312
api/app/controllers/implicit_memory_controller.py
Normal file
312
api/app/controllers/implicit_memory_controller.py
Normal file
@@ -0,0 +1,312 @@
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.core.response_utils import fail, success
|
||||
from app.db import get_db
|
||||
from app.dependencies import (
|
||||
cur_workspace_access_guard,
|
||||
get_current_user,
|
||||
)
|
||||
from app.models.user_model import User
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.services.implicit_memory_service import ImplicitMemoryService
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
api_logger = get_api_logger()
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/memory/implicit-memory",
|
||||
tags=["Implicit Memory"],
|
||||
)
|
||||
|
||||
|
||||
def handle_implicit_memory_error(e: Exception, operation: str, user_id: str = None) -> dict:
|
||||
"""
|
||||
Centralized error handling for implicit memory operations.
|
||||
|
||||
Args:
|
||||
e: The exception that occurred
|
||||
operation: Description of the operation that failed
|
||||
user_id: Optional user ID for logging context
|
||||
|
||||
Returns:
|
||||
Standardized error response
|
||||
"""
|
||||
error_context = f"user_id={user_id}" if user_id else "unknown user"
|
||||
|
||||
if isinstance(e, ValueError):
|
||||
if "user" in str(e).lower() and "not found" in str(e).lower():
|
||||
api_logger.warning(f"Invalid user ID for {operation}: {error_context}")
|
||||
return fail(BizCode.INVALID_USER_ID, "无效的用户ID", str(e))
|
||||
elif "insufficient" in str(e).lower() or "no data" in str(e).lower():
|
||||
api_logger.warning(f"Insufficient data for {operation}: {error_context}")
|
||||
return fail(BizCode.INSUFFICIENT_DATA, "数据不足,无法进行分析", str(e))
|
||||
else:
|
||||
api_logger.warning(f"Invalid parameters for {operation}: {error_context}")
|
||||
return fail(BizCode.INVALID_FILTER_PARAMS, "无效的参数", str(e))
|
||||
|
||||
elif isinstance(e, KeyError):
|
||||
api_logger.warning(f"Missing required data for {operation}: {error_context}")
|
||||
return fail(BizCode.INSUFFICIENT_DATA, "缺少必要的数据", str(e))
|
||||
|
||||
elif isinstance(e, (ConnectionError, TimeoutError)):
|
||||
api_logger.error(f"Service unavailable for {operation}: {error_context}")
|
||||
return fail(BizCode.SERVICE_UNAVAILABLE, "服务暂时不可用", str(e))
|
||||
|
||||
elif "analysis" in str(e).lower() or "llm" in str(e).lower():
|
||||
api_logger.error(f"Analysis failed for {operation}: {error_context}", exc_info=True)
|
||||
return fail(BizCode.ANALYSIS_FAILED, "分析处理失败", str(e))
|
||||
|
||||
elif "storage" in str(e).lower() or "database" in str(e).lower():
|
||||
api_logger.error(f"Storage error for {operation}: {error_context}", exc_info=True)
|
||||
return fail(BizCode.PROFILE_STORAGE_ERROR, "数据存储失败", str(e))
|
||||
|
||||
else:
|
||||
api_logger.error(f"Unexpected error for {operation}: {error_context}", exc_info=True)
|
||||
return fail(BizCode.INTERNAL_ERROR, f"{operation}失败", str(e))
|
||||
|
||||
|
||||
def validate_user_id(user_id: str) -> None:
|
||||
"""
|
||||
Validate user ID format and constraints.
|
||||
|
||||
Args:
|
||||
user_id: User ID to validate
|
||||
|
||||
Raises:
|
||||
ValueError: If user ID is invalid
|
||||
"""
|
||||
if not user_id or not user_id.strip():
|
||||
raise ValueError("User ID cannot be empty")
|
||||
|
||||
if len(user_id.strip()) < 1:
|
||||
raise ValueError("User ID is too short")
|
||||
|
||||
|
||||
def validate_date_range(start_date: Optional[datetime], end_date: Optional[datetime]) -> None:
|
||||
"""
|
||||
Validate date range parameters.
|
||||
|
||||
Args:
|
||||
start_date: Start date
|
||||
end_date: End date
|
||||
|
||||
Raises:
|
||||
ValueError: If date range is invalid
|
||||
"""
|
||||
if (start_date and not end_date) or (end_date and not start_date):
|
||||
raise ValueError("Both start_date and end_date must be provided together")
|
||||
|
||||
if start_date and end_date and start_date >= end_date:
|
||||
raise ValueError("start_date must be before end_date")
|
||||
|
||||
if start_date and start_date > datetime.now():
|
||||
raise ValueError("start_date cannot be in the future")
|
||||
|
||||
|
||||
def validate_confidence_threshold(threshold: float) -> None:
|
||||
"""
|
||||
Validate confidence threshold parameter.
|
||||
|
||||
Args:
|
||||
threshold: Confidence threshold to validate
|
||||
|
||||
Raises:
|
||||
ValueError: If threshold is invalid
|
||||
"""
|
||||
if not 0.0 <= threshold <= 1.0:
|
||||
raise ValueError("confidence_threshold must be between 0.0 and 1.0")
|
||||
|
||||
|
||||
@router.get("/preferences/{user_id}", response_model=ApiResponse)
|
||||
@cur_workspace_access_guard()
|
||||
async def get_preference_tags(
|
||||
user_id: str,
|
||||
confidence_threshold: float = Query(0.5, ge=0.0, le=1.0, description="Minimum confidence threshold"),
|
||||
tag_category: Optional[str] = Query(None, description="Filter by tag category"),
|
||||
start_date: Optional[datetime] = Query(None, description="Filter start date"),
|
||||
end_date: Optional[datetime] = Query(None, description="Filter end date"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
) -> ApiResponse:
|
||||
"""
|
||||
Get user preference tags with filtering options.
|
||||
|
||||
Args:
|
||||
user_id: Target user ID
|
||||
confidence_threshold: Minimum confidence score (0.0-1.0)
|
||||
tag_category: Optional category filter
|
||||
start_date: Optional start date filter
|
||||
end_date: Optional end date filter
|
||||
|
||||
Returns:
|
||||
List of preference tags matching the filters
|
||||
"""
|
||||
api_logger.info(f"Preference tags requested for user: {user_id}")
|
||||
|
||||
try:
|
||||
# Validate inputs
|
||||
validate_user_id(user_id)
|
||||
validate_confidence_threshold(confidence_threshold)
|
||||
validate_date_range(start_date, end_date)
|
||||
|
||||
# Create service with user-specific config
|
||||
service = ImplicitMemoryService(db=db, end_user_id=user_id)
|
||||
|
||||
# Build date range
|
||||
date_range = None
|
||||
if start_date and end_date:
|
||||
from app.schemas.implicit_memory_schema import DateRange
|
||||
date_range = DateRange(start_date=start_date, end_date=end_date)
|
||||
|
||||
# Get preference tags
|
||||
tags = await service.get_preference_tags(
|
||||
user_id=user_id,
|
||||
confidence_threshold=confidence_threshold,
|
||||
tag_category=tag_category,
|
||||
date_range=date_range
|
||||
)
|
||||
|
||||
api_logger.info(f"Retrieved {len(tags)} preference tags for user: {user_id}")
|
||||
return success(data=[tag.model_dump(mode='json') for tag in tags], msg="偏好标签获取成功")
|
||||
|
||||
except Exception as e:
|
||||
return handle_implicit_memory_error(e, "偏好标签获取", user_id)
|
||||
|
||||
|
||||
@router.get("/portrait/{user_id}", response_model=ApiResponse)
|
||||
@cur_workspace_access_guard()
|
||||
async def get_dimension_portrait(
|
||||
user_id: str,
|
||||
include_history: bool = Query(False, description="Include historical trends"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
) -> ApiResponse:
|
||||
"""
|
||||
Get user's four-dimension personality portrait.
|
||||
|
||||
Args:
|
||||
user_id: Target user ID
|
||||
include_history: Whether to include historical trend data
|
||||
|
||||
Returns:
|
||||
Four-dimension personality portrait with scores and evidence
|
||||
"""
|
||||
api_logger.info(f"Dimension portrait requested for user: {user_id}")
|
||||
|
||||
try:
|
||||
# Validate inputs
|
||||
validate_user_id(user_id)
|
||||
|
||||
# Create service with user-specific config
|
||||
service = ImplicitMemoryService(db=db, end_user_id=user_id)
|
||||
|
||||
portrait = await service.get_dimension_portrait(
|
||||
user_id=user_id,
|
||||
include_history=include_history
|
||||
)
|
||||
|
||||
api_logger.info(f"Dimension portrait retrieved for user: {user_id}")
|
||||
return success(data=portrait.model_dump(mode='json'), msg="四维画像获取成功")
|
||||
|
||||
except Exception as e:
|
||||
return handle_implicit_memory_error(e, "四维画像获取", user_id)
|
||||
|
||||
|
||||
@router.get("/interest-areas/{user_id}", response_model=ApiResponse)
|
||||
@cur_workspace_access_guard()
|
||||
async def get_interest_area_distribution(
|
||||
user_id: str,
|
||||
include_trends: bool = Query(False, description="Include trend analysis"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
) -> ApiResponse:
|
||||
"""
|
||||
Get user's interest area distribution across four areas.
|
||||
|
||||
Args:
|
||||
user_id: Target user ID
|
||||
include_trends: Whether to include trend analysis data
|
||||
|
||||
Returns:
|
||||
Interest area distribution with percentages and evidence
|
||||
"""
|
||||
api_logger.info(f"Interest area distribution requested for user: {user_id}")
|
||||
|
||||
try:
|
||||
# Validate inputs
|
||||
validate_user_id(user_id)
|
||||
|
||||
# Create service with user-specific config
|
||||
service = ImplicitMemoryService(db=db, end_user_id=user_id)
|
||||
|
||||
distribution = await service.get_interest_area_distribution(
|
||||
user_id=user_id,
|
||||
include_trends=include_trends
|
||||
)
|
||||
|
||||
api_logger.info(f"Interest area distribution retrieved for user: {user_id}")
|
||||
return success(data=distribution.model_dump(mode='json'), msg="兴趣领域分布获取成功")
|
||||
|
||||
except Exception as e:
|
||||
return handle_implicit_memory_error(e, "兴趣领域分布获取", user_id)
|
||||
|
||||
|
||||
@router.get("/habits/{user_id}", response_model=ApiResponse)
|
||||
@cur_workspace_access_guard()
|
||||
async def get_behavior_habits(
|
||||
user_id: str,
|
||||
confidence_level: Optional[str] = Query(None, regex="^(high|medium|low)$", description="Filter by confidence level"),
|
||||
frequency_pattern: Optional[str] = Query(None, regex="^(daily|weekly|monthly|seasonal|occasional|event_triggered)$", description="Filter by frequency pattern"),
|
||||
time_period: Optional[str] = Query(None, regex="^(current|past)$", description="Filter by time period"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
) -> ApiResponse:
|
||||
"""
|
||||
Get user's behavioral habits with filtering options.
|
||||
|
||||
Args:
|
||||
user_id: Target user ID
|
||||
confidence_level: Filter by confidence level (high, medium, low)
|
||||
frequency_pattern: Filter by frequency pattern (daily, weekly, monthly, seasonal, occasional, event_triggered)
|
||||
time_period: Filter by time period (current, past)
|
||||
|
||||
Returns:
|
||||
List of behavioral habits matching the filters
|
||||
"""
|
||||
api_logger.info(f"Behavior habits requested for user: {user_id}")
|
||||
|
||||
try:
|
||||
# Validate inputs
|
||||
validate_user_id(user_id)
|
||||
|
||||
# Convert string confidence level to numerical
|
||||
numerical_confidence = None
|
||||
if confidence_level:
|
||||
confidence_mapping = {
|
||||
"high": 85,
|
||||
"medium": 50,
|
||||
"low": 20
|
||||
}
|
||||
numerical_confidence = confidence_mapping.get(confidence_level.lower())
|
||||
|
||||
# Create service with user-specific config
|
||||
service = ImplicitMemoryService(db=db, end_user_id=user_id)
|
||||
|
||||
habits = await service.get_behavior_habits(
|
||||
user_id=user_id,
|
||||
confidence_level=numerical_confidence,
|
||||
frequency_pattern=frequency_pattern,
|
||||
time_period=time_period
|
||||
)
|
||||
|
||||
api_logger.info(f"Retrieved {len(habits)} behavior habits for user: {user_id}")
|
||||
return success(data=[habit.model_dump(mode='json') for habit in habits], msg="行为习惯获取成功")
|
||||
|
||||
except Exception as e:
|
||||
return handle_implicit_memory_error(e, "行为习惯获取", user_id)
|
||||
|
||||
|
||||
@@ -76,9 +76,28 @@ async def trigger_forgetting_cycle(
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试触发遗忘周期但未选择工作空间")
|
||||
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
||||
|
||||
# 通过 group_id 获取关联的 config_id
|
||||
try:
|
||||
from app.services.memory_agent_service import get_end_user_connected_config
|
||||
|
||||
connected_config = get_end_user_connected_config(payload.group_id, db)
|
||||
config_id = connected_config.get("memory_config_id")
|
||||
|
||||
if config_id is None:
|
||||
api_logger.warning(f"终端用户 {payload.group_id} 未关联记忆配置")
|
||||
return fail(BizCode.INVALID_PARAMETER, f"终端用户 {payload.group_id} 未关联记忆配置", "memory_config_id is None")
|
||||
|
||||
api_logger.debug(f"通过 group_id={payload.group_id} 获取到 config_id={config_id}")
|
||||
except ValueError as e:
|
||||
api_logger.warning(f"获取终端用户配置失败: {str(e)}")
|
||||
return fail(BizCode.INVALID_PARAMETER, str(e), "ValueError")
|
||||
except Exception as e:
|
||||
api_logger.error(f"获取终端用户配置时发生错误: {str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "获取终端用户配置失败", str(e))
|
||||
|
||||
api_logger.info(
|
||||
f"用户 {current_user.username} 在工作空间 {workspace_id} 请求触发遗忘周期: "
|
||||
f"group_id={payload.group_id}, max_batch={payload.max_merge_batch_size}, "
|
||||
f"group_id={payload.group_id}, config_id={config_id}, max_batch={payload.max_merge_batch_size}, "
|
||||
f"min_days={payload.min_days_since_access}"
|
||||
)
|
||||
|
||||
@@ -89,7 +108,7 @@ async def trigger_forgetting_cycle(
|
||||
group_id=payload.group_id,
|
||||
max_merge_batch_size=payload.max_merge_batch_size,
|
||||
min_days_since_access=payload.min_days_since_access,
|
||||
config_id=payload.config_id
|
||||
config_id=config_id
|
||||
)
|
||||
|
||||
# 构建响应
|
||||
@@ -217,7 +236,6 @@ async def update_forgetting_config(
|
||||
@router.get("/stats", response_model=ApiResponse)
|
||||
async def get_forgetting_stats(
|
||||
group_id: Optional[str] = None,
|
||||
config_id: Optional[int] = None,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
@@ -227,8 +245,7 @@ async def get_forgetting_stats(
|
||||
返回知识层节点统计、激活值分布等信息。
|
||||
|
||||
Args:
|
||||
group_id: 组ID(可选)
|
||||
config_id: 配置ID(可选,用于获取遗忘阈值)
|
||||
group_id: 组ID(即 end_user_id,可选)
|
||||
current_user: 当前用户
|
||||
db: 数据库会话
|
||||
|
||||
@@ -242,6 +259,27 @@ async def get_forgetting_stats(
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试获取遗忘引擎统计但未选择工作空间")
|
||||
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
||||
|
||||
# 如果提供了 group_id,通过它获取 config_id
|
||||
config_id = None
|
||||
if group_id:
|
||||
try:
|
||||
from app.services.memory_agent_service import get_end_user_connected_config
|
||||
|
||||
connected_config = get_end_user_connected_config(group_id, db)
|
||||
config_id = connected_config.get("memory_config_id")
|
||||
|
||||
if config_id is None:
|
||||
api_logger.warning(f"终端用户 {group_id} 未关联记忆配置")
|
||||
return fail(BizCode.INVALID_PARAMETER, f"终端用户 {group_id} 未关联记忆配置", "memory_config_id is None")
|
||||
|
||||
api_logger.debug(f"通过 group_id={group_id} 获取到 config_id={config_id}")
|
||||
except ValueError as e:
|
||||
api_logger.warning(f"获取终端用户配置失败: {str(e)}")
|
||||
return fail(BizCode.INVALID_PARAMETER, str(e), "ValueError")
|
||||
except Exception as e:
|
||||
api_logger.error(f"获取终端用户配置时发生错误: {str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "获取终端用户配置失败", str(e))
|
||||
|
||||
api_logger.info(
|
||||
f"用户 {current_user.username} 在工作空间 {workspace_id} 请求获取遗忘引擎统计: "
|
||||
f"group_id={group_id}, config_id={config_id}"
|
||||
|
||||
255
api/app/controllers/memory_perceptual_controller.py
Normal file
255
api/app/controllers/memory_perceptual_controller.py
Normal file
@@ -0,0 +1,255 @@
|
||||
import uuid
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.core.response_utils import success, fail
|
||||
from app.db import get_db
|
||||
from app.dependencies import get_current_user
|
||||
from app.models import User
|
||||
from app.models.memory_perceptual_model import PerceptualType
|
||||
from app.schemas.memory_perceptual_schema import (
|
||||
PerceptualQuerySchema,
|
||||
PerceptualFilter
|
||||
)
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.services.memory_perceptual_service import MemoryPerceptualService
|
||||
|
||||
api_logger = get_api_logger()
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/memory/perceptual",
|
||||
tags=["Perceptual Memory System"],
|
||||
dependencies=[Depends(get_current_user)]
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{group_id}/count", response_model=ApiResponse)
|
||||
def get_memory_count(
|
||||
group_id: uuid.UUID,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""Retrieve perceptual memory statistics for a user group.
|
||||
|
||||
Args:
|
||||
group_id: ID of the user group (usually end_user_id in this context)
|
||||
current_user: Current authenticated user
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
ApiResponse: Response containing memory count statistics
|
||||
"""
|
||||
api_logger.info(f"Fetching perceptual memory statistics: user={current_user.username}, group_id={group_id}")
|
||||
|
||||
try:
|
||||
service = MemoryPerceptualService(db)
|
||||
count_stats = service.get_memory_count(group_id)
|
||||
|
||||
api_logger.info(f"Memory statistics fetched successfully: total={count_stats.get('total', 0)}")
|
||||
|
||||
return success(
|
||||
data=count_stats,
|
||||
msg="Memory statistics retrieved successfully"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"Failed to fetch memory statistics: group_id={group_id}, error={str(e)}")
|
||||
return fail(
|
||||
code=BizCode.INTERNAL_ERROR,
|
||||
msg="Failed to fetch memory statistics",
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{group_id}/last_visual", response_model=ApiResponse)
|
||||
def get_last_visual_memory(
|
||||
group_id: uuid.UUID,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""Retrieve the most recent VISION-type memory for a user.
|
||||
|
||||
Args:
|
||||
group_id: ID of the user group
|
||||
current_user: Current authenticated user
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
ApiResponse: Metadata of the latest visual memory
|
||||
"""
|
||||
api_logger.info(f"Fetching latest visual memory: user={current_user.username}, group_id={group_id}")
|
||||
|
||||
try:
|
||||
service = MemoryPerceptualService(db)
|
||||
visual_memory = service.get_latest_visual_memory(group_id)
|
||||
|
||||
if visual_memory is None:
|
||||
api_logger.info(f"No visual memory found: group_id={group_id}")
|
||||
return success(
|
||||
data=None,
|
||||
msg="No visual memory available"
|
||||
)
|
||||
|
||||
api_logger.info(f"Latest visual memory retrieved successfully: file={visual_memory.get('file_name')}")
|
||||
|
||||
return success(
|
||||
data=visual_memory,
|
||||
msg="Latest visual memory retrieved successfully"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"Failed to fetch latest visual memory: group_id={group_id}, error={str(e)}")
|
||||
return fail(
|
||||
code=BizCode.INTERNAL_ERROR,
|
||||
msg="Failed to fetch latest visual memory",
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{group_id}/last_listen", response_model=ApiResponse)
|
||||
def get_last_memory_listen(
|
||||
group_id: uuid.UUID,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""Retrieve the most recent AUDIO-type memory for a user.
|
||||
|
||||
Args:
|
||||
group_id: ID of the user group
|
||||
current_user: Current authenticated user
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
ApiResponse: Metadata of the latest audio memory
|
||||
"""
|
||||
api_logger.info(f"Fetching latest audio memory: user={current_user.username}, group_id={group_id}")
|
||||
|
||||
try:
|
||||
service = MemoryPerceptualService(db)
|
||||
audio_memory = service.get_latest_audio_memory(group_id)
|
||||
|
||||
if audio_memory is None:
|
||||
api_logger.info(f"No audio memory found: group_id={group_id}")
|
||||
return success(
|
||||
data=None,
|
||||
msg="No audio memory available"
|
||||
)
|
||||
|
||||
api_logger.info(f"Latest audio memory retrieved successfully: file={audio_memory.get('file_name')}")
|
||||
|
||||
return success(
|
||||
data=audio_memory,
|
||||
msg="Latest audio memory retrieved successfully"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"Failed to fetch latest audio memory: group_id={group_id}, error={str(e)}")
|
||||
return fail(
|
||||
code=BizCode.INTERNAL_ERROR,
|
||||
msg="Failed to fetch latest audio memory",
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{group_id}/last_text", response_model=ApiResponse)
|
||||
def get_last_text_memory(
|
||||
group_id: uuid.UUID,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""Retrieve the most recent TEXT-type memory for a user.
|
||||
|
||||
Args:
|
||||
group_id: ID of the user group
|
||||
current_user: Current authenticated user
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
ApiResponse: Metadata of the latest text memory
|
||||
"""
|
||||
api_logger.info(f"Fetching latest text memory: user={current_user.username}, group_id={group_id}")
|
||||
|
||||
try:
|
||||
# 调用服务层获取最近的文本记忆
|
||||
service = MemoryPerceptualService(db)
|
||||
text_memory = service.get_latest_text_memory(group_id)
|
||||
|
||||
if text_memory is None:
|
||||
api_logger.info(f"No text memory found: group_id={group_id}")
|
||||
return success(
|
||||
data=None,
|
||||
msg="No text memory available"
|
||||
)
|
||||
|
||||
api_logger.info(f"Latest text memory retrieved successfully: file={text_memory.get('file_name')}")
|
||||
|
||||
return success(
|
||||
data=text_memory,
|
||||
msg="Latest text memory retrieved successfully"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"Failed to fetch latest text memory: group_id={group_id}, error={str(e)}")
|
||||
return fail(
|
||||
code=BizCode.INTERNAL_ERROR,
|
||||
msg="Failed to fetch latest text memory",
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{group_id}/timeline", response_model=ApiResponse)
|
||||
def get_memory_time_line(
|
||||
group_id: uuid.UUID,
|
||||
perceptual_type: Optional[PerceptualType] = Query(None, description="感知类型过滤"),
|
||||
page: int = Query(1, ge=1, description="页码"),
|
||||
page_size: int = Query(10, ge=1, le=100, description="每页大小"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""Retrieve a timeline of perceptual memories for a user group.
|
||||
|
||||
Args:
|
||||
group_id: ID of the user group
|
||||
perceptual_type: Optional filter for perceptual type
|
||||
page: Page number for pagination
|
||||
page_size: Number of items per page
|
||||
current_user: Current authenticated user
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
ApiResponse: Timeline data of perceptual memories
|
||||
"""
|
||||
api_logger.info(
|
||||
f"Fetching perceptual memory timeline: user={current_user.username}, "
|
||||
f"group_id={group_id}, type={perceptual_type}, page={page}"
|
||||
)
|
||||
|
||||
try:
|
||||
query = PerceptualQuerySchema(
|
||||
filter=PerceptualFilter(type=perceptual_type),
|
||||
page=page,
|
||||
page_size=page_size
|
||||
)
|
||||
|
||||
service = MemoryPerceptualService(db)
|
||||
timeline_data = service.get_time_line(group_id, query)
|
||||
|
||||
api_logger.info(
|
||||
f"Perceptual memory timeline retrieved successfully: total={timeline_data.total}, "
|
||||
f"returned={len(timeline_data.memories)}"
|
||||
)
|
||||
|
||||
return success(
|
||||
data=timeline_data.model_dump(),
|
||||
msg="Perceptual memory timeline retrieved successfully"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(
|
||||
f"Failed to fetch perceptual memory timeline: group_id={group_id}, "
|
||||
f"error={str(e)}"
|
||||
)
|
||||
return fail(
|
||||
code=BizCode.INTERNAL_ERROR,
|
||||
msg="Failed to fetch perceptual memory timeline",
|
||||
)
|
||||
43
api/app/controllers/memory_short_term_controller.py
Normal file
43
api/app/controllers/memory_short_term_controller.py
Normal file
@@ -0,0 +1,43 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.core.response_utils import success
|
||||
from app.db import get_db
|
||||
from app.dependencies import get_current_user
|
||||
from app.models.user_model import User
|
||||
|
||||
from app.services.memory_storage_service import search_entity
|
||||
from app.services.memory_short_service import ShortService,LongService
|
||||
from dotenv import load_dotenv
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import Optional
|
||||
load_dotenv()
|
||||
api_logger = get_api_logger()
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/memory/short",
|
||||
tags=["Memory"],
|
||||
)
|
||||
@router.get("/short_term")
|
||||
async def short_term_configs(
|
||||
end_user_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
# 获取短期记忆数据
|
||||
short_term=ShortService(end_user_id)
|
||||
short_result=short_term.get_short_databasets()
|
||||
short_count=short_term.get_short_count()
|
||||
|
||||
long_term=LongService(end_user_id)
|
||||
long_result=long_term.get_long_databasets()
|
||||
|
||||
entity_result = await search_entity(end_user_id)
|
||||
result = {
|
||||
'short_term': short_result,
|
||||
'long_term': long_result,
|
||||
'entity': entity_result.get('num', 0),
|
||||
"retrieval_number":short_count,
|
||||
"long_term_number":len(long_result)
|
||||
}
|
||||
|
||||
return success(data=result, msg="短期记忆系统数据获取成功")
|
||||
134
api/app/controllers/memory_working_controller.py
Normal file
134
api/app/controllers/memory_working_controller.py
Normal file
@@ -0,0 +1,134 @@
|
||||
import uuid
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.core.response_utils import success
|
||||
from app.db import get_db
|
||||
from app.dependencies import get_current_user
|
||||
from app.models import User
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.services.conversation_service import ConversationService
|
||||
|
||||
api_logger = get_api_logger()
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/memory/work",
|
||||
tags=["Working Memory System"],
|
||||
dependencies=[Depends(get_current_user)]
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{group_id}/count", response_model=ApiResponse)
|
||||
def get_memory_count(
|
||||
group_id: uuid.UUID,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
@router.get("/{group_id}/conversations", response_model=ApiResponse)
|
||||
def get_conversations(
|
||||
group_id: uuid.UUID,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Retrieve all conversations for the current user in a specific group.
|
||||
|
||||
Args:
|
||||
group_id (UUID): The group identifier.
|
||||
current_user (User, optional): The authenticated user.
|
||||
db (Session, optional): SQLAlchemy session.
|
||||
|
||||
Returns:
|
||||
ApiResponse: Contains a list of conversation IDs.
|
||||
|
||||
Notes:
|
||||
- Initializes the ConversationService with the current DB session.
|
||||
- Returns only conversation IDs for lightweight response.
|
||||
- Logs can be added to trace requests in production.
|
||||
"""
|
||||
conversation_service = ConversationService(db)
|
||||
conversations = conversation_service.get_user_conversations(
|
||||
group_id
|
||||
)
|
||||
return success(data=[
|
||||
{
|
||||
"id": conversation.id,
|
||||
"title": conversation.title
|
||||
} for conversation in conversations
|
||||
], msg="get conversations success")
|
||||
|
||||
|
||||
@router.get("/{group_id}/messages", response_model=ApiResponse)
|
||||
def get_messages(
|
||||
conversation_id: uuid.UUID,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Retrieve the message history for a specific conversation.
|
||||
|
||||
Args:
|
||||
conversation_id (UUID): The ID of the conversation to fetch messages from.
|
||||
current_user (User, optional): The authenticated user.
|
||||
db (Session, optional): SQLAlchemy session.
|
||||
|
||||
Returns:
|
||||
ApiResponse: Contains the list of messages in the conversation.
|
||||
|
||||
Notes:
|
||||
- Uses ConversationService to fetch messages.
|
||||
- Consider paginating results if message history is large.
|
||||
- Logging can be added for audit and debugging.
|
||||
"""
|
||||
conversation_service = ConversationService(db)
|
||||
messages_obj = conversation_service.get_messages(
|
||||
conversation_id,
|
||||
)
|
||||
messages = [
|
||||
{
|
||||
"role": message.role,
|
||||
"content": message.content,
|
||||
"created_at": int(message.created_at.timestamp() * 1000),
|
||||
}
|
||||
for message in messages_obj
|
||||
]
|
||||
return success(data=messages, msg="get conversation history success")
|
||||
|
||||
|
||||
@router.get("/{group_id}/detail", response_model=ApiResponse)
|
||||
async def get_conversation_detail(
|
||||
conversation_id: uuid.UUID,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Retrieve detailed information about a specific conversation.
|
||||
|
||||
This endpoint will fetch the conversation detail for the user. If the detail
|
||||
does not exist or is outdated, it will trigger the LLM to generate a new summary.
|
||||
|
||||
Args:
|
||||
conversation_id (UUID): The ID of the conversation.
|
||||
current_user (User, optional): The authenticated user making the request.
|
||||
db (Session, optional): SQLAlchemy session.
|
||||
|
||||
Returns:
|
||||
ApiResponse: Contains the conversation detail serialized as a dictionary.
|
||||
|
||||
Notes:
|
||||
- Uses async ConversationService to fetch or generate the conversation detail.
|
||||
- Handles workspace and user-specific context automatically.
|
||||
- Logging and exception handling should be implemented for production monitoring.
|
||||
"""
|
||||
conversation_service = ConversationService(db)
|
||||
detail = await conversation_service.get_conversation_detail(
|
||||
user=current_user,
|
||||
conversation_id=conversation_id,
|
||||
workspace_id=current_user.current_workspace_id
|
||||
)
|
||||
return success(data=detail.model_dump(), msg="get conversation detail success")
|
||||
@@ -108,16 +108,23 @@ async def get_prompt_opt(
|
||||
service = PromptOptimizerService(db)
|
||||
|
||||
async def event_generator():
|
||||
async for chunk in service.optimize_prompt(
|
||||
tenant_id=current_user.tenant_id,
|
||||
model_id=data.model_id,
|
||||
session_id=session_id,
|
||||
user_id=current_user.id,
|
||||
current_prompt=data.current_prompt,
|
||||
user_require=data.message
|
||||
):
|
||||
# chunk 是 prompt 的增量内容
|
||||
yield f"event:'message'\ndata: {json.dumps(chunk)}\n\n"
|
||||
yield "event:start\ndata: {}\n\n"
|
||||
try:
|
||||
async for chunk in service.optimize_prompt(
|
||||
tenant_id=current_user.tenant_id,
|
||||
model_id=data.model_id,
|
||||
session_id=session_id,
|
||||
user_id=current_user.id,
|
||||
current_prompt=data.current_prompt,
|
||||
user_require=data.message
|
||||
):
|
||||
# chunk 是 prompt 的增量内容
|
||||
yield f"event:message\ndata: {json.dumps(chunk)}\n\n"
|
||||
except Exception as e:
|
||||
yield f"event:error\ndata: {json.dumps(
|
||||
{"error": str(e)}
|
||||
)}\n\n"
|
||||
yield "event:end\ndata: {}\n\n"
|
||||
|
||||
return StreamingResponse(
|
||||
event_generator(),
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import hashlib
|
||||
import json
|
||||
import uuid
|
||||
from typing import Annotated
|
||||
from fastapi import APIRouter, Depends, Query, Request
|
||||
@@ -18,7 +19,7 @@ from app.services.conversation_service import ConversationService
|
||||
from app.services.release_share_service import ReleaseShareService
|
||||
from app.services.shared_chat_service import SharedChatService
|
||||
from app.services.app_chat_service import AppChatService, get_app_chat_service
|
||||
from app.utils.app_config_utils import dict_to_multi_agent_config, dict_to_workflow_config, agent_config_4_app_release, multi_agent_config_4_app_release
|
||||
from app.utils.app_config_utils import dict_to_multi_agent_config, workflow_config_4_app_release, agent_config_4_app_release, multi_agent_config_4_app_release
|
||||
|
||||
router = APIRouter(prefix="/public/share", tags=["Public Share"])
|
||||
logger = get_business_logger()
|
||||
@@ -288,7 +289,7 @@ async def chat(
|
||||
password = None # Token 认证不需要密码
|
||||
# end_user_id = user_id
|
||||
other_id = user_id
|
||||
|
||||
|
||||
# 提前验证和准备(在流式响应开始前完成)
|
||||
# 这样可以确保错误能正确返回,而不是在流式响应中间出错
|
||||
from app.models.app_model import AppType
|
||||
@@ -364,6 +365,9 @@ async def chat(
|
||||
config = release.config or {}
|
||||
if not config.get("sub_agents"):
|
||||
raise BusinessException("多 Agent 应用未配置子 Agent", BizCode.AGENT_CONFIG_MISSING)
|
||||
elif app_type == AppType.WORKFLOW:
|
||||
# Multi-Agent 类型:验证多 Agent 配置
|
||||
pass
|
||||
else:
|
||||
raise BusinessException(f"不支持的应用类型: {app_type}", BizCode.APP_TYPE_NOT_SUPPORTED)
|
||||
|
||||
@@ -392,6 +396,8 @@ async def chat(
|
||||
|
||||
if app_type == AppType.AGENT:
|
||||
# 流式返回
|
||||
agent_config = agent_config_4_app_release(release)
|
||||
|
||||
if payload.stream:
|
||||
# async def event_generator():
|
||||
# async for event in service.chat_stream(
|
||||
@@ -424,10 +430,11 @@ async def chat(
|
||||
user_id= str(new_end_user.id), # 转换为字符串
|
||||
variables=payload.variables,
|
||||
web_search=payload.web_search,
|
||||
config=payload.agent_config,
|
||||
config=agent_config,
|
||||
memory=payload.memory,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
workspace_id=workspace_id
|
||||
):
|
||||
yield event
|
||||
|
||||
@@ -463,10 +470,12 @@ async def chat(
|
||||
web_search=payload.web_search,
|
||||
memory=payload.memory,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
workspace_id=workspace_id
|
||||
)
|
||||
return success(data=conversation_schema.ChatResponse(**result).model_dump(mode="json"))
|
||||
elif app_type == AppType.MULTI_AGENT:
|
||||
# config = workflow_config_4_app_release(release)
|
||||
config = multi_agent_config_4_app_release(release)
|
||||
if payload.stream:
|
||||
async def event_generator():
|
||||
@@ -479,8 +488,8 @@ async def chat(
|
||||
config=config,
|
||||
web_search=payload.web_search,
|
||||
memory=payload.memory,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id
|
||||
):
|
||||
yield event
|
||||
|
||||
@@ -551,8 +560,71 @@ async def chat(
|
||||
# )
|
||||
|
||||
# return success(data=conversation_schema.ChatResponse(**result))
|
||||
elif app_type == AppType.WORKFLOW:
|
||||
|
||||
config = workflow_config_4_app_release(release)
|
||||
if payload.stream:
|
||||
async def event_generator():
|
||||
async for event in app_chat_service.workflow_chat_stream(
|
||||
|
||||
message=payload.message,
|
||||
conversation_id=conversation.id, # 使用已创建的会话 ID
|
||||
user_id=new_end_user.id, # 转换为字符串
|
||||
variables=payload.variables,
|
||||
config=config,
|
||||
web_search=payload.web_search,
|
||||
memory=payload.memory,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
app_id=release.app_id,
|
||||
workspace_id=workspace_id
|
||||
):
|
||||
event_type = event.get("event", "message")
|
||||
event_data = event.get("data", {})
|
||||
|
||||
# 转换为标准 SSE 格式(字符串)
|
||||
sse_message = f"event: {event_type}\ndata: {json.dumps(event_data)}\n\n"
|
||||
yield sse_message
|
||||
|
||||
return StreamingResponse(
|
||||
event_generator(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no"
|
||||
}
|
||||
)
|
||||
|
||||
# 多 Agent 非流式返回
|
||||
result = await app_chat_service.workflow_chat(
|
||||
|
||||
message=payload.message,
|
||||
conversation_id=conversation.id, # 使用已创建的会话 ID
|
||||
user_id=new_end_user.id, # 转换为字符串
|
||||
variables=payload.variables,
|
||||
config=config,
|
||||
web_search=payload.web_search,
|
||||
memory=payload.memory,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
app_id=release.app_id,
|
||||
workspace_id=workspace_id
|
||||
)
|
||||
logger.debug(
|
||||
"工作流试运行返回结果",
|
||||
extra={
|
||||
"result_type": str(type(result)),
|
||||
"has_response": "response" in result if isinstance(result, dict) else False
|
||||
}
|
||||
)
|
||||
return success(
|
||||
data=result,
|
||||
msg="工作流任务执行成功"
|
||||
)
|
||||
# return success(data=conversation_schema.ChatResponse(**result).model_dump(mode="json"))
|
||||
|
||||
else:
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.error_codes import BizCode
|
||||
raise BusinessException(f"不支持的应用类型: {app_type}", BizCode.APP_TYPE_NOT_SUPPORTED)
|
||||
pass
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
"""App 服务接口 - 基于 API Key 认证"""
|
||||
import json
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends, Request, Body
|
||||
@@ -21,7 +22,7 @@ from app.schemas.api_key_schema import ApiKeyAuth
|
||||
from app.services import workspace_service
|
||||
from app.services.app_chat_service import AppChatService, get_app_chat_service
|
||||
from app.services.conversation_service import ConversationService, get_conversation_service
|
||||
from app.utils.app_config_utils import dict_to_multi_agent_config, dict_to_workflow_config, agent_config_4_app_release, multi_agent_config_4_app_release
|
||||
from app.utils.app_config_utils import dict_to_multi_agent_config, workflow_config_4_app_release, agent_config_4_app_release, multi_agent_config_4_app_release
|
||||
from app.services.app_service import get_app_service, AppService
|
||||
|
||||
router = APIRouter(prefix="/app", tags=["V1 - App API"])
|
||||
@@ -153,7 +154,8 @@ async def chat(
|
||||
config=agent_config,
|
||||
memory=memory,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
workspace_id=workspace_id
|
||||
):
|
||||
yield event
|
||||
|
||||
@@ -177,7 +179,8 @@ async def chat(
|
||||
web_search=web_search,
|
||||
memory=memory,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
workspace_id=workspace_id
|
||||
)
|
||||
return success(data=conversation_schema.ChatResponse(**result).model_dump(mode="json"))
|
||||
elif app_type == AppType.MULTI_AGENT:
|
||||
@@ -194,8 +197,8 @@ async def chat(
|
||||
config=config,
|
||||
web_search=web_search,
|
||||
memory=memory,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id
|
||||
):
|
||||
yield event
|
||||
|
||||
@@ -211,7 +214,6 @@ async def chat(
|
||||
|
||||
# 多 Agent 非流式返回
|
||||
result = await app_chat_service.multi_agent_chat(
|
||||
|
||||
message=payload.message,
|
||||
conversation_id=conversation.id, # 使用已创建的会话 ID
|
||||
user_id=end_user_id, # 转换为字符串
|
||||
@@ -226,22 +228,29 @@ async def chat(
|
||||
return success(data=conversation_schema.ChatResponse(**result).model_dump(mode="json"))
|
||||
elif app_type == AppType.WORKFLOW:
|
||||
# 多 Agent 流式返回
|
||||
config = dict_to_workflow_config(app.current_release.config,app.id)
|
||||
config = workflow_config_4_app_release(app.current_release)
|
||||
if payload.stream:
|
||||
async def event_generator():
|
||||
async for event in app_chat_service.workflow_chat_stream(
|
||||
|
||||
message=payload.message,
|
||||
conversation_id=conversation.id, # 使用已创建的会话 ID
|
||||
user_id=end_user_id, # 转换为字符串
|
||||
user_id=new_end_user.id, # 转换为字符串
|
||||
variables=payload.variables,
|
||||
config=config,
|
||||
web_search=web_search,
|
||||
memory=memory,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id
|
||||
web_search=payload.web_search,
|
||||
memory=payload.memory,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
app_id=app.app_id,
|
||||
workspace_id=workspace_id
|
||||
):
|
||||
yield event
|
||||
event_type = event.get("event", "message")
|
||||
event_data = event.get("data", {})
|
||||
|
||||
# 转换为标准 SSE 格式(字符串)
|
||||
sse_message = f"event: {event_type}\ndata: {json.dumps(event_data)}\n\n"
|
||||
yield sse_message
|
||||
|
||||
return StreamingResponse(
|
||||
event_generator(),
|
||||
@@ -253,23 +262,34 @@ async def chat(
|
||||
}
|
||||
)
|
||||
|
||||
# 非流式返回
|
||||
# 多 Agent 非流式返回
|
||||
result = await app_chat_service.workflow_chat(
|
||||
|
||||
message=payload.message,
|
||||
conversation_id=conversation.id, # 使用已创建的会话 ID
|
||||
user_id=end_user_id, # 转换为字符串
|
||||
user_id=new_end_user.id, # 转换为字符串
|
||||
variables=payload.variables,
|
||||
config=config,
|
||||
web_search=web_search,
|
||||
memory=memory,
|
||||
web_search=payload.web_search,
|
||||
memory=payload.memory,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
app_id=app.app_id,
|
||||
workspace_id=workspace_id
|
||||
)
|
||||
logger.debug(
|
||||
"工作流试运行返回结果",
|
||||
extra={
|
||||
"result_type": str(type(result)),
|
||||
"has_response": "response" in result if isinstance(result, dict) else False
|
||||
}
|
||||
)
|
||||
return success(
|
||||
data=result,
|
||||
msg="工作流任务执行成功"
|
||||
)
|
||||
|
||||
return success(data=conversation_schema.ChatResponse(**result).model_dump(mode="json"))
|
||||
else:
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.error_codes import BizCode
|
||||
raise BusinessException(f"不支持的应用类型: {app_type}", BizCode.APP_TYPE_NOT_SUPPORTED)
|
||||
pass
|
||||
|
||||
|
||||
@@ -1,23 +1,22 @@
|
||||
from fastapi import APIRouter, Depends, status, Query, HTTPException
|
||||
from langchain_core.messages import HumanMessage, SystemMessage
|
||||
from fastapi import APIRouter, Depends, status, HTTPException, Body, Path
|
||||
from fastapi.responses import StreamingResponse
|
||||
from langchain_core.prompts import ChatPromptTemplate
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import List, Optional
|
||||
import uuid
|
||||
|
||||
|
||||
from app.core.models import RedBearLLM, RedBearRerank
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
from app.core.models.embedding import RedBearEmbeddings
|
||||
from app.db import get_db
|
||||
from app.dependencies import get_current_user
|
||||
from app.models.models_model import ModelApiKey, ModelProvider, ModelType
|
||||
from app.models.user_model import User
|
||||
from app.schemas import model_schema
|
||||
from app.models.models_model import ModelApiKey
|
||||
from app.core.response_utils import success
|
||||
from app.schemas.response_schema import ApiResponse, PageData
|
||||
from app.services.model_service import ModelConfigService, ModelApiKeyService
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.schemas.app_schema import AppChatRequest
|
||||
from app.services.model_service import ModelConfigService
|
||||
from app.services.handoffs_service import get_handoffs_service_for_app, reset_handoffs_service_cache
|
||||
from app.services.conversation_service import ConversationService
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.dependencies import get_current_user
|
||||
|
||||
# 获取API专用日志器
|
||||
api_logger = get_api_logger()
|
||||
@@ -28,6 +27,8 @@ router = APIRouter(
|
||||
)
|
||||
|
||||
|
||||
# ==================== 原有测试接口 ====================
|
||||
|
||||
@router.get("/llm/{model_id}", response_model=ApiResponse)
|
||||
def test_llm(
|
||||
model_id: uuid.UUID,
|
||||
@@ -50,7 +51,6 @@ def test_llm(
|
||||
template = """Question: {question}
|
||||
|
||||
Answer: Let's think step by step."""
|
||||
# ChatPromptTemplate
|
||||
prompt = ChatPromptTemplate.from_template(template)
|
||||
chain = prompt | llm
|
||||
answer = chain.invoke({"question": "What is LangChain?"})
|
||||
@@ -80,13 +80,13 @@ def test_embedding(
|
||||
base_url=apiConfig.api_base
|
||||
))
|
||||
|
||||
data = [
|
||||
"最近哪家咖啡店评价最好?",
|
||||
"附近有没有推荐的咖啡厅?",
|
||||
"明天天气预报说会下雨。",
|
||||
"北京是中国的首都。",
|
||||
"我想找一个适合学习的地方。"
|
||||
]
|
||||
data = [
|
||||
"最近哪家咖啡店评价最好?",
|
||||
"附近有没有推荐的咖啡厅?",
|
||||
"明天天气预报说会下雨。",
|
||||
"北京是中国的首都。",
|
||||
"我想找一个适合学习的地方。"
|
||||
]
|
||||
embeddings = model.embed_documents(data)
|
||||
print(embeddings)
|
||||
query = "我想找一个适合学习的地方。"
|
||||
@@ -114,13 +114,123 @@ def test_rerank(
|
||||
base_url=apiConfig.api_base
|
||||
))
|
||||
query = "最近哪家咖啡店评价最好?"
|
||||
data = [
|
||||
"最近哪家咖啡店评价最好?",
|
||||
"附近有没有推荐的咖啡厅?",
|
||||
"明天天气预报说会下雨。",
|
||||
"北京是中国的首都。",
|
||||
"我想找一个适合学习的地方。"
|
||||
]
|
||||
data = [
|
||||
"最近哪家咖啡店评价最好?",
|
||||
"附近有没有推荐的咖啡厅?",
|
||||
"明天天气预报说会下雨。",
|
||||
"北京是中国的首都。",
|
||||
"我想找一个适合学习的地方。"
|
||||
]
|
||||
scores = model.rerank(query=query, documents=data, top_n=3)
|
||||
print(scores)
|
||||
return success(msg="测试Rerank成功", data={"query": query, "documents": data, "scores": scores})
|
||||
|
||||
|
||||
# ==================== Handoffs 测试接口 ====================
|
||||
|
||||
@router.post("/handoffs/{app_id}")
|
||||
async def test_handoffs(
|
||||
app_id: uuid.UUID = Path(..., description="应用 ID"),
|
||||
request: AppChatRequest = Body(...),
|
||||
current_user=Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""测试 Agent Handoffs 功能
|
||||
|
||||
演示 LangGraph 实现的多 Agent 协作和动态切换
|
||||
|
||||
- 从数据库 multi_agent_config 获取 Agent 配置
|
||||
- 根据用户问题自动切换到合适的 Agent
|
||||
- 使用 conversation_id 保持会话状态
|
||||
- 通过 stream 参数控制是否流式输出
|
||||
|
||||
事件类型(流式):
|
||||
- start: 开始执行
|
||||
- agent: 当前 Agent 信息
|
||||
- message: 流式消息内容
|
||||
- handoff: Agent 切换事件
|
||||
- end: 执行结束
|
||||
- error: 错误信息
|
||||
"""
|
||||
try:
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
# 获取或创建会话
|
||||
conversation_service = ConversationService(db)
|
||||
|
||||
if request.conversation_id:
|
||||
# 验证会话存在
|
||||
conversation = conversation_service.get_conversation(uuid.UUID(request.conversation_id))
|
||||
if not conversation:
|
||||
raise HTTPException(status_code=404, detail="会话不存在")
|
||||
conversation_id = str(conversation.id)
|
||||
else:
|
||||
# 创建新会话
|
||||
conversation = conversation_service.create_or_get_conversation(
|
||||
app_id=app_id,
|
||||
workspace_id=workspace_id,
|
||||
user_id=request.user_id,
|
||||
is_draft=True
|
||||
)
|
||||
conversation_id = str(conversation.id)
|
||||
|
||||
# 根据 stream 参数决定返回方式
|
||||
if request.stream:
|
||||
# 流式返回
|
||||
service = get_handoffs_service_for_app(app_id, db, streaming=True)
|
||||
return StreamingResponse(
|
||||
service.chat_stream(
|
||||
message=request.message,
|
||||
conversation_id=conversation_id
|
||||
),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no"
|
||||
}
|
||||
)
|
||||
else:
|
||||
# 非流式返回
|
||||
service = get_handoffs_service_for_app(app_id, db, streaming=False)
|
||||
result = await service.chat(
|
||||
message=request.message,
|
||||
conversation_id=conversation_id
|
||||
)
|
||||
return success(data=result, msg="Handoffs 测试成功")
|
||||
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
api_logger.error(f"Handoffs 测试失败: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/handoffs/{app_id}/agents", response_model=ApiResponse)
|
||||
def get_handoff_agents(
|
||||
app_id: uuid.UUID = Path(..., description="应用 ID"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user)
|
||||
):
|
||||
"""获取应用的 Handoff Agent 列表"""
|
||||
try:
|
||||
service = get_handoffs_service_for_app(app_id, db, streaming=False)
|
||||
agents = service.get_agents()
|
||||
return success(data={"agents": agents}, msg="获取 Agent 列表成功")
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
api_logger.error(f"获取 Agent 列表失败: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.delete("/handoffs/{app_id}/reset")
|
||||
def reset_handoff_service(
|
||||
app_id: uuid.UUID = Path(..., description="应用 ID"),
|
||||
current_user=Depends(get_current_user)
|
||||
):
|
||||
"""重置指定应用的 Handoff 服务缓存"""
|
||||
reset_handoffs_service_cache(app_id)
|
||||
return success(msg="Handoff 服务已重置")
|
||||
|
||||
@@ -215,8 +215,8 @@ async def sync_mcp_tools(
|
||||
"""同步MCP工具列表"""
|
||||
try:
|
||||
result = await service.sync_mcp_tools(tool_id, current_user.tenant_id)
|
||||
if result["success"] is False:
|
||||
raise HTTPException(status_code=404, detail=result["message"])
|
||||
if not result.get("success", False):
|
||||
raise HTTPException(status_code=400, detail=result.get("message", "同步失败"))
|
||||
return success(data=result, msg="MCP工具列表同步完成")
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@@ -11,14 +11,22 @@ from app.db import get_db
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.core.response_utils import success, fail
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.api_key_utils import timestamp_to_datetime
|
||||
from app.services.user_memory_service import (
|
||||
UserMemoryService,
|
||||
analytics_node_statistics,
|
||||
analytics_memory_types,
|
||||
analytics_graph_data,
|
||||
)
|
||||
from app.services.memory_entity_relationship_service import MemoryEntityService,MemoryEmotion,MemoryInteraction
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.schemas.memory_storage_schema import GenerateCacheRequest
|
||||
from app.schemas.user_memory_schema import (
|
||||
EpisodicMemoryOverviewRequest,
|
||||
EpisodicMemoryDetailsRequest,
|
||||
ExplicitMemoryOverviewRequest,
|
||||
ExplicitMemoryDetailsRequest,
|
||||
)
|
||||
|
||||
from app.schemas.end_user_schema import (
|
||||
EndUserProfileResponse,
|
||||
EndUserProfileUpdate,
|
||||
@@ -41,24 +49,27 @@ router = APIRouter(
|
||||
|
||||
@router.get("/analytics/memory_insight/report", response_model=ApiResponse)
|
||||
async def get_memory_insight_report_api(
|
||||
end_user_id: str, # 使用 end_user_id
|
||||
end_user_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
"""获取缓存的记忆洞察报告"""
|
||||
api_logger.info(f"记忆洞察报告请求: end_user_id={end_user_id}, user={current_user.username}")
|
||||
) -> dict:
|
||||
"""
|
||||
获取缓存的记忆洞察报告
|
||||
|
||||
此接口仅查询数据库中已缓存的记忆洞察数据,不执行生成操作。
|
||||
如需生成新的洞察报告,请使用专门的生成接口。
|
||||
"""
|
||||
api_logger.info(f"记忆洞察报告查询请求: end_user_id={end_user_id}, user={current_user.username}")
|
||||
try:
|
||||
# 调用服务层获取缓存数据
|
||||
result = await user_memory_service.get_cached_memory_insight(db, end_user_id)
|
||||
|
||||
|
||||
if result["is_cached"]:
|
||||
# 缓存存在,返回缓存数据
|
||||
api_logger.info(f"成功返回缓存的记忆洞察报告: end_user_id={end_user_id}")
|
||||
return success(data=result, msg="查询成功")
|
||||
else:
|
||||
# 缓存不存在,返回提示消息
|
||||
api_logger.info(f"记忆洞察报告缓存不存在: end_user_id={end_user_id}")
|
||||
return success(data=result, msg="查询成功")
|
||||
return success(data=result, msg="数据尚未生成")
|
||||
except Exception as e:
|
||||
api_logger.error(f"记忆洞察报告查询失败: end_user_id={end_user_id}, error={str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "记忆洞察报告查询失败", str(e))
|
||||
@@ -66,24 +77,27 @@ async def get_memory_insight_report_api(
|
||||
|
||||
@router.get("/analytics/user_summary", response_model=ApiResponse)
|
||||
async def get_user_summary_api(
|
||||
end_user_id: str, # 使用 end_user_id
|
||||
end_user_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
"""获取缓存的用户摘要"""
|
||||
api_logger.info(f"用户摘要请求: end_user_id={end_user_id}, user={current_user.username}")
|
||||
) -> dict:
|
||||
"""
|
||||
获取缓存的用户摘要
|
||||
|
||||
此接口仅查询数据库中已缓存的用户摘要数据,不执行生成操作。
|
||||
如需生成新的用户摘要,请使用专门的生成接口。
|
||||
"""
|
||||
api_logger.info(f"用户摘要查询请求: end_user_id={end_user_id}, user={current_user.username}")
|
||||
try:
|
||||
# 调用服务层获取缓存数据
|
||||
result = await user_memory_service.get_cached_user_summary(db, end_user_id)
|
||||
|
||||
|
||||
if result["is_cached"]:
|
||||
# 缓存存在,返回缓存数据
|
||||
api_logger.info(f"成功返回缓存的用户摘要: end_user_id={end_user_id}")
|
||||
return success(data=result, msg="查询成功")
|
||||
else:
|
||||
# 缓存不存在,返回提示消息
|
||||
api_logger.info(f"用户摘要缓存不存在: end_user_id={end_user_id}")
|
||||
return success(data=result, msg="查询成功")
|
||||
return success(data=result, msg="数据尚未生成")
|
||||
except Exception as e:
|
||||
api_logger.error(f"用户摘要查询失败: end_user_id={end_user_id}, error={str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "用户摘要查询失败", str(e))
|
||||
@@ -97,35 +111,35 @@ async def generate_cache_api(
|
||||
) -> dict:
|
||||
"""
|
||||
手动触发缓存生成
|
||||
|
||||
|
||||
- 如果提供 end_user_id,只为该用户生成
|
||||
- 如果不提供,为当前工作空间的所有用户生成
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
|
||||
# 检查用户是否已选择工作空间
|
||||
if workspace_id is None:
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试生成缓存但未选择工作空间")
|
||||
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
||||
|
||||
|
||||
group_id = request.end_user_id
|
||||
|
||||
|
||||
api_logger.info(
|
||||
f"缓存生成请求: user={current_user.username}, workspace={workspace_id}, "
|
||||
f"end_user_id={group_id if group_id else '全部用户'}"
|
||||
)
|
||||
|
||||
|
||||
try:
|
||||
if group_id:
|
||||
# 为单个用户生成
|
||||
api_logger.info(f"开始为单个用户生成缓存: end_user_id={group_id}")
|
||||
|
||||
|
||||
# 生成记忆洞察
|
||||
insight_result = await user_memory_service.generate_and_cache_insight(db, group_id, workspace_id)
|
||||
|
||||
|
||||
# 生成用户摘要
|
||||
summary_result = await user_memory_service.generate_and_cache_summary(db, group_id, workspace_id)
|
||||
|
||||
|
||||
# 构建响应
|
||||
result = {
|
||||
"end_user_id": group_id,
|
||||
@@ -133,7 +147,7 @@ async def generate_cache_api(
|
||||
"summary_success": summary_result["success"],
|
||||
"errors": []
|
||||
}
|
||||
|
||||
|
||||
# 收集错误信息
|
||||
if not insight_result["success"]:
|
||||
result["errors"].append({
|
||||
@@ -145,29 +159,29 @@ async def generate_cache_api(
|
||||
"type": "summary",
|
||||
"error": summary_result.get("error")
|
||||
})
|
||||
|
||||
|
||||
# 记录结果
|
||||
if result["insight_success"] and result["summary_success"]:
|
||||
api_logger.info(f"成功为用户 {group_id} 生成缓存")
|
||||
else:
|
||||
api_logger.warning(f"用户 {group_id} 的缓存生成部分失败: {result['errors']}")
|
||||
|
||||
|
||||
return success(data=result, msg="生成完成")
|
||||
|
||||
|
||||
else:
|
||||
# 为整个工作空间生成
|
||||
api_logger.info(f"开始为工作空间 {workspace_id} 批量生成缓存")
|
||||
|
||||
|
||||
result = await user_memory_service.generate_cache_for_workspace(db, workspace_id)
|
||||
|
||||
|
||||
# 记录统计信息
|
||||
api_logger.info(
|
||||
f"工作空间 {workspace_id} 批量生成完成: "
|
||||
f"总数={result['total_users']}, 成功={result['successful']}, 失败={result['failed']}"
|
||||
)
|
||||
|
||||
|
||||
return success(data=result, msg="批量生成完成")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"缓存生成失败: user={current_user.username}, error={str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "缓存生成失败", str(e))
|
||||
@@ -180,18 +194,18 @@ async def get_node_statistics_api(
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
|
||||
# 检查用户是否已选择工作空间
|
||||
if workspace_id is None:
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试查询节点统计但未选择工作空间")
|
||||
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
||||
|
||||
|
||||
api_logger.info(f"记忆类型统计请求: end_user_id={end_user_id}, user={current_user.username}, workspace={workspace_id}")
|
||||
|
||||
|
||||
try:
|
||||
# 调用新的记忆类型统计函数
|
||||
result = await analytics_memory_types(db, end_user_id)
|
||||
|
||||
|
||||
# 计算总数用于日志
|
||||
total_count = sum(item["count"] for item in result)
|
||||
api_logger.info(f"成功获取记忆类型统计: end_user_id={end_user_id}, 总记忆数={total_count}, 类型数={len(result)}")
|
||||
@@ -211,31 +225,31 @@ async def get_graph_data_api(
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
|
||||
# 检查用户是否已选择工作空间
|
||||
if workspace_id is None:
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试查询图数据但未选择工作空间")
|
||||
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
||||
|
||||
|
||||
# 参数验证
|
||||
if limit > 1000:
|
||||
limit = 1000
|
||||
api_logger.warning("limit 参数超过最大值,已调整为 1000")
|
||||
|
||||
|
||||
if depth > 3:
|
||||
depth = 3
|
||||
api_logger.warning("depth 参数超过最大值,已调整为 3")
|
||||
|
||||
|
||||
# 解析 node_types 参数
|
||||
node_types_list = None
|
||||
if node_types:
|
||||
node_types_list = [t.strip() for t in node_types.split(",") if t.strip()]
|
||||
|
||||
|
||||
api_logger.info(
|
||||
f"图数据查询请求: end_user_id={end_user_id}, user={current_user.username}, "
|
||||
f"workspace={workspace_id}, node_types={node_types_list}, limit={limit}, depth={depth}"
|
||||
)
|
||||
|
||||
|
||||
try:
|
||||
result = await analytics_graph_data(
|
||||
db=db,
|
||||
@@ -245,19 +259,19 @@ async def get_graph_data_api(
|
||||
depth=depth,
|
||||
center_node_id=center_node_id
|
||||
)
|
||||
|
||||
|
||||
# 检查是否有错误消息
|
||||
if "message" in result and result["statistics"]["total_nodes"] == 0:
|
||||
api_logger.warning(f"图数据查询返回空结果: {result.get('message')}")
|
||||
return success(data=result, msg=result.get("message", "查询成功"))
|
||||
|
||||
|
||||
api_logger.info(
|
||||
f"成功获取图数据: end_user_id={end_user_id}, "
|
||||
f"nodes={result['statistics']['total_nodes']}, "
|
||||
f"edges={result['statistics']['total_edges']}"
|
||||
)
|
||||
return success(data=result, msg="查询成功")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"图数据查询失败: end_user_id={end_user_id}, error={str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "图数据查询失败", str(e))
|
||||
@@ -270,25 +284,25 @@ async def get_end_user_profile(
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
|
||||
# 检查用户是否已选择工作空间
|
||||
if workspace_id is None:
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试查询用户信息但未选择工作空间")
|
||||
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
||||
|
||||
|
||||
api_logger.info(
|
||||
f"用户信息查询请求: end_user_id={end_user_id}, user={current_user.username}, "
|
||||
f"workspace={workspace_id}"
|
||||
)
|
||||
|
||||
|
||||
try:
|
||||
# 查询终端用户
|
||||
end_user = db.query(EndUser).filter(EndUser.id == end_user_id).first()
|
||||
|
||||
|
||||
if not end_user:
|
||||
api_logger.warning(f"终端用户不存在: end_user_id={end_user_id}")
|
||||
return fail(BizCode.INVALID_PARAMETER, "终端用户不存在", f"end_user_id={end_user_id}")
|
||||
|
||||
|
||||
# 构建响应数据
|
||||
profile_data = EndUserProfileResponse(
|
||||
id=end_user.id,
|
||||
@@ -300,10 +314,10 @@ async def get_end_user_profile(
|
||||
hire_date=end_user.hire_date,
|
||||
updatetime_profile=end_user.updatetime_profile
|
||||
)
|
||||
|
||||
|
||||
api_logger.info(f"成功获取用户信息: end_user_id={end_user_id}")
|
||||
return success(data=UserMemoryService.convert_profile_to_dict_with_timestamp(profile_data), msg="查询成功")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"用户信息查询失败: end_user_id={end_user_id}, error={str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "用户信息查询失败", str(e))
|
||||
@@ -317,56 +331,56 @@ async def update_end_user_profile(
|
||||
) -> dict:
|
||||
"""
|
||||
更新终端用户的基本信息
|
||||
|
||||
|
||||
该接口可以更新用户的姓名、职位、部门、联系方式、电话和入职日期等信息。
|
||||
所有字段都是可选的,只更新提供的字段。
|
||||
|
||||
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
end_user_id = profile_update.end_user_id
|
||||
|
||||
|
||||
# 检查用户是否已选择工作空间
|
||||
if workspace_id is None:
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试更新用户信息但未选择工作空间")
|
||||
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
||||
|
||||
|
||||
api_logger.info(
|
||||
f"用户信息更新请求: end_user_id={end_user_id}, user={current_user.username}, "
|
||||
f"workspace={workspace_id}"
|
||||
)
|
||||
|
||||
|
||||
try:
|
||||
# 查询终端用户
|
||||
end_user = db.query(EndUser).filter(EndUser.id == end_user_id).first()
|
||||
|
||||
|
||||
if not end_user:
|
||||
api_logger.warning(f"终端用户不存在: end_user_id={end_user_id}")
|
||||
return fail(BizCode.INVALID_PARAMETER, "终端用户不存在", f"end_user_id={end_user_id}")
|
||||
|
||||
|
||||
# 更新字段(只更新提供的字段,排除 end_user_id)
|
||||
# 允许 None 值来重置字段(如 hire_date)
|
||||
update_data = profile_update.model_dump(exclude_unset=True, exclude={'end_user_id'})
|
||||
|
||||
|
||||
# 特殊处理 hire_date:如果提供了时间戳,转换为 DateTime
|
||||
if 'hire_date' in update_data:
|
||||
hire_date_timestamp = update_data['hire_date']
|
||||
if hire_date_timestamp is not None:
|
||||
update_data['hire_date'] = UserMemoryService.timestamp_to_datetime(hire_date_timestamp)
|
||||
update_data['hire_date'] = timestamp_to_datetime(hire_date_timestamp)
|
||||
# 如果是 None,保持 None(允许清空)
|
||||
|
||||
|
||||
for field, value in update_data.items():
|
||||
setattr(end_user, field, value)
|
||||
|
||||
|
||||
# 更新 updated_at 时间戳
|
||||
end_user.updated_at = datetime.datetime.now()
|
||||
|
||||
|
||||
# 更新 updatetime_profile 为当前时间
|
||||
end_user.updatetime_profile = datetime.datetime.now()
|
||||
|
||||
|
||||
# 提交更改
|
||||
db.commit()
|
||||
db.refresh(end_user)
|
||||
|
||||
|
||||
# 构建响应数据
|
||||
profile_data = EndUserProfileResponse(
|
||||
id=end_user.id,
|
||||
@@ -378,11 +392,243 @@ async def update_end_user_profile(
|
||||
hire_date=end_user.hire_date,
|
||||
updatetime_profile=end_user.updatetime_profile
|
||||
)
|
||||
|
||||
|
||||
api_logger.info(f"成功更新用户信息: end_user_id={end_user_id}, updated_fields={list(update_data.keys())}")
|
||||
return success(data=UserMemoryService.convert_profile_to_dict_with_timestamp(profile_data), msg="更新成功")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
api_logger.error(f"用户信息更新失败: end_user_id={end_user_id}, error={str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "用户信息更新失败", str(e))
|
||||
|
||||
@router.get("/memory_space/timeline_memories", response_model=ApiResponse)
|
||||
async def memory_space_timeline_of_shared_memories(id: str, label: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
MemoryEntity = MemoryEntityService(id, label)
|
||||
timeline_memories_result = await MemoryEntity.get_timeline_memories_server()
|
||||
return success(data=timeline_memories_result, msg="共同记忆时间线")
|
||||
@router.get("/memory_space/relationship_evolution", response_model=ApiResponse)
|
||||
async def memory_space_relationship_evolution(id: str, label: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
try:
|
||||
api_logger.info(f"关系演变查询请求: id={id}, table={label}, user={current_user.username}")
|
||||
|
||||
# 获取情绪数据
|
||||
emotion = MemoryEmotion(id, label)
|
||||
emotion_result = await emotion.get_emotion()
|
||||
|
||||
# 获取交互数据
|
||||
interaction = MemoryInteraction(id, label)
|
||||
interaction_result = await interaction.get_interaction_frequency()
|
||||
|
||||
# 关闭连接
|
||||
await emotion.close()
|
||||
await interaction.close()
|
||||
|
||||
result = {
|
||||
"emotion": emotion_result,
|
||||
"interaction": interaction_result
|
||||
}
|
||||
|
||||
api_logger.info(f"关系演变查询成功: id={id}, table={label}")
|
||||
return success(data=result, msg="关系演变")
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"关系演变查询失败: id={id}, table={label}, error={str(e)}", exc_info=True)
|
||||
return fail(BizCode.INTERNAL_ERROR, "关系演变查询失败", str(e))
|
||||
|
||||
|
||||
@router.post("/classifications/episodic-memory", response_model=ApiResponse)
|
||||
async def get_episodic_memory_overview_api(
|
||||
request: EpisodicMemoryOverviewRequest,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
"""
|
||||
获取情景记忆总览
|
||||
|
||||
返回指定用户的所有情景记忆列表,包括标题和创建时间。
|
||||
支持通过时间范围、情景类型和标题关键词进行筛选。
|
||||
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
# 检查用户是否已选择工作空间
|
||||
if workspace_id is None:
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试查询情景记忆总览但未选择工作空间")
|
||||
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
||||
|
||||
# 验证参数
|
||||
valid_time_ranges = ["all", "today", "this_week", "this_month"]
|
||||
valid_episodic_types = ["all", "conversation", "project_work", "learning", "decision", "important_event"]
|
||||
|
||||
if request.time_range not in valid_time_ranges:
|
||||
return fail(BizCode.INVALID_PARAMETER, f"无效的时间范围参数,可选值:{', '.join(valid_time_ranges)}")
|
||||
|
||||
if request.episodic_type not in valid_episodic_types:
|
||||
return fail(BizCode.INVALID_PARAMETER, f"无效的情景类型参数,可选值:{', '.join(valid_episodic_types)}")
|
||||
|
||||
# 处理 title_keyword(去除首尾空格)
|
||||
title_keyword = request.title_keyword.strip() if request.title_keyword else None
|
||||
|
||||
api_logger.info(
|
||||
f"情景记忆总览查询请求: end_user_id={request.end_user_id}, user={current_user.username}, "
|
||||
f"workspace={workspace_id}, time_range={request.time_range}, episodic_type={request.episodic_type}, "
|
||||
f"title_keyword={title_keyword}"
|
||||
)
|
||||
|
||||
try:
|
||||
# 调用Service层方法
|
||||
result = await user_memory_service.get_episodic_memory_overview(
|
||||
db, request.end_user_id, request.time_range, request.episodic_type, title_keyword
|
||||
)
|
||||
|
||||
api_logger.info(
|
||||
f"成功获取情景记忆总览: end_user_id={request.end_user_id}, "
|
||||
f"total={result['total']}"
|
||||
)
|
||||
return success(data=result, msg="查询成功")
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"情景记忆总览查询失败: end_user_id={request.end_user_id}, error={str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "情景记忆总览查询失败", str(e))
|
||||
|
||||
|
||||
@router.post("/classifications/episodic-memory-details", response_model=ApiResponse)
|
||||
async def get_episodic_memory_details_api(
|
||||
request: EpisodicMemoryDetailsRequest,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
"""
|
||||
获取情景记忆详情
|
||||
|
||||
返回指定情景记忆的详细信息,包括涉及对象、情景类型、内容记录和情绪。
|
||||
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
# 检查用户是否已选择工作空间
|
||||
if workspace_id is None:
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试查询情景记忆详情但未选择工作空间")
|
||||
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
||||
|
||||
api_logger.info(
|
||||
f"情景记忆详情查询请求: end_user_id={request.end_user_id}, summary_id={request.summary_id}, "
|
||||
f"user={current_user.username}, workspace={workspace_id}"
|
||||
)
|
||||
|
||||
try:
|
||||
# 调用Service层方法
|
||||
result = await user_memory_service.get_episodic_memory_details(
|
||||
db=db,
|
||||
end_user_id=request.end_user_id,
|
||||
summary_id=request.summary_id
|
||||
)
|
||||
|
||||
api_logger.info(
|
||||
f"成功获取情景记忆详情: end_user_id={request.end_user_id}, summary_id={request.summary_id}"
|
||||
)
|
||||
return success(data=result, msg="查询成功")
|
||||
|
||||
except ValueError as e:
|
||||
# 处理情景记忆不存在的情况
|
||||
api_logger.warning(f"情景记忆不存在: end_user_id={request.end_user_id}, summary_id={request.summary_id}, error={str(e)}")
|
||||
return fail(BizCode.INVALID_PARAMETER, "情景记忆不存在", str(e))
|
||||
except Exception as e:
|
||||
api_logger.error(f"情景记忆详情查询失败: end_user_id={request.end_user_id}, summary_id={request.summary_id}, error={str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "情景记忆详情查询失败", str(e))
|
||||
|
||||
|
||||
@router.post("/classifications/explicit-memory", response_model=ApiResponse)
|
||||
async def get_explicit_memory_overview_api(
|
||||
request: ExplicitMemoryOverviewRequest,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
"""
|
||||
获取显性记忆总览
|
||||
|
||||
返回指定用户的所有显性记忆列表,包括标题、完整内容、创建时间和情绪信息。
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
# 检查用户是否已选择工作空间
|
||||
if workspace_id is None:
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试查询显性记忆总览但未选择工作空间")
|
||||
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
||||
|
||||
api_logger.info(
|
||||
f"显性记忆总览查询请求: end_user_id={request.end_user_id}, user={current_user.username}, "
|
||||
f"workspace={workspace_id}"
|
||||
)
|
||||
|
||||
try:
|
||||
# 调用Service层方法
|
||||
result = await user_memory_service.get_explicit_memory_overview(
|
||||
db, request.end_user_id
|
||||
)
|
||||
|
||||
api_logger.info(
|
||||
f"成功获取显性记忆总览: end_user_id={request.end_user_id}, "
|
||||
f"total={result['total']}"
|
||||
)
|
||||
return success(data=result, msg="查询成功")
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"显性记忆总览查询失败: end_user_id={request.end_user_id}, error={str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "显性记忆总览查询失败", str(e))
|
||||
|
||||
|
||||
@router.post("/classifications/explicit-memory-details", response_model=ApiResponse)
|
||||
async def get_explicit_memory_details_api(
|
||||
request: ExplicitMemoryDetailsRequest,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
"""
|
||||
获取显性记忆详情
|
||||
|
||||
根据 memory_id 返回情景记忆或语义记忆的详细信息。
|
||||
- 情景记忆:包括标题、内容、情绪、创建时间
|
||||
- 语义记忆:包括名称、核心定义、详细笔记、创建时间
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
# 检查用户是否已选择工作空间
|
||||
if workspace_id is None:
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试查询显性记忆详情但未选择工作空间")
|
||||
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
||||
|
||||
api_logger.info(
|
||||
f"显性记忆详情查询请求: end_user_id={request.end_user_id}, memory_id={request.memory_id}, "
|
||||
f"user={current_user.username}, workspace={workspace_id}"
|
||||
)
|
||||
|
||||
try:
|
||||
# 调用Service层方法
|
||||
result = await user_memory_service.get_explicit_memory_details(
|
||||
db=db,
|
||||
end_user_id=request.end_user_id,
|
||||
memory_id=request.memory_id
|
||||
)
|
||||
|
||||
api_logger.info(
|
||||
f"成功获取显性记忆详情: end_user_id={request.end_user_id}, memory_id={request.memory_id}, "
|
||||
f"memory_type={result.get('memory_type')}"
|
||||
)
|
||||
return success(data=result, msg="查询成功")
|
||||
|
||||
except ValueError as e:
|
||||
# 处理记忆不存在的情况
|
||||
api_logger.warning(f"显性记忆不存在: end_user_id={request.end_user_id}, memory_id={request.memory_id}, error={str(e)}")
|
||||
return fail(BizCode.INVALID_PARAMETER, "显性记忆不存在", str(e))
|
||||
except Exception as e:
|
||||
api_logger.error(f"显性记忆详情查询失败: end_user_id={request.end_user_id}, memory_id={request.memory_id}, error={str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "显性记忆详情查询失败", str(e))
|
||||
|
||||
|
||||
|
||||
@@ -11,10 +11,16 @@ import os
|
||||
import time
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional, Sequence
|
||||
|
||||
|
||||
from app.db import get_db
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.core.memory.agent.utils.redis_tool import store
|
||||
from app.core.models import RedBearLLM, RedBearModelConfig
|
||||
from app.models.models_model import ModelType
|
||||
from app.repositories.memory_short_repository import LongTermMemoryRepository
|
||||
from app.services.memory_agent_service import (
|
||||
get_end_user_connected_config,
|
||||
)
|
||||
from app.services.memory_konwledges_server import write_rag
|
||||
from app.services.task_service import get_task_memory_write_result
|
||||
from app.tasks import write_message_task
|
||||
@@ -159,11 +165,13 @@ class LangChainAgent:
|
||||
history = store.find_user_apply_group(end_user_end, end_user_end, end_user_end)
|
||||
# logger.info(f'Redis_Agent:{end_user_end};{history}')
|
||||
messagss_list=[]
|
||||
retrieved_content=[]
|
||||
for messages in history:
|
||||
query = messages.get("Query")
|
||||
aimessages = messages.get("Answer")
|
||||
messagss_list.append(f'用户:{query}。AI回复:{aimessages}')
|
||||
return messagss_list
|
||||
retrieved_content.append({query: aimessages})
|
||||
return messagss_list,retrieved_content
|
||||
|
||||
|
||||
async def write(self,storage_type,end_user_id,message,user_rag_memory_id,actual_end_user_id,content,actual_config_id):
|
||||
@@ -203,7 +211,6 @@ class LangChainAgent:
|
||||
# If config_id is None, try to get from end_user's connected config
|
||||
if actual_config_id is None and end_user_id:
|
||||
try:
|
||||
from app.db import get_db
|
||||
from app.services.memory_agent_service import (
|
||||
get_end_user_connected_config,
|
||||
)
|
||||
@@ -221,11 +228,26 @@ class LangChainAgent:
|
||||
logger.info(f'写入类型{storage_type,str(end_user_id), message, str(user_rag_memory_id)}')
|
||||
print(f'写入类型{storage_type,str(end_user_id), message, str(user_rag_memory_id)}')
|
||||
|
||||
history_term_memory=await self.term_memory_redis_read(end_user_id)
|
||||
history_term_memory_result = await self.term_memory_redis_read(end_user_id)
|
||||
history_term_memory = history_term_memory_result[0]
|
||||
db_for_memory = next(get_db())
|
||||
if memory_flag:
|
||||
if len(history_term_memory)>=4 and storage_type != "rag":
|
||||
history_term_memory=';'.join(history_term_memory)
|
||||
logger.info(f'写入短长期:{storage_type, str(end_user_id), history_term_memory, str(user_rag_memory_id)}')
|
||||
history_term_memory = ';'.join(history_term_memory)
|
||||
retrieved_content = history_term_memory_result[1]
|
||||
print(retrieved_content)
|
||||
# 为长期记忆操作获取新的数据库连接
|
||||
try:
|
||||
repo = LongTermMemoryRepository(db_for_memory)
|
||||
repo.upsert(end_user_id, retrieved_content)
|
||||
logger.info(
|
||||
f'写入短长期:{storage_type, str(end_user_id), history_term_memory, str(user_rag_memory_id)}')
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to write to LongTermMemory: {e}")
|
||||
raise
|
||||
finally:
|
||||
db_for_memory.close()
|
||||
|
||||
await self.write(storage_type,end_user_id,history_term_memory,user_rag_memory_id,actual_end_user_id,history_term_memory,actual_config_id)
|
||||
await self.write(storage_type,end_user_id,message,user_rag_memory_id,actual_end_user_id,message,actual_config_id)
|
||||
try:
|
||||
@@ -314,10 +336,6 @@ class LangChainAgent:
|
||||
# If config_id is None, try to get from end_user's connected config
|
||||
if actual_config_id is None and end_user_id:
|
||||
try:
|
||||
from app.db import get_db
|
||||
from app.services.memory_agent_service import (
|
||||
get_end_user_connected_config,
|
||||
)
|
||||
db = next(get_db())
|
||||
try:
|
||||
connected_config = get_end_user_connected_config(end_user_id, db)
|
||||
@@ -329,14 +347,24 @@ class LangChainAgent:
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get db session: {e}")
|
||||
|
||||
history_term_memory = await self.term_memory_redis_read(end_user_id)
|
||||
history_term_memory_result = await self.term_memory_redis_read(end_user_id)
|
||||
history_term_memory = history_term_memory_result[0]
|
||||
if memory_flag:
|
||||
if len(history_term_memory) >= 4 and storage_type != "rag":
|
||||
history_term_memory = ';'.join(history_term_memory)
|
||||
logger.info(
|
||||
f'写入短长期:{storage_type, str(end_user_id), history_term_memory, str(user_rag_memory_id)}')
|
||||
await self.write(storage_type, end_user_id, history_term_memory, user_rag_memory_id, end_user_id,
|
||||
history_term_memory, actual_config_id)
|
||||
retrieved_content = history_term_memory_result[1]
|
||||
db_for_memory = next(get_db())
|
||||
try:
|
||||
repo = LongTermMemoryRepository(db_for_memory)
|
||||
repo.upsert(end_user_id, retrieved_content)
|
||||
logger.info(
|
||||
f'写入短长期:{storage_type, str(end_user_id), history_term_memory, str(user_rag_memory_id)}')
|
||||
await self.write(storage_type, end_user_id, history_term_memory, user_rag_memory_id, end_user_id,
|
||||
history_term_memory, actual_config_id)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to write to long term memory: {e}")
|
||||
finally:
|
||||
db_for_memory.close()
|
||||
|
||||
await self.write(storage_type, end_user_id, message, user_rag_memory_id, end_user_id, message, actual_config_id)
|
||||
try:
|
||||
|
||||
@@ -164,6 +164,10 @@ class Settings:
|
||||
TOOL_EXECUTION_TIMEOUT: int = int(os.getenv("TOOL_EXECUTION_TIMEOUT", "60"))
|
||||
TOOL_MAX_CONCURRENCY: int = int(os.getenv("TOOL_MAX_CONCURRENCY", "10"))
|
||||
ENABLE_TOOL_MANAGEMENT: bool = os.getenv("ENABLE_TOOL_MANAGEMENT", "true").lower() == "true"
|
||||
|
||||
# official environment system version
|
||||
SYSTEM_VERSION: str = os.getenv("SYSTEM_VERSION", "v0.2.0")
|
||||
SYSTEM_INTRODUCTION: str = os.getenv("SYSTEM_INTRODUCTION", "")
|
||||
|
||||
def get_memory_output_path(self, filename: str = "") -> str:
|
||||
"""
|
||||
|
||||
@@ -82,6 +82,13 @@ class BizCode(IntEnum):
|
||||
MEMORY_WRITE_FAILED = 9501
|
||||
MEMORY_READ_FAILED = 9502
|
||||
MEMORY_CONFIG_NOT_FOUND = 9503
|
||||
|
||||
# Implicit Memory API(96xx)
|
||||
INVALID_USER_ID = 9601
|
||||
INSUFFICIENT_DATA = 9602
|
||||
INVALID_FILTER_PARAMS = 9603
|
||||
ANALYSIS_FAILED = 9604
|
||||
PROFILE_STORAGE_ERROR = 9605
|
||||
|
||||
# 系统(100xx)
|
||||
INTERNAL_ERROR = 10001
|
||||
@@ -159,6 +166,13 @@ HTTP_MAPPING = {
|
||||
BizCode.MEMORY_READ_FAILED: 500,
|
||||
BizCode.MEMORY_CONFIG_NOT_FOUND: 400,
|
||||
|
||||
# Implicit Memory API 错误码映射
|
||||
BizCode.INVALID_USER_ID: 400,
|
||||
BizCode.INSUFFICIENT_DATA: 400,
|
||||
BizCode.INVALID_FILTER_PARAMS: 400,
|
||||
BizCode.ANALYSIS_FAILED: 500,
|
||||
BizCode.PROFILE_STORAGE_ERROR: 500,
|
||||
|
||||
BizCode.INTERNAL_ERROR: 500,
|
||||
BizCode.DB_ERROR: 500,
|
||||
BizCode.SERVICE_UNAVAILABLE: 503,
|
||||
|
||||
@@ -5,19 +5,16 @@ This module provides analytics and insights for the memory system.
|
||||
|
||||
Available functions:
|
||||
- get_hot_memory_tags: Get hot memory tags by frequency
|
||||
- MemoryInsight: Generate memory insight reports
|
||||
- get_recent_activity_stats: Get recent activity statistics
|
||||
- generate_user_summary: Generate user summary
|
||||
|
||||
Note: MemoryInsight and generate_user_summary have been moved to
|
||||
app.services.user_memory_service for better architecture.
|
||||
"""
|
||||
|
||||
from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags
|
||||
from app.core.memory.analytics.memory_insight import MemoryInsight
|
||||
from app.core.memory.analytics.recent_activity_stats import get_recent_activity_stats
|
||||
from app.core.memory.analytics.user_summary import generate_user_summary
|
||||
|
||||
__all__ = [
|
||||
"get_hot_memory_tags",
|
||||
"MemoryInsight",
|
||||
"get_recent_activity_stats",
|
||||
"generate_user_summary",
|
||||
]
|
||||
|
||||
@@ -0,0 +1,6 @@
|
||||
"""Implicit Memory Module
|
||||
|
||||
This module provides behavior analysis capabilities that build comprehensive user profiles
|
||||
by analyzing memory summary nodes from Neo4j. It creates detailed user portraits across
|
||||
multiple dimensions, tracks interest distributions, and identifies behavioral habits.
|
||||
"""
|
||||
@@ -0,0 +1 @@
|
||||
"""Analyzers package for implicit memory analysis components."""
|
||||
@@ -0,0 +1,271 @@
|
||||
"""Dimension Analyzer for Implicit Memory System
|
||||
|
||||
This module implements LLM-based personality dimension analysis from user memory summaries.
|
||||
It analyzes four key dimensions: creativity, aesthetic, technology, and literature,
|
||||
providing percentage scores with evidence and reasoning.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from app.core.memory.analytics.implicit_memory.llm_client import ImplicitMemoryLLMClient
|
||||
from app.core.memory.llm_tools.llm_client import LLMClientException
|
||||
from app.schemas.implicit_memory_schema import (
|
||||
DimensionPortrait,
|
||||
DimensionScore,
|
||||
UserMemorySummary,
|
||||
)
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DimensionData(BaseModel):
|
||||
"""Individual dimension analysis data."""
|
||||
percentage: float = Field(ge=0.0, le=100.0)
|
||||
evidence: List[str] = Field(default_factory=list)
|
||||
reasoning: str = ""
|
||||
confidence_level: int = 50 # Default to medium confidence
|
||||
|
||||
|
||||
class DimensionAnalysisResponse(BaseModel):
|
||||
"""Response model for dimension analysis."""
|
||||
dimensions: Dict[str, DimensionData] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class DimensionAnalyzer:
|
||||
"""Analyzes user memory summaries to extract personality dimensions."""
|
||||
|
||||
# Define the four dimensions we analyze
|
||||
DIMENSIONS = ["creativity", "aesthetic", "technology", "literature"]
|
||||
|
||||
def __init__(self, db: Session, llm_model_id: Optional[str] = None):
|
||||
"""Initialize the dimension analyzer.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
llm_model_id: Optional LLM model ID to use for analysis
|
||||
"""
|
||||
self.db = db
|
||||
self.llm_model_id = llm_model_id
|
||||
self._llm_client = ImplicitMemoryLLMClient(db, llm_model_id)
|
||||
|
||||
async def analyze_dimensions(
|
||||
self,
|
||||
user_id: str,
|
||||
user_summaries: List[UserMemorySummary],
|
||||
existing_portrait: Optional[DimensionPortrait] = None
|
||||
) -> DimensionPortrait:
|
||||
"""Analyze user summaries to extract personality dimensions.
|
||||
|
||||
Args:
|
||||
user_id: Target user ID
|
||||
user_summaries: List of user-specific memory summaries
|
||||
existing_portrait: Optional existing portrait for incremental updates
|
||||
|
||||
Returns:
|
||||
Dimension portrait with four personality dimensions
|
||||
|
||||
Raises:
|
||||
LLMClientException: If LLM analysis fails
|
||||
"""
|
||||
if not user_summaries:
|
||||
logger.warning(f"No summaries provided for user {user_id}")
|
||||
return self._create_empty_portrait(user_id)
|
||||
|
||||
try:
|
||||
logger.info(f"Analyzing dimensions for user {user_id} with {len(user_summaries)} summaries")
|
||||
|
||||
# Use the LLM client wrapper for analysis
|
||||
response = await self._llm_client.analyze_dimensions(
|
||||
user_summaries=user_summaries,
|
||||
user_id=user_id,
|
||||
model_id=self.llm_model_id
|
||||
)
|
||||
|
||||
# Create dimension scores
|
||||
dimension_scores = {}
|
||||
current_time = datetime.now()
|
||||
|
||||
for dimension_name in self.DIMENSIONS:
|
||||
# Handle response as dictionary
|
||||
dimensions_data = response.get("dimensions", {})
|
||||
dimension_data = dimensions_data.get(dimension_name)
|
||||
|
||||
if dimension_data:
|
||||
# Validate and create dimension score
|
||||
score = self._create_dimension_score(
|
||||
dimension_name=dimension_name,
|
||||
dimension_data=dimension_data
|
||||
)
|
||||
dimension_scores[dimension_name] = score
|
||||
else:
|
||||
# Create default score if missing
|
||||
logger.warning(f"Missing dimension data for {dimension_name}, using default")
|
||||
dimension_scores[dimension_name] = self._create_default_dimension_score(dimension_name)
|
||||
|
||||
# Create dimension portrait
|
||||
portrait = DimensionPortrait(
|
||||
user_id=user_id,
|
||||
creativity=dimension_scores["creativity"],
|
||||
aesthetic=dimension_scores["aesthetic"],
|
||||
technology=dimension_scores["technology"],
|
||||
literature=dimension_scores["literature"],
|
||||
analysis_timestamp=current_time,
|
||||
total_summaries_analyzed=len(user_summaries),
|
||||
historical_trends=self._calculate_historical_trends(existing_portrait) if existing_portrait else None
|
||||
)
|
||||
|
||||
logger.info(f"Created dimension portrait for user {user_id}")
|
||||
return portrait
|
||||
|
||||
except LLMClientException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Dimension analysis failed for user {user_id}: {e}")
|
||||
raise LLMClientException(f"Dimension analysis failed: {e}") from e
|
||||
|
||||
def _create_dimension_score(
|
||||
self,
|
||||
dimension_name: str,
|
||||
dimension_data: dict
|
||||
) -> DimensionScore:
|
||||
"""Create a dimension score from analysis data.
|
||||
|
||||
Args:
|
||||
dimension_name: Name of the dimension
|
||||
dimension_data: Analysis data dictionary for the dimension
|
||||
|
||||
Returns:
|
||||
DimensionScore object
|
||||
"""
|
||||
# Validate percentage - handle dict access
|
||||
percentage = dimension_data.get("percentage", 0.0)
|
||||
percentage = max(0.0, min(100.0, float(percentage)))
|
||||
|
||||
# Validate confidence level
|
||||
confidence_level = self._validate_confidence_level(dimension_data.get("confidence_level", 50))
|
||||
|
||||
# Ensure evidence is not empty
|
||||
evidence = dimension_data.get("evidence", [])
|
||||
if not evidence:
|
||||
evidence = ["No specific evidence found"]
|
||||
|
||||
# Ensure reasoning is not empty
|
||||
reasoning = dimension_data.get("reasoning", "")
|
||||
if not reasoning:
|
||||
reasoning = f"Analysis for {dimension_name} dimension"
|
||||
|
||||
return DimensionScore(
|
||||
dimension_name=dimension_name,
|
||||
percentage=percentage,
|
||||
evidence=evidence,
|
||||
reasoning=reasoning,
|
||||
confidence_level=confidence_level
|
||||
)
|
||||
|
||||
def _create_default_dimension_score(self, dimension_name: str) -> DimensionScore:
|
||||
"""Create a default dimension score when analysis fails.
|
||||
|
||||
Args:
|
||||
dimension_name: Name of the dimension
|
||||
|
||||
Returns:
|
||||
Default DimensionScore object
|
||||
"""
|
||||
return DimensionScore(
|
||||
dimension_name=dimension_name,
|
||||
percentage=0.0,
|
||||
evidence=["Insufficient data for analysis"],
|
||||
reasoning=f"No clear evidence found for {dimension_name} dimension",
|
||||
confidence_level=20 # Low confidence as numerical value
|
||||
)
|
||||
|
||||
def _validate_confidence_level(self, confidence_level) -> int:
|
||||
"""Return confidence level as integer, handling both string and numeric inputs.
|
||||
|
||||
Args:
|
||||
confidence_level: Confidence level (string or numeric)
|
||||
|
||||
Returns:
|
||||
Confidence level as integer (0-100)
|
||||
"""
|
||||
# If it's already a number, return it as int
|
||||
if isinstance(confidence_level, (int, float)):
|
||||
return int(confidence_level)
|
||||
|
||||
# If it's a string, convert common values to numbers
|
||||
if isinstance(confidence_level, str):
|
||||
confidence_str = confidence_level.lower().strip()
|
||||
if confidence_str in ["high", "높음"]:
|
||||
return 85
|
||||
elif confidence_str in ["medium", "중간"]:
|
||||
return 50
|
||||
elif confidence_str in ["low", "낮음"]:
|
||||
return 20
|
||||
else:
|
||||
# Try to parse as number
|
||||
try:
|
||||
return int(float(confidence_str))
|
||||
except ValueError:
|
||||
logger.warning(f"Unknown confidence level: {confidence_level}, defaulting to medium")
|
||||
return 50
|
||||
|
||||
# Default fallback
|
||||
return 50
|
||||
|
||||
def _create_empty_portrait(self, user_id: str) -> DimensionPortrait:
|
||||
"""Create an empty dimension portrait when no data is available.
|
||||
|
||||
Args:
|
||||
user_id: Target user ID
|
||||
|
||||
Returns:
|
||||
Empty DimensionPortrait
|
||||
"""
|
||||
current_time = datetime.now()
|
||||
|
||||
return DimensionPortrait(
|
||||
user_id=user_id,
|
||||
creativity=self._create_default_dimension_score("creativity"),
|
||||
aesthetic=self._create_default_dimension_score("aesthetic"),
|
||||
technology=self._create_default_dimension_score("technology"),
|
||||
literature=self._create_default_dimension_score("literature"),
|
||||
analysis_timestamp=current_time,
|
||||
total_summaries_analyzed=0,
|
||||
historical_trends=None
|
||||
)
|
||||
|
||||
def _calculate_historical_trends(
|
||||
self,
|
||||
existing_portrait: DimensionPortrait
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Calculate historical trends from existing portrait.
|
||||
|
||||
Args:
|
||||
existing_portrait: Previous dimension portrait
|
||||
|
||||
Returns:
|
||||
List of historical trend data
|
||||
"""
|
||||
if not existing_portrait:
|
||||
return []
|
||||
|
||||
# Create trend entry from existing portrait
|
||||
trend_entry = {
|
||||
"timestamp": existing_portrait.analysis_timestamp.isoformat(),
|
||||
"creativity": existing_portrait.creativity.percentage,
|
||||
"aesthetic": existing_portrait.aesthetic.percentage,
|
||||
"technology": existing_portrait.technology.percentage,
|
||||
"literature": existing_portrait.literature.percentage,
|
||||
"total_summaries": existing_portrait.total_summaries_analyzed
|
||||
}
|
||||
|
||||
# Combine with existing trends
|
||||
existing_trends = existing_portrait.historical_trends or []
|
||||
|
||||
# Keep only recent trends (last 10 entries)
|
||||
all_trends = existing_trends + [trend_entry]
|
||||
return all_trends[-10:]
|
||||
@@ -0,0 +1,459 @@
|
||||
"""Habit Analyzer for Implicit Memory System
|
||||
|
||||
This module implements LLM-based behavioral habit analysis from user memory summaries.
|
||||
It identifies recurring behavioral patterns, temporal patterns, and consolidates
|
||||
similar habits with confidence scoring.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import List, Optional
|
||||
|
||||
from app.core.memory.analytics.implicit_memory.llm_client import ImplicitMemoryLLMClient
|
||||
from app.core.memory.llm_tools.llm_client import LLMClientException
|
||||
from app.schemas.implicit_memory_schema import (
|
||||
BehaviorHabit,
|
||||
FrequencyPattern,
|
||||
UserMemorySummary,
|
||||
)
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HabitData(BaseModel):
|
||||
"""Individual habit analysis data."""
|
||||
habit_description: str
|
||||
frequency_pattern: str
|
||||
time_context: str
|
||||
confidence_level: int = 50 # Default to medium confidence
|
||||
supporting_summaries: List[str] = Field(default_factory=list)
|
||||
specific_examples: List[str] = Field(default_factory=list)
|
||||
is_current: bool = True
|
||||
|
||||
|
||||
class HabitAnalysisResponse(BaseModel):
|
||||
"""Response model for habit analysis."""
|
||||
habits: List[HabitData] = Field(default_factory=list)
|
||||
|
||||
|
||||
class HabitAnalyzer:
|
||||
"""Analyzes user memory summaries to extract behavioral habits."""
|
||||
|
||||
def __init__(self, db: Session, llm_model_id: Optional[str] = None):
|
||||
"""Initialize the habit analyzer.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
llm_model_id: Optional LLM model ID to use for analysis
|
||||
"""
|
||||
self.db = db
|
||||
self.llm_model_id = llm_model_id
|
||||
self._llm_client = ImplicitMemoryLLMClient(db, llm_model_id)
|
||||
|
||||
async def analyze_habits(
|
||||
self,
|
||||
user_id: str,
|
||||
user_summaries: List[UserMemorySummary],
|
||||
existing_habits: Optional[List[BehaviorHabit]] = None
|
||||
) -> List[BehaviorHabit]:
|
||||
"""Analyze user summaries to extract behavioral habits.
|
||||
|
||||
Args:
|
||||
user_id: Target user ID
|
||||
user_summaries: List of user-specific memory summaries
|
||||
existing_habits: Optional existing habits for consolidation
|
||||
|
||||
Returns:
|
||||
List of extracted behavioral habits
|
||||
|
||||
Raises:
|
||||
LLMClientException: If LLM analysis fails
|
||||
"""
|
||||
if not user_summaries:
|
||||
logger.warning(f"No summaries provided for user {user_id}")
|
||||
return existing_habits or []
|
||||
|
||||
try:
|
||||
logger.info(f"Analyzing habits for user {user_id} with {len(user_summaries)} summaries")
|
||||
|
||||
# Use the LLM client wrapper for analysis
|
||||
response = await self._llm_client.analyze_habits(
|
||||
user_summaries=user_summaries,
|
||||
user_id=user_id,
|
||||
model_id=self.llm_model_id
|
||||
)
|
||||
|
||||
# Convert to BehaviorHabit objects
|
||||
behavior_habits = []
|
||||
|
||||
for habit_data in response.get("habits", []):
|
||||
try:
|
||||
# Handle habit_data as dictionary
|
||||
supporting_summaries = habit_data.get("supporting_summaries", [])
|
||||
specific_examples = habit_data.get("specific_examples", [])
|
||||
|
||||
# Determine observation dates from summaries
|
||||
first_observed, last_observed = self._determine_observation_dates(
|
||||
user_summaries, supporting_summaries
|
||||
)
|
||||
|
||||
behavior_habit = BehaviorHabit(
|
||||
habit_description=habit_data.get("habit_description", ""),
|
||||
frequency_pattern=self._validate_frequency_pattern(habit_data.get("frequency_pattern", "occasional")),
|
||||
time_context=habit_data.get("time_context", ""),
|
||||
confidence_level=self._validate_confidence_level(habit_data.get("confidence_level", 50)),
|
||||
specific_examples=specific_examples,
|
||||
first_observed=first_observed,
|
||||
last_observed=last_observed,
|
||||
is_current=habit_data.get("is_current", True)
|
||||
)
|
||||
|
||||
# Validate habit
|
||||
if self._is_valid_habit(behavior_habit):
|
||||
behavior_habits.append(behavior_habit)
|
||||
else:
|
||||
logger.warning(f"Invalid habit skipped: {behavior_habit.habit_description}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating behavior habit: {e}")
|
||||
continue
|
||||
|
||||
# Consolidate with existing habits if provided
|
||||
if existing_habits:
|
||||
behavior_habits = self._consolidate_habits(
|
||||
new_habits=behavior_habits,
|
||||
existing_habits=existing_habits
|
||||
)
|
||||
|
||||
# Sort habits by confidence and recency
|
||||
behavior_habits = self._sort_habits_by_priority(behavior_habits)
|
||||
|
||||
logger.info(f"Extracted {len(behavior_habits)} habits for user {user_id}")
|
||||
return behavior_habits
|
||||
|
||||
except LLMClientException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Habit analysis failed for user {user_id}: {e}")
|
||||
raise LLMClientException(f"Habit analysis failed: {e}") from e
|
||||
|
||||
def _validate_frequency_pattern(self, frequency_str: str) -> FrequencyPattern:
|
||||
"""Validate and convert frequency pattern string.
|
||||
|
||||
Args:
|
||||
frequency_str: Frequency pattern as string
|
||||
|
||||
Returns:
|
||||
FrequencyPattern enum value
|
||||
"""
|
||||
frequency_str = frequency_str.lower().strip()
|
||||
|
||||
frequency_mapping = {
|
||||
"daily": FrequencyPattern.DAILY,
|
||||
"weekly": FrequencyPattern.WEEKLY,
|
||||
"monthly": FrequencyPattern.MONTHLY,
|
||||
"seasonal": FrequencyPattern.SEASONAL,
|
||||
"occasional": FrequencyPattern.OCCASIONAL,
|
||||
"event_triggered": FrequencyPattern.EVENT_TRIGGERED,
|
||||
"event-triggered": FrequencyPattern.EVENT_TRIGGERED,
|
||||
}
|
||||
|
||||
return frequency_mapping.get(frequency_str, FrequencyPattern.OCCASIONAL)
|
||||
|
||||
def _validate_confidence_level(self, confidence_level) -> int:
|
||||
"""Return confidence level as integer, handling both string and numeric inputs.
|
||||
|
||||
Args:
|
||||
confidence_level: Confidence level (string or numeric)
|
||||
|
||||
Returns:
|
||||
Confidence level as integer (0-100)
|
||||
"""
|
||||
# If it's already a number, return it as int
|
||||
if isinstance(confidence_level, (int, float)):
|
||||
return int(confidence_level)
|
||||
|
||||
# If it's a string, convert common values to numbers
|
||||
if isinstance(confidence_level, str):
|
||||
confidence_str = confidence_level.lower().strip()
|
||||
if confidence_str in ["high", "높음"]:
|
||||
return 85
|
||||
elif confidence_str in ["medium", "중간"]:
|
||||
return 50
|
||||
elif confidence_str in ["low", "낮음"]:
|
||||
return 20
|
||||
else:
|
||||
# Try to parse as number
|
||||
try:
|
||||
return int(float(confidence_str))
|
||||
except ValueError:
|
||||
logger.warning(f"Unknown confidence level: {confidence_level}, defaulting to medium")
|
||||
return 50
|
||||
|
||||
# Default fallback
|
||||
return 50
|
||||
|
||||
def _determine_observation_dates(
|
||||
self,
|
||||
user_summaries: List[UserMemorySummary],
|
||||
supporting_summary_ids: List[str]
|
||||
) -> tuple[datetime, datetime]:
|
||||
"""Determine first and last observation dates for a habit.
|
||||
|
||||
Args:
|
||||
user_summaries: List of user summaries
|
||||
supporting_summary_ids: IDs of summaries supporting the habit
|
||||
|
||||
Returns:
|
||||
Tuple of (first_observed, last_observed) dates
|
||||
"""
|
||||
from datetime import timezone
|
||||
|
||||
# Find summaries that support this habit
|
||||
supporting_summaries = [
|
||||
summary for summary in user_summaries
|
||||
if summary.summary_id in supporting_summary_ids
|
||||
]
|
||||
|
||||
if not supporting_summaries:
|
||||
# Use all summaries if no specific supporting summaries found
|
||||
supporting_summaries = user_summaries
|
||||
|
||||
if not supporting_summaries:
|
||||
current_time = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
return current_time, current_time
|
||||
|
||||
# Get date range from supporting summaries - normalize to naive datetimes
|
||||
timestamps = []
|
||||
for summary in supporting_summaries:
|
||||
ts = summary.timestamp
|
||||
# Convert to naive datetime if it's timezone-aware
|
||||
if ts.tzinfo is not None:
|
||||
ts = ts.replace(tzinfo=None)
|
||||
timestamps.append(ts)
|
||||
|
||||
first_observed = min(timestamps)
|
||||
last_observed = max(timestamps)
|
||||
|
||||
return first_observed, last_observed
|
||||
|
||||
def _is_valid_habit(self, habit: BehaviorHabit) -> bool:
|
||||
"""Validate a behavioral habit.
|
||||
|
||||
Args:
|
||||
habit: Behavioral habit to validate
|
||||
|
||||
Returns:
|
||||
True if valid, False otherwise
|
||||
"""
|
||||
try:
|
||||
# Check required fields
|
||||
if not habit.habit_description or not habit.habit_description.strip():
|
||||
return False
|
||||
|
||||
# Check time context
|
||||
if not habit.time_context or not habit.time_context.strip():
|
||||
return False
|
||||
|
||||
# Check supporting summaries
|
||||
if not habit.specific_examples or len(habit.specific_examples) == 0:
|
||||
return False
|
||||
|
||||
# Check specific examples
|
||||
if not habit.specific_examples or len(habit.specific_examples) == 0:
|
||||
return False
|
||||
|
||||
# Check observation dates
|
||||
if habit.first_observed > habit.last_observed:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error validating habit: {e}")
|
||||
return False
|
||||
|
||||
def _consolidate_habits(
|
||||
self,
|
||||
new_habits: List[BehaviorHabit],
|
||||
existing_habits: List[BehaviorHabit],
|
||||
similarity_threshold: float = 0.7
|
||||
) -> List[BehaviorHabit]:
|
||||
"""Consolidate new habits with existing ones.
|
||||
|
||||
Args:
|
||||
new_habits: Newly extracted habits
|
||||
existing_habits: Existing habits
|
||||
similarity_threshold: Threshold for considering habits similar
|
||||
|
||||
Returns:
|
||||
Consolidated list of habits
|
||||
"""
|
||||
consolidated = existing_habits.copy()
|
||||
current_time = datetime.now()
|
||||
|
||||
for new_habit in new_habits:
|
||||
# Find similar existing habit
|
||||
similar_habit = self._find_similar_habit(
|
||||
new_habit, existing_habits, similarity_threshold
|
||||
)
|
||||
|
||||
if similar_habit:
|
||||
# Update existing habit
|
||||
updated_habit = self._merge_habits(similar_habit, new_habit, current_time)
|
||||
# Replace in consolidated list
|
||||
for i, habit in enumerate(consolidated):
|
||||
if habit.habit_description == similar_habit.habit_description:
|
||||
consolidated[i] = updated_habit
|
||||
break
|
||||
else:
|
||||
# Add as new habit
|
||||
consolidated.append(new_habit)
|
||||
|
||||
return consolidated
|
||||
|
||||
def _find_similar_habit(
|
||||
self,
|
||||
target_habit: BehaviorHabit,
|
||||
existing_habits: List[BehaviorHabit],
|
||||
threshold: float
|
||||
) -> Optional[BehaviorHabit]:
|
||||
"""Find similar habit in existing list.
|
||||
|
||||
Args:
|
||||
target_habit: Habit to find similarity for
|
||||
existing_habits: List of existing habits
|
||||
threshold: Similarity threshold
|
||||
|
||||
Returns:
|
||||
Similar habit if found, None otherwise
|
||||
"""
|
||||
target_desc = target_habit.habit_description.lower().strip()
|
||||
|
||||
for existing_habit in existing_habits:
|
||||
existing_desc = existing_habit.habit_description.lower().strip()
|
||||
|
||||
# Check description similarity
|
||||
desc_similarity = self._calculate_text_similarity(target_desc, existing_desc)
|
||||
|
||||
# Check frequency pattern match
|
||||
frequency_match = (target_habit.frequency_pattern == existing_habit.frequency_pattern)
|
||||
|
||||
# Check time context similarity
|
||||
time_similarity = self._calculate_text_similarity(
|
||||
target_habit.time_context.lower(),
|
||||
existing_habit.time_context.lower()
|
||||
)
|
||||
|
||||
# Combined similarity score
|
||||
combined_similarity = (desc_similarity * 0.6 + time_similarity * 0.4)
|
||||
if frequency_match:
|
||||
combined_similarity += 0.1 # Bonus for frequency match
|
||||
|
||||
if combined_similarity >= threshold:
|
||||
return existing_habit
|
||||
|
||||
return None
|
||||
|
||||
def _calculate_text_similarity(self, text1: str, text2: str) -> float:
|
||||
"""Calculate simple text similarity based on common words.
|
||||
|
||||
Args:
|
||||
text1: First text
|
||||
text2: Second text
|
||||
|
||||
Returns:
|
||||
Similarity score between 0.0 and 1.0
|
||||
"""
|
||||
if not text1 or not text2:
|
||||
return 0.0
|
||||
|
||||
# Simple word-based similarity
|
||||
words1 = set(text1.lower().split())
|
||||
words2 = set(text2.lower().split())
|
||||
|
||||
if not words1 or not words2:
|
||||
return 0.0
|
||||
|
||||
intersection = words1.intersection(words2)
|
||||
union = words1.union(words2)
|
||||
|
||||
return len(intersection) / len(union) if union else 0.0
|
||||
|
||||
def _merge_habits(
|
||||
self,
|
||||
existing_habit: BehaviorHabit,
|
||||
new_habit: BehaviorHabit,
|
||||
current_time: datetime
|
||||
) -> BehaviorHabit:
|
||||
"""Merge two similar habits.
|
||||
|
||||
Args:
|
||||
existing_habit: Existing habit
|
||||
new_habit: New habit to merge
|
||||
current_time: Current timestamp
|
||||
|
||||
Returns:
|
||||
Merged behavioral habit
|
||||
"""
|
||||
# Combine supporting summaries (using specific_examples instead)
|
||||
combined_examples = list(set(
|
||||
existing_habit.specific_examples + new_habit.specific_examples
|
||||
))
|
||||
|
||||
# Combine specific examples
|
||||
combined_examples = list(set(
|
||||
existing_habit.specific_examples + new_habit.specific_examples
|
||||
))
|
||||
|
||||
# Update confidence level (take higher confidence)
|
||||
new_confidence = max(existing_habit.confidence_level, new_habit.confidence_level)
|
||||
|
||||
# Update observation dates
|
||||
first_observed = min(existing_habit.first_observed, new_habit.first_observed)
|
||||
last_observed = max(existing_habit.last_observed, new_habit.last_observed)
|
||||
|
||||
# Determine if habit is current (observed within last 30 days)
|
||||
is_current = (current_time - last_observed).days <= 30
|
||||
|
||||
# Combine time context
|
||||
combined_time_context = existing_habit.time_context
|
||||
if new_habit.time_context and new_habit.time_context not in combined_time_context:
|
||||
combined_time_context += f"; {new_habit.time_context}"
|
||||
|
||||
return BehaviorHabit(
|
||||
habit_description=existing_habit.habit_description, # Keep original description
|
||||
frequency_pattern=existing_habit.frequency_pattern, # Keep original frequency
|
||||
time_context=combined_time_context,
|
||||
confidence_level=new_confidence,
|
||||
specific_examples=combined_examples,
|
||||
first_observed=first_observed,
|
||||
last_observed=last_observed,
|
||||
is_current=is_current
|
||||
)
|
||||
|
||||
def _sort_habits_by_priority(self, habits: List[BehaviorHabit]) -> List[BehaviorHabit]:
|
||||
"""Sort habits by confidence level and recency.
|
||||
|
||||
Args:
|
||||
habits: List of habits to sort
|
||||
|
||||
Returns:
|
||||
Sorted list of habits
|
||||
"""
|
||||
def priority_score(habit: BehaviorHabit) -> tuple:
|
||||
# Confidence level score (0-100 scale)
|
||||
confidence_score = habit.confidence_level
|
||||
|
||||
# Recency score (more recent = higher score)
|
||||
days_since_last = (datetime.now() - habit.last_observed).days
|
||||
recency_score = max(0, 365 - days_since_last) # Max 365 days
|
||||
|
||||
# Current habit bonus
|
||||
current_bonus = 100 if habit.is_current else 0
|
||||
|
||||
return (confidence_score, recency_score + current_bonus, habit.last_observed)
|
||||
|
||||
return sorted(habits, key=priority_score, reverse=True)
|
||||
@@ -0,0 +1,277 @@
|
||||
"""Interest Analyzer for Implicit Memory System
|
||||
|
||||
This module implements LLM-based interest area analysis from user memory summaries.
|
||||
It categorizes user interests into four areas: tech, lifestyle, music, and art,
|
||||
providing percentage distribution that totals 100%.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from app.core.memory.analytics.implicit_memory.llm_client import ImplicitMemoryLLMClient
|
||||
from app.core.memory.llm_tools.llm_client import LLMClientException
|
||||
from app.schemas.implicit_memory_schema import (
|
||||
InterestAreaDistribution,
|
||||
InterestCategory,
|
||||
UserMemorySummary,
|
||||
)
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class InterestData(BaseModel):
|
||||
"""Individual interest category analysis data."""
|
||||
percentage: float = Field(ge=0.0, le=100.0)
|
||||
evidence: List[str] = Field(default_factory=list)
|
||||
trending_direction: Optional[str] = None
|
||||
|
||||
|
||||
class InterestAnalysisResponse(BaseModel):
|
||||
"""Response model for interest analysis."""
|
||||
interest_distribution: Dict[str, InterestData] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class InterestAnalyzer:
|
||||
"""Analyzes user memory summaries to extract interest area distribution."""
|
||||
|
||||
# Define the four interest categories we analyze
|
||||
INTEREST_CATEGORIES = ["tech", "lifestyle", "music", "art"]
|
||||
|
||||
def __init__(self, db: Session, llm_model_id: Optional[str] = None):
|
||||
"""Initialize the interest analyzer.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
llm_model_id: Optional LLM model ID to use for analysis
|
||||
"""
|
||||
self.db = db
|
||||
self.llm_model_id = llm_model_id
|
||||
self._llm_client = ImplicitMemoryLLMClient(db, llm_model_id)
|
||||
|
||||
async def analyze_interests(
|
||||
self,
|
||||
user_id: str,
|
||||
user_summaries: List[UserMemorySummary],
|
||||
existing_distribution: Optional[InterestAreaDistribution] = None
|
||||
) -> InterestAreaDistribution:
|
||||
"""Analyze user summaries to extract interest area distribution.
|
||||
|
||||
Args:
|
||||
user_id: Target user ID
|
||||
user_summaries: List of user-specific memory summaries
|
||||
existing_distribution: Optional existing distribution for trend tracking
|
||||
|
||||
Returns:
|
||||
Interest area distribution across four categories
|
||||
|
||||
Raises:
|
||||
LLMClientException: If LLM analysis fails
|
||||
"""
|
||||
if not user_summaries:
|
||||
logger.warning(f"No summaries provided for user {user_id}")
|
||||
return self._create_empty_distribution(user_id)
|
||||
|
||||
try:
|
||||
logger.info(f"Analyzing interests for user {user_id} with {len(user_summaries)} summaries")
|
||||
|
||||
# Use the LLM client wrapper for analysis
|
||||
response = await self._llm_client.analyze_interests(
|
||||
user_summaries=user_summaries,
|
||||
user_id=user_id,
|
||||
model_id=self.llm_model_id
|
||||
)
|
||||
|
||||
# Create interest categories
|
||||
interest_categories = {}
|
||||
current_time = datetime.now()
|
||||
|
||||
# Extract interest_distribution from response dict
|
||||
interest_distribution = response.get("interest_distribution", {})
|
||||
|
||||
# Extract and validate interest data
|
||||
raw_interests = {}
|
||||
for category_name in self.INTEREST_CATEGORIES:
|
||||
interest_data_dict = interest_distribution.get(category_name)
|
||||
if interest_data_dict:
|
||||
raw_interests[category_name] = InterestData(
|
||||
percentage=interest_data_dict.get("percentage", 0.0),
|
||||
evidence=interest_data_dict.get("evidence", []),
|
||||
trending_direction=interest_data_dict.get("trending_direction")
|
||||
)
|
||||
else:
|
||||
# Create default if missing
|
||||
logger.warning(f"Missing interest data for {category_name}, using default")
|
||||
raw_interests[category_name] = InterestData(
|
||||
percentage=0.0,
|
||||
evidence=["No specific evidence found"],
|
||||
trending_direction=None
|
||||
)
|
||||
|
||||
# Normalize percentages to ensure they sum to 100%
|
||||
normalized_interests = self._normalize_percentages(raw_interests)
|
||||
|
||||
# Create interest category objects
|
||||
for category_name in self.INTEREST_CATEGORIES:
|
||||
interest_data = normalized_interests[category_name]
|
||||
|
||||
# Calculate trending direction if we have existing data
|
||||
trending_direction = self._calculate_trending_direction(
|
||||
category_name=category_name,
|
||||
current_percentage=interest_data.percentage,
|
||||
existing_distribution=existing_distribution
|
||||
) if existing_distribution else interest_data.trending_direction
|
||||
|
||||
interest_categories[category_name] = InterestCategory(
|
||||
category_name=category_name,
|
||||
percentage=interest_data.percentage,
|
||||
evidence=interest_data.evidence if interest_data.evidence else ["No specific evidence found"],
|
||||
trending_direction=trending_direction
|
||||
)
|
||||
|
||||
# Create interest area distribution
|
||||
distribution = InterestAreaDistribution(
|
||||
user_id=user_id,
|
||||
tech=interest_categories["tech"],
|
||||
lifestyle=interest_categories["lifestyle"],
|
||||
music=interest_categories["music"],
|
||||
art=interest_categories["art"],
|
||||
analysis_timestamp=current_time,
|
||||
total_summaries_analyzed=len(user_summaries)
|
||||
)
|
||||
|
||||
# Validate that percentages sum to 100%
|
||||
total_percentage = distribution.total_percentage
|
||||
if not (99.9 <= total_percentage <= 100.1):
|
||||
logger.warning(f"Interest percentages sum to {total_percentage}, expected ~100%")
|
||||
|
||||
logger.info(f"Created interest distribution for user {user_id}")
|
||||
return distribution
|
||||
|
||||
except LLMClientException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Interest analysis failed for user {user_id}: {e}")
|
||||
raise LLMClientException(f"Interest analysis failed: {e}") from e
|
||||
|
||||
def _normalize_percentages(self, raw_interests: Dict[str, InterestData]) -> Dict[str, InterestData]:
|
||||
"""Normalize percentages to ensure they sum to 100%.
|
||||
|
||||
Args:
|
||||
raw_interests: Raw interest data with potentially unnormalized percentages
|
||||
|
||||
Returns:
|
||||
Normalized interest data
|
||||
"""
|
||||
# Calculate current total
|
||||
total = sum(interest.percentage for interest in raw_interests.values())
|
||||
|
||||
if total == 0:
|
||||
# If all percentages are 0, distribute equally
|
||||
equal_percentage = 100.0 / len(self.INTEREST_CATEGORIES)
|
||||
normalized = {}
|
||||
for category_name, interest_data in raw_interests.items():
|
||||
normalized[category_name] = InterestData(
|
||||
percentage=equal_percentage,
|
||||
evidence=interest_data.evidence,
|
||||
trending_direction=interest_data.trending_direction
|
||||
)
|
||||
return normalized
|
||||
|
||||
# Normalize to sum to 100%
|
||||
normalization_factor = 100.0 / total
|
||||
normalized = {}
|
||||
|
||||
for category_name, interest_data in raw_interests.items():
|
||||
normalized_percentage = interest_data.percentage * normalization_factor
|
||||
|
||||
normalized[category_name] = InterestData(
|
||||
percentage=round(normalized_percentage, 1),
|
||||
evidence=interest_data.evidence,
|
||||
trending_direction=interest_data.trending_direction
|
||||
)
|
||||
|
||||
# Handle rounding errors by adjusting the largest category
|
||||
current_total = sum(interest.percentage for interest in normalized.values())
|
||||
if abs(current_total - 100.0) > 0.1:
|
||||
# Find category with largest percentage and adjust
|
||||
largest_category = max(normalized.keys(), key=lambda k: normalized[k].percentage)
|
||||
adjustment = 100.0 - current_total
|
||||
|
||||
adjusted_percentage = normalized[largest_category].percentage + adjustment
|
||||
normalized[largest_category] = InterestData(
|
||||
percentage=round(max(0.0, adjusted_percentage), 1),
|
||||
evidence=normalized[largest_category].evidence,
|
||||
trending_direction=normalized[largest_category].trending_direction
|
||||
)
|
||||
|
||||
return normalized
|
||||
|
||||
def _calculate_trending_direction(
|
||||
self,
|
||||
category_name: str,
|
||||
current_percentage: float,
|
||||
existing_distribution: InterestAreaDistribution,
|
||||
threshold: float = 5.0
|
||||
) -> Optional[str]:
|
||||
"""Calculate trending direction for an interest category.
|
||||
|
||||
Args:
|
||||
category_name: Name of the interest category
|
||||
current_percentage: Current percentage for the category
|
||||
existing_distribution: Previous distribution for comparison
|
||||
threshold: Minimum percentage change to consider a trend
|
||||
|
||||
Returns:
|
||||
Trending direction: "increasing", "decreasing", "stable", or None
|
||||
"""
|
||||
try:
|
||||
# Get previous percentage
|
||||
previous_category = getattr(existing_distribution, category_name, None)
|
||||
if not previous_category:
|
||||
return None
|
||||
|
||||
previous_percentage = previous_category.percentage
|
||||
change = current_percentage - previous_percentage
|
||||
|
||||
if abs(change) < threshold:
|
||||
return "stable"
|
||||
elif change > 0:
|
||||
return "increasing"
|
||||
else:
|
||||
return "decreasing"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error calculating trending direction for {category_name}: {e}")
|
||||
return None
|
||||
|
||||
def _create_empty_distribution(self, user_id: str) -> InterestAreaDistribution:
|
||||
"""Create an empty interest distribution when no data is available.
|
||||
|
||||
Args:
|
||||
user_id: Target user ID
|
||||
|
||||
Returns:
|
||||
Empty InterestAreaDistribution with equal percentages
|
||||
"""
|
||||
current_time = datetime.now()
|
||||
equal_percentage = 25.0 # 100% / 4 categories
|
||||
|
||||
default_category = lambda name: InterestCategory(
|
||||
category_name=name,
|
||||
percentage=equal_percentage,
|
||||
evidence=["Insufficient data for analysis"],
|
||||
trending_direction=None
|
||||
)
|
||||
|
||||
return InterestAreaDistribution(
|
||||
user_id=user_id,
|
||||
tech=default_category("tech"),
|
||||
lifestyle=default_category("lifestyle"),
|
||||
music=default_category("music"),
|
||||
art=default_category("art"),
|
||||
analysis_timestamp=current_time,
|
||||
total_summaries_analyzed=0
|
||||
)
|
||||
@@ -0,0 +1,302 @@
|
||||
"""Preference Analyzer for Implicit Memory System
|
||||
|
||||
This module implements LLM-based preference extraction from user memory summaries.
|
||||
It identifies implicit preferences, consolidates similar preferences, and calculates
|
||||
confidence scores based on evidence strength.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from app.core.memory.analytics.implicit_memory.llm_client import ImplicitMemoryLLMClient
|
||||
from app.core.memory.llm_tools.llm_client import LLMClientException
|
||||
from app.schemas.implicit_memory_schema import (
|
||||
PreferenceTag,
|
||||
UserMemorySummary,
|
||||
)
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PreferenceAnalysisResponse(BaseModel):
|
||||
"""Response model for preference analysis."""
|
||||
preferences: List[Dict[str, Any]] = Field(default_factory=list)
|
||||
|
||||
|
||||
class PreferenceAnalyzer:
|
||||
"""Analyzes user memory summaries to extract implicit preferences."""
|
||||
|
||||
def __init__(self, db: Session, llm_model_id: Optional[str] = None):
|
||||
"""Initialize the preference analyzer.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
llm_model_id: Optional LLM model ID to use for analysis
|
||||
"""
|
||||
self.db = db
|
||||
self.llm_model_id = llm_model_id
|
||||
self._llm_client = ImplicitMemoryLLMClient(db, llm_model_id)
|
||||
|
||||
async def analyze_preferences(
|
||||
self,
|
||||
user_id: str,
|
||||
user_summaries: List[UserMemorySummary],
|
||||
existing_preferences: Optional[List[PreferenceTag]] = None
|
||||
) -> List[PreferenceTag]:
|
||||
"""Analyze user summaries to extract preferences.
|
||||
|
||||
Args:
|
||||
user_id: Target user ID
|
||||
user_summaries: List of user-specific memory summaries
|
||||
existing_preferences: Optional existing preferences for consolidation
|
||||
|
||||
Returns:
|
||||
List of extracted preference tags
|
||||
|
||||
Raises:
|
||||
LLMClientException: If LLM analysis fails
|
||||
"""
|
||||
if not user_summaries:
|
||||
logger.warning(f"No summaries provided for user {user_id}")
|
||||
return []
|
||||
|
||||
try:
|
||||
logger.info(f"Analyzing preferences for user {user_id} with {len(user_summaries)} summaries")
|
||||
|
||||
# Use the LLM client wrapper for analysis
|
||||
response = await self._llm_client.analyze_preferences(
|
||||
user_summaries=user_summaries,
|
||||
user_id=user_id,
|
||||
model_id=self.llm_model_id
|
||||
)
|
||||
|
||||
# Convert to PreferenceTag objects
|
||||
preference_tags = []
|
||||
current_time = datetime.now()
|
||||
|
||||
for pref_data in response.get("preferences", []):
|
||||
try:
|
||||
# Extract conversation references from summaries
|
||||
conversation_refs = [s.summary_id for s in user_summaries]
|
||||
|
||||
preference_tag = PreferenceTag(
|
||||
tag_name=pref_data.get("tag_name", ""),
|
||||
confidence_score=float(pref_data.get("confidence_score", 0.0)),
|
||||
supporting_evidence=pref_data.get("supporting_evidence", []),
|
||||
context_details=pref_data.get("context_details", ""),
|
||||
category=pref_data.get("category"),
|
||||
conversation_references=conversation_refs,
|
||||
created_at=current_time,
|
||||
updated_at=current_time
|
||||
)
|
||||
|
||||
# Validate preference tag
|
||||
if self._is_valid_preference(preference_tag):
|
||||
preference_tags.append(preference_tag)
|
||||
else:
|
||||
logger.warning(f"Invalid preference tag skipped: {preference_tag.tag_name}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating preference tag: {e}")
|
||||
continue
|
||||
|
||||
# Consolidate with existing preferences if provided
|
||||
if existing_preferences:
|
||||
preference_tags = self._consolidate_preferences(
|
||||
new_preferences=preference_tags,
|
||||
existing_preferences=existing_preferences
|
||||
)
|
||||
|
||||
logger.info(f"Extracted {len(preference_tags)} preferences for user {user_id}")
|
||||
return preference_tags
|
||||
|
||||
except LLMClientException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Preference analysis failed for user {user_id}: {e}")
|
||||
raise LLMClientException(f"Preference analysis failed: {e}") from e
|
||||
|
||||
def _is_valid_preference(self, preference: PreferenceTag) -> bool:
|
||||
"""Validate a preference tag.
|
||||
|
||||
Args:
|
||||
preference: Preference tag to validate
|
||||
|
||||
Returns:
|
||||
True if valid, False otherwise
|
||||
"""
|
||||
try:
|
||||
# Check required fields
|
||||
if not preference.tag_name or not preference.tag_name.strip():
|
||||
return False
|
||||
|
||||
# Check confidence score range
|
||||
if not (0.0 <= preference.confidence_score <= 1.0):
|
||||
return False
|
||||
|
||||
# Check supporting evidence
|
||||
if not preference.supporting_evidence or len(preference.supporting_evidence) == 0:
|
||||
return False
|
||||
|
||||
# Check context details
|
||||
if not preference.context_details or not preference.context_details.strip():
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error validating preference: {e}")
|
||||
return False
|
||||
|
||||
def _consolidate_preferences(
|
||||
self,
|
||||
new_preferences: List[PreferenceTag],
|
||||
existing_preferences: List[PreferenceTag],
|
||||
similarity_threshold: float = 0.8
|
||||
) -> List[PreferenceTag]:
|
||||
"""Consolidate new preferences with existing ones.
|
||||
|
||||
Args:
|
||||
new_preferences: Newly extracted preferences
|
||||
existing_preferences: Existing preferences
|
||||
similarity_threshold: Threshold for considering preferences similar
|
||||
|
||||
Returns:
|
||||
Consolidated list of preferences
|
||||
"""
|
||||
consolidated = existing_preferences.copy()
|
||||
current_time = datetime.now()
|
||||
|
||||
for new_pref in new_preferences:
|
||||
# Find similar existing preference
|
||||
similar_pref = self._find_similar_preference(
|
||||
new_pref, existing_preferences, similarity_threshold
|
||||
)
|
||||
|
||||
if similar_pref:
|
||||
# Update existing preference
|
||||
updated_pref = self._merge_preferences(similar_pref, new_pref, current_time)
|
||||
# Replace in consolidated list
|
||||
for i, pref in enumerate(consolidated):
|
||||
if pref.tag_name == similar_pref.tag_name:
|
||||
consolidated[i] = updated_pref
|
||||
break
|
||||
else:
|
||||
# Add as new preference
|
||||
consolidated.append(new_pref)
|
||||
|
||||
return consolidated
|
||||
|
||||
def _find_similar_preference(
|
||||
self,
|
||||
target_preference: PreferenceTag,
|
||||
existing_preferences: List[PreferenceTag],
|
||||
threshold: float
|
||||
) -> Optional[PreferenceTag]:
|
||||
"""Find similar preference in existing list.
|
||||
|
||||
Args:
|
||||
target_preference: Preference to find similarity for
|
||||
existing_preferences: List of existing preferences
|
||||
threshold: Similarity threshold
|
||||
|
||||
Returns:
|
||||
Similar preference if found, None otherwise
|
||||
"""
|
||||
target_name = target_preference.tag_name.lower().strip()
|
||||
|
||||
for existing_pref in existing_preferences:
|
||||
existing_name = existing_pref.tag_name.lower().strip()
|
||||
|
||||
# Simple similarity check based on common words
|
||||
similarity = self._calculate_text_similarity(target_name, existing_name)
|
||||
|
||||
if similarity >= threshold:
|
||||
return existing_pref
|
||||
|
||||
return None
|
||||
|
||||
def _calculate_text_similarity(self, text1: str, text2: str) -> float:
|
||||
"""Calculate simple text similarity based on common words.
|
||||
|
||||
Args:
|
||||
text1: First text
|
||||
text2: Second text
|
||||
|
||||
Returns:
|
||||
Similarity score between 0.0 and 1.0
|
||||
"""
|
||||
if not text1 or not text2:
|
||||
return 0.0
|
||||
|
||||
# Simple word-based similarity
|
||||
words1 = set(text1.lower().split())
|
||||
words2 = set(text2.lower().split())
|
||||
|
||||
if not words1 or not words2:
|
||||
return 0.0
|
||||
|
||||
intersection = words1.intersection(words2)
|
||||
union = words1.union(words2)
|
||||
|
||||
return len(intersection) / len(union) if union else 0.0
|
||||
|
||||
def _merge_preferences(
|
||||
self,
|
||||
existing_pref: PreferenceTag,
|
||||
new_pref: PreferenceTag,
|
||||
current_time: datetime
|
||||
) -> PreferenceTag:
|
||||
"""Merge two similar preferences.
|
||||
|
||||
Args:
|
||||
existing_pref: Existing preference
|
||||
new_pref: New preference to merge
|
||||
current_time: Current timestamp
|
||||
|
||||
Returns:
|
||||
Merged preference tag
|
||||
"""
|
||||
# Combine supporting evidence
|
||||
combined_evidence = list(set(
|
||||
existing_pref.supporting_evidence + new_pref.supporting_evidence
|
||||
))
|
||||
|
||||
# Combine conversation references
|
||||
combined_refs = list(set(
|
||||
existing_pref.conversation_references + new_pref.conversation_references
|
||||
))
|
||||
|
||||
# Calculate new confidence score (weighted average)
|
||||
evidence_weight = len(new_pref.supporting_evidence)
|
||||
total_weight = len(existing_pref.supporting_evidence) + evidence_weight
|
||||
|
||||
if total_weight > 0:
|
||||
new_confidence = (
|
||||
(existing_pref.confidence_score * len(existing_pref.supporting_evidence) +
|
||||
new_pref.confidence_score * evidence_weight) / total_weight
|
||||
)
|
||||
else:
|
||||
new_confidence = max(existing_pref.confidence_score, new_pref.confidence_score)
|
||||
|
||||
# Ensure confidence doesn't exceed 1.0
|
||||
new_confidence = min(new_confidence, 1.0)
|
||||
|
||||
# Combine context details
|
||||
combined_context = existing_pref.context_details
|
||||
if new_pref.context_details and new_pref.context_details not in combined_context:
|
||||
combined_context += f"; {new_pref.context_details}"
|
||||
|
||||
return PreferenceTag(
|
||||
tag_name=existing_pref.tag_name, # Keep original name
|
||||
confidence_score=new_confidence,
|
||||
supporting_evidence=combined_evidence,
|
||||
context_details=combined_context,
|
||||
category=existing_pref.category or new_pref.category,
|
||||
conversation_references=combined_refs,
|
||||
created_at=existing_pref.created_at,
|
||||
updated_at=current_time
|
||||
)
|
||||
97
api/app/core/memory/analytics/implicit_memory/data_source.py
Normal file
97
api/app/core/memory/analytics/implicit_memory/data_source.py
Normal file
@@ -0,0 +1,97 @@
|
||||
"""
|
||||
Memory Data Source
|
||||
|
||||
Handles retrieval and processing of memory data from Neo4j using direct Cypher queries.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from app.repositories.neo4j.memory_summary_repository import MemorySummaryRepository
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.schemas.implicit_memory_schema import TimeRange, UserMemorySummary
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MemoryDataSource:
|
||||
"""Retrieves processed memory data from Neo4j using direct Cypher queries."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
db: Session,
|
||||
neo4j_connector: Optional[Neo4jConnector] = None
|
||||
):
|
||||
self.db = db
|
||||
self.neo4j_connector = neo4j_connector or Neo4jConnector()
|
||||
self.memory_summary_repo = MemorySummaryRepository(self.neo4j_connector)
|
||||
|
||||
def _parse_timestamp(self, timestamp: Any) -> datetime:
|
||||
"""Parse timestamp from various formats."""
|
||||
if isinstance(timestamp, str):
|
||||
return datetime.fromisoformat(timestamp.replace('Z', '+00:00'))
|
||||
elif timestamp is None:
|
||||
return datetime.now()
|
||||
return timestamp
|
||||
|
||||
def _dict_to_user_summary(self, summary_dict: Dict, user_id: str) -> Optional[UserMemorySummary]:
|
||||
"""Convert a Neo4j dict directly to UserMemorySummary."""
|
||||
try:
|
||||
content = summary_dict.get("content", summary_dict.get("summary", ""))
|
||||
if not content or not content.strip():
|
||||
return None
|
||||
|
||||
return UserMemorySummary(
|
||||
summary_id=summary_dict.get("id", summary_dict.get("uuid", "")),
|
||||
user_id=user_id,
|
||||
user_content=content,
|
||||
timestamp=self._parse_timestamp(summary_dict.get("created_at")),
|
||||
confidence_score=1.0,
|
||||
summary_type="memory_summary"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to parse summary {summary_dict.get('id', 'unknown')}: {e}")
|
||||
return None
|
||||
|
||||
async def get_user_summaries(
|
||||
self,
|
||||
user_id: str,
|
||||
time_range: Optional[TimeRange] = None,
|
||||
limit: int = 1000
|
||||
) -> List[UserMemorySummary]:
|
||||
"""Retrieve user memory summaries from Neo4j.
|
||||
|
||||
Args:
|
||||
user_id: Target user ID
|
||||
time_range: Optional time range filter
|
||||
limit: Maximum number of summaries
|
||||
|
||||
Returns:
|
||||
List of user memory summaries
|
||||
"""
|
||||
try:
|
||||
start_date = time_range.start_date if time_range else None
|
||||
end_date = time_range.end_date if time_range else None
|
||||
|
||||
summary_dicts = await self.memory_summary_repo.find_by_group_id(
|
||||
group_id=user_id,
|
||||
limit=limit,
|
||||
start_date=start_date,
|
||||
end_date=end_date
|
||||
)
|
||||
|
||||
summaries = []
|
||||
for summary_dict in summary_dicts:
|
||||
summary = self._dict_to_user_summary(summary_dict, user_id)
|
||||
if summary:
|
||||
summaries.append(summary)
|
||||
|
||||
logger.info(f"Retrieved {len(summaries)} summaries for user {user_id}")
|
||||
return summaries
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to retrieve summaries for user {user_id}: {e}")
|
||||
raise
|
||||
|
||||
226
api/app/core/memory/analytics/implicit_memory/habit_detector.py
Normal file
226
api/app/core/memory/analytics/implicit_memory/habit_detector.py
Normal file
@@ -0,0 +1,226 @@
|
||||
"""Habit Detector for Implicit Memory System
|
||||
|
||||
This module implements the HabitDetector class that specializes in identifying
|
||||
and ranking behavioral habits from user memory summaries. It provides advanced
|
||||
habit analysis with confidence scoring, recency weighting, and current vs past
|
||||
habit distinction.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import datetime, timedelta
|
||||
from typing import List, Optional
|
||||
|
||||
from app.core.memory.analytics.implicit_memory.analyzers.habit_analyzer import (
|
||||
HabitAnalyzer,
|
||||
)
|
||||
from app.core.memory.llm_tools.llm_client import LLMClientException
|
||||
from app.schemas.implicit_memory_schema import (
|
||||
BehaviorHabit,
|
||||
FrequencyPattern,
|
||||
UserMemorySummary,
|
||||
)
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HabitDetector:
|
||||
"""Detects and ranks behavioral habits from user memory summaries."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
db: Session,
|
||||
llm_model_id: Optional[str] = None
|
||||
):
|
||||
"""Initialize the habit detector.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
llm_model_id: Optional LLM model ID to use for analysis
|
||||
"""
|
||||
self.db = db
|
||||
self.llm_model_id = llm_model_id
|
||||
self.habit_analyzer = HabitAnalyzer(db, llm_model_id)
|
||||
|
||||
async def detect_habits(
|
||||
self,
|
||||
user_id: str,
|
||||
user_summaries: List[UserMemorySummary],
|
||||
existing_habits: Optional[List[BehaviorHabit]] = None
|
||||
) -> List[BehaviorHabit]:
|
||||
"""Detect behavioral habits from user summaries.
|
||||
|
||||
Args:
|
||||
user_id: Target user ID
|
||||
user_summaries: List of user-specific memory summaries
|
||||
existing_habits: Optional existing habits for consolidation
|
||||
|
||||
Returns:
|
||||
List of detected and ranked behavioral habits
|
||||
|
||||
Raises:
|
||||
LLMClientException: If habit analysis fails
|
||||
"""
|
||||
if not user_summaries:
|
||||
logger.warning(f"No summaries provided for user {user_id}")
|
||||
return existing_habits or []
|
||||
|
||||
logger.info(f"Detecting habits for user {user_id} with {len(user_summaries)} summaries")
|
||||
|
||||
try:
|
||||
# Use the habit analyzer to extract habits
|
||||
detected_habits = await self.habit_analyzer.analyze_habits(
|
||||
user_id=user_id,
|
||||
user_summaries=user_summaries,
|
||||
existing_habits=existing_habits
|
||||
)
|
||||
|
||||
# Apply advanced ranking and filtering
|
||||
ranked_habits = self.rank_habits_by_confidence_and_recency(detected_habits)
|
||||
|
||||
# Distinguish current vs past habits
|
||||
categorized_habits = self.distinguish_current_vs_past_habits(ranked_habits)
|
||||
|
||||
logger.info(f"Detected {len(categorized_habits)} habits for user {user_id}")
|
||||
return categorized_habits
|
||||
|
||||
except LLMClientException:
|
||||
logger.error(f"Habit detection failed for user {user_id}")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Habit detection failed for user {user_id}: {e}")
|
||||
raise LLMClientException(f"Habit detection failed: {e}") from e
|
||||
|
||||
def rank_habits_by_confidence_and_recency(
|
||||
self,
|
||||
habits: List[BehaviorHabit],
|
||||
confidence_weight: float = 0.6,
|
||||
recency_weight: float = 0.4
|
||||
) -> List[BehaviorHabit]:
|
||||
"""Rank habits by confidence level and recency.
|
||||
|
||||
Args:
|
||||
habits: List of habits to rank
|
||||
confidence_weight: Weight for confidence score (0.0-1.0)
|
||||
recency_weight: Weight for recency score (0.0-1.0)
|
||||
|
||||
Returns:
|
||||
List of habits ranked by combined score
|
||||
"""
|
||||
if not habits:
|
||||
return []
|
||||
|
||||
logger.info(f"Ranking {len(habits)} habits by confidence and recency")
|
||||
|
||||
def calculate_ranking_score(habit: BehaviorHabit) -> float:
|
||||
"""Calculate combined ranking score for a habit."""
|
||||
|
||||
# Confidence score (0.0-1.0) - convert from 0-100 scale
|
||||
confidence_score = habit.confidence_level / 100.0
|
||||
|
||||
# Recency score (0.0-1.0)
|
||||
current_time = datetime.now()
|
||||
days_since_last = (current_time - habit.last_observed).days
|
||||
|
||||
# Exponential decay for recency (habits lose relevance over time)
|
||||
if days_since_last <= 7:
|
||||
recency_score = 1.0 # Very recent
|
||||
elif days_since_last <= 30:
|
||||
recency_score = 0.8 # Recent
|
||||
elif days_since_last <= 90:
|
||||
recency_score = 0.5 # Somewhat recent
|
||||
elif days_since_last <= 180:
|
||||
recency_score = 0.3 # Old
|
||||
else:
|
||||
recency_score = 0.1 # Very old
|
||||
|
||||
# Frequency pattern bonus
|
||||
frequency_bonuses = {
|
||||
FrequencyPattern.DAILY: 0.2,
|
||||
FrequencyPattern.WEEKLY: 0.15,
|
||||
FrequencyPattern.MONTHLY: 0.1,
|
||||
FrequencyPattern.SEASONAL: 0.05,
|
||||
FrequencyPattern.OCCASIONAL: 0.0,
|
||||
FrequencyPattern.EVENT_TRIGGERED: 0.05
|
||||
}
|
||||
frequency_bonus = frequency_bonuses.get(habit.frequency_pattern, 0.0)
|
||||
|
||||
# Evidence quality bonus
|
||||
evidence_bonus = min(len(habit.specific_examples) / 10.0, 0.1) # Max 0.1 bonus
|
||||
|
||||
# Current habit bonus
|
||||
current_bonus = 0.1 if habit.is_current else 0.0
|
||||
|
||||
# Calculate final score
|
||||
base_score = (confidence_score * confidence_weight +
|
||||
recency_score * recency_weight)
|
||||
|
||||
final_score = base_score + frequency_bonus + evidence_bonus + current_bonus
|
||||
|
||||
return min(final_score, 1.0) # Cap at 1.0
|
||||
|
||||
# Sort habits by ranking score (descending)
|
||||
ranked_habits = sorted(habits, key=calculate_ranking_score, reverse=True)
|
||||
|
||||
logger.info(f"Ranked habits with scores: {[calculate_ranking_score(h) for h in ranked_habits[:5]]}")
|
||||
|
||||
return ranked_habits
|
||||
|
||||
def distinguish_current_vs_past_habits(
|
||||
self,
|
||||
habits: List[BehaviorHabit],
|
||||
current_threshold_days: int = 30
|
||||
) -> List[BehaviorHabit]:
|
||||
"""Distinguish between current and past habits based on recency.
|
||||
|
||||
Args:
|
||||
habits: List of habits to categorize
|
||||
current_threshold_days: Days threshold for considering a habit current
|
||||
|
||||
Returns:
|
||||
List of habits with updated is_current status
|
||||
"""
|
||||
if not habits:
|
||||
return []
|
||||
|
||||
current_time = datetime.now()
|
||||
cutoff_date = current_time - timedelta(days=current_threshold_days)
|
||||
|
||||
current_habits = []
|
||||
past_habits = []
|
||||
|
||||
for habit in habits:
|
||||
# Update is_current status based on last observation
|
||||
if habit.last_observed >= cutoff_date:
|
||||
# Create updated habit with is_current = True
|
||||
updated_habit = BehaviorHabit(
|
||||
habit_description=habit.habit_description,
|
||||
frequency_pattern=habit.frequency_pattern,
|
||||
time_context=habit.time_context,
|
||||
confidence_level=habit.confidence_level,
|
||||
specific_examples=habit.specific_examples,
|
||||
first_observed=habit.first_observed,
|
||||
last_observed=habit.last_observed,
|
||||
is_current=True
|
||||
)
|
||||
current_habits.append(updated_habit)
|
||||
else:
|
||||
# Create updated habit with is_current = False
|
||||
updated_habit = BehaviorHabit(
|
||||
habit_description=habit.habit_description,
|
||||
frequency_pattern=habit.frequency_pattern,
|
||||
time_context=habit.time_context,
|
||||
confidence_level=habit.confidence_level,
|
||||
specific_examples=habit.specific_examples,
|
||||
first_observed=habit.first_observed,
|
||||
last_observed=habit.last_observed,
|
||||
is_current=False
|
||||
)
|
||||
past_habits.append(updated_habit)
|
||||
|
||||
# Return current habits first, then past habits
|
||||
categorized_habits = current_habits + past_habits
|
||||
|
||||
logger.info(f"Categorized habits: {len(current_habits)} current, {len(past_habits)} past")
|
||||
|
||||
return categorized_habits
|
||||
321
api/app/core/memory/analytics/implicit_memory/llm_client.py
Normal file
321
api/app/core/memory/analytics/implicit_memory/llm_client.py
Normal file
@@ -0,0 +1,321 @@
|
||||
"""LLM Client Wrapper for Implicit Memory Analysis
|
||||
|
||||
This module provides a specialized LLM client wrapper that integrates with the
|
||||
MemoryClientFactory to perform implicit memory analysis tasks including preference
|
||||
extraction, personality dimension analysis, interest categorization, and habit detection.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from app.core.memory.analytics.implicit_memory.prompts import (
|
||||
get_dimension_analysis_prompt,
|
||||
get_habit_analysis_prompt,
|
||||
get_interest_analysis_prompt,
|
||||
get_preference_analysis_prompt,
|
||||
)
|
||||
from app.core.memory.llm_tools.llm_client import LLMClientException
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.schemas.implicit_memory_schema import UserMemorySummary
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Response Models for LLM Analysis
|
||||
|
||||
class PreferenceAnalysisResponse(BaseModel):
|
||||
"""Response model for preference analysis."""
|
||||
preferences: List[Dict[str, Any]] = Field(default_factory=list)
|
||||
|
||||
|
||||
class DimensionAnalysisResponse(BaseModel):
|
||||
"""Response model for dimension analysis."""
|
||||
dimensions: Dict[str, Dict[str, Any]] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class InterestAnalysisResponse(BaseModel):
|
||||
"""Response model for interest analysis."""
|
||||
interest_distribution: Dict[str, Dict[str, Any]] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class HabitAnalysisResponse(BaseModel):
|
||||
"""Response model for habit analysis."""
|
||||
habits: List[Dict[str, Any]] = Field(default_factory=list)
|
||||
|
||||
|
||||
class ImplicitMemoryLLMClient:
|
||||
"""LLM client wrapper for implicit memory analysis.
|
||||
|
||||
This class provides a high-level interface for performing LLM-based analysis
|
||||
of user memory summaries to extract preferences, personality dimensions,
|
||||
interests, and behavioral habits.
|
||||
"""
|
||||
|
||||
def __init__(self, db: Session, default_model_id: Optional[str] = None):
|
||||
"""Initialize the LLM client wrapper.
|
||||
|
||||
Args:
|
||||
db: Database session for accessing model configurations
|
||||
default_model_id: Default LLM model ID to use if none specified
|
||||
"""
|
||||
self.db = db
|
||||
self.default_model_id = default_model_id
|
||||
self._client_factory = MemoryClientFactory(db)
|
||||
|
||||
logger.info("ImplicitMemoryLLMClient initialized")
|
||||
|
||||
def _get_llm_client(self, model_id: Optional[str] = None):
|
||||
"""Get LLM client instance.
|
||||
|
||||
Args:
|
||||
model_id: LLM model ID to use, defaults to default_model_id
|
||||
|
||||
Returns:
|
||||
LLM client instance
|
||||
|
||||
Raises:
|
||||
ValueError: If no model ID is provided and no default is set
|
||||
LLMClientException: If client creation fails
|
||||
"""
|
||||
effective_model_id = model_id or self.default_model_id
|
||||
if not effective_model_id:
|
||||
raise ValueError("No LLM model ID provided and no default model ID set")
|
||||
|
||||
try:
|
||||
client = self._client_factory.get_llm_client(effective_model_id)
|
||||
logger.debug(f"Created LLM client for model: {effective_model_id}")
|
||||
return client
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create LLM client for model {effective_model_id}: {e}")
|
||||
raise LLMClientException(f"Failed to create LLM client: {e}") from e
|
||||
|
||||
def _prepare_summaries_for_analysis(self, user_summaries: List[UserMemorySummary]) -> List[Dict[str, Any]]:
|
||||
"""Prepare user memory summaries for LLM analysis.
|
||||
|
||||
Args:
|
||||
user_summaries: List of user memory summaries
|
||||
|
||||
Returns:
|
||||
List of formatted summary dictionaries
|
||||
"""
|
||||
formatted_summaries = []
|
||||
for summary in user_summaries:
|
||||
formatted_summary = {
|
||||
'summary_id': summary.summary_id,
|
||||
'user_content': summary.user_content,
|
||||
'timestamp': summary.timestamp.isoformat(),
|
||||
'summary_type': summary.summary_type,
|
||||
'confidence_score': summary.confidence_score
|
||||
}
|
||||
formatted_summaries.append(formatted_summary)
|
||||
|
||||
logger.debug(f"Prepared {len(formatted_summaries)} summaries for analysis")
|
||||
return formatted_summaries
|
||||
|
||||
async def analyze_preferences(
|
||||
self,
|
||||
user_summaries: List[UserMemorySummary],
|
||||
user_id: str,
|
||||
model_id: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Analyze user preferences from memory summaries.
|
||||
|
||||
Args:
|
||||
user_summaries: List of user memory summaries to analyze
|
||||
user_id: Target user ID for analysis
|
||||
model_id: Optional LLM model ID to use
|
||||
|
||||
Returns:
|
||||
Dictionary containing extracted preferences
|
||||
|
||||
Raises:
|
||||
LLMClientException: If LLM analysis fails
|
||||
ValueError: If input validation fails
|
||||
"""
|
||||
if not user_summaries:
|
||||
logger.warning(f"No summaries provided for preference analysis of user {user_id}")
|
||||
return {"preferences": []}
|
||||
|
||||
if not user_id:
|
||||
raise ValueError("User ID is required for preference analysis")
|
||||
|
||||
try:
|
||||
# Prepare summaries and get prompt
|
||||
formatted_summaries = self._prepare_summaries_for_analysis(user_summaries)
|
||||
prompt = get_preference_analysis_prompt(formatted_summaries, user_id)
|
||||
|
||||
# Get LLM client and perform analysis
|
||||
llm_client = self._get_llm_client(model_id)
|
||||
|
||||
messages = [{"role": "user", "content": prompt}]
|
||||
|
||||
# Use structured output for reliable parsing
|
||||
response = await llm_client.response_structured(
|
||||
messages=messages,
|
||||
response_model=PreferenceAnalysisResponse
|
||||
)
|
||||
|
||||
result = response.model_dump()
|
||||
logger.info(f"Analyzed preferences for user {user_id}: found {len(result.get('preferences', []))} preferences")
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Preference analysis failed for user {user_id}: {e}")
|
||||
raise LLMClientException(f"Preference analysis failed: {e}") from e
|
||||
|
||||
async def analyze_dimensions(
|
||||
self,
|
||||
user_summaries: List[UserMemorySummary],
|
||||
user_id: str,
|
||||
model_id: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Analyze user personality dimensions from memory summaries.
|
||||
|
||||
Args:
|
||||
user_summaries: List of user memory summaries to analyze
|
||||
user_id: Target user ID for analysis
|
||||
model_id: Optional LLM model ID to use
|
||||
|
||||
Returns:
|
||||
Dictionary containing dimension scores and analysis
|
||||
|
||||
Raises:
|
||||
LLMClientException: If LLM analysis fails
|
||||
ValueError: If input validation fails
|
||||
"""
|
||||
if not user_summaries:
|
||||
logger.warning(f"No summaries provided for dimension analysis of user {user_id}")
|
||||
return {"dimensions": {}}
|
||||
|
||||
if not user_id:
|
||||
raise ValueError("User ID is required for dimension analysis")
|
||||
|
||||
try:
|
||||
# Prepare summaries and get prompt
|
||||
formatted_summaries = self._prepare_summaries_for_analysis(user_summaries)
|
||||
prompt = get_dimension_analysis_prompt(formatted_summaries, user_id)
|
||||
|
||||
# Get LLM client and perform analysis
|
||||
llm_client = self._get_llm_client(model_id)
|
||||
|
||||
messages = [{"role": "user", "content": prompt}]
|
||||
|
||||
# Use structured output for reliable parsing
|
||||
response = await llm_client.response_structured(
|
||||
messages=messages,
|
||||
response_model=DimensionAnalysisResponse
|
||||
)
|
||||
|
||||
result = response.model_dump()
|
||||
dimensions = result.get('dimensions', {})
|
||||
logger.info(f"Analyzed dimensions for user {user_id}: {list(dimensions.keys())}")
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Dimension analysis failed for user {user_id}: {e}")
|
||||
raise LLMClientException(f"Dimension analysis failed: {e}") from e
|
||||
|
||||
async def analyze_interests(
|
||||
self,
|
||||
user_summaries: List[UserMemorySummary],
|
||||
user_id: str,
|
||||
model_id: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Analyze user interest distribution from memory summaries.
|
||||
|
||||
Args:
|
||||
user_summaries: List of user memory summaries to analyze
|
||||
user_id: Target user ID for analysis
|
||||
model_id: Optional LLM model ID to use
|
||||
|
||||
Returns:
|
||||
Dictionary containing interest area distribution
|
||||
|
||||
Raises:
|
||||
LLMClientException: If LLM analysis fails
|
||||
ValueError: If input validation fails
|
||||
"""
|
||||
if not user_summaries:
|
||||
logger.warning(f"No summaries provided for interest analysis of user {user_id}")
|
||||
return {"interest_distribution": {}}
|
||||
|
||||
if not user_id:
|
||||
raise ValueError("User ID is required for interest analysis")
|
||||
|
||||
try:
|
||||
# Prepare summaries and get prompt
|
||||
formatted_summaries = self._prepare_summaries_for_analysis(user_summaries)
|
||||
prompt = get_interest_analysis_prompt(formatted_summaries, user_id)
|
||||
|
||||
# Get LLM client and perform analysis
|
||||
llm_client = self._get_llm_client(model_id)
|
||||
|
||||
messages = [{"role": "user", "content": prompt}]
|
||||
|
||||
# Use structured output for reliable parsing
|
||||
response = await llm_client.response_structured(
|
||||
messages=messages,
|
||||
response_model=InterestAnalysisResponse
|
||||
)
|
||||
|
||||
result = response.model_dump()
|
||||
interest_dist = result.get('interest_distribution', {})
|
||||
logger.info(f"Analyzed interests for user {user_id}: {list(interest_dist.keys())}")
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Interest analysis failed for user {user_id}: {e}")
|
||||
raise LLMClientException(f"Interest analysis failed: {e}") from e
|
||||
|
||||
async def analyze_habits(
|
||||
self,
|
||||
user_summaries: List[UserMemorySummary],
|
||||
user_id: str,
|
||||
model_id: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Analyze user behavioral habits from memory summaries.
|
||||
|
||||
Args:
|
||||
user_summaries: List of user memory summaries to analyze
|
||||
user_id: Target user ID for analysis
|
||||
model_id: Optional LLM model ID to use
|
||||
|
||||
Returns:
|
||||
Dictionary containing identified behavioral habits
|
||||
|
||||
Raises:
|
||||
LLMClientException: If LLM analysis fails
|
||||
ValueError: If input validation fails
|
||||
"""
|
||||
if not user_summaries:
|
||||
logger.warning(f"No summaries provided for habit analysis of user {user_id}")
|
||||
return {"habits": []}
|
||||
|
||||
if not user_id:
|
||||
raise ValueError("User ID is required for habit analysis")
|
||||
|
||||
try:
|
||||
# Prepare summaries and get prompt
|
||||
formatted_summaries = self._prepare_summaries_for_analysis(user_summaries)
|
||||
prompt = get_habit_analysis_prompt(formatted_summaries, user_id)
|
||||
|
||||
# Get LLM client and perform analysis
|
||||
llm_client = self._get_llm_client(model_id)
|
||||
|
||||
messages = [{"role": "user", "content": prompt}]
|
||||
|
||||
# Use structured output for reliable parsing
|
||||
response = await llm_client.response_structured(
|
||||
messages=messages,
|
||||
response_model=HabitAnalysisResponse
|
||||
)
|
||||
|
||||
result = response.model_dump()
|
||||
logger.info(f"Analyzed habits for user {user_id}: found {len(result.get('habits', []))} habits")
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Habit analysis failed for user {user_id}: {e}")
|
||||
raise LLMClientException(f"Habit analysis failed: {e}") from e
|
||||
69
api/app/core/memory/analytics/implicit_memory/prompts.py
Normal file
69
api/app/core/memory/analytics/implicit_memory/prompts.py
Normal file
@@ -0,0 +1,69 @@
|
||||
"""LLM Prompt Templates for Implicit Memory Analysis
|
||||
|
||||
This module contains prompt rendering functions for analyzing user memory summaries
|
||||
to extract preferences, personality dimensions, interests, and behavioral habits.
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from jinja2 import Environment, FileSystemLoader
|
||||
|
||||
# Setup Jinja2 environment
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
prompt_dir = os.path.join(current_dir, "prompts")
|
||||
prompt_env = Environment(loader=FileSystemLoader(prompt_dir))
|
||||
|
||||
|
||||
def _render_template(template_name: str, **kwargs) -> str:
|
||||
"""Helper function to render Jinja2 templates."""
|
||||
template = prompt_env.get_template(template_name)
|
||||
return template.render(**kwargs)
|
||||
|
||||
|
||||
def get_preference_analysis_prompt(
|
||||
memory_summaries: List[Dict[str, Any]],
|
||||
user_id: str
|
||||
) -> str:
|
||||
"""Get formatted preference analysis prompt using Jinja2 template."""
|
||||
return _render_template(
|
||||
"preference_analysis.jinja2",
|
||||
memory_summaries=memory_summaries,
|
||||
user_id=user_id
|
||||
)
|
||||
|
||||
|
||||
def get_dimension_analysis_prompt(
|
||||
memory_summaries: List[Dict[str, Any]],
|
||||
user_id: str
|
||||
) -> str:
|
||||
"""Get formatted dimension analysis prompt using Jinja2 template."""
|
||||
return _render_template(
|
||||
"dimension_analysis.jinja2",
|
||||
memory_summaries=memory_summaries,
|
||||
user_id=user_id
|
||||
)
|
||||
|
||||
|
||||
def get_interest_analysis_prompt(
|
||||
memory_summaries: List[Dict[str, Any]],
|
||||
user_id: str
|
||||
) -> str:
|
||||
"""Get formatted interest analysis prompt using Jinja2 template."""
|
||||
return _render_template(
|
||||
"interest_analysis.jinja2",
|
||||
memory_summaries=memory_summaries,
|
||||
user_id=user_id
|
||||
)
|
||||
|
||||
|
||||
def get_habit_analysis_prompt(
|
||||
memory_summaries: List[Dict[str, Any]],
|
||||
user_id: str
|
||||
) -> str:
|
||||
"""Get formatted habit analysis prompt using Jinja2 template."""
|
||||
return _render_template(
|
||||
"habit_analysis.jinja2",
|
||||
memory_summaries=memory_summaries,
|
||||
user_id=user_id
|
||||
)
|
||||
@@ -0,0 +1,41 @@
|
||||
You are an expert personality analyst. Analyze memory summaries to assess the user's personality across four dimensions.
|
||||
|
||||
## Memory Summaries
|
||||
{% for summary in memory_summaries %}
|
||||
Summary {{ loop.index }}:
|
||||
{{ summary.content or summary.user_content or '' }}
|
||||
---
|
||||
{% endfor %}
|
||||
|
||||
## Target User ID
|
||||
{{ user_id }}
|
||||
|
||||
## Dimensions to Analyze
|
||||
1. **Creativity** (0-100%): Creative thinking, artistic interests, innovative ideas
|
||||
2. **Aesthetic** (0-100%): Design preferences, visual interests, artistic appreciation
|
||||
3. **Technology** (0-100%): Technical discussions, tool usage, programming interests
|
||||
4. **Literature** (0-100%): Reading habits, writing style, literary references
|
||||
|
||||
## Instructions
|
||||
1. Analyze the user's content for each dimension
|
||||
2. Calculate percentage scores (0-100%)
|
||||
|
||||
## Output Format
|
||||
{
|
||||
"dimensions": {
|
||||
"creativity": {"percentage": 0-100},
|
||||
"aesthetic": {"percentage": 0-100},
|
||||
"technology": {"percentage": 0-100},
|
||||
"literature": {"percentage": 0-100}
|
||||
}
|
||||
}
|
||||
|
||||
## Example
|
||||
{
|
||||
"dimensions": {
|
||||
"creativity": {"percentage": 75},
|
||||
"aesthetic": {"percentage": 45},
|
||||
"technology": {"percentage": 60},
|
||||
"literature": {"percentage": 30}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,70 @@
|
||||
You are an expert at identifying behavioral patterns and habits from memory summaries.
|
||||
|
||||
## Memory Summaries
|
||||
{% for summary in memory_summaries %}
|
||||
Summary {{ loop.index }}:
|
||||
{{ summary.content or summary.user_content or '' }}
|
||||
---
|
||||
{% endfor %}
|
||||
|
||||
## Target User ID
|
||||
{{ user_id }}
|
||||
|
||||
## Instructions
|
||||
1. Identify recurring behavioral patterns mentioned by the SPECIFIED USER
|
||||
2. Focus on specific, concrete habits with temporal patterns
|
||||
3. For each habit, provide:
|
||||
- habit_description: Clear, specific description
|
||||
- frequency_pattern: "daily", "weekly", "monthly", "seasonal", "occasional", "event_triggered"
|
||||
- time_context: When it typically happens
|
||||
- confidence_level: "high", "medium", "low"
|
||||
- supporting_summaries: References to evidence
|
||||
- specific_examples: Concrete examples from summaries
|
||||
- is_current: true if current habit, false if past habit
|
||||
4. Only include habits with medium or high confidence
|
||||
5. **IMPORTANT: Output language MUST match the input language. If summaries are in Chinese, output in Chinese. If in English, output in English.**
|
||||
|
||||
## Output Format
|
||||
{
|
||||
"habits": [
|
||||
{
|
||||
"habit_description": "string",
|
||||
"frequency_pattern": "daily|weekly|monthly|seasonal|occasional|event_triggered",
|
||||
"time_context": "string",
|
||||
"confidence_level": "high|medium|low",
|
||||
"supporting_summaries": ["id1", "id2"],
|
||||
"specific_examples": ["example1", "example2"],
|
||||
"is_current": true|false
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
## Example (English input → English output)
|
||||
{
|
||||
"habits": [
|
||||
{
|
||||
"habit_description": "drinks coffee every morning",
|
||||
"frequency_pattern": "daily",
|
||||
"time_context": "morning routine",
|
||||
"confidence_level": "high",
|
||||
"supporting_summaries": ["s1", "s2"],
|
||||
"specific_examples": ["needs coffee to start the day"],
|
||||
"is_current": true
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
## Example (Chinese input → Chinese output)
|
||||
{
|
||||
"habits": [
|
||||
{
|
||||
"habit_description": "每天早上喝咖啡",
|
||||
"frequency_pattern": "daily",
|
||||
"time_context": "早晨日常",
|
||||
"confidence_level": "high",
|
||||
"supporting_summaries": ["s1", "s2"],
|
||||
"specific_examples": ["需要咖啡来开始一天"],
|
||||
"is_current": true
|
||||
}
|
||||
]
|
||||
}
|
||||
@@ -0,0 +1,54 @@
|
||||
You are an expert at analyzing user interests from memory summaries.
|
||||
|
||||
## Memory Summaries
|
||||
{% for summary in memory_summaries %}
|
||||
Summary {{ loop.index }}:
|
||||
{{ summary.content or summary.user_content or '' }}
|
||||
---
|
||||
{% endfor %}
|
||||
|
||||
## Target User ID
|
||||
{{ user_id }}
|
||||
|
||||
## Interest Categories
|
||||
1. **Tech**: Programming, technology, software tools, hardware
|
||||
2. **Lifestyle**: Daily routines, health, hobbies, social activities
|
||||
3. **Music**: Music preferences, instruments, concerts
|
||||
4. **Art**: Visual arts, creative projects, design, aesthetics
|
||||
|
||||
## Instructions
|
||||
1. Categorize the user's interests into the four areas
|
||||
2. Calculate percentage distribution (must total 100%)
|
||||
3. Provide specific evidence for each interest area
|
||||
4. Use "increasing", "decreasing", or "stable" for trending direction
|
||||
5. **IMPORTANT: Output language MUST match the input language. If summaries are in Chinese, output in Chinese. If in English, output in English.**
|
||||
|
||||
## Output Format
|
||||
{
|
||||
"interest_distribution": {
|
||||
"tech": {"percentage": 0-100, "evidence": [], "trending_direction": "increasing|decreasing|stable|null"},
|
||||
"lifestyle": {"percentage": 0-100, "evidence": [], "trending_direction": "increasing|decreasing|stable|null"},
|
||||
"music": {"percentage": 0-100, "evidence": [], "trending_direction": "increasing|decreasing|stable|null"},
|
||||
"art": {"percentage": 0-100, "evidence": [], "trending_direction": "increasing|decreasing|stable|null"}
|
||||
}
|
||||
}
|
||||
|
||||
## Example (English input → English output)
|
||||
{
|
||||
"interest_distribution": {
|
||||
"tech": {"percentage": 40, "evidence": ["discusses programming frequently"], "trending_direction": "increasing"},
|
||||
"lifestyle": {"percentage": 35, "evidence": ["talks about fitness routine"], "trending_direction": "stable"},
|
||||
"music": {"percentage": 15, "evidence": ["mentioned favorite bands"], "trending_direction": "stable"},
|
||||
"art": {"percentage": 10, "evidence": ["visited art museum"], "trending_direction": "stable"}
|
||||
}
|
||||
}
|
||||
|
||||
## Example (Chinese input → Chinese output)
|
||||
{
|
||||
"interest_distribution": {
|
||||
"tech": {"percentage": 40, "evidence": ["经常讨论编程"], "trending_direction": "increasing"},
|
||||
"lifestyle": {"percentage": 35, "evidence": ["谈论健身日常"], "trending_direction": "stable"},
|
||||
"music": {"percentage": 15, "evidence": ["提到喜欢的乐队"], "trending_direction": "stable"},
|
||||
"art": {"percentage": 10, "evidence": ["参观了艺术博物馆"], "trending_direction": "stable"}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,47 @@
|
||||
You are an expert at analyzing user memory summaries to identify implicit preferences.
|
||||
|
||||
## Memory Summaries
|
||||
{% for summary in memory_summaries %}
|
||||
Summary {{ loop.index }}:
|
||||
{{ summary.content or summary.user_content or '' }}
|
||||
---
|
||||
{% endfor %}
|
||||
|
||||
## Target User ID
|
||||
{{ user_id }}
|
||||
|
||||
## Instructions
|
||||
1. Focus ONLY on the specified user's preferences
|
||||
2. Extract SHORT preference tags (1-3 words max), like: "音乐", "咖啡", "科幻", "设计", "古典", "吉他"
|
||||
3. DO NOT use long phrases - use short nouns or noun phrases
|
||||
4. Only include preferences with confidence_score >= 0.3
|
||||
5. **IMPORTANT: Output language MUST match the input language. If summaries are in Chinese, output in Chinese. If in English, output in English.**
|
||||
|
||||
## Output Format
|
||||
{
|
||||
"preferences": [
|
||||
{
|
||||
"tag_name": "short tag",
|
||||
"confidence_score": 0.0-1.0,
|
||||
"supporting_evidence": ["evidence1", "evidence2"],
|
||||
"context_details": "brief context",
|
||||
"category": "category or null"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
## Example (Chinese input → Chinese output)
|
||||
{
|
||||
"preferences": [
|
||||
{"tag_name": "咖啡", "confidence_score": 0.8, "supporting_evidence": ["每天早上喝咖啡"], "context_details": "日常习惯", "category": "lifestyle"},
|
||||
{"tag_name": "古典音乐", "confidence_score": 0.7, "supporting_evidence": ["喜欢听古典"], "context_details": "音乐偏好", "category": "music"}
|
||||
]
|
||||
}
|
||||
|
||||
## Example (English input → English output)
|
||||
{
|
||||
"preferences": [
|
||||
{"tag_name": "coffee", "confidence_score": 0.8, "supporting_evidence": ["drinks coffee every morning"], "context_details": "daily routine", "category": "lifestyle"},
|
||||
{"tag_name": "classical music", "confidence_score": 0.7, "supporting_evidence": ["enjoys classical"], "context_details": "music preference", "category": "music"}
|
||||
]
|
||||
}
|
||||
@@ -1,327 +0,0 @@
|
||||
"""
|
||||
This module provides the MemoryInsight class for analyzing user memory data.
|
||||
|
||||
MemoryInsight 是一个工具类,提供基础的数据获取和分析功能:
|
||||
- get_domain_distribution(): 获取记忆领域分布
|
||||
- get_active_periods(): 获取活跃时段
|
||||
- get_social_connections(): 获取社交关联
|
||||
|
||||
业务逻辑(如生成洞察报告)应该在服务层(user_memory_service.py)中实现。
|
||||
|
||||
This script can be executed directly to test the memory insight generation for a test user.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from collections import Counter
|
||||
from datetime import datetime
|
||||
|
||||
# To run this script directly, we need to add the src directory to the Python path
|
||||
# to resolve the inconsistent imports in other modules.
|
||||
src_path = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||
if src_path not in sys.path:
|
||||
sys.path.insert(0, src_path)
|
||||
|
||||
from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.db import get_db_context
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
#TODO: Fix this
|
||||
|
||||
# Default values (previously from definitions.py)
|
||||
DEFAULT_LLM_ID = os.getenv("SELECTED_LLM_ID", "openai/qwen-plus")
|
||||
DEFAULT_GROUP_ID = os.getenv("SELECTED_GROUP_ID", "group_123")
|
||||
|
||||
# 定义用于LLM结构化输出的Pydantic模型
|
||||
class TagClassification(BaseModel):
|
||||
"""
|
||||
Represents the classification of a tag into a specific domain.
|
||||
"""
|
||||
|
||||
domain: str = Field(
|
||||
...,
|
||||
description="The domain the tag belongs to, chosen from the predefined list.",
|
||||
examples=["教育", "学习", "工作", "旅行", "家庭", "运动", "社交", "娱乐", "健康", "其他"],
|
||||
)
|
||||
|
||||
class InsightReport(BaseModel):
|
||||
"""
|
||||
Represents the final insight report generated by the LLM.
|
||||
"""
|
||||
|
||||
report: str = Field(
|
||||
...,
|
||||
description="A comprehensive insight report in Chinese, summarizing the user's memory patterns.",
|
||||
)
|
||||
|
||||
|
||||
class MemoryInsight:
|
||||
"""
|
||||
Provides insights into user memories by analyzing various aspects of their data.
|
||||
"""
|
||||
|
||||
def __init__(self, user_id: str):
|
||||
self.user_id = user_id
|
||||
self.neo4j_connector = Neo4jConnector()
|
||||
|
||||
# Get config_id using get_end_user_connected_config
|
||||
with get_db_context() as db:
|
||||
try:
|
||||
from app.services.memory_agent_service import (
|
||||
get_end_user_connected_config,
|
||||
)
|
||||
connected_config = get_end_user_connected_config(user_id, db)
|
||||
config_id = connected_config.get("memory_config_id")
|
||||
|
||||
if config_id:
|
||||
# Use the config_id to get the proper LLM client
|
||||
config_service = MemoryConfigService(db)
|
||||
memory_config = config_service.load_memory_config(config_id)
|
||||
factory = MemoryClientFactory(db)
|
||||
self.llm_client = factory.get_llm_client(memory_config.llm_model_id)
|
||||
else:
|
||||
# TODO: Remove DEFAULT_LLM_ID fallback once all users have proper config
|
||||
# Fallback to default LLM if no config found
|
||||
factory = MemoryClientFactory(db)
|
||||
self.llm_client = factory.get_llm_client(DEFAULT_LLM_ID)
|
||||
except Exception as e:
|
||||
print(f"Failed to get user connected config, using default LLM: {e}")
|
||||
# TODO: Remove DEFAULT_LLM_ID fallback once all users have proper config
|
||||
# Fallback to default LLM
|
||||
factory = MemoryClientFactory(db)
|
||||
self.llm_client = factory.get_llm_client(DEFAULT_LLM_ID)
|
||||
|
||||
async def close(self):
|
||||
"""关闭数据库连接。"""
|
||||
await self.neo4j_connector.close()
|
||||
|
||||
async def get_domain_distribution(self) -> dict[str, float]:
|
||||
"""
|
||||
Calculates the distribution of memory domains based on hot tags.
|
||||
"""
|
||||
hot_tags = await get_hot_memory_tags(self.user_id)
|
||||
if not hot_tags:
|
||||
return {}
|
||||
|
||||
domain_counts = Counter()
|
||||
for tag, _ in hot_tags:
|
||||
prompt = f"""请将以下标签归类到最合适的领域中。
|
||||
|
||||
可选领域及其关键词:
|
||||
- 教育:学校、课程、考试、培训、教学、学科、教师、学生、班级、作业、成绩、毕业、入学、校园、大学、中学、小学、教材、学位等
|
||||
- 学习:自学、阅读、书籍、技能提升、知识积累、笔记、复习、练习、研究、历史知识、科学知识、文化知识、学术讨论、知识问答等
|
||||
- 工作:职业、项目、会议、同事、业务、公司、办公、任务、客户、合同、职场、工作计划等
|
||||
- 旅行:旅游、景点、出行、度假、酒店、机票、导游、风景、旅行计划等
|
||||
- 家庭:亲人、父母、子女、配偶、家事、家庭活动、亲情、家庭聚会等
|
||||
- 运动:健身、体育、锻炼、跑步、游泳、球类、瑜伽、运动计划等
|
||||
- 社交:朋友、聚会、社交活动、派对、聊天、交友、社交网络等
|
||||
- 娱乐:游戏、电影、音乐、休闲、综艺、动漫、小说、娱乐活动等
|
||||
- 健康:医疗、养生、心理健康、体检、药物、疾病、保健、健康管理等
|
||||
- 其他:确实无法归入以上任何类别的内容
|
||||
|
||||
标签: {tag}
|
||||
|
||||
分析步骤:
|
||||
1. 仔细理解标签的核心含义和使用场景
|
||||
2. 对比各个领域的关键词,找到最匹配的领域
|
||||
3. 特别注意:
|
||||
- 历史、科学、文化等知识性内容应归类为"学习"
|
||||
- 学校、课程、考试等正式教育场景应归类为"教育"
|
||||
- 只有在标签完全不属于上述9个具体领域时,才选择"其他"
|
||||
4. 如果标签与某个领域有任何相关性,就选择该领域,不要选"其他"
|
||||
|
||||
请直接返回最合适的领域名称。"""
|
||||
messages = [
|
||||
{"role": "system", "content": "你是一个专业的标签分类助手。你必须仔细分析标签的实际含义和使用场景,优先选择9个具体领域之一。'其他'类别只用于完全无法归类的极少数情况。特别注意:历史、科学、文化等知识性对话应归类为'学习'领域;学校、课程、考试等正式教育场景应归类为'教育'领域。"},
|
||||
{"role": "user", "content": prompt}
|
||||
]
|
||||
# 直接调用并等待结果
|
||||
classification = await self.llm_client.response_structured(
|
||||
messages=messages,
|
||||
response_model=TagClassification,
|
||||
)
|
||||
if classification and hasattr(classification, 'domain') and classification.domain:
|
||||
domain_counts[classification.domain] += 1
|
||||
|
||||
total_tags = sum(domain_counts.values())
|
||||
if total_tags == 0:
|
||||
return {}
|
||||
|
||||
domain_distribution = {
|
||||
domain: count / total_tags for domain, count in domain_counts.items()
|
||||
}
|
||||
return dict(
|
||||
sorted(domain_distribution.items(), key=lambda item: item[1], reverse=True)
|
||||
)
|
||||
|
||||
async def get_active_periods(self) -> list[int]:
|
||||
"""
|
||||
Identifies the top 2 most active months for the user.
|
||||
Only returns months if there is valid and diverse time data.
|
||||
|
||||
This method checks if the time data represents real user memory timestamps
|
||||
rather than auto-generated system timestamps by verifying:
|
||||
1. Time data exists and is parseable
|
||||
2. Time data is distributed across multiple months (not concentrated in 1-2 months)
|
||||
"""
|
||||
query = f"""
|
||||
MATCH (d:Dialogue)
|
||||
WHERE d.group_id = '{self.user_id}' AND d.created_at IS NOT NULL AND d.created_at <> ''
|
||||
RETURN d.created_at AS creation_time
|
||||
"""
|
||||
records = await self.neo4j_connector.execute_query(query)
|
||||
|
||||
if not records:
|
||||
return []
|
||||
|
||||
month_counts = Counter()
|
||||
valid_dates_count = 0
|
||||
for record in records:
|
||||
creation_time_str = record.get("creation_time")
|
||||
if not creation_time_str:
|
||||
continue
|
||||
try:
|
||||
# 尝试解析时间字符串
|
||||
dt_object = datetime.fromisoformat(creation_time_str.replace("Z", "+00:00"))
|
||||
month_counts[dt_object.month] += 1
|
||||
valid_dates_count += 1
|
||||
except (ValueError, TypeError, AttributeError):
|
||||
# 如果解析失败,跳过这条记录
|
||||
continue
|
||||
|
||||
# 如果没有有效的时间数据,返回空列表
|
||||
if not month_counts or valid_dates_count == 0:
|
||||
return []
|
||||
|
||||
# 检查时间分布是否过于集中(可能是批量导入的数据)
|
||||
# 如果超过80%的数据集中在1-2个月,认为这是系统时间戳而非真实时间
|
||||
unique_months = len(month_counts)
|
||||
if unique_months <= 2:
|
||||
# 只有1-2个月有数据,很可能是批量导入
|
||||
most_common_count = month_counts.most_common(1)[0][1]
|
||||
if most_common_count / valid_dates_count > 0.8:
|
||||
# 超过80%集中在一个月,认为是系统时间戳
|
||||
return []
|
||||
|
||||
# 如果时间分布较为分散(3个月以上),认为是真实时间数据
|
||||
if unique_months >= 3:
|
||||
most_common_months = month_counts.most_common(2)
|
||||
return [month for month, _ in most_common_months]
|
||||
|
||||
# 2个月的情况,检查是否分布均匀
|
||||
if unique_months == 2:
|
||||
counts = list(month_counts.values())
|
||||
# 如果两个月的数据量相差不大(比例在0.3-3之间),认为是真实数据
|
||||
ratio = min(counts) / max(counts)
|
||||
if ratio > 0.3:
|
||||
most_common_months = month_counts.most_common(2)
|
||||
return [month for month, _ in most_common_months]
|
||||
|
||||
# 其他情况返回空列表
|
||||
return []
|
||||
|
||||
async def get_social_connections(self) -> dict | None:
|
||||
"""
|
||||
Finds the user with whom the most memories are shared.
|
||||
使用 Chunk-Statement 的 CONTAINS 关系,因为系统中不创建 Dialogue-Statement 的 MENTIONS 关系。
|
||||
"""
|
||||
# 通过 Chunk 和 Statement 的 CONTAINS 关系来查找共同记忆
|
||||
query = f"""
|
||||
MATCH (c1:Chunk {{group_id: '{self.user_id}'}})
|
||||
OPTIONAL MATCH (c1)-[:CONTAINS]->(s:Statement)
|
||||
OPTIONAL MATCH (s)<-[:CONTAINS]-(c2:Chunk)
|
||||
WHERE c1.group_id <> c2.group_id AND s IS NOT NULL AND c2 IS NOT NULL
|
||||
WITH c2.group_id AS other_user_id, COUNT(DISTINCT s) AS common_statements
|
||||
WHERE common_statements > 0
|
||||
RETURN other_user_id, common_statements
|
||||
ORDER BY common_statements DESC
|
||||
LIMIT 1
|
||||
"""
|
||||
records = await self.neo4j_connector.execute_query(query)
|
||||
if not records or not records[0].get("other_user_id"):
|
||||
return None
|
||||
|
||||
most_connected_user = records[0]["other_user_id"]
|
||||
common_memories_count = records[0]["common_statements"]
|
||||
|
||||
# 使用 Chunk 的时间范围
|
||||
time_range_query = f"""
|
||||
MATCH (c:Chunk)
|
||||
WHERE c.group_id IN ['{self.user_id}', '{most_connected_user}']
|
||||
RETURN min(c.created_at) AS start_time, max(c.created_at) AS end_time
|
||||
"""
|
||||
time_records = await self.neo4j_connector.execute_query(time_range_query)
|
||||
start_year, end_year = "N/A", "N/A"
|
||||
if time_records and time_records[0]["start_time"]:
|
||||
start_year = datetime.fromisoformat(time_records[0]["start_time"].replace("Z", "+00:00")).year
|
||||
end_year = datetime.fromisoformat(time_records[0]["end_time"].replace("Z", "+00:00")).year
|
||||
|
||||
return {
|
||||
"user_id": most_connected_user,
|
||||
"common_memories_count": common_memories_count,
|
||||
"time_range": f"{start_year}-{end_year}",
|
||||
}
|
||||
|
||||
async def close(self):
|
||||
"""
|
||||
Closes the database connection.
|
||||
"""
|
||||
await self.neo4j_connector.close()
|
||||
|
||||
|
||||
async def main():
|
||||
"""
|
||||
Initializes and runs the memory insight analysis for a test user.
|
||||
"""
|
||||
# 默认从环境变量读取
|
||||
test_user_id = DEFAULT_GROUP_ID
|
||||
print(f"正在为用户 {test_user_id} 生成记忆洞察报告...\n")
|
||||
|
||||
try:
|
||||
# 使用服务层函数生成报告
|
||||
from app.services.user_memory_service import analytics_memory_insight_report
|
||||
|
||||
result = await analytics_memory_insight_report(end_user_id=test_user_id)
|
||||
report = result.get("report", "")
|
||||
|
||||
print("--- 记忆洞察报告 ---")
|
||||
print(report)
|
||||
print("---------------------")
|
||||
|
||||
# 将结果写入统一的 User-Dashboard.json,使用全局配置路径
|
||||
try:
|
||||
from app.core.config import settings
|
||||
settings.ensure_memory_output_dir()
|
||||
output_dir = settings.MEMORY_OUTPUT_DIR
|
||||
try:
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
except Exception:
|
||||
pass
|
||||
dashboard_path = os.path.join(output_dir, "User-Dashboard.json")
|
||||
existing = {}
|
||||
if os.path.exists(dashboard_path):
|
||||
with open(dashboard_path, "r", encoding="utf-8") as rf:
|
||||
existing = json.load(rf)
|
||||
existing["memory_insight"] = {
|
||||
"group_id": test_user_id,
|
||||
"report": report
|
||||
}
|
||||
with open(dashboard_path, "w", encoding="utf-8") as wf:
|
||||
json.dump(existing, wf, ensure_ascii=False, indent=2)
|
||||
print(f"已写入 {dashboard_path} -> memory_insight")
|
||||
except Exception as e:
|
||||
print(f"写入 User-Dashboard.json 失败: {e}")
|
||||
except Exception as e:
|
||||
print(f"生成报告时出错: {e}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# This setup allows running the async main function
|
||||
if sys.platform.startswith('win') and sys.version_info >= (3, 8):
|
||||
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
|
||||
asyncio.run(main())
|
||||
@@ -1,157 +0,0 @@
|
||||
"""
|
||||
Generate a concise "关于我" style user summary using data from Neo4j
|
||||
and the existing LLM configuration (mirrors hot_memory_tags.py setup).
|
||||
|
||||
Usage:
|
||||
python -m analytics.user_summary --user_id <group_id>
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Tuple
|
||||
|
||||
# Ensure absolute imports work whether executed directly or via module
|
||||
try:
|
||||
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..'))
|
||||
src_path = os.path.join(project_root, 'src')
|
||||
if src_path not in sys.path:
|
||||
sys.path.insert(0, src_path)
|
||||
if project_root not in sys.path:
|
||||
sys.path.insert(0, project_root)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.db import get_db_context
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
|
||||
#TODO: Fix this
|
||||
|
||||
# Default values (previously from definitions.py)
|
||||
DEFAULT_LLM_ID = os.getenv("SELECTED_LLM_ID", "openai/qwen-plus")
|
||||
DEFAULT_GROUP_ID = os.getenv("SELECTED_GROUP_ID", "group_123")
|
||||
|
||||
|
||||
@dataclass
|
||||
class StatementRecord:
|
||||
statement: str
|
||||
created_at: str | None
|
||||
|
||||
|
||||
class UserSummary:
|
||||
"""Builds a textual user summary for a given user/group id."""
|
||||
|
||||
def __init__(self, user_id: str):
|
||||
self.user_id = user_id
|
||||
self.connector = Neo4jConnector()
|
||||
|
||||
# Get config_id using get_end_user_connected_config
|
||||
with get_db_context() as db:
|
||||
try:
|
||||
from app.services.memory_agent_service import (
|
||||
get_end_user_connected_config,
|
||||
)
|
||||
connected_config = get_end_user_connected_config(user_id, db)
|
||||
config_id = connected_config.get("memory_config_id")
|
||||
|
||||
if config_id:
|
||||
# Use the config_id to get the proper LLM client
|
||||
config_service = MemoryConfigService(db)
|
||||
memory_config = config_service.load_memory_config(config_id)
|
||||
factory = MemoryClientFactory(db)
|
||||
self.llm = factory.get_llm_client(memory_config.llm_model_id)
|
||||
else:
|
||||
# TODO: Remove DEFAULT_LLM_ID fallback once all users have proper config
|
||||
# Fallback to default LLM if no config found
|
||||
factory = MemoryClientFactory(db)
|
||||
self.llm = factory.get_llm_client(DEFAULT_LLM_ID)
|
||||
except Exception as e:
|
||||
print(f"Failed to get user connected config, using default LLM: {e}")
|
||||
# TODO: Remove DEFAULT_LLM_ID fallback once all users have proper config
|
||||
# Fallback to default LLM
|
||||
factory = MemoryClientFactory(db)
|
||||
self.llm = factory.get_llm_client(DEFAULT_LLM_ID)
|
||||
|
||||
async def close(self):
|
||||
await self.connector.close()
|
||||
|
||||
async def _get_recent_statements(self, limit: int = 80) -> List[StatementRecord]: # TODO Used by user_memory_service
|
||||
"""Fetch recent statements authored by the user/group for context."""
|
||||
query = (
|
||||
"MATCH (s:Statement) "
|
||||
"WHERE s.group_id = $group_id AND s.statement IS NOT NULL "
|
||||
"RETURN s.statement AS statement, s.created_at AS created_at "
|
||||
"ORDER BY created_at DESC LIMIT $limit"
|
||||
)
|
||||
rows = await self.connector.execute_query(query, group_id=self.user_id, limit=limit)
|
||||
records: List[StatementRecord] = []
|
||||
for r in rows:
|
||||
try:
|
||||
records.append(StatementRecord(statement=r.get("statement", ""), created_at=r.get("created_at")))
|
||||
except Exception:
|
||||
continue
|
||||
return records
|
||||
|
||||
async def _get_top_entities(self, limit: int = 30) -> List[Tuple[str, int]]:
|
||||
"""Reuse hot tag logic to get meaningful entities and their frequencies."""
|
||||
# get_hot_memory_tags internally filters out non-meaningful nouns with LLM
|
||||
return await get_hot_memory_tags(self.user_id, limit=limit) # TODO Used by user_memory_service
|
||||
|
||||
|
||||
async def generate_user_summary(user_id: str | None = None) -> str: # TODO useless
|
||||
"""
|
||||
生成用户摘要的便捷函数
|
||||
|
||||
Args:
|
||||
user_id: 可选的用户ID
|
||||
|
||||
Returns:
|
||||
用户摘要字符串
|
||||
"""
|
||||
# 导入服务层函数
|
||||
from app.services.user_memory_service import analytics_user_summary
|
||||
|
||||
# 调用服务层函数
|
||||
result = await analytics_user_summary(user_id)
|
||||
return result.get("summary", "")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("开始生成用户摘要…")
|
||||
try:
|
||||
# 直接使用 runtime.json 中的 group_id
|
||||
summary = asyncio.run(generate_user_summary())
|
||||
print("\n— 用户摘要 —\n")
|
||||
print(summary)
|
||||
|
||||
# 将结果写入统一的 User-Dashboard.json
|
||||
try:
|
||||
from app.core.config import settings
|
||||
settings.ensure_memory_output_dir()
|
||||
output_dir = settings.MEMORY_OUTPUT_DIR
|
||||
try:
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
except Exception:
|
||||
pass
|
||||
dashboard_path = os.path.join(output_dir, "User-Dashboard.json")
|
||||
existing = {}
|
||||
if os.path.exists(dashboard_path):
|
||||
with open(dashboard_path, "r", encoding="utf-8") as rf:
|
||||
existing = json.load(rf)
|
||||
existing["user_summary"] = {
|
||||
"group_id": DEFAULT_GROUP_ID,
|
||||
"summary": summary
|
||||
}
|
||||
with open(dashboard_path, "w", encoding="utf-8") as wf:
|
||||
json.dump(existing, wf, ensure_ascii=False, indent=2)
|
||||
print(f"已写入 {dashboard_path} -> user_summary")
|
||||
except Exception as e:
|
||||
print(f"写入 User-Dashboard.json 失败: {e}")
|
||||
except Exception as e:
|
||||
print(f"生成摘要失败: {e}")
|
||||
print("请检查: 1) Neo4j 是否可用;2) config.json 与 .env 的 LLM/Neo4j 配置是否正确;3) 数据是否包含该用户的内容。")
|
||||
@@ -37,12 +37,20 @@ def parse_historical_datetime(v):
|
||||
此函数手动解析 ISO 8601 格式的日期字符串,支持1-4位年份
|
||||
|
||||
Args:
|
||||
v: 日期值(可以是 None、datetime 对象或字符串)
|
||||
v: 日期值(可以是 None、datetime 对象、Neo4j DateTime 对象或字符串)
|
||||
|
||||
Returns:
|
||||
datetime 对象或 None
|
||||
"""
|
||||
if v is None or isinstance(v, datetime):
|
||||
if v is None:
|
||||
return v
|
||||
|
||||
# 处理 Neo4j DateTime 对象
|
||||
if hasattr(v, 'to_native'):
|
||||
return v.to_native()
|
||||
|
||||
# 处理 Python datetime 对象
|
||||
if isinstance(v, datetime):
|
||||
return v
|
||||
|
||||
if isinstance(v, str):
|
||||
@@ -397,6 +405,10 @@ class ExtractedEntityNode(Node):
|
||||
statement_id: str = Field(..., description="Statement this entity was extracted from")
|
||||
entity_type: str = Field(..., description="Type of the entity")
|
||||
description: str = Field(..., description="Entity description")
|
||||
example: str = Field(
|
||||
default="",
|
||||
description="A concise example (around 20 characters) to help understand the entity"
|
||||
)
|
||||
aliases: List[str] = Field(
|
||||
default_factory=list,
|
||||
description="Entity aliases - alternative names for this entity"
|
||||
@@ -433,6 +445,12 @@ class ExtractedEntityNode(Node):
|
||||
description="Total number of times this node has been accessed"
|
||||
)
|
||||
|
||||
# Explicit Memory Classification
|
||||
is_explicit_memory: bool = Field(
|
||||
default=False,
|
||||
description="Whether this entity represents explicit/semantic memory (knowledge, concepts, definitions, theories, principles)"
|
||||
)
|
||||
|
||||
@field_validator('aliases', mode='before')
|
||||
@classmethod
|
||||
def validate_aliases_field(cls, v): # 字段验证器 自动清理和验证 aliases 字段
|
||||
@@ -466,6 +484,8 @@ class MemorySummaryNode(Node):
|
||||
dialog_id: ID of the parent dialog
|
||||
chunk_ids: List of chunk IDs used to generate this summary
|
||||
content: Summary text content
|
||||
name: Title/name of the memory summary (generated by LLM, used as title in API)
|
||||
memory_type: Type/category of the episodic memory (e.g., Conversation, Project/Work, Learning, Decision, Important Event)
|
||||
summary_embedding: Optional embedding vector for the summary
|
||||
metadata: Additional metadata for the summary
|
||||
config_id: Configuration ID used to process this summary
|
||||
@@ -484,6 +504,7 @@ class MemorySummaryNode(Node):
|
||||
dialog_id: str = Field(..., description="ID of the parent dialog")
|
||||
chunk_ids: List[str] = Field(default_factory=list, description="List of chunk IDs used in the summary")
|
||||
content: str = Field(..., description="Summary text content")
|
||||
memory_type: Optional[str] = Field(None, description="Type/category of the episodic memory")
|
||||
summary_embedding: Optional[List[float]] = Field(None, description="Embedding vector for the summary")
|
||||
metadata: dict = Field(default_factory=dict, description="Additional metadata for the summary")
|
||||
config_id: Optional[int | str] = Field(None, description="Configuration ID used to process this summary (integer or string)")
|
||||
|
||||
@@ -38,10 +38,20 @@ class Entity(BaseModel):
|
||||
name_embedding: Optional[List[float]] = Field(None, description="Embedding vector for the entity name")
|
||||
type: str = Field(..., description="Type/category of the entity")
|
||||
description: str = Field(..., description="Description of the entity")
|
||||
example: str = Field(
|
||||
default="",
|
||||
description="A concise example (around 20 characters) to help understand the entity"
|
||||
)
|
||||
aliases: List[str] = Field(
|
||||
default_factory=list,
|
||||
description="Alternative names for this entity (abbreviations, full names, translations, etc.)"
|
||||
)
|
||||
|
||||
# Explicit Memory Classification
|
||||
is_explicit_memory: bool = Field(
|
||||
default=False,
|
||||
description="Whether this entity represents explicit/semantic memory (knowledge, concepts, definitions, theories, principles)"
|
||||
)
|
||||
|
||||
|
||||
class Triplet(BaseModel):
|
||||
|
||||
@@ -42,7 +42,6 @@ from app.core.memory.storage_services.extraction_engine.deduplication.two_stage_
|
||||
)
|
||||
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.embedding_generation import (
|
||||
embedding_generation,
|
||||
embedding_generation_all,
|
||||
generate_entity_embeddings_from_triplets,
|
||||
)
|
||||
|
||||
@@ -179,7 +178,7 @@ class ExtractionOrchestrator:
|
||||
for dialog in dialog_data_list:
|
||||
for chunk in dialog.chunks:
|
||||
all_statements_list.extend(chunk.statements)
|
||||
total_statements = len(all_statements_list)
|
||||
len(all_statements_list)
|
||||
|
||||
# 步骤 2: 并行执行三元组提取、时间信息提取、情绪提取和基础嵌入生成
|
||||
logger.info("步骤 2/6: 并行执行三元组提取、时间信息提取、情绪提取和嵌入生成")
|
||||
@@ -201,9 +200,9 @@ class ExtractionOrchestrator:
|
||||
all_entities_list.extend(triplet_info.entities)
|
||||
all_triplets_list.extend(triplet_info.triplets)
|
||||
|
||||
total_entities = len(all_entities_list)
|
||||
total_triplets = len(all_triplets_list)
|
||||
total_temporal = sum(len(temporal_map) for temporal_map in temporal_maps)
|
||||
len(all_entities_list)
|
||||
len(all_triplets_list)
|
||||
sum(len(temporal_map) for temporal_map in temporal_maps)
|
||||
|
||||
# 步骤 3: 生成实体嵌入(依赖三元组提取结果)
|
||||
logger.info("步骤 3/6: 生成实体嵌入")
|
||||
@@ -385,7 +384,7 @@ class ExtractionOrchestrator:
|
||||
|
||||
# 用于跟踪已完成的陈述句数量
|
||||
completed_statements = 0
|
||||
total_statements = len(all_statements)
|
||||
len(all_statements)
|
||||
|
||||
# 全局并行处理所有陈述句
|
||||
async def extract_for_statement(stmt_data, stmt_index):
|
||||
@@ -497,7 +496,7 @@ class ExtractionOrchestrator:
|
||||
|
||||
# 用于跟踪已完成的时间提取数量
|
||||
completed_temporal = 0
|
||||
total_temporal_statements = len(all_statements)
|
||||
len(all_statements)
|
||||
|
||||
# 全局并行处理所有陈述句
|
||||
async def extract_for_statement(stmt_data, stmt_index):
|
||||
@@ -1082,10 +1081,12 @@ class ExtractionOrchestrator:
|
||||
statement_id=statement.id, # 添加必需的 statement_id 字段
|
||||
entity_type=getattr(entity, 'type', 'unknown'), # 使用 type 而不是 entity_type
|
||||
description=getattr(entity, 'description', ''), # 添加必需的 description 字段
|
||||
example=getattr(entity, 'example', ''), # 新增:传递示例字段
|
||||
fact_summary=getattr(entity, 'fact_summary', ''), # 添加必需的 fact_summary 字段
|
||||
connect_strength=entity_connect_strength if entity_connect_strength is not None else 'Strong', # 添加必需的 connect_strength 字段
|
||||
aliases=getattr(entity, 'aliases', []) or [], # 传递从三元组提取阶段获取的aliases
|
||||
name_embedding=getattr(entity, 'name_embedding', None),
|
||||
is_explicit_memory=getattr(entity, 'is_explicit_memory', False), # 新增:传递语义记忆标记
|
||||
group_id=dialog_data.group_id,
|
||||
user_id=dialog_data.user_id,
|
||||
apply_id=dialog_data.apply_id,
|
||||
|
||||
@@ -59,13 +59,28 @@ async def _process_chunk_summary(
|
||||
)
|
||||
summary_text = structured.summary.strip()
|
||||
|
||||
# Generate title and type for the summary
|
||||
title = None
|
||||
episodic_type = None
|
||||
try:
|
||||
from app.services.user_memory_service import UserMemoryService
|
||||
title, episodic_type = await UserMemoryService.generate_title_and_type_for_summary(
|
||||
content=summary_text,
|
||||
end_user_id=dialog.group_id
|
||||
)
|
||||
logger.info(f"Generated title and type for MemorySummary: title={title}, type={episodic_type}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to generate title and type for chunk {chunk.id}: {e}")
|
||||
# Continue without title and type
|
||||
|
||||
# Embed the summary
|
||||
embedding = (await embedder.response([summary_text]))[0]
|
||||
|
||||
# Build node per chunk
|
||||
# Note: title is stored in the 'name' field, type is stored in 'memory_type' field
|
||||
node = MemorySummaryNode(
|
||||
id=uuid4().hex,
|
||||
name=f"MemorySummaryChunk_{chunk.id}",
|
||||
name=title if title else f"MemorySummaryChunk_{chunk.id}",
|
||||
group_id=dialog.group_id,
|
||||
user_id=dialog.user_id,
|
||||
apply_id=dialog.apply_id,
|
||||
@@ -75,6 +90,7 @@ async def _process_chunk_summary(
|
||||
dialog_id=dialog.id,
|
||||
chunk_ids=[chunk.id],
|
||||
content=summary_text,
|
||||
memory_type=episodic_type,
|
||||
summary_embedding=embedding,
|
||||
metadata={"ref_id": dialog.ref_id},
|
||||
config_id=dialog.config_id, # 添加 config_id
|
||||
|
||||
@@ -246,7 +246,7 @@ class AccessHistoryManager:
|
||||
if not node_data:
|
||||
return ConsistencyCheckResult.CONSISTENT, None
|
||||
|
||||
access_history = node_data.get('access_history', [])
|
||||
access_history = node_data.get('access_history') or []
|
||||
last_access_time = node_data.get('last_access_time')
|
||||
access_count = node_data.get('access_count', 0)
|
||||
activation_value = node_data.get('activation_value')
|
||||
@@ -409,7 +409,7 @@ class AccessHistoryManager:
|
||||
logger.error(f"节点不存在,无法修复: {node_label}[{node_id}]")
|
||||
return False
|
||||
|
||||
access_history = node_data.get('access_history', [])
|
||||
access_history = node_data.get('access_history') or []
|
||||
importance_score = node_data.get('importance_score', 0.5)
|
||||
|
||||
# 准备修复数据
|
||||
@@ -530,7 +530,7 @@ class AccessHistoryManager:
|
||||
Returns:
|
||||
Dict[str, Any]: 更新数据,包含所有需要更新的字段
|
||||
"""
|
||||
access_history = node_data.get('access_history', [])
|
||||
access_history = node_data.get('access_history') or []
|
||||
importance_score = node_data.get('importance_score', 0.5)
|
||||
|
||||
# 追加新的访问时间
|
||||
|
||||
@@ -247,6 +247,9 @@ class ForgettingStrategy:
|
||||
entity_activation = entity_node['entity_activation']
|
||||
entity_importance = entity_node['entity_importance']
|
||||
|
||||
# 获取 group_id(从 statement 或 entity 节点)
|
||||
group_id = statement_node.get('group_id') or entity_node.get('group_id')
|
||||
|
||||
# 生成摘要内容
|
||||
summary_text = await self._generate_summary(
|
||||
statement_text=statement_text,
|
||||
@@ -256,6 +259,19 @@ class ForgettingStrategy:
|
||||
db=db
|
||||
)
|
||||
|
||||
# 生成标题和类型(使用LLM)
|
||||
from app.services.user_memory_service import UserMemoryService
|
||||
try:
|
||||
title, episodic_type = await UserMemoryService.generate_title_and_type_for_summary(
|
||||
content=summary_text,
|
||||
end_user_id=group_id
|
||||
)
|
||||
logger.info(f"成功为MemorySummary生成标题和类型: title={title}, type={episodic_type}")
|
||||
except Exception as e:
|
||||
logger.error(f"生成标题和类型失败,使用默认值: {str(e)}")
|
||||
title = "未命名"
|
||||
episodic_type = "其他"
|
||||
|
||||
# 计算继承的激活值和重要性(取较高值)
|
||||
inherited_activation = max(statement_activation, entity_activation)
|
||||
inherited_importance = max(statement_importance, entity_importance)
|
||||
@@ -268,9 +284,6 @@ class ForgettingStrategy:
|
||||
import uuid
|
||||
summary_id = f"summary_{uuid.uuid4().hex[:16]}"
|
||||
|
||||
# 获取 group_id(从 statement 或 entity 节点)
|
||||
group_id = statement_node.get('group_id') or entity_node.get('group_id')
|
||||
|
||||
# 使用事务创建 MemorySummary 并删除原节点
|
||||
async def merge_transaction(tx, **params):
|
||||
"""事务函数:创建摘要节点并删除原节点"""
|
||||
@@ -287,6 +300,8 @@ class ForgettingStrategy:
|
||||
CREATE (ms:MemorySummary {
|
||||
id: $summary_id,
|
||||
summary: $summary_text,
|
||||
name: $title,
|
||||
memory_type: $episodic_type,
|
||||
original_statement_id: $statement_id,
|
||||
original_entity_id: $entity_id,
|
||||
activation_value: $inherited_activation,
|
||||
@@ -386,6 +401,8 @@ class ForgettingStrategy:
|
||||
params = {
|
||||
'summary_id': summary_id,
|
||||
'summary_text': summary_text,
|
||||
'title': title,
|
||||
'episodic_type': episodic_type,
|
||||
'statement_id': statement_id,
|
||||
'entity_id': entity_id,
|
||||
'inherited_activation': inherited_activation,
|
||||
|
||||
@@ -386,3 +386,26 @@ async def render_memory_insight_prompt(
|
||||
})
|
||||
|
||||
return rendered_prompt
|
||||
|
||||
|
||||
async def render_episodic_title_and_type_prompt(content: str) -> str:
|
||||
"""
|
||||
Renders the episodic title and type classification prompt using the episodic_type_classification.jinja2 template.
|
||||
|
||||
Args:
|
||||
content: The content of the episodic memory summary to analyze
|
||||
|
||||
Returns:
|
||||
Rendered prompt content as string
|
||||
"""
|
||||
template = prompt_env.get_template("episodic_type_classification.jinja2")
|
||||
rendered_prompt = template.render(content=content)
|
||||
|
||||
# 记录渲染结果到提示日志
|
||||
log_prompt_rendering('episodic title and type classification', rendered_prompt)
|
||||
# 可选:记录模板渲染信息
|
||||
log_template_rendering('episodic_type_classification.jinja2', {
|
||||
'content_len': len(content) if content else 0
|
||||
})
|
||||
|
||||
return rendered_prompt
|
||||
|
||||
@@ -0,0 +1,57 @@
|
||||
=== Task ===
|
||||
Generate a concise title and classify the episodic memory into the most appropriate category.
|
||||
|
||||
=== Requirements ===
|
||||
- Extract a clear, concise title (10-20 characters) that captures the core content
|
||||
- Classify into exactly one category based on the primary theme
|
||||
- Be specific and avoid ambiguity
|
||||
- Output must be valid JSON conforming to the schema below
|
||||
|
||||
=== Input ===
|
||||
{{ content }}
|
||||
|
||||
=== Category Definitions ===
|
||||
|
||||
1. **conversation**: Daily communication, chat, discussion, and social interactions
|
||||
- Keywords: chat, communication, discussion, dialogue, exchange
|
||||
|
||||
2. **project_work**: Work-related tasks, projects, meetings, and collaboration
|
||||
- Keywords: project, task, work, meeting, collaboration, business, client
|
||||
|
||||
3. **learning**: Acquiring new knowledge, skill development, reading, and research
|
||||
- Keywords: learning, reading, research, knowledge, skill, course, training
|
||||
|
||||
4. **decision**: Making important decisions, choices, and planning
|
||||
- Keywords: decision, choice, planning, consideration, evaluation, weighing
|
||||
|
||||
5. **important_event**: Major events, milestones, and special experiences
|
||||
- Keywords: important, major, milestone, special, memorable, celebration
|
||||
|
||||
=== Analysis Steps ===
|
||||
1. Read the episodic memory content carefully
|
||||
2. Identify the core theme and context
|
||||
3. Extract a concise title
|
||||
4. Compare against category definitions and keywords
|
||||
5. Select the best matching category
|
||||
6. If multiple categories apply, choose the primary one
|
||||
|
||||
=== Output Schema ===
|
||||
**CRITICAL JSON FORMATTING REQUIREMENTS:**
|
||||
1. Use only standard ASCII double quotes (") for JSON structure
|
||||
2. Escape any quotation marks within string values using backslashes (\")
|
||||
3. Ensure all JSON strings are properly closed and comma-separated
|
||||
4. Do not include line breaks within JSON string values
|
||||
|
||||
Return only a JSON object with title and type fields:
|
||||
{
|
||||
"title": "Generated title here",
|
||||
"type": "Category type here"
|
||||
}
|
||||
|
||||
The type field must be exactly one of:
|
||||
- conversation
|
||||
- project_work
|
||||
- learning
|
||||
- decision
|
||||
- important_event
|
||||
|
||||
@@ -12,7 +12,34 @@ Extract entities and knowledge triplets from the given statement.
|
||||
===Guidelines===
|
||||
|
||||
**Entity Extraction:**
|
||||
- Extract entities with their types, context-independent descriptions, and aliases
|
||||
- Extract entities with their types, context-independent descriptions, **concise examples**, aliases, and semantic memory classification
|
||||
- **Semantic Memory Classification (is_explicit_memory):**
|
||||
* Set to `true` if the entity represents **explicit/semantic memory**:
|
||||
- **Concepts:** "Machine Learning", "Photosynthesis", "Democracy", "人工智能", "光合作用", "民主"
|
||||
- **Knowledge:** "Python Programming Language", "Theory of Relativity", "Python编程语言", "相对论"
|
||||
- **Definitions:** "API (Application Programming Interface)", "REST API", "应用程序接口"
|
||||
- **Principles:** "SOLID Principles", "First Law of Thermodynamics", "SOLID原则", "热力学第一定律"
|
||||
- **Theories:** "Evolution Theory", "Quantum Mechanics", "进化论", "量子力学"
|
||||
- **Methods/Techniques:** "Agile Development", "Machine Learning Algorithm", "敏捷开发", "机器学习算法"
|
||||
- **Technical Terms:** "Neural Network", "Database", "神经网络", "数据库"
|
||||
* Set to `false` for:
|
||||
- **People:** "John Smith", "Dr. Wang", "张明", "王博士"
|
||||
- **Organizations:** "Microsoft", "Harvard University", "微软", "哈佛大学"
|
||||
- **Locations:** "Beijing", "Central Park", "北京", "中央公园"
|
||||
- **Events:** "2024 Conference", "Project Meeting", "2024会议", "项目会议"
|
||||
- **Specific objects:** "iPhone 15", "Building A", "iPhone 15", "A栋"
|
||||
- **Example Generation (IMPORTANT for semantic memory entities):**
|
||||
* For entities where `is_explicit_memory=true`, generate a **concise example (around 20 characters)** to help understand the concept
|
||||
* The example should be:
|
||||
- **Specific and concrete**: Use real-world scenarios or applications
|
||||
- **Brief**: Around 20 characters (can be slightly longer if needed for clarity)
|
||||
- **In the same language as the entity name**
|
||||
* Examples:
|
||||
- Entity: "机器学习" → example: "如:用神经网络识别图片中的猫狗"
|
||||
- Entity: "SOLID Principles" → example: "e.g., Single Responsibility, Open-Closed"
|
||||
- Entity: "Photosynthesis" → example: "e.g., plants convert sunlight to energy"
|
||||
- Entity: "人工智能" → example: "如:智能客服、自动驾驶"
|
||||
* For non-semantic entities (`is_explicit_memory=false`), the example field can be empty
|
||||
- **Aliases Extraction (Important):**
|
||||
* **CRITICAL: Extract aliases ONLY in the SAME LANGUAGE as the input text**
|
||||
* **DO NOT translate or add aliases in different languages**
|
||||
@@ -84,21 +111,27 @@ Output:
|
||||
"name": "I",
|
||||
"type": "Person",
|
||||
"description": "The user",
|
||||
"aliases": []
|
||||
"example": "",
|
||||
"aliases": [],
|
||||
"is_explicit_memory": false
|
||||
},
|
||||
{
|
||||
"entity_idx": 1,
|
||||
"name": "Paris",
|
||||
"type": "Location",
|
||||
"description": "Capital city of France",
|
||||
"aliases": []
|
||||
"example": "",
|
||||
"aliases": [],
|
||||
"is_explicit_memory": false
|
||||
},
|
||||
{
|
||||
"entity_idx": 2,
|
||||
"name": "Louvre",
|
||||
"type": "Location",
|
||||
"description": "World-famous museum located in Paris",
|
||||
"aliases": ["Louvre Museum"]
|
||||
"example": "",
|
||||
"aliases": ["Louvre Museum"],
|
||||
"is_explicit_memory": false
|
||||
}
|
||||
]
|
||||
}
|
||||
@@ -130,21 +163,27 @@ Output:
|
||||
"name": "John Smith",
|
||||
"type": "Person",
|
||||
"description": "Individual person name",
|
||||
"aliases": []
|
||||
"example": "",
|
||||
"aliases": [],
|
||||
"is_explicit_memory": false
|
||||
},
|
||||
{
|
||||
"entity_idx": 1,
|
||||
"name": "Google",
|
||||
"type": "Organization",
|
||||
"description": "American technology company",
|
||||
"aliases": ["Google LLC", "Alphabet Inc."]
|
||||
"example": "",
|
||||
"aliases": ["Google LLC", "Alphabet Inc."],
|
||||
"is_explicit_memory": false
|
||||
},
|
||||
{
|
||||
"entity_idx": 2,
|
||||
"name": "AI product development",
|
||||
"type": "WorkRole",
|
||||
"type": "Concept",
|
||||
"description": "Artificial intelligence product development work",
|
||||
"aliases": []
|
||||
"example": "e.g., developing chatbots, recommendation systems",
|
||||
"aliases": [],
|
||||
"is_explicit_memory": true
|
||||
}
|
||||
]
|
||||
}
|
||||
@@ -176,21 +215,27 @@ Output:
|
||||
"name": "我",
|
||||
"type": "Person",
|
||||
"description": "用户本人",
|
||||
"aliases": []
|
||||
"example": "",
|
||||
"aliases": [],
|
||||
"is_explicit_memory": false
|
||||
},
|
||||
{
|
||||
"entity_idx": 1,
|
||||
"name": "巴黎",
|
||||
"type": "Location",
|
||||
"description": "法国首都城市",
|
||||
"aliases": []
|
||||
"example": "",
|
||||
"aliases": [],
|
||||
"is_explicit_memory": false
|
||||
},
|
||||
{
|
||||
"entity_idx": 2,
|
||||
"name": "卢浮宫",
|
||||
"type": "Location",
|
||||
"description": "位于巴黎的世界著名博物馆",
|
||||
"aliases": []
|
||||
"example": "",
|
||||
"aliases": [],
|
||||
"is_explicit_memory": false
|
||||
}
|
||||
]
|
||||
}
|
||||
@@ -222,21 +267,27 @@ Output:
|
||||
"name": "张明",
|
||||
"type": "Person",
|
||||
"description": "个人姓名",
|
||||
"aliases": []
|
||||
"example": "",
|
||||
"aliases": [],
|
||||
"is_explicit_memory": false
|
||||
},
|
||||
{
|
||||
"entity_idx": 1,
|
||||
"name": "腾讯",
|
||||
"type": "Organization",
|
||||
"description": "中国科技公司",
|
||||
"aliases": ["腾讯控股", "腾讯公司"]
|
||||
"example": "",
|
||||
"aliases": ["腾讯控股", "腾讯公司"],
|
||||
"is_explicit_memory": false
|
||||
},
|
||||
{
|
||||
"entity_idx": 2,
|
||||
"name": "AI产品开发",
|
||||
"type": "WorkRole",
|
||||
"type": "Concept",
|
||||
"description": "人工智能产品研发工作",
|
||||
"aliases": []
|
||||
"example": "如:开发智能客服机器人、推荐系统",
|
||||
"aliases": [],
|
||||
"is_explicit_memory": true
|
||||
}
|
||||
]
|
||||
}
|
||||
@@ -251,7 +302,9 @@ Output:
|
||||
"name": "Tripod",
|
||||
"type": "Equipment",
|
||||
"description": "Photography equipment accessory",
|
||||
"aliases": ["Camera Tripod"]
|
||||
"example": "",
|
||||
"aliases": ["Camera Tripod"],
|
||||
"is_explicit_memory": false
|
||||
}
|
||||
]
|
||||
}
|
||||
@@ -266,7 +319,9 @@ Output:
|
||||
"name": "三脚架",
|
||||
"type": "Equipment",
|
||||
"description": "摄影器材配件",
|
||||
"aliases": ["相机三脚架"]
|
||||
"example": "",
|
||||
"aliases": ["相机三脚架"],
|
||||
"is_explicit_memory": false
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
@@ -85,33 +85,21 @@ Example Output:
|
||||
===End of Example===
|
||||
|
||||
|
||||
===Reflection Process===
|
||||
===Internal Quality Checks (DO NOT OUTPUT)===
|
||||
|
||||
After generating the profile, perform the following self-review steps:
|
||||
Before generating your final output, internally verify:
|
||||
1. All content is grounded in provided data (no fabrication)
|
||||
2. Format follows the specified structure with correct headers
|
||||
3. Tone is objective, third-person, and neutral
|
||||
4. All four sections are complete and within character limits
|
||||
|
||||
**Step 1: Data Grounding Check**
|
||||
- Verify all statements are supported by the provided entities and statements
|
||||
- Ensure no fabricated or speculated information is included
|
||||
- Confirm all claims can be traced back to the input data
|
||||
|
||||
**Step 2: Format Compliance**
|
||||
- Verify each section follows the specified format with section headers
|
||||
- Check character count limits for each section
|
||||
- Ensure proper use of section markers (【】)
|
||||
|
||||
**Step 3: Tone and Style Review**
|
||||
- Confirm objective third-person perspective is maintained
|
||||
- Check for excessive adjectives or empty phrases
|
||||
- Verify neutral and restrained tone throughout
|
||||
|
||||
**Step 4: Completeness Check**
|
||||
- Ensure all four sections are present and complete
|
||||
- Verify each section addresses its specific focus area
|
||||
- Confirm the one-sentence summary effectively captures the user's essence
|
||||
**IMPORTANT: These checks are for your internal use only. DO NOT include them in your output.**
|
||||
|
||||
|
||||
===Output Requirements===
|
||||
|
||||
**CRITICAL: Your response must ONLY contain the four sections below. Do not include any reflection, self-review, or meta-commentary.**
|
||||
|
||||
**LANGUAGE REQUIREMENT:**
|
||||
- The output language should ALWAYS be Chinese (Simplified)
|
||||
- All section content must be in Chinese
|
||||
@@ -122,3 +110,5 @@ After generating the profile, perform the following self-review steps:
|
||||
- Content follows immediately after the header
|
||||
- Sections are separated by blank lines
|
||||
- Strictly adhere to character limits for each section
|
||||
- **DO NOT include any text after the 【一句话总结】 section**
|
||||
- **DO NOT output reflection steps, self-review, or verification notes**
|
||||
|
||||
@@ -64,8 +64,8 @@ def by_mineru(filename, binary=None, from_page=0, to_page=100000, lang="Chinese"
|
||||
|
||||
def by_textln(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", callback=None, vision_model=None, pdf_cls = None, **kwargs):
|
||||
textln_api = os.environ.get("TEXTLN_APISERVER", "https://api.textin.com/ai/service/v1/pdf_to_markdown")
|
||||
app_id = os.environ.get("TEXTLN_APP_ID", "fa3f24380683ad53e6c620c0f0878a09")
|
||||
secret_code = os.environ.get("TEXTLN_SECRET_CODE", "6130caac9aabc6eb26433758d7898f4a")
|
||||
app_id = os.environ.get("TEXTLN_APP_ID", "")
|
||||
secret_code = os.environ.get("TEXTLN_SECRET_CODE", "")
|
||||
pdf_parser = TextLnParser(textln_api=textln_api, app_id=app_id, secret_code=secret_code)
|
||||
|
||||
sections, tables = pdf_parser.parse_pdf(
|
||||
|
||||
@@ -448,7 +448,7 @@ if __name__ == "__main__":
|
||||
# 准备配置vision_model信息
|
||||
# 初始化 QWenCV
|
||||
vision_model = QWenCV(
|
||||
key="sk-8e9e40cd171749858ce2d3722ea75669",
|
||||
key="",
|
||||
model_name="qwen-vl-max",
|
||||
lang="Chinese", # 默认使用中文
|
||||
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1"
|
||||
|
||||
@@ -191,10 +191,14 @@ class BaseTool(ABC):
|
||||
execution_time=execution_time
|
||||
)
|
||||
|
||||
def to_langchain_tool(self):
|
||||
"""转换为Langchain工具格式"""
|
||||
def to_langchain_tool(self, operation: Optional[str] = None):
|
||||
"""转换为Langchain工具格式
|
||||
|
||||
Args:
|
||||
operation: 特定操作(适用于有操作的工具)
|
||||
"""
|
||||
from app.core.tools.langchain_adapter import LangchainAdapter
|
||||
return LangchainAdapter.convert_tool(self)
|
||||
return LangchainAdapter.convert_tool(self, operation)
|
||||
|
||||
def __repr__(self):
|
||||
return f"<{self.__class__.__name__}(id={self.tool_id}, name={self.name})>"
|
||||
216
api/app/core/tools/builtin/operation_tool.py
Normal file
216
api/app/core/tools/builtin/operation_tool.py
Normal file
@@ -0,0 +1,216 @@
|
||||
"""操作工具 - 为特定操作创建的工具包装器"""
|
||||
from typing import List
|
||||
from app.core.tools.base import BaseTool, ToolParameter, ToolResult, ParameterType
|
||||
from app.models import ToolType
|
||||
|
||||
|
||||
class OperationTool(BaseTool):
|
||||
"""操作工具 - 包装基础工具的特定操作"""
|
||||
|
||||
def __init__(self, base_tool: BaseTool, operation: str):
|
||||
self.base_tool = base_tool
|
||||
self.operation = operation
|
||||
super().__init__(base_tool.tool_id, base_tool.config)
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return f"{self.base_tool.name}_{self.operation}"
|
||||
|
||||
@property
|
||||
def tool_type(self) -> ToolType:
|
||||
"""工具类型"""
|
||||
return ToolType.BUILTIN
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return f"{self.base_tool.description} - {self.operation}"
|
||||
|
||||
@property
|
||||
def parameters(self) -> List[ToolParameter]:
|
||||
"""返回特定操作的参数"""
|
||||
if self.base_tool.name == 'datetime_tool':
|
||||
return self._get_datetime_params()
|
||||
elif self.base_tool.name == 'json_tool':
|
||||
return self._get_json_params()
|
||||
else:
|
||||
# 默认返回除operation外的所有参数
|
||||
return [p for p in self.base_tool.parameters if p.name != "operation"]
|
||||
|
||||
def _get_datetime_params(self) -> List[ToolParameter]:
|
||||
"""获取datetime_tool特定操作的参数"""
|
||||
if self.operation == "now":
|
||||
return [
|
||||
ToolParameter(
|
||||
name="to_timezone",
|
||||
type=ParameterType.STRING,
|
||||
description="目标时区(如:UTC, Asia/Shanghai)",
|
||||
required=False,
|
||||
default="Asia/Shanghai"
|
||||
),
|
||||
ToolParameter(
|
||||
name="output_format",
|
||||
type=ParameterType.STRING,
|
||||
description="输出时间格式(如:%Y-%m-%d %H:%M:%S)",
|
||||
required=False,
|
||||
default="%Y-%m-%d %H:%M:%S"
|
||||
)
|
||||
]
|
||||
elif self.operation == "format":
|
||||
return [
|
||||
ToolParameter(
|
||||
name="input_value",
|
||||
type=ParameterType.STRING,
|
||||
description="输入值(时间字符串或时间戳)",
|
||||
required=True
|
||||
),
|
||||
ToolParameter(
|
||||
name="input_format",
|
||||
type=ParameterType.STRING,
|
||||
description="输入时间格式(如:%Y-%m-%d %H:%M:%S)",
|
||||
required=False,
|
||||
default="%Y-%m-%d %H:%M:%S"
|
||||
),
|
||||
ToolParameter(
|
||||
name="output_format",
|
||||
type=ParameterType.STRING,
|
||||
description="输出时间格式(如:%Y-%m-%d %H:%M:%S)",
|
||||
required=False,
|
||||
default="%Y-%m-%d %H:%M:%S"
|
||||
)
|
||||
]
|
||||
elif self.operation == "convert_timezone":
|
||||
return [
|
||||
ToolParameter(
|
||||
name="input_value",
|
||||
type=ParameterType.STRING,
|
||||
description="输入值(时间字符串或时间戳)",
|
||||
required=True
|
||||
),
|
||||
ToolParameter(
|
||||
name="input_format",
|
||||
type=ParameterType.STRING,
|
||||
description="输入时间格式(如:%Y-%m-%d %H:%M:%S)",
|
||||
required=False,
|
||||
default="%Y-%m-%d %H:%M:%S"
|
||||
),
|
||||
ToolParameter(
|
||||
name="output_format",
|
||||
type=ParameterType.STRING,
|
||||
description="输出时间格式(如:%Y-%m-%d %H:%M:%S)",
|
||||
required=False,
|
||||
default="%Y-%m-%d %H:%M:%S"
|
||||
),
|
||||
ToolParameter(
|
||||
name="from_timezone",
|
||||
type=ParameterType.STRING,
|
||||
description="源时区(如:UTC, Asia/Shanghai)",
|
||||
required=False,
|
||||
default="Asia/Shanghai"
|
||||
),
|
||||
ToolParameter(
|
||||
name="to_timezone",
|
||||
type=ParameterType.STRING,
|
||||
description="目标时区(如:UTC, Asia/Shanghai)",
|
||||
required=False,
|
||||
default="Asia/Shanghai"
|
||||
)
|
||||
]
|
||||
elif self.operation == "timestamp_to_datetime":
|
||||
return [
|
||||
ToolParameter(
|
||||
name="input_value",
|
||||
type=ParameterType.STRING,
|
||||
description="输入值(时间字符串或时间戳)",
|
||||
required=True
|
||||
),
|
||||
ToolParameter(
|
||||
name="output_format",
|
||||
type=ParameterType.STRING,
|
||||
description="输出时间格式(如:%Y-%m-%d %H:%M:%S)",
|
||||
required=False,
|
||||
default="%Y-%m-%d %H:%M:%S"
|
||||
),
|
||||
ToolParameter(
|
||||
name="to_timezone",
|
||||
type=ParameterType.STRING,
|
||||
description="目标时区(如:UTC, Asia/Shanghai)",
|
||||
required=False,
|
||||
default="Asia/Shanghai"
|
||||
)
|
||||
]
|
||||
else:
|
||||
return []
|
||||
|
||||
def _get_json_params(self) -> List[ToolParameter]:
|
||||
"""获取json_tool特定操作的参数"""
|
||||
base_params = [
|
||||
ToolParameter(
|
||||
name="input_data",
|
||||
type=ParameterType.STRING,
|
||||
description="输入数据(JSON字符串、YAML字符串或XML字符串)",
|
||||
required=True
|
||||
)
|
||||
]
|
||||
|
||||
if self.operation == "insert":
|
||||
return base_params + [
|
||||
ToolParameter(
|
||||
name="json_path",
|
||||
type=ParameterType.STRING,
|
||||
description="JSON路径表达式(如:$.user.name或users[0].name)",
|
||||
required=True
|
||||
),
|
||||
ToolParameter(
|
||||
name="new_value",
|
||||
type=ParameterType.STRING,
|
||||
description="新值(用于insert操作)",
|
||||
required=True
|
||||
)
|
||||
]
|
||||
elif self.operation == "replace":
|
||||
return base_params + [
|
||||
ToolParameter(
|
||||
name="json_path",
|
||||
type=ParameterType.STRING,
|
||||
description="JSON路径表达式(如:$.user.name或users[0].name)",
|
||||
required=True
|
||||
),
|
||||
ToolParameter(
|
||||
name="old_text",
|
||||
type=ParameterType.STRING,
|
||||
description="要替换的原文本(用于replace操作)",
|
||||
required=True
|
||||
),
|
||||
ToolParameter(
|
||||
name="new_text",
|
||||
type=ParameterType.STRING,
|
||||
description="替换后的新文本(用于replace操作)",
|
||||
required=True
|
||||
)
|
||||
]
|
||||
elif self.operation == "delete":
|
||||
return base_params + [
|
||||
ToolParameter(
|
||||
name="json_path",
|
||||
type=ParameterType.STRING,
|
||||
description="JSON路径表达式(如:$.user.name或users[0].name)",
|
||||
required=True
|
||||
)
|
||||
]
|
||||
elif self.operation == "parse":
|
||||
return base_params + [
|
||||
ToolParameter(
|
||||
name="json_path",
|
||||
type=ParameterType.STRING,
|
||||
description="JSON路径表达式(如:$.user.name或users[0].name)",
|
||||
required=True
|
||||
)
|
||||
]
|
||||
else:
|
||||
return base_params
|
||||
|
||||
async def execute(self, **kwargs) -> ToolResult:
|
||||
"""执行特定操作"""
|
||||
# 添加operation参数
|
||||
kwargs["operation"] = self.operation
|
||||
return await self.base_tool.execute(**kwargs)
|
||||
@@ -1,4 +1,5 @@
|
||||
"""自定义工具基类"""
|
||||
import json
|
||||
import time
|
||||
from typing import Dict, Any, List, Optional
|
||||
import aiohttp
|
||||
@@ -135,6 +136,13 @@ class CustomTool(BaseTool):
|
||||
|
||||
if not self.schema_content:
|
||||
return operations
|
||||
|
||||
if isinstance(self.schema_content, str):
|
||||
try:
|
||||
self.schema_content = json.loads(self.schema_content)
|
||||
except json.JSONDecodeError:
|
||||
logger.error(f"无效的OpenAPI schema: {self.schema_content}")
|
||||
return operations
|
||||
|
||||
paths = self.schema_content.get("paths", {})
|
||||
|
||||
|
||||
@@ -21,24 +21,35 @@ class LangchainToolWrapper(LangchainBaseTool):
|
||||
|
||||
# 内部工具实例
|
||||
tool_instance: BaseTool = Field(..., description="内部工具实例")
|
||||
# 特定操作(用于自定义工具)
|
||||
operation: Optional[str] = Field(None, description="特定操作")
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def __init__(self, tool_instance: BaseTool, **kwargs):
|
||||
def __init__(self, tool_instance: BaseTool, operation: Optional[str] = None, **kwargs):
|
||||
"""初始化Langchain工具包装器
|
||||
|
||||
Args:
|
||||
tool_instance: 内部工具实例
|
||||
operation: 特定操作(用于自定义工具)
|
||||
"""
|
||||
# 动态创建参数schema
|
||||
args_schema = LangchainAdapter._create_pydantic_schema(tool_instance.parameters)
|
||||
args_schema = LangchainAdapter._create_pydantic_schema(
|
||||
tool_instance.parameters, operation
|
||||
)
|
||||
|
||||
# 构建工具名称
|
||||
tool_name = tool_instance.name
|
||||
if operation:
|
||||
tool_name = f"{tool_instance.name}_{operation}"
|
||||
|
||||
super().__init__(
|
||||
name=tool_instance.name,
|
||||
name=tool_name,
|
||||
description=tool_instance.description,
|
||||
args_schema=args_schema,
|
||||
_tool_instance=tool_instance,
|
||||
tool_instance=tool_instance,
|
||||
operation=operation,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
@@ -58,8 +69,12 @@ class LangchainToolWrapper(LangchainBaseTool):
|
||||
) -> str:
|
||||
"""异步执行工具"""
|
||||
try:
|
||||
# 如果有特定操作,添加到参数中
|
||||
if self.operation:
|
||||
kwargs["operation"] = self.operation
|
||||
|
||||
# 执行内部工具
|
||||
result = await self._tool_instance.safe_execute(**kwargs)
|
||||
result = await self.tool_instance.safe_execute(**kwargs)
|
||||
|
||||
# 转换结果为Langchain格式
|
||||
return LangchainAdapter._format_result_for_langchain(result)
|
||||
@@ -73,24 +88,82 @@ class LangchainAdapter:
|
||||
"""Langchain适配器 - 负责工具格式转换和标准化"""
|
||||
|
||||
@staticmethod
|
||||
def convert_tool(tool: BaseTool) -> LangchainToolWrapper:
|
||||
def convert_tool(tool: BaseTool, operation: Optional[str] = None) -> LangchainToolWrapper:
|
||||
"""将内部工具转换为Langchain工具
|
||||
|
||||
Args:
|
||||
tool: 内部工具实例
|
||||
operation: 特定操作(适用于有操作的工具)或MCP工具名称
|
||||
|
||||
Returns:
|
||||
Langchain兼容的工具包装器
|
||||
"""
|
||||
try:
|
||||
wrapper = LangchainToolWrapper(tool_instance=tool)
|
||||
logger.debug(f"工具转换成功: {tool.name} -> Langchain格式")
|
||||
return wrapper
|
||||
# 处理MCP工具的特定工具名称
|
||||
if hasattr(tool, 'tool_type') and tool.tool_type.value == "mcp" and operation:
|
||||
# 为MCP工具创建特定工具名称的实例
|
||||
mcp_tool = LangchainAdapter._create_mcp_tool_with_name(tool, operation)
|
||||
wrapper = LangchainToolWrapper(tool_instance=mcp_tool)
|
||||
logger.debug(f"MCP工具转换成功: {tool.name}_{operation} -> Langchain格式")
|
||||
return wrapper
|
||||
elif operation and LangchainAdapter._tool_supports_operations(tool):
|
||||
# 为支持多操作的工具创建特定操作实例
|
||||
if tool.tool_type.value == "custom":
|
||||
# 自定义工具直接传递operation参数
|
||||
wrapper = LangchainToolWrapper(tool_instance=tool, operation=operation)
|
||||
else:
|
||||
# 内置工具使用OperationTool包装
|
||||
operation_tool = LangchainAdapter._create_operation_tool(tool, operation)
|
||||
wrapper = LangchainToolWrapper(tool_instance=operation_tool)
|
||||
logger.debug(f"工具转换成功: {tool.name}_{operation} -> Langchain格式")
|
||||
return wrapper
|
||||
else:
|
||||
# 单个工具
|
||||
wrapper = LangchainToolWrapper(tool_instance=tool)
|
||||
logger.debug(f"工具转换成功: {tool.name} -> Langchain格式")
|
||||
return wrapper
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"工具转换失败: {tool.name}, 错误: {e}")
|
||||
raise
|
||||
|
||||
@staticmethod
|
||||
def _tool_supports_operations(tool: BaseTool) -> bool:
|
||||
"""检查工具是否支持多操作"""
|
||||
# 内置工具中支持操作的工具
|
||||
builtin_operation_tools = ['datetime_tool', 'json_tool']
|
||||
|
||||
# 检查内置工具
|
||||
if tool.tool_type.value == "builtin" and tool.name in builtin_operation_tools:
|
||||
return True
|
||||
|
||||
# 检查自定义工具(自定义工具通过解析OpenAPI schema支持多操作)
|
||||
if tool.tool_type.value == "custom":
|
||||
# 检查工具是否有多个操作
|
||||
if hasattr(tool, '_parsed_operations') and len(tool._parsed_operations) > 1:
|
||||
return True
|
||||
# 或者检查参数中是否有operation参数
|
||||
for param in tool.parameters:
|
||||
if param.name == "operation" and param.enum:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _create_operation_tool(base_tool: BaseTool, operation: str) -> BaseTool:
|
||||
"""为特定操作创建工具实例"""
|
||||
if base_tool.tool_type.value == "builtin":
|
||||
from app.core.tools.builtin.operation_tool import OperationTool
|
||||
return OperationTool(base_tool, operation)
|
||||
else:
|
||||
raise ValueError(f"不支持的工具类型: {base_tool.tool_type.value}")
|
||||
|
||||
@staticmethod
|
||||
def _create_mcp_tool_with_name(mcp_tool: BaseTool, tool_name: str) -> BaseTool:
|
||||
"""为MCP工具创建指定工具名称的实例"""
|
||||
mcp_tool.set_current_tool(tool_name)
|
||||
return mcp_tool
|
||||
|
||||
@staticmethod
|
||||
def convert_tools(tools: List[BaseTool]) -> List[LangchainToolWrapper]:
|
||||
"""批量转换工具
|
||||
@@ -110,15 +183,19 @@ class LangchainAdapter:
|
||||
except Exception as e:
|
||||
logger.error(f"跳过工具转换: {tool.name}, 错误: {e}")
|
||||
|
||||
logger.info(f"批量转换完成: {len(converted_tools)}/{len(tools)} 个工具")
|
||||
logger.info(f"批量转换完成: {len(converted_tools)} 个工具")
|
||||
return converted_tools
|
||||
|
||||
@staticmethod
|
||||
def _create_pydantic_schema(parameters: List[ToolParameter]) -> Type[BaseModel]:
|
||||
def _create_pydantic_schema(
|
||||
parameters: List[ToolParameter],
|
||||
operation: Optional[str] = None
|
||||
) -> Type[BaseModel]:
|
||||
"""根据工具参数创建Pydantic schema
|
||||
|
||||
Args:
|
||||
parameters: 工具参数列表
|
||||
operation: 特定操作(用于过滤参数)
|
||||
|
||||
Returns:
|
||||
Pydantic模型类
|
||||
@@ -127,7 +204,12 @@ class LangchainAdapter:
|
||||
fields = {}
|
||||
annotations = {}
|
||||
|
||||
for param in parameters:
|
||||
# 如果指定了operation,过滤掉operation参数
|
||||
filtered_params = parameters
|
||||
if operation:
|
||||
filtered_params = [p for p in parameters if p.name != "operation"]
|
||||
|
||||
for param in filtered_params:
|
||||
# 确定Python类型
|
||||
python_type = LangchainAdapter._get_python_type(param.type)
|
||||
|
||||
@@ -169,9 +251,10 @@ class LangchainAdapter:
|
||||
"ToolArgsSchema",
|
||||
(BaseModel,),
|
||||
{
|
||||
"__module__": __name__,
|
||||
"__annotations__": annotations,
|
||||
**fields,
|
||||
"Config": type("Config", (), {"extra": "forbid"})
|
||||
"model_config": {"extra": "forbid"},
|
||||
**fields
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@@ -1,12 +1,20 @@
|
||||
"""MCP工具模块"""
|
||||
"""MCP 工具模块 - Model Context Protocol 支持"""
|
||||
|
||||
from app.core.tools.mcp.base import MCPTool
|
||||
from app.core.tools.mcp.client import MCPClient, MCPConnectionPool
|
||||
from app.core.tools.mcp.service_manager import MCPServiceManager
|
||||
# 主要类导出
|
||||
from .base import MCPTool, MCPToolManager, MCPError
|
||||
from .client import SimpleMCPClient, MCPConnectionError
|
||||
from .service_manager import MCPServiceManager
|
||||
|
||||
__all__ = [
|
||||
# 核心类
|
||||
"MCPTool",
|
||||
"MCPClient",
|
||||
"MCPConnectionPool",
|
||||
"MCPToolManager",
|
||||
"MCPError",
|
||||
|
||||
# 客户端类
|
||||
"SimpleMCPClient",
|
||||
"MCPConnectionError",
|
||||
|
||||
# 服务管理(简化版)
|
||||
"MCPServiceManager"
|
||||
]
|
||||
@@ -1,10 +1,9 @@
|
||||
"""MCP工具基类"""
|
||||
"""MCP工具基类 - 整合版本"""
|
||||
import time
|
||||
from typing import Dict, Any, List
|
||||
from typing import List, Dict, Any
|
||||
|
||||
from app.models.tool_model import ToolType
|
||||
from app.core.tools.base import BaseTool
|
||||
from app.schemas.tool_schema import ToolParameter, ToolResult, ParameterType
|
||||
from app.core.tools.base import BaseTool, ToolParameter, ToolResult, ParameterType
|
||||
from app.core.logging_config import get_business_logger
|
||||
|
||||
logger = get_business_logger()
|
||||
@@ -14,215 +13,188 @@ class MCPTool(BaseTool):
|
||||
"""MCP工具 - Model Context Protocol工具"""
|
||||
|
||||
def __init__(self, tool_id: str, config: Dict[str, Any]):
|
||||
"""初始化MCP工具
|
||||
|
||||
Args:
|
||||
tool_id: 工具ID
|
||||
config: 工具配置
|
||||
"""
|
||||
super().__init__(tool_id, config)
|
||||
self.server_url = config.get("server_url", "")
|
||||
self.connection_config = config.get("connection_config", {})
|
||||
self.available_tools = config.get("available_tools", [])
|
||||
self._client = None
|
||||
self._connected = False
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
"""工具名称"""
|
||||
return f"mcp_tool_{self.tool_id[:8]}"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
"""工具描述"""
|
||||
return f"MCP工具 - 连接到 {self.server_url}"
|
||||
|
||||
@property
|
||||
def tool_type(self) -> ToolType:
|
||||
"""工具类型"""
|
||||
return ToolType.MCP
|
||||
|
||||
@property
|
||||
def parameters(self) -> List[ToolParameter]:
|
||||
"""工具参数定义"""
|
||||
params = []
|
||||
"""根据工具名称返回对应参数"""
|
||||
# 如果有指定的工具名称,从 available_tools 中获取参数
|
||||
tool_name = getattr(self, '_current_tool_name', None)
|
||||
if tool_name and self.available_tools:
|
||||
for tool_info in self.available_tools:
|
||||
if tool_info.get("tool_name") == tool_name:
|
||||
arguments = tool_info.get("arguments", {})
|
||||
return self._generate_parameters_from_schema(arguments)
|
||||
|
||||
# 添加工具选择参数
|
||||
if len(self.available_tools) > 1:
|
||||
params.append(ToolParameter(
|
||||
# 默认返回通用参数
|
||||
return [
|
||||
ToolParameter(
|
||||
name="tool_name",
|
||||
type=ParameterType.STRING,
|
||||
description="要调用的MCP工具名称",
|
||||
required=True,
|
||||
enum=self.available_tools
|
||||
))
|
||||
|
||||
# 添加通用参数
|
||||
params.extend([
|
||||
description="要执行的工具名称",
|
||||
required=True
|
||||
),
|
||||
ToolParameter(
|
||||
name="arguments",
|
||||
type=ParameterType.OBJECT,
|
||||
description="工具参数(JSON对象)",
|
||||
description="工具参数",
|
||||
required=False,
|
||||
default={}
|
||||
),
|
||||
ToolParameter(
|
||||
name="timeout",
|
||||
type=ParameterType.INTEGER,
|
||||
description="超时时间(秒)",
|
||||
required=False,
|
||||
default=30,
|
||||
minimum=1,
|
||||
maximum=300
|
||||
)
|
||||
])
|
||||
]
|
||||
|
||||
def _generate_parameters_from_schema(self, arguments: Dict[str, Any]) -> List[ToolParameter]:
|
||||
"""从参数schema生成参数列表"""
|
||||
properties = arguments.get("properties", {})
|
||||
required_fields = arguments.get("required", [])
|
||||
|
||||
params = []
|
||||
for param_name, param_def in properties.items():
|
||||
param_type = self._convert_json_type_to_parameter_type(param_def.get("type", "string"))
|
||||
|
||||
params.append(ToolParameter(
|
||||
name=param_name,
|
||||
type=param_type,
|
||||
description=param_def.get("description", f"参数: {param_name}"),
|
||||
required=param_name in required_fields,
|
||||
default=param_def.get("default"),
|
||||
enum=param_def.get("enum"),
|
||||
minimum=param_def.get("minimum"),
|
||||
maximum=param_def.get("maximum")
|
||||
))
|
||||
|
||||
return params
|
||||
|
||||
def _convert_json_type_to_parameter_type(self, json_type: str) -> ParameterType:
|
||||
"""转换JSON Schema类型到ParameterType"""
|
||||
type_mapping = {
|
||||
"string": ParameterType.STRING,
|
||||
"integer": ParameterType.INTEGER,
|
||||
"number": ParameterType.NUMBER,
|
||||
"boolean": ParameterType.BOOLEAN,
|
||||
"array": ParameterType.ARRAY,
|
||||
"object": ParameterType.OBJECT
|
||||
}
|
||||
return type_mapping.get(json_type, ParameterType.STRING)
|
||||
|
||||
def set_current_tool(self, tool_name: str):
|
||||
"""设置当前工具名称,用于获取特定参数"""
|
||||
self._current_tool_name = tool_name
|
||||
|
||||
async def execute(self, **kwargs) -> ToolResult:
|
||||
"""执行MCP工具"""
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
# 确保连接
|
||||
if not self._connected:
|
||||
await self.connect()
|
||||
|
||||
# 确定要调用的工具
|
||||
tool_name = kwargs.get("tool_name")
|
||||
if not tool_name and len(self.available_tools) == 1:
|
||||
tool_name = self.available_tools[0]
|
||||
|
||||
if not tool_name:
|
||||
raise ValueError("必须指定要调用的MCP工具名称")
|
||||
raise Exception("未指定工具名称")
|
||||
|
||||
if tool_name not in self.available_tools:
|
||||
raise ValueError(f"MCP工具不存在: {tool_name}")
|
||||
|
||||
# 获取参数
|
||||
arguments = kwargs.get("arguments", {})
|
||||
timeout = kwargs.get("timeout", 30)
|
||||
|
||||
# 调用MCP工具
|
||||
result = await self._call_mcp_tool(tool_name, arguments, timeout)
|
||||
from .client import SimpleMCPClient
|
||||
|
||||
execution_time = time.time() - start_time
|
||||
return ToolResult.success_result(
|
||||
data=result,
|
||||
execution_time=execution_time
|
||||
)
|
||||
client = SimpleMCPClient(self.server_url, self.connection_config)
|
||||
|
||||
async with client:
|
||||
result = await client.call_tool(tool_name, arguments)
|
||||
|
||||
execution_time = time.time() - start_time
|
||||
return ToolResult.success_result(
|
||||
data=result,
|
||||
execution_time=execution_time
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
execution_time = time.time() - start_time
|
||||
logger.error(f"MCP工具执行失败: {kwargs.get('tool_name', 'unknown')}, 错误: {e}")
|
||||
return ToolResult.error_result(
|
||||
error=str(e),
|
||||
error_code="MCP_ERROR",
|
||||
error_code="MCP_EXECUTION_ERROR",
|
||||
execution_time=execution_time
|
||||
)
|
||||
|
||||
|
||||
class MCPError(Exception):
|
||||
"""MCP 错误基类"""
|
||||
pass
|
||||
|
||||
|
||||
class MCPToolManager:
|
||||
"""MCP 工具管理器 - 简化版本"""
|
||||
|
||||
async def connect(self) -> bool:
|
||||
"""连接到MCP服务器"""
|
||||
def __init__(self, db=None):
|
||||
self.db = db
|
||||
self._tool_cache: Dict[str, Dict[str, Any]] = {} # server_url -> tools_info
|
||||
|
||||
async def discover_tools(
|
||||
self,
|
||||
server_url: str,
|
||||
connection_config: Dict[str, Any] = None
|
||||
) -> tuple[bool, List[Dict[str, Any]], str | None]:
|
||||
"""发现 MCP 服务器上的工具"""
|
||||
try:
|
||||
from .client import MCPClient
|
||||
from .client import SimpleMCPClient
|
||||
|
||||
if self._connected:
|
||||
return True
|
||||
|
||||
self._client = MCPClient(self.server_url, self.connection_config)
|
||||
|
||||
if await self._client.connect():
|
||||
self._connected = True
|
||||
# 更新可用工具列表
|
||||
await self._update_available_tools()
|
||||
logger.info(f"MCP服务器连接成功: {self.server_url}")
|
||||
return True
|
||||
else:
|
||||
logger.error(f"MCP服务器连接失败: {self.server_url}")
|
||||
return False
|
||||
client = SimpleMCPClient(server_url, connection_config)
|
||||
|
||||
async with client:
|
||||
tools = await client.list_tools()
|
||||
|
||||
# 缓存工具信息
|
||||
self._tool_cache[server_url] = {
|
||||
"tools": tools,
|
||||
"connection_config": connection_config,
|
||||
"last_updated": time.time()
|
||||
}
|
||||
|
||||
logger.info(f"发现 {len(tools)} 个MCP工具: {server_url}")
|
||||
return True, tools, None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"MCP服务器连接异常: {self.server_url}, 错误: {e}")
|
||||
self._connected = False
|
||||
return False
|
||||
error_msg = f"发现工具失败: {e}"
|
||||
logger.error(error_msg)
|
||||
return False, [], error_msg
|
||||
|
||||
async def _update_available_tools(self):
|
||||
"""更新可用工具列表"""
|
||||
async def test_tool_connection(
|
||||
self,
|
||||
server_url: str,
|
||||
connection_config: Dict[str, Any] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""测试工具连接"""
|
||||
try:
|
||||
if self._client and self._connected:
|
||||
tools = await self._client.list_tools()
|
||||
self.available_tools = [tool.get("name") for tool in tools if tool.get("name")]
|
||||
logger.info(f"MCP工具列表已更新: {len(self.available_tools)} 个工具")
|
||||
except Exception as e:
|
||||
logger.error(f"更新MCP工具列表失败: {e}")
|
||||
|
||||
async def disconnect(self) -> bool:
|
||||
"""断开MCP服务器连接"""
|
||||
try:
|
||||
if self._client:
|
||||
await self._client.disconnect()
|
||||
self._client = None
|
||||
from .client import SimpleMCPClient
|
||||
|
||||
self._connected = False
|
||||
logger.info(f"MCP服务器连接已断开: {self.server_url}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"断开MCP服务器连接失败: {e}")
|
||||
return False
|
||||
|
||||
def get_health_status(self) -> Dict[str, Any]:
|
||||
"""获取MCP服务健康状态"""
|
||||
return {
|
||||
"connected": self._connected,
|
||||
"server_url": self.server_url,
|
||||
"available_tools": self.available_tools,
|
||||
"last_check": time.time()
|
||||
}
|
||||
|
||||
async def _call_mcp_tool(self, tool_name: str, arguments: Dict[str, Any], timeout: int) -> Any:
|
||||
"""调用MCP工具"""
|
||||
if not self._client or not self._connected:
|
||||
raise Exception("MCP客户端未连接")
|
||||
|
||||
try:
|
||||
result = await self._client.call_tool(tool_name, arguments, timeout)
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"MCP工具调用失败: {tool_name}, 错误: {e}")
|
||||
raise
|
||||
|
||||
async def list_available_tools(self) -> List[Dict[str, Any]]:
|
||||
"""列出可用的MCP工具"""
|
||||
try:
|
||||
if not self._connected:
|
||||
await self.connect()
|
||||
|
||||
if self._client:
|
||||
tools = await self._client.list_tools()
|
||||
self.available_tools = [tool.get("name") for tool in tools if tool.get("name")]
|
||||
return tools
|
||||
|
||||
return []
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取MCP工具列表失败: {e}")
|
||||
return []
|
||||
|
||||
def test_connection(self) -> Dict[str, Any]:
|
||||
"""测试MCP连接"""
|
||||
try:
|
||||
# 这里应该实现同步的连接测试
|
||||
# 为了简化,返回基本信息
|
||||
return {
|
||||
"success": bool(self.server_url),
|
||||
"server_url": self.server_url,
|
||||
"connected": self._connected,
|
||||
"available_tools_count": len(self.available_tools),
|
||||
"message": "MCP配置有效" if self.server_url else "缺少服务器URL配置"
|
||||
}
|
||||
client = SimpleMCPClient(server_url, connection_config)
|
||||
|
||||
async with client:
|
||||
tools = await client.list_tools()
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"tools_count": len(tools),
|
||||
"tools": [tool.get("name") for tool in tools],
|
||||
"message": "连接成功"
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e)
|
||||
"error": str(e),
|
||||
"message": "连接失败"
|
||||
}
|
||||
@@ -1,9 +1,8 @@
|
||||
"""MCP客户端 - Model Context Protocol客户端实现"""
|
||||
"""MCP客户端 - 简化版本"""
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
from typing import Dict, Any, List, Optional, Callable
|
||||
from urllib.parse import urlparse
|
||||
from typing import Dict, Any, List
|
||||
import aiohttp
|
||||
import websockets
|
||||
from websockets.exceptions import ConnectionClosed
|
||||
@@ -18,139 +17,156 @@ class MCPConnectionError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class MCPProtocolError(Exception):
|
||||
"""MCP协议错误"""
|
||||
pass
|
||||
|
||||
|
||||
class MCPClient:
|
||||
"""MCP客户端 - 支持HTTP和WebSocket连接"""
|
||||
class SimpleMCPClient:
|
||||
"""简化的 MCP 客户端"""
|
||||
|
||||
def __init__(self, server_url: str, connection_config: Dict[str, Any] = None):
|
||||
"""初始化MCP客户端
|
||||
|
||||
Args:
|
||||
server_url: MCP服务器URL
|
||||
connection_config: 连接配置
|
||||
"""
|
||||
self.server_url = server_url
|
||||
self.connection_config = connection_config or {}
|
||||
self.timeout = self.connection_config.get("timeout", 30)
|
||||
|
||||
# 解析URL确定连接类型
|
||||
parsed_url = urlparse(server_url)
|
||||
self.connection_type = "websocket" if parsed_url.scheme in ["ws", "wss"] else "http"
|
||||
# 确定连接类型
|
||||
self.is_websocket = server_url.startswith(("ws://", "wss://"))
|
||||
|
||||
# 连接状态
|
||||
self._connected = False
|
||||
self._websocket = None
|
||||
self._session = None
|
||||
|
||||
# 请求管理
|
||||
self._request_id = 0
|
||||
self._pending_requests: Dict[str, asyncio.Future] = {}
|
||||
|
||||
# 连接池配置
|
||||
self.max_connections = self.connection_config.get("max_connections", 10)
|
||||
self.connection_timeout = self.connection_config.get("timeout", 30)
|
||||
self.retry_attempts = self.connection_config.get("retry_attempts", 3)
|
||||
self.retry_delay = self.connection_config.get("retry_delay", 1)
|
||||
|
||||
# 健康检查
|
||||
self.health_check_interval = self.connection_config.get("health_check_interval", 60)
|
||||
self._health_check_task = None
|
||||
self._last_health_check = None
|
||||
|
||||
# 事件回调
|
||||
self._on_connect_callbacks: List[Callable] = []
|
||||
self._on_disconnect_callbacks: List[Callable] = []
|
||||
self._on_error_callbacks: List[Callable] = []
|
||||
self._pending_requests = {}
|
||||
|
||||
async def connect(self) -> bool:
|
||||
"""连接到MCP服务器
|
||||
|
||||
Returns:
|
||||
连接是否成功
|
||||
"""
|
||||
async def __aenter__(self):
|
||||
"""异步上下文管理器入口"""
|
||||
await self.connect()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
"""异步上下文管理器出口"""
|
||||
await self.disconnect()
|
||||
|
||||
async def connect(self):
|
||||
"""建立连接"""
|
||||
try:
|
||||
if self._connected:
|
||||
return True
|
||||
|
||||
logger.info(f"连接MCP服务器: {self.server_url}")
|
||||
|
||||
if self.connection_type == "websocket":
|
||||
success = await self._connect_websocket()
|
||||
if self.is_websocket:
|
||||
await self._connect_websocket()
|
||||
else:
|
||||
success = await self._connect_http()
|
||||
|
||||
if success:
|
||||
self._connected = True
|
||||
await self._start_health_check()
|
||||
await self._notify_connect_callbacks()
|
||||
logger.info(f"MCP服务器连接成功: {self.server_url}")
|
||||
|
||||
return success
|
||||
|
||||
await self._connect_http()
|
||||
except Exception as e:
|
||||
logger.error(f"连接MCP服务器失败: {self.server_url}, 错误: {e}")
|
||||
await self._notify_error_callbacks(e)
|
||||
return False
|
||||
logger.error(f"MCP连接失败: {self.server_url}, 错误: {e}")
|
||||
raise MCPConnectionError(f"连接失败: {e}")
|
||||
|
||||
async def disconnect(self) -> bool:
|
||||
"""断开MCP服务器连接
|
||||
|
||||
Returns:
|
||||
断开是否成功
|
||||
"""
|
||||
async def disconnect(self):
|
||||
"""断开连接"""
|
||||
try:
|
||||
if not self._connected:
|
||||
return True
|
||||
|
||||
logger.info(f"断开MCP服务器连接: {self.server_url}")
|
||||
|
||||
# 停止健康检查
|
||||
await self._stop_health_check()
|
||||
|
||||
# 取消所有待处理的请求
|
||||
for future in self._pending_requests.values():
|
||||
if not future.done():
|
||||
future.cancel()
|
||||
self._pending_requests.clear()
|
||||
|
||||
# 断开连接
|
||||
if self.connection_type == "websocket" and self._websocket:
|
||||
if self._websocket:
|
||||
await self._websocket.close()
|
||||
self._websocket = None
|
||||
elif self._session:
|
||||
|
||||
if self._session:
|
||||
await self._session.close()
|
||||
self._session = None
|
||||
|
||||
self._connected = False
|
||||
await self._notify_disconnect_callbacks()
|
||||
logger.info(f"MCP服务器连接已断开: {self.server_url}")
|
||||
|
||||
return True
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"断开MCP服务器连接失败: {e}")
|
||||
return False
|
||||
logger.error(f"断开连接失败: {e}")
|
||||
|
||||
def _build_auth_headers(self) -> Dict[str, str]:
|
||||
"""构建认证头"""
|
||||
headers = {}
|
||||
auth_type = self.connection_config.get("auth_type", "none")
|
||||
async def _connect_websocket(self):
|
||||
"""WebSocket 连接"""
|
||||
headers = self._build_headers()
|
||||
|
||||
self._websocket = await websockets.connect(
|
||||
self.server_url,
|
||||
extra_headers=headers,
|
||||
timeout=self.timeout
|
||||
)
|
||||
|
||||
# 启动消息处理
|
||||
asyncio.create_task(self._handle_websocket_messages())
|
||||
|
||||
# 发送初始化消息
|
||||
await self._send_initialize()
|
||||
|
||||
async def _connect_http(self):
|
||||
"""HTTP 连接"""
|
||||
headers = self._build_headers()
|
||||
timeout = aiohttp.ClientTimeout(total=self.timeout)
|
||||
|
||||
self._session = aiohttp.ClientSession(
|
||||
headers=headers,
|
||||
timeout=timeout
|
||||
)
|
||||
|
||||
# 对于 ModelScope MCP 服务,需要先发送初始化请求
|
||||
if "modelscope.net" in self.server_url:
|
||||
await self._initialize_modelscope_session()
|
||||
|
||||
async def _initialize_modelscope_session(self):
|
||||
"""初始化 ModelScope MCP 会话"""
|
||||
init_request = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": self._get_request_id(),
|
||||
"method": "initialize",
|
||||
"params": {
|
||||
"protocolVersion": "2024-11-05",
|
||||
"capabilities": {"tools": {}},
|
||||
"clientInfo": {
|
||||
"name": "MemoryBear",
|
||||
"version": "1.0.0"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
try:
|
||||
async with self._session.post(
|
||||
self.server_url,
|
||||
json=init_request
|
||||
) as response:
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
raise MCPConnectionError(f"初始化失败 {response.status}: {error_text}")
|
||||
|
||||
init_response = await response.json()
|
||||
if "error" in init_response:
|
||||
raise MCPConnectionError(f"初始化失败: {init_response['error']}")
|
||||
|
||||
# 获取 session ID
|
||||
session_id = response.headers.get("Mcp-Session-Id") or response.headers.get("mcp-session-id")
|
||||
if session_id:
|
||||
self._session.headers.update({"Mcp-Session-Id": session_id})
|
||||
|
||||
# 发送 initialized 通知
|
||||
initialized_notification = {
|
||||
"jsonrpc": "2.0",
|
||||
"method": "notifications/initialized"
|
||||
}
|
||||
|
||||
async with self._session.post(
|
||||
self.server_url,
|
||||
json=initialized_notification
|
||||
) as notif_response:
|
||||
pass
|
||||
|
||||
except aiohttp.ClientError as e:
|
||||
raise MCPConnectionError(f"初始化连接失败: {e}")
|
||||
|
||||
def _build_headers(self) -> Dict[str, str]:
|
||||
"""构建请求头"""
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json, text/event-stream"
|
||||
}
|
||||
|
||||
# 添加认证头
|
||||
auth_config = self.connection_config.get("auth_config", {})
|
||||
auth_type = self.connection_config.get("auth_type", "none")
|
||||
|
||||
if auth_type == "api_key":
|
||||
api_key = auth_config.get("api_key")
|
||||
key_name = auth_config.get("key_name", "X-API-Key")
|
||||
if api_key:
|
||||
headers[key_name] = api_key
|
||||
|
||||
elif auth_type == "bearer_token":
|
||||
if auth_type == "bearer_token":
|
||||
token = auth_config.get("token")
|
||||
if token:
|
||||
headers["Authorization"] = f"Bearer {token}"
|
||||
|
||||
elif auth_type == "api_key":
|
||||
key = auth_config.get("api_key")
|
||||
header_name = auth_config.get("key_name", "X-API-Key")
|
||||
if key:
|
||||
headers[header_name] = key
|
||||
elif auth_type == "basic_auth":
|
||||
username = auth_config.get("username")
|
||||
password = auth_config.get("password")
|
||||
@@ -161,160 +177,63 @@ class MCPClient:
|
||||
|
||||
return headers
|
||||
|
||||
async def _connect_websocket(self) -> bool:
|
||||
"""建立WebSocket连接"""
|
||||
try:
|
||||
# WebSocket连接配置
|
||||
extra_headers = self.connection_config.get("headers", {})
|
||||
auth_headers = self._build_auth_headers()
|
||||
extra_headers.update(auth_headers)
|
||||
|
||||
self._websocket = await websockets.connect(
|
||||
self.server_url,
|
||||
extra_headers=extra_headers,
|
||||
timeout=self.connection_timeout
|
||||
)
|
||||
|
||||
# 启动消息监听
|
||||
asyncio.create_task(self._websocket_message_handler())
|
||||
|
||||
# 发送初始化消息
|
||||
init_message = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": self._get_next_request_id(),
|
||||
"method": "initialize",
|
||||
"params": {
|
||||
"protocolVersion": "2024-11-05",
|
||||
"capabilities": {
|
||||
"tools": {}
|
||||
},
|
||||
"clientInfo": {
|
||||
"name": "ToolManagementSystem",
|
||||
"version": "1.0.0"
|
||||
}
|
||||
async def _send_initialize(self):
|
||||
"""发送初始化消息"""
|
||||
init_message = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": self._get_request_id(),
|
||||
"method": "initialize",
|
||||
"params": {
|
||||
"protocolVersion": "2024-11-05",
|
||||
"capabilities": {"tools": {}},
|
||||
"clientInfo": {
|
||||
"name": "MemoryBear",
|
||||
"version": "1.0.0"
|
||||
}
|
||||
}
|
||||
|
||||
await self._websocket.send(json.dumps(init_message))
|
||||
|
||||
# 等待初始化响应
|
||||
response = await asyncio.wait_for(
|
||||
self._websocket.recv(),
|
||||
timeout=self.connection_timeout
|
||||
)
|
||||
|
||||
init_response = json.loads(response)
|
||||
if init_response.get("error", None) is not None:
|
||||
raise MCPProtocolError(f"初始化失败: {init_response['error']}")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"WebSocket连接失败: {e}")
|
||||
return False
|
||||
}
|
||||
|
||||
await self._websocket.send(json.dumps(init_message))
|
||||
|
||||
# 等待初始化响应
|
||||
response = await asyncio.wait_for(
|
||||
self._websocket.recv(),
|
||||
timeout=self.timeout
|
||||
)
|
||||
|
||||
init_response = json.loads(response)
|
||||
if "error" in init_response:
|
||||
raise MCPConnectionError(f"初始化失败: {init_response['error']}")
|
||||
|
||||
async def _connect_http(self) -> bool:
|
||||
"""建立HTTP连接"""
|
||||
try:
|
||||
# HTTP会话配置
|
||||
timeout = aiohttp.ClientTimeout(total=self.connection_timeout)
|
||||
headers = self.connection_config.get("headers", {})
|
||||
auth_headers = self._build_auth_headers()
|
||||
headers.update(auth_headers)
|
||||
|
||||
self._session = aiohttp.ClientSession(
|
||||
timeout=timeout,
|
||||
headers=headers
|
||||
)
|
||||
|
||||
# 测试连接
|
||||
test_url = f"{self.server_url}/health" if not self.server_url.endswith('/') else f"{self.server_url}health"
|
||||
|
||||
async with self._session.get(test_url) as response:
|
||||
if response.status == 200:
|
||||
return True
|
||||
else:
|
||||
# 尝试根路径
|
||||
async with self._session.get(self.server_url) as root_response:
|
||||
return root_response.status < 400
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"HTTP连接失败: {e}")
|
||||
if self._session:
|
||||
await self._session.close()
|
||||
self._session = None
|
||||
return False
|
||||
|
||||
async def _websocket_message_handler(self):
|
||||
"""WebSocket消息处理器"""
|
||||
async def _handle_websocket_messages(self):
|
||||
"""处理 WebSocket 消息"""
|
||||
try:
|
||||
while self._websocket and not self._websocket.closed:
|
||||
try:
|
||||
message = await self._websocket.recv()
|
||||
await self._handle_message(json.loads(message))
|
||||
data = json.loads(message)
|
||||
|
||||
# 处理响应
|
||||
if "id" in data:
|
||||
request_id = str(data["id"])
|
||||
if request_id in self._pending_requests:
|
||||
future = self._pending_requests.pop(request_id)
|
||||
if not future.done():
|
||||
future.set_result(data)
|
||||
|
||||
except ConnectionClosed:
|
||||
break
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"解析WebSocket消息失败: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"处理WebSocket消息失败: {e}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"WebSocket消息处理器异常: {e}")
|
||||
finally:
|
||||
self._connected = False
|
||||
await self._notify_disconnect_callbacks()
|
||||
logger.error(f"WebSocket消息处理异常: {e}")
|
||||
|
||||
async def _handle_message(self, message: Dict[str, Any]):
|
||||
"""处理收到的消息"""
|
||||
try:
|
||||
# 检查是否是响应消息
|
||||
if "id" in message:
|
||||
request_id = str(message["id"])
|
||||
if request_id in self._pending_requests:
|
||||
future = self._pending_requests.pop(request_id)
|
||||
if not future.done():
|
||||
future.set_result(message)
|
||||
|
||||
# 处理通知消息
|
||||
elif "method" in message:
|
||||
await self._handle_notification(message)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"处理消息失败: {e}")
|
||||
|
||||
@staticmethod
|
||||
async def _handle_notification(message: Dict[str, Any]):
|
||||
"""处理通知消息"""
|
||||
method = message.get("method")
|
||||
params = message.get("params", {})
|
||||
|
||||
logger.debug(f"收到MCP通知: {method}, 参数: {params}")
|
||||
|
||||
# 这里可以根据需要处理特定的通知
|
||||
# 例如:工具列表更新、服务器状态变化等
|
||||
|
||||
async def call_tool(self, tool_name: str, arguments: Dict[str, Any], timeout: int = 30) -> Dict[str, Any]:
|
||||
"""调用MCP工具
|
||||
|
||||
Args:
|
||||
tool_name: 工具名称
|
||||
arguments: 工具参数
|
||||
timeout: 超时时间(秒)
|
||||
|
||||
Returns:
|
||||
工具执行结果
|
||||
|
||||
Raises:
|
||||
MCPConnectionError: 连接错误
|
||||
MCPProtocolError: 协议错误
|
||||
"""
|
||||
if not self._connected:
|
||||
raise MCPConnectionError("MCP客户端未连接")
|
||||
|
||||
async def call_tool(self, tool_name: str, arguments: Dict[str, Any]) -> Any:
|
||||
"""调用工具"""
|
||||
request_data = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": self._get_next_request_id(),
|
||||
"id": self._get_request_id(),
|
||||
"method": "tools/call",
|
||||
"params": {
|
||||
"name": tool_name,
|
||||
@@ -322,343 +241,69 @@ class MCPClient:
|
||||
}
|
||||
}
|
||||
|
||||
try:
|
||||
response = await self._send_request(request_data, timeout)
|
||||
|
||||
if response.get("error", None) is not None:
|
||||
error = response["error"]
|
||||
raise MCPProtocolError(f"工具调用失败: {error.get('message', '未知错误')}")
|
||||
|
||||
return response.get("result", {})
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
raise MCPProtocolError(f"工具调用超时: {tool_name}")
|
||||
if self.is_websocket:
|
||||
response = await self._send_websocket_request(request_data)
|
||||
else:
|
||||
response = await self._send_http_request(request_data)
|
||||
|
||||
if "error" in response:
|
||||
error = response["error"]
|
||||
raise MCPConnectionError(f"工具调用失败: {error.get('message', '未知错误')}")
|
||||
|
||||
return response.get("result", {})
|
||||
|
||||
async def list_tools(self, timeout: int = 10) -> List[Dict[str, Any]]:
|
||||
"""获取可用工具列表
|
||||
|
||||
Args:
|
||||
timeout: 超时时间(秒)
|
||||
|
||||
Returns:
|
||||
工具列表
|
||||
|
||||
Raises:
|
||||
MCPConnectionError: 连接错误
|
||||
MCPProtocolError: 协议错误
|
||||
"""
|
||||
if not self._connected:
|
||||
raise MCPConnectionError("MCP客户端未连接")
|
||||
|
||||
async def list_tools(self) -> List[Dict[str, Any]]:
|
||||
"""获取工具列表"""
|
||||
request_data = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": self._get_next_request_id(),
|
||||
"method": "tools/list"
|
||||
"id": self._get_request_id(),
|
||||
"method": "tools/list",
|
||||
"params": {}
|
||||
}
|
||||
|
||||
try:
|
||||
response = await self._send_request(request_data, timeout)
|
||||
|
||||
if response.get("error", None) is not None:
|
||||
error = response["error"]
|
||||
raise MCPProtocolError(f"获取工具列表失败: {error.get('message', '未知错误')}")
|
||||
|
||||
result = response.get("result", {})
|
||||
return result.get("tools", [])
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
raise MCPProtocolError("获取工具列表超时")
|
||||
|
||||
async def _send_request(self, request_data: Dict[str, Any], timeout: int) -> Dict[str, Any]:
|
||||
"""发送请求并等待响应
|
||||
|
||||
Args:
|
||||
request_data: 请求数据
|
||||
timeout: 超时时间(秒)
|
||||
|
||||
Returns:
|
||||
响应数据
|
||||
"""
|
||||
if self.connection_type == "websocket":
|
||||
request_id = str(request_data["id"])
|
||||
return await self._send_websocket_request(request_data, request_id, timeout)
|
||||
if self.is_websocket:
|
||||
response = await self._send_websocket_request(request_data)
|
||||
else:
|
||||
return await self._send_http_request(request_data, timeout)
|
||||
|
||||
async def _send_websocket_request(self, request_data: Dict[str, Any], request_id: str, timeout: int) -> Dict[str, Any]:
|
||||
"""发送WebSocket请求"""
|
||||
if not self._websocket or self._websocket.closed:
|
||||
raise MCPConnectionError("WebSocket连接已断开")
|
||||
response = await self._send_http_request(request_data)
|
||||
|
||||
# 创建Future等待响应
|
||||
if "error" in response:
|
||||
error = response["error"]
|
||||
raise MCPConnectionError(f"获取工具列表失败: {error.get('message', '未知错误')}")
|
||||
|
||||
result = response.get("result", {})
|
||||
return result.get("tools", [])
|
||||
|
||||
async def _send_websocket_request(self, request_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""发送WebSocket请求"""
|
||||
request_id = str(request_data["id"])
|
||||
future = asyncio.Future()
|
||||
self._pending_requests[request_id] = future
|
||||
|
||||
try:
|
||||
# 发送请求
|
||||
await self._websocket.send(json.dumps(request_data))
|
||||
|
||||
# 等待响应
|
||||
response = await asyncio.wait_for(future, timeout=timeout)
|
||||
response = await asyncio.wait_for(future, timeout=self.timeout)
|
||||
return response
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
await self._pending_requests.pop(request_id, None)
|
||||
self._pending_requests.pop(request_id, None)
|
||||
raise
|
||||
except Exception as e:
|
||||
await self._pending_requests.pop(request_id, None)
|
||||
raise MCPConnectionError(f"发送WebSocket请求失败: {e}")
|
||||
|
||||
async def _send_http_request(self, request_data: Dict[str, Any], timeout: int) -> Dict[str, Any]:
|
||||
async def _send_http_request(self, request_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""发送HTTP请求"""
|
||||
if not self._session:
|
||||
raise MCPConnectionError("HTTP会话未建立")
|
||||
|
||||
try:
|
||||
url = f"{self.server_url}/mcp" if not self.server_url.endswith('/') else f"{self.server_url}mcp"
|
||||
|
||||
async with self._session.post(
|
||||
url,
|
||||
json=request_data,
|
||||
timeout=aiohttp.ClientTimeout(total=timeout)
|
||||
self.server_url,
|
||||
json=request_data
|
||||
) as response:
|
||||
if response.status == 200:
|
||||
return await response.json()
|
||||
else:
|
||||
async with self._session.post(
|
||||
self.server_url,
|
||||
json=request_data,
|
||||
timeout=aiohttp.ClientTimeout(total=timeout)
|
||||
) as root_response:
|
||||
if root_response.status != 200:
|
||||
error_text = await root_response.text()
|
||||
raise MCPConnectionError(f"HTTP请求失败 {response.status}: {error_text}")
|
||||
|
||||
return await response.json()
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
raise MCPConnectionError(f"HTTP请求失败 {response.status}: {error_text}")
|
||||
|
||||
return await response.json()
|
||||
|
||||
except aiohttp.ClientError as e:
|
||||
raise MCPConnectionError(f"HTTP请求失败: {e}")
|
||||
|
||||
async def health_check(self) -> Dict[str, Any]:
|
||||
"""执行健康检查
|
||||
|
||||
Returns:
|
||||
健康状态信息
|
||||
"""
|
||||
try:
|
||||
if not self._connected:
|
||||
return {
|
||||
"healthy": False,
|
||||
"error": "未连接",
|
||||
"timestamp": time.time()
|
||||
}
|
||||
|
||||
# 发送ping请求
|
||||
request_data = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": self._get_next_request_id(),
|
||||
"method": "ping"
|
||||
}
|
||||
|
||||
start_time = time.time()
|
||||
response = await self._send_request(request_data, timeout=5)
|
||||
response_time = round((time.time() - start_time) * 1000)
|
||||
|
||||
self._last_health_check = round(time.time() * 1000)
|
||||
|
||||
return {
|
||||
"healthy": True,
|
||||
"response_time": response_time,
|
||||
"timestamp": self._last_health_check,
|
||||
"server_info": response.get("result", {})
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
"healthy": False,
|
||||
"error": str(e),
|
||||
"timestamp": time.time()
|
||||
}
|
||||
|
||||
async def _start_health_check(self):
|
||||
"""启动健康检查任务"""
|
||||
if self.health_check_interval > 0:
|
||||
self._health_check_task = asyncio.create_task(self._health_check_loop())
|
||||
|
||||
async def _stop_health_check(self):
|
||||
"""停止健康检查任务"""
|
||||
if self._health_check_task:
|
||||
self._health_check_task.cancel()
|
||||
try:
|
||||
await self._health_check_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
self._health_check_task = None
|
||||
|
||||
async def _health_check_loop(self):
|
||||
"""健康检查循环"""
|
||||
try:
|
||||
while self._connected:
|
||||
await asyncio.sleep(self.health_check_interval)
|
||||
|
||||
if self._connected:
|
||||
health_status = await self.health_check()
|
||||
if not health_status["healthy"]:
|
||||
logger.warning(f"MCP服务器健康检查失败: {health_status.get('error')}")
|
||||
# 可以在这里实现重连逻辑
|
||||
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.error(f"健康检查循环异常: {e}")
|
||||
|
||||
def _get_next_request_id(self) -> str:
|
||||
"""获取下一个请求ID"""
|
||||
def _get_request_id(self) -> str:
|
||||
"""获取请求ID"""
|
||||
self._request_id += 1
|
||||
return f"req_{self._request_id}_{int(time.time() * 1000)}"
|
||||
|
||||
# 事件回调管理
|
||||
def on_connect(self, callback: Callable):
|
||||
"""注册连接回调"""
|
||||
self._on_connect_callbacks.append(callback)
|
||||
|
||||
def on_disconnect(self, callback: Callable):
|
||||
"""注册断开连接回调"""
|
||||
self._on_disconnect_callbacks.append(callback)
|
||||
|
||||
def on_error(self, callback: Callable):
|
||||
"""注册错误回调"""
|
||||
self._on_error_callbacks.append(callback)
|
||||
|
||||
async def _notify_connect_callbacks(self):
|
||||
"""通知连接回调"""
|
||||
for callback in self._on_connect_callbacks:
|
||||
try:
|
||||
if asyncio.iscoroutinefunction(callback):
|
||||
await callback()
|
||||
else:
|
||||
callback()
|
||||
except Exception as e:
|
||||
logger.error(f"连接回调执行失败: {e}")
|
||||
|
||||
async def _notify_disconnect_callbacks(self):
|
||||
"""通知断开连接回调"""
|
||||
for callback in self._on_disconnect_callbacks:
|
||||
try:
|
||||
if asyncio.iscoroutinefunction(callback):
|
||||
await callback()
|
||||
else:
|
||||
callback()
|
||||
except Exception as e:
|
||||
logger.error(f"断开连接回调执行失败: {e}")
|
||||
|
||||
async def _notify_error_callbacks(self, error: Exception):
|
||||
"""通知错误回调"""
|
||||
for callback in self._on_error_callbacks:
|
||||
try:
|
||||
if asyncio.iscoroutinefunction(callback):
|
||||
await callback(error)
|
||||
else:
|
||||
callback(error)
|
||||
except Exception as e:
|
||||
logger.error(f"错误回调执行失败: {e}")
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
"""检查是否已连接"""
|
||||
return self._connected
|
||||
|
||||
@property
|
||||
def last_health_check(self) -> Optional[float]:
|
||||
"""获取最后一次健康检查时间"""
|
||||
return self._last_health_check
|
||||
|
||||
def get_connection_info(self) -> Dict[str, Any]:
|
||||
"""获取连接信息"""
|
||||
return {
|
||||
"server_url": self.server_url,
|
||||
"connection_type": self.connection_type,
|
||||
"connected": self._connected,
|
||||
"last_health_check": self._last_health_check,
|
||||
"pending_requests": len(self._pending_requests),
|
||||
"config": self.connection_config
|
||||
}
|
||||
|
||||
async def __aenter__(self):
|
||||
"""异步上下文管理器入口"""
|
||||
await self.connect()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
"""异步上下文管理器出口"""
|
||||
await self.disconnect()
|
||||
|
||||
|
||||
class MCPConnectionPool:
|
||||
"""MCP连接池 - 管理多个MCP客户端连接"""
|
||||
|
||||
def __init__(self, max_connections: int = 10):
|
||||
"""初始化连接池
|
||||
|
||||
Args:
|
||||
max_connections: 最大连接数
|
||||
"""
|
||||
self.max_connections = max_connections
|
||||
self._clients: Dict[str, MCPClient] = {}
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
async def get_client(self, server_url: str, connection_config: Dict[str, Any] = None) -> MCPClient:
|
||||
"""获取或创建MCP客户端
|
||||
|
||||
Args:
|
||||
server_url: 服务器URL
|
||||
connection_config: 连接配置
|
||||
|
||||
Returns:
|
||||
MCP客户端实例
|
||||
"""
|
||||
async with self._lock:
|
||||
if server_url in self._clients:
|
||||
client = self._clients[server_url]
|
||||
if client.is_connected:
|
||||
return client
|
||||
else:
|
||||
# 尝试重连
|
||||
if await client.connect():
|
||||
return client
|
||||
else:
|
||||
# 移除失效的客户端
|
||||
del self._clients[server_url]
|
||||
|
||||
# 检查连接数限制
|
||||
if len(self._clients) >= self.max_connections:
|
||||
# 移除最旧的连接
|
||||
oldest_url = next(iter(self._clients))
|
||||
await self._clients[oldest_url].disconnect()
|
||||
del self._clients[oldest_url]
|
||||
|
||||
# 创建新客户端
|
||||
client = MCPClient(server_url, connection_config)
|
||||
if await client.connect():
|
||||
self._clients[server_url] = client
|
||||
return client
|
||||
else:
|
||||
raise MCPConnectionError(f"无法连接到MCP服务器: {server_url}")
|
||||
|
||||
async def disconnect_all(self):
|
||||
"""断开所有连接"""
|
||||
async with self._lock:
|
||||
for client in self._clients.values():
|
||||
await client.disconnect()
|
||||
self._clients.clear()
|
||||
|
||||
def get_pool_status(self) -> Dict[str, Any]:
|
||||
"""获取连接池状态"""
|
||||
return {
|
||||
"total_connections": len(self._clients),
|
||||
"max_connections": self.max_connections,
|
||||
"connections": {
|
||||
url: client.get_connection_info()
|
||||
for url, client in self._clients.items()
|
||||
}
|
||||
}
|
||||
return f"req_{self._request_id}_{int(time.time() * 1000)}"
|
||||
@@ -1,6 +1,4 @@
|
||||
"""MCP服务管理器 - 管理MCP服务的注册、更新、删除和状态监控"""
|
||||
import asyncio
|
||||
import time
|
||||
"""MCP服务管理器 - 简化版本"""
|
||||
import uuid
|
||||
from typing import Dict, Any, List, Optional, Tuple
|
||||
from datetime import datetime
|
||||
@@ -8,133 +6,53 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from app.models.tool_model import MCPToolConfig, ToolConfig, ToolType, ToolStatus
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.core.tools.mcp.client import MCPClient, MCPConnectionPool
|
||||
from app.core.tools.mcp.base import MCPToolManager
|
||||
|
||||
logger = get_business_logger()
|
||||
|
||||
|
||||
class MCPServiceManager:
|
||||
"""MCP服务管理器 - 管理MCP服务的生命周期"""
|
||||
"""MCP服务管理器 - 简化版本,主要用于工具创建"""
|
||||
|
||||
def __init__(self, db: Session):
|
||||
"""初始化MCP服务管理器
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
"""
|
||||
def __init__(self, db: Session = None):
|
||||
self.db = db
|
||||
self.connection_pool = MCPConnectionPool(max_connections=20)
|
||||
|
||||
# 服务状态管理
|
||||
self._services: Dict[str, Dict[str, Any]] = {} # service_id -> service_info
|
||||
self._monitoring_tasks: Dict[str, asyncio.Task] = {} # service_id -> monitoring_task
|
||||
|
||||
# 配置
|
||||
self.health_check_interval = 60 # 健康检查间隔(秒)
|
||||
self.max_retry_attempts = 3 # 最大重试次数
|
||||
self.retry_delay = 5 # 重试延迟(秒)
|
||||
|
||||
# 状态
|
||||
self._running = False
|
||||
self._manager_task = None
|
||||
self.tool_manager = MCPToolManager(db) if db else None
|
||||
|
||||
async def start(self):
|
||||
"""启动服务管理器"""
|
||||
if self._running:
|
||||
return
|
||||
|
||||
self._running = True
|
||||
logger.info("MCP服务管理器启动")
|
||||
|
||||
# 加载现有服务
|
||||
await self._load_existing_services()
|
||||
|
||||
# 启动管理任务
|
||||
self._manager_task = asyncio.create_task(self._management_loop())
|
||||
|
||||
async def stop(self):
|
||||
"""停止服务管理器"""
|
||||
if not self._running:
|
||||
return
|
||||
|
||||
self._running = False
|
||||
logger.info("MCP服务管理器停止")
|
||||
|
||||
# 停止管理任务
|
||||
if self._manager_task:
|
||||
self._manager_task.cancel()
|
||||
try:
|
||||
await self._manager_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
# 停止所有监控任务
|
||||
for task in self._monitoring_tasks.values():
|
||||
task.cancel()
|
||||
|
||||
if self._monitoring_tasks:
|
||||
await asyncio.gather(*self._monitoring_tasks.values(), return_exceptions=True)
|
||||
|
||||
self._monitoring_tasks.clear()
|
||||
|
||||
# 断开所有连接
|
||||
await self.connection_pool.disconnect_all()
|
||||
|
||||
async def register_service(
|
||||
async def create_mcp_tool(
|
||||
self,
|
||||
server_url: str,
|
||||
connection_config: Dict[str, Any],
|
||||
tenant_id: uuid.UUID,
|
||||
tool_name: str,
|
||||
service_name: str = None
|
||||
) -> Tuple[bool, str, Optional[str]]:
|
||||
"""注册MCP服务
|
||||
"""创建单个MCP工具
|
||||
|
||||
Args:
|
||||
server_url: 服务器URL
|
||||
connection_config: 连接配置
|
||||
tenant_id: 租户ID
|
||||
service_name: 服务名称(可选)
|
||||
tool_name: 具体工具名称
|
||||
service_name: 服务名称
|
||||
|
||||
Returns:
|
||||
(是否成功, 服务ID或错误信息, 错误详情)
|
||||
(是否成功, 工具ID或错误信息, 错误详情)
|
||||
"""
|
||||
try:
|
||||
# 检查服务是否已存在
|
||||
existing_service = self.db.query(MCPToolConfig).filter(
|
||||
MCPToolConfig.server_url == server_url
|
||||
).first()
|
||||
|
||||
if existing_service:
|
||||
return False, "服务已存在", f"URL {server_url} 已被注册"
|
||||
|
||||
# 测试连接
|
||||
try:
|
||||
client = MCPClient(server_url, connection_config)
|
||||
if not await client.connect():
|
||||
return False, "连接测试失败", "无法连接到MCP服务器"
|
||||
|
||||
# 获取可用工具
|
||||
available_tools = await client.list_tools()
|
||||
tool_names = [tool.get("name") for tool in available_tools if tool.get("name")]
|
||||
|
||||
await client.disconnect()
|
||||
|
||||
except Exception as e:
|
||||
return False, "连接测试失败", str(e)
|
||||
if not service_name:
|
||||
service_name = f"mcp_{tool_name}"
|
||||
|
||||
# 创建工具配置
|
||||
if not service_name:
|
||||
service_name = f"mcp_service_{server_url.split('/')[-1]}"
|
||||
|
||||
tool_config = ToolConfig(
|
||||
name=service_name,
|
||||
description=f"MCP服务 - {server_url}",
|
||||
description=f"MCP工具: {tool_name}",
|
||||
tool_type=ToolType.MCP.value,
|
||||
tenant_id=tenant_id,
|
||||
version="1.0.0",
|
||||
status=ToolStatus.AVAILABLE.value,
|
||||
config_data={
|
||||
"server_url": server_url,
|
||||
"connection_config": connection_config
|
||||
"connection_config": connection_config,
|
||||
"tool_name": tool_name
|
||||
}
|
||||
)
|
||||
|
||||
@@ -146,460 +64,22 @@ class MCPServiceManager:
|
||||
id=tool_config.id,
|
||||
server_url=server_url,
|
||||
connection_config=connection_config,
|
||||
available_tools=tool_names,
|
||||
health_status="healthy",
|
||||
available_tools=[tool_name],
|
||||
health_status="unknown",
|
||||
last_health_check=datetime.now()
|
||||
)
|
||||
|
||||
self.db.add(mcp_config)
|
||||
self.db.commit()
|
||||
|
||||
service_id = str(tool_config.id)
|
||||
|
||||
# 添加到内存管理
|
||||
self._services[service_id] = {
|
||||
"id": service_id,
|
||||
"server_url": server_url,
|
||||
"connection_config": connection_config,
|
||||
"tenant_id": tenant_id,
|
||||
"available_tools": tool_names,
|
||||
"status": "healthy",
|
||||
"last_health_check": time.time(),
|
||||
"retry_count": 0,
|
||||
"created_at": time.time()
|
||||
}
|
||||
|
||||
# 启动监控
|
||||
await self._start_service_monitoring(service_id)
|
||||
|
||||
logger.info(f"MCP服务注册成功: {service_id} ({server_url})")
|
||||
return True, service_id, None
|
||||
logger.info(f"MCP工具创建成功: {tool_config.id} ({tool_name})")
|
||||
return True, str(tool_config.id), None
|
||||
|
||||
except Exception as e:
|
||||
self.db.rollback()
|
||||
logger.error(f"注册MCP服务失败: {server_url}, 错误: {e}")
|
||||
return False, "注册失败", str(e)
|
||||
logger.error(f"创建MCP工具失败: {tool_name}, 错误: {e}")
|
||||
return False, "创建失败", str(e)
|
||||
|
||||
async def unregister_service(self, service_id: str) -> Tuple[bool, str]:
|
||||
"""注销MCP服务
|
||||
|
||||
Args:
|
||||
service_id: 服务ID
|
||||
|
||||
Returns:
|
||||
(是否成功, 错误信息)
|
||||
"""
|
||||
try:
|
||||
# 从数据库删除
|
||||
tool_config = self.db.get(ToolConfig, uuid.UUID(service_id))
|
||||
if not tool_config:
|
||||
return False, "服务不存在"
|
||||
|
||||
self.db.delete(tool_config)
|
||||
self.db.commit()
|
||||
|
||||
# 停止监控
|
||||
await self._stop_service_monitoring(service_id)
|
||||
|
||||
# 从内存移除
|
||||
if service_id in self._services:
|
||||
del self._services[service_id]
|
||||
|
||||
logger.info(f"MCP服务注销成功: {service_id}")
|
||||
return True, ""
|
||||
|
||||
except Exception as e:
|
||||
self.db.rollback()
|
||||
logger.error(f"注销MCP服务失败: {service_id}, 错误: {e}")
|
||||
return False, str(e)
|
||||
|
||||
async def update_service(
|
||||
self,
|
||||
service_id: str,
|
||||
connection_config: Dict[str, Any] = None,
|
||||
enabled: bool = None
|
||||
) -> Tuple[bool, str]:
|
||||
"""更新MCP服务配置
|
||||
|
||||
Args:
|
||||
service_id: 服务ID
|
||||
connection_config: 新的连接配置
|
||||
enabled: 是否启用
|
||||
|
||||
Returns:
|
||||
(是否成功, 错误信息)
|
||||
"""
|
||||
try:
|
||||
# 更新数据库
|
||||
mcp_config = self.db.query(MCPToolConfig).filter(
|
||||
MCPToolConfig.id == uuid.UUID(service_id)
|
||||
).first()
|
||||
|
||||
if not mcp_config:
|
||||
return False, "服务不存在"
|
||||
|
||||
tool_config = mcp_config.base_config
|
||||
|
||||
if connection_config is not None:
|
||||
mcp_config.connection_config = connection_config
|
||||
tool_config.config_data["connection_config"] = connection_config
|
||||
|
||||
if enabled is not None:
|
||||
tool_config.is_enabled = enabled
|
||||
|
||||
self.db.commit()
|
||||
|
||||
# 更新内存状态
|
||||
if service_id in self._services:
|
||||
if connection_config is not None:
|
||||
self._services[service_id]["connection_config"] = connection_config
|
||||
|
||||
# 如果配置有变化,重启监控
|
||||
if connection_config is not None:
|
||||
await self._restart_service_monitoring(service_id)
|
||||
|
||||
logger.info(f"MCP服务更新成功: {service_id}")
|
||||
return True, ""
|
||||
|
||||
except Exception as e:
|
||||
self.db.rollback()
|
||||
logger.error(f"更新MCP服务失败: {service_id}, 错误: {e}")
|
||||
return False, str(e)
|
||||
|
||||
async def get_service_status(self, service_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""获取服务状态
|
||||
|
||||
Args:
|
||||
service_id: 服务ID
|
||||
|
||||
Returns:
|
||||
服务状态信息
|
||||
"""
|
||||
if service_id not in self._services:
|
||||
return None
|
||||
|
||||
service_info = self._services[service_id].copy()
|
||||
|
||||
# 添加实时健康检查
|
||||
try:
|
||||
client = await self.connection_pool.get_client(
|
||||
service_info["server_url"],
|
||||
service_info["connection_config"]
|
||||
)
|
||||
|
||||
health_status = await client.health_check()
|
||||
service_info["real_time_health"] = health_status
|
||||
|
||||
except Exception as e:
|
||||
service_info["real_time_health"] = {
|
||||
"healthy": False,
|
||||
"error": str(e),
|
||||
"timestamp": time.time()
|
||||
}
|
||||
|
||||
return service_info
|
||||
|
||||
async def list_services(self, tenant_id: uuid.UUID = None) -> List[Dict[str, Any]]:
|
||||
"""列出所有服务
|
||||
|
||||
Args:
|
||||
tenant_id: 租户ID过滤
|
||||
|
||||
Returns:
|
||||
服务列表
|
||||
"""
|
||||
services = []
|
||||
|
||||
for service_id, service_info in self._services.items():
|
||||
if tenant_id and service_info["tenant_id"] != tenant_id:
|
||||
continue
|
||||
|
||||
services.append(service_info.copy())
|
||||
|
||||
return services
|
||||
|
||||
async def get_service_tools(self, service_id: str) -> List[Dict[str, Any]]:
|
||||
"""获取服务的可用工具
|
||||
|
||||
Args:
|
||||
service_id: 服务ID
|
||||
|
||||
Returns:
|
||||
工具列表
|
||||
"""
|
||||
if service_id not in self._services:
|
||||
return []
|
||||
|
||||
service_info = self._services[service_id]
|
||||
|
||||
try:
|
||||
client = await self.connection_pool.get_client(
|
||||
service_info["server_url"],
|
||||
service_info["connection_config"]
|
||||
)
|
||||
|
||||
tools = await client.list_tools()
|
||||
|
||||
# 更新缓存的工具列表
|
||||
tool_names = [tool.get("name") for tool in tools if tool.get("name")]
|
||||
service_info["available_tools"] = tool_names
|
||||
|
||||
# 更新数据库
|
||||
mcp_config = self.db.query(MCPToolConfig).filter(
|
||||
MCPToolConfig.id == uuid.UUID(service_id)
|
||||
).first()
|
||||
|
||||
if mcp_config:
|
||||
mcp_config.available_tools = tool_names
|
||||
self.db.commit()
|
||||
|
||||
return tools
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取服务工具失败: {service_id}, 错误: {e}")
|
||||
return []
|
||||
|
||||
async def call_service_tool(
|
||||
self,
|
||||
service_id: str,
|
||||
tool_name: str,
|
||||
arguments: Dict[str, Any],
|
||||
timeout: int = 30
|
||||
) -> Dict[str, Any]:
|
||||
"""调用服务工具
|
||||
|
||||
Args:
|
||||
service_id: 服务ID
|
||||
tool_name: 工具名称
|
||||
arguments: 工具参数
|
||||
timeout: 超时时间
|
||||
|
||||
Returns:
|
||||
执行结果
|
||||
"""
|
||||
if service_id not in self._services:
|
||||
raise ValueError(f"服务不存在: {service_id}")
|
||||
|
||||
service_info = self._services[service_id]
|
||||
|
||||
try:
|
||||
client = await self.connection_pool.get_client(
|
||||
service_info["server_url"],
|
||||
service_info["connection_config"]
|
||||
)
|
||||
|
||||
result = await client.call_tool(tool_name, arguments, timeout)
|
||||
|
||||
# 更新服务状态为健康
|
||||
service_info["status"] = "healthy"
|
||||
service_info["last_health_check"] = time.time()
|
||||
service_info["retry_count"] = 0
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
# 更新服务状态为错误
|
||||
service_info["status"] = "error"
|
||||
service_info["last_error"] = str(e)
|
||||
service_info["retry_count"] += 1
|
||||
|
||||
logger.error(f"调用服务工具失败: {service_id}/{tool_name}, 错误: {e}")
|
||||
raise
|
||||
|
||||
async def _load_existing_services(self):
|
||||
"""加载现有服务"""
|
||||
try:
|
||||
mcp_configs = self.db.query(MCPToolConfig).join(ToolConfig).filter(
|
||||
ToolConfig.status == ToolStatus.AVAILABLE.value,
|
||||
ToolConfig.tool_type == ToolType.MCP.value
|
||||
).all()
|
||||
|
||||
for mcp_config in mcp_configs:
|
||||
tool_config = mcp_config.base_config
|
||||
service_id = str(mcp_config.id)
|
||||
|
||||
self._services[service_id] = {
|
||||
"id": service_id,
|
||||
"server_url": mcp_config.server_url,
|
||||
"connection_config": mcp_config.connection_config or {},
|
||||
"tenant_id": tool_config.tenant_id,
|
||||
"available_tools": mcp_config.available_tools or [],
|
||||
"status": mcp_config.health_status or "unknown",
|
||||
"last_health_check": mcp_config.last_health_check.timestamp() if mcp_config.last_health_check else 0,
|
||||
"retry_count": 0,
|
||||
"created_at": tool_config.created_at.timestamp()
|
||||
}
|
||||
|
||||
# 启动监控
|
||||
await self._start_service_monitoring(service_id)
|
||||
|
||||
logger.info(f"加载了 {len(mcp_configs)} 个MCP服务")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"加载现有服务失败: {e}")
|
||||
|
||||
async def _start_service_monitoring(self, service_id: str):
|
||||
"""启动服务监控"""
|
||||
if service_id in self._monitoring_tasks:
|
||||
return
|
||||
|
||||
task = asyncio.create_task(self._monitor_service(service_id))
|
||||
self._monitoring_tasks[service_id] = task
|
||||
|
||||
async def _stop_service_monitoring(self, service_id: str):
|
||||
"""停止服务监控"""
|
||||
if service_id in self._monitoring_tasks:
|
||||
task = self._monitoring_tasks.pop(service_id)
|
||||
task.cancel()
|
||||
try:
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
async def _restart_service_monitoring(self, service_id: str):
|
||||
"""重启服务监控"""
|
||||
await self._stop_service_monitoring(service_id)
|
||||
await self._start_service_monitoring(service_id)
|
||||
|
||||
async def _monitor_service(self, service_id: str):
|
||||
"""监控单个服务"""
|
||||
try:
|
||||
while self._running and service_id in self._services:
|
||||
service_info = self._services[service_id]
|
||||
|
||||
try:
|
||||
# 执行健康检查
|
||||
client = await self.connection_pool.get_client(
|
||||
service_info["server_url"],
|
||||
service_info["connection_config"]
|
||||
)
|
||||
|
||||
health_status = await client.health_check()
|
||||
|
||||
if health_status["healthy"]:
|
||||
# 服务健康
|
||||
service_info["status"] = "healthy"
|
||||
service_info["retry_count"] = 0
|
||||
|
||||
# 更新工具列表
|
||||
try:
|
||||
tools = await client.list_tools()
|
||||
tool_names = [tool.get("name") for tool in tools if tool.get("name")]
|
||||
service_info["available_tools"] = tool_names
|
||||
except Exception as e:
|
||||
logger.warning(f"更新工具列表失败: {service_id}, 错误: {e}")
|
||||
|
||||
else:
|
||||
# 服务不健康
|
||||
service_info["status"] = "unhealthy"
|
||||
service_info["last_error"] = health_status.get("error", "健康检查失败")
|
||||
service_info["retry_count"] += 1
|
||||
|
||||
service_info["last_health_check"] = time.time()
|
||||
|
||||
# 更新数据库
|
||||
await self._update_service_health_in_db(service_id, health_status)
|
||||
|
||||
except Exception as e:
|
||||
# 监控异常
|
||||
service_info["status"] = "error"
|
||||
service_info["last_error"] = str(e)
|
||||
service_info["retry_count"] += 1
|
||||
service_info["last_health_check"] = time.time()
|
||||
|
||||
logger.error(f"服务监控异常: {service_id}, 错误: {e}")
|
||||
|
||||
# 如果重试次数过多,暂停监控
|
||||
if service_info["retry_count"] >= self.max_retry_attempts:
|
||||
logger.warning(f"服务 {service_id} 重试次数过多,暂停监控")
|
||||
await asyncio.sleep(self.health_check_interval * 5) # 延长等待时间
|
||||
service_info["retry_count"] = 0 # 重置重试计数
|
||||
|
||||
# 等待下次检查
|
||||
await asyncio.sleep(self.health_check_interval)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.error(f"服务监控任务异常: {service_id}, 错误: {e}")
|
||||
|
||||
async def _update_service_health_in_db(self, service_id: str, health_status: Dict[str, Any]):
|
||||
"""更新数据库中的服务健康状态"""
|
||||
try:
|
||||
mcp_config = self.db.query(MCPToolConfig).filter(
|
||||
MCPToolConfig.id == uuid.UUID(service_id)
|
||||
).first()
|
||||
|
||||
if mcp_config:
|
||||
mcp_config.health_status = "healthy" if health_status["healthy"] else "unhealthy"
|
||||
mcp_config.last_health_check = datetime.now()
|
||||
|
||||
if not health_status["healthy"]:
|
||||
mcp_config.error_message = health_status.get("error", "")
|
||||
else:
|
||||
mcp_config.error_message = None
|
||||
|
||||
self.db.commit()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"更新数据库健康状态失败: {service_id}, 错误: {e}")
|
||||
self.db.rollback()
|
||||
|
||||
async def _management_loop(self):
|
||||
"""管理循环 - 处理服务清理等任务"""
|
||||
try:
|
||||
while self._running:
|
||||
# 清理失效的服务
|
||||
await self._cleanup_failed_services()
|
||||
|
||||
# 等待下次循环
|
||||
await asyncio.sleep(300) # 5分钟
|
||||
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.error(f"管理循环异常: {e}")
|
||||
|
||||
async def _cleanup_failed_services(self):
|
||||
"""清理长期失效的服务"""
|
||||
try:
|
||||
current_time = time.time()
|
||||
cleanup_threshold = 24 * 60 * 60 # 24小时
|
||||
|
||||
services_to_cleanup = []
|
||||
|
||||
for service_id, service_info in self._services.items():
|
||||
# 检查服务是否长期失效
|
||||
if (service_info["status"] in ["error", "unhealthy"] and
|
||||
current_time - service_info["last_health_check"] > cleanup_threshold):
|
||||
|
||||
services_to_cleanup.append(service_id)
|
||||
|
||||
for service_id in services_to_cleanup:
|
||||
logger.warning(f"清理长期失效的服务: {service_id}")
|
||||
|
||||
# 停止监控但不删除数据库记录
|
||||
await self._stop_service_monitoring(service_id)
|
||||
|
||||
# 标记为禁用
|
||||
tool_config = self.db.get(ToolConfig, uuid.UUID(service_id))
|
||||
if tool_config:
|
||||
tool_config.is_enabled = False
|
||||
self.db.commit()
|
||||
|
||||
# 从内存移除
|
||||
del self._services[service_id]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"清理失效服务失败: {e}")
|
||||
|
||||
def get_manager_status(self) -> Dict[str, Any]:
|
||||
"""获取管理器状态"""
|
||||
return {
|
||||
"running": self._running,
|
||||
"total_services": len(self._services),
|
||||
"healthy_services": len([s for s in self._services.values() if s["status"] == "healthy"]),
|
||||
"unhealthy_services": len([s for s in self._services.values() if s["status"] in ["unhealthy", "error"]]),
|
||||
"monitoring_tasks": len(self._monitoring_tasks),
|
||||
"connection_pool_status": self.connection_pool.get_pool_status()
|
||||
}
|
||||
def get_tool_manager(self) -> MCPToolManager:
|
||||
"""获取工具管理器实例"""
|
||||
return self.tool_manager
|
||||
@@ -64,6 +64,11 @@ def validate_model_exists_and_active(
|
||||
) -> tuple[str, bool]:
|
||||
"""Validate that a model exists and is active.
|
||||
|
||||
This function performs tenant-aware model validation with detailed error messages:
|
||||
- If model doesn't exist at all: "Model not found"
|
||||
- If model exists but belongs to different tenant: "Model belongs to different tenant" with details
|
||||
- If model exists and accessible but inactive: "Model is inactive"
|
||||
|
||||
Args:
|
||||
model_id: Model UUID to validate
|
||||
model_type: Type of model ("llm", "embedding", "rerank")
|
||||
@@ -76,7 +81,7 @@ def validate_model_exists_and_active(
|
||||
Tuple of (model_name, is_active)
|
||||
|
||||
Raises:
|
||||
ModelNotFoundError: If model does not exist
|
||||
ModelNotFoundError: If model does not exist or belongs to different tenant
|
||||
ModelInactiveError: If model exists but is inactive
|
||||
"""
|
||||
from app.repositories.model_repository import ModelConfigRepository
|
||||
@@ -84,21 +89,48 @@ def validate_model_exists_and_active(
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
# First check if model exists at all (without tenant filtering)
|
||||
model_without_tenant = ModelConfigRepository.get_by_id(db, model_id, tenant_id=None)
|
||||
|
||||
# Then check with tenant filtering
|
||||
model = ModelConfigRepository.get_by_id(db, model_id, tenant_id)
|
||||
elapsed_ms = (time.time() - start_time) * 1000
|
||||
|
||||
if not model:
|
||||
logger.warning(
|
||||
"Model not found",
|
||||
extra={"model_id": str(model_id), "model_type": model_type, "elapsed_ms": elapsed_ms}
|
||||
)
|
||||
raise ModelNotFoundError(
|
||||
model_id=model_id,
|
||||
model_type=model_type,
|
||||
config_id=config_id,
|
||||
workspace_id=workspace_id,
|
||||
message=f"{model_type.title()} model {model_id} not found"
|
||||
)
|
||||
if model_without_tenant:
|
||||
# Model exists but belongs to different tenant
|
||||
logger.warning(
|
||||
"Model belongs to different tenant",
|
||||
extra={
|
||||
"model_id": str(model_id),
|
||||
"model_type": model_type,
|
||||
"model_name": model_without_tenant.name,
|
||||
"model_tenant_id": str(model_without_tenant.tenant_id),
|
||||
"requested_tenant_id": str(tenant_id),
|
||||
"is_public": model_without_tenant.is_public,
|
||||
"elapsed_ms": elapsed_ms
|
||||
}
|
||||
)
|
||||
raise ModelNotFoundError(
|
||||
model_id=model_id,
|
||||
model_type=model_type,
|
||||
config_id=config_id,
|
||||
workspace_id=workspace_id,
|
||||
message=f"{model_type.title()} model {model_id} ({model_without_tenant.name}) belongs to a different tenant (model tenant: {model_without_tenant.tenant_id}, workspace tenant: {tenant_id}). The model is not public and cannot be accessed from this workspace."
|
||||
)
|
||||
else:
|
||||
# Model doesn't exist at all
|
||||
logger.warning(
|
||||
"Model not found",
|
||||
extra={"model_id": str(model_id), "model_type": model_type, "elapsed_ms": elapsed_ms}
|
||||
)
|
||||
raise ModelNotFoundError(
|
||||
model_id=model_id,
|
||||
model_type=model_type,
|
||||
config_id=config_id,
|
||||
workspace_id=workspace_id,
|
||||
message=f"{model_type.title()} model {model_id} not found"
|
||||
)
|
||||
|
||||
if not model.is_active:
|
||||
logger.warning(
|
||||
|
||||
@@ -22,7 +22,7 @@ class AssignmentItem(BaseModel):
|
||||
)
|
||||
|
||||
value: Any = Field(
|
||||
...,
|
||||
default=None,
|
||||
description="Value(s) to assign to the variable(s)",
|
||||
)
|
||||
|
||||
|
||||
@@ -29,7 +29,7 @@ class WorkflowState(TypedDict):
|
||||
|
||||
# Set of loop node IDs, used for assigning values in loop nodes
|
||||
cycle_nodes: list
|
||||
looping: bool
|
||||
looping: Annotated[bool, lambda x, y: x and y]
|
||||
|
||||
# Input variables (passed from configured variables)
|
||||
# Uses a deep merge function, supporting nested dict updates (e.g., conv.xxx)
|
||||
@@ -534,7 +534,7 @@ class BaseNode(ABC):
|
||||
return edge
|
||||
return None
|
||||
|
||||
def _render_template(self, template: str, state: WorkflowState | None) -> str:
|
||||
def _render_template(self, template: str, state: WorkflowState | None, struct: bool = True) -> str:
|
||||
"""渲染模板
|
||||
|
||||
支持的变量命名空间:
|
||||
@@ -568,7 +568,8 @@ class BaseNode(ABC):
|
||||
template=template,
|
||||
variables=variables,
|
||||
node_outputs=pool.get_all_node_outputs(),
|
||||
system_vars=pool.get_all_system_vars()
|
||||
system_vars=pool.get_all_system_vars(),
|
||||
struct=struct
|
||||
)
|
||||
|
||||
def _evaluate_condition(self, expression: str, state: WorkflowState | None) -> bool:
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from typing import Any
|
||||
|
||||
from pydantic import Field, BaseModel
|
||||
from pydantic import Field, BaseModel, field_validator
|
||||
|
||||
from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableType
|
||||
from app.core.workflow.nodes.enums import ComparisonOperator, LogicOperator, ValueInputType
|
||||
@@ -27,6 +27,16 @@ class CycleVariable(BaseNodeConfig):
|
||||
description="Initial or current value of the loop variable"
|
||||
)
|
||||
|
||||
@field_validator("input_type", mode="before")
|
||||
@classmethod
|
||||
def lower_input_type(cls, v):
|
||||
if isinstance(v, str):
|
||||
try:
|
||||
return ValueInputType(v.lower())
|
||||
except ValueError:
|
||||
raise ValueError(f"Invalid input_type: {v}")
|
||||
return v
|
||||
|
||||
|
||||
class ConditionDetail(BaseModel):
|
||||
operator: ComparisonOperator = Field(
|
||||
@@ -45,10 +55,20 @@ class ConditionDetail(BaseModel):
|
||||
)
|
||||
|
||||
input_type: ValueInputType = Field(
|
||||
...,
|
||||
default=ValueInputType.CONSTANT,
|
||||
description="Input type of the loop variable"
|
||||
)
|
||||
|
||||
@field_validator("input_type", mode="before")
|
||||
@classmethod
|
||||
def lower_input_type(cls, v):
|
||||
if isinstance(v, str):
|
||||
try:
|
||||
return ValueInputType(v.lower())
|
||||
except ValueError:
|
||||
raise ValueError(f"Invalid input_type: {v}")
|
||||
return v
|
||||
|
||||
|
||||
class ConditionsConfig(BaseModel):
|
||||
"""Configuration for loop condition evaluation"""
|
||||
|
||||
@@ -37,7 +37,7 @@ class EndNode(BaseNode):
|
||||
|
||||
# 如果配置了输出模板,使用模板渲染;否则使用默认输出
|
||||
if output_template:
|
||||
output = self._render_template(output_template, state)
|
||||
output = self._render_template(output_template, state, struct=False)
|
||||
else:
|
||||
output = "工作流已完成"
|
||||
|
||||
|
||||
@@ -93,5 +93,5 @@ class HttpErrorHandle(StrEnum):
|
||||
|
||||
|
||||
class ValueInputType(StrEnum):
|
||||
VARIABLE = "Variable"
|
||||
CONSTANT = "Constant"
|
||||
VARIABLE = "variable"
|
||||
CONSTANT = "constant"
|
||||
|
||||
@@ -73,8 +73,10 @@ class HttpContentTypeConfig(BaseModel):
|
||||
content_type = info.data.get("content_type")
|
||||
if content_type == HttpContentType.FROM_DATA and not isinstance(v, HttpFormData):
|
||||
raise ValueError("When content_type is 'form-data', data must be of type HttpFormData")
|
||||
elif content_type in [HttpContentType.JSON, HttpContentType.WWW_FORM] and not isinstance(v, dict):
|
||||
raise ValueError("When content_type is JSON or x-www-form-urlencoded, data must be a object")
|
||||
elif content_type in [HttpContentType.JSON] and not isinstance(v, str):
|
||||
raise ValueError("When content_type is JSON, data must be of type str")
|
||||
elif content_type in [HttpContentType.WWW_FORM] and not isinstance(v, dict):
|
||||
raise ValueError("When content_type is x-www-form-urlencoded, data must be an object(dict)")
|
||||
elif content_type in [HttpContentType.RAW, HttpContentType.BINARY] and not isinstance(v, str):
|
||||
raise ValueError("When content_type is raw/binary, data must be a string (File descriptor)")
|
||||
return v
|
||||
|
||||
@@ -120,7 +120,7 @@ class HttpRequestNode(BaseNode):
|
||||
return {}
|
||||
case HttpContentType.JSON:
|
||||
content["json"] = json.loads(self._render_template(
|
||||
json.dumps(self.typed_config.body.data), state
|
||||
self.typed_config.body.data, state
|
||||
))
|
||||
case HttpContentType.FROM_DATA:
|
||||
data = {}
|
||||
@@ -208,17 +208,12 @@ class HttpRequestNode(BaseNode):
|
||||
retries -= 1
|
||||
if retries > 0:
|
||||
await asyncio.sleep(self.typed_config.retry.retry_interval / 1000)
|
||||
elif self.typed_config.error_handle.method == HttpErrorHandle.NONE:
|
||||
raise e
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"HTTP request node exception: {e}")
|
||||
else:
|
||||
match self.typed_config.error_handle.method:
|
||||
case HttpErrorHandle.NONE:
|
||||
logger.warning(
|
||||
f"Node {self.node_id}: HTTP request failed, returning error response"
|
||||
)
|
||||
return HttpRequestNodeOutput(
|
||||
body="",
|
||||
status_code=resp.status_code,
|
||||
headers=resp.headers,
|
||||
).model_dump()
|
||||
case HttpErrorHandle.DEFAULT:
|
||||
logger.warning(
|
||||
f"Node {self.node_id}: HTTP request failed, returning default result"
|
||||
@@ -229,3 +224,4 @@ class HttpRequestNode(BaseNode):
|
||||
f"Node {self.node_id}: HTTP request failed, switching to error handling branch"
|
||||
)
|
||||
return "ERROR"
|
||||
raise RuntimeError("http request failed")
|
||||
|
||||
@@ -23,10 +23,20 @@ class ConditionDetail(BaseModel):
|
||||
)
|
||||
|
||||
input_type: ValueInputType = Field(
|
||||
...,
|
||||
default=ValueInputType.CONSTANT,
|
||||
description="Value input type for comparison"
|
||||
)
|
||||
|
||||
@field_validator("input_type", mode="before")
|
||||
@classmethod
|
||||
def lower_input_type(cls, v):
|
||||
if isinstance(v, str):
|
||||
try:
|
||||
return ValueInputType(v.lower())
|
||||
except ValueError:
|
||||
raise ValueError(f"Invalid input_type: {v}")
|
||||
return v
|
||||
|
||||
|
||||
class ConditionBranchConfig(BaseModel):
|
||||
"""Configuration for a conditional branch"""
|
||||
|
||||
@@ -71,7 +71,10 @@ class IfElseNode(BaseNode):
|
||||
for expression in case_branch.expressions:
|
||||
pattern = r"\{\{\s*(.*?)\s*\}\}"
|
||||
left_string = re.sub(pattern, r"\1", expression.left).strip()
|
||||
left_value = self.get_variable(left_string, state)
|
||||
try:
|
||||
left_value = self.get_variable(left_string, state)
|
||||
except KeyError:
|
||||
left_value = None
|
||||
evaluator = ConditionExpressionResolver.resolve_by_value(left_value)(
|
||||
self.get_variable_pool(state),
|
||||
expression.left,
|
||||
|
||||
@@ -203,15 +203,20 @@ class KnowledgeRetrievalNode(BaseNode):
|
||||
rs2 = vector_service.search_by_full_text(query=query, top_k=kb_config.top_k,
|
||||
indices=indices,
|
||||
score_threshold=kb_config.similarity_threshold)
|
||||
# Deduplicate hybrid retrieval results
|
||||
# Deduplicate hy brid retrieval results
|
||||
unique_rs = self._deduplicate_docs(rs1, rs2)
|
||||
if not unique_rs:
|
||||
continue
|
||||
vector_service.reranker = self.get_reranker_model()
|
||||
rs.extend(vector_service.rerank(query=query, docs=unique_rs, top_k=kb_config.top_k))
|
||||
case _:
|
||||
raise RuntimeError("Unknown retrieval type")
|
||||
if not rs:
|
||||
return []
|
||||
vector_service.reranker = self.get_reranker_model()
|
||||
# TODO:其他重排序方式支持
|
||||
final_rs = vector_service.rerank(query=query, docs=rs, top_k=self.typed_config.reranker_top_k)
|
||||
logger.info(
|
||||
f"Node {self.node_id}: knowledge base retrieval completed, results count: {len(final_rs)}"
|
||||
)
|
||||
return [chunk.model_dump() for chunk in final_rs]
|
||||
return [chunk.page_content for chunk in final_rs]
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
"""LLM 节点配置"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableDefinition, VariableType
|
||||
@@ -7,17 +9,17 @@ from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableDefiniti
|
||||
|
||||
class MessageConfig(BaseModel):
|
||||
"""消息配置"""
|
||||
|
||||
|
||||
role: str = Field(
|
||||
...,
|
||||
description="消息角色:system, user, assistant"
|
||||
)
|
||||
|
||||
|
||||
content: str = Field(
|
||||
...,
|
||||
description="消息内容,支持模板变量,如:{{ sys.message }}"
|
||||
)
|
||||
|
||||
|
||||
@field_validator("role")
|
||||
@classmethod
|
||||
def validate_role(cls, v: str) -> str:
|
||||
@@ -35,24 +37,29 @@ class LLMNodeConfig(BaseNodeConfig):
|
||||
1. 简单模式:使用 prompt 字段
|
||||
2. 消息模式:使用 messages 字段(推荐)
|
||||
"""
|
||||
|
||||
|
||||
model_id: str = Field(
|
||||
...,
|
||||
description="模型配置 ID"
|
||||
)
|
||||
|
||||
|
||||
context: Any = Field(
|
||||
default="",
|
||||
description="上下文"
|
||||
)
|
||||
|
||||
# 简单模式
|
||||
prompt: str | None = Field(
|
||||
default=None,
|
||||
description="提示词模板(简单模式),支持变量引用"
|
||||
)
|
||||
|
||||
|
||||
# 消息模式(推荐)
|
||||
messages: list[MessageConfig] | None = Field(
|
||||
default=None,
|
||||
description="消息列表(消息模式),支持多轮对话"
|
||||
)
|
||||
|
||||
|
||||
# 模型参数
|
||||
temperature: float | None = Field(
|
||||
default=0.7,
|
||||
@@ -60,35 +67,35 @@ class LLMNodeConfig(BaseNodeConfig):
|
||||
le=2.0,
|
||||
description="温度参数,控制输出的随机性"
|
||||
)
|
||||
|
||||
|
||||
max_tokens: int | None = Field(
|
||||
default=1000,
|
||||
ge=1,
|
||||
le=32000,
|
||||
description="最大生成 token 数"
|
||||
)
|
||||
|
||||
|
||||
top_p: float | None = Field(
|
||||
default=None,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description="Top-p 采样参数"
|
||||
)
|
||||
|
||||
|
||||
frequency_penalty: float | None = Field(
|
||||
default=None,
|
||||
ge=-2.0,
|
||||
le=2.0,
|
||||
description="频率惩罚"
|
||||
)
|
||||
|
||||
|
||||
presence_penalty: float | None = Field(
|
||||
default=None,
|
||||
ge=-2.0,
|
||||
le=2.0,
|
||||
description="存在惩罚"
|
||||
)
|
||||
|
||||
|
||||
# 输出变量定义
|
||||
output_variables: list[VariableDefinition] = Field(
|
||||
default_factory=lambda: [
|
||||
@@ -105,14 +112,14 @@ class LLMNodeConfig(BaseNodeConfig):
|
||||
],
|
||||
description="输出变量定义(自动生成,通常不需要修改)"
|
||||
)
|
||||
|
||||
|
||||
@field_validator("messages", "prompt")
|
||||
@classmethod
|
||||
def validate_input_mode(cls, v, info):
|
||||
"""验证输入模式:prompt 和 messages 至少有一个"""
|
||||
# 这个验证在 model_validator 中更合适
|
||||
return v
|
||||
|
||||
|
||||
class Config:
|
||||
json_schema_extra = {
|
||||
"examples": [
|
||||
|
||||
@@ -5,15 +5,17 @@ LLM 节点实现
|
||||
"""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from typing import Any
|
||||
from langchain_core.messages import AIMessage, SystemMessage, HumanMessage
|
||||
|
||||
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
|
||||
from app.core.models import RedBearLLM, RedBearModelConfig
|
||||
from app.core.workflow.nodes.llm.config import LLMNodeConfig
|
||||
from app.db import get_db_context
|
||||
from app.models import ModelType
|
||||
from app.services.model_service import ModelConfigService
|
||||
|
||||
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.error_codes import BizCode
|
||||
|
||||
@@ -63,8 +65,16 @@ class LLMNode(BaseNode):
|
||||
- user/human: 用户消息(HumanMessage)
|
||||
- ai/assistant: AI 消息(AIMessage)
|
||||
"""
|
||||
|
||||
def _prepare_llm(self, state: WorkflowState,stream:bool = False) -> tuple[RedBearLLM, list | str]:
|
||||
|
||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
||||
super().__init__(node_config, workflow_config)
|
||||
self.typed_config = LLMNodeConfig(**self.config)
|
||||
|
||||
def _render_context(self, message, state):
|
||||
context = f"<context>{self._render_template(self.typed_config.context, state)}</context>"
|
||||
return re.sub(r"{{context}}", context, message)
|
||||
|
||||
def _prepare_llm(self, state: WorkflowState, stream: bool = False) -> tuple[RedBearLLM, list | str]:
|
||||
"""准备 LLM 实例(公共逻辑)
|
||||
|
||||
Args:
|
||||
@@ -76,15 +86,16 @@ class LLMNode(BaseNode):
|
||||
|
||||
# 1. 处理消息格式(优先使用 messages)
|
||||
messages_config = self.config.get("messages")
|
||||
|
||||
|
||||
if messages_config:
|
||||
# 使用 LangChain 消息格式
|
||||
messages = []
|
||||
for msg_config in messages_config:
|
||||
role = msg_config.get("role", "user").lower()
|
||||
content_template = msg_config.get("content", "")
|
||||
content_template = self._render_context(content_template, state)
|
||||
content = self._render_template(content_template, state)
|
||||
|
||||
|
||||
# 根据角色创建对应的消息对象
|
||||
if role == "system":
|
||||
messages.append(SystemMessage(content=content))
|
||||
@@ -95,7 +106,7 @@ class LLMNode(BaseNode):
|
||||
else:
|
||||
logger.warning(f"未知的消息角色: {role},默认使用 user")
|
||||
messages.append(HumanMessage(content=content))
|
||||
|
||||
|
||||
prompt_or_messages = messages
|
||||
else:
|
||||
# 使用简单的 prompt 格式(向后兼容)
|
||||
@@ -106,17 +117,17 @@ class LLMNode(BaseNode):
|
||||
model_id = self.config.get("model_id")
|
||||
if not model_id:
|
||||
raise ValueError(f"节点 {self.node_id} 缺少 model_id 配置")
|
||||
|
||||
|
||||
# 3. 在 with 块内完成所有数据库操作和数据提取
|
||||
with get_db_context() as db:
|
||||
config = ModelConfigService.get_model_by_id(db=db, model_id=model_id)
|
||||
|
||||
if not config:
|
||||
|
||||
if not config:
|
||||
raise BusinessException("配置的模型不存在", BizCode.NOT_FOUND)
|
||||
|
||||
|
||||
if not config.api_keys or len(config.api_keys) == 0:
|
||||
raise BusinessException("模型配置缺少 API Key", BizCode.INVALID_PARAMETER)
|
||||
|
||||
|
||||
# 在 Session 关闭前提取所有需要的数据
|
||||
api_config = config.api_keys[0]
|
||||
model_name = api_config.model_name
|
||||
@@ -124,26 +135,26 @@ class LLMNode(BaseNode):
|
||||
api_key = api_config.api_key
|
||||
api_base = api_config.api_base
|
||||
model_type = config.type
|
||||
|
||||
|
||||
# 4. 创建 LLM 实例(使用已提取的数据)
|
||||
# 注意:对于流式输出,需要在模型初始化时设置 streaming=True
|
||||
extra_params = {"streaming": stream} if stream else {}
|
||||
|
||||
|
||||
llm = RedBearLLM(
|
||||
RedBearModelConfig(
|
||||
model_name=model_name,
|
||||
provider=provider,
|
||||
provider=provider,
|
||||
api_key=api_key,
|
||||
base_url=api_base,
|
||||
extra_params=extra_params
|
||||
),
|
||||
),
|
||||
type=ModelType(model_type)
|
||||
)
|
||||
|
||||
|
||||
logger.debug(f"创建 LLM 实例: provider={provider}, model={model_name}, streaming={stream}")
|
||||
|
||||
|
||||
return llm, prompt_or_messages
|
||||
|
||||
|
||||
async def execute(self, state: WorkflowState) -> AIMessage:
|
||||
"""非流式执行 LLM 调用
|
||||
|
||||
@@ -153,10 +164,10 @@ class LLMNode(BaseNode):
|
||||
Returns:
|
||||
LLM 响应消息
|
||||
"""
|
||||
llm, prompt_or_messages = self._prepare_llm(state,True)
|
||||
|
||||
llm, prompt_or_messages = self._prepare_llm(state, True)
|
||||
|
||||
logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(非流式)")
|
||||
|
||||
|
||||
# 调用 LLM(支持字符串或消息列表)
|
||||
response = await llm.ainvoke(prompt_or_messages)
|
||||
# 提取内容
|
||||
@@ -164,16 +175,16 @@ class LLMNode(BaseNode):
|
||||
content = response.content
|
||||
else:
|
||||
content = str(response)
|
||||
|
||||
|
||||
logger.info(f"节点 {self.node_id} LLM 调用完成,输出长度: {len(content)}")
|
||||
|
||||
|
||||
# 返回 AIMessage(包含响应元数据)
|
||||
return response if isinstance(response, AIMessage) else AIMessage(content=content)
|
||||
|
||||
|
||||
def _extract_input(self, state: WorkflowState) -> dict[str, Any]:
|
||||
"""提取输入数据(用于记录)"""
|
||||
_, prompt_or_messages = self._prepare_llm(state)
|
||||
|
||||
|
||||
return {
|
||||
"prompt": prompt_or_messages if isinstance(prompt_or_messages, str) else None,
|
||||
"messages": [
|
||||
@@ -186,13 +197,13 @@ class LLMNode(BaseNode):
|
||||
"max_tokens": self.config.get("max_tokens")
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def _extract_output(self, business_result: Any) -> str:
|
||||
"""从 AIMessage 中提取文本内容"""
|
||||
if isinstance(business_result, AIMessage):
|
||||
return business_result.content
|
||||
return str(business_result)
|
||||
|
||||
|
||||
def _extract_token_usage(self, business_result: Any) -> dict[str, int] | None:
|
||||
"""从 AIMessage 中提取 token 使用情况"""
|
||||
if isinstance(business_result, AIMessage) and hasattr(business_result, 'response_metadata'):
|
||||
@@ -204,7 +215,7 @@ class LLMNode(BaseNode):
|
||||
"total_tokens": usage.get('total_tokens', 0)
|
||||
}
|
||||
return None
|
||||
|
||||
|
||||
async def execute_stream(self, state: WorkflowState):
|
||||
"""流式执行 LLM 调用
|
||||
|
||||
@@ -215,26 +226,26 @@ class LLMNode(BaseNode):
|
||||
文本片段(chunk)或完成标记
|
||||
"""
|
||||
from langgraph.config import get_stream_writer
|
||||
|
||||
|
||||
llm, prompt_or_messages = self._prepare_llm(state, True)
|
||||
|
||||
|
||||
logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(流式)")
|
||||
logger.debug(f"LLM 配置: streaming={getattr(llm._model, 'streaming', 'unknown')}")
|
||||
|
||||
|
||||
# 检查是否有注入的 End 节点前缀配置
|
||||
writer = get_stream_writer()
|
||||
end_prefix = getattr(self, '_end_node_prefix', None)
|
||||
|
||||
|
||||
logger.info(f"[LLM前缀] 节点 {self.node_id} 检查前缀配置: {end_prefix is not None}")
|
||||
if end_prefix:
|
||||
logger.info(f"[LLM前缀] 前缀内容: '{end_prefix}'")
|
||||
|
||||
|
||||
if end_prefix:
|
||||
# 渲染前缀(可能包含其他变量)
|
||||
try:
|
||||
rendered_prefix = self._render_template(end_prefix, state)
|
||||
logger.info(f"节点 {self.node_id} 提前发送 End 节点前缀: '{rendered_prefix[:50]}...'")
|
||||
|
||||
|
||||
# 提前发送 End 节点的前缀(使用 "message" 类型)
|
||||
writer({
|
||||
"type": "message", # End 相关的内容都是 message 类型
|
||||
@@ -246,12 +257,12 @@ class LLMNode(BaseNode):
|
||||
})
|
||||
except Exception as e:
|
||||
logger.warning(f"渲染/发送 End 节点前缀失败: {e}")
|
||||
|
||||
|
||||
# 累积完整响应
|
||||
full_response = ""
|
||||
last_chunk = None
|
||||
chunk_count = 0
|
||||
|
||||
|
||||
# 调用 LLM(流式,支持字符串或消息列表)
|
||||
async for chunk in llm.astream(prompt_or_messages):
|
||||
# 提取内容
|
||||
@@ -259,18 +270,18 @@ class LLMNode(BaseNode):
|
||||
content = chunk.content
|
||||
else:
|
||||
content = str(chunk)
|
||||
|
||||
|
||||
# 只有当内容不为空时才处理
|
||||
if content:
|
||||
full_response += content
|
||||
last_chunk = chunk
|
||||
chunk_count += 1
|
||||
|
||||
|
||||
# 流式返回每个文本片段
|
||||
yield content
|
||||
|
||||
|
||||
logger.info(f"节点 {self.node_id} LLM 调用完成,输出长度: {len(full_response)}, 总 chunks: {chunk_count}")
|
||||
|
||||
|
||||
# 构建完整的 AIMessage(包含元数据)
|
||||
if isinstance(last_chunk, AIMessage):
|
||||
final_message = AIMessage(
|
||||
@@ -279,6 +290,6 @@ class LLMNode(BaseNode):
|
||||
)
|
||||
else:
|
||||
final_message = AIMessage(content=full_response)
|
||||
|
||||
|
||||
# yield 完成标记
|
||||
yield {"__final__": True, "result": final_message}
|
||||
|
||||
@@ -24,7 +24,7 @@ class MemoryReadNode(BaseNode):
|
||||
|
||||
return await MemoryAgentService().read_memory(
|
||||
group_id=end_user_id,
|
||||
message=self.typed_config.message,
|
||||
message=self._render_template(self.typed_config.message, state),
|
||||
config_id=self.typed_config.config_id,
|
||||
search_switch=self.typed_config.search_switch,
|
||||
history=[],
|
||||
@@ -51,7 +51,7 @@ class MemoryWriteNode(BaseNode):
|
||||
|
||||
return await MemoryAgentService().write_memory(
|
||||
group_id=end_user_id,
|
||||
message=self.typed_config.message,
|
||||
message=self._render_template(self.typed_config.message, state),
|
||||
config_id=self.typed_config.config_id,
|
||||
db=db,
|
||||
storage_type="neo4j",
|
||||
|
||||
@@ -387,6 +387,11 @@ class ArrayComparisonOperator(ConditionBase):
|
||||
return self.right_value not in self.left_value
|
||||
|
||||
|
||||
class NoneObjectComparisonOperator(ConditionBase):
|
||||
def __getattr__(self, name):
|
||||
return lambda *args, **kwargs: False
|
||||
|
||||
|
||||
CompareOperatorInstance = Union[
|
||||
StringComparisonOperator,
|
||||
NumberComparisonOperator,
|
||||
@@ -405,6 +410,7 @@ class ConditionExpressionResolver:
|
||||
float: NumberComparisonOperator,
|
||||
list: ArrayComparisonOperator,
|
||||
dict: ObjectComparisonOperator,
|
||||
type(None): NoneObjectComparisonOperator
|
||||
}
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -65,7 +65,7 @@ class QuestionClassifierNode(BaseNode):
|
||||
category_map[category_name] = case_tag
|
||||
return category_map
|
||||
|
||||
async def execute(self, state: WorkflowState) -> str:
|
||||
async def execute(self, state: WorkflowState) -> dict:
|
||||
"""执行问题分类"""
|
||||
question = self.typed_config.input_variable
|
||||
supplement_prompt = self.typed_config.user_supplement_prompt or ""
|
||||
@@ -79,7 +79,15 @@ class QuestionClassifierNode(BaseNode):
|
||||
f"(默认分支:{DEFAULT_EMPTY_QUESTION_CASE},分类总数:{category_count})"
|
||||
)
|
||||
# 若分类列表为空,返回默认unknown分支,否则返回CASE1
|
||||
return DEFAULT_EMPTY_QUESTION_CASE if category_count > 0 else "unknown"
|
||||
if category_count > 0:
|
||||
return {
|
||||
"class_name": category_names[0],
|
||||
"output": DEFAULT_EMPTY_QUESTION_CASE
|
||||
}
|
||||
return {
|
||||
"class_name": "unknown",
|
||||
"output": DEFAULT_EMPTY_QUESTION_CASE
|
||||
}
|
||||
|
||||
try:
|
||||
llm = self._get_llm_instance()
|
||||
@@ -111,7 +119,10 @@ class QuestionClassifierNode(BaseNode):
|
||||
log_supplement = supplement_prompt if supplement_prompt else "无"
|
||||
logger.info(f"节点 {self.node_id} 分类结果: {category}, 用户补充提示词:{log_supplement}")
|
||||
|
||||
return f"CASE{category_names.index(category) + 1}"
|
||||
return {
|
||||
"class_name": category,
|
||||
"output": f"CASE{category_names.index(category) + 1}",
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"节点 {self.node_id} 分类执行异常:{str(e)}",
|
||||
@@ -119,5 +130,11 @@ class QuestionClassifierNode(BaseNode):
|
||||
)
|
||||
# 异常时返回默认分支,保证工作流容错性
|
||||
if category_count > 0:
|
||||
return DEFAULT_EMPTY_QUESTION_CASE
|
||||
return "unknown"
|
||||
return {
|
||||
"class_name": category_names[0],
|
||||
"output": DEFAULT_EMPTY_QUESTION_CASE
|
||||
}
|
||||
return {
|
||||
"class_name": "unknown",
|
||||
"output": DEFAULT_EMPTY_QUESTION_CASE
|
||||
}
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
from pydantic import Field
|
||||
from typing import Any
|
||||
|
||||
from app.core.workflow.nodes.base_config import BaseNodeConfig
|
||||
|
||||
|
||||
@@ -6,4 +8,4 @@ class ToolNodeConfig(BaseNodeConfig):
|
||||
"""工具节点配置"""
|
||||
|
||||
tool_id: str = Field(..., description="工具ID")
|
||||
tool_parameters: dict[str, str] = Field(default_factory=dict, description="工具参数映射,支持工作流变量")
|
||||
tool_parameters: dict[str, Any] = Field(default_factory=dict, description="工具参数映射,支持工作流变量")
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import logging
|
||||
import uuid
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
|
||||
@@ -9,6 +9,8 @@ from app.db import get_db_read
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
TEMPLATE_PATTERN = re.compile(r"\{\{.*?\}\}")
|
||||
|
||||
|
||||
class ToolNode(BaseNode):
|
||||
"""工具节点"""
|
||||
@@ -25,25 +27,33 @@ class ToolNode(BaseNode):
|
||||
|
||||
# 如果没有租户ID,尝试从工作流ID获取
|
||||
if not tenant_id:
|
||||
workflow_id = self.get_variable("sys.workflow_id", state)
|
||||
if workflow_id:
|
||||
workspace_id = self.get_variable("sys.workspace_id", state)
|
||||
if workspace_id:
|
||||
from app.repositories.tool_repository import ToolRepository
|
||||
with get_db_read() as db:
|
||||
tenant_id = ToolRepository.get_tenant_id_by_workflow_id(db, workflow_id)
|
||||
tenant_id = ToolRepository.get_tenant_id_by_workspace_id(db, workspace_id)
|
||||
|
||||
if not tenant_id:
|
||||
tenant_id = uuid.UUID("6c2c91b0-3f49-4489-9157-2208aa56a097")
|
||||
# logger.error(f"节点 {self.node_id} 缺少租户ID")
|
||||
# return {"error": "缺少租户ID"}
|
||||
logger.error(f"节点 {self.node_id} 缺少租户ID")
|
||||
return {
|
||||
"success": False,
|
||||
"data": "缺少租户ID"
|
||||
}
|
||||
|
||||
# 渲染工具参数
|
||||
rendered_parameters = {}
|
||||
for param_name, param_template in self.typed_config.tool_parameters.items():
|
||||
rendered_value = self._render_template(param_template, state)
|
||||
if isinstance(param_template, str) and TEMPLATE_PATTERN.search(param_template):
|
||||
try:
|
||||
rendered_value = self._render_template(param_template, state)
|
||||
except Exception as e:
|
||||
raise ValueError(f"模板渲染失败:参数 {param_name} 的模板 {param_template} 解析错误") from e
|
||||
else:
|
||||
# 非模板参数(数字/布尔/普通字符串)直接保留原值
|
||||
rendered_value = param_template
|
||||
rendered_parameters[param_name] = rendered_value
|
||||
|
||||
logger.info(f"节点 {self.node_id} 执行工具 {self.typed_config.tool_id},参数: {rendered_parameters}")
|
||||
print(self.typed_config.tool_id)
|
||||
|
||||
# 执行工具
|
||||
with get_db_read() as db:
|
||||
@@ -54,7 +64,7 @@ class ToolNode(BaseNode):
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id
|
||||
)
|
||||
print(result)
|
||||
|
||||
if result.success:
|
||||
logger.info(f"节点 {self.node_id} 工具执行成功")
|
||||
return {
|
||||
@@ -66,7 +76,7 @@ class ToolNode(BaseNode):
|
||||
logger.error(f"节点 {self.node_id} 工具执行失败: {result.error}")
|
||||
return {
|
||||
"success": False,
|
||||
"error": result.error,
|
||||
"data": result.error,
|
||||
"error_code": result.error_code,
|
||||
"execution_time": result.execution_time
|
||||
}
|
||||
@@ -14,7 +14,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
class TemplateRenderer:
|
||||
"""模板渲染器"""
|
||||
|
||||
|
||||
def __init__(self, strict: bool = True):
|
||||
"""初始化渲染器
|
||||
|
||||
@@ -25,13 +25,13 @@ class TemplateRenderer:
|
||||
undefined=StrictUndefined if strict else Undefined,
|
||||
autoescape=False # 不自动转义,因为我们处理的是文本而非 HTML
|
||||
)
|
||||
|
||||
|
||||
def render(
|
||||
self,
|
||||
template: str,
|
||||
variables: dict[str, Any],
|
||||
node_outputs: dict[str, Any],
|
||||
system_vars: dict[str, Any] | None = None
|
||||
self,
|
||||
template: str,
|
||||
variables: dict[str, Any],
|
||||
node_outputs: dict[str, Any],
|
||||
system_vars: dict[str, Any] | None = None
|
||||
) -> str:
|
||||
"""渲染模板
|
||||
|
||||
@@ -69,40 +69,40 @@ class TemplateRenderer:
|
||||
# variables 的结构:{"sys": {...}, "conv": {...}}
|
||||
sys_vars = variables.get("sys", {}) if isinstance(variables, dict) else {}
|
||||
conv_vars = variables.get("conv", {}) if isinstance(variables, dict) else {}
|
||||
|
||||
|
||||
context = {
|
||||
"conv": conv_vars, # 会话变量:{{conv.user_name}}
|
||||
"node": node_outputs, # 节点输出:{{node.node_1.output}}
|
||||
"conv": conv_vars, # 会话变量:{{conv.user_name}}
|
||||
"node": node_outputs, # 节点输出:{{node.node_1.output}}
|
||||
"sys": {**(system_vars or {}), **sys_vars}, # 系统变量:{{sys.execution_id}}(合并两个来源)
|
||||
}
|
||||
|
||||
|
||||
# 支持直接通过节点ID访问节点输出:{{llm_qa.output}}
|
||||
# 将所有节点输出添加到顶层上下文
|
||||
if node_outputs:
|
||||
context.update(node_outputs)
|
||||
|
||||
|
||||
# 支持直接访问会话变量(不需要 conv. 前缀):{{user_name}}
|
||||
if conv_vars:
|
||||
context.update(conv_vars)
|
||||
|
||||
|
||||
context["nodes"] = node_outputs or {} # 旧语法兼容
|
||||
|
||||
|
||||
try:
|
||||
tmpl = self.env.from_string(template)
|
||||
return tmpl.render(**context)
|
||||
|
||||
|
||||
except TemplateSyntaxError as e:
|
||||
logger.error(f"模板语法错误: {template}, 错误: {e}")
|
||||
raise ValueError(f"模板语法错误: {e}")
|
||||
|
||||
|
||||
except UndefinedError as e:
|
||||
logger.error(f"模板中引用了未定义的变量: {template}, 错误: {e}")
|
||||
raise ValueError(f"未定义的变量: {e}")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"模板渲染异常: {template}, 错误: {e}")
|
||||
raise ValueError(f"模板渲染失败: {e}")
|
||||
|
||||
|
||||
def validate(self, template: str) -> list[str]:
|
||||
"""验证模板语法
|
||||
|
||||
@@ -121,14 +121,14 @@ class TemplateRenderer:
|
||||
['模板语法错误: ...']
|
||||
"""
|
||||
errors = []
|
||||
|
||||
|
||||
try:
|
||||
self.env.from_string(template)
|
||||
except TemplateSyntaxError as e:
|
||||
errors.append(f"模板语法错误: {e}")
|
||||
except Exception as e:
|
||||
errors.append(f"模板验证失败: {e}")
|
||||
|
||||
|
||||
return errors
|
||||
|
||||
|
||||
@@ -137,14 +137,16 @@ _default_renderer = TemplateRenderer(strict=True)
|
||||
|
||||
|
||||
def render_template(
|
||||
template: str,
|
||||
variables: dict[str, Any],
|
||||
node_outputs: dict[str, Any],
|
||||
system_vars: dict[str, Any] | None = None
|
||||
template: str,
|
||||
variables: dict[str, Any],
|
||||
node_outputs: dict[str, Any],
|
||||
system_vars: dict[str, Any] | None = None,
|
||||
struct: bool = True
|
||||
) -> str:
|
||||
"""渲染模板(便捷函数)
|
||||
|
||||
Args:
|
||||
struct: 渲染模式
|
||||
template: 模板字符串
|
||||
variables: 用户变量
|
||||
node_outputs: 节点输出
|
||||
@@ -162,7 +164,8 @@ def render_template(
|
||||
... )
|
||||
'请分析: 这是一段文本'
|
||||
"""
|
||||
return _default_renderer.render(template, variables, node_outputs, system_vars)
|
||||
renderer = TemplateRenderer(strict=struct)
|
||||
return renderer.render(template, variables, node_outputs, system_vars)
|
||||
|
||||
|
||||
def validate_template(template: str) -> list[str]:
|
||||
|
||||
@@ -87,10 +87,11 @@ class WorkflowValidator:
|
||||
return graphs
|
||||
|
||||
@classmethod
|
||||
def validate(cls, workflow_config: Union[dict[str, Any], Any]) -> tuple[bool, list[str]]:
|
||||
def validate(cls, workflow_config: Union[dict[str, Any], Any], publish=False) -> tuple[bool, list[str]]:
|
||||
"""验证工作流配置
|
||||
|
||||
Args:
|
||||
publish: 发布验证标识
|
||||
workflow_config: 工作流配置字典或 WorkflowConfig Pydantic 模型
|
||||
|
||||
Returns:
|
||||
@@ -114,7 +115,7 @@ class WorkflowValidator:
|
||||
|
||||
graphs = cls.get_subgraph(workflow_config)
|
||||
logger.info(graphs)
|
||||
for graph in graphs:
|
||||
for index, graph in enumerate(graphs):
|
||||
nodes = graph.get("nodes", [])
|
||||
edges = graph.get("edges", [])
|
||||
variables = graph.get("variables", [])
|
||||
@@ -125,10 +126,11 @@ class WorkflowValidator:
|
||||
elif len(start_nodes) > 1:
|
||||
errors.append(f"工作流只能有一个 start 节点,当前有 {len(start_nodes)} 个")
|
||||
|
||||
# 2. 验证 end 节点(至少一个)
|
||||
end_nodes = [n for n in nodes if n.get("type") == NodeType.END]
|
||||
if len(end_nodes) == 0:
|
||||
errors.append("工作流必须至少有一个 end 节点")
|
||||
if index == len(graphs) - 1:
|
||||
# 2. 验证 主图end 节点(至少一个)
|
||||
end_nodes = [n for n in nodes if n.get("type") == NodeType.END]
|
||||
if len(end_nodes) == 0:
|
||||
errors.append("工作流必须至少有一个 end 节点")
|
||||
|
||||
# 3. 验证节点 ID 唯一性
|
||||
node_ids = [n.get("id") for n in nodes]
|
||||
@@ -159,15 +161,17 @@ class WorkflowValidator:
|
||||
elif target not in node_id_set:
|
||||
errors.append(f"边 #{i} 的 target 节点不存在: {target}")
|
||||
|
||||
# 6. 验证所有节点可达(从 start 节点出发)
|
||||
if start_nodes and not errors: # 只有在前面验证通过时才检查可达性
|
||||
reachable = WorkflowValidator._get_reachable_nodes(
|
||||
start_nodes[0]["id"],
|
||||
edges
|
||||
)
|
||||
unreachable = node_id_set - reachable
|
||||
if unreachable:
|
||||
errors.append(f"以下节点无法从 start 节点到达: {unreachable}")
|
||||
if publish:
|
||||
# 仅在发布时验证所有节点可达
|
||||
# 6. 验证所有节点可达(从 start 节点出发)
|
||||
if start_nodes and not errors: # 只有在前面验证通过时才检查可达性
|
||||
reachable = WorkflowValidator._get_reachable_nodes(
|
||||
start_nodes[0]["id"],
|
||||
edges
|
||||
)
|
||||
unreachable = node_id_set - reachable
|
||||
if unreachable:
|
||||
errors.append(f"以下节点无法从 start 节点到达: {unreachable}")
|
||||
|
||||
# 7. 检测循环依赖(非 loop 节点)
|
||||
if not errors: # 只有在前面验证通过时才检查循环
|
||||
@@ -288,7 +292,7 @@ class WorkflowValidator:
|
||||
(is_valid, errors): 是否有效和错误列表
|
||||
"""
|
||||
# 先执行基础验证
|
||||
is_valid, errors = WorkflowValidator.validate(workflow_config)
|
||||
is_valid, errors = WorkflowValidator.validate(workflow_config, publish=True)
|
||||
|
||||
if not is_valid:
|
||||
return False, errors
|
||||
|
||||
@@ -6,6 +6,7 @@ from .document_model import Document
|
||||
from .file_model import File
|
||||
from .generic_file_model import GenericFile
|
||||
from .models_model import ModelConfig, ModelProvider, ModelType, ModelApiKey
|
||||
from .memory_short_model import ShortTermMemory, LongTermMemory
|
||||
from .knowledgeshare_model import KnowledgeShare
|
||||
from .app_model import App
|
||||
from .agent_app_config_model import AgentConfig
|
||||
@@ -25,6 +26,7 @@ from .tool_model import (
|
||||
ToolConfig, BuiltinToolConfig, CustomToolConfig, MCPToolConfig,
|
||||
ToolExecution, ToolType, ToolStatus, AuthType, ExecutionStatus
|
||||
)
|
||||
from .memory_perceptual_model import MemoryPerceptualModel
|
||||
|
||||
__all__ = [
|
||||
"Tenants",
|
||||
@@ -67,9 +69,12 @@ __all__ = [
|
||||
"BuiltinToolConfig",
|
||||
"CustomToolConfig",
|
||||
"MCPToolConfig",
|
||||
"ShortTermMemory",
|
||||
"LongTermMemory",
|
||||
"ToolExecution",
|
||||
"ToolType",
|
||||
"ToolStatus",
|
||||
"AuthType",
|
||||
"ExecutionStatus"
|
||||
"ExecutionStatus",
|
||||
"MemoryPerceptualModel"
|
||||
]
|
||||
|
||||
@@ -3,7 +3,10 @@ import uuid
|
||||
from sqlalchemy import Column, String, Boolean, DateTime, Text, ForeignKey
|
||||
from sqlalchemy.dialects.postgresql import UUID, JSON
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
from app.base.type import PydanticType
|
||||
from app.db import Base
|
||||
from app.schemas import ModelParameters
|
||||
|
||||
|
||||
class AgentConfig(Base):
|
||||
@@ -17,14 +20,17 @@ class AgentConfig(Base):
|
||||
# Agent 行为配置
|
||||
system_prompt = Column(Text, nullable=True, comment="系统提示词")
|
||||
default_model_config_id = Column(UUID(as_uuid=True), ForeignKey("model_configs.id"), nullable=True, index=True, comment="默认模型配置ID")
|
||||
|
||||
|
||||
# 结构化配置(直接存储 JSON)
|
||||
model_parameters = Column(JSON, nullable=True, comment="模型参数配置(temperature、max_tokens等)")
|
||||
# model_parameters = Column(JSON, nullable=True, comment="模型参数配置(temperature、max_tokens等)")
|
||||
model_parameters = Column(PydanticType(ModelParameters), nullable=True,
|
||||
comment="模型参数配置(temperature、max_tokens等)")
|
||||
|
||||
knowledge_retrieval = Column(JSON, nullable=True, comment="知识库检索配置")
|
||||
memory = Column(JSON, nullable=True, comment="记忆配置")
|
||||
variables = Column(JSON, default=list, nullable=True, comment="变量配置")
|
||||
tools = Column(JSON, default=dict, nullable=True, comment="工具配置")
|
||||
|
||||
|
||||
# 多 Agent 相关字段
|
||||
agent_role = Column(String(20), comment="Agent 角色: master|sub|standalone")
|
||||
agent_domain = Column(String(50), comment="专业领域: customer_service|technical_support|sales 等")
|
||||
@@ -41,4 +47,4 @@ class AgentConfig(Base):
|
||||
parent_agent = relationship("AgentConfig", remote_side=[id], backref="sub_agents")
|
||||
|
||||
def __repr__(self):
|
||||
return f"<AgentConfig(id={self.id}, app_id={self.app_id})>"
|
||||
return f"<AgentConfig(id={self.id}, app_id={self.app_id})>"
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
"""
|
||||
import uuid
|
||||
import datetime
|
||||
|
||||
from sqlalchemy import Column, String, DateTime, ForeignKey, Boolean, Integer, Text, JSON
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.orm import relationship
|
||||
@@ -25,56 +26,69 @@ class Conversation(Base):
|
||||
__tablename__ = "conversations"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, index=True)
|
||||
|
||||
|
||||
# 关联信息
|
||||
app_id = Column(UUID(as_uuid=True), ForeignKey("apps.id"), nullable=False, comment="应用ID")
|
||||
workspace_id = Column(UUID(as_uuid=True), ForeignKey("workspaces.id"), nullable=False, comment="工作空间ID")
|
||||
user_id = Column(String, nullable=True, comment="用户ID(外部系统)")
|
||||
|
||||
|
||||
# 会话信息
|
||||
title = Column(String(255), comment="会话标题")
|
||||
summary = Column(Text, comment="会话摘要")
|
||||
|
||||
|
||||
# 会话类型:True=草稿会话(使用草稿配置),False=发布会话(使用发布配置)
|
||||
is_draft = Column(Boolean, default=True, nullable=False, comment="是否为草稿会话")
|
||||
|
||||
|
||||
# 配置快照:保存创建会话时的完整配置,用于审计和问题追溯
|
||||
config_snapshot = Column(JSON, comment="配置快照(Agent配置、模型配置等)")
|
||||
|
||||
|
||||
# 统计信息
|
||||
message_count = Column(Integer, default=0, comment="消息数量")
|
||||
|
||||
|
||||
# 状态
|
||||
is_active = Column(Boolean, default=True, nullable=False, comment="是否活跃")
|
||||
|
||||
|
||||
# 时间戳
|
||||
created_at = Column(DateTime, default=datetime.datetime.now, comment="创建时间")
|
||||
updated_at = Column(DateTime, default=datetime.datetime.now, onupdate=datetime.datetime.now, comment="更新时间")
|
||||
|
||||
|
||||
# 关联关系
|
||||
app = relationship("App", back_populates="conversations")
|
||||
workspace = relationship("Workspace")
|
||||
messages = relationship("Message", back_populates="conversation", cascade="all, delete-orphan")
|
||||
|
||||
|
||||
class ConversationDetail(Base):
|
||||
__tablename__ = "conversation_details"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, index=True)
|
||||
conversation_id = Column(UUID(as_uuid=True), ForeignKey("conversations.id"))
|
||||
|
||||
theme = Column(String, comment="会话主题")
|
||||
summary = Column(String, comment="会话摘要")
|
||||
takeaways = Column(JSON, comment="会话要点")
|
||||
question = Column(JSON, comment="用户问题")
|
||||
info_score = Column(Integer, comment="会话信息量评分")
|
||||
|
||||
|
||||
class Message(Base):
|
||||
"""消息表"""
|
||||
__tablename__ = "messages"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, index=True)
|
||||
|
||||
|
||||
# 关联信息
|
||||
conversation_id = Column(UUID(as_uuid=True), ForeignKey("conversations.id"), nullable=False, comment="会话ID")
|
||||
|
||||
|
||||
# 消息内容
|
||||
role = Column(String(20), nullable=False, comment="角色: user/assistant/system")
|
||||
content = Column(Text, nullable=False, comment="消息内容")
|
||||
|
||||
|
||||
# 元数据(避免使用 metadata 保留字)
|
||||
meta_data = Column(JSON, comment="消息元数据(如模型、token使用等)")
|
||||
|
||||
|
||||
# 时间戳
|
||||
created_at = Column(DateTime, default=datetime.datetime.now, comment="创建时间")
|
||||
|
||||
|
||||
# 关联关系
|
||||
conversation = relationship("Conversation", back_populates="messages")
|
||||
|
||||
@@ -25,7 +25,6 @@ class DataConfig(Base):
|
||||
llm_id = Column(String, nullable=True, comment="LLM模型配置ID")
|
||||
embedding_id = Column(String, nullable=True, comment="嵌入模型配置ID")
|
||||
rerank_id = Column(String, nullable=True, comment="重排序模型配置ID")
|
||||
llm = Column(String, nullable=True, comment="LLM模型配置ID")
|
||||
|
||||
# 记忆萃取引擎配置
|
||||
enable_llm_dedup_blockwise = Column(Boolean, default=True, comment="启用LLM决策去重")
|
||||
|
||||
40
api/app/models/forgetting_cycle_history_model.py
Normal file
40
api/app/models/forgetting_cycle_history_model.py
Normal file
@@ -0,0 +1,40 @@
|
||||
"""
|
||||
遗忘周期历史记录模型
|
||||
|
||||
用于存储每次遗忘周期执行的历史数据,支持趋势分析和可视化。
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from sqlalchemy import Column, Integer, String, Float, DateTime, Index
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from app.db import Base
|
||||
|
||||
|
||||
class ForgettingCycleHistory(Base):
|
||||
"""遗忘周期历史记录表"""
|
||||
|
||||
__tablename__ = "forgetting_cycle_history"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, nullable=False, index=True, comment="主键ID")
|
||||
end_user_id = Column(String(255), nullable=False, comment="终端用户ID")
|
||||
execution_time = Column(DateTime, nullable=False, default=datetime.now, comment="执行时间")
|
||||
merged_count = Column(Integer, default=0, comment="本次成功融合的节点对数")
|
||||
failed_count = Column(Integer, default=0, comment="本次融合失败的节点对数")
|
||||
average_activation_value = Column(Float, nullable=True, comment="平均激活值")
|
||||
total_nodes = Column(Integer, default=0, comment="总节点数")
|
||||
low_activation_nodes = Column(Integer, default=0, comment="低于遗忘阈值的节点总数(包含已融合、失败和待处理的)")
|
||||
duration_seconds = Column(Float, nullable=True, comment="执行耗时(秒)")
|
||||
trigger_type = Column(String(50), default="manual", comment="触发类型: manual/scheduled")
|
||||
|
||||
# 创建索引以优化查询
|
||||
__table_args__ = (
|
||||
Index('idx_end_user_time', 'end_user_id', 'execution_time'),
|
||||
Index('idx_execution_time', 'execution_time'),
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
f"<ForgettingCycleHistory(id={self.id}, end_user_id={self.end_user_id}, "
|
||||
f"merged_count={self.merged_count}, execution_time={self.execution_time})>"
|
||||
)
|
||||
40
api/app/models/memory_perceptual_model.py
Normal file
40
api/app/models/memory_perceptual_model.py
Normal file
@@ -0,0 +1,40 @@
|
||||
import datetime
|
||||
import uuid
|
||||
from enum import IntEnum
|
||||
|
||||
from sqlalchemy import Column, ForeignKey, Integer, DateTime, String
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.dialects.postgresql import JSONB
|
||||
|
||||
from app.db import Base
|
||||
|
||||
|
||||
class PerceptualType(IntEnum):
|
||||
VISION = 1
|
||||
AUDIO = 2
|
||||
TEXT = 3
|
||||
CONVERSATION = 4
|
||||
|
||||
|
||||
class FileStorageType(IntEnum):
|
||||
LOCAL = 1
|
||||
REMOTE = 2
|
||||
|
||||
|
||||
class MemoryPerceptualModel(Base):
|
||||
__tablename__ = "memory_perceptual"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
end_user_id = Column(UUID(as_uuid=True), ForeignKey("end_users.id"), index=True)
|
||||
|
||||
perceptual_type = Column(Integer, index=True, nullable=False, comment="感知类型")
|
||||
|
||||
storage_service = Column(Integer, default=0, comment="存储服务类型")
|
||||
file_path = Column(String, nullable=False, comment="文件路径")
|
||||
file_name = Column(String, nullable=False, comment="文件名称")
|
||||
file_ext = Column(String, nullable=False, comment="文件后缀名")
|
||||
|
||||
summary = Column(String, comment="摘要")
|
||||
meta_data = Column(JSONB, comment="元信息")
|
||||
|
||||
created_time = Column(DateTime, default=datetime.datetime.now, comment="创建时间")
|
||||
60
api/app/models/memory_short_model.py
Normal file
60
api/app/models/memory_short_model.py
Normal file
@@ -0,0 +1,60 @@
|
||||
"""
|
||||
记忆模型 - 短期记忆和长期记忆表
|
||||
"""
|
||||
import uuid
|
||||
import datetime
|
||||
from sqlalchemy import Column, String, DateTime, Text, JSON
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
|
||||
from app.db import Base
|
||||
|
||||
|
||||
class ShortTermMemory(Base):
|
||||
"""短期记忆表
|
||||
|
||||
用于存储临时的对话记忆,通常保存较短时间
|
||||
"""
|
||||
__tablename__ = "memory_short_term"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, index=True, comment="记忆ID")
|
||||
|
||||
# 用户信息
|
||||
end_user_id = Column(String(255), nullable=False, index=True, comment="终端用户ID")
|
||||
|
||||
# 对话内容
|
||||
messages = Column(Text, nullable=False, comment="用户消息内容")
|
||||
aimessages = Column(Text, nullable=True, comment="AI回复消息内容")
|
||||
|
||||
# 搜索开关
|
||||
search_switch = Column(String(50), nullable=True, comment="搜索开关状态")
|
||||
|
||||
# 检索内容 - 存储为JSON格式的列表,包含字典 [{}, {}]
|
||||
retrieved_content = Column(JSON, nullable=True, default=list, comment="检索到的相关内容,格式为[{}, {}]")
|
||||
|
||||
# 时间戳
|
||||
created_at = Column(DateTime, default=datetime.datetime.now, nullable=False, index=True, comment="创建时间")
|
||||
|
||||
def __repr__(self):
|
||||
return f"<ShortTermMemory(id={self.id}, end_user_id={self.end_user_id}, created_at={self.created_at})>"
|
||||
|
||||
|
||||
class LongTermMemory(Base):
|
||||
"""长期记忆表
|
||||
|
||||
用于存储重要的对话记忆,长期保存
|
||||
"""
|
||||
__tablename__ = "memory_long_term"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, index=True, comment="记忆ID")
|
||||
|
||||
# 用户信息
|
||||
end_user_id = Column(String(255), nullable=False, index=True, comment="终端用户ID")
|
||||
|
||||
# 检索内容 - 存储为JSON格式的列表,包含字典 [{}, {}]
|
||||
retrieved_content = Column(JSON, nullable=True, default=list, comment="检索到的相关内容,格式为[{}, {}]")
|
||||
|
||||
# 时间戳
|
||||
created_at = Column(DateTime, default=datetime.datetime.now, nullable=False, index=True, comment="创建时间")
|
||||
|
||||
def __repr__(self):
|
||||
return f"<LongTermMemory(id={self.id}, end_user_id={self.end_user_id}, created_at={self.created_at})>"
|
||||
@@ -8,15 +8,15 @@ from sqlalchemy import Column, String, Boolean, DateTime, Integer, Float, Text,
|
||||
from sqlalchemy.dialects.postgresql import UUID, JSON
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
from app.base.type import PydanticType
|
||||
from app.db import Base
|
||||
from app.schemas import ModelParameters
|
||||
|
||||
|
||||
class OrchestrationMode(StrEnum):
|
||||
"""图标类型枚举"""
|
||||
SEQUENTIAL = "sequential"
|
||||
PARALLEL = "parallel"
|
||||
CONDITIONAL = "conditional"
|
||||
"""协作模式枚举"""
|
||||
COLLABORATION = "collaboration" # 协作模式:Agent 之间可以相互 handoff
|
||||
SUPERVISOR = "supervisor" # 监督模式:由主 Agent 统一调度子 Agent
|
||||
|
||||
class AggregationStrategy(StrEnum):
|
||||
"""图标类型枚举"""
|
||||
@@ -24,27 +24,27 @@ class AggregationStrategy(StrEnum):
|
||||
VOTE = "vote"
|
||||
PRIORITY = "priority"
|
||||
|
||||
class PydanticType(TypeDecorator):
|
||||
impl = JSON
|
||||
# class PydanticType(TypeDecorator):
|
||||
# impl = JSON
|
||||
|
||||
def __init__(self, pydantic_model: type[BaseModel]):
|
||||
super().__init__()
|
||||
self.model = pydantic_model
|
||||
# def __init__(self, pydantic_model: type[BaseModel]):
|
||||
# super().__init__()
|
||||
# self.model = pydantic_model
|
||||
|
||||
def process_bind_param(self, value, dialect):
|
||||
# 入库:Model -> dict
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, self.model):
|
||||
return value.dict()
|
||||
return value # 已经是 dict 也放行
|
||||
# def process_bind_param(self, value, dialect):
|
||||
# # 入库:Model -> dict
|
||||
# if value is None:
|
||||
# return None
|
||||
# if isinstance(value, self.model):
|
||||
# return value.dict()
|
||||
# return value # 已经是 dict 也放行
|
||||
|
||||
def process_result_value(self, value, dialect):
|
||||
# 出库:dict -> Model
|
||||
if value is None:
|
||||
return None
|
||||
# return self.model.parse_obj(value) # pydantic v1
|
||||
return self.model.model_validate(value) # pydantic v2
|
||||
# def process_result_value(self, value, dialect):
|
||||
# # 出库:dict -> Model
|
||||
# if value is None:
|
||||
# return None
|
||||
# # return self.model.parse_obj(value) # pydantic v1
|
||||
# return self.model.model_validate(value) # pydantic v2
|
||||
|
||||
class MultiAgentConfig(Base):
|
||||
"""多 Agent 配置表"""
|
||||
@@ -66,8 +66,8 @@ class MultiAgentConfig(Base):
|
||||
orchestration_mode = Column(
|
||||
String(20),
|
||||
nullable=False,
|
||||
default="conditional",
|
||||
comment="协作模式: sequential|parallel|conditional|loop"
|
||||
default="collaboration",
|
||||
comment="协作模式: collaboration(协作)| supervisor(监督)"
|
||||
)
|
||||
|
||||
# 子 Agent 列表
|
||||
|
||||
317
api/app/repositories/conversation_repository.py
Normal file
317
api/app/repositories/conversation_repository.py
Normal file
@@ -0,0 +1,317 @@
|
||||
import uuid
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import select, desc, func
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.exceptions import ResourceNotFoundException
|
||||
from app.core.logging_config import get_db_logger
|
||||
from app.models import Conversation, Message
|
||||
from app.models.conversation_model import ConversationDetail
|
||||
|
||||
logger = get_db_logger()
|
||||
|
||||
|
||||
class ConversationRepository:
|
||||
"""Repository for Conversation entity, encapsulating CRUD operations."""
|
||||
|
||||
def __init__(self, db: Session):
|
||||
self.db = db
|
||||
|
||||
def create_conversation(
|
||||
self,
|
||||
app_id: uuid.UUID,
|
||||
workspace_id: uuid.UUID,
|
||||
user_id: Optional[str] = None,
|
||||
title: Optional[str] = None,
|
||||
is_draft: bool = False,
|
||||
config_snapshot: Optional[dict] = None
|
||||
) -> Conversation:
|
||||
"""
|
||||
Create a new conversation record.
|
||||
|
||||
Args:
|
||||
app_id: Application ID the conversation belongs to.
|
||||
workspace_id: Workspace ID where the conversation is created.
|
||||
user_id: Optional user ID associated with the conversation.
|
||||
title: Optional conversation title. Defaults to "New Conversation".
|
||||
is_draft: Whether the conversation is a draft.
|
||||
config_snapshot: Optional configuration snapshot.
|
||||
|
||||
Returns:
|
||||
Conversation: Newly created Conversation instance.
|
||||
"""
|
||||
conversation = Conversation(
|
||||
app_id=app_id,
|
||||
workspace_id=workspace_id,
|
||||
user_id=user_id,
|
||||
title=title or "New Conversation",
|
||||
is_draft=is_draft,
|
||||
config_snapshot=config_snapshot
|
||||
)
|
||||
self.db.add(conversation)
|
||||
return conversation
|
||||
|
||||
def get_conversation_by_conversation_id(
|
||||
self,
|
||||
conversation_id: uuid.UUID,
|
||||
workspace_id: Optional[uuid.UUID] = None
|
||||
) -> Conversation:
|
||||
"""
|
||||
Retrieve a conversation by its ID, optionally filtered by workspace.
|
||||
|
||||
Args:
|
||||
conversation_id: The UUID of the conversation.
|
||||
workspace_id: Optional workspace UUID to filter the conversation.
|
||||
|
||||
Raises:
|
||||
ResourceNotFoundException: If conversation does not exist.
|
||||
|
||||
Returns:
|
||||
Conversation: The matching Conversation instance.
|
||||
"""
|
||||
logger.info(f"Fetching conversation: {conversation_id}")
|
||||
|
||||
stmt = select(Conversation).where(Conversation.id == conversation_id)
|
||||
|
||||
if workspace_id:
|
||||
stmt = stmt.where(Conversation.workspace_id == workspace_id)
|
||||
|
||||
conversation = self.db.scalars(stmt).first()
|
||||
|
||||
if not conversation:
|
||||
logger.warning(f"Conversation not found: {conversation_id}")
|
||||
raise ResourceNotFoundException("Conversation", str(conversation_id))
|
||||
|
||||
logger.info(f"Conversation fetched successfully: {conversation_id}")
|
||||
return conversation
|
||||
|
||||
def get_conversation_by_user_id(
|
||||
self,
|
||||
user_id: uuid.UUID,
|
||||
workspace_id: uuid.UUID = None,
|
||||
limit: int = 10,
|
||||
is_activate: bool = True
|
||||
) -> list[Conversation]:
|
||||
"""
|
||||
Retrieve recent conversations for a specific user.
|
||||
|
||||
This method queries conversations associated with the given user ID,
|
||||
optionally scoped to a specific workspace. Results are ordered by the
|
||||
most recently updated conversations and limited to a fixed number.
|
||||
|
||||
Args:
|
||||
user_id (uuid.UUID): Unique identifier of the user.
|
||||
workspace_id (uuid.UUID, optional): Workspace scope for the query.
|
||||
If provided, only conversations under this workspace will be returned.
|
||||
limit (int): Maximum number of conversations to return.
|
||||
Defaults to 10.
|
||||
is_activate (bool): Convsersation State limit
|
||||
|
||||
Returns:
|
||||
list[Conversation]: A list of conversation entities ordered by
|
||||
last updated time (descending).
|
||||
"""
|
||||
logger.info(f"Fetching conversation by user_id: {user_id}")
|
||||
|
||||
stmt = select(Conversation).where(
|
||||
Conversation.user_id == str(user_id),
|
||||
Conversation.is_active.is_(is_activate)
|
||||
)
|
||||
|
||||
if workspace_id:
|
||||
stmt = stmt.where(Conversation.workspace_id == workspace_id)
|
||||
|
||||
stmt = stmt.order_by(desc(Conversation.updated_at))
|
||||
stmt = stmt.limit(limit)
|
||||
|
||||
convsersations = list(self.db.scalars(stmt).all())
|
||||
logger.info(
|
||||
"Conversation fetched successfully",
|
||||
extra={
|
||||
"user_id": str(user_id),
|
||||
"workspace_id": str(workspace_id),
|
||||
}
|
||||
)
|
||||
return convsersations
|
||||
|
||||
def list_conversations(
|
||||
self,
|
||||
app_id: uuid.UUID,
|
||||
workspace_id: uuid.UUID,
|
||||
user_id: Optional[str] = None,
|
||||
is_draft: Optional[bool] = None,
|
||||
page: int = 1,
|
||||
pagesize: int = 20
|
||||
) -> tuple[list[Conversation], int]:
|
||||
"""
|
||||
List conversations with optional filters and pagination.
|
||||
|
||||
Args:
|
||||
app_id: Application ID filter.
|
||||
workspace_id: Workspace ID filter.
|
||||
user_id: Optional user ID filter.
|
||||
is_draft: Optional draft status filter.
|
||||
page: Page number (1-based).
|
||||
pagesize: Number of items per page.
|
||||
|
||||
Returns:
|
||||
Tuple[List[Conversation], int]: List of Conversation instances and total count.
|
||||
"""
|
||||
stmt = select(Conversation).where(
|
||||
Conversation.app_id == app_id,
|
||||
Conversation.workspace_id == workspace_id,
|
||||
Conversation.is_active.is_(True)
|
||||
)
|
||||
|
||||
if user_id:
|
||||
stmt = stmt.where(Conversation.user_id == str(user_id))
|
||||
|
||||
if is_draft is not None:
|
||||
stmt = stmt.where(Conversation.is_draft == is_draft)
|
||||
|
||||
# Calculate total number of records
|
||||
total = int(self.db.execute(
|
||||
select(func.count()).select_from(stmt.subquery())
|
||||
).scalar_one())
|
||||
|
||||
# Apply pagination
|
||||
stmt = stmt.order_by(desc(Conversation.updated_at))
|
||||
stmt = stmt.offset((page - 1) * pagesize).limit(pagesize)
|
||||
|
||||
conversations = list(self.db.scalars(stmt).all())
|
||||
|
||||
logger.info(
|
||||
"Listed conversations successfully",
|
||||
extra={
|
||||
"app_id": str(app_id),
|
||||
"workspace_id": str(workspace_id),
|
||||
"returned": len(conversations),
|
||||
"total": total
|
||||
}
|
||||
)
|
||||
return conversations, total
|
||||
|
||||
def soft_delete_conversation_by_conversation_id(
|
||||
self,
|
||||
conversation_id: uuid.UUID,
|
||||
workspace_id: uuid.UUID,
|
||||
):
|
||||
"""
|
||||
Soft delete a conversation by setting is_active to False.
|
||||
|
||||
Args:
|
||||
conversation_id: The UUID of the conversation.
|
||||
workspace_id: Workspace ID for verification.
|
||||
"""
|
||||
conversation = self.get_conversation_by_conversation_id(
|
||||
conversation_id,
|
||||
workspace_id
|
||||
)
|
||||
conversation.is_active = False
|
||||
|
||||
def get_conversation_detail(
|
||||
self,
|
||||
conversation_id: uuid.UUID
|
||||
) -> ConversationDetail | None:
|
||||
"""
|
||||
Retrieve the detail of a conversation by its ID.
|
||||
|
||||
Args:
|
||||
conversation_id (UUID): The unique identifier of the conversation.
|
||||
|
||||
Returns:
|
||||
ConversationDetail or None: The conversation detail object if found,
|
||||
otherwise None.
|
||||
|
||||
Notes:
|
||||
- This method queries the database but does not modify it.
|
||||
- The caller is responsible for handling the case where None is returned.
|
||||
"""
|
||||
stmt = select(ConversationDetail).where(
|
||||
ConversationDetail.conversation_id == conversation_id
|
||||
)
|
||||
detail = self.db.scalars(stmt).first()
|
||||
return detail
|
||||
|
||||
def add_conversation_detail(
|
||||
self,
|
||||
conversation_detail: ConversationDetail,
|
||||
):
|
||||
"""
|
||||
Add a new conversation detail record to the database session.
|
||||
|
||||
Args:
|
||||
conversation_detail (ConversationDetail): The ORM object representing
|
||||
the conversation detail to add.
|
||||
|
||||
Returns:
|
||||
ConversationDetail: The same object added to the session.
|
||||
|
||||
Notes:
|
||||
- This method only adds the object to the current session.
|
||||
- It does not commit the transaction; commit/rollback is handled
|
||||
by the caller.
|
||||
- Useful for batch operations or transactional control.
|
||||
"""
|
||||
self.db.add(conversation_detail)
|
||||
return conversation_detail
|
||||
|
||||
|
||||
class MessageRepository:
|
||||
"""Repository for Message entity, encapsulating CRUD operations."""
|
||||
|
||||
def __init__(self, db: Session):
|
||||
self.db = db
|
||||
|
||||
def add_message(self, message: Message) -> Message:
|
||||
"""
|
||||
Add a new message record to the conversation.
|
||||
|
||||
Args:
|
||||
message (Message): The Message ORM object to be added.
|
||||
|
||||
Returns:
|
||||
Message: The same message object added to the conversation.
|
||||
|
||||
Notes:
|
||||
- This method only adds the object to the current conversation.
|
||||
- It does not commit the transaction; commit/rollback should be handled
|
||||
by the caller.
|
||||
- Useful for transactional control or batch operations.
|
||||
"""
|
||||
self.db.add(message)
|
||||
return message
|
||||
|
||||
def get_message_by_conversation_id(
|
||||
self,
|
||||
conversation_id: uuid.UUID,
|
||||
limit: Optional[int] = None
|
||||
) -> list[Message]:
|
||||
"""
|
||||
Retrieve messages by conversation ID.
|
||||
|
||||
Args:
|
||||
conversation_id: The UUID of the conversation.
|
||||
limit: Optional limit on the number of messages returned.
|
||||
|
||||
Returns:
|
||||
List[Message]: List of Message instances.
|
||||
"""
|
||||
stmt = select(Message).where(
|
||||
Message.conversation_id == conversation_id
|
||||
).order_by(Message.created_at)
|
||||
|
||||
if limit:
|
||||
stmt = stmt.limit(limit)
|
||||
|
||||
messages = list(self.db.scalars(stmt).all())
|
||||
|
||||
logger.info(
|
||||
"Fetched messages successfully",
|
||||
extra={
|
||||
"conversation_id": str(conversation_id),
|
||||
"returned": len(messages)
|
||||
}
|
||||
)
|
||||
return messages
|
||||
@@ -327,7 +327,7 @@ class DataConfigRepository:
|
||||
# 更新字段映射
|
||||
field_mapping = {
|
||||
# 模型选择
|
||||
"llm_id": "llm",
|
||||
"llm_id": "llm_id",
|
||||
"embedding_id": "embedding_id",
|
||||
"rerank_id": "rerank_id",
|
||||
# 记忆萃取引擎
|
||||
|
||||
105
api/app/repositories/forgetting_cycle_history_repository.py
Normal file
105
api/app/repositories/forgetting_cycle_history_repository.py
Normal file
@@ -0,0 +1,105 @@
|
||||
"""
|
||||
遗忘周期历史记录仓储
|
||||
|
||||
提供遗忘周期历史记录的数据访问操作。
|
||||
"""
|
||||
|
||||
from typing import List, Optional
|
||||
from datetime import datetime, timedelta
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import desc, and_
|
||||
|
||||
from app.models.forgetting_cycle_history_model import ForgettingCycleHistory
|
||||
|
||||
|
||||
class ForgettingCycleHistoryRepository:
|
||||
"""遗忘周期历史记录仓储类"""
|
||||
|
||||
def create(
|
||||
self,
|
||||
db: Session,
|
||||
end_user_id: str,
|
||||
execution_time: datetime,
|
||||
merged_count: int,
|
||||
failed_count: int,
|
||||
average_activation_value: Optional[float],
|
||||
total_nodes: int,
|
||||
low_activation_nodes: int,
|
||||
duration_seconds: float,
|
||||
trigger_type: str = "manual"
|
||||
) -> ForgettingCycleHistory:
|
||||
"""
|
||||
创建历史记录
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
end_user_id: 终端用户ID
|
||||
execution_time: 执行时间
|
||||
merged_count: 融合节点数
|
||||
failed_count: 失败节点数
|
||||
average_activation_value: 平均激活值
|
||||
total_nodes: 总节点数
|
||||
low_activation_nodes: 低激活值节点数
|
||||
duration_seconds: 执行耗时
|
||||
trigger_type: 触发类型
|
||||
|
||||
Returns:
|
||||
ForgettingCycleHistory: 创建的历史记录
|
||||
"""
|
||||
history = ForgettingCycleHistory(
|
||||
end_user_id=end_user_id,
|
||||
execution_time=execution_time,
|
||||
merged_count=merged_count,
|
||||
failed_count=failed_count,
|
||||
average_activation_value=average_activation_value,
|
||||
total_nodes=total_nodes,
|
||||
low_activation_nodes=low_activation_nodes,
|
||||
duration_seconds=duration_seconds,
|
||||
trigger_type=trigger_type
|
||||
)
|
||||
|
||||
db.add(history)
|
||||
db.commit()
|
||||
db.refresh(history)
|
||||
|
||||
return history
|
||||
|
||||
def get_recent_by_end_user(
|
||||
self,
|
||||
db: Session,
|
||||
end_user_id: str
|
||||
) -> List[ForgettingCycleHistory]:
|
||||
"""
|
||||
获取指定终端用户的所有历史记录(按时间降序排列)
|
||||
|
||||
注意:此方法返回所有历史记录,调用方需要自行处理日期分组和数量限制。
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
end_user_id: 终端用户ID
|
||||
|
||||
Returns:
|
||||
List[ForgettingCycleHistory]: 历史记录列表,按时间降序排列
|
||||
"""
|
||||
return db.query(ForgettingCycleHistory).filter(
|
||||
ForgettingCycleHistory.end_user_id == end_user_id
|
||||
).order_by(ForgettingCycleHistory.execution_time.desc()).all()
|
||||
|
||||
def get_latest_by_end_user(
|
||||
self,
|
||||
db: Session,
|
||||
end_user_id: str
|
||||
) -> Optional[ForgettingCycleHistory]:
|
||||
"""
|
||||
获取指定终端用户的最新历史记录
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
end_user_id: 终端用户ID
|
||||
|
||||
Returns:
|
||||
Optional[ForgettingCycleHistory]: 最新历史记录
|
||||
"""
|
||||
return db.query(ForgettingCycleHistory).filter(
|
||||
ForgettingCycleHistory.end_user_id == end_user_id
|
||||
).order_by(desc(ForgettingCycleHistory.execution_time)).first()
|
||||
156
api/app/repositories/memory_perceptual_repository.py
Normal file
156
api/app/repositories/memory_perceptual_repository.py
Normal file
@@ -0,0 +1,156 @@
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import List, Tuple, Optional
|
||||
|
||||
from sqlalchemy import and_, desc
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.logging_config import get_db_logger
|
||||
from app.models.memory_perceptual_model import MemoryPerceptualModel, PerceptualType, FileStorageType
|
||||
from app.schemas.memory_perceptual_schema import PerceptualQuerySchema
|
||||
|
||||
db_logger = get_db_logger()
|
||||
|
||||
|
||||
class MemoryPerceptualRepository:
|
||||
"""Data Access Layer for perceptual memory"""
|
||||
|
||||
def __init__(self, db: Session):
|
||||
self.db = db
|
||||
|
||||
# ==================== Create and update ====================
|
||||
def create_perceptual_memory(
|
||||
self,
|
||||
end_user_id: uuid.UUID,
|
||||
perceptual_type: PerceptualType,
|
||||
file_path: str,
|
||||
file_name: str,
|
||||
file_ext: str,
|
||||
summary: Optional[str] = None,
|
||||
meta_data: Optional[dict] = None,
|
||||
storage_service: FileStorageType = FileStorageType.LOCAL
|
||||
|
||||
) -> MemoryPerceptualModel:
|
||||
|
||||
"""Create perceptual memory"""
|
||||
|
||||
db_logger.debug(f"Creating perceptual memory: end_user_id={end_user_id}, "
|
||||
f"type={perceptual_type}, file={file_name}")
|
||||
|
||||
try:
|
||||
perceptual_memory = MemoryPerceptualModel(
|
||||
end_user_id=end_user_id,
|
||||
perceptual_type=perceptual_type,
|
||||
storage_service=storage_service,
|
||||
file_path=file_path,
|
||||
file_name=file_name,
|
||||
file_ext=file_ext,
|
||||
summary=summary,
|
||||
meta_data=meta_data,
|
||||
created_time=datetime.now()
|
||||
)
|
||||
|
||||
self.db.add(perceptual_memory)
|
||||
self.db.flush()
|
||||
|
||||
db_logger.info(f"Perceptual memory created successfully: id={perceptual_memory.id}, file={file_name}")
|
||||
return perceptual_memory
|
||||
|
||||
except Exception as e:
|
||||
db_logger.error(f"Failed to create perceptual memory: end_user_id={end_user_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
# ==================== Query ====================
|
||||
def get_count_by_user_id(
|
||||
self,
|
||||
end_user_id: uuid.UUID,
|
||||
):
|
||||
db_logger.debug(f"Querying perceptual memory Count: end_user_id={end_user_id}")
|
||||
|
||||
try:
|
||||
count = self.db.query(MemoryPerceptualModel).filter(
|
||||
MemoryPerceptualModel.end_user_id == end_user_id
|
||||
).count()
|
||||
return count
|
||||
except Exception as e:
|
||||
db_logger.error(f"Failed to query perceptual memory count: end_user_id={end_user_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
def get_count_by_type(
|
||||
self,
|
||||
end_user_id: uuid.UUID,
|
||||
perceptual_type: PerceptualType,
|
||||
):
|
||||
db_logger.debug(f"Querying perceptual memory Count: end_user_id={end_user_id}, type={perceptual_type}")
|
||||
|
||||
try:
|
||||
count = self.db.query(MemoryPerceptualModel).filter(
|
||||
MemoryPerceptualModel.end_user_id == end_user_id,
|
||||
MemoryPerceptualModel.perceptual_type == perceptual_type
|
||||
).count()
|
||||
return count
|
||||
except Exception as e:
|
||||
db_logger.error(f"Failed to query perceptual memory count: end_user_id={end_user_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
def get_timeline(
|
||||
self,
|
||||
end_user_id: uuid.UUID,
|
||||
query: PerceptualQuerySchema
|
||||
) -> Tuple[int, List[MemoryPerceptualModel]]:
|
||||
"""Get the timeline of a user's perceptual memories"""
|
||||
db_logger.debug(f"Querying perceptual memory timeline: end_user_id={end_user_id}, filter={query.filter}")
|
||||
|
||||
try:
|
||||
base_query = self.db.query(MemoryPerceptualModel).filter(
|
||||
MemoryPerceptualModel.end_user_id == end_user_id
|
||||
)
|
||||
|
||||
if query.filter.type is not None:
|
||||
base_query = base_query.filter(
|
||||
MemoryPerceptualModel.perceptual_type == query.filter.type
|
||||
)
|
||||
|
||||
total_count = base_query.count()
|
||||
|
||||
memories = base_query.order_by(
|
||||
desc(MemoryPerceptualModel.created_time)
|
||||
).offset(
|
||||
(query.page - 1) * query.page_size
|
||||
).limit(query.page_size).all()
|
||||
|
||||
db_logger.info(
|
||||
f"Perceptual memory timeline query succeeded: end_user_id={end_user_id}, total={total_count}, returned={len(memories)}")
|
||||
return total_count, memories
|
||||
|
||||
except Exception as e:
|
||||
db_logger.error(f"Failed to query perceptual memory timeline: end_user_id={end_user_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
def get_by_type(
|
||||
self,
|
||||
end_user_id: uuid.UUID,
|
||||
perceptual_type: PerceptualType,
|
||||
limit: int = 10,
|
||||
offset: int = 0
|
||||
) -> List[MemoryPerceptualModel]:
|
||||
"""Get memories by perceptual type"""
|
||||
db_logger.debug(f"Querying perceptual memories by type: end_user_id={end_user_id}, type={perceptual_type}")
|
||||
|
||||
try:
|
||||
memories = self.db.query(MemoryPerceptualModel).filter(
|
||||
and_(
|
||||
MemoryPerceptualModel.end_user_id == end_user_id,
|
||||
MemoryPerceptualModel.perceptual_type == perceptual_type
|
||||
)
|
||||
).order_by(
|
||||
desc(MemoryPerceptualModel.created_time)
|
||||
).offset(offset).limit(limit).all()
|
||||
|
||||
db_logger.debug(f"Query by type succeeded: count={len(memories)}")
|
||||
return memories
|
||||
|
||||
except Exception as e:
|
||||
db_logger.error(f"Failed to query perceptual memories by type: end_user_id={end_user_id}, "
|
||||
f"type={perceptual_type} - {str(e)}")
|
||||
raise
|
||||
503
api/app/repositories/memory_short_repository.py
Normal file
503
api/app/repositories/memory_short_repository.py
Normal file
@@ -0,0 +1,503 @@
|
||||
"""
|
||||
记忆仓储模块 - 短期记忆和长期记忆的数据访问层
|
||||
"""
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import List, Optional, Dict, Any
|
||||
import uuid
|
||||
import datetime
|
||||
|
||||
from app.models.memory_short_model import ShortTermMemory, LongTermMemory
|
||||
from app.core.logging_config import get_db_logger
|
||||
|
||||
# 获取数据库专用日志器
|
||||
db_logger = get_db_logger()
|
||||
|
||||
|
||||
class ShortTermMemoryRepository:
|
||||
"""短期记忆仓储类"""
|
||||
|
||||
def __init__(self, db: Session):
|
||||
self.db = db
|
||||
|
||||
def create(self, end_user_id: str, messages: str, aimessages: str = None, search_switch: str = None, retrieved_content: List[Dict] = None) -> ShortTermMemory:
|
||||
"""创建短期记忆记录
|
||||
|
||||
Args:
|
||||
end_user_id: 终端用户ID
|
||||
messages: 用户消息内容
|
||||
aimessages: AI回复消息内容
|
||||
search_switch: 搜索开关状态
|
||||
retrieved_content: 检索到的相关内容,格式为[{}, {}]
|
||||
|
||||
Returns:
|
||||
ShortTermMemory: 创建的短期记忆对象
|
||||
"""
|
||||
try:
|
||||
memory = ShortTermMemory(
|
||||
end_user_id=end_user_id,
|
||||
messages=messages,
|
||||
aimessages=aimessages,
|
||||
search_switch=search_switch,
|
||||
retrieved_content=retrieved_content or []
|
||||
)
|
||||
|
||||
self.db.add(memory)
|
||||
self.db.commit()
|
||||
self.db.refresh(memory)
|
||||
|
||||
db_logger.info(f"成功创建短期记忆记录: {memory.id} for user {end_user_id}")
|
||||
return memory
|
||||
|
||||
except Exception as e:
|
||||
self.db.rollback()
|
||||
db_logger.error(f"创建短期记忆记录时出错: {str(e)}")
|
||||
raise
|
||||
|
||||
def count_by_user_id(self,end_user_id: str) -> int:
|
||||
"""根据ID获取短期记忆记录
|
||||
|
||||
Args:
|
||||
memory_id: 记忆ID
|
||||
|
||||
Returns:
|
||||
Optional[ShortTermMemory]: 记忆对象,如果不存在则返回None
|
||||
"""
|
||||
try:
|
||||
count = (
|
||||
self.db.query(ShortTermMemory)
|
||||
.filter(ShortTermMemory.end_user_id == end_user_id)
|
||||
.count()
|
||||
)
|
||||
db_logger.debug(f"成功统计用户 {end_user_id} 的短期记忆数量: {count}")
|
||||
|
||||
return count
|
||||
|
||||
except Exception as e:
|
||||
self.db.rollback()
|
||||
db_logger.error(f"查询短期记忆记录 {count} 时出错: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
def get_latest_by_user_id(self, end_user_id: str, limit: int = 5) -> List[ShortTermMemory]:
|
||||
"""获取用户最新的短期记忆记录
|
||||
|
||||
Args:
|
||||
end_user_id: 终端用户ID
|
||||
limit: 返回记录数限制,默认5条
|
||||
|
||||
Returns:
|
||||
List[ShortTermMemory]: 最新的记忆记录列表,按创建时间倒序
|
||||
"""
|
||||
try:
|
||||
# 使用复合索引 ix_memory_short_term_user_time 优化查询
|
||||
memories = (
|
||||
self.db.query(ShortTermMemory)
|
||||
.filter(ShortTermMemory.end_user_id == end_user_id)
|
||||
.order_by(ShortTermMemory.created_at.desc())
|
||||
.limit(limit)
|
||||
.all()
|
||||
)
|
||||
|
||||
db_logger.info(f"成功查询用户 {end_user_id} 的最新 {len(memories)} 条短期记忆记录")
|
||||
return memories
|
||||
|
||||
except Exception as e:
|
||||
self.db.rollback()
|
||||
db_logger.error(f"查询用户 {end_user_id} 的最新短期记忆记录时出错: {str(e)}")
|
||||
raise
|
||||
|
||||
def get_recent_by_user_id(self, end_user_id: str, hours: int = 24) -> List[ShortTermMemory]:
|
||||
"""获取用户最近指定小时内的短期记忆记录
|
||||
|
||||
Args:
|
||||
end_user_id: 终端用户ID
|
||||
hours: 时间范围(小时),默认24小时
|
||||
|
||||
Returns:
|
||||
List[ShortTermMemory]: 记忆记录列表,按创建时间倒序
|
||||
"""
|
||||
try:
|
||||
cutoff_time = datetime.datetime.now() - datetime.timedelta(hours=hours)
|
||||
|
||||
# 使用复合索引 ix_memory_short_term_user_time 优化查询
|
||||
memories = (
|
||||
self.db.query(ShortTermMemory)
|
||||
.filter(
|
||||
ShortTermMemory.end_user_id == end_user_id,
|
||||
ShortTermMemory.created_at >= cutoff_time
|
||||
)
|
||||
.order_by(ShortTermMemory.created_at.desc())
|
||||
.all()
|
||||
)
|
||||
|
||||
db_logger.info(f"成功查询用户 {end_user_id} 最近 {hours} 小时的 {len(memories)} 条短期记忆记录")
|
||||
return memories
|
||||
|
||||
except Exception as e:
|
||||
self.db.rollback()
|
||||
db_logger.error(f"查询用户 {end_user_id} 最近 {hours} 小时的短期记忆记录时出错: {str(e)}")
|
||||
raise
|
||||
|
||||
def delete_by_id(self, memory_id: uuid.UUID) -> bool:
|
||||
"""删除指定ID的短期记忆记录
|
||||
|
||||
Args:
|
||||
memory_id: 记忆ID
|
||||
|
||||
Returns:
|
||||
bool: 删除成功返回True,否则返回False
|
||||
"""
|
||||
try:
|
||||
deleted_count = (
|
||||
self.db.query(ShortTermMemory)
|
||||
.filter(ShortTermMemory.id == memory_id)
|
||||
.delete(synchronize_session=False)
|
||||
)
|
||||
|
||||
self.db.commit()
|
||||
|
||||
if deleted_count > 0:
|
||||
db_logger.info(f"成功删除短期记忆记录 {memory_id}")
|
||||
return True
|
||||
else:
|
||||
db_logger.warning(f"未找到短期记忆记录 {memory_id},无法删除")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
self.db.rollback()
|
||||
db_logger.error(f"删除短期记忆记录 {memory_id} 时出错: {str(e)}")
|
||||
raise
|
||||
|
||||
def delete_old_memories(self, days: int = 7) -> int:
|
||||
"""删除指定天数之前的短期记忆记录
|
||||
|
||||
Args:
|
||||
days: 保留天数,默认7天
|
||||
|
||||
Returns:
|
||||
int: 删除的记录数
|
||||
"""
|
||||
try:
|
||||
cutoff_time = datetime.datetime.now() - datetime.timedelta(days=days)
|
||||
|
||||
deleted_count = (
|
||||
self.db.query(ShortTermMemory)
|
||||
.filter(ShortTermMemory.created_at < cutoff_time)
|
||||
.delete(synchronize_session=False)
|
||||
)
|
||||
|
||||
self.db.commit()
|
||||
|
||||
db_logger.info(f"成功删除 {days} 天前的 {deleted_count} 条短期记忆记录")
|
||||
return deleted_count
|
||||
|
||||
except Exception as e:
|
||||
self.db.rollback()
|
||||
db_logger.error(f"删除 {days} 天前的短期记忆记录时出错: {str(e)}")
|
||||
raise
|
||||
|
||||
def upsert(self, end_user_id: str, messages: str, aimessages: str = None, search_switch: str = None, retrieved_content: List[Dict] = None) -> ShortTermMemory:
|
||||
"""创建或更新短期记忆记录
|
||||
|
||||
根据 end_user_id、messages 和 aimessages 查找现有记录:
|
||||
- 如果找到匹配的记录,则更新 messages、aimessages、search_switch 和 retrieved_content
|
||||
- 如果没有找到匹配的记录,则创建新记录
|
||||
|
||||
Args:
|
||||
end_user_id: 终端用户ID
|
||||
messages: 用户消息内容
|
||||
aimessages: AI回复消息内容
|
||||
search_switch: 搜索开关状态
|
||||
retrieved_content: 检索到的相关内容,格式为[{}, {}]
|
||||
|
||||
Returns:
|
||||
ShortTermMemory: 创建或更新的短期记忆对象
|
||||
"""
|
||||
try:
|
||||
# 构建查询条件,使用复合索引 ix_memory_short_term_user_messages 优化查询
|
||||
query_filters = [
|
||||
ShortTermMemory.end_user_id == end_user_id,
|
||||
ShortTermMemory.messages == messages
|
||||
]
|
||||
|
||||
# 如果 aimessages 不为空,则加入查询条件
|
||||
if aimessages is not None:
|
||||
query_filters.append(ShortTermMemory.aimessages == aimessages)
|
||||
else:
|
||||
# 如果 aimessages 为 None,则查找 aimessages 为 NULL 的记录
|
||||
query_filters.append(ShortTermMemory.aimessages.is_(None))
|
||||
|
||||
# 查找现有记录
|
||||
existing_memory = (
|
||||
self.db.query(ShortTermMemory)
|
||||
.filter(*query_filters)
|
||||
.first()
|
||||
)
|
||||
|
||||
if existing_memory:
|
||||
# 更新现有记录
|
||||
existing_memory.messages = messages
|
||||
existing_memory.aimessages = aimessages
|
||||
existing_memory.search_switch = search_switch
|
||||
existing_memory.retrieved_content = retrieved_content or []
|
||||
|
||||
self.db.commit()
|
||||
self.db.refresh(existing_memory)
|
||||
|
||||
db_logger.info(f"成功更新短期记忆记录: {existing_memory.id} for user {end_user_id}")
|
||||
return existing_memory
|
||||
else:
|
||||
# 创建新记录
|
||||
new_memory = ShortTermMemory(
|
||||
end_user_id=end_user_id,
|
||||
messages=messages,
|
||||
aimessages=aimessages,
|
||||
search_switch=search_switch,
|
||||
retrieved_content=retrieved_content or []
|
||||
)
|
||||
|
||||
self.db.add(new_memory)
|
||||
self.db.commit()
|
||||
self.db.refresh(new_memory)
|
||||
|
||||
db_logger.info(f"成功创建新的短期记忆记录: {new_memory.id} for user {end_user_id}")
|
||||
return new_memory
|
||||
|
||||
except Exception as e:
|
||||
self.db.rollback()
|
||||
db_logger.error(f"创建或更新短期记忆记录时出错: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
class LongTermMemoryRepository:
|
||||
"""长期记忆仓储类"""
|
||||
|
||||
def __init__(self, db: Session):
|
||||
self.db = db
|
||||
|
||||
def create(self, end_user_id: str, retrieved_content: List[Dict] = None) -> LongTermMemory:
|
||||
"""创建长期记忆记录
|
||||
|
||||
Args:
|
||||
end_user_id: 终端用户ID
|
||||
retrieved_content: 检索到的相关内容,格式为[{}, {}]
|
||||
|
||||
Returns:
|
||||
LongTermMemory: 创建的长期记忆对象
|
||||
"""
|
||||
try:
|
||||
memory = LongTermMemory(
|
||||
end_user_id=end_user_id,
|
||||
retrieved_content=retrieved_content or []
|
||||
)
|
||||
|
||||
self.db.add(memory)
|
||||
self.db.commit()
|
||||
self.db.refresh(memory)
|
||||
|
||||
db_logger.info(f"成功创建长期记忆记录: {memory.id} for user {end_user_id}")
|
||||
return memory
|
||||
|
||||
except Exception as e:
|
||||
self.db.rollback()
|
||||
db_logger.error(f"创建长期记忆记录时出错: {str(e)}")
|
||||
raise
|
||||
|
||||
def get_by_id(self, memory_id: uuid.UUID) -> Optional[LongTermMemory]:
|
||||
"""根据ID获取长期记忆记录
|
||||
|
||||
Args:
|
||||
memory_id: 记忆ID
|
||||
|
||||
Returns:
|
||||
Optional[LongTermMemory]: 记忆对象,如果不存在则返回None
|
||||
"""
|
||||
try:
|
||||
memory = (
|
||||
self.db.query(LongTermMemory)
|
||||
.filter(LongTermMemory.id == memory_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if memory:
|
||||
db_logger.debug(f"成功查询到长期记忆记录 {memory_id}")
|
||||
else:
|
||||
db_logger.debug(f"未找到长期记忆记录 {memory_id}")
|
||||
|
||||
return memory
|
||||
|
||||
except Exception as e:
|
||||
self.db.rollback()
|
||||
db_logger.error(f"查询长期记忆记录 {memory_id} 时出错: {str(e)}")
|
||||
raise
|
||||
|
||||
def get_by_user_id(self, end_user_id: str, limit: int = 100, offset: int = 0) -> List[LongTermMemory]:
|
||||
"""根据用户ID获取长期记忆记录列表
|
||||
|
||||
Args:
|
||||
end_user_id: 终端用户ID
|
||||
limit: 返回记录数限制,默认100
|
||||
offset: 偏移量,默认0
|
||||
|
||||
Returns:
|
||||
List[LongTermMemory]: 记忆记录列表,按创建时间倒序
|
||||
"""
|
||||
try:
|
||||
# 使用复合索引 ix_memory_long_term_user_time 优化查询
|
||||
memories = (
|
||||
self.db.query(LongTermMemory)
|
||||
.filter(LongTermMemory.end_user_id == end_user_id)
|
||||
.order_by(LongTermMemory.created_at.desc())
|
||||
.limit(limit)
|
||||
.offset(offset)
|
||||
.all()
|
||||
)
|
||||
|
||||
db_logger.info(f"成功查询用户 {end_user_id} 的 {len(memories)} 条长期记忆记录")
|
||||
return memories
|
||||
|
||||
except Exception as e:
|
||||
self.db.rollback()
|
||||
db_logger.error(f"查询用户 {end_user_id} 的长期记忆记录时出错: {str(e)}")
|
||||
raise
|
||||
|
||||
def search_by_content(self, end_user_id: str, keyword: str, limit: int = 50) -> List[LongTermMemory]:
|
||||
"""根据内容关键词搜索长期记忆记录
|
||||
|
||||
Args:
|
||||
end_user_id: 终端用户ID
|
||||
keyword: 搜索关键词
|
||||
limit: 返回记录数限制,默认50
|
||||
|
||||
Returns:
|
||||
List[LongTermMemory]: 匹配的记忆记录列表,按创建时间倒序
|
||||
"""
|
||||
try:
|
||||
# 使用 GIN 索引 ix_memory_long_term_retrieved_content_gin 优化 JSON 搜索
|
||||
# 同时使用复合索引 ix_memory_long_term_user_time 优化用户过滤
|
||||
memories = (
|
||||
self.db.query(LongTermMemory)
|
||||
.filter(
|
||||
LongTermMemory.end_user_id == end_user_id,
|
||||
LongTermMemory.retrieved_content.astext.contains(keyword)
|
||||
)
|
||||
.order_by(LongTermMemory.created_at.desc())
|
||||
.limit(limit)
|
||||
.all()
|
||||
)
|
||||
|
||||
db_logger.info(f"成功搜索用户 {end_user_id} 包含关键词 '{keyword}' 的 {len(memories)} 条长期记忆记录")
|
||||
return memories
|
||||
|
||||
except Exception as e:
|
||||
self.db.rollback()
|
||||
db_logger.error(f"搜索用户 {end_user_id} 包含关键词 '{keyword}' 的长期记忆记录时出错: {str(e)}")
|
||||
raise
|
||||
|
||||
def delete_by_id(self, memory_id: uuid.UUID) -> bool:
|
||||
"""删除指定ID的长期记忆记录
|
||||
|
||||
Args:
|
||||
memory_id: 记忆ID
|
||||
|
||||
Returns:
|
||||
bool: 删除成功返回True,否则返回False
|
||||
"""
|
||||
try:
|
||||
deleted_count = (
|
||||
self.db.query(LongTermMemory)
|
||||
.filter(LongTermMemory.id == memory_id)
|
||||
.delete(synchronize_session=False)
|
||||
)
|
||||
|
||||
self.db.commit()
|
||||
|
||||
if deleted_count > 0:
|
||||
db_logger.info(f"成功删除长期记忆记录 {memory_id}")
|
||||
return True
|
||||
else:
|
||||
db_logger.warning(f"未找到长期记忆记录 {memory_id},无法删除")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
self.db.rollback()
|
||||
db_logger.error(f"删除长期记忆记录 {memory_id} 时出错: {str(e)}")
|
||||
raise
|
||||
|
||||
def count_by_user_id(self, end_user_id: str) -> int:
|
||||
"""统计用户的长期记忆记录数量
|
||||
|
||||
Args:
|
||||
end_user_id: 终端用户ID
|
||||
|
||||
Returns:
|
||||
int: 记录数量
|
||||
"""
|
||||
try:
|
||||
count = (
|
||||
self.db.query(LongTermMemory)
|
||||
.filter(LongTermMemory.end_user_id == end_user_id)
|
||||
.count()
|
||||
)
|
||||
|
||||
db_logger.debug(f"用户 {end_user_id} 共有 {count} 条长期记忆记录")
|
||||
return count
|
||||
|
||||
except Exception as e:
|
||||
self.db.rollback()
|
||||
db_logger.error(f"统计用户 {end_user_id} 的长期记忆记录数量时出错: {str(e)}")
|
||||
raise
|
||||
|
||||
def upsert(self, end_user_id: str, retrieved_content: List[Dict] = None) -> Optional[LongTermMemory]:
|
||||
"""创建或更新长期记忆记录
|
||||
|
||||
根据 end_user_id 和 retrieved_content 判断是否需要写入:
|
||||
- 如果找到相同的 end_user_id 和 retrieved_content,则不写入,返回 None
|
||||
- 如果没有找到相同的记录,则创建新记录
|
||||
|
||||
Args:
|
||||
end_user_id: 终端用户ID
|
||||
retrieved_content: 检索到的相关内容,格式为[{}, {}]
|
||||
|
||||
Returns:
|
||||
Optional[LongTermMemory]: 创建的长期记忆对象,如果不需要写入则返回 None
|
||||
"""
|
||||
try:
|
||||
retrieved_content = retrieved_content or []
|
||||
|
||||
# 优化查询:使用复合索引 ix_memory_long_term_user_time 先过滤用户
|
||||
# 然后在应用层比较 JSON 内容,避免复杂的数据库 JSON 比较
|
||||
existing_memories = (
|
||||
self.db.query(LongTermMemory)
|
||||
.filter(LongTermMemory.end_user_id == end_user_id)
|
||||
.order_by(LongTermMemory.created_at.desc())
|
||||
.limit(100) # 限制查询数量,避免加载过多数据
|
||||
.all()
|
||||
)
|
||||
|
||||
# 在 Python 中比较 retrieved_content
|
||||
for memory in existing_memories:
|
||||
if memory.retrieved_content == retrieved_content:
|
||||
# 如果找到相同的记录,不写入
|
||||
db_logger.info(f"长期记忆记录已存在,跳过写入: user {end_user_id}")
|
||||
return None
|
||||
|
||||
# 如果没有找到相同的记录,创建新记录
|
||||
new_memory = LongTermMemory(
|
||||
end_user_id=end_user_id,
|
||||
retrieved_content=retrieved_content
|
||||
)
|
||||
|
||||
self.db.add(new_memory)
|
||||
self.db.commit()
|
||||
self.db.refresh(new_memory)
|
||||
|
||||
db_logger.info(f"成功创建新的长期记忆记录: {new_memory.id} for user {end_user_id}")
|
||||
return new_memory
|
||||
|
||||
except Exception as e:
|
||||
self.db.rollback()
|
||||
db_logger.error(f"创建或更新长期记忆记录时出错: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
@@ -211,6 +211,7 @@ async def add_memory_summary_nodes(summaries: List[MemorySummaryNode], connector
|
||||
"dialog_id": s.dialog_id,
|
||||
"chunk_ids": s.chunk_ids,
|
||||
"content": s.content,
|
||||
"memory_type": s.memory_type, # 添加 memory_type 字段
|
||||
"summary_embedding": s.summary_embedding if s.summary_embedding else None,
|
||||
"config_id": s.config_id, # 添加 config_id
|
||||
})
|
||||
|
||||
@@ -92,6 +92,11 @@ SET e.name = CASE WHEN entity.name IS NOT NULL AND entity.name <> '' THEN entity
|
||||
WHEN entity.description IS NOT NULL AND entity.description <> ''
|
||||
AND (e.description IS NULL OR size(e.description) = 0 OR size(entity.description) > size(e.description))
|
||||
THEN entity.description ELSE e.description END,
|
||||
e.example = CASE
|
||||
WHEN entity.example IS NOT NULL AND entity.example <> ''
|
||||
THEN entity.example
|
||||
ELSE coalesce(e.example, '')
|
||||
END,
|
||||
e.statement_id = CASE WHEN entity.statement_id IS NOT NULL AND entity.statement_id <> '' THEN entity.statement_id ELSE e.statement_id END,
|
||||
e.aliases = CASE
|
||||
WHEN entity.aliases IS NOT NULL AND size(entity.aliases) > 0
|
||||
@@ -121,7 +126,8 @@ SET e.name = CASE WHEN entity.name IS NOT NULL AND entity.name <> '' THEN entity
|
||||
e.activation_value = CASE WHEN entity.activation_value IS NOT NULL THEN entity.activation_value ELSE e.activation_value END,
|
||||
e.access_history = CASE WHEN entity.access_history IS NOT NULL THEN entity.access_history ELSE coalesce(e.access_history, []) END,
|
||||
e.last_access_time = CASE WHEN entity.last_access_time IS NOT NULL THEN entity.last_access_time ELSE e.last_access_time END,
|
||||
e.access_count = CASE WHEN entity.access_count IS NOT NULL THEN entity.access_count ELSE coalesce(e.access_count, 0) END
|
||||
e.access_count = CASE WHEN entity.access_count IS NOT NULL THEN entity.access_count ELSE coalesce(e.access_count, 0) END,
|
||||
e.is_explicit_memory = CASE WHEN entity.is_explicit_memory IS NOT NULL THEN entity.is_explicit_memory ELSE coalesce(e.is_explicit_memory, false) END
|
||||
RETURN e.id AS uuid
|
||||
"""
|
||||
|
||||
@@ -722,7 +728,12 @@ SET m += {
|
||||
chunk_ids: summary.chunk_ids,
|
||||
content: summary.content,
|
||||
summary_embedding: summary.summary_embedding,
|
||||
config_id: summary.config_id
|
||||
config_id: summary.config_id,
|
||||
importance_score: CASE WHEN summary.importance_score IS NOT NULL THEN summary.importance_score ELSE coalesce(m.importance_score, 0.5) END,
|
||||
activation_value: CASE WHEN summary.activation_value IS NOT NULL THEN summary.activation_value ELSE m.activation_value END,
|
||||
access_history: CASE WHEN summary.access_history IS NOT NULL THEN summary.access_history ELSE coalesce(m.access_history, []) END,
|
||||
last_access_time: CASE WHEN summary.last_access_time IS NOT NULL THEN summary.last_access_time ELSE m.last_access_time END,
|
||||
access_count: CASE WHEN summary.access_count IS NOT NULL THEN summary.access_count ELSE coalesce(m.access_count, 0) END
|
||||
}
|
||||
RETURN m.id AS uuid
|
||||
"""
|
||||
@@ -857,3 +868,174 @@ neo4j_query_all = """
|
||||
"""
|
||||
|
||||
|
||||
'''针对当前节点下扩长的句子,实体和总结'''
|
||||
Memory_Timeline_ExtractedEntity="""
|
||||
MATCH (n)-[r1]-(e)-[r2]-(ms)
|
||||
WHERE elementId(n) = $id
|
||||
AND (ms:ExtractedEntity OR ms:MemorySummary)
|
||||
|
||||
RETURN
|
||||
collect(
|
||||
DISTINCT
|
||||
CASE
|
||||
WHEN ms:ExtractedEntity THEN {
|
||||
text: ms.name,
|
||||
created_at: ms.created_at,
|
||||
type: "情景记忆"
|
||||
}
|
||||
END
|
||||
) AS ExtractedEntity,
|
||||
|
||||
collect(
|
||||
DISTINCT
|
||||
CASE
|
||||
WHEN ms:MemorySummary THEN {
|
||||
text: ms.content,
|
||||
created_at: ms.created_at,
|
||||
type: "长期沉淀"
|
||||
}
|
||||
END
|
||||
) AS MemorySummary,
|
||||
|
||||
collect(
|
||||
DISTINCT {
|
||||
text: e.statement,
|
||||
created_at: e.created_at,
|
||||
type: "情绪记忆"
|
||||
}
|
||||
) AS statement;
|
||||
|
||||
|
||||
"""
|
||||
Memory_Timeline_MemorySummary="""
|
||||
MATCH (n)-[r1]-(e)-[r2]-(ms)
|
||||
WHERE elementId(n) =$id
|
||||
AND (ms:MemorySummary OR ms:ExtractedEntity)
|
||||
RETURN
|
||||
collect(
|
||||
DISTINCT
|
||||
CASE
|
||||
WHEN ms:ExtractedEntity THEN {
|
||||
text: ms.name,
|
||||
created_at: ms.created_at
|
||||
}
|
||||
END
|
||||
) AS ExtractedEntity,
|
||||
|
||||
collect(
|
||||
DISTINCT
|
||||
CASE
|
||||
WHEN n:MemorySummary THEN {
|
||||
text: n.content,
|
||||
created_at: n.created_at
|
||||
}
|
||||
END
|
||||
) AS MemorySummary,
|
||||
|
||||
collect(
|
||||
DISTINCT {
|
||||
text: e.statement,
|
||||
created_at: e.created_at
|
||||
}
|
||||
) AS statement;
|
||||
"""
|
||||
Memory_Timeline_Statement="""
|
||||
MATCH (n)
|
||||
WHERE elementId(n) = $id
|
||||
|
||||
CALL {
|
||||
WITH n
|
||||
MATCH (n)-[]-(m:ExtractedEntity)
|
||||
WHERE NOT m:MemorySummary AND NOT m:Chunk
|
||||
RETURN collect(
|
||||
DISTINCT {
|
||||
text: m.name,
|
||||
created_at: m.created_at,
|
||||
type: "情景记忆"
|
||||
}
|
||||
) AS ExtractedEntity
|
||||
}
|
||||
|
||||
CALL {
|
||||
WITH n
|
||||
MATCH (n)-[]-(m:MemorySummary)
|
||||
WHERE NOT m:Chunk
|
||||
RETURN collect(
|
||||
DISTINCT {
|
||||
text: m.content,
|
||||
created_at: m.created_at,
|
||||
type: "长期沉淀"
|
||||
}
|
||||
) AS MemorySummary
|
||||
}
|
||||
|
||||
RETURN
|
||||
ExtractedEntity,
|
||||
MemorySummary,
|
||||
{
|
||||
text: n.statement,
|
||||
created_at: n.created_at,
|
||||
type: "情绪记忆"
|
||||
} AS statement;
|
||||
|
||||
|
||||
"""
|
||||
|
||||
'''针对当前节点,主要获取更加完整的句子节点'''
|
||||
Memory_Space_Emotion_Statement="""
|
||||
MATCH (n)
|
||||
WHERE elementId(n) = $id
|
||||
RETURN
|
||||
n.emotion_intensity AS emotion_intensity,
|
||||
n.created_at AS created_at,
|
||||
n.emotion_type AS emotion_type,
|
||||
n.statement AS statement;
|
||||
|
||||
"""
|
||||
Memory_Space_Emotion_MemorySummary="""
|
||||
MATCH (n)-[]-(e)
|
||||
WHERE elementId(n) = $id
|
||||
AND EXISTS {
|
||||
MATCH (e)-[]-(ms)
|
||||
WHERE ms:MemorySummary OR ms:ExtractedEntity
|
||||
}
|
||||
RETURN DISTINCT
|
||||
e.emotion_intensity AS emotion_intensity,
|
||||
e.created_at AS created_at,
|
||||
e.emotion_type AS emotion_type,
|
||||
e.statement AS statement;
|
||||
"""
|
||||
Memory_Space_Emotion_ExtractedEntity="""
|
||||
MATCH (n)-[]-(e)
|
||||
WHERE elementId(n) = $id
|
||||
AND EXISTS {
|
||||
MATCH (e)-[]-(ms:ExtractedEntity)
|
||||
}
|
||||
RETURN DISTINCT
|
||||
e.emotion_intensity AS emotion_intensity,
|
||||
e.created_at AS created_at,
|
||||
e.emotion_type AS emotion_type,
|
||||
e.statement AS statement;
|
||||
"""
|
||||
|
||||
'''获取实体'''
|
||||
|
||||
Memory_Space_User="""
|
||||
MATCH (n)-[r]->(m)
|
||||
WHERE n.group_id = $group_id AND m.name="用户"
|
||||
return DISTINCT elementId(m) as id
|
||||
"""
|
||||
Memory_Space_Entity="""
|
||||
MATCH (n)-[]-(m)
|
||||
WHERE elementId(m) = $id AND m.entity_type = "Person"
|
||||
RETURN
|
||||
DISTINCT m.name as name,m.group_id as group_id
|
||||
"""
|
||||
Memory_Space_Associative="""
|
||||
MATCH (u)-[]-(x)-[]-(h)
|
||||
WHERE elementId(u) = $user_id
|
||||
AND elementId(h) = $id
|
||||
RETURN DISTINCT
|
||||
x.statement as statement,x.created_at as created_at
|
||||
"""
|
||||
|
||||
|
||||
@@ -58,7 +58,7 @@ class EntityRepository(BaseNeo4jRepository[ExtractedEntityNode]):
|
||||
# 处理 ACT-R 属性 - 确保字段存在且有默认值
|
||||
n['importance_score'] = n.get('importance_score', 0.5)
|
||||
n['activation_value'] = n.get('activation_value')
|
||||
n['access_history'] = n.get('access_history', [])
|
||||
n['access_history'] = n.get('access_history') or []
|
||||
n['last_access_time'] = n.get('last_access_time')
|
||||
n['access_count'] = n.get('access_count', 0)
|
||||
|
||||
|
||||
273
api/app/repositories/neo4j/memory_summary_repository.py
Normal file
273
api/app/repositories/neo4j/memory_summary_repository.py
Normal file
@@ -0,0 +1,273 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Memory Summary Repository Module
|
||||
|
||||
This module provides data access functionality for MemorySummary nodes.
|
||||
|
||||
Classes:
|
||||
MemorySummaryRepository: Repository for managing MemorySummary CRUD operations
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from app.repositories.neo4j.base_neo4j_repository import BaseNeo4jRepository
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
|
||||
|
||||
class MemorySummaryRepository(BaseNeo4jRepository):
|
||||
"""Memory Summary Repository
|
||||
|
||||
Manages CRUD operations for MemorySummary nodes.
|
||||
Provides methods to query summaries by group_id, user_id, and time ranges.
|
||||
|
||||
Attributes:
|
||||
connector: Neo4j connector instance
|
||||
node_label: Node label, fixed as "MemorySummary"
|
||||
"""
|
||||
|
||||
def __init__(self, connector: Neo4jConnector):
|
||||
"""Initialize memory summary repository
|
||||
|
||||
Args:
|
||||
connector: Neo4j connector instance
|
||||
"""
|
||||
super().__init__(connector, "MemorySummary")
|
||||
|
||||
def _map_to_dict(self, node_data: Dict) -> Dict[str, Any]:
|
||||
"""Map node data to dictionary format
|
||||
|
||||
Args:
|
||||
node_data: Node data returned from Neo4j query
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: Memory summary data dictionary
|
||||
"""
|
||||
# Extract node data from query result
|
||||
n = node_data.get('n', node_data)
|
||||
|
||||
# Handle datetime fields
|
||||
if isinstance(n.get('created_at'), str):
|
||||
n['created_at'] = datetime.fromisoformat(n['created_at'])
|
||||
|
||||
return dict(n)
|
||||
|
||||
async def find_by_group_id(
|
||||
self,
|
||||
group_id: str,
|
||||
limit: int = 1000,
|
||||
start_date: Optional[datetime] = None,
|
||||
end_date: Optional[datetime] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Query memory summaries by group_id
|
||||
|
||||
Args:
|
||||
group_id: Group ID to filter by
|
||||
limit: Maximum number of results to return
|
||||
start_date: Optional start date filter
|
||||
end_date: Optional end date filter
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: List of memory summary dictionaries
|
||||
"""
|
||||
query = f"""
|
||||
MATCH (n:{self.node_label})
|
||||
WHERE n.group_id = $group_id
|
||||
"""
|
||||
|
||||
params = {"group_id": group_id, "limit": limit}
|
||||
|
||||
# Add date range filters if provided
|
||||
if start_date:
|
||||
query += " AND n.created_at >= $start_date"
|
||||
params["start_date"] = start_date
|
||||
|
||||
if end_date:
|
||||
query += " AND n.created_at <= $end_date"
|
||||
params["end_date"] = end_date
|
||||
|
||||
query += """
|
||||
RETURN n
|
||||
ORDER BY n.created_at DESC
|
||||
LIMIT $limit
|
||||
"""
|
||||
|
||||
results = await self.connector.execute_query(query, **params)
|
||||
return [self._map_to_dict(r) for r in results]
|
||||
|
||||
async def find_by_user_id(
|
||||
self,
|
||||
user_id: str,
|
||||
limit: int = 1000,
|
||||
start_date: Optional[datetime] = None,
|
||||
end_date: Optional[datetime] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Query memory summaries by user_id
|
||||
|
||||
Args:
|
||||
user_id: User ID to filter by
|
||||
limit: Maximum number of results to return
|
||||
start_date: Optional start date filter
|
||||
end_date: Optional end date filter
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: List of memory summary dictionaries
|
||||
"""
|
||||
query = f"""
|
||||
MATCH (n:{self.node_label})
|
||||
WHERE n.user_id = $user_id
|
||||
"""
|
||||
|
||||
params = {"user_id": user_id, "limit": limit}
|
||||
|
||||
# Add date range filters if provided
|
||||
if start_date:
|
||||
query += " AND n.created_at >= $start_date"
|
||||
params["start_date"] = start_date
|
||||
|
||||
if end_date:
|
||||
query += " AND n.created_at <= $end_date"
|
||||
params["end_date"] = end_date
|
||||
|
||||
query += """
|
||||
RETURN n
|
||||
ORDER BY n.created_at DESC
|
||||
LIMIT $limit
|
||||
"""
|
||||
|
||||
results = await self.connector.execute_query(query, **params)
|
||||
return [self._map_to_dict(r) for r in results]
|
||||
|
||||
async def find_by_group_and_user(
|
||||
self,
|
||||
group_id: str,
|
||||
user_id: str,
|
||||
limit: int = 1000,
|
||||
start_date: Optional[datetime] = None,
|
||||
end_date: Optional[datetime] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Query memory summaries by both group_id and user_id
|
||||
|
||||
Args:
|
||||
group_id: Group ID to filter by
|
||||
user_id: User ID to filter by
|
||||
limit: Maximum number of results to return
|
||||
start_date: Optional start date filter
|
||||
end_date: Optional end date filter
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: List of memory summary dictionaries
|
||||
"""
|
||||
query = f"""
|
||||
MATCH (n:{self.node_label})
|
||||
WHERE n.group_id = $group_id AND n.user_id = $user_id
|
||||
"""
|
||||
|
||||
params = {"group_id": group_id, "user_id": user_id, "limit": limit}
|
||||
|
||||
# Add date range filters if provided
|
||||
if start_date:
|
||||
query += " AND n.created_at >= $start_date"
|
||||
params["start_date"] = start_date
|
||||
|
||||
if end_date:
|
||||
query += " AND n.created_at <= $end_date"
|
||||
params["end_date"] = end_date
|
||||
|
||||
query += """
|
||||
RETURN n
|
||||
ORDER BY n.created_at DESC
|
||||
LIMIT $limit
|
||||
"""
|
||||
|
||||
results = await self.connector.execute_query(query, **params)
|
||||
return [self._map_to_dict(r) for r in results]
|
||||
|
||||
async def find_recent_summaries(
|
||||
self,
|
||||
group_id: str,
|
||||
days: int = 7,
|
||||
limit: int = 1000
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Query recent memory summaries
|
||||
|
||||
Args:
|
||||
group_id: Group ID to filter by
|
||||
days: Number of recent days to query
|
||||
limit: Maximum number of results to return
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: List of memory summary dictionaries
|
||||
"""
|
||||
query = f"""
|
||||
MATCH (n:{self.node_label})
|
||||
WHERE n.group_id = $group_id
|
||||
AND n.created_at >= datetime() - duration({{days: $days}})
|
||||
RETURN n
|
||||
ORDER BY n.created_at DESC
|
||||
LIMIT $limit
|
||||
"""
|
||||
|
||||
results = await self.connector.execute_query(
|
||||
query,
|
||||
group_id=group_id,
|
||||
days=days,
|
||||
limit=limit
|
||||
)
|
||||
return [self._map_to_dict(r) for r in results]
|
||||
|
||||
async def find_by_content_keywords(
|
||||
self,
|
||||
group_id: str,
|
||||
keywords: List[str],
|
||||
limit: int = 100
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Query memory summaries by content keywords
|
||||
|
||||
Args:
|
||||
group_id: Group ID to filter by
|
||||
keywords: List of keywords to search for in content
|
||||
limit: Maximum number of results to return
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: List of memory summary dictionaries
|
||||
"""
|
||||
# Build keyword search conditions
|
||||
keyword_conditions = []
|
||||
params = {"group_id": group_id, "limit": limit}
|
||||
|
||||
for i, keyword in enumerate(keywords):
|
||||
keyword_conditions.append(f"toLower(n.content) CONTAINS toLower($keyword_{i})")
|
||||
params[f"keyword_{i}"] = keyword
|
||||
|
||||
keyword_filter = " OR ".join(keyword_conditions)
|
||||
|
||||
query = f"""
|
||||
MATCH (n:{self.node_label})
|
||||
WHERE n.group_id = $group_id
|
||||
AND ({keyword_filter})
|
||||
RETURN n
|
||||
ORDER BY n.created_at DESC
|
||||
LIMIT $limit
|
||||
"""
|
||||
|
||||
results = await self.connector.execute_query(query, **params)
|
||||
return [self._map_to_dict(r) for r in results]
|
||||
|
||||
async def get_summary_count_by_group(self, group_id: str) -> int:
|
||||
"""Get count of memory summaries for a group
|
||||
|
||||
Args:
|
||||
group_id: Group ID to count summaries for
|
||||
|
||||
Returns:
|
||||
int: Number of memory summaries
|
||||
"""
|
||||
query = f"""
|
||||
MATCH (n:{self.node_label})
|
||||
WHERE n.group_id = $group_id
|
||||
RETURN count(n) as count
|
||||
"""
|
||||
|
||||
results = await self.connector.execute_query(query, group_id=group_id)
|
||||
return results[0]['count'] if results else 0
|
||||
|
||||
@@ -78,7 +78,7 @@ class StatementRepository(BaseNeo4jRepository[StatementNode]):
|
||||
# 处理 ACT-R 属性 - 确保字段存在且有默认值
|
||||
n['importance_score'] = n.get('importance_score', 0.5)
|
||||
n['activation_value'] = n.get('activation_value')
|
||||
n['access_history'] = n.get('access_history', [])
|
||||
n['access_history'] = n.get('access_history') or []
|
||||
n['last_access_time'] = n.get('last_access_time')
|
||||
n['access_count'] = n.get('access_count', 0)
|
||||
|
||||
|
||||
@@ -38,6 +38,33 @@ class ToolRepository:
|
||||
|
||||
return result[0] if result else None
|
||||
|
||||
@staticmethod
|
||||
def get_tenant_id_by_workspace_id(db: Session, workspace_id: str) -> Optional[uuid.UUID]:
|
||||
"""
|
||||
根据空间ID获取tenant_id
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
workspace_id: 空间ID
|
||||
|
||||
Returns:
|
||||
tenant_id或None
|
||||
"""
|
||||
from app.models.workspace_model import Workspace
|
||||
|
||||
tenant_id = db.query(Workspace.tenant_id).filter(
|
||||
Workspace.id == workspace_id
|
||||
).scalar()
|
||||
|
||||
if tenant_id is not None and not isinstance(tenant_id, uuid.UUID):
|
||||
# 兼容数据库中字段类型不匹配的情况(比如存储为字符串)
|
||||
try:
|
||||
tenant_id = uuid.UUID(tenant_id)
|
||||
except (ValueError, TypeError):
|
||||
return None
|
||||
|
||||
return tenant_id
|
||||
|
||||
@staticmethod
|
||||
def find_by_tenant(
|
||||
db: Session,
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import datetime
|
||||
import uuid
|
||||
from typing import Optional, Any, List, Dict
|
||||
from typing import Optional, Any, List, Dict, Union
|
||||
|
||||
from pydantic import BaseModel, Field, ConfigDict, field_serializer, field_validator
|
||||
|
||||
@@ -36,6 +36,12 @@ class KnowledgeRetrievalConfig(BaseModel):
|
||||
|
||||
|
||||
class ToolConfig(BaseModel):
|
||||
"""工具配置"""
|
||||
enabled: bool = Field(default=False, description="是否启用该工具")
|
||||
tool_id: Optional[str] = Field(default=None, description="工具ID")
|
||||
operation: Optional[str] = Field(default=None, description="工具特定配置")
|
||||
|
||||
class ToolOldConfig(BaseModel):
|
||||
"""工具配置"""
|
||||
enabled: bool = Field(default=False, description="是否启用该工具")
|
||||
config: Optional[Dict[str, Any]] = Field(default_factory=dict, description="工具特定配置")
|
||||
@@ -103,9 +109,9 @@ class AgentConfigCreate(BaseModel):
|
||||
)
|
||||
|
||||
# 工具配置
|
||||
tools: Dict[str, ToolConfig] = Field(
|
||||
default_factory=dict,
|
||||
description="工具配置,key 为工具名称(web_search, code_interpreter, image_generation 等)"
|
||||
tools: List[ToolConfig] = Field(
|
||||
default_factory=list,
|
||||
description="Agent 可用的工具列表"
|
||||
)
|
||||
|
||||
|
||||
@@ -158,7 +164,7 @@ class AgentConfigUpdate(BaseModel):
|
||||
variables: Optional[List[VariableDefinition]] = Field(default=None, description="变量列表")
|
||||
|
||||
# 工具配置
|
||||
tools: Optional[Dict[str, ToolConfig]] = Field(default=None, description="工具配置")
|
||||
tools: Optional[List[ToolConfig]] = Field(default_factory=list, description="工具列表")
|
||||
|
||||
|
||||
# ---------- Output Schemas ----------
|
||||
@@ -216,7 +222,7 @@ class AgentConfig(BaseModel):
|
||||
variables: List[VariableDefinition] = []
|
||||
|
||||
# 工具配置
|
||||
tools: Dict[str, ToolConfig] = {}
|
||||
tools: Union[List[ToolConfig], Dict[str, ToolOldConfig]] = []
|
||||
|
||||
is_active: bool
|
||||
created_at: datetime.datetime
|
||||
|
||||
@@ -35,14 +35,14 @@ class ChatRequest(BaseModel):
|
||||
class Message(BaseModel):
|
||||
"""消息输出"""
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
id: uuid.UUID
|
||||
conversation_id: uuid.UUID
|
||||
role: str
|
||||
content: str
|
||||
meta_data: Optional[Dict[str, Any]] = None
|
||||
created_at: datetime.datetime
|
||||
|
||||
|
||||
@field_serializer("created_at", when_used="json")
|
||||
def _serialize_created_at(self, dt: datetime.datetime):
|
||||
return int(dt.timestamp() * 1000) if dt else None
|
||||
@@ -51,7 +51,7 @@ class Message(BaseModel):
|
||||
class Conversation(BaseModel):
|
||||
"""会话输出"""
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
id: uuid.UUID
|
||||
app_id: uuid.UUID
|
||||
workspace_id: uuid.UUID
|
||||
@@ -63,11 +63,11 @@ class Conversation(BaseModel):
|
||||
is_active: bool
|
||||
created_at: datetime.datetime
|
||||
updated_at: datetime.datetime
|
||||
|
||||
|
||||
@field_serializer("created_at", when_used="json")
|
||||
def _serialize_created_at(self, dt: datetime.datetime):
|
||||
return int(dt.timestamp() * 1000) if dt else None
|
||||
|
||||
|
||||
@field_serializer("updated_at", when_used="json")
|
||||
def _serialize_updated_at(self, dt: datetime.datetime):
|
||||
return int(dt.timestamp() * 1000) if dt else None
|
||||
@@ -84,3 +84,12 @@ class ChatResponse(BaseModel):
|
||||
message: str
|
||||
usage: Optional[Dict[str, Any]] = None
|
||||
elapsed_time: Optional[float] = None
|
||||
|
||||
|
||||
# ---------- Conversation Summary Schemas ----------
|
||||
class ConversationOut(BaseModel):
|
||||
theme: str
|
||||
question: list[str]
|
||||
summary: str
|
||||
takeaways: list[str]
|
||||
info_score: int
|
||||
|
||||
264
api/app/schemas/implicit_memory_schema.py
Normal file
264
api/app/schemas/implicit_memory_schema.py
Normal file
@@ -0,0 +1,264 @@
|
||||
"""Implicit Memory Schemas
|
||||
|
||||
This module defines the Pydantic schemas for the implicit memory system API.
|
||||
"""
|
||||
|
||||
import datetime
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_serializer, field_validator
|
||||
|
||||
|
||||
class FrequencyPattern(str, Enum):
|
||||
"""Frequency patterns for habits."""
|
||||
DAILY = "daily"
|
||||
WEEKLY = "weekly"
|
||||
MONTHLY = "monthly"
|
||||
SEASONAL = "seasonal"
|
||||
OCCASIONAL = "occasional"
|
||||
EVENT_TRIGGERED = "event_triggered"
|
||||
|
||||
|
||||
# Request Schemas
|
||||
|
||||
class TimeRange(BaseModel):
|
||||
"""Time range for analysis."""
|
||||
start_date: datetime.datetime
|
||||
end_date: datetime.datetime
|
||||
|
||||
@field_validator('end_date')
|
||||
@classmethod
|
||||
def end_date_must_be_after_start_date(cls, v, info):
|
||||
if 'start_date' in info.data and v <= info.data['start_date']:
|
||||
raise ValueError('end_date must be after start_date')
|
||||
return v
|
||||
|
||||
@field_serializer("start_date", when_used="json")
|
||||
def _serialize_start_date(self, dt: datetime.datetime):
|
||||
return int(dt.timestamp() * 1000) if dt else None
|
||||
|
||||
@field_serializer("end_date", when_used="json")
|
||||
def _serialize_end_date(self, dt: datetime.datetime):
|
||||
return int(dt.timestamp() * 1000) if dt else None
|
||||
|
||||
|
||||
class DateRange(BaseModel):
|
||||
"""Date range for filtering."""
|
||||
start_date: Optional[datetime.datetime] = None
|
||||
end_date: Optional[datetime.datetime] = None
|
||||
|
||||
@field_validator('end_date')
|
||||
@classmethod
|
||||
def end_date_must_be_after_start_date(cls, v, info):
|
||||
if v and 'start_date' in info.data and info.data['start_date'] and v <= info.data['start_date']:
|
||||
raise ValueError('end_date must be after start_date')
|
||||
return v
|
||||
|
||||
@field_serializer("start_date", when_used="json")
|
||||
def _serialize_start_date(self, dt: Optional[datetime.datetime]):
|
||||
return int(dt.timestamp() * 1000) if dt else None
|
||||
|
||||
@field_serializer("end_date", when_used="json")
|
||||
def _serialize_end_date(self, dt: Optional[datetime.datetime]):
|
||||
return int(dt.timestamp() * 1000) if dt else None
|
||||
|
||||
|
||||
class AnalysisConfig(BaseModel):
|
||||
"""Configuration for analysis operations."""
|
||||
llm_model_id: Optional[str] = None
|
||||
batch_size: int = 100
|
||||
confidence_threshold: float = 0.5
|
||||
include_historical_trends: bool = False
|
||||
max_conversations: Optional[int] = None
|
||||
|
||||
|
||||
# Response Schemas
|
||||
|
||||
class PreferenceTagResponse(BaseModel):
|
||||
"""A user preference tag with detailed context."""
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
tag_name: str
|
||||
confidence_score: float = Field(ge=0.0, le=1.0)
|
||||
supporting_evidence: List[str]
|
||||
context_details: str
|
||||
created_at: datetime.datetime
|
||||
updated_at: datetime.datetime
|
||||
conversation_references: List[str]
|
||||
category: Optional[str] = None
|
||||
|
||||
@field_serializer("created_at", when_used="json")
|
||||
def _serialize_created_at(self, dt: datetime.datetime):
|
||||
return int(dt.timestamp() * 1000) if dt else None
|
||||
|
||||
@field_serializer("updated_at", when_used="json")
|
||||
def _serialize_updated_at(self, dt: datetime.datetime):
|
||||
return int(dt.timestamp() * 1000) if dt else None
|
||||
|
||||
|
||||
class DimensionScoreResponse(BaseModel):
|
||||
"""Score for a personality dimension."""
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
dimension_name: str
|
||||
percentage: float = Field(ge=0.0, le=100.0)
|
||||
evidence: List[str]
|
||||
reasoning: str
|
||||
confidence_level: int = Field(ge=0, le=100)
|
||||
|
||||
|
||||
class DimensionPortraitResponse(BaseModel):
|
||||
"""Four-dimension personality portrait."""
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
user_id: str
|
||||
creativity: DimensionScoreResponse
|
||||
aesthetic: DimensionScoreResponse
|
||||
technology: DimensionScoreResponse
|
||||
literature: DimensionScoreResponse
|
||||
analysis_timestamp: datetime.datetime
|
||||
total_summaries_analyzed: int
|
||||
historical_trends: Optional[List[Dict[str, Any]]] = None
|
||||
|
||||
@field_serializer("analysis_timestamp", when_used="json")
|
||||
def _serialize_analysis_timestamp(self, dt: datetime.datetime):
|
||||
return int(dt.timestamp() * 1000) if dt else None
|
||||
|
||||
|
||||
class InterestCategoryResponse(BaseModel):
|
||||
"""Interest category with percentage and evidence."""
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
category_name: str
|
||||
percentage: float = Field(ge=0.0, le=100.0)
|
||||
evidence: List[str]
|
||||
trending_direction: Optional[str] = None
|
||||
|
||||
|
||||
class InterestAreaDistributionResponse(BaseModel):
|
||||
"""Distribution of user interests across four areas."""
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
user_id: str
|
||||
tech: InterestCategoryResponse
|
||||
lifestyle: InterestCategoryResponse
|
||||
music: InterestCategoryResponse
|
||||
art: InterestCategoryResponse
|
||||
analysis_timestamp: datetime.datetime
|
||||
total_summaries_analyzed: int
|
||||
|
||||
@property
|
||||
def total_percentage(self) -> float:
|
||||
"""Calculate total percentage across all interest areas."""
|
||||
return self.tech.percentage + self.lifestyle.percentage + self.music.percentage + self.art.percentage
|
||||
|
||||
@field_serializer("analysis_timestamp", when_used="json")
|
||||
def _serialize_analysis_timestamp(self, dt: datetime.datetime):
|
||||
return int(dt.timestamp() * 1000) if dt else None
|
||||
|
||||
|
||||
class BehaviorHabitResponse(BaseModel):
|
||||
"""A behavioral habit identified from conversations."""
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
habit_description: str
|
||||
frequency_pattern: FrequencyPattern
|
||||
time_context: str
|
||||
confidence_level: int = Field(ge=0, le=100)
|
||||
first_observed: datetime.datetime
|
||||
last_observed: datetime.datetime
|
||||
is_current: bool = True
|
||||
specific_examples: List[str]
|
||||
|
||||
@field_serializer("first_observed", when_used="json")
|
||||
def _serialize_first_observed(self, dt: datetime.datetime):
|
||||
return int(dt.timestamp() * 1000) if dt else None
|
||||
|
||||
@field_serializer("last_observed", when_used="json")
|
||||
def _serialize_last_observed(self, dt: datetime.datetime):
|
||||
return int(dt.timestamp() * 1000) if dt else None
|
||||
|
||||
|
||||
class UserProfileResponse(BaseModel):
|
||||
"""Comprehensive user profile."""
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
user_id: str
|
||||
preference_tags: List[PreferenceTagResponse]
|
||||
dimension_portrait: DimensionPortraitResponse
|
||||
interest_area_distribution: InterestAreaDistributionResponse
|
||||
behavior_habits: List[BehaviorHabitResponse]
|
||||
profile_version: int
|
||||
created_at: datetime.datetime
|
||||
updated_at: datetime.datetime
|
||||
total_summaries_analyzed: int
|
||||
analysis_completeness_score: float = Field(ge=0.0, le=1.0)
|
||||
|
||||
@field_serializer("created_at", when_used="json")
|
||||
def _serialize_created_at(self, dt: datetime.datetime):
|
||||
return int(dt.timestamp() * 1000) if dt else None
|
||||
|
||||
@field_serializer("updated_at", when_used="json")
|
||||
def _serialize_updated_at(self, dt: datetime.datetime):
|
||||
return int(dt.timestamp() * 1000) if dt else None
|
||||
|
||||
|
||||
# Internal/Business Logic Schemas
|
||||
|
||||
class MemorySummary(BaseModel):
|
||||
"""Memory summary from existing storage system."""
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
summary_id: str
|
||||
content: str
|
||||
timestamp: datetime.datetime
|
||||
participants: List[str]
|
||||
summary_type: str
|
||||
|
||||
@field_serializer("timestamp", when_used="json")
|
||||
def _serialize_timestamp(self, dt: datetime.datetime):
|
||||
return int(dt.timestamp() * 1000) if dt else None
|
||||
|
||||
|
||||
class UserMemorySummary(BaseModel):
|
||||
"""Memory summary filtered for specific user content."""
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
summary_id: str
|
||||
user_id: str
|
||||
user_content: str
|
||||
timestamp: datetime.datetime
|
||||
confidence_score: float = Field(ge=0.0, le=1.0)
|
||||
summary_type: str
|
||||
|
||||
@field_serializer("timestamp", when_used="json")
|
||||
def _serialize_timestamp(self, dt: datetime.datetime):
|
||||
return int(dt.timestamp() * 1000) if dt else None
|
||||
|
||||
|
||||
class SummaryAnalysisResult(BaseModel):
|
||||
"""Result of analyzing memory summaries."""
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
user_id: str
|
||||
preferences: List[PreferenceTagResponse]
|
||||
dimension_evidence: Dict[str, List[str]]
|
||||
interest_evidence: Dict[str, List[str]]
|
||||
habit_evidence: List[Dict[str, Any]]
|
||||
analysis_timestamp: datetime.datetime
|
||||
summaries_analyzed: List[str]
|
||||
|
||||
@field_serializer("analysis_timestamp", when_used="json")
|
||||
def _serialize_analysis_timestamp(self, dt: datetime.datetime):
|
||||
return int(dt.timestamp() * 1000) if dt else None
|
||||
|
||||
|
||||
# Aliases for backward compatibility with existing code
|
||||
PreferenceTag = PreferenceTagResponse
|
||||
DimensionScore = DimensionScoreResponse
|
||||
DimensionPortrait = DimensionPortraitResponse
|
||||
InterestCategory = InterestCategoryResponse
|
||||
InterestAreaDistribution = InterestAreaDistributionResponse
|
||||
BehaviorHabit = BehaviorHabitResponse
|
||||
UserProfile = UserProfileResponse
|
||||
136
api/app/schemas/memory_perceptual_schema.py
Normal file
136
api/app/schemas/memory_perceptual_schema.py
Normal file
@@ -0,0 +1,136 @@
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.models.memory_perceptual_model import PerceptualType, FileStorageType
|
||||
|
||||
|
||||
class PerceptualFilter(BaseModel):
|
||||
type: PerceptualType | None = Field(
|
||||
default=None,
|
||||
description="Perceptual type used for filtering the query; optional"
|
||||
)
|
||||
|
||||
|
||||
class PerceptualQuerySchema(BaseModel):
|
||||
filter: PerceptualFilter = Field(
|
||||
default_factory=lambda: PerceptualFilter(),
|
||||
description="Query filter containing perceptual type criteria"
|
||||
)
|
||||
|
||||
page: int = Field(
|
||||
default=1,
|
||||
ge=1,
|
||||
description="Page number for pagination, starting from 1"
|
||||
)
|
||||
|
||||
page_size: int = Field(
|
||||
default=10,
|
||||
ge=1,
|
||||
le=100,
|
||||
description="Number of records per page, range 1-100"
|
||||
)
|
||||
|
||||
|
||||
class PerceptualMemoryItem(BaseModel):
|
||||
"""感知记忆项"""
|
||||
id: uuid.UUID = Field(..., description="Unique memory ID")
|
||||
perceptual_type: PerceptualType = Field(..., description="Type of perception, e.g., text, audio, or video")
|
||||
file_path: str = Field(..., description="File path in the storage service")
|
||||
file_ext: str = Field(..., description="File extension")
|
||||
file_name: str = Field(..., description="File name")
|
||||
summary: Optional[str] = Field(None, description="summary")
|
||||
storage_type: FileStorageType = Field(..., description="Storage type for file")
|
||||
created_time: int = Field(..., description="create time")
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class PerceptualTimelineResponse(BaseModel):
|
||||
"""感知记忆时间线响应"""
|
||||
total: int = Field(..., description="总数量")
|
||||
page: int = Field(..., description="当前页码")
|
||||
page_size: int = Field(..., description="每页大小")
|
||||
total_pages: int = Field(..., description="总页数")
|
||||
memories: list[PerceptualMemoryItem] = Field(..., description="记忆列表")
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
# --------------------------
|
||||
# TODO: FileMetaData
|
||||
# --------------------------
|
||||
class Identity(BaseModel):
|
||||
title: str
|
||||
filename: str
|
||||
source: str # upload | crawl | system
|
||||
author: Optional[str] = None
|
||||
|
||||
|
||||
class Semantic(BaseModel):
|
||||
topic: str
|
||||
domain: str
|
||||
difficulty: str # beginner | intermediate | advanced
|
||||
intent: str # informative | instructional | promotional
|
||||
sentiment: str # positive | neutral | negative
|
||||
|
||||
|
||||
class Content(BaseModel):
|
||||
summary: str
|
||||
keywords: list[str]
|
||||
topic: str
|
||||
domain: str
|
||||
|
||||
|
||||
class Usage(BaseModel):
|
||||
target_audience: list[str]
|
||||
use_cases: list[str]
|
||||
|
||||
|
||||
class Stats(BaseModel):
|
||||
duration_sec: Optional[int] = None
|
||||
char_count: int
|
||||
word_count: int
|
||||
|
||||
|
||||
class Processing(BaseModel):
|
||||
transcribed: bool
|
||||
ocr_applied: bool
|
||||
chunked: bool
|
||||
vectorized: bool
|
||||
embedding_model: Optional[str] = None
|
||||
|
||||
|
||||
class VideoModal(BaseModel):
|
||||
scene: list[str]
|
||||
|
||||
|
||||
class AudioModal(BaseModel):
|
||||
speaker_count: int
|
||||
|
||||
|
||||
class TextModal(BaseModel):
|
||||
section_count: int
|
||||
title: str
|
||||
first_line: str
|
||||
|
||||
|
||||
class Asset(BaseModel):
|
||||
type: str
|
||||
modality: str # text | audio | video
|
||||
format: str # docx | mp3 | mp4
|
||||
language: str
|
||||
encoding: str
|
||||
|
||||
identity: Identity
|
||||
semantic: Semantic
|
||||
content: Content
|
||||
usage: Usage
|
||||
stats: Stats
|
||||
processing: Processing
|
||||
created_at: str
|
||||
modalities: AudioModal | TextModal | VideoModal
|
||||
@@ -409,10 +409,9 @@ class ForgettingTriggerRequest(BaseModel):
|
||||
"""手动触发遗忘周期请求模型"""
|
||||
model_config = ConfigDict(populate_by_name=True, extra="forbid")
|
||||
|
||||
group_id: Optional[str] = Field(None, description="组ID(可选,用于过滤特定组的节点)")
|
||||
group_id: str = Field(..., description="组ID(即终端用户ID,必填)")
|
||||
max_merge_batch_size: int = Field(100, ge=1, le=1000, description="单次最大融合节点对数(默认100)")
|
||||
min_days_since_access: int = Field(30, ge=1, le=365, description="最小未访问天数(默认30天)")
|
||||
config_id: Optional[int] = Field(None, description="配置ID(可选,用于指定遗忘引擎配置)") # TODO 后续group_id更换成enduser_id,自动与config_id关联 ,要删除此行
|
||||
|
||||
|
||||
class ForgettingConfigResponse(BaseModel):
|
||||
@@ -450,15 +449,36 @@ class ForgettingConfigUpdateRequest(BaseModel):
|
||||
forgetting_interval_hours: Optional[int] = Field(None, ge=1, le=168, description="遗忘周期间隔(小时)")
|
||||
|
||||
|
||||
class ForgettingCycleHistoryPoint(BaseModel):
|
||||
"""遗忘周期历史数据点模型(用于趋势图)"""
|
||||
model_config = ConfigDict(populate_by_name=True, extra="forbid")
|
||||
|
||||
date: str = Field(..., description="日期(格式: '1/1', '1/2')")
|
||||
merged_count: int = Field(..., description="每日融合节点数")
|
||||
average_activation: Optional[float] = Field(None, description="平均激活值")
|
||||
total_nodes: int = Field(..., description="总节点数")
|
||||
execution_time: int = Field(..., description="执行时间(Unix时间戳,秒)")
|
||||
|
||||
|
||||
class PendingForgettingNode(BaseModel):
|
||||
"""待遗忘节点模型"""
|
||||
model_config = ConfigDict(populate_by_name=True, extra="forbid")
|
||||
|
||||
node_id: str = Field(..., description="节点ID")
|
||||
node_type: str = Field(..., description="节点类型:statement/entity/summary")
|
||||
content_summary: str = Field(..., description="内容摘要")
|
||||
activation_value: float = Field(..., description="激活值")
|
||||
last_access_time: int = Field(..., description="最后访问时间(Unix时间戳,秒)")
|
||||
|
||||
|
||||
class ForgettingStatsResponse(BaseModel):
|
||||
"""遗忘引擎统计信息响应模型"""
|
||||
model_config = ConfigDict(populate_by_name=True, extra="forbid")
|
||||
activation_metrics: Dict[str, Any] = Field(..., description="激活值相关指标")
|
||||
node_distribution: Dict[str, int] = Field(..., description="节点类型分布")
|
||||
consistency_check: Optional[Dict[str, Any]] = Field(None, description="数据一致性检查结果")
|
||||
nodes_merged_total: int = Field(..., description="累计融合节点对数")
|
||||
recent_cycles: List[Dict[str, Any]] = Field(..., description="最近的遗忘周期记录")
|
||||
timestamp: str = Field(..., description="统计时间(ISO格式)")
|
||||
recent_trends: List[ForgettingCycleHistoryPoint] = Field(..., description="最近7个日期的遗忘趋势数据(每天取最后一次执行)")
|
||||
pending_nodes: List[PendingForgettingNode] = Field(..., description="待遗忘节点列表(前20个满足遗忘条件的节点)")
|
||||
timestamp: int = Field(..., description="统计时间(时间戳)")
|
||||
|
||||
|
||||
class ForgettingReportResponse(BaseModel):
|
||||
|
||||
@@ -66,9 +66,9 @@ class MultiAgentConfigCreate(BaseModel):
|
||||
master_agent_id: uuid.UUID = Field(..., description="主 Agent ID")
|
||||
master_agent_name: Optional[str] = Field(None, max_length=100, description="主 Agent 名称")
|
||||
orchestration_mode: str = Field(
|
||||
...,
|
||||
pattern="^(sequential|parallel|conditional|loop)$",
|
||||
description="编排模式:sequential|parallel|conditional|loop"
|
||||
default="collaboration",
|
||||
pattern="^(collaboration|supervisor)$",
|
||||
description="协作模式:collaboration(协作)| supervisor(监督)"
|
||||
)
|
||||
sub_agents: List[SubAgentConfig] = Field(..., description="子 Agent 列表")
|
||||
routing_rules: Optional[List[RoutingRule]] = Field(None, description="路由规则")
|
||||
@@ -84,14 +84,15 @@ class MultiAgentConfigUpdate(BaseModel):
|
||||
"""更新多 Agent 配置"""
|
||||
master_agent_id: Optional[uuid.UUID] = None
|
||||
master_agent_name: Optional[str] = Field(None, max_length=100, description="主 Agent 名称")
|
||||
default_model_config_id : uuid.UUID = Field(description="默认模型配置ID")
|
||||
model_parameters: ModelParameters | None = Field(
|
||||
default_factory=ModelParameters,
|
||||
default_model_config_id: Optional[uuid.UUID] = Field(None, description="默认模型配置ID")
|
||||
model_parameters: Optional[ModelParameters] = Field(
|
||||
None,
|
||||
description="模型参数配置(temperature、max_tokens 等)"
|
||||
)
|
||||
orchestration_mode: Optional[str] = Field(
|
||||
None,
|
||||
pattern="^(sequential|parallel|conditional|loop)$"
|
||||
default="collaboration",
|
||||
pattern="^(collaboration|supervisor)$",
|
||||
description="协作模式:collaboration(协作)| supervisor(监督)"
|
||||
)
|
||||
sub_agents: Optional[List[SubAgentConfig]] = None
|
||||
routing_rules: Optional[List[RoutingRule]] = None
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user