Merge branch 'develop' into feature/knowledgeBase_yjp

This commit is contained in:
yujiangping
2026-01-12 17:53:22 +08:00
231 changed files with 18309 additions and 4060 deletions

25
api/app/base/type.py Normal file
View File

@@ -0,0 +1,25 @@
from pydantic import BaseModel, Field
from sqlalchemy import TypeDecorator, JSON
class PydanticType(TypeDecorator):
impl = JSON
def __init__(self, pydantic_model: type[BaseModel]):
super().__init__()
self.model = pydantic_model
def process_bind_param(self, value, dialect):
# 入库Model -> dict
if value is None:
return None
if isinstance(value, self.model):
return value.dict()
return value # 已经是 dict 也放行
def process_result_value(self, value, dialect):
# 出库dict -> Model
if value is None:
return None
# return self.model.parse_obj(value) # pydantic v1
return self.model.model_validate(value) # pydantic v2

View File

@@ -4,39 +4,45 @@
认证方式: JWT Token
"""
from fastapi import APIRouter
from . import (
model_controller,
task_controller,
test_controller,
user_controller,
auth_controller,
workspace_controller,
setup_controller,
file_controller,
document_controller,
knowledge_controller,
chunk_controller,
knowledgeshare_controller,
api_key_controller,
app_controller,
upload_controller,
auth_controller,
chunk_controller,
document_controller,
emotion_config_controller,
emotion_controller,
file_controller,
home_page_controller,
implicit_memory_controller,
knowledge_controller,
knowledgeshare_controller,
memory_agent_controller,
memory_dashboard_controller,
memory_storage_controller,
memory_dashboard_controller,
memory_forget_controller,
memory_reflection_controller,
api_key_controller,
release_share_controller,
public_share_controller,
memory_short_term_controller,
memory_storage_controller,
model_controller,
multi_agent_controller,
workflow_controller,
emotion_controller,
emotion_config_controller,
prompt_optimizer_controller,
public_share_controller,
release_share_controller,
setup_controller,
task_controller,
test_controller,
tool_controller,
upload_controller,
user_controller,
user_memory_controllers,
workflow_controller,
workspace_controller,
memory_forget_controller,
home_page_controller,
memory_perceptual_controller,
memory_working_controller,
)
from . import user_memory_controllers
# 创建管理端 API 路由器
manager_router = APIRouter()
@@ -71,8 +77,12 @@ manager_router.include_router(emotion_controller.router)
manager_router.include_router(emotion_config_controller.router)
manager_router.include_router(prompt_optimizer_controller.router)
manager_router.include_router(memory_reflection_controller.router)
manager_router.include_router(memory_short_term_controller.router)
manager_router.include_router(tool_controller.router)
manager_router.include_router(memory_forget_controller.router)
manager_router.include_router(home_page_controller.router)
manager_router.include_router(implicit_memory_controller.router)
manager_router.include_router(memory_perceptual_controller.router)
manager_router.include_router(memory_working_controller.router)
__all__ = ["manager_router"]

View File

@@ -48,6 +48,7 @@ def list_apps(
include_shared: bool = True,
page: int = 1,
pagesize: int = 10,
ids: Optional[str] = None,
db: Session = Depends(get_db),
current_user=Depends(get_current_user),
):
@@ -55,8 +56,19 @@ def list_apps(
- 默认包含本工作空间的应用和分享给本工作空间的应用
- 设置 include_shared=false 可以只查看本工作空间的应用
- 当提供 ids 参数时,按逗号分割获取指定应用,不分页
"""
workspace_id = current_user.current_workspace_id
service = app_service.AppService(db)
# 当 ids 存在且不为 None 时,根据 ids 获取应用
if ids is not None:
app_ids = [id.strip() for id in ids.split(',') if id.strip()]
items_orm = app_service.get_apps_by_ids(db, app_ids, workspace_id)
items = [service._convert_to_schema(app, workspace_id) for app in items_orm]
return success(data=items)
# 正常分页查询
items_orm, total = app_service.list_apps(
db,
workspace_id=workspace_id,
@@ -69,8 +81,6 @@ def list_apps(
pagesize=pagesize,
)
# 使用 AppService 的转换方法来设置 is_shared 字段
service = app_service.AppService(db)
items = [service._convert_to_schema(app, workspace_id) for app in items_orm]
meta = PageMeta(page=page, pagesize=pagesize, total=total, hasnext=(page * pagesize) < total)
return success(data=PageData(page=meta, items=items))
@@ -506,7 +516,7 @@ async def draft_run(
multi_agent_request = MultiAgentRunRequest(
message=payload.message,
conversation_id=payload.conversation_id,
user_id=payload.user_id,
user_id=payload.user_id or str(current_user.id),
variables=payload.variables or {},
use_llm_routing=True # 默认启用 LLM 路由
)
@@ -728,9 +738,23 @@ async def draft_run_compare(
from app.core.exceptions import ResourceNotFoundException
raise ResourceNotFoundException("模型配置", str(model_item.model_config_id))
# 获取 agent_cfg.model_parameters如果是 ModelParameters 对象则转为字典
agent_model_params = agent_cfg.model_parameters
if hasattr(agent_model_params, 'model_dump'):
agent_model_params = agent_model_params.model_dump()
elif not isinstance(agent_model_params, dict):
agent_model_params = {}
# 获取 model_item.model_parameters如果是 ModelParameters 对象则转为字典
item_model_params = model_item.model_parameters
if hasattr(item_model_params, 'model_dump'):
item_model_params = item_model_params.model_dump()
elif not isinstance(item_model_params, dict):
item_model_params = {}
merged_parameters = {
**(agent_cfg.model_parameters or {}),
**(model_item.model_parameters or {})
**(agent_model_params or {}),
**(item_model_params or {})
}
model_configs.append({

View File

@@ -1,6 +1,7 @@
from fastapi import APIRouter, Depends
from sqlalchemy.orm import Session
from app.core.config import settings
from app.core.response_utils import success
from app.db import get_db
from app.dependencies import get_current_user
@@ -26,4 +27,12 @@ def get_workspace_list(
):
"""获取工作空间列表"""
workspace_list = HomePageService.get_workspace_list(db, current_user.tenant_id)
return success(data=workspace_list, msg="工作空间列表获取成功")
return success(data=workspace_list, msg="工作空间列表获取成功")
@router.get("/version", response_model=ApiResponse)
def get_system_version():
"""获取系统版本号+说明"""
return success(data={
"version": settings.SYSTEM_VERSION,
"introduction": settings.SYSTEM_INTRODUCTION
}, msg="系统版本获取成功")

View File

@@ -0,0 +1,312 @@
from datetime import datetime
from typing import Optional
from app.core.error_codes import BizCode
from app.core.logging_config import get_api_logger
from app.core.response_utils import fail, success
from app.db import get_db
from app.dependencies import (
cur_workspace_access_guard,
get_current_user,
)
from app.models.user_model import User
from app.schemas.response_schema import ApiResponse
from app.services.implicit_memory_service import ImplicitMemoryService
from fastapi import APIRouter, Depends, Query
from sqlalchemy.orm import Session
api_logger = get_api_logger()
router = APIRouter(
prefix="/memory/implicit-memory",
tags=["Implicit Memory"],
)
def handle_implicit_memory_error(e: Exception, operation: str, user_id: str = None) -> dict:
"""
Centralized error handling for implicit memory operations.
Args:
e: The exception that occurred
operation: Description of the operation that failed
user_id: Optional user ID for logging context
Returns:
Standardized error response
"""
error_context = f"user_id={user_id}" if user_id else "unknown user"
if isinstance(e, ValueError):
if "user" in str(e).lower() and "not found" in str(e).lower():
api_logger.warning(f"Invalid user ID for {operation}: {error_context}")
return fail(BizCode.INVALID_USER_ID, "无效的用户ID", str(e))
elif "insufficient" in str(e).lower() or "no data" in str(e).lower():
api_logger.warning(f"Insufficient data for {operation}: {error_context}")
return fail(BizCode.INSUFFICIENT_DATA, "数据不足,无法进行分析", str(e))
else:
api_logger.warning(f"Invalid parameters for {operation}: {error_context}")
return fail(BizCode.INVALID_FILTER_PARAMS, "无效的参数", str(e))
elif isinstance(e, KeyError):
api_logger.warning(f"Missing required data for {operation}: {error_context}")
return fail(BizCode.INSUFFICIENT_DATA, "缺少必要的数据", str(e))
elif isinstance(e, (ConnectionError, TimeoutError)):
api_logger.error(f"Service unavailable for {operation}: {error_context}")
return fail(BizCode.SERVICE_UNAVAILABLE, "服务暂时不可用", str(e))
elif "analysis" in str(e).lower() or "llm" in str(e).lower():
api_logger.error(f"Analysis failed for {operation}: {error_context}", exc_info=True)
return fail(BizCode.ANALYSIS_FAILED, "分析处理失败", str(e))
elif "storage" in str(e).lower() or "database" in str(e).lower():
api_logger.error(f"Storage error for {operation}: {error_context}", exc_info=True)
return fail(BizCode.PROFILE_STORAGE_ERROR, "数据存储失败", str(e))
else:
api_logger.error(f"Unexpected error for {operation}: {error_context}", exc_info=True)
return fail(BizCode.INTERNAL_ERROR, f"{operation}失败", str(e))
def validate_user_id(user_id: str) -> None:
"""
Validate user ID format and constraints.
Args:
user_id: User ID to validate
Raises:
ValueError: If user ID is invalid
"""
if not user_id or not user_id.strip():
raise ValueError("User ID cannot be empty")
if len(user_id.strip()) < 1:
raise ValueError("User ID is too short")
def validate_date_range(start_date: Optional[datetime], end_date: Optional[datetime]) -> None:
"""
Validate date range parameters.
Args:
start_date: Start date
end_date: End date
Raises:
ValueError: If date range is invalid
"""
if (start_date and not end_date) or (end_date and not start_date):
raise ValueError("Both start_date and end_date must be provided together")
if start_date and end_date and start_date >= end_date:
raise ValueError("start_date must be before end_date")
if start_date and start_date > datetime.now():
raise ValueError("start_date cannot be in the future")
def validate_confidence_threshold(threshold: float) -> None:
"""
Validate confidence threshold parameter.
Args:
threshold: Confidence threshold to validate
Raises:
ValueError: If threshold is invalid
"""
if not 0.0 <= threshold <= 1.0:
raise ValueError("confidence_threshold must be between 0.0 and 1.0")
@router.get("/preferences/{user_id}", response_model=ApiResponse)
@cur_workspace_access_guard()
async def get_preference_tags(
user_id: str,
confidence_threshold: float = Query(0.5, ge=0.0, le=1.0, description="Minimum confidence threshold"),
tag_category: Optional[str] = Query(None, description="Filter by tag category"),
start_date: Optional[datetime] = Query(None, description="Filter start date"),
end_date: Optional[datetime] = Query(None, description="Filter end date"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
) -> ApiResponse:
"""
Get user preference tags with filtering options.
Args:
user_id: Target user ID
confidence_threshold: Minimum confidence score (0.0-1.0)
tag_category: Optional category filter
start_date: Optional start date filter
end_date: Optional end date filter
Returns:
List of preference tags matching the filters
"""
api_logger.info(f"Preference tags requested for user: {user_id}")
try:
# Validate inputs
validate_user_id(user_id)
validate_confidence_threshold(confidence_threshold)
validate_date_range(start_date, end_date)
# Create service with user-specific config
service = ImplicitMemoryService(db=db, end_user_id=user_id)
# Build date range
date_range = None
if start_date and end_date:
from app.schemas.implicit_memory_schema import DateRange
date_range = DateRange(start_date=start_date, end_date=end_date)
# Get preference tags
tags = await service.get_preference_tags(
user_id=user_id,
confidence_threshold=confidence_threshold,
tag_category=tag_category,
date_range=date_range
)
api_logger.info(f"Retrieved {len(tags)} preference tags for user: {user_id}")
return success(data=[tag.model_dump(mode='json') for tag in tags], msg="偏好标签获取成功")
except Exception as e:
return handle_implicit_memory_error(e, "偏好标签获取", user_id)
@router.get("/portrait/{user_id}", response_model=ApiResponse)
@cur_workspace_access_guard()
async def get_dimension_portrait(
user_id: str,
include_history: bool = Query(False, description="Include historical trends"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
) -> ApiResponse:
"""
Get user's four-dimension personality portrait.
Args:
user_id: Target user ID
include_history: Whether to include historical trend data
Returns:
Four-dimension personality portrait with scores and evidence
"""
api_logger.info(f"Dimension portrait requested for user: {user_id}")
try:
# Validate inputs
validate_user_id(user_id)
# Create service with user-specific config
service = ImplicitMemoryService(db=db, end_user_id=user_id)
portrait = await service.get_dimension_portrait(
user_id=user_id,
include_history=include_history
)
api_logger.info(f"Dimension portrait retrieved for user: {user_id}")
return success(data=portrait.model_dump(mode='json'), msg="四维画像获取成功")
except Exception as e:
return handle_implicit_memory_error(e, "四维画像获取", user_id)
@router.get("/interest-areas/{user_id}", response_model=ApiResponse)
@cur_workspace_access_guard()
async def get_interest_area_distribution(
user_id: str,
include_trends: bool = Query(False, description="Include trend analysis"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
) -> ApiResponse:
"""
Get user's interest area distribution across four areas.
Args:
user_id: Target user ID
include_trends: Whether to include trend analysis data
Returns:
Interest area distribution with percentages and evidence
"""
api_logger.info(f"Interest area distribution requested for user: {user_id}")
try:
# Validate inputs
validate_user_id(user_id)
# Create service with user-specific config
service = ImplicitMemoryService(db=db, end_user_id=user_id)
distribution = await service.get_interest_area_distribution(
user_id=user_id,
include_trends=include_trends
)
api_logger.info(f"Interest area distribution retrieved for user: {user_id}")
return success(data=distribution.model_dump(mode='json'), msg="兴趣领域分布获取成功")
except Exception as e:
return handle_implicit_memory_error(e, "兴趣领域分布获取", user_id)
@router.get("/habits/{user_id}", response_model=ApiResponse)
@cur_workspace_access_guard()
async def get_behavior_habits(
user_id: str,
confidence_level: Optional[str] = Query(None, regex="^(high|medium|low)$", description="Filter by confidence level"),
frequency_pattern: Optional[str] = Query(None, regex="^(daily|weekly|monthly|seasonal|occasional|event_triggered)$", description="Filter by frequency pattern"),
time_period: Optional[str] = Query(None, regex="^(current|past)$", description="Filter by time period"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
) -> ApiResponse:
"""
Get user's behavioral habits with filtering options.
Args:
user_id: Target user ID
confidence_level: Filter by confidence level (high, medium, low)
frequency_pattern: Filter by frequency pattern (daily, weekly, monthly, seasonal, occasional, event_triggered)
time_period: Filter by time period (current, past)
Returns:
List of behavioral habits matching the filters
"""
api_logger.info(f"Behavior habits requested for user: {user_id}")
try:
# Validate inputs
validate_user_id(user_id)
# Convert string confidence level to numerical
numerical_confidence = None
if confidence_level:
confidence_mapping = {
"high": 85,
"medium": 50,
"low": 20
}
numerical_confidence = confidence_mapping.get(confidence_level.lower())
# Create service with user-specific config
service = ImplicitMemoryService(db=db, end_user_id=user_id)
habits = await service.get_behavior_habits(
user_id=user_id,
confidence_level=numerical_confidence,
frequency_pattern=frequency_pattern,
time_period=time_period
)
api_logger.info(f"Retrieved {len(habits)} behavior habits for user: {user_id}")
return success(data=[habit.model_dump(mode='json') for habit in habits], msg="行为习惯获取成功")
except Exception as e:
return handle_implicit_memory_error(e, "行为习惯获取", user_id)

View File

@@ -76,9 +76,28 @@ async def trigger_forgetting_cycle(
api_logger.warning(f"用户 {current_user.username} 尝试触发遗忘周期但未选择工作空间")
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
# 通过 group_id 获取关联的 config_id
try:
from app.services.memory_agent_service import get_end_user_connected_config
connected_config = get_end_user_connected_config(payload.group_id, db)
config_id = connected_config.get("memory_config_id")
if config_id is None:
api_logger.warning(f"终端用户 {payload.group_id} 未关联记忆配置")
return fail(BizCode.INVALID_PARAMETER, f"终端用户 {payload.group_id} 未关联记忆配置", "memory_config_id is None")
api_logger.debug(f"通过 group_id={payload.group_id} 获取到 config_id={config_id}")
except ValueError as e:
api_logger.warning(f"获取终端用户配置失败: {str(e)}")
return fail(BizCode.INVALID_PARAMETER, str(e), "ValueError")
except Exception as e:
api_logger.error(f"获取终端用户配置时发生错误: {str(e)}")
return fail(BizCode.INTERNAL_ERROR, "获取终端用户配置失败", str(e))
api_logger.info(
f"用户 {current_user.username} 在工作空间 {workspace_id} 请求触发遗忘周期: "
f"group_id={payload.group_id}, max_batch={payload.max_merge_batch_size}, "
f"group_id={payload.group_id}, config_id={config_id}, max_batch={payload.max_merge_batch_size}, "
f"min_days={payload.min_days_since_access}"
)
@@ -89,7 +108,7 @@ async def trigger_forgetting_cycle(
group_id=payload.group_id,
max_merge_batch_size=payload.max_merge_batch_size,
min_days_since_access=payload.min_days_since_access,
config_id=payload.config_id
config_id=config_id
)
# 构建响应
@@ -217,7 +236,6 @@ async def update_forgetting_config(
@router.get("/stats", response_model=ApiResponse)
async def get_forgetting_stats(
group_id: Optional[str] = None,
config_id: Optional[int] = None,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
@@ -227,8 +245,7 @@ async def get_forgetting_stats(
返回知识层节点统计、激活值分布等信息。
Args:
group_id: 组ID可选
config_id: 配置ID可选用于获取遗忘阈值
group_id: 组ID即 end_user_id可选)
current_user: 当前用户
db: 数据库会话
@@ -242,6 +259,27 @@ async def get_forgetting_stats(
api_logger.warning(f"用户 {current_user.username} 尝试获取遗忘引擎统计但未选择工作空间")
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
# 如果提供了 group_id通过它获取 config_id
config_id = None
if group_id:
try:
from app.services.memory_agent_service import get_end_user_connected_config
connected_config = get_end_user_connected_config(group_id, db)
config_id = connected_config.get("memory_config_id")
if config_id is None:
api_logger.warning(f"终端用户 {group_id} 未关联记忆配置")
return fail(BizCode.INVALID_PARAMETER, f"终端用户 {group_id} 未关联记忆配置", "memory_config_id is None")
api_logger.debug(f"通过 group_id={group_id} 获取到 config_id={config_id}")
except ValueError as e:
api_logger.warning(f"获取终端用户配置失败: {str(e)}")
return fail(BizCode.INVALID_PARAMETER, str(e), "ValueError")
except Exception as e:
api_logger.error(f"获取终端用户配置时发生错误: {str(e)}")
return fail(BizCode.INTERNAL_ERROR, "获取终端用户配置失败", str(e))
api_logger.info(
f"用户 {current_user.username} 在工作空间 {workspace_id} 请求获取遗忘引擎统计: "
f"group_id={group_id}, config_id={config_id}"

View File

@@ -0,0 +1,255 @@
import uuid
from typing import Optional
from fastapi import APIRouter, Depends, Query
from sqlalchemy.orm import Session
from app.core.error_codes import BizCode
from app.core.logging_config import get_api_logger
from app.core.response_utils import success, fail
from app.db import get_db
from app.dependencies import get_current_user
from app.models import User
from app.models.memory_perceptual_model import PerceptualType
from app.schemas.memory_perceptual_schema import (
PerceptualQuerySchema,
PerceptualFilter
)
from app.schemas.response_schema import ApiResponse
from app.services.memory_perceptual_service import MemoryPerceptualService
api_logger = get_api_logger()
router = APIRouter(
prefix="/memory/perceptual",
tags=["Perceptual Memory System"],
dependencies=[Depends(get_current_user)]
)
@router.get("/{group_id}/count", response_model=ApiResponse)
def get_memory_count(
group_id: uuid.UUID,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""Retrieve perceptual memory statistics for a user group.
Args:
group_id: ID of the user group (usually end_user_id in this context)
current_user: Current authenticated user
db: Database session
Returns:
ApiResponse: Response containing memory count statistics
"""
api_logger.info(f"Fetching perceptual memory statistics: user={current_user.username}, group_id={group_id}")
try:
service = MemoryPerceptualService(db)
count_stats = service.get_memory_count(group_id)
api_logger.info(f"Memory statistics fetched successfully: total={count_stats.get('total', 0)}")
return success(
data=count_stats,
msg="Memory statistics retrieved successfully"
)
except Exception as e:
api_logger.error(f"Failed to fetch memory statistics: group_id={group_id}, error={str(e)}")
return fail(
code=BizCode.INTERNAL_ERROR,
msg="Failed to fetch memory statistics",
)
@router.get("/{group_id}/last_visual", response_model=ApiResponse)
def get_last_visual_memory(
group_id: uuid.UUID,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""Retrieve the most recent VISION-type memory for a user.
Args:
group_id: ID of the user group
current_user: Current authenticated user
db: Database session
Returns:
ApiResponse: Metadata of the latest visual memory
"""
api_logger.info(f"Fetching latest visual memory: user={current_user.username}, group_id={group_id}")
try:
service = MemoryPerceptualService(db)
visual_memory = service.get_latest_visual_memory(group_id)
if visual_memory is None:
api_logger.info(f"No visual memory found: group_id={group_id}")
return success(
data=None,
msg="No visual memory available"
)
api_logger.info(f"Latest visual memory retrieved successfully: file={visual_memory.get('file_name')}")
return success(
data=visual_memory,
msg="Latest visual memory retrieved successfully"
)
except Exception as e:
api_logger.error(f"Failed to fetch latest visual memory: group_id={group_id}, error={str(e)}")
return fail(
code=BizCode.INTERNAL_ERROR,
msg="Failed to fetch latest visual memory",
)
@router.get("/{group_id}/last_listen", response_model=ApiResponse)
def get_last_memory_listen(
group_id: uuid.UUID,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""Retrieve the most recent AUDIO-type memory for a user.
Args:
group_id: ID of the user group
current_user: Current authenticated user
db: Database session
Returns:
ApiResponse: Metadata of the latest audio memory
"""
api_logger.info(f"Fetching latest audio memory: user={current_user.username}, group_id={group_id}")
try:
service = MemoryPerceptualService(db)
audio_memory = service.get_latest_audio_memory(group_id)
if audio_memory is None:
api_logger.info(f"No audio memory found: group_id={group_id}")
return success(
data=None,
msg="No audio memory available"
)
api_logger.info(f"Latest audio memory retrieved successfully: file={audio_memory.get('file_name')}")
return success(
data=audio_memory,
msg="Latest audio memory retrieved successfully"
)
except Exception as e:
api_logger.error(f"Failed to fetch latest audio memory: group_id={group_id}, error={str(e)}")
return fail(
code=BizCode.INTERNAL_ERROR,
msg="Failed to fetch latest audio memory",
)
@router.get("/{group_id}/last_text", response_model=ApiResponse)
def get_last_text_memory(
group_id: uuid.UUID,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""Retrieve the most recent TEXT-type memory for a user.
Args:
group_id: ID of the user group
current_user: Current authenticated user
db: Database session
Returns:
ApiResponse: Metadata of the latest text memory
"""
api_logger.info(f"Fetching latest text memory: user={current_user.username}, group_id={group_id}")
try:
# 调用服务层获取最近的文本记忆
service = MemoryPerceptualService(db)
text_memory = service.get_latest_text_memory(group_id)
if text_memory is None:
api_logger.info(f"No text memory found: group_id={group_id}")
return success(
data=None,
msg="No text memory available"
)
api_logger.info(f"Latest text memory retrieved successfully: file={text_memory.get('file_name')}")
return success(
data=text_memory,
msg="Latest text memory retrieved successfully"
)
except Exception as e:
api_logger.error(f"Failed to fetch latest text memory: group_id={group_id}, error={str(e)}")
return fail(
code=BizCode.INTERNAL_ERROR,
msg="Failed to fetch latest text memory",
)
@router.get("/{group_id}/timeline", response_model=ApiResponse)
def get_memory_time_line(
group_id: uuid.UUID,
perceptual_type: Optional[PerceptualType] = Query(None, description="感知类型过滤"),
page: int = Query(1, ge=1, description="页码"),
page_size: int = Query(10, ge=1, le=100, description="每页大小"),
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""Retrieve a timeline of perceptual memories for a user group.
Args:
group_id: ID of the user group
perceptual_type: Optional filter for perceptual type
page: Page number for pagination
page_size: Number of items per page
current_user: Current authenticated user
db: Database session
Returns:
ApiResponse: Timeline data of perceptual memories
"""
api_logger.info(
f"Fetching perceptual memory timeline: user={current_user.username}, "
f"group_id={group_id}, type={perceptual_type}, page={page}"
)
try:
query = PerceptualQuerySchema(
filter=PerceptualFilter(type=perceptual_type),
page=page,
page_size=page_size
)
service = MemoryPerceptualService(db)
timeline_data = service.get_time_line(group_id, query)
api_logger.info(
f"Perceptual memory timeline retrieved successfully: total={timeline_data.total}, "
f"returned={len(timeline_data.memories)}"
)
return success(
data=timeline_data.model_dump(),
msg="Perceptual memory timeline retrieved successfully"
)
except Exception as e:
api_logger.error(
f"Failed to fetch perceptual memory timeline: group_id={group_id}, "
f"error={str(e)}"
)
return fail(
code=BizCode.INTERNAL_ERROR,
msg="Failed to fetch perceptual memory timeline",
)

View File

@@ -0,0 +1,43 @@
from fastapi import APIRouter, Depends, HTTPException, status
from app.core.logging_config import get_api_logger
from app.core.response_utils import success
from app.db import get_db
from app.dependencies import get_current_user
from app.models.user_model import User
from app.services.memory_storage_service import search_entity
from app.services.memory_short_service import ShortService,LongService
from dotenv import load_dotenv
from sqlalchemy.orm import Session
from typing import Optional
load_dotenv()
api_logger = get_api_logger()
router = APIRouter(
prefix="/memory/short",
tags=["Memory"],
)
@router.get("/short_term")
async def short_term_configs(
end_user_id: str,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):
# 获取短期记忆数据
short_term=ShortService(end_user_id)
short_result=short_term.get_short_databasets()
short_count=short_term.get_short_count()
long_term=LongService(end_user_id)
long_result=long_term.get_long_databasets()
entity_result = await search_entity(end_user_id)
result = {
'short_term': short_result,
'long_term': long_result,
'entity': entity_result.get('num', 0),
"retrieval_number":short_count,
"long_term_number":len(long_result)
}
return success(data=result, msg="短期记忆系统数据获取成功")

View File

@@ -0,0 +1,134 @@
import uuid
from fastapi import APIRouter, Depends
from sqlalchemy.orm import Session
from app.core.logging_config import get_api_logger
from app.core.response_utils import success
from app.db import get_db
from app.dependencies import get_current_user
from app.models import User
from app.schemas.response_schema import ApiResponse
from app.services.conversation_service import ConversationService
api_logger = get_api_logger()
router = APIRouter(
prefix="/memory/work",
tags=["Working Memory System"],
dependencies=[Depends(get_current_user)]
)
@router.get("/{group_id}/count", response_model=ApiResponse)
def get_memory_count(
group_id: uuid.UUID,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
pass
@router.get("/{group_id}/conversations", response_model=ApiResponse)
def get_conversations(
group_id: uuid.UUID,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""
Retrieve all conversations for the current user in a specific group.
Args:
group_id (UUID): The group identifier.
current_user (User, optional): The authenticated user.
db (Session, optional): SQLAlchemy session.
Returns:
ApiResponse: Contains a list of conversation IDs.
Notes:
- Initializes the ConversationService with the current DB session.
- Returns only conversation IDs for lightweight response.
- Logs can be added to trace requests in production.
"""
conversation_service = ConversationService(db)
conversations = conversation_service.get_user_conversations(
group_id
)
return success(data=[
{
"id": conversation.id,
"title": conversation.title
} for conversation in conversations
], msg="get conversations success")
@router.get("/{group_id}/messages", response_model=ApiResponse)
def get_messages(
conversation_id: uuid.UUID,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""
Retrieve the message history for a specific conversation.
Args:
conversation_id (UUID): The ID of the conversation to fetch messages from.
current_user (User, optional): The authenticated user.
db (Session, optional): SQLAlchemy session.
Returns:
ApiResponse: Contains the list of messages in the conversation.
Notes:
- Uses ConversationService to fetch messages.
- Consider paginating results if message history is large.
- Logging can be added for audit and debugging.
"""
conversation_service = ConversationService(db)
messages_obj = conversation_service.get_messages(
conversation_id,
)
messages = [
{
"role": message.role,
"content": message.content,
"created_at": int(message.created_at.timestamp() * 1000),
}
for message in messages_obj
]
return success(data=messages, msg="get conversation history success")
@router.get("/{group_id}/detail", response_model=ApiResponse)
async def get_conversation_detail(
conversation_id: uuid.UUID,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""
Retrieve detailed information about a specific conversation.
This endpoint will fetch the conversation detail for the user. If the detail
does not exist or is outdated, it will trigger the LLM to generate a new summary.
Args:
conversation_id (UUID): The ID of the conversation.
current_user (User, optional): The authenticated user making the request.
db (Session, optional): SQLAlchemy session.
Returns:
ApiResponse: Contains the conversation detail serialized as a dictionary.
Notes:
- Uses async ConversationService to fetch or generate the conversation detail.
- Handles workspace and user-specific context automatically.
- Logging and exception handling should be implemented for production monitoring.
"""
conversation_service = ConversationService(db)
detail = await conversation_service.get_conversation_detail(
user=current_user,
conversation_id=conversation_id,
workspace_id=current_user.current_workspace_id
)
return success(data=detail.model_dump(), msg="get conversation detail success")

View File

@@ -108,16 +108,23 @@ async def get_prompt_opt(
service = PromptOptimizerService(db)
async def event_generator():
async for chunk in service.optimize_prompt(
tenant_id=current_user.tenant_id,
model_id=data.model_id,
session_id=session_id,
user_id=current_user.id,
current_prompt=data.current_prompt,
user_require=data.message
):
# chunk 是 prompt 的增量内容
yield f"event:'message'\ndata: {json.dumps(chunk)}\n\n"
yield "event:start\ndata: {}\n\n"
try:
async for chunk in service.optimize_prompt(
tenant_id=current_user.tenant_id,
model_id=data.model_id,
session_id=session_id,
user_id=current_user.id,
current_prompt=data.current_prompt,
user_require=data.message
):
# chunk 是 prompt 的增量内容
yield f"event:message\ndata: {json.dumps(chunk)}\n\n"
except Exception as e:
yield f"event:error\ndata: {json.dumps(
{"error": str(e)}
)}\n\n"
yield "event:end\ndata: {}\n\n"
return StreamingResponse(
event_generator(),

View File

@@ -1,4 +1,5 @@
import hashlib
import json
import uuid
from typing import Annotated
from fastapi import APIRouter, Depends, Query, Request
@@ -18,7 +19,7 @@ from app.services.conversation_service import ConversationService
from app.services.release_share_service import ReleaseShareService
from app.services.shared_chat_service import SharedChatService
from app.services.app_chat_service import AppChatService, get_app_chat_service
from app.utils.app_config_utils import dict_to_multi_agent_config, dict_to_workflow_config, agent_config_4_app_release, multi_agent_config_4_app_release
from app.utils.app_config_utils import dict_to_multi_agent_config, workflow_config_4_app_release, agent_config_4_app_release, multi_agent_config_4_app_release
router = APIRouter(prefix="/public/share", tags=["Public Share"])
logger = get_business_logger()
@@ -288,7 +289,7 @@ async def chat(
password = None # Token 认证不需要密码
# end_user_id = user_id
other_id = user_id
# 提前验证和准备(在流式响应开始前完成)
# 这样可以确保错误能正确返回,而不是在流式响应中间出错
from app.models.app_model import AppType
@@ -364,6 +365,9 @@ async def chat(
config = release.config or {}
if not config.get("sub_agents"):
raise BusinessException("多 Agent 应用未配置子 Agent", BizCode.AGENT_CONFIG_MISSING)
elif app_type == AppType.WORKFLOW:
# Multi-Agent 类型:验证多 Agent 配置
pass
else:
raise BusinessException(f"不支持的应用类型: {app_type}", BizCode.APP_TYPE_NOT_SUPPORTED)
@@ -392,6 +396,8 @@ async def chat(
if app_type == AppType.AGENT:
# 流式返回
agent_config = agent_config_4_app_release(release)
if payload.stream:
# async def event_generator():
# async for event in service.chat_stream(
@@ -424,10 +430,11 @@ async def chat(
user_id= str(new_end_user.id), # 转换为字符串
variables=payload.variables,
web_search=payload.web_search,
config=payload.agent_config,
config=agent_config,
memory=payload.memory,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id
user_rag_memory_id=user_rag_memory_id,
workspace_id=workspace_id
):
yield event
@@ -463,10 +470,12 @@ async def chat(
web_search=payload.web_search,
memory=payload.memory,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id
user_rag_memory_id=user_rag_memory_id,
workspace_id=workspace_id
)
return success(data=conversation_schema.ChatResponse(**result).model_dump(mode="json"))
elif app_type == AppType.MULTI_AGENT:
# config = workflow_config_4_app_release(release)
config = multi_agent_config_4_app_release(release)
if payload.stream:
async def event_generator():
@@ -479,8 +488,8 @@ async def chat(
config=config,
web_search=payload.web_search,
memory=payload.memory,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id
):
yield event
@@ -551,8 +560,71 @@ async def chat(
# )
# return success(data=conversation_schema.ChatResponse(**result))
elif app_type == AppType.WORKFLOW:
config = workflow_config_4_app_release(release)
if payload.stream:
async def event_generator():
async for event in app_chat_service.workflow_chat_stream(
message=payload.message,
conversation_id=conversation.id, # 使用已创建的会话 ID
user_id=new_end_user.id, # 转换为字符串
variables=payload.variables,
config=config,
web_search=payload.web_search,
memory=payload.memory,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id,
app_id=release.app_id,
workspace_id=workspace_id
):
event_type = event.get("event", "message")
event_data = event.get("data", {})
# 转换为标准 SSE 格式(字符串)
sse_message = f"event: {event_type}\ndata: {json.dumps(event_data)}\n\n"
yield sse_message
return StreamingResponse(
event_generator(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no"
}
)
# 多 Agent 非流式返回
result = await app_chat_service.workflow_chat(
message=payload.message,
conversation_id=conversation.id, # 使用已创建的会话 ID
user_id=new_end_user.id, # 转换为字符串
variables=payload.variables,
config=config,
web_search=payload.web_search,
memory=payload.memory,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id,
app_id=release.app_id,
workspace_id=workspace_id
)
logger.debug(
"工作流试运行返回结果",
extra={
"result_type": str(type(result)),
"has_response": "response" in result if isinstance(result, dict) else False
}
)
return success(
data=result,
msg="工作流任务执行成功"
)
# return success(data=conversation_schema.ChatResponse(**result).model_dump(mode="json"))
else:
from app.core.exceptions import BusinessException
from app.core.error_codes import BizCode
raise BusinessException(f"不支持的应用类型: {app_type}", BizCode.APP_TYPE_NOT_SUPPORTED)
pass

View File

@@ -1,4 +1,5 @@
"""App 服务接口 - 基于 API Key 认证"""
import json
from typing import Annotated
from fastapi import APIRouter, Depends, Request, Body
@@ -21,7 +22,7 @@ from app.schemas.api_key_schema import ApiKeyAuth
from app.services import workspace_service
from app.services.app_chat_service import AppChatService, get_app_chat_service
from app.services.conversation_service import ConversationService, get_conversation_service
from app.utils.app_config_utils import dict_to_multi_agent_config, dict_to_workflow_config, agent_config_4_app_release, multi_agent_config_4_app_release
from app.utils.app_config_utils import dict_to_multi_agent_config, workflow_config_4_app_release, agent_config_4_app_release, multi_agent_config_4_app_release
from app.services.app_service import get_app_service, AppService
router = APIRouter(prefix="/app", tags=["V1 - App API"])
@@ -153,7 +154,8 @@ async def chat(
config=agent_config,
memory=memory,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id
user_rag_memory_id=user_rag_memory_id,
workspace_id=workspace_id
):
yield event
@@ -177,7 +179,8 @@ async def chat(
web_search=web_search,
memory=memory,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id
user_rag_memory_id=user_rag_memory_id,
workspace_id=workspace_id
)
return success(data=conversation_schema.ChatResponse(**result).model_dump(mode="json"))
elif app_type == AppType.MULTI_AGENT:
@@ -194,8 +197,8 @@ async def chat(
config=config,
web_search=web_search,
memory=memory,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id
):
yield event
@@ -211,7 +214,6 @@ async def chat(
# 多 Agent 非流式返回
result = await app_chat_service.multi_agent_chat(
message=payload.message,
conversation_id=conversation.id, # 使用已创建的会话 ID
user_id=end_user_id, # 转换为字符串
@@ -226,22 +228,29 @@ async def chat(
return success(data=conversation_schema.ChatResponse(**result).model_dump(mode="json"))
elif app_type == AppType.WORKFLOW:
# 多 Agent 流式返回
config = dict_to_workflow_config(app.current_release.config,app.id)
config = workflow_config_4_app_release(app.current_release)
if payload.stream:
async def event_generator():
async for event in app_chat_service.workflow_chat_stream(
message=payload.message,
conversation_id=conversation.id, # 使用已创建的会话 ID
user_id=end_user_id, # 转换为字符串
user_id=new_end_user.id, # 转换为字符串
variables=payload.variables,
config=config,
web_search=web_search,
memory=memory,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id
web_search=payload.web_search,
memory=payload.memory,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id,
app_id=app.app_id,
workspace_id=workspace_id
):
yield event
event_type = event.get("event", "message")
event_data = event.get("data", {})
# 转换为标准 SSE 格式(字符串)
sse_message = f"event: {event_type}\ndata: {json.dumps(event_data)}\n\n"
yield sse_message
return StreamingResponse(
event_generator(),
@@ -253,23 +262,34 @@ async def chat(
}
)
# 非流式返回
# 多 Agent 非流式返回
result = await app_chat_service.workflow_chat(
message=payload.message,
conversation_id=conversation.id, # 使用已创建的会话 ID
user_id=end_user_id, # 转换为字符串
user_id=new_end_user.id, # 转换为字符串
variables=payload.variables,
config=config,
web_search=web_search,
memory=memory,
web_search=payload.web_search,
memory=payload.memory,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id
user_rag_memory_id=user_rag_memory_id,
app_id=app.app_id,
workspace_id=workspace_id
)
logger.debug(
"工作流试运行返回结果",
extra={
"result_type": str(type(result)),
"has_response": "response" in result if isinstance(result, dict) else False
}
)
return success(
data=result,
msg="工作流任务执行成功"
)
return success(data=conversation_schema.ChatResponse(**result).model_dump(mode="json"))
else:
from app.core.exceptions import BusinessException
from app.core.error_codes import BizCode
raise BusinessException(f"不支持的应用类型: {app_type}", BizCode.APP_TYPE_NOT_SUPPORTED)
pass

View File

@@ -1,23 +1,22 @@
from fastapi import APIRouter, Depends, status, Query, HTTPException
from langchain_core.messages import HumanMessage, SystemMessage
from fastapi import APIRouter, Depends, status, HTTPException, Body, Path
from fastapi.responses import StreamingResponse
from langchain_core.prompts import ChatPromptTemplate
from sqlalchemy.orm import Session
from typing import List, Optional
import uuid
from app.core.models import RedBearLLM, RedBearRerank
from app.core.models.base import RedBearModelConfig
from app.core.models.embedding import RedBearEmbeddings
from app.db import get_db
from app.dependencies import get_current_user
from app.models.models_model import ModelApiKey, ModelProvider, ModelType
from app.models.user_model import User
from app.schemas import model_schema
from app.models.models_model import ModelApiKey
from app.core.response_utils import success
from app.schemas.response_schema import ApiResponse, PageData
from app.services.model_service import ModelConfigService, ModelApiKeyService
from app.schemas.response_schema import ApiResponse
from app.schemas.app_schema import AppChatRequest
from app.services.model_service import ModelConfigService
from app.services.handoffs_service import get_handoffs_service_for_app, reset_handoffs_service_cache
from app.services.conversation_service import ConversationService
from app.core.logging_config import get_api_logger
from app.dependencies import get_current_user
# 获取API专用日志器
api_logger = get_api_logger()
@@ -28,6 +27,8 @@ router = APIRouter(
)
# ==================== 原有测试接口 ====================
@router.get("/llm/{model_id}", response_model=ApiResponse)
def test_llm(
model_id: uuid.UUID,
@@ -50,7 +51,6 @@ def test_llm(
template = """Question: {question}
Answer: Let's think step by step."""
# ChatPromptTemplate
prompt = ChatPromptTemplate.from_template(template)
chain = prompt | llm
answer = chain.invoke({"question": "What is LangChain?"})
@@ -80,13 +80,13 @@ def test_embedding(
base_url=apiConfig.api_base
))
data = [
"最近哪家咖啡店评价最好?",
"附近有没有推荐的咖啡厅?",
"明天天气预报说会下雨。",
"北京是中国的首都。",
"我想找一个适合学习的地方。"
]
data = [
"最近哪家咖啡店评价最好?",
"附近有没有推荐的咖啡厅?",
"明天天气预报说会下雨。",
"北京是中国的首都。",
"我想找一个适合学习的地方。"
]
embeddings = model.embed_documents(data)
print(embeddings)
query = "我想找一个适合学习的地方。"
@@ -114,13 +114,123 @@ def test_rerank(
base_url=apiConfig.api_base
))
query = "最近哪家咖啡店评价最好?"
data = [
"最近哪家咖啡店评价最好?",
"附近有没有推荐的咖啡厅?",
"明天天气预报说会下雨。",
"北京是中国的首都。",
"我想找一个适合学习的地方。"
]
data = [
"最近哪家咖啡店评价最好?",
"附近有没有推荐的咖啡厅?",
"明天天气预报说会下雨。",
"北京是中国的首都。",
"我想找一个适合学习的地方。"
]
scores = model.rerank(query=query, documents=data, top_n=3)
print(scores)
return success(msg="测试Rerank成功", data={"query": query, "documents": data, "scores": scores})
# ==================== Handoffs 测试接口 ====================
@router.post("/handoffs/{app_id}")
async def test_handoffs(
app_id: uuid.UUID = Path(..., description="应用 ID"),
request: AppChatRequest = Body(...),
current_user=Depends(get_current_user),
db: Session = Depends(get_db)
):
"""测试 Agent Handoffs 功能
演示 LangGraph 实现的多 Agent 协作和动态切换
- 从数据库 multi_agent_config 获取 Agent 配置
- 根据用户问题自动切换到合适的 Agent
- 使用 conversation_id 保持会话状态
- 通过 stream 参数控制是否流式输出
事件类型(流式):
- start: 开始执行
- agent: 当前 Agent 信息
- message: 流式消息内容
- handoff: Agent 切换事件
- end: 执行结束
- error: 错误信息
"""
try:
workspace_id = current_user.current_workspace_id
# 获取或创建会话
conversation_service = ConversationService(db)
if request.conversation_id:
# 验证会话存在
conversation = conversation_service.get_conversation(uuid.UUID(request.conversation_id))
if not conversation:
raise HTTPException(status_code=404, detail="会话不存在")
conversation_id = str(conversation.id)
else:
# 创建新会话
conversation = conversation_service.create_or_get_conversation(
app_id=app_id,
workspace_id=workspace_id,
user_id=request.user_id,
is_draft=True
)
conversation_id = str(conversation.id)
# 根据 stream 参数决定返回方式
if request.stream:
# 流式返回
service = get_handoffs_service_for_app(app_id, db, streaming=True)
return StreamingResponse(
service.chat_stream(
message=request.message,
conversation_id=conversation_id
),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no"
}
)
else:
# 非流式返回
service = get_handoffs_service_for_app(app_id, db, streaming=False)
result = await service.chat(
message=request.message,
conversation_id=conversation_id
)
return success(data=result, msg="Handoffs 测试成功")
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except HTTPException:
raise
except Exception as e:
api_logger.error(f"Handoffs 测试失败: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@router.get("/handoffs/{app_id}/agents", response_model=ApiResponse)
def get_handoff_agents(
app_id: uuid.UUID = Path(..., description="应用 ID"),
db: Session = Depends(get_db),
current_user=Depends(get_current_user)
):
"""获取应用的 Handoff Agent 列表"""
try:
service = get_handoffs_service_for_app(app_id, db, streaming=False)
agents = service.get_agents()
return success(data={"agents": agents}, msg="获取 Agent 列表成功")
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
api_logger.error(f"获取 Agent 列表失败: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@router.delete("/handoffs/{app_id}/reset")
def reset_handoff_service(
app_id: uuid.UUID = Path(..., description="应用 ID"),
current_user=Depends(get_current_user)
):
"""重置指定应用的 Handoff 服务缓存"""
reset_handoffs_service_cache(app_id)
return success(msg="Handoff 服务已重置")

View File

@@ -215,8 +215,8 @@ async def sync_mcp_tools(
"""同步MCP工具列表"""
try:
result = await service.sync_mcp_tools(tool_id, current_user.tenant_id)
if result["success"] is False:
raise HTTPException(status_code=404, detail=result["message"])
if not result.get("success", False):
raise HTTPException(status_code=400, detail=result.get("message", "同步失败"))
return success(data=result, msg="MCP工具列表同步完成")
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))

View File

@@ -11,14 +11,22 @@ from app.db import get_db
from app.core.logging_config import get_api_logger
from app.core.response_utils import success, fail
from app.core.error_codes import BizCode
from app.core.api_key_utils import timestamp_to_datetime
from app.services.user_memory_service import (
UserMemoryService,
analytics_node_statistics,
analytics_memory_types,
analytics_graph_data,
)
from app.services.memory_entity_relationship_service import MemoryEntityService,MemoryEmotion,MemoryInteraction
from app.schemas.response_schema import ApiResponse
from app.schemas.memory_storage_schema import GenerateCacheRequest
from app.schemas.user_memory_schema import (
EpisodicMemoryOverviewRequest,
EpisodicMemoryDetailsRequest,
ExplicitMemoryOverviewRequest,
ExplicitMemoryDetailsRequest,
)
from app.schemas.end_user_schema import (
EndUserProfileResponse,
EndUserProfileUpdate,
@@ -41,24 +49,27 @@ router = APIRouter(
@router.get("/analytics/memory_insight/report", response_model=ApiResponse)
async def get_memory_insight_report_api(
end_user_id: str, # 使用 end_user_id
end_user_id: str,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
) -> dict:
"""获取缓存的记忆洞察报告"""
api_logger.info(f"记忆洞察报告请求: end_user_id={end_user_id}, user={current_user.username}")
) -> dict:
"""
获取缓存的记忆洞察报告
此接口仅查询数据库中已缓存的记忆洞察数据,不执行生成操作。
如需生成新的洞察报告,请使用专门的生成接口。
"""
api_logger.info(f"记忆洞察报告查询请求: end_user_id={end_user_id}, user={current_user.username}")
try:
# 调用服务层获取缓存数据
result = await user_memory_service.get_cached_memory_insight(db, end_user_id)
if result["is_cached"]:
# 缓存存在,返回缓存数据
api_logger.info(f"成功返回缓存的记忆洞察报告: end_user_id={end_user_id}")
return success(data=result, msg="查询成功")
else:
# 缓存不存在,返回提示消息
api_logger.info(f"记忆洞察报告缓存不存在: end_user_id={end_user_id}")
return success(data=result, msg="查询成功")
return success(data=result, msg="数据尚未生成")
except Exception as e:
api_logger.error(f"记忆洞察报告查询失败: end_user_id={end_user_id}, error={str(e)}")
return fail(BizCode.INTERNAL_ERROR, "记忆洞察报告查询失败", str(e))
@@ -66,24 +77,27 @@ async def get_memory_insight_report_api(
@router.get("/analytics/user_summary", response_model=ApiResponse)
async def get_user_summary_api(
end_user_id: str, # 使用 end_user_id
end_user_id: str,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
) -> dict:
"""获取缓存的用户摘要"""
api_logger.info(f"用户摘要请求: end_user_id={end_user_id}, user={current_user.username}")
) -> dict:
"""
获取缓存的用户摘要
此接口仅查询数据库中已缓存的用户摘要数据,不执行生成操作。
如需生成新的用户摘要,请使用专门的生成接口。
"""
api_logger.info(f"用户摘要查询请求: end_user_id={end_user_id}, user={current_user.username}")
try:
# 调用服务层获取缓存数据
result = await user_memory_service.get_cached_user_summary(db, end_user_id)
if result["is_cached"]:
# 缓存存在,返回缓存数据
api_logger.info(f"成功返回缓存的用户摘要: end_user_id={end_user_id}")
return success(data=result, msg="查询成功")
else:
# 缓存不存在,返回提示消息
api_logger.info(f"用户摘要缓存不存在: end_user_id={end_user_id}")
return success(data=result, msg="查询成功")
return success(data=result, msg="数据尚未生成")
except Exception as e:
api_logger.error(f"用户摘要查询失败: end_user_id={end_user_id}, error={str(e)}")
return fail(BizCode.INTERNAL_ERROR, "用户摘要查询失败", str(e))
@@ -97,35 +111,35 @@ async def generate_cache_api(
) -> dict:
"""
手动触发缓存生成
- 如果提供 end_user_id只为该用户生成
- 如果不提供,为当前工作空间的所有用户生成
"""
workspace_id = current_user.current_workspace_id
# 检查用户是否已选择工作空间
if workspace_id is None:
api_logger.warning(f"用户 {current_user.username} 尝试生成缓存但未选择工作空间")
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
group_id = request.end_user_id
api_logger.info(
f"缓存生成请求: user={current_user.username}, workspace={workspace_id}, "
f"end_user_id={group_id if group_id else '全部用户'}"
)
try:
if group_id:
# 为单个用户生成
api_logger.info(f"开始为单个用户生成缓存: end_user_id={group_id}")
# 生成记忆洞察
insight_result = await user_memory_service.generate_and_cache_insight(db, group_id, workspace_id)
# 生成用户摘要
summary_result = await user_memory_service.generate_and_cache_summary(db, group_id, workspace_id)
# 构建响应
result = {
"end_user_id": group_id,
@@ -133,7 +147,7 @@ async def generate_cache_api(
"summary_success": summary_result["success"],
"errors": []
}
# 收集错误信息
if not insight_result["success"]:
result["errors"].append({
@@ -145,29 +159,29 @@ async def generate_cache_api(
"type": "summary",
"error": summary_result.get("error")
})
# 记录结果
if result["insight_success"] and result["summary_success"]:
api_logger.info(f"成功为用户 {group_id} 生成缓存")
else:
api_logger.warning(f"用户 {group_id} 的缓存生成部分失败: {result['errors']}")
return success(data=result, msg="生成完成")
else:
# 为整个工作空间生成
api_logger.info(f"开始为工作空间 {workspace_id} 批量生成缓存")
result = await user_memory_service.generate_cache_for_workspace(db, workspace_id)
# 记录统计信息
api_logger.info(
f"工作空间 {workspace_id} 批量生成完成: "
f"总数={result['total_users']}, 成功={result['successful']}, 失败={result['failed']}"
)
return success(data=result, msg="批量生成完成")
except Exception as e:
api_logger.error(f"缓存生成失败: user={current_user.username}, error={str(e)}")
return fail(BizCode.INTERNAL_ERROR, "缓存生成失败", str(e))
@@ -180,18 +194,18 @@ async def get_node_statistics_api(
db: Session = Depends(get_db),
) -> dict:
workspace_id = current_user.current_workspace_id
# 检查用户是否已选择工作空间
if workspace_id is None:
api_logger.warning(f"用户 {current_user.username} 尝试查询节点统计但未选择工作空间")
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
api_logger.info(f"记忆类型统计请求: end_user_id={end_user_id}, user={current_user.username}, workspace={workspace_id}")
try:
# 调用新的记忆类型统计函数
result = await analytics_memory_types(db, end_user_id)
# 计算总数用于日志
total_count = sum(item["count"] for item in result)
api_logger.info(f"成功获取记忆类型统计: end_user_id={end_user_id}, 总记忆数={total_count}, 类型数={len(result)}")
@@ -211,31 +225,31 @@ async def get_graph_data_api(
db: Session = Depends(get_db),
) -> dict:
workspace_id = current_user.current_workspace_id
# 检查用户是否已选择工作空间
if workspace_id is None:
api_logger.warning(f"用户 {current_user.username} 尝试查询图数据但未选择工作空间")
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
# 参数验证
if limit > 1000:
limit = 1000
api_logger.warning("limit 参数超过最大值,已调整为 1000")
if depth > 3:
depth = 3
api_logger.warning("depth 参数超过最大值,已调整为 3")
# 解析 node_types 参数
node_types_list = None
if node_types:
node_types_list = [t.strip() for t in node_types.split(",") if t.strip()]
api_logger.info(
f"图数据查询请求: end_user_id={end_user_id}, user={current_user.username}, "
f"workspace={workspace_id}, node_types={node_types_list}, limit={limit}, depth={depth}"
)
try:
result = await analytics_graph_data(
db=db,
@@ -245,19 +259,19 @@ async def get_graph_data_api(
depth=depth,
center_node_id=center_node_id
)
# 检查是否有错误消息
if "message" in result and result["statistics"]["total_nodes"] == 0:
api_logger.warning(f"图数据查询返回空结果: {result.get('message')}")
return success(data=result, msg=result.get("message", "查询成功"))
api_logger.info(
f"成功获取图数据: end_user_id={end_user_id}, "
f"nodes={result['statistics']['total_nodes']}, "
f"edges={result['statistics']['total_edges']}"
)
return success(data=result, msg="查询成功")
except Exception as e:
api_logger.error(f"图数据查询失败: end_user_id={end_user_id}, error={str(e)}")
return fail(BizCode.INTERNAL_ERROR, "图数据查询失败", str(e))
@@ -270,25 +284,25 @@ async def get_end_user_profile(
db: Session = Depends(get_db),
) -> dict:
workspace_id = current_user.current_workspace_id
# 检查用户是否已选择工作空间
if workspace_id is None:
api_logger.warning(f"用户 {current_user.username} 尝试查询用户信息但未选择工作空间")
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
api_logger.info(
f"用户信息查询请求: end_user_id={end_user_id}, user={current_user.username}, "
f"workspace={workspace_id}"
)
try:
# 查询终端用户
end_user = db.query(EndUser).filter(EndUser.id == end_user_id).first()
if not end_user:
api_logger.warning(f"终端用户不存在: end_user_id={end_user_id}")
return fail(BizCode.INVALID_PARAMETER, "终端用户不存在", f"end_user_id={end_user_id}")
# 构建响应数据
profile_data = EndUserProfileResponse(
id=end_user.id,
@@ -300,10 +314,10 @@ async def get_end_user_profile(
hire_date=end_user.hire_date,
updatetime_profile=end_user.updatetime_profile
)
api_logger.info(f"成功获取用户信息: end_user_id={end_user_id}")
return success(data=UserMemoryService.convert_profile_to_dict_with_timestamp(profile_data), msg="查询成功")
except Exception as e:
api_logger.error(f"用户信息查询失败: end_user_id={end_user_id}, error={str(e)}")
return fail(BizCode.INTERNAL_ERROR, "用户信息查询失败", str(e))
@@ -317,56 +331,56 @@ async def update_end_user_profile(
) -> dict:
"""
更新终端用户的基本信息
该接口可以更新用户的姓名、职位、部门、联系方式、电话和入职日期等信息。
所有字段都是可选的,只更新提供的字段。
"""
workspace_id = current_user.current_workspace_id
end_user_id = profile_update.end_user_id
# 检查用户是否已选择工作空间
if workspace_id is None:
api_logger.warning(f"用户 {current_user.username} 尝试更新用户信息但未选择工作空间")
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
api_logger.info(
f"用户信息更新请求: end_user_id={end_user_id}, user={current_user.username}, "
f"workspace={workspace_id}"
)
try:
# 查询终端用户
end_user = db.query(EndUser).filter(EndUser.id == end_user_id).first()
if not end_user:
api_logger.warning(f"终端用户不存在: end_user_id={end_user_id}")
return fail(BizCode.INVALID_PARAMETER, "终端用户不存在", f"end_user_id={end_user_id}")
# 更新字段(只更新提供的字段,排除 end_user_id
# 允许 None 值来重置字段(如 hire_date
update_data = profile_update.model_dump(exclude_unset=True, exclude={'end_user_id'})
# 特殊处理 hire_date如果提供了时间戳转换为 DateTime
if 'hire_date' in update_data:
hire_date_timestamp = update_data['hire_date']
if hire_date_timestamp is not None:
update_data['hire_date'] = UserMemoryService.timestamp_to_datetime(hire_date_timestamp)
update_data['hire_date'] = timestamp_to_datetime(hire_date_timestamp)
# 如果是 None保持 None允许清空
for field, value in update_data.items():
setattr(end_user, field, value)
# 更新 updated_at 时间戳
end_user.updated_at = datetime.datetime.now()
# 更新 updatetime_profile 为当前时间
end_user.updatetime_profile = datetime.datetime.now()
# 提交更改
db.commit()
db.refresh(end_user)
# 构建响应数据
profile_data = EndUserProfileResponse(
id=end_user.id,
@@ -378,11 +392,243 @@ async def update_end_user_profile(
hire_date=end_user.hire_date,
updatetime_profile=end_user.updatetime_profile
)
api_logger.info(f"成功更新用户信息: end_user_id={end_user_id}, updated_fields={list(update_data.keys())}")
return success(data=UserMemoryService.convert_profile_to_dict_with_timestamp(profile_data), msg="更新成功")
except Exception as e:
db.rollback()
api_logger.error(f"用户信息更新失败: end_user_id={end_user_id}, error={str(e)}")
return fail(BizCode.INTERNAL_ERROR, "用户信息更新失败", str(e))
@router.get("/memory_space/timeline_memories", response_model=ApiResponse)
async def memory_space_timeline_of_shared_memories(id: str, label: str,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):
MemoryEntity = MemoryEntityService(id, label)
timeline_memories_result = await MemoryEntity.get_timeline_memories_server()
return success(data=timeline_memories_result, msg="共同记忆时间线")
@router.get("/memory_space/relationship_evolution", response_model=ApiResponse)
async def memory_space_relationship_evolution(id: str, label: str,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):
try:
api_logger.info(f"关系演变查询请求: id={id}, table={label}, user={current_user.username}")
# 获取情绪数据
emotion = MemoryEmotion(id, label)
emotion_result = await emotion.get_emotion()
# 获取交互数据
interaction = MemoryInteraction(id, label)
interaction_result = await interaction.get_interaction_frequency()
# 关闭连接
await emotion.close()
await interaction.close()
result = {
"emotion": emotion_result,
"interaction": interaction_result
}
api_logger.info(f"关系演变查询成功: id={id}, table={label}")
return success(data=result, msg="关系演变")
except Exception as e:
api_logger.error(f"关系演变查询失败: id={id}, table={label}, error={str(e)}", exc_info=True)
return fail(BizCode.INTERNAL_ERROR, "关系演变查询失败", str(e))
@router.post("/classifications/episodic-memory", response_model=ApiResponse)
async def get_episodic_memory_overview_api(
request: EpisodicMemoryOverviewRequest,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
) -> dict:
"""
获取情景记忆总览
返回指定用户的所有情景记忆列表,包括标题和创建时间。
支持通过时间范围、情景类型和标题关键词进行筛选。
"""
workspace_id = current_user.current_workspace_id
# 检查用户是否已选择工作空间
if workspace_id is None:
api_logger.warning(f"用户 {current_user.username} 尝试查询情景记忆总览但未选择工作空间")
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
# 验证参数
valid_time_ranges = ["all", "today", "this_week", "this_month"]
valid_episodic_types = ["all", "conversation", "project_work", "learning", "decision", "important_event"]
if request.time_range not in valid_time_ranges:
return fail(BizCode.INVALID_PARAMETER, f"无效的时间范围参数,可选值:{', '.join(valid_time_ranges)}")
if request.episodic_type not in valid_episodic_types:
return fail(BizCode.INVALID_PARAMETER, f"无效的情景类型参数,可选值:{', '.join(valid_episodic_types)}")
# 处理 title_keyword去除首尾空格
title_keyword = request.title_keyword.strip() if request.title_keyword else None
api_logger.info(
f"情景记忆总览查询请求: end_user_id={request.end_user_id}, user={current_user.username}, "
f"workspace={workspace_id}, time_range={request.time_range}, episodic_type={request.episodic_type}, "
f"title_keyword={title_keyword}"
)
try:
# 调用Service层方法
result = await user_memory_service.get_episodic_memory_overview(
db, request.end_user_id, request.time_range, request.episodic_type, title_keyword
)
api_logger.info(
f"成功获取情景记忆总览: end_user_id={request.end_user_id}, "
f"total={result['total']}"
)
return success(data=result, msg="查询成功")
except Exception as e:
api_logger.error(f"情景记忆总览查询失败: end_user_id={request.end_user_id}, error={str(e)}")
return fail(BizCode.INTERNAL_ERROR, "情景记忆总览查询失败", str(e))
@router.post("/classifications/episodic-memory-details", response_model=ApiResponse)
async def get_episodic_memory_details_api(
request: EpisodicMemoryDetailsRequest,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
) -> dict:
"""
获取情景记忆详情
返回指定情景记忆的详细信息,包括涉及对象、情景类型、内容记录和情绪。
"""
workspace_id = current_user.current_workspace_id
# 检查用户是否已选择工作空间
if workspace_id is None:
api_logger.warning(f"用户 {current_user.username} 尝试查询情景记忆详情但未选择工作空间")
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
api_logger.info(
f"情景记忆详情查询请求: end_user_id={request.end_user_id}, summary_id={request.summary_id}, "
f"user={current_user.username}, workspace={workspace_id}"
)
try:
# 调用Service层方法
result = await user_memory_service.get_episodic_memory_details(
db=db,
end_user_id=request.end_user_id,
summary_id=request.summary_id
)
api_logger.info(
f"成功获取情景记忆详情: end_user_id={request.end_user_id}, summary_id={request.summary_id}"
)
return success(data=result, msg="查询成功")
except ValueError as e:
# 处理情景记忆不存在的情况
api_logger.warning(f"情景记忆不存在: end_user_id={request.end_user_id}, summary_id={request.summary_id}, error={str(e)}")
return fail(BizCode.INVALID_PARAMETER, "情景记忆不存在", str(e))
except Exception as e:
api_logger.error(f"情景记忆详情查询失败: end_user_id={request.end_user_id}, summary_id={request.summary_id}, error={str(e)}")
return fail(BizCode.INTERNAL_ERROR, "情景记忆详情查询失败", str(e))
@router.post("/classifications/explicit-memory", response_model=ApiResponse)
async def get_explicit_memory_overview_api(
request: ExplicitMemoryOverviewRequest,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
) -> dict:
"""
获取显性记忆总览
返回指定用户的所有显性记忆列表,包括标题、完整内容、创建时间和情绪信息。
"""
workspace_id = current_user.current_workspace_id
# 检查用户是否已选择工作空间
if workspace_id is None:
api_logger.warning(f"用户 {current_user.username} 尝试查询显性记忆总览但未选择工作空间")
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
api_logger.info(
f"显性记忆总览查询请求: end_user_id={request.end_user_id}, user={current_user.username}, "
f"workspace={workspace_id}"
)
try:
# 调用Service层方法
result = await user_memory_service.get_explicit_memory_overview(
db, request.end_user_id
)
api_logger.info(
f"成功获取显性记忆总览: end_user_id={request.end_user_id}, "
f"total={result['total']}"
)
return success(data=result, msg="查询成功")
except Exception as e:
api_logger.error(f"显性记忆总览查询失败: end_user_id={request.end_user_id}, error={str(e)}")
return fail(BizCode.INTERNAL_ERROR, "显性记忆总览查询失败", str(e))
@router.post("/classifications/explicit-memory-details", response_model=ApiResponse)
async def get_explicit_memory_details_api(
request: ExplicitMemoryDetailsRequest,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
) -> dict:
"""
获取显性记忆详情
根据 memory_id 返回情景记忆或语义记忆的详细信息。
- 情景记忆:包括标题、内容、情绪、创建时间
- 语义记忆:包括名称、核心定义、详细笔记、创建时间
"""
workspace_id = current_user.current_workspace_id
# 检查用户是否已选择工作空间
if workspace_id is None:
api_logger.warning(f"用户 {current_user.username} 尝试查询显性记忆详情但未选择工作空间")
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
api_logger.info(
f"显性记忆详情查询请求: end_user_id={request.end_user_id}, memory_id={request.memory_id}, "
f"user={current_user.username}, workspace={workspace_id}"
)
try:
# 调用Service层方法
result = await user_memory_service.get_explicit_memory_details(
db=db,
end_user_id=request.end_user_id,
memory_id=request.memory_id
)
api_logger.info(
f"成功获取显性记忆详情: end_user_id={request.end_user_id}, memory_id={request.memory_id}, "
f"memory_type={result.get('memory_type')}"
)
return success(data=result, msg="查询成功")
except ValueError as e:
# 处理记忆不存在的情况
api_logger.warning(f"显性记忆不存在: end_user_id={request.end_user_id}, memory_id={request.memory_id}, error={str(e)}")
return fail(BizCode.INVALID_PARAMETER, "显性记忆不存在", str(e))
except Exception as e:
api_logger.error(f"显性记忆详情查询失败: end_user_id={request.end_user_id}, memory_id={request.memory_id}, error={str(e)}")
return fail(BizCode.INTERNAL_ERROR, "显性记忆详情查询失败", str(e))

View File

@@ -11,10 +11,16 @@ import os
import time
from typing import Any, AsyncGenerator, Dict, List, Optional, Sequence
from app.db import get_db
from app.core.logging_config import get_business_logger
from app.core.memory.agent.utils.redis_tool import store
from app.core.models import RedBearLLM, RedBearModelConfig
from app.models.models_model import ModelType
from app.repositories.memory_short_repository import LongTermMemoryRepository
from app.services.memory_agent_service import (
get_end_user_connected_config,
)
from app.services.memory_konwledges_server import write_rag
from app.services.task_service import get_task_memory_write_result
from app.tasks import write_message_task
@@ -159,11 +165,13 @@ class LangChainAgent:
history = store.find_user_apply_group(end_user_end, end_user_end, end_user_end)
# logger.info(f'Redis_Agent:{end_user_end};{history}')
messagss_list=[]
retrieved_content=[]
for messages in history:
query = messages.get("Query")
aimessages = messages.get("Answer")
messagss_list.append(f'用户:{query}。AI回复:{aimessages}')
return messagss_list
retrieved_content.append({query: aimessages})
return messagss_list,retrieved_content
async def write(self,storage_type,end_user_id,message,user_rag_memory_id,actual_end_user_id,content,actual_config_id):
@@ -203,7 +211,6 @@ class LangChainAgent:
# If config_id is None, try to get from end_user's connected config
if actual_config_id is None and end_user_id:
try:
from app.db import get_db
from app.services.memory_agent_service import (
get_end_user_connected_config,
)
@@ -221,11 +228,26 @@ class LangChainAgent:
logger.info(f'写入类型{storage_type,str(end_user_id), message, str(user_rag_memory_id)}')
print(f'写入类型{storage_type,str(end_user_id), message, str(user_rag_memory_id)}')
history_term_memory=await self.term_memory_redis_read(end_user_id)
history_term_memory_result = await self.term_memory_redis_read(end_user_id)
history_term_memory = history_term_memory_result[0]
db_for_memory = next(get_db())
if memory_flag:
if len(history_term_memory)>=4 and storage_type != "rag":
history_term_memory=';'.join(history_term_memory)
logger.info(f'写入短长期:{storage_type, str(end_user_id), history_term_memory, str(user_rag_memory_id)}')
history_term_memory = ';'.join(history_term_memory)
retrieved_content = history_term_memory_result[1]
print(retrieved_content)
# 为长期记忆操作获取新的数据库连接
try:
repo = LongTermMemoryRepository(db_for_memory)
repo.upsert(end_user_id, retrieved_content)
logger.info(
f'写入短长期:{storage_type, str(end_user_id), history_term_memory, str(user_rag_memory_id)}')
except Exception as e:
logger.error(f"Failed to write to LongTermMemory: {e}")
raise
finally:
db_for_memory.close()
await self.write(storage_type,end_user_id,history_term_memory,user_rag_memory_id,actual_end_user_id,history_term_memory,actual_config_id)
await self.write(storage_type,end_user_id,message,user_rag_memory_id,actual_end_user_id,message,actual_config_id)
try:
@@ -314,10 +336,6 @@ class LangChainAgent:
# If config_id is None, try to get from end_user's connected config
if actual_config_id is None and end_user_id:
try:
from app.db import get_db
from app.services.memory_agent_service import (
get_end_user_connected_config,
)
db = next(get_db())
try:
connected_config = get_end_user_connected_config(end_user_id, db)
@@ -329,14 +347,24 @@ class LangChainAgent:
except Exception as e:
logger.warning(f"Failed to get db session: {e}")
history_term_memory = await self.term_memory_redis_read(end_user_id)
history_term_memory_result = await self.term_memory_redis_read(end_user_id)
history_term_memory = history_term_memory_result[0]
if memory_flag:
if len(history_term_memory) >= 4 and storage_type != "rag":
history_term_memory = ';'.join(history_term_memory)
logger.info(
f'写入短长期:{storage_type, str(end_user_id), history_term_memory, str(user_rag_memory_id)}')
await self.write(storage_type, end_user_id, history_term_memory, user_rag_memory_id, end_user_id,
history_term_memory, actual_config_id)
retrieved_content = history_term_memory_result[1]
db_for_memory = next(get_db())
try:
repo = LongTermMemoryRepository(db_for_memory)
repo.upsert(end_user_id, retrieved_content)
logger.info(
f'写入短长期:{storage_type, str(end_user_id), history_term_memory, str(user_rag_memory_id)}')
await self.write(storage_type, end_user_id, history_term_memory, user_rag_memory_id, end_user_id,
history_term_memory, actual_config_id)
except Exception as e:
logger.error(f"Failed to write to long term memory: {e}")
finally:
db_for_memory.close()
await self.write(storage_type, end_user_id, message, user_rag_memory_id, end_user_id, message, actual_config_id)
try:

View File

@@ -164,6 +164,10 @@ class Settings:
TOOL_EXECUTION_TIMEOUT: int = int(os.getenv("TOOL_EXECUTION_TIMEOUT", "60"))
TOOL_MAX_CONCURRENCY: int = int(os.getenv("TOOL_MAX_CONCURRENCY", "10"))
ENABLE_TOOL_MANAGEMENT: bool = os.getenv("ENABLE_TOOL_MANAGEMENT", "true").lower() == "true"
# official environment system version
SYSTEM_VERSION: str = os.getenv("SYSTEM_VERSION", "v0.2.0")
SYSTEM_INTRODUCTION: str = os.getenv("SYSTEM_INTRODUCTION", "")
def get_memory_output_path(self, filename: str = "") -> str:
"""

View File

@@ -82,6 +82,13 @@ class BizCode(IntEnum):
MEMORY_WRITE_FAILED = 9501
MEMORY_READ_FAILED = 9502
MEMORY_CONFIG_NOT_FOUND = 9503
# Implicit Memory API96xx
INVALID_USER_ID = 9601
INSUFFICIENT_DATA = 9602
INVALID_FILTER_PARAMS = 9603
ANALYSIS_FAILED = 9604
PROFILE_STORAGE_ERROR = 9605
# 系统100xx
INTERNAL_ERROR = 10001
@@ -159,6 +166,13 @@ HTTP_MAPPING = {
BizCode.MEMORY_READ_FAILED: 500,
BizCode.MEMORY_CONFIG_NOT_FOUND: 400,
# Implicit Memory API 错误码映射
BizCode.INVALID_USER_ID: 400,
BizCode.INSUFFICIENT_DATA: 400,
BizCode.INVALID_FILTER_PARAMS: 400,
BizCode.ANALYSIS_FAILED: 500,
BizCode.PROFILE_STORAGE_ERROR: 500,
BizCode.INTERNAL_ERROR: 500,
BizCode.DB_ERROR: 500,
BizCode.SERVICE_UNAVAILABLE: 503,

View File

@@ -5,19 +5,16 @@ This module provides analytics and insights for the memory system.
Available functions:
- get_hot_memory_tags: Get hot memory tags by frequency
- MemoryInsight: Generate memory insight reports
- get_recent_activity_stats: Get recent activity statistics
- generate_user_summary: Generate user summary
Note: MemoryInsight and generate_user_summary have been moved to
app.services.user_memory_service for better architecture.
"""
from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags
from app.core.memory.analytics.memory_insight import MemoryInsight
from app.core.memory.analytics.recent_activity_stats import get_recent_activity_stats
from app.core.memory.analytics.user_summary import generate_user_summary
__all__ = [
"get_hot_memory_tags",
"MemoryInsight",
"get_recent_activity_stats",
"generate_user_summary",
]

View File

@@ -0,0 +1,6 @@
"""Implicit Memory Module
This module provides behavior analysis capabilities that build comprehensive user profiles
by analyzing memory summary nodes from Neo4j. It creates detailed user portraits across
multiple dimensions, tracks interest distributions, and identifies behavioral habits.
"""

View File

@@ -0,0 +1 @@
"""Analyzers package for implicit memory analysis components."""

View File

@@ -0,0 +1,271 @@
"""Dimension Analyzer for Implicit Memory System
This module implements LLM-based personality dimension analysis from user memory summaries.
It analyzes four key dimensions: creativity, aesthetic, technology, and literature,
providing percentage scores with evidence and reasoning.
"""
import logging
from datetime import datetime
from typing import Any, Dict, List, Optional
from app.core.memory.analytics.implicit_memory.llm_client import ImplicitMemoryLLMClient
from app.core.memory.llm_tools.llm_client import LLMClientException
from app.schemas.implicit_memory_schema import (
DimensionPortrait,
DimensionScore,
UserMemorySummary,
)
from pydantic import BaseModel, Field
from sqlalchemy.orm import Session
logger = logging.getLogger(__name__)
class DimensionData(BaseModel):
"""Individual dimension analysis data."""
percentage: float = Field(ge=0.0, le=100.0)
evidence: List[str] = Field(default_factory=list)
reasoning: str = ""
confidence_level: int = 50 # Default to medium confidence
class DimensionAnalysisResponse(BaseModel):
"""Response model for dimension analysis."""
dimensions: Dict[str, DimensionData] = Field(default_factory=dict)
class DimensionAnalyzer:
"""Analyzes user memory summaries to extract personality dimensions."""
# Define the four dimensions we analyze
DIMENSIONS = ["creativity", "aesthetic", "technology", "literature"]
def __init__(self, db: Session, llm_model_id: Optional[str] = None):
"""Initialize the dimension analyzer.
Args:
db: Database session
llm_model_id: Optional LLM model ID to use for analysis
"""
self.db = db
self.llm_model_id = llm_model_id
self._llm_client = ImplicitMemoryLLMClient(db, llm_model_id)
async def analyze_dimensions(
self,
user_id: str,
user_summaries: List[UserMemorySummary],
existing_portrait: Optional[DimensionPortrait] = None
) -> DimensionPortrait:
"""Analyze user summaries to extract personality dimensions.
Args:
user_id: Target user ID
user_summaries: List of user-specific memory summaries
existing_portrait: Optional existing portrait for incremental updates
Returns:
Dimension portrait with four personality dimensions
Raises:
LLMClientException: If LLM analysis fails
"""
if not user_summaries:
logger.warning(f"No summaries provided for user {user_id}")
return self._create_empty_portrait(user_id)
try:
logger.info(f"Analyzing dimensions for user {user_id} with {len(user_summaries)} summaries")
# Use the LLM client wrapper for analysis
response = await self._llm_client.analyze_dimensions(
user_summaries=user_summaries,
user_id=user_id,
model_id=self.llm_model_id
)
# Create dimension scores
dimension_scores = {}
current_time = datetime.now()
for dimension_name in self.DIMENSIONS:
# Handle response as dictionary
dimensions_data = response.get("dimensions", {})
dimension_data = dimensions_data.get(dimension_name)
if dimension_data:
# Validate and create dimension score
score = self._create_dimension_score(
dimension_name=dimension_name,
dimension_data=dimension_data
)
dimension_scores[dimension_name] = score
else:
# Create default score if missing
logger.warning(f"Missing dimension data for {dimension_name}, using default")
dimension_scores[dimension_name] = self._create_default_dimension_score(dimension_name)
# Create dimension portrait
portrait = DimensionPortrait(
user_id=user_id,
creativity=dimension_scores["creativity"],
aesthetic=dimension_scores["aesthetic"],
technology=dimension_scores["technology"],
literature=dimension_scores["literature"],
analysis_timestamp=current_time,
total_summaries_analyzed=len(user_summaries),
historical_trends=self._calculate_historical_trends(existing_portrait) if existing_portrait else None
)
logger.info(f"Created dimension portrait for user {user_id}")
return portrait
except LLMClientException:
raise
except Exception as e:
logger.error(f"Dimension analysis failed for user {user_id}: {e}")
raise LLMClientException(f"Dimension analysis failed: {e}") from e
def _create_dimension_score(
self,
dimension_name: str,
dimension_data: dict
) -> DimensionScore:
"""Create a dimension score from analysis data.
Args:
dimension_name: Name of the dimension
dimension_data: Analysis data dictionary for the dimension
Returns:
DimensionScore object
"""
# Validate percentage - handle dict access
percentage = dimension_data.get("percentage", 0.0)
percentage = max(0.0, min(100.0, float(percentage)))
# Validate confidence level
confidence_level = self._validate_confidence_level(dimension_data.get("confidence_level", 50))
# Ensure evidence is not empty
evidence = dimension_data.get("evidence", [])
if not evidence:
evidence = ["No specific evidence found"]
# Ensure reasoning is not empty
reasoning = dimension_data.get("reasoning", "")
if not reasoning:
reasoning = f"Analysis for {dimension_name} dimension"
return DimensionScore(
dimension_name=dimension_name,
percentage=percentage,
evidence=evidence,
reasoning=reasoning,
confidence_level=confidence_level
)
def _create_default_dimension_score(self, dimension_name: str) -> DimensionScore:
"""Create a default dimension score when analysis fails.
Args:
dimension_name: Name of the dimension
Returns:
Default DimensionScore object
"""
return DimensionScore(
dimension_name=dimension_name,
percentage=0.0,
evidence=["Insufficient data for analysis"],
reasoning=f"No clear evidence found for {dimension_name} dimension",
confidence_level=20 # Low confidence as numerical value
)
def _validate_confidence_level(self, confidence_level) -> int:
"""Return confidence level as integer, handling both string and numeric inputs.
Args:
confidence_level: Confidence level (string or numeric)
Returns:
Confidence level as integer (0-100)
"""
# If it's already a number, return it as int
if isinstance(confidence_level, (int, float)):
return int(confidence_level)
# If it's a string, convert common values to numbers
if isinstance(confidence_level, str):
confidence_str = confidence_level.lower().strip()
if confidence_str in ["high", "높음"]:
return 85
elif confidence_str in ["medium", "중간"]:
return 50
elif confidence_str in ["low", "낮음"]:
return 20
else:
# Try to parse as number
try:
return int(float(confidence_str))
except ValueError:
logger.warning(f"Unknown confidence level: {confidence_level}, defaulting to medium")
return 50
# Default fallback
return 50
def _create_empty_portrait(self, user_id: str) -> DimensionPortrait:
"""Create an empty dimension portrait when no data is available.
Args:
user_id: Target user ID
Returns:
Empty DimensionPortrait
"""
current_time = datetime.now()
return DimensionPortrait(
user_id=user_id,
creativity=self._create_default_dimension_score("creativity"),
aesthetic=self._create_default_dimension_score("aesthetic"),
technology=self._create_default_dimension_score("technology"),
literature=self._create_default_dimension_score("literature"),
analysis_timestamp=current_time,
total_summaries_analyzed=0,
historical_trends=None
)
def _calculate_historical_trends(
self,
existing_portrait: DimensionPortrait
) -> List[Dict[str, Any]]:
"""Calculate historical trends from existing portrait.
Args:
existing_portrait: Previous dimension portrait
Returns:
List of historical trend data
"""
if not existing_portrait:
return []
# Create trend entry from existing portrait
trend_entry = {
"timestamp": existing_portrait.analysis_timestamp.isoformat(),
"creativity": existing_portrait.creativity.percentage,
"aesthetic": existing_portrait.aesthetic.percentage,
"technology": existing_portrait.technology.percentage,
"literature": existing_portrait.literature.percentage,
"total_summaries": existing_portrait.total_summaries_analyzed
}
# Combine with existing trends
existing_trends = existing_portrait.historical_trends or []
# Keep only recent trends (last 10 entries)
all_trends = existing_trends + [trend_entry]
return all_trends[-10:]

View File

@@ -0,0 +1,459 @@
"""Habit Analyzer for Implicit Memory System
This module implements LLM-based behavioral habit analysis from user memory summaries.
It identifies recurring behavioral patterns, temporal patterns, and consolidates
similar habits with confidence scoring.
"""
import logging
from datetime import datetime
from typing import List, Optional
from app.core.memory.analytics.implicit_memory.llm_client import ImplicitMemoryLLMClient
from app.core.memory.llm_tools.llm_client import LLMClientException
from app.schemas.implicit_memory_schema import (
BehaviorHabit,
FrequencyPattern,
UserMemorySummary,
)
from pydantic import BaseModel, Field
from sqlalchemy.orm import Session
logger = logging.getLogger(__name__)
class HabitData(BaseModel):
"""Individual habit analysis data."""
habit_description: str
frequency_pattern: str
time_context: str
confidence_level: int = 50 # Default to medium confidence
supporting_summaries: List[str] = Field(default_factory=list)
specific_examples: List[str] = Field(default_factory=list)
is_current: bool = True
class HabitAnalysisResponse(BaseModel):
"""Response model for habit analysis."""
habits: List[HabitData] = Field(default_factory=list)
class HabitAnalyzer:
"""Analyzes user memory summaries to extract behavioral habits."""
def __init__(self, db: Session, llm_model_id: Optional[str] = None):
"""Initialize the habit analyzer.
Args:
db: Database session
llm_model_id: Optional LLM model ID to use for analysis
"""
self.db = db
self.llm_model_id = llm_model_id
self._llm_client = ImplicitMemoryLLMClient(db, llm_model_id)
async def analyze_habits(
self,
user_id: str,
user_summaries: List[UserMemorySummary],
existing_habits: Optional[List[BehaviorHabit]] = None
) -> List[BehaviorHabit]:
"""Analyze user summaries to extract behavioral habits.
Args:
user_id: Target user ID
user_summaries: List of user-specific memory summaries
existing_habits: Optional existing habits for consolidation
Returns:
List of extracted behavioral habits
Raises:
LLMClientException: If LLM analysis fails
"""
if not user_summaries:
logger.warning(f"No summaries provided for user {user_id}")
return existing_habits or []
try:
logger.info(f"Analyzing habits for user {user_id} with {len(user_summaries)} summaries")
# Use the LLM client wrapper for analysis
response = await self._llm_client.analyze_habits(
user_summaries=user_summaries,
user_id=user_id,
model_id=self.llm_model_id
)
# Convert to BehaviorHabit objects
behavior_habits = []
for habit_data in response.get("habits", []):
try:
# Handle habit_data as dictionary
supporting_summaries = habit_data.get("supporting_summaries", [])
specific_examples = habit_data.get("specific_examples", [])
# Determine observation dates from summaries
first_observed, last_observed = self._determine_observation_dates(
user_summaries, supporting_summaries
)
behavior_habit = BehaviorHabit(
habit_description=habit_data.get("habit_description", ""),
frequency_pattern=self._validate_frequency_pattern(habit_data.get("frequency_pattern", "occasional")),
time_context=habit_data.get("time_context", ""),
confidence_level=self._validate_confidence_level(habit_data.get("confidence_level", 50)),
specific_examples=specific_examples,
first_observed=first_observed,
last_observed=last_observed,
is_current=habit_data.get("is_current", True)
)
# Validate habit
if self._is_valid_habit(behavior_habit):
behavior_habits.append(behavior_habit)
else:
logger.warning(f"Invalid habit skipped: {behavior_habit.habit_description}")
except Exception as e:
logger.error(f"Error creating behavior habit: {e}")
continue
# Consolidate with existing habits if provided
if existing_habits:
behavior_habits = self._consolidate_habits(
new_habits=behavior_habits,
existing_habits=existing_habits
)
# Sort habits by confidence and recency
behavior_habits = self._sort_habits_by_priority(behavior_habits)
logger.info(f"Extracted {len(behavior_habits)} habits for user {user_id}")
return behavior_habits
except LLMClientException:
raise
except Exception as e:
logger.error(f"Habit analysis failed for user {user_id}: {e}")
raise LLMClientException(f"Habit analysis failed: {e}") from e
def _validate_frequency_pattern(self, frequency_str: str) -> FrequencyPattern:
"""Validate and convert frequency pattern string.
Args:
frequency_str: Frequency pattern as string
Returns:
FrequencyPattern enum value
"""
frequency_str = frequency_str.lower().strip()
frequency_mapping = {
"daily": FrequencyPattern.DAILY,
"weekly": FrequencyPattern.WEEKLY,
"monthly": FrequencyPattern.MONTHLY,
"seasonal": FrequencyPattern.SEASONAL,
"occasional": FrequencyPattern.OCCASIONAL,
"event_triggered": FrequencyPattern.EVENT_TRIGGERED,
"event-triggered": FrequencyPattern.EVENT_TRIGGERED,
}
return frequency_mapping.get(frequency_str, FrequencyPattern.OCCASIONAL)
def _validate_confidence_level(self, confidence_level) -> int:
"""Return confidence level as integer, handling both string and numeric inputs.
Args:
confidence_level: Confidence level (string or numeric)
Returns:
Confidence level as integer (0-100)
"""
# If it's already a number, return it as int
if isinstance(confidence_level, (int, float)):
return int(confidence_level)
# If it's a string, convert common values to numbers
if isinstance(confidence_level, str):
confidence_str = confidence_level.lower().strip()
if confidence_str in ["high", "높음"]:
return 85
elif confidence_str in ["medium", "중간"]:
return 50
elif confidence_str in ["low", "낮음"]:
return 20
else:
# Try to parse as number
try:
return int(float(confidence_str))
except ValueError:
logger.warning(f"Unknown confidence level: {confidence_level}, defaulting to medium")
return 50
# Default fallback
return 50
def _determine_observation_dates(
self,
user_summaries: List[UserMemorySummary],
supporting_summary_ids: List[str]
) -> tuple[datetime, datetime]:
"""Determine first and last observation dates for a habit.
Args:
user_summaries: List of user summaries
supporting_summary_ids: IDs of summaries supporting the habit
Returns:
Tuple of (first_observed, last_observed) dates
"""
from datetime import timezone
# Find summaries that support this habit
supporting_summaries = [
summary for summary in user_summaries
if summary.summary_id in supporting_summary_ids
]
if not supporting_summaries:
# Use all summaries if no specific supporting summaries found
supporting_summaries = user_summaries
if not supporting_summaries:
current_time = datetime.now(timezone.utc).replace(tzinfo=None)
return current_time, current_time
# Get date range from supporting summaries - normalize to naive datetimes
timestamps = []
for summary in supporting_summaries:
ts = summary.timestamp
# Convert to naive datetime if it's timezone-aware
if ts.tzinfo is not None:
ts = ts.replace(tzinfo=None)
timestamps.append(ts)
first_observed = min(timestamps)
last_observed = max(timestamps)
return first_observed, last_observed
def _is_valid_habit(self, habit: BehaviorHabit) -> bool:
"""Validate a behavioral habit.
Args:
habit: Behavioral habit to validate
Returns:
True if valid, False otherwise
"""
try:
# Check required fields
if not habit.habit_description or not habit.habit_description.strip():
return False
# Check time context
if not habit.time_context or not habit.time_context.strip():
return False
# Check supporting summaries
if not habit.specific_examples or len(habit.specific_examples) == 0:
return False
# Check specific examples
if not habit.specific_examples or len(habit.specific_examples) == 0:
return False
# Check observation dates
if habit.first_observed > habit.last_observed:
return False
return True
except Exception as e:
logger.error(f"Error validating habit: {e}")
return False
def _consolidate_habits(
self,
new_habits: List[BehaviorHabit],
existing_habits: List[BehaviorHabit],
similarity_threshold: float = 0.7
) -> List[BehaviorHabit]:
"""Consolidate new habits with existing ones.
Args:
new_habits: Newly extracted habits
existing_habits: Existing habits
similarity_threshold: Threshold for considering habits similar
Returns:
Consolidated list of habits
"""
consolidated = existing_habits.copy()
current_time = datetime.now()
for new_habit in new_habits:
# Find similar existing habit
similar_habit = self._find_similar_habit(
new_habit, existing_habits, similarity_threshold
)
if similar_habit:
# Update existing habit
updated_habit = self._merge_habits(similar_habit, new_habit, current_time)
# Replace in consolidated list
for i, habit in enumerate(consolidated):
if habit.habit_description == similar_habit.habit_description:
consolidated[i] = updated_habit
break
else:
# Add as new habit
consolidated.append(new_habit)
return consolidated
def _find_similar_habit(
self,
target_habit: BehaviorHabit,
existing_habits: List[BehaviorHabit],
threshold: float
) -> Optional[BehaviorHabit]:
"""Find similar habit in existing list.
Args:
target_habit: Habit to find similarity for
existing_habits: List of existing habits
threshold: Similarity threshold
Returns:
Similar habit if found, None otherwise
"""
target_desc = target_habit.habit_description.lower().strip()
for existing_habit in existing_habits:
existing_desc = existing_habit.habit_description.lower().strip()
# Check description similarity
desc_similarity = self._calculate_text_similarity(target_desc, existing_desc)
# Check frequency pattern match
frequency_match = (target_habit.frequency_pattern == existing_habit.frequency_pattern)
# Check time context similarity
time_similarity = self._calculate_text_similarity(
target_habit.time_context.lower(),
existing_habit.time_context.lower()
)
# Combined similarity score
combined_similarity = (desc_similarity * 0.6 + time_similarity * 0.4)
if frequency_match:
combined_similarity += 0.1 # Bonus for frequency match
if combined_similarity >= threshold:
return existing_habit
return None
def _calculate_text_similarity(self, text1: str, text2: str) -> float:
"""Calculate simple text similarity based on common words.
Args:
text1: First text
text2: Second text
Returns:
Similarity score between 0.0 and 1.0
"""
if not text1 or not text2:
return 0.0
# Simple word-based similarity
words1 = set(text1.lower().split())
words2 = set(text2.lower().split())
if not words1 or not words2:
return 0.0
intersection = words1.intersection(words2)
union = words1.union(words2)
return len(intersection) / len(union) if union else 0.0
def _merge_habits(
self,
existing_habit: BehaviorHabit,
new_habit: BehaviorHabit,
current_time: datetime
) -> BehaviorHabit:
"""Merge two similar habits.
Args:
existing_habit: Existing habit
new_habit: New habit to merge
current_time: Current timestamp
Returns:
Merged behavioral habit
"""
# Combine supporting summaries (using specific_examples instead)
combined_examples = list(set(
existing_habit.specific_examples + new_habit.specific_examples
))
# Combine specific examples
combined_examples = list(set(
existing_habit.specific_examples + new_habit.specific_examples
))
# Update confidence level (take higher confidence)
new_confidence = max(existing_habit.confidence_level, new_habit.confidence_level)
# Update observation dates
first_observed = min(existing_habit.first_observed, new_habit.first_observed)
last_observed = max(existing_habit.last_observed, new_habit.last_observed)
# Determine if habit is current (observed within last 30 days)
is_current = (current_time - last_observed).days <= 30
# Combine time context
combined_time_context = existing_habit.time_context
if new_habit.time_context and new_habit.time_context not in combined_time_context:
combined_time_context += f"; {new_habit.time_context}"
return BehaviorHabit(
habit_description=existing_habit.habit_description, # Keep original description
frequency_pattern=existing_habit.frequency_pattern, # Keep original frequency
time_context=combined_time_context,
confidence_level=new_confidence,
specific_examples=combined_examples,
first_observed=first_observed,
last_observed=last_observed,
is_current=is_current
)
def _sort_habits_by_priority(self, habits: List[BehaviorHabit]) -> List[BehaviorHabit]:
"""Sort habits by confidence level and recency.
Args:
habits: List of habits to sort
Returns:
Sorted list of habits
"""
def priority_score(habit: BehaviorHabit) -> tuple:
# Confidence level score (0-100 scale)
confidence_score = habit.confidence_level
# Recency score (more recent = higher score)
days_since_last = (datetime.now() - habit.last_observed).days
recency_score = max(0, 365 - days_since_last) # Max 365 days
# Current habit bonus
current_bonus = 100 if habit.is_current else 0
return (confidence_score, recency_score + current_bonus, habit.last_observed)
return sorted(habits, key=priority_score, reverse=True)

View File

@@ -0,0 +1,277 @@
"""Interest Analyzer for Implicit Memory System
This module implements LLM-based interest area analysis from user memory summaries.
It categorizes user interests into four areas: tech, lifestyle, music, and art,
providing percentage distribution that totals 100%.
"""
import logging
from datetime import datetime
from typing import Any, Dict, List, Optional
from app.core.memory.analytics.implicit_memory.llm_client import ImplicitMemoryLLMClient
from app.core.memory.llm_tools.llm_client import LLMClientException
from app.schemas.implicit_memory_schema import (
InterestAreaDistribution,
InterestCategory,
UserMemorySummary,
)
from pydantic import BaseModel, Field
from sqlalchemy.orm import Session
logger = logging.getLogger(__name__)
class InterestData(BaseModel):
"""Individual interest category analysis data."""
percentage: float = Field(ge=0.0, le=100.0)
evidence: List[str] = Field(default_factory=list)
trending_direction: Optional[str] = None
class InterestAnalysisResponse(BaseModel):
"""Response model for interest analysis."""
interest_distribution: Dict[str, InterestData] = Field(default_factory=dict)
class InterestAnalyzer:
"""Analyzes user memory summaries to extract interest area distribution."""
# Define the four interest categories we analyze
INTEREST_CATEGORIES = ["tech", "lifestyle", "music", "art"]
def __init__(self, db: Session, llm_model_id: Optional[str] = None):
"""Initialize the interest analyzer.
Args:
db: Database session
llm_model_id: Optional LLM model ID to use for analysis
"""
self.db = db
self.llm_model_id = llm_model_id
self._llm_client = ImplicitMemoryLLMClient(db, llm_model_id)
async def analyze_interests(
self,
user_id: str,
user_summaries: List[UserMemorySummary],
existing_distribution: Optional[InterestAreaDistribution] = None
) -> InterestAreaDistribution:
"""Analyze user summaries to extract interest area distribution.
Args:
user_id: Target user ID
user_summaries: List of user-specific memory summaries
existing_distribution: Optional existing distribution for trend tracking
Returns:
Interest area distribution across four categories
Raises:
LLMClientException: If LLM analysis fails
"""
if not user_summaries:
logger.warning(f"No summaries provided for user {user_id}")
return self._create_empty_distribution(user_id)
try:
logger.info(f"Analyzing interests for user {user_id} with {len(user_summaries)} summaries")
# Use the LLM client wrapper for analysis
response = await self._llm_client.analyze_interests(
user_summaries=user_summaries,
user_id=user_id,
model_id=self.llm_model_id
)
# Create interest categories
interest_categories = {}
current_time = datetime.now()
# Extract interest_distribution from response dict
interest_distribution = response.get("interest_distribution", {})
# Extract and validate interest data
raw_interests = {}
for category_name in self.INTEREST_CATEGORIES:
interest_data_dict = interest_distribution.get(category_name)
if interest_data_dict:
raw_interests[category_name] = InterestData(
percentage=interest_data_dict.get("percentage", 0.0),
evidence=interest_data_dict.get("evidence", []),
trending_direction=interest_data_dict.get("trending_direction")
)
else:
# Create default if missing
logger.warning(f"Missing interest data for {category_name}, using default")
raw_interests[category_name] = InterestData(
percentage=0.0,
evidence=["No specific evidence found"],
trending_direction=None
)
# Normalize percentages to ensure they sum to 100%
normalized_interests = self._normalize_percentages(raw_interests)
# Create interest category objects
for category_name in self.INTEREST_CATEGORIES:
interest_data = normalized_interests[category_name]
# Calculate trending direction if we have existing data
trending_direction = self._calculate_trending_direction(
category_name=category_name,
current_percentage=interest_data.percentage,
existing_distribution=existing_distribution
) if existing_distribution else interest_data.trending_direction
interest_categories[category_name] = InterestCategory(
category_name=category_name,
percentage=interest_data.percentage,
evidence=interest_data.evidence if interest_data.evidence else ["No specific evidence found"],
trending_direction=trending_direction
)
# Create interest area distribution
distribution = InterestAreaDistribution(
user_id=user_id,
tech=interest_categories["tech"],
lifestyle=interest_categories["lifestyle"],
music=interest_categories["music"],
art=interest_categories["art"],
analysis_timestamp=current_time,
total_summaries_analyzed=len(user_summaries)
)
# Validate that percentages sum to 100%
total_percentage = distribution.total_percentage
if not (99.9 <= total_percentage <= 100.1):
logger.warning(f"Interest percentages sum to {total_percentage}, expected ~100%")
logger.info(f"Created interest distribution for user {user_id}")
return distribution
except LLMClientException:
raise
except Exception as e:
logger.error(f"Interest analysis failed for user {user_id}: {e}")
raise LLMClientException(f"Interest analysis failed: {e}") from e
def _normalize_percentages(self, raw_interests: Dict[str, InterestData]) -> Dict[str, InterestData]:
"""Normalize percentages to ensure they sum to 100%.
Args:
raw_interests: Raw interest data with potentially unnormalized percentages
Returns:
Normalized interest data
"""
# Calculate current total
total = sum(interest.percentage for interest in raw_interests.values())
if total == 0:
# If all percentages are 0, distribute equally
equal_percentage = 100.0 / len(self.INTEREST_CATEGORIES)
normalized = {}
for category_name, interest_data in raw_interests.items():
normalized[category_name] = InterestData(
percentage=equal_percentage,
evidence=interest_data.evidence,
trending_direction=interest_data.trending_direction
)
return normalized
# Normalize to sum to 100%
normalization_factor = 100.0 / total
normalized = {}
for category_name, interest_data in raw_interests.items():
normalized_percentage = interest_data.percentage * normalization_factor
normalized[category_name] = InterestData(
percentage=round(normalized_percentage, 1),
evidence=interest_data.evidence,
trending_direction=interest_data.trending_direction
)
# Handle rounding errors by adjusting the largest category
current_total = sum(interest.percentage for interest in normalized.values())
if abs(current_total - 100.0) > 0.1:
# Find category with largest percentage and adjust
largest_category = max(normalized.keys(), key=lambda k: normalized[k].percentage)
adjustment = 100.0 - current_total
adjusted_percentage = normalized[largest_category].percentage + adjustment
normalized[largest_category] = InterestData(
percentage=round(max(0.0, adjusted_percentage), 1),
evidence=normalized[largest_category].evidence,
trending_direction=normalized[largest_category].trending_direction
)
return normalized
def _calculate_trending_direction(
self,
category_name: str,
current_percentage: float,
existing_distribution: InterestAreaDistribution,
threshold: float = 5.0
) -> Optional[str]:
"""Calculate trending direction for an interest category.
Args:
category_name: Name of the interest category
current_percentage: Current percentage for the category
existing_distribution: Previous distribution for comparison
threshold: Minimum percentage change to consider a trend
Returns:
Trending direction: "increasing", "decreasing", "stable", or None
"""
try:
# Get previous percentage
previous_category = getattr(existing_distribution, category_name, None)
if not previous_category:
return None
previous_percentage = previous_category.percentage
change = current_percentage - previous_percentage
if abs(change) < threshold:
return "stable"
elif change > 0:
return "increasing"
else:
return "decreasing"
except Exception as e:
logger.error(f"Error calculating trending direction for {category_name}: {e}")
return None
def _create_empty_distribution(self, user_id: str) -> InterestAreaDistribution:
"""Create an empty interest distribution when no data is available.
Args:
user_id: Target user ID
Returns:
Empty InterestAreaDistribution with equal percentages
"""
current_time = datetime.now()
equal_percentage = 25.0 # 100% / 4 categories
default_category = lambda name: InterestCategory(
category_name=name,
percentage=equal_percentage,
evidence=["Insufficient data for analysis"],
trending_direction=None
)
return InterestAreaDistribution(
user_id=user_id,
tech=default_category("tech"),
lifestyle=default_category("lifestyle"),
music=default_category("music"),
art=default_category("art"),
analysis_timestamp=current_time,
total_summaries_analyzed=0
)

View File

@@ -0,0 +1,302 @@
"""Preference Analyzer for Implicit Memory System
This module implements LLM-based preference extraction from user memory summaries.
It identifies implicit preferences, consolidates similar preferences, and calculates
confidence scores based on evidence strength.
"""
import logging
from datetime import datetime
from typing import Any, Dict, List, Optional
from app.core.memory.analytics.implicit_memory.llm_client import ImplicitMemoryLLMClient
from app.core.memory.llm_tools.llm_client import LLMClientException
from app.schemas.implicit_memory_schema import (
PreferenceTag,
UserMemorySummary,
)
from pydantic import BaseModel, Field
from sqlalchemy.orm import Session
logger = logging.getLogger(__name__)
class PreferenceAnalysisResponse(BaseModel):
"""Response model for preference analysis."""
preferences: List[Dict[str, Any]] = Field(default_factory=list)
class PreferenceAnalyzer:
"""Analyzes user memory summaries to extract implicit preferences."""
def __init__(self, db: Session, llm_model_id: Optional[str] = None):
"""Initialize the preference analyzer.
Args:
db: Database session
llm_model_id: Optional LLM model ID to use for analysis
"""
self.db = db
self.llm_model_id = llm_model_id
self._llm_client = ImplicitMemoryLLMClient(db, llm_model_id)
async def analyze_preferences(
self,
user_id: str,
user_summaries: List[UserMemorySummary],
existing_preferences: Optional[List[PreferenceTag]] = None
) -> List[PreferenceTag]:
"""Analyze user summaries to extract preferences.
Args:
user_id: Target user ID
user_summaries: List of user-specific memory summaries
existing_preferences: Optional existing preferences for consolidation
Returns:
List of extracted preference tags
Raises:
LLMClientException: If LLM analysis fails
"""
if not user_summaries:
logger.warning(f"No summaries provided for user {user_id}")
return []
try:
logger.info(f"Analyzing preferences for user {user_id} with {len(user_summaries)} summaries")
# Use the LLM client wrapper for analysis
response = await self._llm_client.analyze_preferences(
user_summaries=user_summaries,
user_id=user_id,
model_id=self.llm_model_id
)
# Convert to PreferenceTag objects
preference_tags = []
current_time = datetime.now()
for pref_data in response.get("preferences", []):
try:
# Extract conversation references from summaries
conversation_refs = [s.summary_id for s in user_summaries]
preference_tag = PreferenceTag(
tag_name=pref_data.get("tag_name", ""),
confidence_score=float(pref_data.get("confidence_score", 0.0)),
supporting_evidence=pref_data.get("supporting_evidence", []),
context_details=pref_data.get("context_details", ""),
category=pref_data.get("category"),
conversation_references=conversation_refs,
created_at=current_time,
updated_at=current_time
)
# Validate preference tag
if self._is_valid_preference(preference_tag):
preference_tags.append(preference_tag)
else:
logger.warning(f"Invalid preference tag skipped: {preference_tag.tag_name}")
except Exception as e:
logger.error(f"Error creating preference tag: {e}")
continue
# Consolidate with existing preferences if provided
if existing_preferences:
preference_tags = self._consolidate_preferences(
new_preferences=preference_tags,
existing_preferences=existing_preferences
)
logger.info(f"Extracted {len(preference_tags)} preferences for user {user_id}")
return preference_tags
except LLMClientException:
raise
except Exception as e:
logger.error(f"Preference analysis failed for user {user_id}: {e}")
raise LLMClientException(f"Preference analysis failed: {e}") from e
def _is_valid_preference(self, preference: PreferenceTag) -> bool:
"""Validate a preference tag.
Args:
preference: Preference tag to validate
Returns:
True if valid, False otherwise
"""
try:
# Check required fields
if not preference.tag_name or not preference.tag_name.strip():
return False
# Check confidence score range
if not (0.0 <= preference.confidence_score <= 1.0):
return False
# Check supporting evidence
if not preference.supporting_evidence or len(preference.supporting_evidence) == 0:
return False
# Check context details
if not preference.context_details or not preference.context_details.strip():
return False
return True
except Exception as e:
logger.error(f"Error validating preference: {e}")
return False
def _consolidate_preferences(
self,
new_preferences: List[PreferenceTag],
existing_preferences: List[PreferenceTag],
similarity_threshold: float = 0.8
) -> List[PreferenceTag]:
"""Consolidate new preferences with existing ones.
Args:
new_preferences: Newly extracted preferences
existing_preferences: Existing preferences
similarity_threshold: Threshold for considering preferences similar
Returns:
Consolidated list of preferences
"""
consolidated = existing_preferences.copy()
current_time = datetime.now()
for new_pref in new_preferences:
# Find similar existing preference
similar_pref = self._find_similar_preference(
new_pref, existing_preferences, similarity_threshold
)
if similar_pref:
# Update existing preference
updated_pref = self._merge_preferences(similar_pref, new_pref, current_time)
# Replace in consolidated list
for i, pref in enumerate(consolidated):
if pref.tag_name == similar_pref.tag_name:
consolidated[i] = updated_pref
break
else:
# Add as new preference
consolidated.append(new_pref)
return consolidated
def _find_similar_preference(
self,
target_preference: PreferenceTag,
existing_preferences: List[PreferenceTag],
threshold: float
) -> Optional[PreferenceTag]:
"""Find similar preference in existing list.
Args:
target_preference: Preference to find similarity for
existing_preferences: List of existing preferences
threshold: Similarity threshold
Returns:
Similar preference if found, None otherwise
"""
target_name = target_preference.tag_name.lower().strip()
for existing_pref in existing_preferences:
existing_name = existing_pref.tag_name.lower().strip()
# Simple similarity check based on common words
similarity = self._calculate_text_similarity(target_name, existing_name)
if similarity >= threshold:
return existing_pref
return None
def _calculate_text_similarity(self, text1: str, text2: str) -> float:
"""Calculate simple text similarity based on common words.
Args:
text1: First text
text2: Second text
Returns:
Similarity score between 0.0 and 1.0
"""
if not text1 or not text2:
return 0.0
# Simple word-based similarity
words1 = set(text1.lower().split())
words2 = set(text2.lower().split())
if not words1 or not words2:
return 0.0
intersection = words1.intersection(words2)
union = words1.union(words2)
return len(intersection) / len(union) if union else 0.0
def _merge_preferences(
self,
existing_pref: PreferenceTag,
new_pref: PreferenceTag,
current_time: datetime
) -> PreferenceTag:
"""Merge two similar preferences.
Args:
existing_pref: Existing preference
new_pref: New preference to merge
current_time: Current timestamp
Returns:
Merged preference tag
"""
# Combine supporting evidence
combined_evidence = list(set(
existing_pref.supporting_evidence + new_pref.supporting_evidence
))
# Combine conversation references
combined_refs = list(set(
existing_pref.conversation_references + new_pref.conversation_references
))
# Calculate new confidence score (weighted average)
evidence_weight = len(new_pref.supporting_evidence)
total_weight = len(existing_pref.supporting_evidence) + evidence_weight
if total_weight > 0:
new_confidence = (
(existing_pref.confidence_score * len(existing_pref.supporting_evidence) +
new_pref.confidence_score * evidence_weight) / total_weight
)
else:
new_confidence = max(existing_pref.confidence_score, new_pref.confidence_score)
# Ensure confidence doesn't exceed 1.0
new_confidence = min(new_confidence, 1.0)
# Combine context details
combined_context = existing_pref.context_details
if new_pref.context_details and new_pref.context_details not in combined_context:
combined_context += f"; {new_pref.context_details}"
return PreferenceTag(
tag_name=existing_pref.tag_name, # Keep original name
confidence_score=new_confidence,
supporting_evidence=combined_evidence,
context_details=combined_context,
category=existing_pref.category or new_pref.category,
conversation_references=combined_refs,
created_at=existing_pref.created_at,
updated_at=current_time
)

View File

@@ -0,0 +1,97 @@
"""
Memory Data Source
Handles retrieval and processing of memory data from Neo4j using direct Cypher queries.
"""
import logging
from datetime import datetime
from typing import Any, Dict, List, Optional
from app.repositories.neo4j.memory_summary_repository import MemorySummaryRepository
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.schemas.implicit_memory_schema import TimeRange, UserMemorySummary
from sqlalchemy.orm import Session
logger = logging.getLogger(__name__)
class MemoryDataSource:
"""Retrieves processed memory data from Neo4j using direct Cypher queries."""
def __init__(
self,
db: Session,
neo4j_connector: Optional[Neo4jConnector] = None
):
self.db = db
self.neo4j_connector = neo4j_connector or Neo4jConnector()
self.memory_summary_repo = MemorySummaryRepository(self.neo4j_connector)
def _parse_timestamp(self, timestamp: Any) -> datetime:
"""Parse timestamp from various formats."""
if isinstance(timestamp, str):
return datetime.fromisoformat(timestamp.replace('Z', '+00:00'))
elif timestamp is None:
return datetime.now()
return timestamp
def _dict_to_user_summary(self, summary_dict: Dict, user_id: str) -> Optional[UserMemorySummary]:
"""Convert a Neo4j dict directly to UserMemorySummary."""
try:
content = summary_dict.get("content", summary_dict.get("summary", ""))
if not content or not content.strip():
return None
return UserMemorySummary(
summary_id=summary_dict.get("id", summary_dict.get("uuid", "")),
user_id=user_id,
user_content=content,
timestamp=self._parse_timestamp(summary_dict.get("created_at")),
confidence_score=1.0,
summary_type="memory_summary"
)
except Exception as e:
logger.warning(f"Failed to parse summary {summary_dict.get('id', 'unknown')}: {e}")
return None
async def get_user_summaries(
self,
user_id: str,
time_range: Optional[TimeRange] = None,
limit: int = 1000
) -> List[UserMemorySummary]:
"""Retrieve user memory summaries from Neo4j.
Args:
user_id: Target user ID
time_range: Optional time range filter
limit: Maximum number of summaries
Returns:
List of user memory summaries
"""
try:
start_date = time_range.start_date if time_range else None
end_date = time_range.end_date if time_range else None
summary_dicts = await self.memory_summary_repo.find_by_group_id(
group_id=user_id,
limit=limit,
start_date=start_date,
end_date=end_date
)
summaries = []
for summary_dict in summary_dicts:
summary = self._dict_to_user_summary(summary_dict, user_id)
if summary:
summaries.append(summary)
logger.info(f"Retrieved {len(summaries)} summaries for user {user_id}")
return summaries
except Exception as e:
logger.error(f"Failed to retrieve summaries for user {user_id}: {e}")
raise

View File

@@ -0,0 +1,226 @@
"""Habit Detector for Implicit Memory System
This module implements the HabitDetector class that specializes in identifying
and ranking behavioral habits from user memory summaries. It provides advanced
habit analysis with confidence scoring, recency weighting, and current vs past
habit distinction.
"""
import logging
from datetime import datetime, timedelta
from typing import List, Optional
from app.core.memory.analytics.implicit_memory.analyzers.habit_analyzer import (
HabitAnalyzer,
)
from app.core.memory.llm_tools.llm_client import LLMClientException
from app.schemas.implicit_memory_schema import (
BehaviorHabit,
FrequencyPattern,
UserMemorySummary,
)
from sqlalchemy.orm import Session
logger = logging.getLogger(__name__)
class HabitDetector:
"""Detects and ranks behavioral habits from user memory summaries."""
def __init__(
self,
db: Session,
llm_model_id: Optional[str] = None
):
"""Initialize the habit detector.
Args:
db: Database session
llm_model_id: Optional LLM model ID to use for analysis
"""
self.db = db
self.llm_model_id = llm_model_id
self.habit_analyzer = HabitAnalyzer(db, llm_model_id)
async def detect_habits(
self,
user_id: str,
user_summaries: List[UserMemorySummary],
existing_habits: Optional[List[BehaviorHabit]] = None
) -> List[BehaviorHabit]:
"""Detect behavioral habits from user summaries.
Args:
user_id: Target user ID
user_summaries: List of user-specific memory summaries
existing_habits: Optional existing habits for consolidation
Returns:
List of detected and ranked behavioral habits
Raises:
LLMClientException: If habit analysis fails
"""
if not user_summaries:
logger.warning(f"No summaries provided for user {user_id}")
return existing_habits or []
logger.info(f"Detecting habits for user {user_id} with {len(user_summaries)} summaries")
try:
# Use the habit analyzer to extract habits
detected_habits = await self.habit_analyzer.analyze_habits(
user_id=user_id,
user_summaries=user_summaries,
existing_habits=existing_habits
)
# Apply advanced ranking and filtering
ranked_habits = self.rank_habits_by_confidence_and_recency(detected_habits)
# Distinguish current vs past habits
categorized_habits = self.distinguish_current_vs_past_habits(ranked_habits)
logger.info(f"Detected {len(categorized_habits)} habits for user {user_id}")
return categorized_habits
except LLMClientException:
logger.error(f"Habit detection failed for user {user_id}")
raise
except Exception as e:
logger.error(f"Habit detection failed for user {user_id}: {e}")
raise LLMClientException(f"Habit detection failed: {e}") from e
def rank_habits_by_confidence_and_recency(
self,
habits: List[BehaviorHabit],
confidence_weight: float = 0.6,
recency_weight: float = 0.4
) -> List[BehaviorHabit]:
"""Rank habits by confidence level and recency.
Args:
habits: List of habits to rank
confidence_weight: Weight for confidence score (0.0-1.0)
recency_weight: Weight for recency score (0.0-1.0)
Returns:
List of habits ranked by combined score
"""
if not habits:
return []
logger.info(f"Ranking {len(habits)} habits by confidence and recency")
def calculate_ranking_score(habit: BehaviorHabit) -> float:
"""Calculate combined ranking score for a habit."""
# Confidence score (0.0-1.0) - convert from 0-100 scale
confidence_score = habit.confidence_level / 100.0
# Recency score (0.0-1.0)
current_time = datetime.now()
days_since_last = (current_time - habit.last_observed).days
# Exponential decay for recency (habits lose relevance over time)
if days_since_last <= 7:
recency_score = 1.0 # Very recent
elif days_since_last <= 30:
recency_score = 0.8 # Recent
elif days_since_last <= 90:
recency_score = 0.5 # Somewhat recent
elif days_since_last <= 180:
recency_score = 0.3 # Old
else:
recency_score = 0.1 # Very old
# Frequency pattern bonus
frequency_bonuses = {
FrequencyPattern.DAILY: 0.2,
FrequencyPattern.WEEKLY: 0.15,
FrequencyPattern.MONTHLY: 0.1,
FrequencyPattern.SEASONAL: 0.05,
FrequencyPattern.OCCASIONAL: 0.0,
FrequencyPattern.EVENT_TRIGGERED: 0.05
}
frequency_bonus = frequency_bonuses.get(habit.frequency_pattern, 0.0)
# Evidence quality bonus
evidence_bonus = min(len(habit.specific_examples) / 10.0, 0.1) # Max 0.1 bonus
# Current habit bonus
current_bonus = 0.1 if habit.is_current else 0.0
# Calculate final score
base_score = (confidence_score * confidence_weight +
recency_score * recency_weight)
final_score = base_score + frequency_bonus + evidence_bonus + current_bonus
return min(final_score, 1.0) # Cap at 1.0
# Sort habits by ranking score (descending)
ranked_habits = sorted(habits, key=calculate_ranking_score, reverse=True)
logger.info(f"Ranked habits with scores: {[calculate_ranking_score(h) for h in ranked_habits[:5]]}")
return ranked_habits
def distinguish_current_vs_past_habits(
self,
habits: List[BehaviorHabit],
current_threshold_days: int = 30
) -> List[BehaviorHabit]:
"""Distinguish between current and past habits based on recency.
Args:
habits: List of habits to categorize
current_threshold_days: Days threshold for considering a habit current
Returns:
List of habits with updated is_current status
"""
if not habits:
return []
current_time = datetime.now()
cutoff_date = current_time - timedelta(days=current_threshold_days)
current_habits = []
past_habits = []
for habit in habits:
# Update is_current status based on last observation
if habit.last_observed >= cutoff_date:
# Create updated habit with is_current = True
updated_habit = BehaviorHabit(
habit_description=habit.habit_description,
frequency_pattern=habit.frequency_pattern,
time_context=habit.time_context,
confidence_level=habit.confidence_level,
specific_examples=habit.specific_examples,
first_observed=habit.first_observed,
last_observed=habit.last_observed,
is_current=True
)
current_habits.append(updated_habit)
else:
# Create updated habit with is_current = False
updated_habit = BehaviorHabit(
habit_description=habit.habit_description,
frequency_pattern=habit.frequency_pattern,
time_context=habit.time_context,
confidence_level=habit.confidence_level,
specific_examples=habit.specific_examples,
first_observed=habit.first_observed,
last_observed=habit.last_observed,
is_current=False
)
past_habits.append(updated_habit)
# Return current habits first, then past habits
categorized_habits = current_habits + past_habits
logger.info(f"Categorized habits: {len(current_habits)} current, {len(past_habits)} past")
return categorized_habits

View File

@@ -0,0 +1,321 @@
"""LLM Client Wrapper for Implicit Memory Analysis
This module provides a specialized LLM client wrapper that integrates with the
MemoryClientFactory to perform implicit memory analysis tasks including preference
extraction, personality dimension analysis, interest categorization, and habit detection.
"""
import logging
from typing import Any, Dict, List, Optional
from app.core.memory.analytics.implicit_memory.prompts import (
get_dimension_analysis_prompt,
get_habit_analysis_prompt,
get_interest_analysis_prompt,
get_preference_analysis_prompt,
)
from app.core.memory.llm_tools.llm_client import LLMClientException
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.schemas.implicit_memory_schema import UserMemorySummary
from pydantic import BaseModel, Field
from sqlalchemy.orm import Session
logger = logging.getLogger(__name__)
# Response Models for LLM Analysis
class PreferenceAnalysisResponse(BaseModel):
"""Response model for preference analysis."""
preferences: List[Dict[str, Any]] = Field(default_factory=list)
class DimensionAnalysisResponse(BaseModel):
"""Response model for dimension analysis."""
dimensions: Dict[str, Dict[str, Any]] = Field(default_factory=dict)
class InterestAnalysisResponse(BaseModel):
"""Response model for interest analysis."""
interest_distribution: Dict[str, Dict[str, Any]] = Field(default_factory=dict)
class HabitAnalysisResponse(BaseModel):
"""Response model for habit analysis."""
habits: List[Dict[str, Any]] = Field(default_factory=list)
class ImplicitMemoryLLMClient:
"""LLM client wrapper for implicit memory analysis.
This class provides a high-level interface for performing LLM-based analysis
of user memory summaries to extract preferences, personality dimensions,
interests, and behavioral habits.
"""
def __init__(self, db: Session, default_model_id: Optional[str] = None):
"""Initialize the LLM client wrapper.
Args:
db: Database session for accessing model configurations
default_model_id: Default LLM model ID to use if none specified
"""
self.db = db
self.default_model_id = default_model_id
self._client_factory = MemoryClientFactory(db)
logger.info("ImplicitMemoryLLMClient initialized")
def _get_llm_client(self, model_id: Optional[str] = None):
"""Get LLM client instance.
Args:
model_id: LLM model ID to use, defaults to default_model_id
Returns:
LLM client instance
Raises:
ValueError: If no model ID is provided and no default is set
LLMClientException: If client creation fails
"""
effective_model_id = model_id or self.default_model_id
if not effective_model_id:
raise ValueError("No LLM model ID provided and no default model ID set")
try:
client = self._client_factory.get_llm_client(effective_model_id)
logger.debug(f"Created LLM client for model: {effective_model_id}")
return client
except Exception as e:
logger.error(f"Failed to create LLM client for model {effective_model_id}: {e}")
raise LLMClientException(f"Failed to create LLM client: {e}") from e
def _prepare_summaries_for_analysis(self, user_summaries: List[UserMemorySummary]) -> List[Dict[str, Any]]:
"""Prepare user memory summaries for LLM analysis.
Args:
user_summaries: List of user memory summaries
Returns:
List of formatted summary dictionaries
"""
formatted_summaries = []
for summary in user_summaries:
formatted_summary = {
'summary_id': summary.summary_id,
'user_content': summary.user_content,
'timestamp': summary.timestamp.isoformat(),
'summary_type': summary.summary_type,
'confidence_score': summary.confidence_score
}
formatted_summaries.append(formatted_summary)
logger.debug(f"Prepared {len(formatted_summaries)} summaries for analysis")
return formatted_summaries
async def analyze_preferences(
self,
user_summaries: List[UserMemorySummary],
user_id: str,
model_id: Optional[str] = None
) -> Dict[str, Any]:
"""Analyze user preferences from memory summaries.
Args:
user_summaries: List of user memory summaries to analyze
user_id: Target user ID for analysis
model_id: Optional LLM model ID to use
Returns:
Dictionary containing extracted preferences
Raises:
LLMClientException: If LLM analysis fails
ValueError: If input validation fails
"""
if not user_summaries:
logger.warning(f"No summaries provided for preference analysis of user {user_id}")
return {"preferences": []}
if not user_id:
raise ValueError("User ID is required for preference analysis")
try:
# Prepare summaries and get prompt
formatted_summaries = self._prepare_summaries_for_analysis(user_summaries)
prompt = get_preference_analysis_prompt(formatted_summaries, user_id)
# Get LLM client and perform analysis
llm_client = self._get_llm_client(model_id)
messages = [{"role": "user", "content": prompt}]
# Use structured output for reliable parsing
response = await llm_client.response_structured(
messages=messages,
response_model=PreferenceAnalysisResponse
)
result = response.model_dump()
logger.info(f"Analyzed preferences for user {user_id}: found {len(result.get('preferences', []))} preferences")
return result
except Exception as e:
logger.error(f"Preference analysis failed for user {user_id}: {e}")
raise LLMClientException(f"Preference analysis failed: {e}") from e
async def analyze_dimensions(
self,
user_summaries: List[UserMemorySummary],
user_id: str,
model_id: Optional[str] = None
) -> Dict[str, Any]:
"""Analyze user personality dimensions from memory summaries.
Args:
user_summaries: List of user memory summaries to analyze
user_id: Target user ID for analysis
model_id: Optional LLM model ID to use
Returns:
Dictionary containing dimension scores and analysis
Raises:
LLMClientException: If LLM analysis fails
ValueError: If input validation fails
"""
if not user_summaries:
logger.warning(f"No summaries provided for dimension analysis of user {user_id}")
return {"dimensions": {}}
if not user_id:
raise ValueError("User ID is required for dimension analysis")
try:
# Prepare summaries and get prompt
formatted_summaries = self._prepare_summaries_for_analysis(user_summaries)
prompt = get_dimension_analysis_prompt(formatted_summaries, user_id)
# Get LLM client and perform analysis
llm_client = self._get_llm_client(model_id)
messages = [{"role": "user", "content": prompt}]
# Use structured output for reliable parsing
response = await llm_client.response_structured(
messages=messages,
response_model=DimensionAnalysisResponse
)
result = response.model_dump()
dimensions = result.get('dimensions', {})
logger.info(f"Analyzed dimensions for user {user_id}: {list(dimensions.keys())}")
return result
except Exception as e:
logger.error(f"Dimension analysis failed for user {user_id}: {e}")
raise LLMClientException(f"Dimension analysis failed: {e}") from e
async def analyze_interests(
self,
user_summaries: List[UserMemorySummary],
user_id: str,
model_id: Optional[str] = None
) -> Dict[str, Any]:
"""Analyze user interest distribution from memory summaries.
Args:
user_summaries: List of user memory summaries to analyze
user_id: Target user ID for analysis
model_id: Optional LLM model ID to use
Returns:
Dictionary containing interest area distribution
Raises:
LLMClientException: If LLM analysis fails
ValueError: If input validation fails
"""
if not user_summaries:
logger.warning(f"No summaries provided for interest analysis of user {user_id}")
return {"interest_distribution": {}}
if not user_id:
raise ValueError("User ID is required for interest analysis")
try:
# Prepare summaries and get prompt
formatted_summaries = self._prepare_summaries_for_analysis(user_summaries)
prompt = get_interest_analysis_prompt(formatted_summaries, user_id)
# Get LLM client and perform analysis
llm_client = self._get_llm_client(model_id)
messages = [{"role": "user", "content": prompt}]
# Use structured output for reliable parsing
response = await llm_client.response_structured(
messages=messages,
response_model=InterestAnalysisResponse
)
result = response.model_dump()
interest_dist = result.get('interest_distribution', {})
logger.info(f"Analyzed interests for user {user_id}: {list(interest_dist.keys())}")
return result
except Exception as e:
logger.error(f"Interest analysis failed for user {user_id}: {e}")
raise LLMClientException(f"Interest analysis failed: {e}") from e
async def analyze_habits(
self,
user_summaries: List[UserMemorySummary],
user_id: str,
model_id: Optional[str] = None
) -> Dict[str, Any]:
"""Analyze user behavioral habits from memory summaries.
Args:
user_summaries: List of user memory summaries to analyze
user_id: Target user ID for analysis
model_id: Optional LLM model ID to use
Returns:
Dictionary containing identified behavioral habits
Raises:
LLMClientException: If LLM analysis fails
ValueError: If input validation fails
"""
if not user_summaries:
logger.warning(f"No summaries provided for habit analysis of user {user_id}")
return {"habits": []}
if not user_id:
raise ValueError("User ID is required for habit analysis")
try:
# Prepare summaries and get prompt
formatted_summaries = self._prepare_summaries_for_analysis(user_summaries)
prompt = get_habit_analysis_prompt(formatted_summaries, user_id)
# Get LLM client and perform analysis
llm_client = self._get_llm_client(model_id)
messages = [{"role": "user", "content": prompt}]
# Use structured output for reliable parsing
response = await llm_client.response_structured(
messages=messages,
response_model=HabitAnalysisResponse
)
result = response.model_dump()
logger.info(f"Analyzed habits for user {user_id}: found {len(result.get('habits', []))} habits")
return result
except Exception as e:
logger.error(f"Habit analysis failed for user {user_id}: {e}")
raise LLMClientException(f"Habit analysis failed: {e}") from e

View File

@@ -0,0 +1,69 @@
"""LLM Prompt Templates for Implicit Memory Analysis
This module contains prompt rendering functions for analyzing user memory summaries
to extract preferences, personality dimensions, interests, and behavioral habits.
"""
import os
from typing import Any, Dict, List
from jinja2 import Environment, FileSystemLoader
# Setup Jinja2 environment
current_dir = os.path.dirname(os.path.abspath(__file__))
prompt_dir = os.path.join(current_dir, "prompts")
prompt_env = Environment(loader=FileSystemLoader(prompt_dir))
def _render_template(template_name: str, **kwargs) -> str:
"""Helper function to render Jinja2 templates."""
template = prompt_env.get_template(template_name)
return template.render(**kwargs)
def get_preference_analysis_prompt(
memory_summaries: List[Dict[str, Any]],
user_id: str
) -> str:
"""Get formatted preference analysis prompt using Jinja2 template."""
return _render_template(
"preference_analysis.jinja2",
memory_summaries=memory_summaries,
user_id=user_id
)
def get_dimension_analysis_prompt(
memory_summaries: List[Dict[str, Any]],
user_id: str
) -> str:
"""Get formatted dimension analysis prompt using Jinja2 template."""
return _render_template(
"dimension_analysis.jinja2",
memory_summaries=memory_summaries,
user_id=user_id
)
def get_interest_analysis_prompt(
memory_summaries: List[Dict[str, Any]],
user_id: str
) -> str:
"""Get formatted interest analysis prompt using Jinja2 template."""
return _render_template(
"interest_analysis.jinja2",
memory_summaries=memory_summaries,
user_id=user_id
)
def get_habit_analysis_prompt(
memory_summaries: List[Dict[str, Any]],
user_id: str
) -> str:
"""Get formatted habit analysis prompt using Jinja2 template."""
return _render_template(
"habit_analysis.jinja2",
memory_summaries=memory_summaries,
user_id=user_id
)

View File

@@ -0,0 +1,41 @@
You are an expert personality analyst. Analyze memory summaries to assess the user's personality across four dimensions.
## Memory Summaries
{% for summary in memory_summaries %}
Summary {{ loop.index }}:
{{ summary.content or summary.user_content or '' }}
---
{% endfor %}
## Target User ID
{{ user_id }}
## Dimensions to Analyze
1. **Creativity** (0-100%): Creative thinking, artistic interests, innovative ideas
2. **Aesthetic** (0-100%): Design preferences, visual interests, artistic appreciation
3. **Technology** (0-100%): Technical discussions, tool usage, programming interests
4. **Literature** (0-100%): Reading habits, writing style, literary references
## Instructions
1. Analyze the user's content for each dimension
2. Calculate percentage scores (0-100%)
## Output Format
{
"dimensions": {
"creativity": {"percentage": 0-100},
"aesthetic": {"percentage": 0-100},
"technology": {"percentage": 0-100},
"literature": {"percentage": 0-100}
}
}
## Example
{
"dimensions": {
"creativity": {"percentage": 75},
"aesthetic": {"percentage": 45},
"technology": {"percentage": 60},
"literature": {"percentage": 30}
}
}

View File

@@ -0,0 +1,70 @@
You are an expert at identifying behavioral patterns and habits from memory summaries.
## Memory Summaries
{% for summary in memory_summaries %}
Summary {{ loop.index }}:
{{ summary.content or summary.user_content or '' }}
---
{% endfor %}
## Target User ID
{{ user_id }}
## Instructions
1. Identify recurring behavioral patterns mentioned by the SPECIFIED USER
2. Focus on specific, concrete habits with temporal patterns
3. For each habit, provide:
- habit_description: Clear, specific description
- frequency_pattern: "daily", "weekly", "monthly", "seasonal", "occasional", "event_triggered"
- time_context: When it typically happens
- confidence_level: "high", "medium", "low"
- supporting_summaries: References to evidence
- specific_examples: Concrete examples from summaries
- is_current: true if current habit, false if past habit
4. Only include habits with medium or high confidence
5. **IMPORTANT: Output language MUST match the input language. If summaries are in Chinese, output in Chinese. If in English, output in English.**
## Output Format
{
"habits": [
{
"habit_description": "string",
"frequency_pattern": "daily|weekly|monthly|seasonal|occasional|event_triggered",
"time_context": "string",
"confidence_level": "high|medium|low",
"supporting_summaries": ["id1", "id2"],
"specific_examples": ["example1", "example2"],
"is_current": true|false
}
]
}
## Example (English input → English output)
{
"habits": [
{
"habit_description": "drinks coffee every morning",
"frequency_pattern": "daily",
"time_context": "morning routine",
"confidence_level": "high",
"supporting_summaries": ["s1", "s2"],
"specific_examples": ["needs coffee to start the day"],
"is_current": true
}
]
}
## Example (Chinese input → Chinese output)
{
"habits": [
{
"habit_description": "每天早上喝咖啡",
"frequency_pattern": "daily",
"time_context": "早晨日常",
"confidence_level": "high",
"supporting_summaries": ["s1", "s2"],
"specific_examples": ["需要咖啡来开始一天"],
"is_current": true
}
]
}

View File

@@ -0,0 +1,54 @@
You are an expert at analyzing user interests from memory summaries.
## Memory Summaries
{% for summary in memory_summaries %}
Summary {{ loop.index }}:
{{ summary.content or summary.user_content or '' }}
---
{% endfor %}
## Target User ID
{{ user_id }}
## Interest Categories
1. **Tech**: Programming, technology, software tools, hardware
2. **Lifestyle**: Daily routines, health, hobbies, social activities
3. **Music**: Music preferences, instruments, concerts
4. **Art**: Visual arts, creative projects, design, aesthetics
## Instructions
1. Categorize the user's interests into the four areas
2. Calculate percentage distribution (must total 100%)
3. Provide specific evidence for each interest area
4. Use "increasing", "decreasing", or "stable" for trending direction
5. **IMPORTANT: Output language MUST match the input language. If summaries are in Chinese, output in Chinese. If in English, output in English.**
## Output Format
{
"interest_distribution": {
"tech": {"percentage": 0-100, "evidence": [], "trending_direction": "increasing|decreasing|stable|null"},
"lifestyle": {"percentage": 0-100, "evidence": [], "trending_direction": "increasing|decreasing|stable|null"},
"music": {"percentage": 0-100, "evidence": [], "trending_direction": "increasing|decreasing|stable|null"},
"art": {"percentage": 0-100, "evidence": [], "trending_direction": "increasing|decreasing|stable|null"}
}
}
## Example (English input → English output)
{
"interest_distribution": {
"tech": {"percentage": 40, "evidence": ["discusses programming frequently"], "trending_direction": "increasing"},
"lifestyle": {"percentage": 35, "evidence": ["talks about fitness routine"], "trending_direction": "stable"},
"music": {"percentage": 15, "evidence": ["mentioned favorite bands"], "trending_direction": "stable"},
"art": {"percentage": 10, "evidence": ["visited art museum"], "trending_direction": "stable"}
}
}
## Example (Chinese input → Chinese output)
{
"interest_distribution": {
"tech": {"percentage": 40, "evidence": ["经常讨论编程"], "trending_direction": "increasing"},
"lifestyle": {"percentage": 35, "evidence": ["谈论健身日常"], "trending_direction": "stable"},
"music": {"percentage": 15, "evidence": ["提到喜欢的乐队"], "trending_direction": "stable"},
"art": {"percentage": 10, "evidence": ["参观了艺术博物馆"], "trending_direction": "stable"}
}
}

View File

@@ -0,0 +1,47 @@
You are an expert at analyzing user memory summaries to identify implicit preferences.
## Memory Summaries
{% for summary in memory_summaries %}
Summary {{ loop.index }}:
{{ summary.content or summary.user_content or '' }}
---
{% endfor %}
## Target User ID
{{ user_id }}
## Instructions
1. Focus ONLY on the specified user's preferences
2. Extract SHORT preference tags (1-3 words max), like: "音乐", "咖啡", "科幻", "设计", "古典", "吉他"
3. DO NOT use long phrases - use short nouns or noun phrases
4. Only include preferences with confidence_score >= 0.3
5. **IMPORTANT: Output language MUST match the input language. If summaries are in Chinese, output in Chinese. If in English, output in English.**
## Output Format
{
"preferences": [
{
"tag_name": "short tag",
"confidence_score": 0.0-1.0,
"supporting_evidence": ["evidence1", "evidence2"],
"context_details": "brief context",
"category": "category or null"
}
]
}
## Example (Chinese input → Chinese output)
{
"preferences": [
{"tag_name": "咖啡", "confidence_score": 0.8, "supporting_evidence": ["每天早上喝咖啡"], "context_details": "日常习惯", "category": "lifestyle"},
{"tag_name": "古典音乐", "confidence_score": 0.7, "supporting_evidence": ["喜欢听古典"], "context_details": "音乐偏好", "category": "music"}
]
}
## Example (English input → English output)
{
"preferences": [
{"tag_name": "coffee", "confidence_score": 0.8, "supporting_evidence": ["drinks coffee every morning"], "context_details": "daily routine", "category": "lifestyle"},
{"tag_name": "classical music", "confidence_score": 0.7, "supporting_evidence": ["enjoys classical"], "context_details": "music preference", "category": "music"}
]
}

View File

@@ -1,327 +0,0 @@
"""
This module provides the MemoryInsight class for analyzing user memory data.
MemoryInsight 是一个工具类,提供基础的数据获取和分析功能:
- get_domain_distribution(): 获取记忆领域分布
- get_active_periods(): 获取活跃时段
- get_social_connections(): 获取社交关联
业务逻辑如生成洞察报告应该在服务层user_memory_service.py中实现。
This script can be executed directly to test the memory insight generation for a test user.
"""
import asyncio
import json
import os
import sys
from collections import Counter
from datetime import datetime
# To run this script directly, we need to add the src directory to the Python path
# to resolve the inconsistent imports in other modules.
src_path = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
if src_path not in sys.path:
sys.path.insert(0, src_path)
from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.db import get_db_context
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.services.memory_config_service import MemoryConfigService
from pydantic import BaseModel, Field
#TODO: Fix this
# Default values (previously from definitions.py)
DEFAULT_LLM_ID = os.getenv("SELECTED_LLM_ID", "openai/qwen-plus")
DEFAULT_GROUP_ID = os.getenv("SELECTED_GROUP_ID", "group_123")
# 定义用于LLM结构化输出的Pydantic模型
class TagClassification(BaseModel):
"""
Represents the classification of a tag into a specific domain.
"""
domain: str = Field(
...,
description="The domain the tag belongs to, chosen from the predefined list.",
examples=["教育", "学习", "工作", "旅行", "家庭", "运动", "社交", "娱乐", "健康", "其他"],
)
class InsightReport(BaseModel):
"""
Represents the final insight report generated by the LLM.
"""
report: str = Field(
...,
description="A comprehensive insight report in Chinese, summarizing the user's memory patterns.",
)
class MemoryInsight:
"""
Provides insights into user memories by analyzing various aspects of their data.
"""
def __init__(self, user_id: str):
self.user_id = user_id
self.neo4j_connector = Neo4jConnector()
# Get config_id using get_end_user_connected_config
with get_db_context() as db:
try:
from app.services.memory_agent_service import (
get_end_user_connected_config,
)
connected_config = get_end_user_connected_config(user_id, db)
config_id = connected_config.get("memory_config_id")
if config_id:
# Use the config_id to get the proper LLM client
config_service = MemoryConfigService(db)
memory_config = config_service.load_memory_config(config_id)
factory = MemoryClientFactory(db)
self.llm_client = factory.get_llm_client(memory_config.llm_model_id)
else:
# TODO: Remove DEFAULT_LLM_ID fallback once all users have proper config
# Fallback to default LLM if no config found
factory = MemoryClientFactory(db)
self.llm_client = factory.get_llm_client(DEFAULT_LLM_ID)
except Exception as e:
print(f"Failed to get user connected config, using default LLM: {e}")
# TODO: Remove DEFAULT_LLM_ID fallback once all users have proper config
# Fallback to default LLM
factory = MemoryClientFactory(db)
self.llm_client = factory.get_llm_client(DEFAULT_LLM_ID)
async def close(self):
"""关闭数据库连接。"""
await self.neo4j_connector.close()
async def get_domain_distribution(self) -> dict[str, float]:
"""
Calculates the distribution of memory domains based on hot tags.
"""
hot_tags = await get_hot_memory_tags(self.user_id)
if not hot_tags:
return {}
domain_counts = Counter()
for tag, _ in hot_tags:
prompt = f"""请将以下标签归类到最合适的领域中。
可选领域及其关键词:
- 教育:学校、课程、考试、培训、教学、学科、教师、学生、班级、作业、成绩、毕业、入学、校园、大学、中学、小学、教材、学位等
- 学习:自学、阅读、书籍、技能提升、知识积累、笔记、复习、练习、研究、历史知识、科学知识、文化知识、学术讨论、知识问答等
- 工作:职业、项目、会议、同事、业务、公司、办公、任务、客户、合同、职场、工作计划等
- 旅行:旅游、景点、出行、度假、酒店、机票、导游、风景、旅行计划等
- 家庭:亲人、父母、子女、配偶、家事、家庭活动、亲情、家庭聚会等
- 运动:健身、体育、锻炼、跑步、游泳、球类、瑜伽、运动计划等
- 社交:朋友、聚会、社交活动、派对、聊天、交友、社交网络等
- 娱乐:游戏、电影、音乐、休闲、综艺、动漫、小说、娱乐活动等
- 健康:医疗、养生、心理健康、体检、药物、疾病、保健、健康管理等
- 其他:确实无法归入以上任何类别的内容
标签: {tag}
分析步骤:
1. 仔细理解标签的核心含义和使用场景
2. 对比各个领域的关键词,找到最匹配的领域
3. 特别注意:
- 历史、科学、文化等知识性内容应归类为"学习"
- 学校、课程、考试等正式教育场景应归类为"教育"
- 只有在标签完全不属于上述9个具体领域时才选择"其他"
4. 如果标签与某个领域有任何相关性,就选择该领域,不要选"其他"
请直接返回最合适的领域名称。"""
messages = [
{"role": "system", "content": "你是一个专业的标签分类助手。你必须仔细分析标签的实际含义和使用场景优先选择9个具体领域之一。'其他'类别只用于完全无法归类的极少数情况。特别注意:历史、科学、文化等知识性对话应归类为'学习'领域;学校、课程、考试等正式教育场景应归类为'教育'领域。"},
{"role": "user", "content": prompt}
]
# 直接调用并等待结果
classification = await self.llm_client.response_structured(
messages=messages,
response_model=TagClassification,
)
if classification and hasattr(classification, 'domain') and classification.domain:
domain_counts[classification.domain] += 1
total_tags = sum(domain_counts.values())
if total_tags == 0:
return {}
domain_distribution = {
domain: count / total_tags for domain, count in domain_counts.items()
}
return dict(
sorted(domain_distribution.items(), key=lambda item: item[1], reverse=True)
)
async def get_active_periods(self) -> list[int]:
"""
Identifies the top 2 most active months for the user.
Only returns months if there is valid and diverse time data.
This method checks if the time data represents real user memory timestamps
rather than auto-generated system timestamps by verifying:
1. Time data exists and is parseable
2. Time data is distributed across multiple months (not concentrated in 1-2 months)
"""
query = f"""
MATCH (d:Dialogue)
WHERE d.group_id = '{self.user_id}' AND d.created_at IS NOT NULL AND d.created_at <> ''
RETURN d.created_at AS creation_time
"""
records = await self.neo4j_connector.execute_query(query)
if not records:
return []
month_counts = Counter()
valid_dates_count = 0
for record in records:
creation_time_str = record.get("creation_time")
if not creation_time_str:
continue
try:
# 尝试解析时间字符串
dt_object = datetime.fromisoformat(creation_time_str.replace("Z", "+00:00"))
month_counts[dt_object.month] += 1
valid_dates_count += 1
except (ValueError, TypeError, AttributeError):
# 如果解析失败,跳过这条记录
continue
# 如果没有有效的时间数据,返回空列表
if not month_counts or valid_dates_count == 0:
return []
# 检查时间分布是否过于集中(可能是批量导入的数据)
# 如果超过80%的数据集中在1-2个月认为这是系统时间戳而非真实时间
unique_months = len(month_counts)
if unique_months <= 2:
# 只有1-2个月有数据很可能是批量导入
most_common_count = month_counts.most_common(1)[0][1]
if most_common_count / valid_dates_count > 0.8:
# 超过80%集中在一个月,认为是系统时间戳
return []
# 如果时间分布较为分散3个月以上认为是真实时间数据
if unique_months >= 3:
most_common_months = month_counts.most_common(2)
return [month for month, _ in most_common_months]
# 2个月的情况检查是否分布均匀
if unique_months == 2:
counts = list(month_counts.values())
# 如果两个月的数据量相差不大比例在0.3-3之间认为是真实数据
ratio = min(counts) / max(counts)
if ratio > 0.3:
most_common_months = month_counts.most_common(2)
return [month for month, _ in most_common_months]
# 其他情况返回空列表
return []
async def get_social_connections(self) -> dict | None:
"""
Finds the user with whom the most memories are shared.
使用 Chunk-Statement 的 CONTAINS 关系,因为系统中不创建 Dialogue-Statement 的 MENTIONS 关系。
"""
# 通过 Chunk 和 Statement 的 CONTAINS 关系来查找共同记忆
query = f"""
MATCH (c1:Chunk {{group_id: '{self.user_id}'}})
OPTIONAL MATCH (c1)-[:CONTAINS]->(s:Statement)
OPTIONAL MATCH (s)<-[:CONTAINS]-(c2:Chunk)
WHERE c1.group_id <> c2.group_id AND s IS NOT NULL AND c2 IS NOT NULL
WITH c2.group_id AS other_user_id, COUNT(DISTINCT s) AS common_statements
WHERE common_statements > 0
RETURN other_user_id, common_statements
ORDER BY common_statements DESC
LIMIT 1
"""
records = await self.neo4j_connector.execute_query(query)
if not records or not records[0].get("other_user_id"):
return None
most_connected_user = records[0]["other_user_id"]
common_memories_count = records[0]["common_statements"]
# 使用 Chunk 的时间范围
time_range_query = f"""
MATCH (c:Chunk)
WHERE c.group_id IN ['{self.user_id}', '{most_connected_user}']
RETURN min(c.created_at) AS start_time, max(c.created_at) AS end_time
"""
time_records = await self.neo4j_connector.execute_query(time_range_query)
start_year, end_year = "N/A", "N/A"
if time_records and time_records[0]["start_time"]:
start_year = datetime.fromisoformat(time_records[0]["start_time"].replace("Z", "+00:00")).year
end_year = datetime.fromisoformat(time_records[0]["end_time"].replace("Z", "+00:00")).year
return {
"user_id": most_connected_user,
"common_memories_count": common_memories_count,
"time_range": f"{start_year}-{end_year}",
}
async def close(self):
"""
Closes the database connection.
"""
await self.neo4j_connector.close()
async def main():
"""
Initializes and runs the memory insight analysis for a test user.
"""
# 默认从环境变量读取
test_user_id = DEFAULT_GROUP_ID
print(f"正在为用户 {test_user_id} 生成记忆洞察报告...\n")
try:
# 使用服务层函数生成报告
from app.services.user_memory_service import analytics_memory_insight_report
result = await analytics_memory_insight_report(end_user_id=test_user_id)
report = result.get("report", "")
print("--- 记忆洞察报告 ---")
print(report)
print("---------------------")
# 将结果写入统一的 User-Dashboard.json使用全局配置路径
try:
from app.core.config import settings
settings.ensure_memory_output_dir()
output_dir = settings.MEMORY_OUTPUT_DIR
try:
os.makedirs(output_dir, exist_ok=True)
except Exception:
pass
dashboard_path = os.path.join(output_dir, "User-Dashboard.json")
existing = {}
if os.path.exists(dashboard_path):
with open(dashboard_path, "r", encoding="utf-8") as rf:
existing = json.load(rf)
existing["memory_insight"] = {
"group_id": test_user_id,
"report": report
}
with open(dashboard_path, "w", encoding="utf-8") as wf:
json.dump(existing, wf, ensure_ascii=False, indent=2)
print(f"已写入 {dashboard_path} -> memory_insight")
except Exception as e:
print(f"写入 User-Dashboard.json 失败: {e}")
except Exception as e:
print(f"生成报告时出错: {e}")
if __name__ == "__main__":
# This setup allows running the async main function
if sys.platform.startswith('win') and sys.version_info >= (3, 8):
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
asyncio.run(main())

View File

@@ -1,157 +0,0 @@
"""
Generate a concise "关于我" style user summary using data from Neo4j
and the existing LLM configuration (mirrors hot_memory_tags.py setup).
Usage:
python -m analytics.user_summary --user_id <group_id>
"""
import asyncio
import json
import os
import sys
from dataclasses import dataclass
from typing import List, Tuple
# Ensure absolute imports work whether executed directly or via module
try:
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..'))
src_path = os.path.join(project_root, 'src')
if src_path not in sys.path:
sys.path.insert(0, src_path)
if project_root not in sys.path:
sys.path.insert(0, project_root)
except Exception:
pass
from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.db import get_db_context
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.services.memory_config_service import MemoryConfigService
#TODO: Fix this
# Default values (previously from definitions.py)
DEFAULT_LLM_ID = os.getenv("SELECTED_LLM_ID", "openai/qwen-plus")
DEFAULT_GROUP_ID = os.getenv("SELECTED_GROUP_ID", "group_123")
@dataclass
class StatementRecord:
statement: str
created_at: str | None
class UserSummary:
"""Builds a textual user summary for a given user/group id."""
def __init__(self, user_id: str):
self.user_id = user_id
self.connector = Neo4jConnector()
# Get config_id using get_end_user_connected_config
with get_db_context() as db:
try:
from app.services.memory_agent_service import (
get_end_user_connected_config,
)
connected_config = get_end_user_connected_config(user_id, db)
config_id = connected_config.get("memory_config_id")
if config_id:
# Use the config_id to get the proper LLM client
config_service = MemoryConfigService(db)
memory_config = config_service.load_memory_config(config_id)
factory = MemoryClientFactory(db)
self.llm = factory.get_llm_client(memory_config.llm_model_id)
else:
# TODO: Remove DEFAULT_LLM_ID fallback once all users have proper config
# Fallback to default LLM if no config found
factory = MemoryClientFactory(db)
self.llm = factory.get_llm_client(DEFAULT_LLM_ID)
except Exception as e:
print(f"Failed to get user connected config, using default LLM: {e}")
# TODO: Remove DEFAULT_LLM_ID fallback once all users have proper config
# Fallback to default LLM
factory = MemoryClientFactory(db)
self.llm = factory.get_llm_client(DEFAULT_LLM_ID)
async def close(self):
await self.connector.close()
async def _get_recent_statements(self, limit: int = 80) -> List[StatementRecord]: # TODO Used by user_memory_service
"""Fetch recent statements authored by the user/group for context."""
query = (
"MATCH (s:Statement) "
"WHERE s.group_id = $group_id AND s.statement IS NOT NULL "
"RETURN s.statement AS statement, s.created_at AS created_at "
"ORDER BY created_at DESC LIMIT $limit"
)
rows = await self.connector.execute_query(query, group_id=self.user_id, limit=limit)
records: List[StatementRecord] = []
for r in rows:
try:
records.append(StatementRecord(statement=r.get("statement", ""), created_at=r.get("created_at")))
except Exception:
continue
return records
async def _get_top_entities(self, limit: int = 30) -> List[Tuple[str, int]]:
"""Reuse hot tag logic to get meaningful entities and their frequencies."""
# get_hot_memory_tags internally filters out non-meaningful nouns with LLM
return await get_hot_memory_tags(self.user_id, limit=limit) # TODO Used by user_memory_service
async def generate_user_summary(user_id: str | None = None) -> str: # TODO useless
"""
生成用户摘要的便捷函数
Args:
user_id: 可选的用户ID
Returns:
用户摘要字符串
"""
# 导入服务层函数
from app.services.user_memory_service import analytics_user_summary
# 调用服务层函数
result = await analytics_user_summary(user_id)
return result.get("summary", "")
if __name__ == "__main__":
print("开始生成用户摘要…")
try:
# 直接使用 runtime.json 中的 group_id
summary = asyncio.run(generate_user_summary())
print("\n— 用户摘要 —\n")
print(summary)
# 将结果写入统一的 User-Dashboard.json
try:
from app.core.config import settings
settings.ensure_memory_output_dir()
output_dir = settings.MEMORY_OUTPUT_DIR
try:
os.makedirs(output_dir, exist_ok=True)
except Exception:
pass
dashboard_path = os.path.join(output_dir, "User-Dashboard.json")
existing = {}
if os.path.exists(dashboard_path):
with open(dashboard_path, "r", encoding="utf-8") as rf:
existing = json.load(rf)
existing["user_summary"] = {
"group_id": DEFAULT_GROUP_ID,
"summary": summary
}
with open(dashboard_path, "w", encoding="utf-8") as wf:
json.dump(existing, wf, ensure_ascii=False, indent=2)
print(f"已写入 {dashboard_path} -> user_summary")
except Exception as e:
print(f"写入 User-Dashboard.json 失败: {e}")
except Exception as e:
print(f"生成摘要失败: {e}")
print("请检查: 1) Neo4j 是否可用2) config.json 与 .env 的 LLM/Neo4j 配置是否正确3) 数据是否包含该用户的内容。")

View File

@@ -37,12 +37,20 @@ def parse_historical_datetime(v):
此函数手动解析 ISO 8601 格式的日期字符串支持1-4位年份
Args:
v: 日期值(可以是 None、datetime 对象或字符串)
v: 日期值(可以是 None、datetime 对象、Neo4j DateTime 对象或字符串)
Returns:
datetime 对象或 None
"""
if v is None or isinstance(v, datetime):
if v is None:
return v
# 处理 Neo4j DateTime 对象
if hasattr(v, 'to_native'):
return v.to_native()
# 处理 Python datetime 对象
if isinstance(v, datetime):
return v
if isinstance(v, str):
@@ -397,6 +405,10 @@ class ExtractedEntityNode(Node):
statement_id: str = Field(..., description="Statement this entity was extracted from")
entity_type: str = Field(..., description="Type of the entity")
description: str = Field(..., description="Entity description")
example: str = Field(
default="",
description="A concise example (around 20 characters) to help understand the entity"
)
aliases: List[str] = Field(
default_factory=list,
description="Entity aliases - alternative names for this entity"
@@ -433,6 +445,12 @@ class ExtractedEntityNode(Node):
description="Total number of times this node has been accessed"
)
# Explicit Memory Classification
is_explicit_memory: bool = Field(
default=False,
description="Whether this entity represents explicit/semantic memory (knowledge, concepts, definitions, theories, principles)"
)
@field_validator('aliases', mode='before')
@classmethod
def validate_aliases_field(cls, v): # 字段验证器 自动清理和验证 aliases 字段
@@ -466,6 +484,8 @@ class MemorySummaryNode(Node):
dialog_id: ID of the parent dialog
chunk_ids: List of chunk IDs used to generate this summary
content: Summary text content
name: Title/name of the memory summary (generated by LLM, used as title in API)
memory_type: Type/category of the episodic memory (e.g., Conversation, Project/Work, Learning, Decision, Important Event)
summary_embedding: Optional embedding vector for the summary
metadata: Additional metadata for the summary
config_id: Configuration ID used to process this summary
@@ -484,6 +504,7 @@ class MemorySummaryNode(Node):
dialog_id: str = Field(..., description="ID of the parent dialog")
chunk_ids: List[str] = Field(default_factory=list, description="List of chunk IDs used in the summary")
content: str = Field(..., description="Summary text content")
memory_type: Optional[str] = Field(None, description="Type/category of the episodic memory")
summary_embedding: Optional[List[float]] = Field(None, description="Embedding vector for the summary")
metadata: dict = Field(default_factory=dict, description="Additional metadata for the summary")
config_id: Optional[int | str] = Field(None, description="Configuration ID used to process this summary (integer or string)")

View File

@@ -38,10 +38,20 @@ class Entity(BaseModel):
name_embedding: Optional[List[float]] = Field(None, description="Embedding vector for the entity name")
type: str = Field(..., description="Type/category of the entity")
description: str = Field(..., description="Description of the entity")
example: str = Field(
default="",
description="A concise example (around 20 characters) to help understand the entity"
)
aliases: List[str] = Field(
default_factory=list,
description="Alternative names for this entity (abbreviations, full names, translations, etc.)"
)
# Explicit Memory Classification
is_explicit_memory: bool = Field(
default=False,
description="Whether this entity represents explicit/semantic memory (knowledge, concepts, definitions, theories, principles)"
)
class Triplet(BaseModel):

View File

@@ -42,7 +42,6 @@ from app.core.memory.storage_services.extraction_engine.deduplication.two_stage_
)
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.embedding_generation import (
embedding_generation,
embedding_generation_all,
generate_entity_embeddings_from_triplets,
)
@@ -179,7 +178,7 @@ class ExtractionOrchestrator:
for dialog in dialog_data_list:
for chunk in dialog.chunks:
all_statements_list.extend(chunk.statements)
total_statements = len(all_statements_list)
len(all_statements_list)
# 步骤 2: 并行执行三元组提取、时间信息提取、情绪提取和基础嵌入生成
logger.info("步骤 2/6: 并行执行三元组提取、时间信息提取、情绪提取和嵌入生成")
@@ -201,9 +200,9 @@ class ExtractionOrchestrator:
all_entities_list.extend(triplet_info.entities)
all_triplets_list.extend(triplet_info.triplets)
total_entities = len(all_entities_list)
total_triplets = len(all_triplets_list)
total_temporal = sum(len(temporal_map) for temporal_map in temporal_maps)
len(all_entities_list)
len(all_triplets_list)
sum(len(temporal_map) for temporal_map in temporal_maps)
# 步骤 3: 生成实体嵌入(依赖三元组提取结果)
logger.info("步骤 3/6: 生成实体嵌入")
@@ -385,7 +384,7 @@ class ExtractionOrchestrator:
# 用于跟踪已完成的陈述句数量
completed_statements = 0
total_statements = len(all_statements)
len(all_statements)
# 全局并行处理所有陈述句
async def extract_for_statement(stmt_data, stmt_index):
@@ -497,7 +496,7 @@ class ExtractionOrchestrator:
# 用于跟踪已完成的时间提取数量
completed_temporal = 0
total_temporal_statements = len(all_statements)
len(all_statements)
# 全局并行处理所有陈述句
async def extract_for_statement(stmt_data, stmt_index):
@@ -1082,10 +1081,12 @@ class ExtractionOrchestrator:
statement_id=statement.id, # 添加必需的 statement_id 字段
entity_type=getattr(entity, 'type', 'unknown'), # 使用 type 而不是 entity_type
description=getattr(entity, 'description', ''), # 添加必需的 description 字段
example=getattr(entity, 'example', ''), # 新增:传递示例字段
fact_summary=getattr(entity, 'fact_summary', ''), # 添加必需的 fact_summary 字段
connect_strength=entity_connect_strength if entity_connect_strength is not None else 'Strong', # 添加必需的 connect_strength 字段
aliases=getattr(entity, 'aliases', []) or [], # 传递从三元组提取阶段获取的aliases
name_embedding=getattr(entity, 'name_embedding', None),
is_explicit_memory=getattr(entity, 'is_explicit_memory', False), # 新增:传递语义记忆标记
group_id=dialog_data.group_id,
user_id=dialog_data.user_id,
apply_id=dialog_data.apply_id,

View File

@@ -59,13 +59,28 @@ async def _process_chunk_summary(
)
summary_text = structured.summary.strip()
# Generate title and type for the summary
title = None
episodic_type = None
try:
from app.services.user_memory_service import UserMemoryService
title, episodic_type = await UserMemoryService.generate_title_and_type_for_summary(
content=summary_text,
end_user_id=dialog.group_id
)
logger.info(f"Generated title and type for MemorySummary: title={title}, type={episodic_type}")
except Exception as e:
logger.warning(f"Failed to generate title and type for chunk {chunk.id}: {e}")
# Continue without title and type
# Embed the summary
embedding = (await embedder.response([summary_text]))[0]
# Build node per chunk
# Note: title is stored in the 'name' field, type is stored in 'memory_type' field
node = MemorySummaryNode(
id=uuid4().hex,
name=f"MemorySummaryChunk_{chunk.id}",
name=title if title else f"MemorySummaryChunk_{chunk.id}",
group_id=dialog.group_id,
user_id=dialog.user_id,
apply_id=dialog.apply_id,
@@ -75,6 +90,7 @@ async def _process_chunk_summary(
dialog_id=dialog.id,
chunk_ids=[chunk.id],
content=summary_text,
memory_type=episodic_type,
summary_embedding=embedding,
metadata={"ref_id": dialog.ref_id},
config_id=dialog.config_id, # 添加 config_id

View File

@@ -246,7 +246,7 @@ class AccessHistoryManager:
if not node_data:
return ConsistencyCheckResult.CONSISTENT, None
access_history = node_data.get('access_history', [])
access_history = node_data.get('access_history') or []
last_access_time = node_data.get('last_access_time')
access_count = node_data.get('access_count', 0)
activation_value = node_data.get('activation_value')
@@ -409,7 +409,7 @@ class AccessHistoryManager:
logger.error(f"节点不存在,无法修复: {node_label}[{node_id}]")
return False
access_history = node_data.get('access_history', [])
access_history = node_data.get('access_history') or []
importance_score = node_data.get('importance_score', 0.5)
# 准备修复数据
@@ -530,7 +530,7 @@ class AccessHistoryManager:
Returns:
Dict[str, Any]: 更新数据,包含所有需要更新的字段
"""
access_history = node_data.get('access_history', [])
access_history = node_data.get('access_history') or []
importance_score = node_data.get('importance_score', 0.5)
# 追加新的访问时间

View File

@@ -247,6 +247,9 @@ class ForgettingStrategy:
entity_activation = entity_node['entity_activation']
entity_importance = entity_node['entity_importance']
# 获取 group_id从 statement 或 entity 节点)
group_id = statement_node.get('group_id') or entity_node.get('group_id')
# 生成摘要内容
summary_text = await self._generate_summary(
statement_text=statement_text,
@@ -256,6 +259,19 @@ class ForgettingStrategy:
db=db
)
# 生成标题和类型使用LLM
from app.services.user_memory_service import UserMemoryService
try:
title, episodic_type = await UserMemoryService.generate_title_and_type_for_summary(
content=summary_text,
end_user_id=group_id
)
logger.info(f"成功为MemorySummary生成标题和类型: title={title}, type={episodic_type}")
except Exception as e:
logger.error(f"生成标题和类型失败,使用默认值: {str(e)}")
title = "未命名"
episodic_type = "其他"
# 计算继承的激活值和重要性(取较高值)
inherited_activation = max(statement_activation, entity_activation)
inherited_importance = max(statement_importance, entity_importance)
@@ -268,9 +284,6 @@ class ForgettingStrategy:
import uuid
summary_id = f"summary_{uuid.uuid4().hex[:16]}"
# 获取 group_id从 statement 或 entity 节点)
group_id = statement_node.get('group_id') or entity_node.get('group_id')
# 使用事务创建 MemorySummary 并删除原节点
async def merge_transaction(tx, **params):
"""事务函数:创建摘要节点并删除原节点"""
@@ -287,6 +300,8 @@ class ForgettingStrategy:
CREATE (ms:MemorySummary {
id: $summary_id,
summary: $summary_text,
name: $title,
memory_type: $episodic_type,
original_statement_id: $statement_id,
original_entity_id: $entity_id,
activation_value: $inherited_activation,
@@ -386,6 +401,8 @@ class ForgettingStrategy:
params = {
'summary_id': summary_id,
'summary_text': summary_text,
'title': title,
'episodic_type': episodic_type,
'statement_id': statement_id,
'entity_id': entity_id,
'inherited_activation': inherited_activation,

View File

@@ -386,3 +386,26 @@ async def render_memory_insight_prompt(
})
return rendered_prompt
async def render_episodic_title_and_type_prompt(content: str) -> str:
"""
Renders the episodic title and type classification prompt using the episodic_type_classification.jinja2 template.
Args:
content: The content of the episodic memory summary to analyze
Returns:
Rendered prompt content as string
"""
template = prompt_env.get_template("episodic_type_classification.jinja2")
rendered_prompt = template.render(content=content)
# 记录渲染结果到提示日志
log_prompt_rendering('episodic title and type classification', rendered_prompt)
# 可选:记录模板渲染信息
log_template_rendering('episodic_type_classification.jinja2', {
'content_len': len(content) if content else 0
})
return rendered_prompt

View File

@@ -0,0 +1,57 @@
=== Task ===
Generate a concise title and classify the episodic memory into the most appropriate category.
=== Requirements ===
- Extract a clear, concise title (10-20 characters) that captures the core content
- Classify into exactly one category based on the primary theme
- Be specific and avoid ambiguity
- Output must be valid JSON conforming to the schema below
=== Input ===
{{ content }}
=== Category Definitions ===
1. **conversation**: Daily communication, chat, discussion, and social interactions
- Keywords: chat, communication, discussion, dialogue, exchange
2. **project_work**: Work-related tasks, projects, meetings, and collaboration
- Keywords: project, task, work, meeting, collaboration, business, client
3. **learning**: Acquiring new knowledge, skill development, reading, and research
- Keywords: learning, reading, research, knowledge, skill, course, training
4. **decision**: Making important decisions, choices, and planning
- Keywords: decision, choice, planning, consideration, evaluation, weighing
5. **important_event**: Major events, milestones, and special experiences
- Keywords: important, major, milestone, special, memorable, celebration
=== Analysis Steps ===
1. Read the episodic memory content carefully
2. Identify the core theme and context
3. Extract a concise title
4. Compare against category definitions and keywords
5. Select the best matching category
6. If multiple categories apply, choose the primary one
=== Output Schema ===
**CRITICAL JSON FORMATTING REQUIREMENTS:**
1. Use only standard ASCII double quotes (") for JSON structure
2. Escape any quotation marks within string values using backslashes (\")
3. Ensure all JSON strings are properly closed and comma-separated
4. Do not include line breaks within JSON string values
Return only a JSON object with title and type fields:
{
"title": "Generated title here",
"type": "Category type here"
}
The type field must be exactly one of:
- conversation
- project_work
- learning
- decision
- important_event

View File

@@ -12,7 +12,34 @@ Extract entities and knowledge triplets from the given statement.
===Guidelines===
**Entity Extraction:**
- Extract entities with their types, context-independent descriptions, and aliases
- Extract entities with their types, context-independent descriptions, **concise examples**, aliases, and semantic memory classification
- **Semantic Memory Classification (is_explicit_memory):**
* Set to `true` if the entity represents **explicit/semantic memory**:
- **Concepts:** "Machine Learning", "Photosynthesis", "Democracy", "人工智能", "光合作用", "民主"
- **Knowledge:** "Python Programming Language", "Theory of Relativity", "Python编程语言", "相对论"
- **Definitions:** "API (Application Programming Interface)", "REST API", "应用程序接口"
- **Principles:** "SOLID Principles", "First Law of Thermodynamics", "SOLID原则", "热力学第一定律"
- **Theories:** "Evolution Theory", "Quantum Mechanics", "进化论", "量子力学"
- **Methods/Techniques:** "Agile Development", "Machine Learning Algorithm", "敏捷开发", "机器学习算法"
- **Technical Terms:** "Neural Network", "Database", "神经网络", "数据库"
* Set to `false` for:
- **People:** "John Smith", "Dr. Wang", "张明", "王博士"
- **Organizations:** "Microsoft", "Harvard University", "微软", "哈佛大学"
- **Locations:** "Beijing", "Central Park", "北京", "中央公园"
- **Events:** "2024 Conference", "Project Meeting", "2024会议", "项目会议"
- **Specific objects:** "iPhone 15", "Building A", "iPhone 15", "A栋"
- **Example Generation (IMPORTANT for semantic memory entities):**
* For entities where `is_explicit_memory=true`, generate a **concise example (around 20 characters)** to help understand the concept
* The example should be:
- **Specific and concrete**: Use real-world scenarios or applications
- **Brief**: Around 20 characters (can be slightly longer if needed for clarity)
- **In the same language as the entity name**
* Examples:
- Entity: "机器学习" → example: "如:用神经网络识别图片中的猫狗"
- Entity: "SOLID Principles" → example: "e.g., Single Responsibility, Open-Closed"
- Entity: "Photosynthesis" → example: "e.g., plants convert sunlight to energy"
- Entity: "人工智能" → example: "如:智能客服、自动驾驶"
* For non-semantic entities (`is_explicit_memory=false`), the example field can be empty
- **Aliases Extraction (Important):**
* **CRITICAL: Extract aliases ONLY in the SAME LANGUAGE as the input text**
* **DO NOT translate or add aliases in different languages**
@@ -84,21 +111,27 @@ Output:
"name": "I",
"type": "Person",
"description": "The user",
"aliases": []
"example": "",
"aliases": [],
"is_explicit_memory": false
},
{
"entity_idx": 1,
"name": "Paris",
"type": "Location",
"description": "Capital city of France",
"aliases": []
"example": "",
"aliases": [],
"is_explicit_memory": false
},
{
"entity_idx": 2,
"name": "Louvre",
"type": "Location",
"description": "World-famous museum located in Paris",
"aliases": ["Louvre Museum"]
"example": "",
"aliases": ["Louvre Museum"],
"is_explicit_memory": false
}
]
}
@@ -130,21 +163,27 @@ Output:
"name": "John Smith",
"type": "Person",
"description": "Individual person name",
"aliases": []
"example": "",
"aliases": [],
"is_explicit_memory": false
},
{
"entity_idx": 1,
"name": "Google",
"type": "Organization",
"description": "American technology company",
"aliases": ["Google LLC", "Alphabet Inc."]
"example": "",
"aliases": ["Google LLC", "Alphabet Inc."],
"is_explicit_memory": false
},
{
"entity_idx": 2,
"name": "AI product development",
"type": "WorkRole",
"type": "Concept",
"description": "Artificial intelligence product development work",
"aliases": []
"example": "e.g., developing chatbots, recommendation systems",
"aliases": [],
"is_explicit_memory": true
}
]
}
@@ -176,21 +215,27 @@ Output:
"name": "我",
"type": "Person",
"description": "用户本人",
"aliases": []
"example": "",
"aliases": [],
"is_explicit_memory": false
},
{
"entity_idx": 1,
"name": "巴黎",
"type": "Location",
"description": "法国首都城市",
"aliases": []
"example": "",
"aliases": [],
"is_explicit_memory": false
},
{
"entity_idx": 2,
"name": "卢浮宫",
"type": "Location",
"description": "位于巴黎的世界著名博物馆",
"aliases": []
"example": "",
"aliases": [],
"is_explicit_memory": false
}
]
}
@@ -222,21 +267,27 @@ Output:
"name": "张明",
"type": "Person",
"description": "个人姓名",
"aliases": []
"example": "",
"aliases": [],
"is_explicit_memory": false
},
{
"entity_idx": 1,
"name": "腾讯",
"type": "Organization",
"description": "中国科技公司",
"aliases": ["腾讯控股", "腾讯公司"]
"example": "",
"aliases": ["腾讯控股", "腾讯公司"],
"is_explicit_memory": false
},
{
"entity_idx": 2,
"name": "AI产品开发",
"type": "WorkRole",
"type": "Concept",
"description": "人工智能产品研发工作",
"aliases": []
"example": "如:开发智能客服机器人、推荐系统",
"aliases": [],
"is_explicit_memory": true
}
]
}
@@ -251,7 +302,9 @@ Output:
"name": "Tripod",
"type": "Equipment",
"description": "Photography equipment accessory",
"aliases": ["Camera Tripod"]
"example": "",
"aliases": ["Camera Tripod"],
"is_explicit_memory": false
}
]
}
@@ -266,7 +319,9 @@ Output:
"name": "三脚架",
"type": "Equipment",
"description": "摄影器材配件",
"aliases": ["相机三脚架"]
"example": "",
"aliases": ["相机三脚架"],
"is_explicit_memory": false
}
]
}

View File

@@ -85,33 +85,21 @@ Example Output:
===End of Example===
===Reflection Process===
===Internal Quality Checks (DO NOT OUTPUT)===
After generating the profile, perform the following self-review steps:
Before generating your final output, internally verify:
1. All content is grounded in provided data (no fabrication)
2. Format follows the specified structure with correct headers
3. Tone is objective, third-person, and neutral
4. All four sections are complete and within character limits
**Step 1: Data Grounding Check**
- Verify all statements are supported by the provided entities and statements
- Ensure no fabricated or speculated information is included
- Confirm all claims can be traced back to the input data
**Step 2: Format Compliance**
- Verify each section follows the specified format with section headers
- Check character count limits for each section
- Ensure proper use of section markers (【】)
**Step 3: Tone and Style Review**
- Confirm objective third-person perspective is maintained
- Check for excessive adjectives or empty phrases
- Verify neutral and restrained tone throughout
**Step 4: Completeness Check**
- Ensure all four sections are present and complete
- Verify each section addresses its specific focus area
- Confirm the one-sentence summary effectively captures the user's essence
**IMPORTANT: These checks are for your internal use only. DO NOT include them in your output.**
===Output Requirements===
**CRITICAL: Your response must ONLY contain the four sections below. Do not include any reflection, self-review, or meta-commentary.**
**LANGUAGE REQUIREMENT:**
- The output language should ALWAYS be Chinese (Simplified)
- All section content must be in Chinese
@@ -122,3 +110,5 @@ After generating the profile, perform the following self-review steps:
- Content follows immediately after the header
- Sections are separated by blank lines
- Strictly adhere to character limits for each section
- **DO NOT include any text after the 【一句话总结】 section**
- **DO NOT output reflection steps, self-review, or verification notes**

View File

@@ -64,8 +64,8 @@ def by_mineru(filename, binary=None, from_page=0, to_page=100000, lang="Chinese"
def by_textln(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", callback=None, vision_model=None, pdf_cls = None, **kwargs):
textln_api = os.environ.get("TEXTLN_APISERVER", "https://api.textin.com/ai/service/v1/pdf_to_markdown")
app_id = os.environ.get("TEXTLN_APP_ID", "fa3f24380683ad53e6c620c0f0878a09")
secret_code = os.environ.get("TEXTLN_SECRET_CODE", "6130caac9aabc6eb26433758d7898f4a")
app_id = os.environ.get("TEXTLN_APP_ID", "")
secret_code = os.environ.get("TEXTLN_SECRET_CODE", "")
pdf_parser = TextLnParser(textln_api=textln_api, app_id=app_id, secret_code=secret_code)
sections, tables = pdf_parser.parse_pdf(

View File

@@ -448,7 +448,7 @@ if __name__ == "__main__":
# 准备配置vision_model信息
# 初始化 QWenCV
vision_model = QWenCV(
key="sk-8e9e40cd171749858ce2d3722ea75669",
key="",
model_name="qwen-vl-max",
lang="Chinese", # 默认使用中文
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1"

View File

@@ -191,10 +191,14 @@ class BaseTool(ABC):
execution_time=execution_time
)
def to_langchain_tool(self):
"""转换为Langchain工具格式"""
def to_langchain_tool(self, operation: Optional[str] = None):
"""转换为Langchain工具格式
Args:
operation: 特定操作(适用于有操作的工具)
"""
from app.core.tools.langchain_adapter import LangchainAdapter
return LangchainAdapter.convert_tool(self)
return LangchainAdapter.convert_tool(self, operation)
def __repr__(self):
return f"<{self.__class__.__name__}(id={self.tool_id}, name={self.name})>"

View File

@@ -0,0 +1,216 @@
"""操作工具 - 为特定操作创建的工具包装器"""
from typing import List
from app.core.tools.base import BaseTool, ToolParameter, ToolResult, ParameterType
from app.models import ToolType
class OperationTool(BaseTool):
"""操作工具 - 包装基础工具的特定操作"""
def __init__(self, base_tool: BaseTool, operation: str):
self.base_tool = base_tool
self.operation = operation
super().__init__(base_tool.tool_id, base_tool.config)
@property
def name(self) -> str:
return f"{self.base_tool.name}_{self.operation}"
@property
def tool_type(self) -> ToolType:
"""工具类型"""
return ToolType.BUILTIN
@property
def description(self) -> str:
return f"{self.base_tool.description} - {self.operation}"
@property
def parameters(self) -> List[ToolParameter]:
"""返回特定操作的参数"""
if self.base_tool.name == 'datetime_tool':
return self._get_datetime_params()
elif self.base_tool.name == 'json_tool':
return self._get_json_params()
else:
# 默认返回除operation外的所有参数
return [p for p in self.base_tool.parameters if p.name != "operation"]
def _get_datetime_params(self) -> List[ToolParameter]:
"""获取datetime_tool特定操作的参数"""
if self.operation == "now":
return [
ToolParameter(
name="to_timezone",
type=ParameterType.STRING,
description="目标时区UTC, Asia/Shanghai",
required=False,
default="Asia/Shanghai"
),
ToolParameter(
name="output_format",
type=ParameterType.STRING,
description="输出时间格式(如:%Y-%m-%d %H:%M:%S",
required=False,
default="%Y-%m-%d %H:%M:%S"
)
]
elif self.operation == "format":
return [
ToolParameter(
name="input_value",
type=ParameterType.STRING,
description="输入值(时间字符串或时间戳)",
required=True
),
ToolParameter(
name="input_format",
type=ParameterType.STRING,
description="输入时间格式(如:%Y-%m-%d %H:%M:%S",
required=False,
default="%Y-%m-%d %H:%M:%S"
),
ToolParameter(
name="output_format",
type=ParameterType.STRING,
description="输出时间格式(如:%Y-%m-%d %H:%M:%S",
required=False,
default="%Y-%m-%d %H:%M:%S"
)
]
elif self.operation == "convert_timezone":
return [
ToolParameter(
name="input_value",
type=ParameterType.STRING,
description="输入值(时间字符串或时间戳)",
required=True
),
ToolParameter(
name="input_format",
type=ParameterType.STRING,
description="输入时间格式(如:%Y-%m-%d %H:%M:%S",
required=False,
default="%Y-%m-%d %H:%M:%S"
),
ToolParameter(
name="output_format",
type=ParameterType.STRING,
description="输出时间格式(如:%Y-%m-%d %H:%M:%S",
required=False,
default="%Y-%m-%d %H:%M:%S"
),
ToolParameter(
name="from_timezone",
type=ParameterType.STRING,
description="源时区UTC, Asia/Shanghai",
required=False,
default="Asia/Shanghai"
),
ToolParameter(
name="to_timezone",
type=ParameterType.STRING,
description="目标时区UTC, Asia/Shanghai",
required=False,
default="Asia/Shanghai"
)
]
elif self.operation == "timestamp_to_datetime":
return [
ToolParameter(
name="input_value",
type=ParameterType.STRING,
description="输入值(时间字符串或时间戳)",
required=True
),
ToolParameter(
name="output_format",
type=ParameterType.STRING,
description="输出时间格式(如:%Y-%m-%d %H:%M:%S",
required=False,
default="%Y-%m-%d %H:%M:%S"
),
ToolParameter(
name="to_timezone",
type=ParameterType.STRING,
description="目标时区UTC, Asia/Shanghai",
required=False,
default="Asia/Shanghai"
)
]
else:
return []
def _get_json_params(self) -> List[ToolParameter]:
"""获取json_tool特定操作的参数"""
base_params = [
ToolParameter(
name="input_data",
type=ParameterType.STRING,
description="输入数据JSON字符串、YAML字符串或XML字符串",
required=True
)
]
if self.operation == "insert":
return base_params + [
ToolParameter(
name="json_path",
type=ParameterType.STRING,
description="JSON路径表达式$.user.name或users[0].name",
required=True
),
ToolParameter(
name="new_value",
type=ParameterType.STRING,
description="新值用于insert操作",
required=True
)
]
elif self.operation == "replace":
return base_params + [
ToolParameter(
name="json_path",
type=ParameterType.STRING,
description="JSON路径表达式$.user.name或users[0].name",
required=True
),
ToolParameter(
name="old_text",
type=ParameterType.STRING,
description="要替换的原文本用于replace操作",
required=True
),
ToolParameter(
name="new_text",
type=ParameterType.STRING,
description="替换后的新文本用于replace操作",
required=True
)
]
elif self.operation == "delete":
return base_params + [
ToolParameter(
name="json_path",
type=ParameterType.STRING,
description="JSON路径表达式$.user.name或users[0].name",
required=True
)
]
elif self.operation == "parse":
return base_params + [
ToolParameter(
name="json_path",
type=ParameterType.STRING,
description="JSON路径表达式$.user.name或users[0].name",
required=True
)
]
else:
return base_params
async def execute(self, **kwargs) -> ToolResult:
"""执行特定操作"""
# 添加operation参数
kwargs["operation"] = self.operation
return await self.base_tool.execute(**kwargs)

View File

@@ -1,4 +1,5 @@
"""自定义工具基类"""
import json
import time
from typing import Dict, Any, List, Optional
import aiohttp
@@ -135,6 +136,13 @@ class CustomTool(BaseTool):
if not self.schema_content:
return operations
if isinstance(self.schema_content, str):
try:
self.schema_content = json.loads(self.schema_content)
except json.JSONDecodeError:
logger.error(f"无效的OpenAPI schema: {self.schema_content}")
return operations
paths = self.schema_content.get("paths", {})

View File

@@ -21,24 +21,35 @@ class LangchainToolWrapper(LangchainBaseTool):
# 内部工具实例
tool_instance: BaseTool = Field(..., description="内部工具实例")
# 特定操作(用于自定义工具)
operation: Optional[str] = Field(None, description="特定操作")
class Config:
arbitrary_types_allowed = True
def __init__(self, tool_instance: BaseTool, **kwargs):
def __init__(self, tool_instance: BaseTool, operation: Optional[str] = None, **kwargs):
"""初始化Langchain工具包装器
Args:
tool_instance: 内部工具实例
operation: 特定操作(用于自定义工具)
"""
# 动态创建参数schema
args_schema = LangchainAdapter._create_pydantic_schema(tool_instance.parameters)
args_schema = LangchainAdapter._create_pydantic_schema(
tool_instance.parameters, operation
)
# 构建工具名称
tool_name = tool_instance.name
if operation:
tool_name = f"{tool_instance.name}_{operation}"
super().__init__(
name=tool_instance.name,
name=tool_name,
description=tool_instance.description,
args_schema=args_schema,
_tool_instance=tool_instance,
tool_instance=tool_instance,
operation=operation,
**kwargs
)
@@ -58,8 +69,12 @@ class LangchainToolWrapper(LangchainBaseTool):
) -> str:
"""异步执行工具"""
try:
# 如果有特定操作,添加到参数中
if self.operation:
kwargs["operation"] = self.operation
# 执行内部工具
result = await self._tool_instance.safe_execute(**kwargs)
result = await self.tool_instance.safe_execute(**kwargs)
# 转换结果为Langchain格式
return LangchainAdapter._format_result_for_langchain(result)
@@ -73,24 +88,82 @@ class LangchainAdapter:
"""Langchain适配器 - 负责工具格式转换和标准化"""
@staticmethod
def convert_tool(tool: BaseTool) -> LangchainToolWrapper:
def convert_tool(tool: BaseTool, operation: Optional[str] = None) -> LangchainToolWrapper:
"""将内部工具转换为Langchain工具
Args:
tool: 内部工具实例
operation: 特定操作适用于有操作的工具或MCP工具名称
Returns:
Langchain兼容的工具包装器
"""
try:
wrapper = LangchainToolWrapper(tool_instance=tool)
logger.debug(f"工具转换成功: {tool.name} -> Langchain格式")
return wrapper
# 处理MCP工具的特定工具名称
if hasattr(tool, 'tool_type') and tool.tool_type.value == "mcp" and operation:
# 为MCP工具创建特定工具名称的实例
mcp_tool = LangchainAdapter._create_mcp_tool_with_name(tool, operation)
wrapper = LangchainToolWrapper(tool_instance=mcp_tool)
logger.debug(f"MCP工具转换成功: {tool.name}_{operation} -> Langchain格式")
return wrapper
elif operation and LangchainAdapter._tool_supports_operations(tool):
# 为支持多操作的工具创建特定操作实例
if tool.tool_type.value == "custom":
# 自定义工具直接传递operation参数
wrapper = LangchainToolWrapper(tool_instance=tool, operation=operation)
else:
# 内置工具使用OperationTool包装
operation_tool = LangchainAdapter._create_operation_tool(tool, operation)
wrapper = LangchainToolWrapper(tool_instance=operation_tool)
logger.debug(f"工具转换成功: {tool.name}_{operation} -> Langchain格式")
return wrapper
else:
# 单个工具
wrapper = LangchainToolWrapper(tool_instance=tool)
logger.debug(f"工具转换成功: {tool.name} -> Langchain格式")
return wrapper
except Exception as e:
logger.error(f"工具转换失败: {tool.name}, 错误: {e}")
raise
@staticmethod
def _tool_supports_operations(tool: BaseTool) -> bool:
"""检查工具是否支持多操作"""
# 内置工具中支持操作的工具
builtin_operation_tools = ['datetime_tool', 'json_tool']
# 检查内置工具
if tool.tool_type.value == "builtin" and tool.name in builtin_operation_tools:
return True
# 检查自定义工具自定义工具通过解析OpenAPI schema支持多操作
if tool.tool_type.value == "custom":
# 检查工具是否有多个操作
if hasattr(tool, '_parsed_operations') and len(tool._parsed_operations) > 1:
return True
# 或者检查参数中是否有operation参数
for param in tool.parameters:
if param.name == "operation" and param.enum:
return True
return False
@staticmethod
def _create_operation_tool(base_tool: BaseTool, operation: str) -> BaseTool:
"""为特定操作创建工具实例"""
if base_tool.tool_type.value == "builtin":
from app.core.tools.builtin.operation_tool import OperationTool
return OperationTool(base_tool, operation)
else:
raise ValueError(f"不支持的工具类型: {base_tool.tool_type.value}")
@staticmethod
def _create_mcp_tool_with_name(mcp_tool: BaseTool, tool_name: str) -> BaseTool:
"""为MCP工具创建指定工具名称的实例"""
mcp_tool.set_current_tool(tool_name)
return mcp_tool
@staticmethod
def convert_tools(tools: List[BaseTool]) -> List[LangchainToolWrapper]:
"""批量转换工具
@@ -110,15 +183,19 @@ class LangchainAdapter:
except Exception as e:
logger.error(f"跳过工具转换: {tool.name}, 错误: {e}")
logger.info(f"批量转换完成: {len(converted_tools)}/{len(tools)} 个工具")
logger.info(f"批量转换完成: {len(converted_tools)} 个工具")
return converted_tools
@staticmethod
def _create_pydantic_schema(parameters: List[ToolParameter]) -> Type[BaseModel]:
def _create_pydantic_schema(
parameters: List[ToolParameter],
operation: Optional[str] = None
) -> Type[BaseModel]:
"""根据工具参数创建Pydantic schema
Args:
parameters: 工具参数列表
operation: 特定操作(用于过滤参数)
Returns:
Pydantic模型类
@@ -127,7 +204,12 @@ class LangchainAdapter:
fields = {}
annotations = {}
for param in parameters:
# 如果指定了operation过滤掉operation参数
filtered_params = parameters
if operation:
filtered_params = [p for p in parameters if p.name != "operation"]
for param in filtered_params:
# 确定Python类型
python_type = LangchainAdapter._get_python_type(param.type)
@@ -169,9 +251,10 @@ class LangchainAdapter:
"ToolArgsSchema",
(BaseModel,),
{
"__module__": __name__,
"__annotations__": annotations,
**fields,
"Config": type("Config", (), {"extra": "forbid"})
"model_config": {"extra": "forbid"},
**fields
}
)

View File

@@ -1,12 +1,20 @@
"""MCP工具模块"""
"""MCP 工具模块 - Model Context Protocol 支持"""
from app.core.tools.mcp.base import MCPTool
from app.core.tools.mcp.client import MCPClient, MCPConnectionPool
from app.core.tools.mcp.service_manager import MCPServiceManager
# 主要类导出
from .base import MCPTool, MCPToolManager, MCPError
from .client import SimpleMCPClient, MCPConnectionError
from .service_manager import MCPServiceManager
__all__ = [
# 核心类
"MCPTool",
"MCPClient",
"MCPConnectionPool",
"MCPToolManager",
"MCPError",
# 客户端类
"SimpleMCPClient",
"MCPConnectionError",
# 服务管理(简化版)
"MCPServiceManager"
]

View File

@@ -1,10 +1,9 @@
"""MCP工具基类"""
"""MCP工具基类 - 整合版本"""
import time
from typing import Dict, Any, List
from typing import List, Dict, Any
from app.models.tool_model import ToolType
from app.core.tools.base import BaseTool
from app.schemas.tool_schema import ToolParameter, ToolResult, ParameterType
from app.core.tools.base import BaseTool, ToolParameter, ToolResult, ParameterType
from app.core.logging_config import get_business_logger
logger = get_business_logger()
@@ -14,215 +13,188 @@ class MCPTool(BaseTool):
"""MCP工具 - Model Context Protocol工具"""
def __init__(self, tool_id: str, config: Dict[str, Any]):
"""初始化MCP工具
Args:
tool_id: 工具ID
config: 工具配置
"""
super().__init__(tool_id, config)
self.server_url = config.get("server_url", "")
self.connection_config = config.get("connection_config", {})
self.available_tools = config.get("available_tools", [])
self._client = None
self._connected = False
@property
def name(self) -> str:
"""工具名称"""
return f"mcp_tool_{self.tool_id[:8]}"
@property
def description(self) -> str:
"""工具描述"""
return f"MCP工具 - 连接到 {self.server_url}"
@property
def tool_type(self) -> ToolType:
"""工具类型"""
return ToolType.MCP
@property
def parameters(self) -> List[ToolParameter]:
"""工具参数定义"""
params = []
"""根据工具名称返回对应参数"""
# 如果有指定的工具名称,从 available_tools 中获取参数
tool_name = getattr(self, '_current_tool_name', None)
if tool_name and self.available_tools:
for tool_info in self.available_tools:
if tool_info.get("tool_name") == tool_name:
arguments = tool_info.get("arguments", {})
return self._generate_parameters_from_schema(arguments)
# 添加工具选择参数
if len(self.available_tools) > 1:
params.append(ToolParameter(
# 默认返回通用参数
return [
ToolParameter(
name="tool_name",
type=ParameterType.STRING,
description="调用的MCP工具名称",
required=True,
enum=self.available_tools
))
# 添加通用参数
params.extend([
description="执行的工具名称",
required=True
),
ToolParameter(
name="arguments",
type=ParameterType.OBJECT,
description="工具参数JSON对象",
description="工具参数",
required=False,
default={}
),
ToolParameter(
name="timeout",
type=ParameterType.INTEGER,
description="超时时间(秒)",
required=False,
default=30,
minimum=1,
maximum=300
)
])
]
def _generate_parameters_from_schema(self, arguments: Dict[str, Any]) -> List[ToolParameter]:
"""从参数schema生成参数列表"""
properties = arguments.get("properties", {})
required_fields = arguments.get("required", [])
params = []
for param_name, param_def in properties.items():
param_type = self._convert_json_type_to_parameter_type(param_def.get("type", "string"))
params.append(ToolParameter(
name=param_name,
type=param_type,
description=param_def.get("description", f"参数: {param_name}"),
required=param_name in required_fields,
default=param_def.get("default"),
enum=param_def.get("enum"),
minimum=param_def.get("minimum"),
maximum=param_def.get("maximum")
))
return params
def _convert_json_type_to_parameter_type(self, json_type: str) -> ParameterType:
"""转换JSON Schema类型到ParameterType"""
type_mapping = {
"string": ParameterType.STRING,
"integer": ParameterType.INTEGER,
"number": ParameterType.NUMBER,
"boolean": ParameterType.BOOLEAN,
"array": ParameterType.ARRAY,
"object": ParameterType.OBJECT
}
return type_mapping.get(json_type, ParameterType.STRING)
def set_current_tool(self, tool_name: str):
"""设置当前工具名称,用于获取特定参数"""
self._current_tool_name = tool_name
async def execute(self, **kwargs) -> ToolResult:
"""执行MCP工具"""
start_time = time.time()
try:
# 确保连接
if not self._connected:
await self.connect()
# 确定要调用的工具
tool_name = kwargs.get("tool_name")
if not tool_name and len(self.available_tools) == 1:
tool_name = self.available_tools[0]
if not tool_name:
raise ValueError("必须指定要调用的MCP工具名称")
raise Exception("未指定工具名称")
if tool_name not in self.available_tools:
raise ValueError(f"MCP工具不存在: {tool_name}")
# 获取参数
arguments = kwargs.get("arguments", {})
timeout = kwargs.get("timeout", 30)
# 调用MCP工具
result = await self._call_mcp_tool(tool_name, arguments, timeout)
from .client import SimpleMCPClient
execution_time = time.time() - start_time
return ToolResult.success_result(
data=result,
execution_time=execution_time
)
client = SimpleMCPClient(self.server_url, self.connection_config)
async with client:
result = await client.call_tool(tool_name, arguments)
execution_time = time.time() - start_time
return ToolResult.success_result(
data=result,
execution_time=execution_time
)
except Exception as e:
execution_time = time.time() - start_time
logger.error(f"MCP工具执行失败: {kwargs.get('tool_name', 'unknown')}, 错误: {e}")
return ToolResult.error_result(
error=str(e),
error_code="MCP_ERROR",
error_code="MCP_EXECUTION_ERROR",
execution_time=execution_time
)
class MCPError(Exception):
"""MCP 错误基类"""
pass
class MCPToolManager:
"""MCP 工具管理器 - 简化版本"""
async def connect(self) -> bool:
"""连接到MCP服务器"""
def __init__(self, db=None):
self.db = db
self._tool_cache: Dict[str, Dict[str, Any]] = {} # server_url -> tools_info
async def discover_tools(
self,
server_url: str,
connection_config: Dict[str, Any] = None
) -> tuple[bool, List[Dict[str, Any]], str | None]:
"""发现 MCP 服务器上的工具"""
try:
from .client import MCPClient
from .client import SimpleMCPClient
if self._connected:
return True
self._client = MCPClient(self.server_url, self.connection_config)
if await self._client.connect():
self._connected = True
# 更新可用工具列表
await self._update_available_tools()
logger.info(f"MCP服务器连接成功: {self.server_url}")
return True
else:
logger.error(f"MCP服务器连接失败: {self.server_url}")
return False
client = SimpleMCPClient(server_url, connection_config)
async with client:
tools = await client.list_tools()
# 缓存工具信息
self._tool_cache[server_url] = {
"tools": tools,
"connection_config": connection_config,
"last_updated": time.time()
}
logger.info(f"发现 {len(tools)} 个MCP工具: {server_url}")
return True, tools, None
except Exception as e:
logger.error(f"MCP服务器连接异常: {self.server_url}, 错误: {e}")
self._connected = False
return False
error_msg = f"发现工具失败: {e}"
logger.error(error_msg)
return False, [], error_msg
async def _update_available_tools(self):
"""更新可用工具列表"""
async def test_tool_connection(
self,
server_url: str,
connection_config: Dict[str, Any] = None
) -> Dict[str, Any]:
"""测试工具连接"""
try:
if self._client and self._connected:
tools = await self._client.list_tools()
self.available_tools = [tool.get("name") for tool in tools if tool.get("name")]
logger.info(f"MCP工具列表已更新: {len(self.available_tools)} 个工具")
except Exception as e:
logger.error(f"更新MCP工具列表失败: {e}")
async def disconnect(self) -> bool:
"""断开MCP服务器连接"""
try:
if self._client:
await self._client.disconnect()
self._client = None
from .client import SimpleMCPClient
self._connected = False
logger.info(f"MCP服务器连接已断开: {self.server_url}")
return True
except Exception as e:
logger.error(f"断开MCP服务器连接失败: {e}")
return False
def get_health_status(self) -> Dict[str, Any]:
"""获取MCP服务健康状态"""
return {
"connected": self._connected,
"server_url": self.server_url,
"available_tools": self.available_tools,
"last_check": time.time()
}
async def _call_mcp_tool(self, tool_name: str, arguments: Dict[str, Any], timeout: int) -> Any:
"""调用MCP工具"""
if not self._client or not self._connected:
raise Exception("MCP客户端未连接")
try:
result = await self._client.call_tool(tool_name, arguments, timeout)
return result
except Exception as e:
logger.error(f"MCP工具调用失败: {tool_name}, 错误: {e}")
raise
async def list_available_tools(self) -> List[Dict[str, Any]]:
"""列出可用的MCP工具"""
try:
if not self._connected:
await self.connect()
if self._client:
tools = await self._client.list_tools()
self.available_tools = [tool.get("name") for tool in tools if tool.get("name")]
return tools
return []
except Exception as e:
logger.error(f"获取MCP工具列表失败: {e}")
return []
def test_connection(self) -> Dict[str, Any]:
"""测试MCP连接"""
try:
# 这里应该实现同步的连接测试
# 为了简化,返回基本信息
return {
"success": bool(self.server_url),
"server_url": self.server_url,
"connected": self._connected,
"available_tools_count": len(self.available_tools),
"message": "MCP配置有效" if self.server_url else "缺少服务器URL配置"
}
client = SimpleMCPClient(server_url, connection_config)
async with client:
tools = await client.list_tools()
return {
"success": True,
"tools_count": len(tools),
"tools": [tool.get("name") for tool in tools],
"message": "连接成功"
}
except Exception as e:
return {
"success": False,
"error": str(e)
"error": str(e),
"message": "连接失败"
}

View File

@@ -1,9 +1,8 @@
"""MCP客户端 - Model Context Protocol客户端实现"""
"""MCP客户端 - 简化版本"""
import asyncio
import json
import time
from typing import Dict, Any, List, Optional, Callable
from urllib.parse import urlparse
from typing import Dict, Any, List
import aiohttp
import websockets
from websockets.exceptions import ConnectionClosed
@@ -18,139 +17,156 @@ class MCPConnectionError(Exception):
pass
class MCPProtocolError(Exception):
"""MCP协议错误"""
pass
class MCPClient:
"""MCP客户端 - 支持HTTP和WebSocket连接"""
class SimpleMCPClient:
"""简化的 MCP 客户端"""
def __init__(self, server_url: str, connection_config: Dict[str, Any] = None):
"""初始化MCP客户端
Args:
server_url: MCP服务器URL
connection_config: 连接配置
"""
self.server_url = server_url
self.connection_config = connection_config or {}
self.timeout = self.connection_config.get("timeout", 30)
# 解析URL确定连接类型
parsed_url = urlparse(server_url)
self.connection_type = "websocket" if parsed_url.scheme in ["ws", "wss"] else "http"
# 确定连接类型
self.is_websocket = server_url.startswith(("ws://", "wss://"))
# 连接状态
self._connected = False
self._websocket = None
self._session = None
# 请求管理
self._request_id = 0
self._pending_requests: Dict[str, asyncio.Future] = {}
# 连接池配置
self.max_connections = self.connection_config.get("max_connections", 10)
self.connection_timeout = self.connection_config.get("timeout", 30)
self.retry_attempts = self.connection_config.get("retry_attempts", 3)
self.retry_delay = self.connection_config.get("retry_delay", 1)
# 健康检查
self.health_check_interval = self.connection_config.get("health_check_interval", 60)
self._health_check_task = None
self._last_health_check = None
# 事件回调
self._on_connect_callbacks: List[Callable] = []
self._on_disconnect_callbacks: List[Callable] = []
self._on_error_callbacks: List[Callable] = []
self._pending_requests = {}
async def connect(self) -> bool:
"""连接到MCP服务器
Returns:
连接是否成功
"""
async def __aenter__(self):
"""异步上下文管理器入口"""
await self.connect()
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
"""异步上下文管理器出口"""
await self.disconnect()
async def connect(self):
"""建立连接"""
try:
if self._connected:
return True
logger.info(f"连接MCP服务器: {self.server_url}")
if self.connection_type == "websocket":
success = await self._connect_websocket()
if self.is_websocket:
await self._connect_websocket()
else:
success = await self._connect_http()
if success:
self._connected = True
await self._start_health_check()
await self._notify_connect_callbacks()
logger.info(f"MCP服务器连接成功: {self.server_url}")
return success
await self._connect_http()
except Exception as e:
logger.error(f"连接MCP服务器失败: {self.server_url}, 错误: {e}")
await self._notify_error_callbacks(e)
return False
logger.error(f"MCP连接失败: {self.server_url}, 错误: {e}")
raise MCPConnectionError(f"连接失败: {e}")
async def disconnect(self) -> bool:
"""断开MCP服务器连接
Returns:
断开是否成功
"""
async def disconnect(self):
"""断开连接"""
try:
if not self._connected:
return True
logger.info(f"断开MCP服务器连接: {self.server_url}")
# 停止健康检查
await self._stop_health_check()
# 取消所有待处理的请求
for future in self._pending_requests.values():
if not future.done():
future.cancel()
self._pending_requests.clear()
# 断开连接
if self.connection_type == "websocket" and self._websocket:
if self._websocket:
await self._websocket.close()
self._websocket = None
elif self._session:
if self._session:
await self._session.close()
self._session = None
self._connected = False
await self._notify_disconnect_callbacks()
logger.info(f"MCP服务器连接已断开: {self.server_url}")
return True
except Exception as e:
logger.error(f"断开MCP服务器连接失败: {e}")
return False
logger.error(f"断开连接失败: {e}")
def _build_auth_headers(self) -> Dict[str, str]:
"""构建认证头"""
headers = {}
auth_type = self.connection_config.get("auth_type", "none")
async def _connect_websocket(self):
"""WebSocket 连接"""
headers = self._build_headers()
self._websocket = await websockets.connect(
self.server_url,
extra_headers=headers,
timeout=self.timeout
)
# 启动消息处理
asyncio.create_task(self._handle_websocket_messages())
# 发送初始化消息
await self._send_initialize()
async def _connect_http(self):
"""HTTP 连接"""
headers = self._build_headers()
timeout = aiohttp.ClientTimeout(total=self.timeout)
self._session = aiohttp.ClientSession(
headers=headers,
timeout=timeout
)
# 对于 ModelScope MCP 服务,需要先发送初始化请求
if "modelscope.net" in self.server_url:
await self._initialize_modelscope_session()
async def _initialize_modelscope_session(self):
"""初始化 ModelScope MCP 会话"""
init_request = {
"jsonrpc": "2.0",
"id": self._get_request_id(),
"method": "initialize",
"params": {
"protocolVersion": "2024-11-05",
"capabilities": {"tools": {}},
"clientInfo": {
"name": "MemoryBear",
"version": "1.0.0"
}
}
}
try:
async with self._session.post(
self.server_url,
json=init_request
) as response:
if response.status != 200:
error_text = await response.text()
raise MCPConnectionError(f"初始化失败 {response.status}: {error_text}")
init_response = await response.json()
if "error" in init_response:
raise MCPConnectionError(f"初始化失败: {init_response['error']}")
# 获取 session ID
session_id = response.headers.get("Mcp-Session-Id") or response.headers.get("mcp-session-id")
if session_id:
self._session.headers.update({"Mcp-Session-Id": session_id})
# 发送 initialized 通知
initialized_notification = {
"jsonrpc": "2.0",
"method": "notifications/initialized"
}
async with self._session.post(
self.server_url,
json=initialized_notification
) as notif_response:
pass
except aiohttp.ClientError as e:
raise MCPConnectionError(f"初始化连接失败: {e}")
def _build_headers(self) -> Dict[str, str]:
"""构建请求头"""
headers = {
"Content-Type": "application/json",
"Accept": "application/json, text/event-stream"
}
# 添加认证头
auth_config = self.connection_config.get("auth_config", {})
auth_type = self.connection_config.get("auth_type", "none")
if auth_type == "api_key":
api_key = auth_config.get("api_key")
key_name = auth_config.get("key_name", "X-API-Key")
if api_key:
headers[key_name] = api_key
elif auth_type == "bearer_token":
if auth_type == "bearer_token":
token = auth_config.get("token")
if token:
headers["Authorization"] = f"Bearer {token}"
elif auth_type == "api_key":
key = auth_config.get("api_key")
header_name = auth_config.get("key_name", "X-API-Key")
if key:
headers[header_name] = key
elif auth_type == "basic_auth":
username = auth_config.get("username")
password = auth_config.get("password")
@@ -161,160 +177,63 @@ class MCPClient:
return headers
async def _connect_websocket(self) -> bool:
"""建立WebSocket连接"""
try:
# WebSocket连接配置
extra_headers = self.connection_config.get("headers", {})
auth_headers = self._build_auth_headers()
extra_headers.update(auth_headers)
self._websocket = await websockets.connect(
self.server_url,
extra_headers=extra_headers,
timeout=self.connection_timeout
)
# 启动消息监听
asyncio.create_task(self._websocket_message_handler())
# 发送初始化消息
init_message = {
"jsonrpc": "2.0",
"id": self._get_next_request_id(),
"method": "initialize",
"params": {
"protocolVersion": "2024-11-05",
"capabilities": {
"tools": {}
},
"clientInfo": {
"name": "ToolManagementSystem",
"version": "1.0.0"
}
async def _send_initialize(self):
"""发送初始化消息"""
init_message = {
"jsonrpc": "2.0",
"id": self._get_request_id(),
"method": "initialize",
"params": {
"protocolVersion": "2024-11-05",
"capabilities": {"tools": {}},
"clientInfo": {
"name": "MemoryBear",
"version": "1.0.0"
}
}
await self._websocket.send(json.dumps(init_message))
# 等待初始化响应
response = await asyncio.wait_for(
self._websocket.recv(),
timeout=self.connection_timeout
)
init_response = json.loads(response)
if init_response.get("error", None) is not None:
raise MCPProtocolError(f"初始化失败: {init_response['error']}")
return True
except Exception as e:
logger.error(f"WebSocket连接失败: {e}")
return False
}
await self._websocket.send(json.dumps(init_message))
# 等待初始化响应
response = await asyncio.wait_for(
self._websocket.recv(),
timeout=self.timeout
)
init_response = json.loads(response)
if "error" in init_response:
raise MCPConnectionError(f"初始化失败: {init_response['error']}")
async def _connect_http(self) -> bool:
"""建立HTTP连接"""
try:
# HTTP会话配置
timeout = aiohttp.ClientTimeout(total=self.connection_timeout)
headers = self.connection_config.get("headers", {})
auth_headers = self._build_auth_headers()
headers.update(auth_headers)
self._session = aiohttp.ClientSession(
timeout=timeout,
headers=headers
)
# 测试连接
test_url = f"{self.server_url}/health" if not self.server_url.endswith('/') else f"{self.server_url}health"
async with self._session.get(test_url) as response:
if response.status == 200:
return True
else:
# 尝试根路径
async with self._session.get(self.server_url) as root_response:
return root_response.status < 400
except Exception as e:
logger.error(f"HTTP连接失败: {e}")
if self._session:
await self._session.close()
self._session = None
return False
async def _websocket_message_handler(self):
"""WebSocket消息处理器"""
async def _handle_websocket_messages(self):
"""处理 WebSocket 消息"""
try:
while self._websocket and not self._websocket.closed:
try:
message = await self._websocket.recv()
await self._handle_message(json.loads(message))
data = json.loads(message)
# 处理响应
if "id" in data:
request_id = str(data["id"])
if request_id in self._pending_requests:
future = self._pending_requests.pop(request_id)
if not future.done():
future.set_result(data)
except ConnectionClosed:
break
except json.JSONDecodeError as e:
logger.error(f"解析WebSocket消息失败: {e}")
except Exception as e:
logger.error(f"处理WebSocket消息失败: {e}")
except Exception as e:
logger.error(f"WebSocket消息处理异常: {e}")
finally:
self._connected = False
await self._notify_disconnect_callbacks()
logger.error(f"WebSocket消息处理异常: {e}")
async def _handle_message(self, message: Dict[str, Any]):
"""处理收到的消息"""
try:
# 检查是否是响应消息
if "id" in message:
request_id = str(message["id"])
if request_id in self._pending_requests:
future = self._pending_requests.pop(request_id)
if not future.done():
future.set_result(message)
# 处理通知消息
elif "method" in message:
await self._handle_notification(message)
except Exception as e:
logger.error(f"处理消息失败: {e}")
@staticmethod
async def _handle_notification(message: Dict[str, Any]):
"""处理通知消息"""
method = message.get("method")
params = message.get("params", {})
logger.debug(f"收到MCP通知: {method}, 参数: {params}")
# 这里可以根据需要处理特定的通知
# 例如:工具列表更新、服务器状态变化等
async def call_tool(self, tool_name: str, arguments: Dict[str, Any], timeout: int = 30) -> Dict[str, Any]:
"""调用MCP工具
Args:
tool_name: 工具名称
arguments: 工具参数
timeout: 超时时间(秒)
Returns:
工具执行结果
Raises:
MCPConnectionError: 连接错误
MCPProtocolError: 协议错误
"""
if not self._connected:
raise MCPConnectionError("MCP客户端未连接")
async def call_tool(self, tool_name: str, arguments: Dict[str, Any]) -> Any:
"""调用工具"""
request_data = {
"jsonrpc": "2.0",
"id": self._get_next_request_id(),
"id": self._get_request_id(),
"method": "tools/call",
"params": {
"name": tool_name,
@@ -322,343 +241,69 @@ class MCPClient:
}
}
try:
response = await self._send_request(request_data, timeout)
if response.get("error", None) is not None:
error = response["error"]
raise MCPProtocolError(f"工具调用失败: {error.get('message', '未知错误')}")
return response.get("result", {})
except asyncio.TimeoutError:
raise MCPProtocolError(f"工具调用超时: {tool_name}")
if self.is_websocket:
response = await self._send_websocket_request(request_data)
else:
response = await self._send_http_request(request_data)
if "error" in response:
error = response["error"]
raise MCPConnectionError(f"工具调用失败: {error.get('message', '未知错误')}")
return response.get("result", {})
async def list_tools(self, timeout: int = 10) -> List[Dict[str, Any]]:
"""获取可用工具列表
Args:
timeout: 超时时间(秒)
Returns:
工具列表
Raises:
MCPConnectionError: 连接错误
MCPProtocolError: 协议错误
"""
if not self._connected:
raise MCPConnectionError("MCP客户端未连接")
async def list_tools(self) -> List[Dict[str, Any]]:
"""获取工具列表"""
request_data = {
"jsonrpc": "2.0",
"id": self._get_next_request_id(),
"method": "tools/list"
"id": self._get_request_id(),
"method": "tools/list",
"params": {}
}
try:
response = await self._send_request(request_data, timeout)
if response.get("error", None) is not None:
error = response["error"]
raise MCPProtocolError(f"获取工具列表失败: {error.get('message', '未知错误')}")
result = response.get("result", {})
return result.get("tools", [])
except asyncio.TimeoutError:
raise MCPProtocolError("获取工具列表超时")
async def _send_request(self, request_data: Dict[str, Any], timeout: int) -> Dict[str, Any]:
"""发送请求并等待响应
Args:
request_data: 请求数据
timeout: 超时时间(秒)
Returns:
响应数据
"""
if self.connection_type == "websocket":
request_id = str(request_data["id"])
return await self._send_websocket_request(request_data, request_id, timeout)
if self.is_websocket:
response = await self._send_websocket_request(request_data)
else:
return await self._send_http_request(request_data, timeout)
async def _send_websocket_request(self, request_data: Dict[str, Any], request_id: str, timeout: int) -> Dict[str, Any]:
"""发送WebSocket请求"""
if not self._websocket or self._websocket.closed:
raise MCPConnectionError("WebSocket连接已断开")
response = await self._send_http_request(request_data)
# 创建Future等待响应
if "error" in response:
error = response["error"]
raise MCPConnectionError(f"获取工具列表失败: {error.get('message', '未知错误')}")
result = response.get("result", {})
return result.get("tools", [])
async def _send_websocket_request(self, request_data: Dict[str, Any]) -> Dict[str, Any]:
"""发送WebSocket请求"""
request_id = str(request_data["id"])
future = asyncio.Future()
self._pending_requests[request_id] = future
try:
# 发送请求
await self._websocket.send(json.dumps(request_data))
# 等待响应
response = await asyncio.wait_for(future, timeout=timeout)
response = await asyncio.wait_for(future, timeout=self.timeout)
return response
except asyncio.TimeoutError:
await self._pending_requests.pop(request_id, None)
self._pending_requests.pop(request_id, None)
raise
except Exception as e:
await self._pending_requests.pop(request_id, None)
raise MCPConnectionError(f"发送WebSocket请求失败: {e}")
async def _send_http_request(self, request_data: Dict[str, Any], timeout: int) -> Dict[str, Any]:
async def _send_http_request(self, request_data: Dict[str, Any]) -> Dict[str, Any]:
"""发送HTTP请求"""
if not self._session:
raise MCPConnectionError("HTTP会话未建立")
try:
url = f"{self.server_url}/mcp" if not self.server_url.endswith('/') else f"{self.server_url}mcp"
async with self._session.post(
url,
json=request_data,
timeout=aiohttp.ClientTimeout(total=timeout)
self.server_url,
json=request_data
) as response:
if response.status == 200:
return await response.json()
else:
async with self._session.post(
self.server_url,
json=request_data,
timeout=aiohttp.ClientTimeout(total=timeout)
) as root_response:
if root_response.status != 200:
error_text = await root_response.text()
raise MCPConnectionError(f"HTTP请求失败 {response.status}: {error_text}")
return await response.json()
if response.status != 200:
error_text = await response.text()
raise MCPConnectionError(f"HTTP请求失败 {response.status}: {error_text}")
return await response.json()
except aiohttp.ClientError as e:
raise MCPConnectionError(f"HTTP请求失败: {e}")
async def health_check(self) -> Dict[str, Any]:
"""执行健康检查
Returns:
健康状态信息
"""
try:
if not self._connected:
return {
"healthy": False,
"error": "未连接",
"timestamp": time.time()
}
# 发送ping请求
request_data = {
"jsonrpc": "2.0",
"id": self._get_next_request_id(),
"method": "ping"
}
start_time = time.time()
response = await self._send_request(request_data, timeout=5)
response_time = round((time.time() - start_time) * 1000)
self._last_health_check = round(time.time() * 1000)
return {
"healthy": True,
"response_time": response_time,
"timestamp": self._last_health_check,
"server_info": response.get("result", {})
}
except Exception as e:
return {
"healthy": False,
"error": str(e),
"timestamp": time.time()
}
async def _start_health_check(self):
"""启动健康检查任务"""
if self.health_check_interval > 0:
self._health_check_task = asyncio.create_task(self._health_check_loop())
async def _stop_health_check(self):
"""停止健康检查任务"""
if self._health_check_task:
self._health_check_task.cancel()
try:
await self._health_check_task
except asyncio.CancelledError:
pass
self._health_check_task = None
async def _health_check_loop(self):
"""健康检查循环"""
try:
while self._connected:
await asyncio.sleep(self.health_check_interval)
if self._connected:
health_status = await self.health_check()
if not health_status["healthy"]:
logger.warning(f"MCP服务器健康检查失败: {health_status.get('error')}")
# 可以在这里实现重连逻辑
except asyncio.CancelledError:
pass
except Exception as e:
logger.error(f"健康检查循环异常: {e}")
def _get_next_request_id(self) -> str:
"""获取下一个请求ID"""
def _get_request_id(self) -> str:
"""获取请求ID"""
self._request_id += 1
return f"req_{self._request_id}_{int(time.time() * 1000)}"
# 事件回调管理
def on_connect(self, callback: Callable):
"""注册连接回调"""
self._on_connect_callbacks.append(callback)
def on_disconnect(self, callback: Callable):
"""注册断开连接回调"""
self._on_disconnect_callbacks.append(callback)
def on_error(self, callback: Callable):
"""注册错误回调"""
self._on_error_callbacks.append(callback)
async def _notify_connect_callbacks(self):
"""通知连接回调"""
for callback in self._on_connect_callbacks:
try:
if asyncio.iscoroutinefunction(callback):
await callback()
else:
callback()
except Exception as e:
logger.error(f"连接回调执行失败: {e}")
async def _notify_disconnect_callbacks(self):
"""通知断开连接回调"""
for callback in self._on_disconnect_callbacks:
try:
if asyncio.iscoroutinefunction(callback):
await callback()
else:
callback()
except Exception as e:
logger.error(f"断开连接回调执行失败: {e}")
async def _notify_error_callbacks(self, error: Exception):
"""通知错误回调"""
for callback in self._on_error_callbacks:
try:
if asyncio.iscoroutinefunction(callback):
await callback(error)
else:
callback(error)
except Exception as e:
logger.error(f"错误回调执行失败: {e}")
@property
def is_connected(self) -> bool:
"""检查是否已连接"""
return self._connected
@property
def last_health_check(self) -> Optional[float]:
"""获取最后一次健康检查时间"""
return self._last_health_check
def get_connection_info(self) -> Dict[str, Any]:
"""获取连接信息"""
return {
"server_url": self.server_url,
"connection_type": self.connection_type,
"connected": self._connected,
"last_health_check": self._last_health_check,
"pending_requests": len(self._pending_requests),
"config": self.connection_config
}
async def __aenter__(self):
"""异步上下文管理器入口"""
await self.connect()
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
"""异步上下文管理器出口"""
await self.disconnect()
class MCPConnectionPool:
"""MCP连接池 - 管理多个MCP客户端连接"""
def __init__(self, max_connections: int = 10):
"""初始化连接池
Args:
max_connections: 最大连接数
"""
self.max_connections = max_connections
self._clients: Dict[str, MCPClient] = {}
self._lock = asyncio.Lock()
async def get_client(self, server_url: str, connection_config: Dict[str, Any] = None) -> MCPClient:
"""获取或创建MCP客户端
Args:
server_url: 服务器URL
connection_config: 连接配置
Returns:
MCP客户端实例
"""
async with self._lock:
if server_url in self._clients:
client = self._clients[server_url]
if client.is_connected:
return client
else:
# 尝试重连
if await client.connect():
return client
else:
# 移除失效的客户端
del self._clients[server_url]
# 检查连接数限制
if len(self._clients) >= self.max_connections:
# 移除最旧的连接
oldest_url = next(iter(self._clients))
await self._clients[oldest_url].disconnect()
del self._clients[oldest_url]
# 创建新客户端
client = MCPClient(server_url, connection_config)
if await client.connect():
self._clients[server_url] = client
return client
else:
raise MCPConnectionError(f"无法连接到MCP服务器: {server_url}")
async def disconnect_all(self):
"""断开所有连接"""
async with self._lock:
for client in self._clients.values():
await client.disconnect()
self._clients.clear()
def get_pool_status(self) -> Dict[str, Any]:
"""获取连接池状态"""
return {
"total_connections": len(self._clients),
"max_connections": self.max_connections,
"connections": {
url: client.get_connection_info()
for url, client in self._clients.items()
}
}
return f"req_{self._request_id}_{int(time.time() * 1000)}"

View File

@@ -1,6 +1,4 @@
"""MCP服务管理器 - 管理MCP服务的注册、更新、删除和状态监控"""
import asyncio
import time
"""MCP服务管理器 - 简化版本"""
import uuid
from typing import Dict, Any, List, Optional, Tuple
from datetime import datetime
@@ -8,133 +6,53 @@ from sqlalchemy.orm import Session
from app.models.tool_model import MCPToolConfig, ToolConfig, ToolType, ToolStatus
from app.core.logging_config import get_business_logger
from app.core.tools.mcp.client import MCPClient, MCPConnectionPool
from app.core.tools.mcp.base import MCPToolManager
logger = get_business_logger()
class MCPServiceManager:
"""MCP服务管理器 - 管理MCP服务的生命周期"""
"""MCP服务管理器 - 简化版本,主要用于工具创建"""
def __init__(self, db: Session):
"""初始化MCP服务管理器
Args:
db: 数据库会话
"""
def __init__(self, db: Session = None):
self.db = db
self.connection_pool = MCPConnectionPool(max_connections=20)
# 服务状态管理
self._services: Dict[str, Dict[str, Any]] = {} # service_id -> service_info
self._monitoring_tasks: Dict[str, asyncio.Task] = {} # service_id -> monitoring_task
# 配置
self.health_check_interval = 60 # 健康检查间隔(秒)
self.max_retry_attempts = 3 # 最大重试次数
self.retry_delay = 5 # 重试延迟(秒)
# 状态
self._running = False
self._manager_task = None
self.tool_manager = MCPToolManager(db) if db else None
async def start(self):
"""启动服务管理器"""
if self._running:
return
self._running = True
logger.info("MCP服务管理器启动")
# 加载现有服务
await self._load_existing_services()
# 启动管理任务
self._manager_task = asyncio.create_task(self._management_loop())
async def stop(self):
"""停止服务管理器"""
if not self._running:
return
self._running = False
logger.info("MCP服务管理器停止")
# 停止管理任务
if self._manager_task:
self._manager_task.cancel()
try:
await self._manager_task
except asyncio.CancelledError:
pass
# 停止所有监控任务
for task in self._monitoring_tasks.values():
task.cancel()
if self._monitoring_tasks:
await asyncio.gather(*self._monitoring_tasks.values(), return_exceptions=True)
self._monitoring_tasks.clear()
# 断开所有连接
await self.connection_pool.disconnect_all()
async def register_service(
async def create_mcp_tool(
self,
server_url: str,
connection_config: Dict[str, Any],
tenant_id: uuid.UUID,
tool_name: str,
service_name: str = None
) -> Tuple[bool, str, Optional[str]]:
"""注册MCP服务
"""创建单个MCP工具
Args:
server_url: 服务器URL
connection_config: 连接配置
tenant_id: 租户ID
service_name: 服务名称(可选)
tool_name: 具体工具名称
service_name: 服务名称
Returns:
(是否成功, 服务ID或错误信息, 错误详情)
(是否成功, 工具ID或错误信息, 错误详情)
"""
try:
# 检查服务是否已存在
existing_service = self.db.query(MCPToolConfig).filter(
MCPToolConfig.server_url == server_url
).first()
if existing_service:
return False, "服务已存在", f"URL {server_url} 已被注册"
# 测试连接
try:
client = MCPClient(server_url, connection_config)
if not await client.connect():
return False, "连接测试失败", "无法连接到MCP服务器"
# 获取可用工具
available_tools = await client.list_tools()
tool_names = [tool.get("name") for tool in available_tools if tool.get("name")]
await client.disconnect()
except Exception as e:
return False, "连接测试失败", str(e)
if not service_name:
service_name = f"mcp_{tool_name}"
# 创建工具配置
if not service_name:
service_name = f"mcp_service_{server_url.split('/')[-1]}"
tool_config = ToolConfig(
name=service_name,
description=f"MCP服务 - {server_url}",
description=f"MCP工具: {tool_name}",
tool_type=ToolType.MCP.value,
tenant_id=tenant_id,
version="1.0.0",
status=ToolStatus.AVAILABLE.value,
config_data={
"server_url": server_url,
"connection_config": connection_config
"connection_config": connection_config,
"tool_name": tool_name
}
)
@@ -146,460 +64,22 @@ class MCPServiceManager:
id=tool_config.id,
server_url=server_url,
connection_config=connection_config,
available_tools=tool_names,
health_status="healthy",
available_tools=[tool_name],
health_status="unknown",
last_health_check=datetime.now()
)
self.db.add(mcp_config)
self.db.commit()
service_id = str(tool_config.id)
# 添加到内存管理
self._services[service_id] = {
"id": service_id,
"server_url": server_url,
"connection_config": connection_config,
"tenant_id": tenant_id,
"available_tools": tool_names,
"status": "healthy",
"last_health_check": time.time(),
"retry_count": 0,
"created_at": time.time()
}
# 启动监控
await self._start_service_monitoring(service_id)
logger.info(f"MCP服务注册成功: {service_id} ({server_url})")
return True, service_id, None
logger.info(f"MCP工具创建成功: {tool_config.id} ({tool_name})")
return True, str(tool_config.id), None
except Exception as e:
self.db.rollback()
logger.error(f"注册MCP服务失败: {server_url}, 错误: {e}")
return False, "注册失败", str(e)
logger.error(f"创建MCP工具失败: {tool_name}, 错误: {e}")
return False, "创建失败", str(e)
async def unregister_service(self, service_id: str) -> Tuple[bool, str]:
"""注销MCP服务
Args:
service_id: 服务ID
Returns:
(是否成功, 错误信息)
"""
try:
# 从数据库删除
tool_config = self.db.get(ToolConfig, uuid.UUID(service_id))
if not tool_config:
return False, "服务不存在"
self.db.delete(tool_config)
self.db.commit()
# 停止监控
await self._stop_service_monitoring(service_id)
# 从内存移除
if service_id in self._services:
del self._services[service_id]
logger.info(f"MCP服务注销成功: {service_id}")
return True, ""
except Exception as e:
self.db.rollback()
logger.error(f"注销MCP服务失败: {service_id}, 错误: {e}")
return False, str(e)
async def update_service(
self,
service_id: str,
connection_config: Dict[str, Any] = None,
enabled: bool = None
) -> Tuple[bool, str]:
"""更新MCP服务配置
Args:
service_id: 服务ID
connection_config: 新的连接配置
enabled: 是否启用
Returns:
(是否成功, 错误信息)
"""
try:
# 更新数据库
mcp_config = self.db.query(MCPToolConfig).filter(
MCPToolConfig.id == uuid.UUID(service_id)
).first()
if not mcp_config:
return False, "服务不存在"
tool_config = mcp_config.base_config
if connection_config is not None:
mcp_config.connection_config = connection_config
tool_config.config_data["connection_config"] = connection_config
if enabled is not None:
tool_config.is_enabled = enabled
self.db.commit()
# 更新内存状态
if service_id in self._services:
if connection_config is not None:
self._services[service_id]["connection_config"] = connection_config
# 如果配置有变化,重启监控
if connection_config is not None:
await self._restart_service_monitoring(service_id)
logger.info(f"MCP服务更新成功: {service_id}")
return True, ""
except Exception as e:
self.db.rollback()
logger.error(f"更新MCP服务失败: {service_id}, 错误: {e}")
return False, str(e)
async def get_service_status(self, service_id: str) -> Optional[Dict[str, Any]]:
"""获取服务状态
Args:
service_id: 服务ID
Returns:
服务状态信息
"""
if service_id not in self._services:
return None
service_info = self._services[service_id].copy()
# 添加实时健康检查
try:
client = await self.connection_pool.get_client(
service_info["server_url"],
service_info["connection_config"]
)
health_status = await client.health_check()
service_info["real_time_health"] = health_status
except Exception as e:
service_info["real_time_health"] = {
"healthy": False,
"error": str(e),
"timestamp": time.time()
}
return service_info
async def list_services(self, tenant_id: uuid.UUID = None) -> List[Dict[str, Any]]:
"""列出所有服务
Args:
tenant_id: 租户ID过滤
Returns:
服务列表
"""
services = []
for service_id, service_info in self._services.items():
if tenant_id and service_info["tenant_id"] != tenant_id:
continue
services.append(service_info.copy())
return services
async def get_service_tools(self, service_id: str) -> List[Dict[str, Any]]:
"""获取服务的可用工具
Args:
service_id: 服务ID
Returns:
工具列表
"""
if service_id not in self._services:
return []
service_info = self._services[service_id]
try:
client = await self.connection_pool.get_client(
service_info["server_url"],
service_info["connection_config"]
)
tools = await client.list_tools()
# 更新缓存的工具列表
tool_names = [tool.get("name") for tool in tools if tool.get("name")]
service_info["available_tools"] = tool_names
# 更新数据库
mcp_config = self.db.query(MCPToolConfig).filter(
MCPToolConfig.id == uuid.UUID(service_id)
).first()
if mcp_config:
mcp_config.available_tools = tool_names
self.db.commit()
return tools
except Exception as e:
logger.error(f"获取服务工具失败: {service_id}, 错误: {e}")
return []
async def call_service_tool(
self,
service_id: str,
tool_name: str,
arguments: Dict[str, Any],
timeout: int = 30
) -> Dict[str, Any]:
"""调用服务工具
Args:
service_id: 服务ID
tool_name: 工具名称
arguments: 工具参数
timeout: 超时时间
Returns:
执行结果
"""
if service_id not in self._services:
raise ValueError(f"服务不存在: {service_id}")
service_info = self._services[service_id]
try:
client = await self.connection_pool.get_client(
service_info["server_url"],
service_info["connection_config"]
)
result = await client.call_tool(tool_name, arguments, timeout)
# 更新服务状态为健康
service_info["status"] = "healthy"
service_info["last_health_check"] = time.time()
service_info["retry_count"] = 0
return result
except Exception as e:
# 更新服务状态为错误
service_info["status"] = "error"
service_info["last_error"] = str(e)
service_info["retry_count"] += 1
logger.error(f"调用服务工具失败: {service_id}/{tool_name}, 错误: {e}")
raise
async def _load_existing_services(self):
"""加载现有服务"""
try:
mcp_configs = self.db.query(MCPToolConfig).join(ToolConfig).filter(
ToolConfig.status == ToolStatus.AVAILABLE.value,
ToolConfig.tool_type == ToolType.MCP.value
).all()
for mcp_config in mcp_configs:
tool_config = mcp_config.base_config
service_id = str(mcp_config.id)
self._services[service_id] = {
"id": service_id,
"server_url": mcp_config.server_url,
"connection_config": mcp_config.connection_config or {},
"tenant_id": tool_config.tenant_id,
"available_tools": mcp_config.available_tools or [],
"status": mcp_config.health_status or "unknown",
"last_health_check": mcp_config.last_health_check.timestamp() if mcp_config.last_health_check else 0,
"retry_count": 0,
"created_at": tool_config.created_at.timestamp()
}
# 启动监控
await self._start_service_monitoring(service_id)
logger.info(f"加载了 {len(mcp_configs)} 个MCP服务")
except Exception as e:
logger.error(f"加载现有服务失败: {e}")
async def _start_service_monitoring(self, service_id: str):
"""启动服务监控"""
if service_id in self._monitoring_tasks:
return
task = asyncio.create_task(self._monitor_service(service_id))
self._monitoring_tasks[service_id] = task
async def _stop_service_monitoring(self, service_id: str):
"""停止服务监控"""
if service_id in self._monitoring_tasks:
task = self._monitoring_tasks.pop(service_id)
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
async def _restart_service_monitoring(self, service_id: str):
"""重启服务监控"""
await self._stop_service_monitoring(service_id)
await self._start_service_monitoring(service_id)
async def _monitor_service(self, service_id: str):
"""监控单个服务"""
try:
while self._running and service_id in self._services:
service_info = self._services[service_id]
try:
# 执行健康检查
client = await self.connection_pool.get_client(
service_info["server_url"],
service_info["connection_config"]
)
health_status = await client.health_check()
if health_status["healthy"]:
# 服务健康
service_info["status"] = "healthy"
service_info["retry_count"] = 0
# 更新工具列表
try:
tools = await client.list_tools()
tool_names = [tool.get("name") for tool in tools if tool.get("name")]
service_info["available_tools"] = tool_names
except Exception as e:
logger.warning(f"更新工具列表失败: {service_id}, 错误: {e}")
else:
# 服务不健康
service_info["status"] = "unhealthy"
service_info["last_error"] = health_status.get("error", "健康检查失败")
service_info["retry_count"] += 1
service_info["last_health_check"] = time.time()
# 更新数据库
await self._update_service_health_in_db(service_id, health_status)
except Exception as e:
# 监控异常
service_info["status"] = "error"
service_info["last_error"] = str(e)
service_info["retry_count"] += 1
service_info["last_health_check"] = time.time()
logger.error(f"服务监控异常: {service_id}, 错误: {e}")
# 如果重试次数过多,暂停监控
if service_info["retry_count"] >= self.max_retry_attempts:
logger.warning(f"服务 {service_id} 重试次数过多,暂停监控")
await asyncio.sleep(self.health_check_interval * 5) # 延长等待时间
service_info["retry_count"] = 0 # 重置重试计数
# 等待下次检查
await asyncio.sleep(self.health_check_interval)
except asyncio.CancelledError:
pass
except Exception as e:
logger.error(f"服务监控任务异常: {service_id}, 错误: {e}")
async def _update_service_health_in_db(self, service_id: str, health_status: Dict[str, Any]):
"""更新数据库中的服务健康状态"""
try:
mcp_config = self.db.query(MCPToolConfig).filter(
MCPToolConfig.id == uuid.UUID(service_id)
).first()
if mcp_config:
mcp_config.health_status = "healthy" if health_status["healthy"] else "unhealthy"
mcp_config.last_health_check = datetime.now()
if not health_status["healthy"]:
mcp_config.error_message = health_status.get("error", "")
else:
mcp_config.error_message = None
self.db.commit()
except Exception as e:
logger.error(f"更新数据库健康状态失败: {service_id}, 错误: {e}")
self.db.rollback()
async def _management_loop(self):
"""管理循环 - 处理服务清理等任务"""
try:
while self._running:
# 清理失效的服务
await self._cleanup_failed_services()
# 等待下次循环
await asyncio.sleep(300) # 5分钟
except asyncio.CancelledError:
pass
except Exception as e:
logger.error(f"管理循环异常: {e}")
async def _cleanup_failed_services(self):
"""清理长期失效的服务"""
try:
current_time = time.time()
cleanup_threshold = 24 * 60 * 60 # 24小时
services_to_cleanup = []
for service_id, service_info in self._services.items():
# 检查服务是否长期失效
if (service_info["status"] in ["error", "unhealthy"] and
current_time - service_info["last_health_check"] > cleanup_threshold):
services_to_cleanup.append(service_id)
for service_id in services_to_cleanup:
logger.warning(f"清理长期失效的服务: {service_id}")
# 停止监控但不删除数据库记录
await self._stop_service_monitoring(service_id)
# 标记为禁用
tool_config = self.db.get(ToolConfig, uuid.UUID(service_id))
if tool_config:
tool_config.is_enabled = False
self.db.commit()
# 从内存移除
del self._services[service_id]
except Exception as e:
logger.error(f"清理失效服务失败: {e}")
def get_manager_status(self) -> Dict[str, Any]:
"""获取管理器状态"""
return {
"running": self._running,
"total_services": len(self._services),
"healthy_services": len([s for s in self._services.values() if s["status"] == "healthy"]),
"unhealthy_services": len([s for s in self._services.values() if s["status"] in ["unhealthy", "error"]]),
"monitoring_tasks": len(self._monitoring_tasks),
"connection_pool_status": self.connection_pool.get_pool_status()
}
def get_tool_manager(self) -> MCPToolManager:
"""获取工具管理器实例"""
return self.tool_manager

View File

@@ -64,6 +64,11 @@ def validate_model_exists_and_active(
) -> tuple[str, bool]:
"""Validate that a model exists and is active.
This function performs tenant-aware model validation with detailed error messages:
- If model doesn't exist at all: "Model not found"
- If model exists but belongs to different tenant: "Model belongs to different tenant" with details
- If model exists and accessible but inactive: "Model is inactive"
Args:
model_id: Model UUID to validate
model_type: Type of model ("llm", "embedding", "rerank")
@@ -76,7 +81,7 @@ def validate_model_exists_and_active(
Tuple of (model_name, is_active)
Raises:
ModelNotFoundError: If model does not exist
ModelNotFoundError: If model does not exist or belongs to different tenant
ModelInactiveError: If model exists but is inactive
"""
from app.repositories.model_repository import ModelConfigRepository
@@ -84,21 +89,48 @@ def validate_model_exists_and_active(
start_time = time.time()
try:
# First check if model exists at all (without tenant filtering)
model_without_tenant = ModelConfigRepository.get_by_id(db, model_id, tenant_id=None)
# Then check with tenant filtering
model = ModelConfigRepository.get_by_id(db, model_id, tenant_id)
elapsed_ms = (time.time() - start_time) * 1000
if not model:
logger.warning(
"Model not found",
extra={"model_id": str(model_id), "model_type": model_type, "elapsed_ms": elapsed_ms}
)
raise ModelNotFoundError(
model_id=model_id,
model_type=model_type,
config_id=config_id,
workspace_id=workspace_id,
message=f"{model_type.title()} model {model_id} not found"
)
if model_without_tenant:
# Model exists but belongs to different tenant
logger.warning(
"Model belongs to different tenant",
extra={
"model_id": str(model_id),
"model_type": model_type,
"model_name": model_without_tenant.name,
"model_tenant_id": str(model_without_tenant.tenant_id),
"requested_tenant_id": str(tenant_id),
"is_public": model_without_tenant.is_public,
"elapsed_ms": elapsed_ms
}
)
raise ModelNotFoundError(
model_id=model_id,
model_type=model_type,
config_id=config_id,
workspace_id=workspace_id,
message=f"{model_type.title()} model {model_id} ({model_without_tenant.name}) belongs to a different tenant (model tenant: {model_without_tenant.tenant_id}, workspace tenant: {tenant_id}). The model is not public and cannot be accessed from this workspace."
)
else:
# Model doesn't exist at all
logger.warning(
"Model not found",
extra={"model_id": str(model_id), "model_type": model_type, "elapsed_ms": elapsed_ms}
)
raise ModelNotFoundError(
model_id=model_id,
model_type=model_type,
config_id=config_id,
workspace_id=workspace_id,
message=f"{model_type.title()} model {model_id} not found"
)
if not model.is_active:
logger.warning(

View File

@@ -22,7 +22,7 @@ class AssignmentItem(BaseModel):
)
value: Any = Field(
...,
default=None,
description="Value(s) to assign to the variable(s)",
)

View File

@@ -29,7 +29,7 @@ class WorkflowState(TypedDict):
# Set of loop node IDs, used for assigning values in loop nodes
cycle_nodes: list
looping: bool
looping: Annotated[bool, lambda x, y: x and y]
# Input variables (passed from configured variables)
# Uses a deep merge function, supporting nested dict updates (e.g., conv.xxx)
@@ -534,7 +534,7 @@ class BaseNode(ABC):
return edge
return None
def _render_template(self, template: str, state: WorkflowState | None) -> str:
def _render_template(self, template: str, state: WorkflowState | None, struct: bool = True) -> str:
"""渲染模板
支持的变量命名空间:
@@ -568,7 +568,8 @@ class BaseNode(ABC):
template=template,
variables=variables,
node_outputs=pool.get_all_node_outputs(),
system_vars=pool.get_all_system_vars()
system_vars=pool.get_all_system_vars(),
struct=struct
)
def _evaluate_condition(self, expression: str, state: WorkflowState | None) -> bool:

View File

@@ -1,6 +1,6 @@
from typing import Any
from pydantic import Field, BaseModel
from pydantic import Field, BaseModel, field_validator
from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableType
from app.core.workflow.nodes.enums import ComparisonOperator, LogicOperator, ValueInputType
@@ -27,6 +27,16 @@ class CycleVariable(BaseNodeConfig):
description="Initial or current value of the loop variable"
)
@field_validator("input_type", mode="before")
@classmethod
def lower_input_type(cls, v):
if isinstance(v, str):
try:
return ValueInputType(v.lower())
except ValueError:
raise ValueError(f"Invalid input_type: {v}")
return v
class ConditionDetail(BaseModel):
operator: ComparisonOperator = Field(
@@ -45,10 +55,20 @@ class ConditionDetail(BaseModel):
)
input_type: ValueInputType = Field(
...,
default=ValueInputType.CONSTANT,
description="Input type of the loop variable"
)
@field_validator("input_type", mode="before")
@classmethod
def lower_input_type(cls, v):
if isinstance(v, str):
try:
return ValueInputType(v.lower())
except ValueError:
raise ValueError(f"Invalid input_type: {v}")
return v
class ConditionsConfig(BaseModel):
"""Configuration for loop condition evaluation"""

View File

@@ -37,7 +37,7 @@ class EndNode(BaseNode):
# 如果配置了输出模板,使用模板渲染;否则使用默认输出
if output_template:
output = self._render_template(output_template, state)
output = self._render_template(output_template, state, struct=False)
else:
output = "工作流已完成"

View File

@@ -93,5 +93,5 @@ class HttpErrorHandle(StrEnum):
class ValueInputType(StrEnum):
VARIABLE = "Variable"
CONSTANT = "Constant"
VARIABLE = "variable"
CONSTANT = "constant"

View File

@@ -73,8 +73,10 @@ class HttpContentTypeConfig(BaseModel):
content_type = info.data.get("content_type")
if content_type == HttpContentType.FROM_DATA and not isinstance(v, HttpFormData):
raise ValueError("When content_type is 'form-data', data must be of type HttpFormData")
elif content_type in [HttpContentType.JSON, HttpContentType.WWW_FORM] and not isinstance(v, dict):
raise ValueError("When content_type is JSON or x-www-form-urlencoded, data must be a object")
elif content_type in [HttpContentType.JSON] and not isinstance(v, str):
raise ValueError("When content_type is JSON, data must be of type str")
elif content_type in [HttpContentType.WWW_FORM] and not isinstance(v, dict):
raise ValueError("When content_type is x-www-form-urlencoded, data must be an object(dict)")
elif content_type in [HttpContentType.RAW, HttpContentType.BINARY] and not isinstance(v, str):
raise ValueError("When content_type is raw/binary, data must be a string (File descriptor)")
return v

View File

@@ -120,7 +120,7 @@ class HttpRequestNode(BaseNode):
return {}
case HttpContentType.JSON:
content["json"] = json.loads(self._render_template(
json.dumps(self.typed_config.body.data), state
self.typed_config.body.data, state
))
case HttpContentType.FROM_DATA:
data = {}
@@ -208,17 +208,12 @@ class HttpRequestNode(BaseNode):
retries -= 1
if retries > 0:
await asyncio.sleep(self.typed_config.retry.retry_interval / 1000)
elif self.typed_config.error_handle.method == HttpErrorHandle.NONE:
raise e
except Exception as e:
raise RuntimeError(f"HTTP request node exception: {e}")
else:
match self.typed_config.error_handle.method:
case HttpErrorHandle.NONE:
logger.warning(
f"Node {self.node_id}: HTTP request failed, returning error response"
)
return HttpRequestNodeOutput(
body="",
status_code=resp.status_code,
headers=resp.headers,
).model_dump()
case HttpErrorHandle.DEFAULT:
logger.warning(
f"Node {self.node_id}: HTTP request failed, returning default result"
@@ -229,3 +224,4 @@ class HttpRequestNode(BaseNode):
f"Node {self.node_id}: HTTP request failed, switching to error handling branch"
)
return "ERROR"
raise RuntimeError("http request failed")

View File

@@ -23,10 +23,20 @@ class ConditionDetail(BaseModel):
)
input_type: ValueInputType = Field(
...,
default=ValueInputType.CONSTANT,
description="Value input type for comparison"
)
@field_validator("input_type", mode="before")
@classmethod
def lower_input_type(cls, v):
if isinstance(v, str):
try:
return ValueInputType(v.lower())
except ValueError:
raise ValueError(f"Invalid input_type: {v}")
return v
class ConditionBranchConfig(BaseModel):
"""Configuration for a conditional branch"""

View File

@@ -71,7 +71,10 @@ class IfElseNode(BaseNode):
for expression in case_branch.expressions:
pattern = r"\{\{\s*(.*?)\s*\}\}"
left_string = re.sub(pattern, r"\1", expression.left).strip()
left_value = self.get_variable(left_string, state)
try:
left_value = self.get_variable(left_string, state)
except KeyError:
left_value = None
evaluator = ConditionExpressionResolver.resolve_by_value(left_value)(
self.get_variable_pool(state),
expression.left,

View File

@@ -203,15 +203,20 @@ class KnowledgeRetrievalNode(BaseNode):
rs2 = vector_service.search_by_full_text(query=query, top_k=kb_config.top_k,
indices=indices,
score_threshold=kb_config.similarity_threshold)
# Deduplicate hybrid retrieval results
# Deduplicate hy brid retrieval results
unique_rs = self._deduplicate_docs(rs1, rs2)
if not unique_rs:
continue
vector_service.reranker = self.get_reranker_model()
rs.extend(vector_service.rerank(query=query, docs=unique_rs, top_k=kb_config.top_k))
case _:
raise RuntimeError("Unknown retrieval type")
if not rs:
return []
vector_service.reranker = self.get_reranker_model()
# TODO其他重排序方式支持
final_rs = vector_service.rerank(query=query, docs=rs, top_k=self.typed_config.reranker_top_k)
logger.info(
f"Node {self.node_id}: knowledge base retrieval completed, results count: {len(final_rs)}"
)
return [chunk.model_dump() for chunk in final_rs]
return [chunk.page_content for chunk in final_rs]

View File

@@ -1,5 +1,7 @@
"""LLM 节点配置"""
from typing import Any
from pydantic import BaseModel, Field, field_validator
from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableDefinition, VariableType
@@ -7,17 +9,17 @@ from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableDefiniti
class MessageConfig(BaseModel):
"""消息配置"""
role: str = Field(
...,
description="消息角色system, user, assistant"
)
content: str = Field(
...,
description="消息内容,支持模板变量,如:{{ sys.message }}"
)
@field_validator("role")
@classmethod
def validate_role(cls, v: str) -> str:
@@ -35,24 +37,29 @@ class LLMNodeConfig(BaseNodeConfig):
1. 简单模式:使用 prompt 字段
2. 消息模式:使用 messages 字段(推荐)
"""
model_id: str = Field(
...,
description="模型配置 ID"
)
context: Any = Field(
default="",
description="上下文"
)
# 简单模式
prompt: str | None = Field(
default=None,
description="提示词模板(简单模式),支持变量引用"
)
# 消息模式(推荐)
messages: list[MessageConfig] | None = Field(
default=None,
description="消息列表(消息模式),支持多轮对话"
)
# 模型参数
temperature: float | None = Field(
default=0.7,
@@ -60,35 +67,35 @@ class LLMNodeConfig(BaseNodeConfig):
le=2.0,
description="温度参数,控制输出的随机性"
)
max_tokens: int | None = Field(
default=1000,
ge=1,
le=32000,
description="最大生成 token 数"
)
top_p: float | None = Field(
default=None,
ge=0.0,
le=1.0,
description="Top-p 采样参数"
)
frequency_penalty: float | None = Field(
default=None,
ge=-2.0,
le=2.0,
description="频率惩罚"
)
presence_penalty: float | None = Field(
default=None,
ge=-2.0,
le=2.0,
description="存在惩罚"
)
# 输出变量定义
output_variables: list[VariableDefinition] = Field(
default_factory=lambda: [
@@ -105,14 +112,14 @@ class LLMNodeConfig(BaseNodeConfig):
],
description="输出变量定义(自动生成,通常不需要修改)"
)
@field_validator("messages", "prompt")
@classmethod
def validate_input_mode(cls, v, info):
"""验证输入模式prompt 和 messages 至少有一个"""
# 这个验证在 model_validator 中更合适
return v
class Config:
json_schema_extra = {
"examples": [

View File

@@ -5,15 +5,17 @@ LLM 节点实现
"""
import logging
import re
from typing import Any
from langchain_core.messages import AIMessage, SystemMessage, HumanMessage
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
from app.core.models import RedBearLLM, RedBearModelConfig
from app.core.workflow.nodes.llm.config import LLMNodeConfig
from app.db import get_db_context
from app.models import ModelType
from app.services.model_service import ModelConfigService
from app.core.exceptions import BusinessException
from app.core.error_codes import BizCode
@@ -63,8 +65,16 @@ class LLMNode(BaseNode):
- user/human: 用户消息HumanMessage
- ai/assistant: AI 消息AIMessage
"""
def _prepare_llm(self, state: WorkflowState,stream:bool = False) -> tuple[RedBearLLM, list | str]:
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
super().__init__(node_config, workflow_config)
self.typed_config = LLMNodeConfig(**self.config)
def _render_context(self, message, state):
context = f"<context>{self._render_template(self.typed_config.context, state)}</context>"
return re.sub(r"{{context}}", context, message)
def _prepare_llm(self, state: WorkflowState, stream: bool = False) -> tuple[RedBearLLM, list | str]:
"""准备 LLM 实例(公共逻辑)
Args:
@@ -76,15 +86,16 @@ class LLMNode(BaseNode):
# 1. 处理消息格式(优先使用 messages
messages_config = self.config.get("messages")
if messages_config:
# 使用 LangChain 消息格式
messages = []
for msg_config in messages_config:
role = msg_config.get("role", "user").lower()
content_template = msg_config.get("content", "")
content_template = self._render_context(content_template, state)
content = self._render_template(content_template, state)
# 根据角色创建对应的消息对象
if role == "system":
messages.append(SystemMessage(content=content))
@@ -95,7 +106,7 @@ class LLMNode(BaseNode):
else:
logger.warning(f"未知的消息角色: {role},默认使用 user")
messages.append(HumanMessage(content=content))
prompt_or_messages = messages
else:
# 使用简单的 prompt 格式(向后兼容)
@@ -106,17 +117,17 @@ class LLMNode(BaseNode):
model_id = self.config.get("model_id")
if not model_id:
raise ValueError(f"节点 {self.node_id} 缺少 model_id 配置")
# 3. 在 with 块内完成所有数据库操作和数据提取
with get_db_context() as db:
config = ModelConfigService.get_model_by_id(db=db, model_id=model_id)
if not config:
if not config:
raise BusinessException("配置的模型不存在", BizCode.NOT_FOUND)
if not config.api_keys or len(config.api_keys) == 0:
raise BusinessException("模型配置缺少 API Key", BizCode.INVALID_PARAMETER)
# 在 Session 关闭前提取所有需要的数据
api_config = config.api_keys[0]
model_name = api_config.model_name
@@ -124,26 +135,26 @@ class LLMNode(BaseNode):
api_key = api_config.api_key
api_base = api_config.api_base
model_type = config.type
# 4. 创建 LLM 实例(使用已提取的数据)
# 注意:对于流式输出,需要在模型初始化时设置 streaming=True
extra_params = {"streaming": stream} if stream else {}
llm = RedBearLLM(
RedBearModelConfig(
model_name=model_name,
provider=provider,
provider=provider,
api_key=api_key,
base_url=api_base,
extra_params=extra_params
),
),
type=ModelType(model_type)
)
logger.debug(f"创建 LLM 实例: provider={provider}, model={model_name}, streaming={stream}")
return llm, prompt_or_messages
async def execute(self, state: WorkflowState) -> AIMessage:
"""非流式执行 LLM 调用
@@ -153,10 +164,10 @@ class LLMNode(BaseNode):
Returns:
LLM 响应消息
"""
llm, prompt_or_messages = self._prepare_llm(state,True)
llm, prompt_or_messages = self._prepare_llm(state, True)
logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(非流式)")
# 调用 LLM支持字符串或消息列表
response = await llm.ainvoke(prompt_or_messages)
# 提取内容
@@ -164,16 +175,16 @@ class LLMNode(BaseNode):
content = response.content
else:
content = str(response)
logger.info(f"节点 {self.node_id} LLM 调用完成,输出长度: {len(content)}")
# 返回 AIMessage包含响应元数据
return response if isinstance(response, AIMessage) else AIMessage(content=content)
def _extract_input(self, state: WorkflowState) -> dict[str, Any]:
"""提取输入数据(用于记录)"""
_, prompt_or_messages = self._prepare_llm(state)
return {
"prompt": prompt_or_messages if isinstance(prompt_or_messages, str) else None,
"messages": [
@@ -186,13 +197,13 @@ class LLMNode(BaseNode):
"max_tokens": self.config.get("max_tokens")
}
}
def _extract_output(self, business_result: Any) -> str:
"""从 AIMessage 中提取文本内容"""
if isinstance(business_result, AIMessage):
return business_result.content
return str(business_result)
def _extract_token_usage(self, business_result: Any) -> dict[str, int] | None:
"""从 AIMessage 中提取 token 使用情况"""
if isinstance(business_result, AIMessage) and hasattr(business_result, 'response_metadata'):
@@ -204,7 +215,7 @@ class LLMNode(BaseNode):
"total_tokens": usage.get('total_tokens', 0)
}
return None
async def execute_stream(self, state: WorkflowState):
"""流式执行 LLM 调用
@@ -215,26 +226,26 @@ class LLMNode(BaseNode):
文本片段chunk或完成标记
"""
from langgraph.config import get_stream_writer
llm, prompt_or_messages = self._prepare_llm(state, True)
logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(流式)")
logger.debug(f"LLM 配置: streaming={getattr(llm._model, 'streaming', 'unknown')}")
# 检查是否有注入的 End 节点前缀配置
writer = get_stream_writer()
end_prefix = getattr(self, '_end_node_prefix', None)
logger.info(f"[LLM前缀] 节点 {self.node_id} 检查前缀配置: {end_prefix is not None}")
if end_prefix:
logger.info(f"[LLM前缀] 前缀内容: '{end_prefix}'")
if end_prefix:
# 渲染前缀(可能包含其他变量)
try:
rendered_prefix = self._render_template(end_prefix, state)
logger.info(f"节点 {self.node_id} 提前发送 End 节点前缀: '{rendered_prefix[:50]}...'")
# 提前发送 End 节点的前缀(使用 "message" 类型)
writer({
"type": "message", # End 相关的内容都是 message 类型
@@ -246,12 +257,12 @@ class LLMNode(BaseNode):
})
except Exception as e:
logger.warning(f"渲染/发送 End 节点前缀失败: {e}")
# 累积完整响应
full_response = ""
last_chunk = None
chunk_count = 0
# 调用 LLM流式支持字符串或消息列表
async for chunk in llm.astream(prompt_or_messages):
# 提取内容
@@ -259,18 +270,18 @@ class LLMNode(BaseNode):
content = chunk.content
else:
content = str(chunk)
# 只有当内容不为空时才处理
if content:
full_response += content
last_chunk = chunk
chunk_count += 1
# 流式返回每个文本片段
yield content
logger.info(f"节点 {self.node_id} LLM 调用完成,输出长度: {len(full_response)}, 总 chunks: {chunk_count}")
# 构建完整的 AIMessage包含元数据
if isinstance(last_chunk, AIMessage):
final_message = AIMessage(
@@ -279,6 +290,6 @@ class LLMNode(BaseNode):
)
else:
final_message = AIMessage(content=full_response)
# yield 完成标记
yield {"__final__": True, "result": final_message}

View File

@@ -24,7 +24,7 @@ class MemoryReadNode(BaseNode):
return await MemoryAgentService().read_memory(
group_id=end_user_id,
message=self.typed_config.message,
message=self._render_template(self.typed_config.message, state),
config_id=self.typed_config.config_id,
search_switch=self.typed_config.search_switch,
history=[],
@@ -51,7 +51,7 @@ class MemoryWriteNode(BaseNode):
return await MemoryAgentService().write_memory(
group_id=end_user_id,
message=self.typed_config.message,
message=self._render_template(self.typed_config.message, state),
config_id=self.typed_config.config_id,
db=db,
storage_type="neo4j",

View File

@@ -387,6 +387,11 @@ class ArrayComparisonOperator(ConditionBase):
return self.right_value not in self.left_value
class NoneObjectComparisonOperator(ConditionBase):
def __getattr__(self, name):
return lambda *args, **kwargs: False
CompareOperatorInstance = Union[
StringComparisonOperator,
NumberComparisonOperator,
@@ -405,6 +410,7 @@ class ConditionExpressionResolver:
float: NumberComparisonOperator,
list: ArrayComparisonOperator,
dict: ObjectComparisonOperator,
type(None): NoneObjectComparisonOperator
}
@classmethod

View File

@@ -65,7 +65,7 @@ class QuestionClassifierNode(BaseNode):
category_map[category_name] = case_tag
return category_map
async def execute(self, state: WorkflowState) -> str:
async def execute(self, state: WorkflowState) -> dict:
"""执行问题分类"""
question = self.typed_config.input_variable
supplement_prompt = self.typed_config.user_supplement_prompt or ""
@@ -79,7 +79,15 @@ class QuestionClassifierNode(BaseNode):
f"(默认分支:{DEFAULT_EMPTY_QUESTION_CASE},分类总数:{category_count}"
)
# 若分类列表为空返回默认unknown分支否则返回CASE1
return DEFAULT_EMPTY_QUESTION_CASE if category_count > 0 else "unknown"
if category_count > 0:
return {
"class_name": category_names[0],
"output": DEFAULT_EMPTY_QUESTION_CASE
}
return {
"class_name": "unknown",
"output": DEFAULT_EMPTY_QUESTION_CASE
}
try:
llm = self._get_llm_instance()
@@ -111,7 +119,10 @@ class QuestionClassifierNode(BaseNode):
log_supplement = supplement_prompt if supplement_prompt else ""
logger.info(f"节点 {self.node_id} 分类结果: {category}, 用户补充提示词:{log_supplement}")
return f"CASE{category_names.index(category) + 1}"
return {
"class_name": category,
"output": f"CASE{category_names.index(category) + 1}",
}
except Exception as e:
logger.error(
f"节点 {self.node_id} 分类执行异常:{str(e)}",
@@ -119,5 +130,11 @@ class QuestionClassifierNode(BaseNode):
)
# 异常时返回默认分支,保证工作流容错性
if category_count > 0:
return DEFAULT_EMPTY_QUESTION_CASE
return "unknown"
return {
"class_name": category_names[0],
"output": DEFAULT_EMPTY_QUESTION_CASE
}
return {
"class_name": "unknown",
"output": DEFAULT_EMPTY_QUESTION_CASE
}

View File

@@ -1,4 +1,6 @@
from pydantic import Field
from typing import Any
from app.core.workflow.nodes.base_config import BaseNodeConfig
@@ -6,4 +8,4 @@ class ToolNodeConfig(BaseNodeConfig):
"""工具节点配置"""
tool_id: str = Field(..., description="工具ID")
tool_parameters: dict[str, str] = Field(default_factory=dict, description="工具参数映射,支持工作流变量")
tool_parameters: dict[str, Any] = Field(default_factory=dict, description="工具参数映射,支持工作流变量")

View File

@@ -1,5 +1,5 @@
import logging
import uuid
import re
from typing import Any
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
@@ -9,6 +9,8 @@ from app.db import get_db_read
logger = logging.getLogger(__name__)
TEMPLATE_PATTERN = re.compile(r"\{\{.*?\}\}")
class ToolNode(BaseNode):
"""工具节点"""
@@ -25,25 +27,33 @@ class ToolNode(BaseNode):
# 如果没有租户ID尝试从工作流ID获取
if not tenant_id:
workflow_id = self.get_variable("sys.workflow_id", state)
if workflow_id:
workspace_id = self.get_variable("sys.workspace_id", state)
if workspace_id:
from app.repositories.tool_repository import ToolRepository
with get_db_read() as db:
tenant_id = ToolRepository.get_tenant_id_by_workflow_id(db, workflow_id)
tenant_id = ToolRepository.get_tenant_id_by_workspace_id(db, workspace_id)
if not tenant_id:
tenant_id = uuid.UUID("6c2c91b0-3f49-4489-9157-2208aa56a097")
# logger.error(f"节点 {self.node_id} 缺少租户ID")
# return {"error": "缺少租户ID"}
logger.error(f"节点 {self.node_id} 缺少租户ID")
return {
"success": False,
"data": "缺少租户ID"
}
# 渲染工具参数
rendered_parameters = {}
for param_name, param_template in self.typed_config.tool_parameters.items():
rendered_value = self._render_template(param_template, state)
if isinstance(param_template, str) and TEMPLATE_PATTERN.search(param_template):
try:
rendered_value = self._render_template(param_template, state)
except Exception as e:
raise ValueError(f"模板渲染失败:参数 {param_name} 的模板 {param_template} 解析错误") from e
else:
# 非模板参数(数字/布尔/普通字符串)直接保留原值
rendered_value = param_template
rendered_parameters[param_name] = rendered_value
logger.info(f"节点 {self.node_id} 执行工具 {self.typed_config.tool_id},参数: {rendered_parameters}")
print(self.typed_config.tool_id)
# 执行工具
with get_db_read() as db:
@@ -54,7 +64,7 @@ class ToolNode(BaseNode):
tenant_id=tenant_id,
user_id=user_id
)
print(result)
if result.success:
logger.info(f"节点 {self.node_id} 工具执行成功")
return {
@@ -66,7 +76,7 @@ class ToolNode(BaseNode):
logger.error(f"节点 {self.node_id} 工具执行失败: {result.error}")
return {
"success": False,
"error": result.error,
"data": result.error,
"error_code": result.error_code,
"execution_time": result.execution_time
}

View File

@@ -14,7 +14,7 @@ logger = logging.getLogger(__name__)
class TemplateRenderer:
"""模板渲染器"""
def __init__(self, strict: bool = True):
"""初始化渲染器
@@ -25,13 +25,13 @@ class TemplateRenderer:
undefined=StrictUndefined if strict else Undefined,
autoescape=False # 不自动转义,因为我们处理的是文本而非 HTML
)
def render(
self,
template: str,
variables: dict[str, Any],
node_outputs: dict[str, Any],
system_vars: dict[str, Any] | None = None
self,
template: str,
variables: dict[str, Any],
node_outputs: dict[str, Any],
system_vars: dict[str, Any] | None = None
) -> str:
"""渲染模板
@@ -69,40 +69,40 @@ class TemplateRenderer:
# variables 的结构:{"sys": {...}, "conv": {...}}
sys_vars = variables.get("sys", {}) if isinstance(variables, dict) else {}
conv_vars = variables.get("conv", {}) if isinstance(variables, dict) else {}
context = {
"conv": conv_vars, # 会话变量:{{conv.user_name}}
"node": node_outputs, # 节点输出:{{node.node_1.output}}
"conv": conv_vars, # 会话变量:{{conv.user_name}}
"node": node_outputs, # 节点输出:{{node.node_1.output}}
"sys": {**(system_vars or {}), **sys_vars}, # 系统变量:{{sys.execution_id}}(合并两个来源)
}
# 支持直接通过节点ID访问节点输出{{llm_qa.output}}
# 将所有节点输出添加到顶层上下文
if node_outputs:
context.update(node_outputs)
# 支持直接访问会话变量(不需要 conv. 前缀):{{user_name}}
if conv_vars:
context.update(conv_vars)
context["nodes"] = node_outputs or {} # 旧语法兼容
try:
tmpl = self.env.from_string(template)
return tmpl.render(**context)
except TemplateSyntaxError as e:
logger.error(f"模板语法错误: {template}, 错误: {e}")
raise ValueError(f"模板语法错误: {e}")
except UndefinedError as e:
logger.error(f"模板中引用了未定义的变量: {template}, 错误: {e}")
raise ValueError(f"未定义的变量: {e}")
except Exception as e:
logger.error(f"模板渲染异常: {template}, 错误: {e}")
raise ValueError(f"模板渲染失败: {e}")
def validate(self, template: str) -> list[str]:
"""验证模板语法
@@ -121,14 +121,14 @@ class TemplateRenderer:
['模板语法错误: ...']
"""
errors = []
try:
self.env.from_string(template)
except TemplateSyntaxError as e:
errors.append(f"模板语法错误: {e}")
except Exception as e:
errors.append(f"模板验证失败: {e}")
return errors
@@ -137,14 +137,16 @@ _default_renderer = TemplateRenderer(strict=True)
def render_template(
template: str,
variables: dict[str, Any],
node_outputs: dict[str, Any],
system_vars: dict[str, Any] | None = None
template: str,
variables: dict[str, Any],
node_outputs: dict[str, Any],
system_vars: dict[str, Any] | None = None,
struct: bool = True
) -> str:
"""渲染模板(便捷函数)
Args:
struct: 渲染模式
template: 模板字符串
variables: 用户变量
node_outputs: 节点输出
@@ -162,7 +164,8 @@ def render_template(
... )
'请分析: 这是一段文本'
"""
return _default_renderer.render(template, variables, node_outputs, system_vars)
renderer = TemplateRenderer(strict=struct)
return renderer.render(template, variables, node_outputs, system_vars)
def validate_template(template: str) -> list[str]:

View File

@@ -87,10 +87,11 @@ class WorkflowValidator:
return graphs
@classmethod
def validate(cls, workflow_config: Union[dict[str, Any], Any]) -> tuple[bool, list[str]]:
def validate(cls, workflow_config: Union[dict[str, Any], Any], publish=False) -> tuple[bool, list[str]]:
"""验证工作流配置
Args:
publish: 发布验证标识
workflow_config: 工作流配置字典或 WorkflowConfig Pydantic 模型
Returns:
@@ -114,7 +115,7 @@ class WorkflowValidator:
graphs = cls.get_subgraph(workflow_config)
logger.info(graphs)
for graph in graphs:
for index, graph in enumerate(graphs):
nodes = graph.get("nodes", [])
edges = graph.get("edges", [])
variables = graph.get("variables", [])
@@ -125,10 +126,11 @@ class WorkflowValidator:
elif len(start_nodes) > 1:
errors.append(f"工作流只能有一个 start 节点,当前有 {len(start_nodes)}")
# 2. 验证 end 节点(至少一个)
end_nodes = [n for n in nodes if n.get("type") == NodeType.END]
if len(end_nodes) == 0:
errors.append("工作流必须至少有一个 end 节点")
if index == len(graphs) - 1:
# 2. 验证 主图end 节点(至少一个)
end_nodes = [n for n in nodes if n.get("type") == NodeType.END]
if len(end_nodes) == 0:
errors.append("工作流必须至少有一个 end 节点")
# 3. 验证节点 ID 唯一性
node_ids = [n.get("id") for n in nodes]
@@ -159,15 +161,17 @@ class WorkflowValidator:
elif target not in node_id_set:
errors.append(f"边 #{i} 的 target 节点不存在: {target}")
# 6. 验证所有节点可达(从 start 节点出发)
if start_nodes and not errors: # 只有在前面验证通过时才检查可达
reachable = WorkflowValidator._get_reachable_nodes(
start_nodes[0]["id"],
edges
)
unreachable = node_id_set - reachable
if unreachable:
errors.append(f"以下节点无法从 start 节点到达: {unreachable}")
if publish:
# 仅在发布时验证所有节点可达
# 6. 验证所有节点可达(从 start 节点出发)
if start_nodes and not errors: # 只有在前面验证通过时才检查可达性
reachable = WorkflowValidator._get_reachable_nodes(
start_nodes[0]["id"],
edges
)
unreachable = node_id_set - reachable
if unreachable:
errors.append(f"以下节点无法从 start 节点到达: {unreachable}")
# 7. 检测循环依赖(非 loop 节点)
if not errors: # 只有在前面验证通过时才检查循环
@@ -288,7 +292,7 @@ class WorkflowValidator:
(is_valid, errors): 是否有效和错误列表
"""
# 先执行基础验证
is_valid, errors = WorkflowValidator.validate(workflow_config)
is_valid, errors = WorkflowValidator.validate(workflow_config, publish=True)
if not is_valid:
return False, errors

View File

@@ -6,6 +6,7 @@ from .document_model import Document
from .file_model import File
from .generic_file_model import GenericFile
from .models_model import ModelConfig, ModelProvider, ModelType, ModelApiKey
from .memory_short_model import ShortTermMemory, LongTermMemory
from .knowledgeshare_model import KnowledgeShare
from .app_model import App
from .agent_app_config_model import AgentConfig
@@ -25,6 +26,7 @@ from .tool_model import (
ToolConfig, BuiltinToolConfig, CustomToolConfig, MCPToolConfig,
ToolExecution, ToolType, ToolStatus, AuthType, ExecutionStatus
)
from .memory_perceptual_model import MemoryPerceptualModel
__all__ = [
"Tenants",
@@ -67,9 +69,12 @@ __all__ = [
"BuiltinToolConfig",
"CustomToolConfig",
"MCPToolConfig",
"ShortTermMemory",
"LongTermMemory",
"ToolExecution",
"ToolType",
"ToolStatus",
"AuthType",
"ExecutionStatus"
"ExecutionStatus",
"MemoryPerceptualModel"
]

View File

@@ -3,7 +3,10 @@ import uuid
from sqlalchemy import Column, String, Boolean, DateTime, Text, ForeignKey
from sqlalchemy.dialects.postgresql import UUID, JSON
from sqlalchemy.orm import relationship
from app.base.type import PydanticType
from app.db import Base
from app.schemas import ModelParameters
class AgentConfig(Base):
@@ -17,14 +20,17 @@ class AgentConfig(Base):
# Agent 行为配置
system_prompt = Column(Text, nullable=True, comment="系统提示词")
default_model_config_id = Column(UUID(as_uuid=True), ForeignKey("model_configs.id"), nullable=True, index=True, comment="默认模型配置ID")
# 结构化配置(直接存储 JSON
model_parameters = Column(JSON, nullable=True, comment="模型参数配置temperature、max_tokens等")
# model_parameters = Column(JSON, nullable=True, comment="模型参数配置temperature、max_tokens等")
model_parameters = Column(PydanticType(ModelParameters), nullable=True,
comment="模型参数配置temperature、max_tokens等")
knowledge_retrieval = Column(JSON, nullable=True, comment="知识库检索配置")
memory = Column(JSON, nullable=True, comment="记忆配置")
variables = Column(JSON, default=list, nullable=True, comment="变量配置")
tools = Column(JSON, default=dict, nullable=True, comment="工具配置")
# 多 Agent 相关字段
agent_role = Column(String(20), comment="Agent 角色: master|sub|standalone")
agent_domain = Column(String(50), comment="专业领域: customer_service|technical_support|sales 等")
@@ -41,4 +47,4 @@ class AgentConfig(Base):
parent_agent = relationship("AgentConfig", remote_side=[id], backref="sub_agents")
def __repr__(self):
return f"<AgentConfig(id={self.id}, app_id={self.app_id})>"
return f"<AgentConfig(id={self.id}, app_id={self.app_id})>"

View File

@@ -3,6 +3,7 @@
"""
import uuid
import datetime
from sqlalchemy import Column, String, DateTime, ForeignKey, Boolean, Integer, Text, JSON
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.orm import relationship
@@ -25,56 +26,69 @@ class Conversation(Base):
__tablename__ = "conversations"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, index=True)
# 关联信息
app_id = Column(UUID(as_uuid=True), ForeignKey("apps.id"), nullable=False, comment="应用ID")
workspace_id = Column(UUID(as_uuid=True), ForeignKey("workspaces.id"), nullable=False, comment="工作空间ID")
user_id = Column(String, nullable=True, comment="用户ID外部系统")
# 会话信息
title = Column(String(255), comment="会话标题")
summary = Column(Text, comment="会话摘要")
# 会话类型True=草稿会话使用草稿配置False=发布会话(使用发布配置)
is_draft = Column(Boolean, default=True, nullable=False, comment="是否为草稿会话")
# 配置快照:保存创建会话时的完整配置,用于审计和问题追溯
config_snapshot = Column(JSON, comment="配置快照Agent配置、模型配置等")
# 统计信息
message_count = Column(Integer, default=0, comment="消息数量")
# 状态
is_active = Column(Boolean, default=True, nullable=False, comment="是否活跃")
# 时间戳
created_at = Column(DateTime, default=datetime.datetime.now, comment="创建时间")
updated_at = Column(DateTime, default=datetime.datetime.now, onupdate=datetime.datetime.now, comment="更新时间")
# 关联关系
app = relationship("App", back_populates="conversations")
workspace = relationship("Workspace")
messages = relationship("Message", back_populates="conversation", cascade="all, delete-orphan")
class ConversationDetail(Base):
__tablename__ = "conversation_details"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, index=True)
conversation_id = Column(UUID(as_uuid=True), ForeignKey("conversations.id"))
theme = Column(String, comment="会话主题")
summary = Column(String, comment="会话摘要")
takeaways = Column(JSON, comment="会话要点")
question = Column(JSON, comment="用户问题")
info_score = Column(Integer, comment="会话信息量评分")
class Message(Base):
"""消息表"""
__tablename__ = "messages"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, index=True)
# 关联信息
conversation_id = Column(UUID(as_uuid=True), ForeignKey("conversations.id"), nullable=False, comment="会话ID")
# 消息内容
role = Column(String(20), nullable=False, comment="角色: user/assistant/system")
content = Column(Text, nullable=False, comment="消息内容")
# 元数据(避免使用 metadata 保留字)
meta_data = Column(JSON, comment="消息元数据如模型、token使用等")
# 时间戳
created_at = Column(DateTime, default=datetime.datetime.now, comment="创建时间")
# 关联关系
conversation = relationship("Conversation", back_populates="messages")

View File

@@ -25,7 +25,6 @@ class DataConfig(Base):
llm_id = Column(String, nullable=True, comment="LLM模型配置ID")
embedding_id = Column(String, nullable=True, comment="嵌入模型配置ID")
rerank_id = Column(String, nullable=True, comment="重排序模型配置ID")
llm = Column(String, nullable=True, comment="LLM模型配置ID")
# 记忆萃取引擎配置
enable_llm_dedup_blockwise = Column(Boolean, default=True, comment="启用LLM决策去重")

View File

@@ -0,0 +1,40 @@
"""
遗忘周期历史记录模型
用于存储每次遗忘周期执行的历史数据,支持趋势分析和可视化。
"""
import uuid
from datetime import datetime
from sqlalchemy import Column, Integer, String, Float, DateTime, Index
from sqlalchemy.dialects.postgresql import UUID
from app.db import Base
class ForgettingCycleHistory(Base):
"""遗忘周期历史记录表"""
__tablename__ = "forgetting_cycle_history"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, nullable=False, index=True, comment="主键ID")
end_user_id = Column(String(255), nullable=False, comment="终端用户ID")
execution_time = Column(DateTime, nullable=False, default=datetime.now, comment="执行时间")
merged_count = Column(Integer, default=0, comment="本次成功融合的节点对数")
failed_count = Column(Integer, default=0, comment="本次融合失败的节点对数")
average_activation_value = Column(Float, nullable=True, comment="平均激活值")
total_nodes = Column(Integer, default=0, comment="总节点数")
low_activation_nodes = Column(Integer, default=0, comment="低于遗忘阈值的节点总数(包含已融合、失败和待处理的)")
duration_seconds = Column(Float, nullable=True, comment="执行耗时(秒)")
trigger_type = Column(String(50), default="manual", comment="触发类型: manual/scheduled")
# 创建索引以优化查询
__table_args__ = (
Index('idx_end_user_time', 'end_user_id', 'execution_time'),
Index('idx_execution_time', 'execution_time'),
)
def __repr__(self):
return (
f"<ForgettingCycleHistory(id={self.id}, end_user_id={self.end_user_id}, "
f"merged_count={self.merged_count}, execution_time={self.execution_time})>"
)

View File

@@ -0,0 +1,40 @@
import datetime
import uuid
from enum import IntEnum
from sqlalchemy import Column, ForeignKey, Integer, DateTime, String
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.dialects.postgresql import JSONB
from app.db import Base
class PerceptualType(IntEnum):
VISION = 1
AUDIO = 2
TEXT = 3
CONVERSATION = 4
class FileStorageType(IntEnum):
LOCAL = 1
REMOTE = 2
class MemoryPerceptualModel(Base):
__tablename__ = "memory_perceptual"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
end_user_id = Column(UUID(as_uuid=True), ForeignKey("end_users.id"), index=True)
perceptual_type = Column(Integer, index=True, nullable=False, comment="感知类型")
storage_service = Column(Integer, default=0, comment="存储服务类型")
file_path = Column(String, nullable=False, comment="文件路径")
file_name = Column(String, nullable=False, comment="文件名称")
file_ext = Column(String, nullable=False, comment="文件后缀名")
summary = Column(String, comment="摘要")
meta_data = Column(JSONB, comment="元信息")
created_time = Column(DateTime, default=datetime.datetime.now, comment="创建时间")

View File

@@ -0,0 +1,60 @@
"""
记忆模型 - 短期记忆和长期记忆表
"""
import uuid
import datetime
from sqlalchemy import Column, String, DateTime, Text, JSON
from sqlalchemy.dialects.postgresql import UUID
from app.db import Base
class ShortTermMemory(Base):
"""短期记忆表
用于存储临时的对话记忆,通常保存较短时间
"""
__tablename__ = "memory_short_term"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, index=True, comment="记忆ID")
# 用户信息
end_user_id = Column(String(255), nullable=False, index=True, comment="终端用户ID")
# 对话内容
messages = Column(Text, nullable=False, comment="用户消息内容")
aimessages = Column(Text, nullable=True, comment="AI回复消息内容")
# 搜索开关
search_switch = Column(String(50), nullable=True, comment="搜索开关状态")
# 检索内容 - 存储为JSON格式的列表包含字典 [{}, {}]
retrieved_content = Column(JSON, nullable=True, default=list, comment="检索到的相关内容,格式为[{}, {}]")
# 时间戳
created_at = Column(DateTime, default=datetime.datetime.now, nullable=False, index=True, comment="创建时间")
def __repr__(self):
return f"<ShortTermMemory(id={self.id}, end_user_id={self.end_user_id}, created_at={self.created_at})>"
class LongTermMemory(Base):
"""长期记忆表
用于存储重要的对话记忆,长期保存
"""
__tablename__ = "memory_long_term"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, index=True, comment="记忆ID")
# 用户信息
end_user_id = Column(String(255), nullable=False, index=True, comment="终端用户ID")
# 检索内容 - 存储为JSON格式的列表包含字典 [{}, {}]
retrieved_content = Column(JSON, nullable=True, default=list, comment="检索到的相关内容,格式为[{}, {}]")
# 时间戳
created_at = Column(DateTime, default=datetime.datetime.now, nullable=False, index=True, comment="创建时间")
def __repr__(self):
return f"<LongTermMemory(id={self.id}, end_user_id={self.end_user_id}, created_at={self.created_at})>"

View File

@@ -8,15 +8,15 @@ from sqlalchemy import Column, String, Boolean, DateTime, Integer, Float, Text,
from sqlalchemy.dialects.postgresql import UUID, JSON
from sqlalchemy.orm import relationship
from app.base.type import PydanticType
from app.db import Base
from app.schemas import ModelParameters
class OrchestrationMode(StrEnum):
"""图标类型枚举"""
SEQUENTIAL = "sequential"
PARALLEL = "parallel"
CONDITIONAL = "conditional"
"""协作模式枚举"""
COLLABORATION = "collaboration" # 协作模式Agent 之间可以相互 handoff
SUPERVISOR = "supervisor" # 监督模式:由主 Agent 统一调度子 Agent
class AggregationStrategy(StrEnum):
"""图标类型枚举"""
@@ -24,27 +24,27 @@ class AggregationStrategy(StrEnum):
VOTE = "vote"
PRIORITY = "priority"
class PydanticType(TypeDecorator):
impl = JSON
# class PydanticType(TypeDecorator):
# impl = JSON
def __init__(self, pydantic_model: type[BaseModel]):
super().__init__()
self.model = pydantic_model
# def __init__(self, pydantic_model: type[BaseModel]):
# super().__init__()
# self.model = pydantic_model
def process_bind_param(self, value, dialect):
# 入库Model -> dict
if value is None:
return None
if isinstance(value, self.model):
return value.dict()
return value # 已经是 dict 也放行
# def process_bind_param(self, value, dialect):
# # 入库Model -> dict
# if value is None:
# return None
# if isinstance(value, self.model):
# return value.dict()
# return value # 已经是 dict 也放行
def process_result_value(self, value, dialect):
# 出库dict -> Model
if value is None:
return None
# return self.model.parse_obj(value) # pydantic v1
return self.model.model_validate(value) # pydantic v2
# def process_result_value(self, value, dialect):
# # 出库dict -> Model
# if value is None:
# return None
# # return self.model.parse_obj(value) # pydantic v1
# return self.model.model_validate(value) # pydantic v2
class MultiAgentConfig(Base):
"""多 Agent 配置表"""
@@ -66,8 +66,8 @@ class MultiAgentConfig(Base):
orchestration_mode = Column(
String(20),
nullable=False,
default="conditional",
comment="协作模式: sequential|parallel|conditional|loop"
default="collaboration",
comment="协作模式: collaboration协作| supervisor监督"
)
# 子 Agent 列表

View File

@@ -0,0 +1,317 @@
import uuid
from typing import Optional
from sqlalchemy import select, desc, func
from sqlalchemy.orm import Session
from app.core.exceptions import ResourceNotFoundException
from app.core.logging_config import get_db_logger
from app.models import Conversation, Message
from app.models.conversation_model import ConversationDetail
logger = get_db_logger()
class ConversationRepository:
"""Repository for Conversation entity, encapsulating CRUD operations."""
def __init__(self, db: Session):
self.db = db
def create_conversation(
self,
app_id: uuid.UUID,
workspace_id: uuid.UUID,
user_id: Optional[str] = None,
title: Optional[str] = None,
is_draft: bool = False,
config_snapshot: Optional[dict] = None
) -> Conversation:
"""
Create a new conversation record.
Args:
app_id: Application ID the conversation belongs to.
workspace_id: Workspace ID where the conversation is created.
user_id: Optional user ID associated with the conversation.
title: Optional conversation title. Defaults to "New Conversation".
is_draft: Whether the conversation is a draft.
config_snapshot: Optional configuration snapshot.
Returns:
Conversation: Newly created Conversation instance.
"""
conversation = Conversation(
app_id=app_id,
workspace_id=workspace_id,
user_id=user_id,
title=title or "New Conversation",
is_draft=is_draft,
config_snapshot=config_snapshot
)
self.db.add(conversation)
return conversation
def get_conversation_by_conversation_id(
self,
conversation_id: uuid.UUID,
workspace_id: Optional[uuid.UUID] = None
) -> Conversation:
"""
Retrieve a conversation by its ID, optionally filtered by workspace.
Args:
conversation_id: The UUID of the conversation.
workspace_id: Optional workspace UUID to filter the conversation.
Raises:
ResourceNotFoundException: If conversation does not exist.
Returns:
Conversation: The matching Conversation instance.
"""
logger.info(f"Fetching conversation: {conversation_id}")
stmt = select(Conversation).where(Conversation.id == conversation_id)
if workspace_id:
stmt = stmt.where(Conversation.workspace_id == workspace_id)
conversation = self.db.scalars(stmt).first()
if not conversation:
logger.warning(f"Conversation not found: {conversation_id}")
raise ResourceNotFoundException("Conversation", str(conversation_id))
logger.info(f"Conversation fetched successfully: {conversation_id}")
return conversation
def get_conversation_by_user_id(
self,
user_id: uuid.UUID,
workspace_id: uuid.UUID = None,
limit: int = 10,
is_activate: bool = True
) -> list[Conversation]:
"""
Retrieve recent conversations for a specific user.
This method queries conversations associated with the given user ID,
optionally scoped to a specific workspace. Results are ordered by the
most recently updated conversations and limited to a fixed number.
Args:
user_id (uuid.UUID): Unique identifier of the user.
workspace_id (uuid.UUID, optional): Workspace scope for the query.
If provided, only conversations under this workspace will be returned.
limit (int): Maximum number of conversations to return.
Defaults to 10.
is_activate (bool): Convsersation State limit
Returns:
list[Conversation]: A list of conversation entities ordered by
last updated time (descending).
"""
logger.info(f"Fetching conversation by user_id: {user_id}")
stmt = select(Conversation).where(
Conversation.user_id == str(user_id),
Conversation.is_active.is_(is_activate)
)
if workspace_id:
stmt = stmt.where(Conversation.workspace_id == workspace_id)
stmt = stmt.order_by(desc(Conversation.updated_at))
stmt = stmt.limit(limit)
convsersations = list(self.db.scalars(stmt).all())
logger.info(
"Conversation fetched successfully",
extra={
"user_id": str(user_id),
"workspace_id": str(workspace_id),
}
)
return convsersations
def list_conversations(
self,
app_id: uuid.UUID,
workspace_id: uuid.UUID,
user_id: Optional[str] = None,
is_draft: Optional[bool] = None,
page: int = 1,
pagesize: int = 20
) -> tuple[list[Conversation], int]:
"""
List conversations with optional filters and pagination.
Args:
app_id: Application ID filter.
workspace_id: Workspace ID filter.
user_id: Optional user ID filter.
is_draft: Optional draft status filter.
page: Page number (1-based).
pagesize: Number of items per page.
Returns:
Tuple[List[Conversation], int]: List of Conversation instances and total count.
"""
stmt = select(Conversation).where(
Conversation.app_id == app_id,
Conversation.workspace_id == workspace_id,
Conversation.is_active.is_(True)
)
if user_id:
stmt = stmt.where(Conversation.user_id == str(user_id))
if is_draft is not None:
stmt = stmt.where(Conversation.is_draft == is_draft)
# Calculate total number of records
total = int(self.db.execute(
select(func.count()).select_from(stmt.subquery())
).scalar_one())
# Apply pagination
stmt = stmt.order_by(desc(Conversation.updated_at))
stmt = stmt.offset((page - 1) * pagesize).limit(pagesize)
conversations = list(self.db.scalars(stmt).all())
logger.info(
"Listed conversations successfully",
extra={
"app_id": str(app_id),
"workspace_id": str(workspace_id),
"returned": len(conversations),
"total": total
}
)
return conversations, total
def soft_delete_conversation_by_conversation_id(
self,
conversation_id: uuid.UUID,
workspace_id: uuid.UUID,
):
"""
Soft delete a conversation by setting is_active to False.
Args:
conversation_id: The UUID of the conversation.
workspace_id: Workspace ID for verification.
"""
conversation = self.get_conversation_by_conversation_id(
conversation_id,
workspace_id
)
conversation.is_active = False
def get_conversation_detail(
self,
conversation_id: uuid.UUID
) -> ConversationDetail | None:
"""
Retrieve the detail of a conversation by its ID.
Args:
conversation_id (UUID): The unique identifier of the conversation.
Returns:
ConversationDetail or None: The conversation detail object if found,
otherwise None.
Notes:
- This method queries the database but does not modify it.
- The caller is responsible for handling the case where None is returned.
"""
stmt = select(ConversationDetail).where(
ConversationDetail.conversation_id == conversation_id
)
detail = self.db.scalars(stmt).first()
return detail
def add_conversation_detail(
self,
conversation_detail: ConversationDetail,
):
"""
Add a new conversation detail record to the database session.
Args:
conversation_detail (ConversationDetail): The ORM object representing
the conversation detail to add.
Returns:
ConversationDetail: The same object added to the session.
Notes:
- This method only adds the object to the current session.
- It does not commit the transaction; commit/rollback is handled
by the caller.
- Useful for batch operations or transactional control.
"""
self.db.add(conversation_detail)
return conversation_detail
class MessageRepository:
"""Repository for Message entity, encapsulating CRUD operations."""
def __init__(self, db: Session):
self.db = db
def add_message(self, message: Message) -> Message:
"""
Add a new message record to the conversation.
Args:
message (Message): The Message ORM object to be added.
Returns:
Message: The same message object added to the conversation.
Notes:
- This method only adds the object to the current conversation.
- It does not commit the transaction; commit/rollback should be handled
by the caller.
- Useful for transactional control or batch operations.
"""
self.db.add(message)
return message
def get_message_by_conversation_id(
self,
conversation_id: uuid.UUID,
limit: Optional[int] = None
) -> list[Message]:
"""
Retrieve messages by conversation ID.
Args:
conversation_id: The UUID of the conversation.
limit: Optional limit on the number of messages returned.
Returns:
List[Message]: List of Message instances.
"""
stmt = select(Message).where(
Message.conversation_id == conversation_id
).order_by(Message.created_at)
if limit:
stmt = stmt.limit(limit)
messages = list(self.db.scalars(stmt).all())
logger.info(
"Fetched messages successfully",
extra={
"conversation_id": str(conversation_id),
"returned": len(messages)
}
)
return messages

View File

@@ -327,7 +327,7 @@ class DataConfigRepository:
# 更新字段映射
field_mapping = {
# 模型选择
"llm_id": "llm",
"llm_id": "llm_id",
"embedding_id": "embedding_id",
"rerank_id": "rerank_id",
# 记忆萃取引擎

View File

@@ -0,0 +1,105 @@
"""
遗忘周期历史记录仓储
提供遗忘周期历史记录的数据访问操作。
"""
from typing import List, Optional
from datetime import datetime, timedelta
from sqlalchemy.orm import Session
from sqlalchemy import desc, and_
from app.models.forgetting_cycle_history_model import ForgettingCycleHistory
class ForgettingCycleHistoryRepository:
"""遗忘周期历史记录仓储类"""
def create(
self,
db: Session,
end_user_id: str,
execution_time: datetime,
merged_count: int,
failed_count: int,
average_activation_value: Optional[float],
total_nodes: int,
low_activation_nodes: int,
duration_seconds: float,
trigger_type: str = "manual"
) -> ForgettingCycleHistory:
"""
创建历史记录
Args:
db: 数据库会话
end_user_id: 终端用户ID
execution_time: 执行时间
merged_count: 融合节点数
failed_count: 失败节点数
average_activation_value: 平均激活值
total_nodes: 总节点数
low_activation_nodes: 低激活值节点数
duration_seconds: 执行耗时
trigger_type: 触发类型
Returns:
ForgettingCycleHistory: 创建的历史记录
"""
history = ForgettingCycleHistory(
end_user_id=end_user_id,
execution_time=execution_time,
merged_count=merged_count,
failed_count=failed_count,
average_activation_value=average_activation_value,
total_nodes=total_nodes,
low_activation_nodes=low_activation_nodes,
duration_seconds=duration_seconds,
trigger_type=trigger_type
)
db.add(history)
db.commit()
db.refresh(history)
return history
def get_recent_by_end_user(
self,
db: Session,
end_user_id: str
) -> List[ForgettingCycleHistory]:
"""
获取指定终端用户的所有历史记录(按时间降序排列)
注意:此方法返回所有历史记录,调用方需要自行处理日期分组和数量限制。
Args:
db: 数据库会话
end_user_id: 终端用户ID
Returns:
List[ForgettingCycleHistory]: 历史记录列表,按时间降序排列
"""
return db.query(ForgettingCycleHistory).filter(
ForgettingCycleHistory.end_user_id == end_user_id
).order_by(ForgettingCycleHistory.execution_time.desc()).all()
def get_latest_by_end_user(
self,
db: Session,
end_user_id: str
) -> Optional[ForgettingCycleHistory]:
"""
获取指定终端用户的最新历史记录
Args:
db: 数据库会话
end_user_id: 终端用户ID
Returns:
Optional[ForgettingCycleHistory]: 最新历史记录
"""
return db.query(ForgettingCycleHistory).filter(
ForgettingCycleHistory.end_user_id == end_user_id
).order_by(desc(ForgettingCycleHistory.execution_time)).first()

View File

@@ -0,0 +1,156 @@
import uuid
from datetime import datetime
from typing import List, Tuple, Optional
from sqlalchemy import and_, desc
from sqlalchemy.orm import Session
from app.core.logging_config import get_db_logger
from app.models.memory_perceptual_model import MemoryPerceptualModel, PerceptualType, FileStorageType
from app.schemas.memory_perceptual_schema import PerceptualQuerySchema
db_logger = get_db_logger()
class MemoryPerceptualRepository:
"""Data Access Layer for perceptual memory"""
def __init__(self, db: Session):
self.db = db
# ==================== Create and update ====================
def create_perceptual_memory(
self,
end_user_id: uuid.UUID,
perceptual_type: PerceptualType,
file_path: str,
file_name: str,
file_ext: str,
summary: Optional[str] = None,
meta_data: Optional[dict] = None,
storage_service: FileStorageType = FileStorageType.LOCAL
) -> MemoryPerceptualModel:
"""Create perceptual memory"""
db_logger.debug(f"Creating perceptual memory: end_user_id={end_user_id}, "
f"type={perceptual_type}, file={file_name}")
try:
perceptual_memory = MemoryPerceptualModel(
end_user_id=end_user_id,
perceptual_type=perceptual_type,
storage_service=storage_service,
file_path=file_path,
file_name=file_name,
file_ext=file_ext,
summary=summary,
meta_data=meta_data,
created_time=datetime.now()
)
self.db.add(perceptual_memory)
self.db.flush()
db_logger.info(f"Perceptual memory created successfully: id={perceptual_memory.id}, file={file_name}")
return perceptual_memory
except Exception as e:
db_logger.error(f"Failed to create perceptual memory: end_user_id={end_user_id} - {str(e)}")
raise
# ==================== Query ====================
def get_count_by_user_id(
self,
end_user_id: uuid.UUID,
):
db_logger.debug(f"Querying perceptual memory Count: end_user_id={end_user_id}")
try:
count = self.db.query(MemoryPerceptualModel).filter(
MemoryPerceptualModel.end_user_id == end_user_id
).count()
return count
except Exception as e:
db_logger.error(f"Failed to query perceptual memory count: end_user_id={end_user_id} - {str(e)}")
raise
def get_count_by_type(
self,
end_user_id: uuid.UUID,
perceptual_type: PerceptualType,
):
db_logger.debug(f"Querying perceptual memory Count: end_user_id={end_user_id}, type={perceptual_type}")
try:
count = self.db.query(MemoryPerceptualModel).filter(
MemoryPerceptualModel.end_user_id == end_user_id,
MemoryPerceptualModel.perceptual_type == perceptual_type
).count()
return count
except Exception as e:
db_logger.error(f"Failed to query perceptual memory count: end_user_id={end_user_id} - {str(e)}")
raise
def get_timeline(
self,
end_user_id: uuid.UUID,
query: PerceptualQuerySchema
) -> Tuple[int, List[MemoryPerceptualModel]]:
"""Get the timeline of a user's perceptual memories"""
db_logger.debug(f"Querying perceptual memory timeline: end_user_id={end_user_id}, filter={query.filter}")
try:
base_query = self.db.query(MemoryPerceptualModel).filter(
MemoryPerceptualModel.end_user_id == end_user_id
)
if query.filter.type is not None:
base_query = base_query.filter(
MemoryPerceptualModel.perceptual_type == query.filter.type
)
total_count = base_query.count()
memories = base_query.order_by(
desc(MemoryPerceptualModel.created_time)
).offset(
(query.page - 1) * query.page_size
).limit(query.page_size).all()
db_logger.info(
f"Perceptual memory timeline query succeeded: end_user_id={end_user_id}, total={total_count}, returned={len(memories)}")
return total_count, memories
except Exception as e:
db_logger.error(f"Failed to query perceptual memory timeline: end_user_id={end_user_id} - {str(e)}")
raise
def get_by_type(
self,
end_user_id: uuid.UUID,
perceptual_type: PerceptualType,
limit: int = 10,
offset: int = 0
) -> List[MemoryPerceptualModel]:
"""Get memories by perceptual type"""
db_logger.debug(f"Querying perceptual memories by type: end_user_id={end_user_id}, type={perceptual_type}")
try:
memories = self.db.query(MemoryPerceptualModel).filter(
and_(
MemoryPerceptualModel.end_user_id == end_user_id,
MemoryPerceptualModel.perceptual_type == perceptual_type
)
).order_by(
desc(MemoryPerceptualModel.created_time)
).offset(offset).limit(limit).all()
db_logger.debug(f"Query by type succeeded: count={len(memories)}")
return memories
except Exception as e:
db_logger.error(f"Failed to query perceptual memories by type: end_user_id={end_user_id}, "
f"type={perceptual_type} - {str(e)}")
raise

View File

@@ -0,0 +1,503 @@
"""
记忆仓储模块 - 短期记忆和长期记忆的数据访问层
"""
from sqlalchemy.orm import Session
from typing import List, Optional, Dict, Any
import uuid
import datetime
from app.models.memory_short_model import ShortTermMemory, LongTermMemory
from app.core.logging_config import get_db_logger
# 获取数据库专用日志器
db_logger = get_db_logger()
class ShortTermMemoryRepository:
"""短期记忆仓储类"""
def __init__(self, db: Session):
self.db = db
def create(self, end_user_id: str, messages: str, aimessages: str = None, search_switch: str = None, retrieved_content: List[Dict] = None) -> ShortTermMemory:
"""创建短期记忆记录
Args:
end_user_id: 终端用户ID
messages: 用户消息内容
aimessages: AI回复消息内容
search_switch: 搜索开关状态
retrieved_content: 检索到的相关内容,格式为[{}, {}]
Returns:
ShortTermMemory: 创建的短期记忆对象
"""
try:
memory = ShortTermMemory(
end_user_id=end_user_id,
messages=messages,
aimessages=aimessages,
search_switch=search_switch,
retrieved_content=retrieved_content or []
)
self.db.add(memory)
self.db.commit()
self.db.refresh(memory)
db_logger.info(f"成功创建短期记忆记录: {memory.id} for user {end_user_id}")
return memory
except Exception as e:
self.db.rollback()
db_logger.error(f"创建短期记忆记录时出错: {str(e)}")
raise
def count_by_user_id(self,end_user_id: str) -> int:
"""根据ID获取短期记忆记录
Args:
memory_id: 记忆ID
Returns:
Optional[ShortTermMemory]: 记忆对象如果不存在则返回None
"""
try:
count = (
self.db.query(ShortTermMemory)
.filter(ShortTermMemory.end_user_id == end_user_id)
.count()
)
db_logger.debug(f"成功统计用户 {end_user_id} 的短期记忆数量: {count}")
return count
except Exception as e:
self.db.rollback()
db_logger.error(f"查询短期记忆记录 {count} 时出错: {str(e)}")
raise
def get_latest_by_user_id(self, end_user_id: str, limit: int = 5) -> List[ShortTermMemory]:
"""获取用户最新的短期记忆记录
Args:
end_user_id: 终端用户ID
limit: 返回记录数限制默认5条
Returns:
List[ShortTermMemory]: 最新的记忆记录列表,按创建时间倒序
"""
try:
# 使用复合索引 ix_memory_short_term_user_time 优化查询
memories = (
self.db.query(ShortTermMemory)
.filter(ShortTermMemory.end_user_id == end_user_id)
.order_by(ShortTermMemory.created_at.desc())
.limit(limit)
.all()
)
db_logger.info(f"成功查询用户 {end_user_id} 的最新 {len(memories)} 条短期记忆记录")
return memories
except Exception as e:
self.db.rollback()
db_logger.error(f"查询用户 {end_user_id} 的最新短期记忆记录时出错: {str(e)}")
raise
def get_recent_by_user_id(self, end_user_id: str, hours: int = 24) -> List[ShortTermMemory]:
"""获取用户最近指定小时内的短期记忆记录
Args:
end_user_id: 终端用户ID
hours: 时间范围小时默认24小时
Returns:
List[ShortTermMemory]: 记忆记录列表,按创建时间倒序
"""
try:
cutoff_time = datetime.datetime.now() - datetime.timedelta(hours=hours)
# 使用复合索引 ix_memory_short_term_user_time 优化查询
memories = (
self.db.query(ShortTermMemory)
.filter(
ShortTermMemory.end_user_id == end_user_id,
ShortTermMemory.created_at >= cutoff_time
)
.order_by(ShortTermMemory.created_at.desc())
.all()
)
db_logger.info(f"成功查询用户 {end_user_id} 最近 {hours} 小时的 {len(memories)} 条短期记忆记录")
return memories
except Exception as e:
self.db.rollback()
db_logger.error(f"查询用户 {end_user_id} 最近 {hours} 小时的短期记忆记录时出错: {str(e)}")
raise
def delete_by_id(self, memory_id: uuid.UUID) -> bool:
"""删除指定ID的短期记忆记录
Args:
memory_id: 记忆ID
Returns:
bool: 删除成功返回True否则返回False
"""
try:
deleted_count = (
self.db.query(ShortTermMemory)
.filter(ShortTermMemory.id == memory_id)
.delete(synchronize_session=False)
)
self.db.commit()
if deleted_count > 0:
db_logger.info(f"成功删除短期记忆记录 {memory_id}")
return True
else:
db_logger.warning(f"未找到短期记忆记录 {memory_id},无法删除")
return False
except Exception as e:
self.db.rollback()
db_logger.error(f"删除短期记忆记录 {memory_id} 时出错: {str(e)}")
raise
def delete_old_memories(self, days: int = 7) -> int:
"""删除指定天数之前的短期记忆记录
Args:
days: 保留天数默认7天
Returns:
int: 删除的记录数
"""
try:
cutoff_time = datetime.datetime.now() - datetime.timedelta(days=days)
deleted_count = (
self.db.query(ShortTermMemory)
.filter(ShortTermMemory.created_at < cutoff_time)
.delete(synchronize_session=False)
)
self.db.commit()
db_logger.info(f"成功删除 {days} 天前的 {deleted_count} 条短期记忆记录")
return deleted_count
except Exception as e:
self.db.rollback()
db_logger.error(f"删除 {days} 天前的短期记忆记录时出错: {str(e)}")
raise
def upsert(self, end_user_id: str, messages: str, aimessages: str = None, search_switch: str = None, retrieved_content: List[Dict] = None) -> ShortTermMemory:
"""创建或更新短期记忆记录
根据 end_user_id、messages 和 aimessages 查找现有记录:
- 如果找到匹配的记录,则更新 messages、aimessages、search_switch 和 retrieved_content
- 如果没有找到匹配的记录,则创建新记录
Args:
end_user_id: 终端用户ID
messages: 用户消息内容
aimessages: AI回复消息内容
search_switch: 搜索开关状态
retrieved_content: 检索到的相关内容,格式为[{}, {}]
Returns:
ShortTermMemory: 创建或更新的短期记忆对象
"""
try:
# 构建查询条件,使用复合索引 ix_memory_short_term_user_messages 优化查询
query_filters = [
ShortTermMemory.end_user_id == end_user_id,
ShortTermMemory.messages == messages
]
# 如果 aimessages 不为空,则加入查询条件
if aimessages is not None:
query_filters.append(ShortTermMemory.aimessages == aimessages)
else:
# 如果 aimessages 为 None则查找 aimessages 为 NULL 的记录
query_filters.append(ShortTermMemory.aimessages.is_(None))
# 查找现有记录
existing_memory = (
self.db.query(ShortTermMemory)
.filter(*query_filters)
.first()
)
if existing_memory:
# 更新现有记录
existing_memory.messages = messages
existing_memory.aimessages = aimessages
existing_memory.search_switch = search_switch
existing_memory.retrieved_content = retrieved_content or []
self.db.commit()
self.db.refresh(existing_memory)
db_logger.info(f"成功更新短期记忆记录: {existing_memory.id} for user {end_user_id}")
return existing_memory
else:
# 创建新记录
new_memory = ShortTermMemory(
end_user_id=end_user_id,
messages=messages,
aimessages=aimessages,
search_switch=search_switch,
retrieved_content=retrieved_content or []
)
self.db.add(new_memory)
self.db.commit()
self.db.refresh(new_memory)
db_logger.info(f"成功创建新的短期记忆记录: {new_memory.id} for user {end_user_id}")
return new_memory
except Exception as e:
self.db.rollback()
db_logger.error(f"创建或更新短期记忆记录时出错: {str(e)}")
raise
class LongTermMemoryRepository:
"""长期记忆仓储类"""
def __init__(self, db: Session):
self.db = db
def create(self, end_user_id: str, retrieved_content: List[Dict] = None) -> LongTermMemory:
"""创建长期记忆记录
Args:
end_user_id: 终端用户ID
retrieved_content: 检索到的相关内容,格式为[{}, {}]
Returns:
LongTermMemory: 创建的长期记忆对象
"""
try:
memory = LongTermMemory(
end_user_id=end_user_id,
retrieved_content=retrieved_content or []
)
self.db.add(memory)
self.db.commit()
self.db.refresh(memory)
db_logger.info(f"成功创建长期记忆记录: {memory.id} for user {end_user_id}")
return memory
except Exception as e:
self.db.rollback()
db_logger.error(f"创建长期记忆记录时出错: {str(e)}")
raise
def get_by_id(self, memory_id: uuid.UUID) -> Optional[LongTermMemory]:
"""根据ID获取长期记忆记录
Args:
memory_id: 记忆ID
Returns:
Optional[LongTermMemory]: 记忆对象如果不存在则返回None
"""
try:
memory = (
self.db.query(LongTermMemory)
.filter(LongTermMemory.id == memory_id)
.first()
)
if memory:
db_logger.debug(f"成功查询到长期记忆记录 {memory_id}")
else:
db_logger.debug(f"未找到长期记忆记录 {memory_id}")
return memory
except Exception as e:
self.db.rollback()
db_logger.error(f"查询长期记忆记录 {memory_id} 时出错: {str(e)}")
raise
def get_by_user_id(self, end_user_id: str, limit: int = 100, offset: int = 0) -> List[LongTermMemory]:
"""根据用户ID获取长期记忆记录列表
Args:
end_user_id: 终端用户ID
limit: 返回记录数限制默认100
offset: 偏移量默认0
Returns:
List[LongTermMemory]: 记忆记录列表,按创建时间倒序
"""
try:
# 使用复合索引 ix_memory_long_term_user_time 优化查询
memories = (
self.db.query(LongTermMemory)
.filter(LongTermMemory.end_user_id == end_user_id)
.order_by(LongTermMemory.created_at.desc())
.limit(limit)
.offset(offset)
.all()
)
db_logger.info(f"成功查询用户 {end_user_id}{len(memories)} 条长期记忆记录")
return memories
except Exception as e:
self.db.rollback()
db_logger.error(f"查询用户 {end_user_id} 的长期记忆记录时出错: {str(e)}")
raise
def search_by_content(self, end_user_id: str, keyword: str, limit: int = 50) -> List[LongTermMemory]:
"""根据内容关键词搜索长期记忆记录
Args:
end_user_id: 终端用户ID
keyword: 搜索关键词
limit: 返回记录数限制默认50
Returns:
List[LongTermMemory]: 匹配的记忆记录列表,按创建时间倒序
"""
try:
# 使用 GIN 索引 ix_memory_long_term_retrieved_content_gin 优化 JSON 搜索
# 同时使用复合索引 ix_memory_long_term_user_time 优化用户过滤
memories = (
self.db.query(LongTermMemory)
.filter(
LongTermMemory.end_user_id == end_user_id,
LongTermMemory.retrieved_content.astext.contains(keyword)
)
.order_by(LongTermMemory.created_at.desc())
.limit(limit)
.all()
)
db_logger.info(f"成功搜索用户 {end_user_id} 包含关键词 '{keyword}'{len(memories)} 条长期记忆记录")
return memories
except Exception as e:
self.db.rollback()
db_logger.error(f"搜索用户 {end_user_id} 包含关键词 '{keyword}' 的长期记忆记录时出错: {str(e)}")
raise
def delete_by_id(self, memory_id: uuid.UUID) -> bool:
"""删除指定ID的长期记忆记录
Args:
memory_id: 记忆ID
Returns:
bool: 删除成功返回True否则返回False
"""
try:
deleted_count = (
self.db.query(LongTermMemory)
.filter(LongTermMemory.id == memory_id)
.delete(synchronize_session=False)
)
self.db.commit()
if deleted_count > 0:
db_logger.info(f"成功删除长期记忆记录 {memory_id}")
return True
else:
db_logger.warning(f"未找到长期记忆记录 {memory_id},无法删除")
return False
except Exception as e:
self.db.rollback()
db_logger.error(f"删除长期记忆记录 {memory_id} 时出错: {str(e)}")
raise
def count_by_user_id(self, end_user_id: str) -> int:
"""统计用户的长期记忆记录数量
Args:
end_user_id: 终端用户ID
Returns:
int: 记录数量
"""
try:
count = (
self.db.query(LongTermMemory)
.filter(LongTermMemory.end_user_id == end_user_id)
.count()
)
db_logger.debug(f"用户 {end_user_id} 共有 {count} 条长期记忆记录")
return count
except Exception as e:
self.db.rollback()
db_logger.error(f"统计用户 {end_user_id} 的长期记忆记录数量时出错: {str(e)}")
raise
def upsert(self, end_user_id: str, retrieved_content: List[Dict] = None) -> Optional[LongTermMemory]:
"""创建或更新长期记忆记录
根据 end_user_id 和 retrieved_content 判断是否需要写入:
- 如果找到相同的 end_user_id 和 retrieved_content则不写入返回 None
- 如果没有找到相同的记录,则创建新记录
Args:
end_user_id: 终端用户ID
retrieved_content: 检索到的相关内容,格式为[{}, {}]
Returns:
Optional[LongTermMemory]: 创建的长期记忆对象,如果不需要写入则返回 None
"""
try:
retrieved_content = retrieved_content or []
# 优化查询:使用复合索引 ix_memory_long_term_user_time 先过滤用户
# 然后在应用层比较 JSON 内容,避免复杂的数据库 JSON 比较
existing_memories = (
self.db.query(LongTermMemory)
.filter(LongTermMemory.end_user_id == end_user_id)
.order_by(LongTermMemory.created_at.desc())
.limit(100) # 限制查询数量,避免加载过多数据
.all()
)
# 在 Python 中比较 retrieved_content
for memory in existing_memories:
if memory.retrieved_content == retrieved_content:
# 如果找到相同的记录,不写入
db_logger.info(f"长期记忆记录已存在,跳过写入: user {end_user_id}")
return None
# 如果没有找到相同的记录,创建新记录
new_memory = LongTermMemory(
end_user_id=end_user_id,
retrieved_content=retrieved_content
)
self.db.add(new_memory)
self.db.commit()
self.db.refresh(new_memory)
db_logger.info(f"成功创建新的长期记忆记录: {new_memory.id} for user {end_user_id}")
return new_memory
except Exception as e:
self.db.rollback()
db_logger.error(f"创建或更新长期记忆记录时出错: {str(e)}")
raise

View File

@@ -211,6 +211,7 @@ async def add_memory_summary_nodes(summaries: List[MemorySummaryNode], connector
"dialog_id": s.dialog_id,
"chunk_ids": s.chunk_ids,
"content": s.content,
"memory_type": s.memory_type, # 添加 memory_type 字段
"summary_embedding": s.summary_embedding if s.summary_embedding else None,
"config_id": s.config_id, # 添加 config_id
})

View File

@@ -92,6 +92,11 @@ SET e.name = CASE WHEN entity.name IS NOT NULL AND entity.name <> '' THEN entity
WHEN entity.description IS NOT NULL AND entity.description <> ''
AND (e.description IS NULL OR size(e.description) = 0 OR size(entity.description) > size(e.description))
THEN entity.description ELSE e.description END,
e.example = CASE
WHEN entity.example IS NOT NULL AND entity.example <> ''
THEN entity.example
ELSE coalesce(e.example, '')
END,
e.statement_id = CASE WHEN entity.statement_id IS NOT NULL AND entity.statement_id <> '' THEN entity.statement_id ELSE e.statement_id END,
e.aliases = CASE
WHEN entity.aliases IS NOT NULL AND size(entity.aliases) > 0
@@ -121,7 +126,8 @@ SET e.name = CASE WHEN entity.name IS NOT NULL AND entity.name <> '' THEN entity
e.activation_value = CASE WHEN entity.activation_value IS NOT NULL THEN entity.activation_value ELSE e.activation_value END,
e.access_history = CASE WHEN entity.access_history IS NOT NULL THEN entity.access_history ELSE coalesce(e.access_history, []) END,
e.last_access_time = CASE WHEN entity.last_access_time IS NOT NULL THEN entity.last_access_time ELSE e.last_access_time END,
e.access_count = CASE WHEN entity.access_count IS NOT NULL THEN entity.access_count ELSE coalesce(e.access_count, 0) END
e.access_count = CASE WHEN entity.access_count IS NOT NULL THEN entity.access_count ELSE coalesce(e.access_count, 0) END,
e.is_explicit_memory = CASE WHEN entity.is_explicit_memory IS NOT NULL THEN entity.is_explicit_memory ELSE coalesce(e.is_explicit_memory, false) END
RETURN e.id AS uuid
"""
@@ -722,7 +728,12 @@ SET m += {
chunk_ids: summary.chunk_ids,
content: summary.content,
summary_embedding: summary.summary_embedding,
config_id: summary.config_id
config_id: summary.config_id,
importance_score: CASE WHEN summary.importance_score IS NOT NULL THEN summary.importance_score ELSE coalesce(m.importance_score, 0.5) END,
activation_value: CASE WHEN summary.activation_value IS NOT NULL THEN summary.activation_value ELSE m.activation_value END,
access_history: CASE WHEN summary.access_history IS NOT NULL THEN summary.access_history ELSE coalesce(m.access_history, []) END,
last_access_time: CASE WHEN summary.last_access_time IS NOT NULL THEN summary.last_access_time ELSE m.last_access_time END,
access_count: CASE WHEN summary.access_count IS NOT NULL THEN summary.access_count ELSE coalesce(m.access_count, 0) END
}
RETURN m.id AS uuid
"""
@@ -857,3 +868,174 @@ neo4j_query_all = """
"""
'''针对当前节点下扩长的句子,实体和总结'''
Memory_Timeline_ExtractedEntity="""
MATCH (n)-[r1]-(e)-[r2]-(ms)
WHERE elementId(n) = $id
AND (ms:ExtractedEntity OR ms:MemorySummary)
RETURN
collect(
DISTINCT
CASE
WHEN ms:ExtractedEntity THEN {
text: ms.name,
created_at: ms.created_at,
type: "情景记忆"
}
END
) AS ExtractedEntity,
collect(
DISTINCT
CASE
WHEN ms:MemorySummary THEN {
text: ms.content,
created_at: ms.created_at,
type: "长期沉淀"
}
END
) AS MemorySummary,
collect(
DISTINCT {
text: e.statement,
created_at: e.created_at,
type: "情绪记忆"
}
) AS statement;
"""
Memory_Timeline_MemorySummary="""
MATCH (n)-[r1]-(e)-[r2]-(ms)
WHERE elementId(n) =$id
AND (ms:MemorySummary OR ms:ExtractedEntity)
RETURN
collect(
DISTINCT
CASE
WHEN ms:ExtractedEntity THEN {
text: ms.name,
created_at: ms.created_at
}
END
) AS ExtractedEntity,
collect(
DISTINCT
CASE
WHEN n:MemorySummary THEN {
text: n.content,
created_at: n.created_at
}
END
) AS MemorySummary,
collect(
DISTINCT {
text: e.statement,
created_at: e.created_at
}
) AS statement;
"""
Memory_Timeline_Statement="""
MATCH (n)
WHERE elementId(n) = $id
CALL {
WITH n
MATCH (n)-[]-(m:ExtractedEntity)
WHERE NOT m:MemorySummary AND NOT m:Chunk
RETURN collect(
DISTINCT {
text: m.name,
created_at: m.created_at,
type: "情景记忆"
}
) AS ExtractedEntity
}
CALL {
WITH n
MATCH (n)-[]-(m:MemorySummary)
WHERE NOT m:Chunk
RETURN collect(
DISTINCT {
text: m.content,
created_at: m.created_at,
type: "长期沉淀"
}
) AS MemorySummary
}
RETURN
ExtractedEntity,
MemorySummary,
{
text: n.statement,
created_at: n.created_at,
type: "情绪记忆"
} AS statement;
"""
'''针对当前节点,主要获取更加完整的句子节点'''
Memory_Space_Emotion_Statement="""
MATCH (n)
WHERE elementId(n) = $id
RETURN
n.emotion_intensity AS emotion_intensity,
n.created_at AS created_at,
n.emotion_type AS emotion_type,
n.statement AS statement;
"""
Memory_Space_Emotion_MemorySummary="""
MATCH (n)-[]-(e)
WHERE elementId(n) = $id
AND EXISTS {
MATCH (e)-[]-(ms)
WHERE ms:MemorySummary OR ms:ExtractedEntity
}
RETURN DISTINCT
e.emotion_intensity AS emotion_intensity,
e.created_at AS created_at,
e.emotion_type AS emotion_type,
e.statement AS statement;
"""
Memory_Space_Emotion_ExtractedEntity="""
MATCH (n)-[]-(e)
WHERE elementId(n) = $id
AND EXISTS {
MATCH (e)-[]-(ms:ExtractedEntity)
}
RETURN DISTINCT
e.emotion_intensity AS emotion_intensity,
e.created_at AS created_at,
e.emotion_type AS emotion_type,
e.statement AS statement;
"""
'''获取实体'''
Memory_Space_User="""
MATCH (n)-[r]->(m)
WHERE n.group_id = $group_id AND m.name="用户"
return DISTINCT elementId(m) as id
"""
Memory_Space_Entity="""
MATCH (n)-[]-(m)
WHERE elementId(m) = $id AND m.entity_type = "Person"
RETURN
DISTINCT m.name as name,m.group_id as group_id
"""
Memory_Space_Associative="""
MATCH (u)-[]-(x)-[]-(h)
WHERE elementId(u) = $user_id
AND elementId(h) = $id
RETURN DISTINCT
x.statement as statement,x.created_at as created_at
"""

View File

@@ -58,7 +58,7 @@ class EntityRepository(BaseNeo4jRepository[ExtractedEntityNode]):
# 处理 ACT-R 属性 - 确保字段存在且有默认值
n['importance_score'] = n.get('importance_score', 0.5)
n['activation_value'] = n.get('activation_value')
n['access_history'] = n.get('access_history', [])
n['access_history'] = n.get('access_history') or []
n['last_access_time'] = n.get('last_access_time')
n['access_count'] = n.get('access_count', 0)

View File

@@ -0,0 +1,273 @@
# -*- coding: utf-8 -*-
"""Memory Summary Repository Module
This module provides data access functionality for MemorySummary nodes.
Classes:
MemorySummaryRepository: Repository for managing MemorySummary CRUD operations
"""
from datetime import datetime
from typing import Any, Dict, List, Optional
from app.repositories.neo4j.base_neo4j_repository import BaseNeo4jRepository
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
class MemorySummaryRepository(BaseNeo4jRepository):
"""Memory Summary Repository
Manages CRUD operations for MemorySummary nodes.
Provides methods to query summaries by group_id, user_id, and time ranges.
Attributes:
connector: Neo4j connector instance
node_label: Node label, fixed as "MemorySummary"
"""
def __init__(self, connector: Neo4jConnector):
"""Initialize memory summary repository
Args:
connector: Neo4j connector instance
"""
super().__init__(connector, "MemorySummary")
def _map_to_dict(self, node_data: Dict) -> Dict[str, Any]:
"""Map node data to dictionary format
Args:
node_data: Node data returned from Neo4j query
Returns:
Dict[str, Any]: Memory summary data dictionary
"""
# Extract node data from query result
n = node_data.get('n', node_data)
# Handle datetime fields
if isinstance(n.get('created_at'), str):
n['created_at'] = datetime.fromisoformat(n['created_at'])
return dict(n)
async def find_by_group_id(
self,
group_id: str,
limit: int = 1000,
start_date: Optional[datetime] = None,
end_date: Optional[datetime] = None
) -> List[Dict[str, Any]]:
"""Query memory summaries by group_id
Args:
group_id: Group ID to filter by
limit: Maximum number of results to return
start_date: Optional start date filter
end_date: Optional end date filter
Returns:
List[Dict[str, Any]]: List of memory summary dictionaries
"""
query = f"""
MATCH (n:{self.node_label})
WHERE n.group_id = $group_id
"""
params = {"group_id": group_id, "limit": limit}
# Add date range filters if provided
if start_date:
query += " AND n.created_at >= $start_date"
params["start_date"] = start_date
if end_date:
query += " AND n.created_at <= $end_date"
params["end_date"] = end_date
query += """
RETURN n
ORDER BY n.created_at DESC
LIMIT $limit
"""
results = await self.connector.execute_query(query, **params)
return [self._map_to_dict(r) for r in results]
async def find_by_user_id(
self,
user_id: str,
limit: int = 1000,
start_date: Optional[datetime] = None,
end_date: Optional[datetime] = None
) -> List[Dict[str, Any]]:
"""Query memory summaries by user_id
Args:
user_id: User ID to filter by
limit: Maximum number of results to return
start_date: Optional start date filter
end_date: Optional end date filter
Returns:
List[Dict[str, Any]]: List of memory summary dictionaries
"""
query = f"""
MATCH (n:{self.node_label})
WHERE n.user_id = $user_id
"""
params = {"user_id": user_id, "limit": limit}
# Add date range filters if provided
if start_date:
query += " AND n.created_at >= $start_date"
params["start_date"] = start_date
if end_date:
query += " AND n.created_at <= $end_date"
params["end_date"] = end_date
query += """
RETURN n
ORDER BY n.created_at DESC
LIMIT $limit
"""
results = await self.connector.execute_query(query, **params)
return [self._map_to_dict(r) for r in results]
async def find_by_group_and_user(
self,
group_id: str,
user_id: str,
limit: int = 1000,
start_date: Optional[datetime] = None,
end_date: Optional[datetime] = None
) -> List[Dict[str, Any]]:
"""Query memory summaries by both group_id and user_id
Args:
group_id: Group ID to filter by
user_id: User ID to filter by
limit: Maximum number of results to return
start_date: Optional start date filter
end_date: Optional end date filter
Returns:
List[Dict[str, Any]]: List of memory summary dictionaries
"""
query = f"""
MATCH (n:{self.node_label})
WHERE n.group_id = $group_id AND n.user_id = $user_id
"""
params = {"group_id": group_id, "user_id": user_id, "limit": limit}
# Add date range filters if provided
if start_date:
query += " AND n.created_at >= $start_date"
params["start_date"] = start_date
if end_date:
query += " AND n.created_at <= $end_date"
params["end_date"] = end_date
query += """
RETURN n
ORDER BY n.created_at DESC
LIMIT $limit
"""
results = await self.connector.execute_query(query, **params)
return [self._map_to_dict(r) for r in results]
async def find_recent_summaries(
self,
group_id: str,
days: int = 7,
limit: int = 1000
) -> List[Dict[str, Any]]:
"""Query recent memory summaries
Args:
group_id: Group ID to filter by
days: Number of recent days to query
limit: Maximum number of results to return
Returns:
List[Dict[str, Any]]: List of memory summary dictionaries
"""
query = f"""
MATCH (n:{self.node_label})
WHERE n.group_id = $group_id
AND n.created_at >= datetime() - duration({{days: $days}})
RETURN n
ORDER BY n.created_at DESC
LIMIT $limit
"""
results = await self.connector.execute_query(
query,
group_id=group_id,
days=days,
limit=limit
)
return [self._map_to_dict(r) for r in results]
async def find_by_content_keywords(
self,
group_id: str,
keywords: List[str],
limit: int = 100
) -> List[Dict[str, Any]]:
"""Query memory summaries by content keywords
Args:
group_id: Group ID to filter by
keywords: List of keywords to search for in content
limit: Maximum number of results to return
Returns:
List[Dict[str, Any]]: List of memory summary dictionaries
"""
# Build keyword search conditions
keyword_conditions = []
params = {"group_id": group_id, "limit": limit}
for i, keyword in enumerate(keywords):
keyword_conditions.append(f"toLower(n.content) CONTAINS toLower($keyword_{i})")
params[f"keyword_{i}"] = keyword
keyword_filter = " OR ".join(keyword_conditions)
query = f"""
MATCH (n:{self.node_label})
WHERE n.group_id = $group_id
AND ({keyword_filter})
RETURN n
ORDER BY n.created_at DESC
LIMIT $limit
"""
results = await self.connector.execute_query(query, **params)
return [self._map_to_dict(r) for r in results]
async def get_summary_count_by_group(self, group_id: str) -> int:
"""Get count of memory summaries for a group
Args:
group_id: Group ID to count summaries for
Returns:
int: Number of memory summaries
"""
query = f"""
MATCH (n:{self.node_label})
WHERE n.group_id = $group_id
RETURN count(n) as count
"""
results = await self.connector.execute_query(query, group_id=group_id)
return results[0]['count'] if results else 0

View File

@@ -78,7 +78,7 @@ class StatementRepository(BaseNeo4jRepository[StatementNode]):
# 处理 ACT-R 属性 - 确保字段存在且有默认值
n['importance_score'] = n.get('importance_score', 0.5)
n['activation_value'] = n.get('activation_value')
n['access_history'] = n.get('access_history', [])
n['access_history'] = n.get('access_history') or []
n['last_access_time'] = n.get('last_access_time')
n['access_count'] = n.get('access_count', 0)

View File

@@ -38,6 +38,33 @@ class ToolRepository:
return result[0] if result else None
@staticmethod
def get_tenant_id_by_workspace_id(db: Session, workspace_id: str) -> Optional[uuid.UUID]:
"""
根据空间ID获取tenant_id
Args:
db: 数据库会话
workspace_id: 空间ID
Returns:
tenant_id或None
"""
from app.models.workspace_model import Workspace
tenant_id = db.query(Workspace.tenant_id).filter(
Workspace.id == workspace_id
).scalar()
if tenant_id is not None and not isinstance(tenant_id, uuid.UUID):
# 兼容数据库中字段类型不匹配的情况(比如存储为字符串)
try:
tenant_id = uuid.UUID(tenant_id)
except (ValueError, TypeError):
return None
return tenant_id
@staticmethod
def find_by_tenant(
db: Session,

View File

@@ -1,6 +1,6 @@
import datetime
import uuid
from typing import Optional, Any, List, Dict
from typing import Optional, Any, List, Dict, Union
from pydantic import BaseModel, Field, ConfigDict, field_serializer, field_validator
@@ -36,6 +36,12 @@ class KnowledgeRetrievalConfig(BaseModel):
class ToolConfig(BaseModel):
"""工具配置"""
enabled: bool = Field(default=False, description="是否启用该工具")
tool_id: Optional[str] = Field(default=None, description="工具ID")
operation: Optional[str] = Field(default=None, description="工具特定配置")
class ToolOldConfig(BaseModel):
"""工具配置"""
enabled: bool = Field(default=False, description="是否启用该工具")
config: Optional[Dict[str, Any]] = Field(default_factory=dict, description="工具特定配置")
@@ -103,9 +109,9 @@ class AgentConfigCreate(BaseModel):
)
# 工具配置
tools: Dict[str, ToolConfig] = Field(
default_factory=dict,
description="工具配置key 为工具名称web_search, code_interpreter, image_generation 等)"
tools: List[ToolConfig] = Field(
default_factory=list,
description="Agent 可用的工具列表"
)
@@ -158,7 +164,7 @@ class AgentConfigUpdate(BaseModel):
variables: Optional[List[VariableDefinition]] = Field(default=None, description="变量列表")
# 工具配置
tools: Optional[Dict[str, ToolConfig]] = Field(default=None, description="工具配置")
tools: Optional[List[ToolConfig]] = Field(default_factory=list, description="工具列表")
# ---------- Output Schemas ----------
@@ -216,7 +222,7 @@ class AgentConfig(BaseModel):
variables: List[VariableDefinition] = []
# 工具配置
tools: Dict[str, ToolConfig] = {}
tools: Union[List[ToolConfig], Dict[str, ToolOldConfig]] = []
is_active: bool
created_at: datetime.datetime

View File

@@ -35,14 +35,14 @@ class ChatRequest(BaseModel):
class Message(BaseModel):
"""消息输出"""
model_config = ConfigDict(from_attributes=True)
id: uuid.UUID
conversation_id: uuid.UUID
role: str
content: str
meta_data: Optional[Dict[str, Any]] = None
created_at: datetime.datetime
@field_serializer("created_at", when_used="json")
def _serialize_created_at(self, dt: datetime.datetime):
return int(dt.timestamp() * 1000) if dt else None
@@ -51,7 +51,7 @@ class Message(BaseModel):
class Conversation(BaseModel):
"""会话输出"""
model_config = ConfigDict(from_attributes=True)
id: uuid.UUID
app_id: uuid.UUID
workspace_id: uuid.UUID
@@ -63,11 +63,11 @@ class Conversation(BaseModel):
is_active: bool
created_at: datetime.datetime
updated_at: datetime.datetime
@field_serializer("created_at", when_used="json")
def _serialize_created_at(self, dt: datetime.datetime):
return int(dt.timestamp() * 1000) if dt else None
@field_serializer("updated_at", when_used="json")
def _serialize_updated_at(self, dt: datetime.datetime):
return int(dt.timestamp() * 1000) if dt else None
@@ -84,3 +84,12 @@ class ChatResponse(BaseModel):
message: str
usage: Optional[Dict[str, Any]] = None
elapsed_time: Optional[float] = None
# ---------- Conversation Summary Schemas ----------
class ConversationOut(BaseModel):
theme: str
question: list[str]
summary: str
takeaways: list[str]
info_score: int

View File

@@ -0,0 +1,264 @@
"""Implicit Memory Schemas
This module defines the Pydantic schemas for the implicit memory system API.
"""
import datetime
from enum import Enum
from typing import Any, Dict, List, Optional
from pydantic import BaseModel, ConfigDict, Field, field_serializer, field_validator
class FrequencyPattern(str, Enum):
"""Frequency patterns for habits."""
DAILY = "daily"
WEEKLY = "weekly"
MONTHLY = "monthly"
SEASONAL = "seasonal"
OCCASIONAL = "occasional"
EVENT_TRIGGERED = "event_triggered"
# Request Schemas
class TimeRange(BaseModel):
"""Time range for analysis."""
start_date: datetime.datetime
end_date: datetime.datetime
@field_validator('end_date')
@classmethod
def end_date_must_be_after_start_date(cls, v, info):
if 'start_date' in info.data and v <= info.data['start_date']:
raise ValueError('end_date must be after start_date')
return v
@field_serializer("start_date", when_used="json")
def _serialize_start_date(self, dt: datetime.datetime):
return int(dt.timestamp() * 1000) if dt else None
@field_serializer("end_date", when_used="json")
def _serialize_end_date(self, dt: datetime.datetime):
return int(dt.timestamp() * 1000) if dt else None
class DateRange(BaseModel):
"""Date range for filtering."""
start_date: Optional[datetime.datetime] = None
end_date: Optional[datetime.datetime] = None
@field_validator('end_date')
@classmethod
def end_date_must_be_after_start_date(cls, v, info):
if v and 'start_date' in info.data and info.data['start_date'] and v <= info.data['start_date']:
raise ValueError('end_date must be after start_date')
return v
@field_serializer("start_date", when_used="json")
def _serialize_start_date(self, dt: Optional[datetime.datetime]):
return int(dt.timestamp() * 1000) if dt else None
@field_serializer("end_date", when_used="json")
def _serialize_end_date(self, dt: Optional[datetime.datetime]):
return int(dt.timestamp() * 1000) if dt else None
class AnalysisConfig(BaseModel):
"""Configuration for analysis operations."""
llm_model_id: Optional[str] = None
batch_size: int = 100
confidence_threshold: float = 0.5
include_historical_trends: bool = False
max_conversations: Optional[int] = None
# Response Schemas
class PreferenceTagResponse(BaseModel):
"""A user preference tag with detailed context."""
model_config = ConfigDict(from_attributes=True)
tag_name: str
confidence_score: float = Field(ge=0.0, le=1.0)
supporting_evidence: List[str]
context_details: str
created_at: datetime.datetime
updated_at: datetime.datetime
conversation_references: List[str]
category: Optional[str] = None
@field_serializer("created_at", when_used="json")
def _serialize_created_at(self, dt: datetime.datetime):
return int(dt.timestamp() * 1000) if dt else None
@field_serializer("updated_at", when_used="json")
def _serialize_updated_at(self, dt: datetime.datetime):
return int(dt.timestamp() * 1000) if dt else None
class DimensionScoreResponse(BaseModel):
"""Score for a personality dimension."""
model_config = ConfigDict(from_attributes=True)
dimension_name: str
percentage: float = Field(ge=0.0, le=100.0)
evidence: List[str]
reasoning: str
confidence_level: int = Field(ge=0, le=100)
class DimensionPortraitResponse(BaseModel):
"""Four-dimension personality portrait."""
model_config = ConfigDict(from_attributes=True)
user_id: str
creativity: DimensionScoreResponse
aesthetic: DimensionScoreResponse
technology: DimensionScoreResponse
literature: DimensionScoreResponse
analysis_timestamp: datetime.datetime
total_summaries_analyzed: int
historical_trends: Optional[List[Dict[str, Any]]] = None
@field_serializer("analysis_timestamp", when_used="json")
def _serialize_analysis_timestamp(self, dt: datetime.datetime):
return int(dt.timestamp() * 1000) if dt else None
class InterestCategoryResponse(BaseModel):
"""Interest category with percentage and evidence."""
model_config = ConfigDict(from_attributes=True)
category_name: str
percentage: float = Field(ge=0.0, le=100.0)
evidence: List[str]
trending_direction: Optional[str] = None
class InterestAreaDistributionResponse(BaseModel):
"""Distribution of user interests across four areas."""
model_config = ConfigDict(from_attributes=True)
user_id: str
tech: InterestCategoryResponse
lifestyle: InterestCategoryResponse
music: InterestCategoryResponse
art: InterestCategoryResponse
analysis_timestamp: datetime.datetime
total_summaries_analyzed: int
@property
def total_percentage(self) -> float:
"""Calculate total percentage across all interest areas."""
return self.tech.percentage + self.lifestyle.percentage + self.music.percentage + self.art.percentage
@field_serializer("analysis_timestamp", when_used="json")
def _serialize_analysis_timestamp(self, dt: datetime.datetime):
return int(dt.timestamp() * 1000) if dt else None
class BehaviorHabitResponse(BaseModel):
"""A behavioral habit identified from conversations."""
model_config = ConfigDict(from_attributes=True)
habit_description: str
frequency_pattern: FrequencyPattern
time_context: str
confidence_level: int = Field(ge=0, le=100)
first_observed: datetime.datetime
last_observed: datetime.datetime
is_current: bool = True
specific_examples: List[str]
@field_serializer("first_observed", when_used="json")
def _serialize_first_observed(self, dt: datetime.datetime):
return int(dt.timestamp() * 1000) if dt else None
@field_serializer("last_observed", when_used="json")
def _serialize_last_observed(self, dt: datetime.datetime):
return int(dt.timestamp() * 1000) if dt else None
class UserProfileResponse(BaseModel):
"""Comprehensive user profile."""
model_config = ConfigDict(from_attributes=True)
user_id: str
preference_tags: List[PreferenceTagResponse]
dimension_portrait: DimensionPortraitResponse
interest_area_distribution: InterestAreaDistributionResponse
behavior_habits: List[BehaviorHabitResponse]
profile_version: int
created_at: datetime.datetime
updated_at: datetime.datetime
total_summaries_analyzed: int
analysis_completeness_score: float = Field(ge=0.0, le=1.0)
@field_serializer("created_at", when_used="json")
def _serialize_created_at(self, dt: datetime.datetime):
return int(dt.timestamp() * 1000) if dt else None
@field_serializer("updated_at", when_used="json")
def _serialize_updated_at(self, dt: datetime.datetime):
return int(dt.timestamp() * 1000) if dt else None
# Internal/Business Logic Schemas
class MemorySummary(BaseModel):
"""Memory summary from existing storage system."""
model_config = ConfigDict(from_attributes=True)
summary_id: str
content: str
timestamp: datetime.datetime
participants: List[str]
summary_type: str
@field_serializer("timestamp", when_used="json")
def _serialize_timestamp(self, dt: datetime.datetime):
return int(dt.timestamp() * 1000) if dt else None
class UserMemorySummary(BaseModel):
"""Memory summary filtered for specific user content."""
model_config = ConfigDict(from_attributes=True)
summary_id: str
user_id: str
user_content: str
timestamp: datetime.datetime
confidence_score: float = Field(ge=0.0, le=1.0)
summary_type: str
@field_serializer("timestamp", when_used="json")
def _serialize_timestamp(self, dt: datetime.datetime):
return int(dt.timestamp() * 1000) if dt else None
class SummaryAnalysisResult(BaseModel):
"""Result of analyzing memory summaries."""
model_config = ConfigDict(from_attributes=True)
user_id: str
preferences: List[PreferenceTagResponse]
dimension_evidence: Dict[str, List[str]]
interest_evidence: Dict[str, List[str]]
habit_evidence: List[Dict[str, Any]]
analysis_timestamp: datetime.datetime
summaries_analyzed: List[str]
@field_serializer("analysis_timestamp", when_used="json")
def _serialize_analysis_timestamp(self, dt: datetime.datetime):
return int(dt.timestamp() * 1000) if dt else None
# Aliases for backward compatibility with existing code
PreferenceTag = PreferenceTagResponse
DimensionScore = DimensionScoreResponse
DimensionPortrait = DimensionPortraitResponse
InterestCategory = InterestCategoryResponse
InterestAreaDistribution = InterestAreaDistributionResponse
BehaviorHabit = BehaviorHabitResponse
UserProfile = UserProfileResponse

View File

@@ -0,0 +1,136 @@
import uuid
from datetime import datetime
from typing import Optional
from pydantic import BaseModel, Field
from app.models.memory_perceptual_model import PerceptualType, FileStorageType
class PerceptualFilter(BaseModel):
type: PerceptualType | None = Field(
default=None,
description="Perceptual type used for filtering the query; optional"
)
class PerceptualQuerySchema(BaseModel):
filter: PerceptualFilter = Field(
default_factory=lambda: PerceptualFilter(),
description="Query filter containing perceptual type criteria"
)
page: int = Field(
default=1,
ge=1,
description="Page number for pagination, starting from 1"
)
page_size: int = Field(
default=10,
ge=1,
le=100,
description="Number of records per page, range 1-100"
)
class PerceptualMemoryItem(BaseModel):
"""感知记忆项"""
id: uuid.UUID = Field(..., description="Unique memory ID")
perceptual_type: PerceptualType = Field(..., description="Type of perception, e.g., text, audio, or video")
file_path: str = Field(..., description="File path in the storage service")
file_ext: str = Field(..., description="File extension")
file_name: str = Field(..., description="File name")
summary: Optional[str] = Field(None, description="summary")
storage_type: FileStorageType = Field(..., description="Storage type for file")
created_time: int = Field(..., description="create time")
class Config:
from_attributes = True
class PerceptualTimelineResponse(BaseModel):
"""感知记忆时间线响应"""
total: int = Field(..., description="总数量")
page: int = Field(..., description="当前页码")
page_size: int = Field(..., description="每页大小")
total_pages: int = Field(..., description="总页数")
memories: list[PerceptualMemoryItem] = Field(..., description="记忆列表")
class Config:
from_attributes = True
# --------------------------
# TODO: FileMetaData
# --------------------------
class Identity(BaseModel):
title: str
filename: str
source: str # upload | crawl | system
author: Optional[str] = None
class Semantic(BaseModel):
topic: str
domain: str
difficulty: str # beginner | intermediate | advanced
intent: str # informative | instructional | promotional
sentiment: str # positive | neutral | negative
class Content(BaseModel):
summary: str
keywords: list[str]
topic: str
domain: str
class Usage(BaseModel):
target_audience: list[str]
use_cases: list[str]
class Stats(BaseModel):
duration_sec: Optional[int] = None
char_count: int
word_count: int
class Processing(BaseModel):
transcribed: bool
ocr_applied: bool
chunked: bool
vectorized: bool
embedding_model: Optional[str] = None
class VideoModal(BaseModel):
scene: list[str]
class AudioModal(BaseModel):
speaker_count: int
class TextModal(BaseModel):
section_count: int
title: str
first_line: str
class Asset(BaseModel):
type: str
modality: str # text | audio | video
format: str # docx | mp3 | mp4
language: str
encoding: str
identity: Identity
semantic: Semantic
content: Content
usage: Usage
stats: Stats
processing: Processing
created_at: str
modalities: AudioModal | TextModal | VideoModal

View File

@@ -409,10 +409,9 @@ class ForgettingTriggerRequest(BaseModel):
"""手动触发遗忘周期请求模型"""
model_config = ConfigDict(populate_by_name=True, extra="forbid")
group_id: Optional[str] = Field(None, description="组ID可选,用于过滤特定组的节点")
group_id: str = Field(..., description="组ID即终端用户ID必填")
max_merge_batch_size: int = Field(100, ge=1, le=1000, description="单次最大融合节点对数默认100")
min_days_since_access: int = Field(30, ge=1, le=365, description="最小未访问天数默认30天")
config_id: Optional[int] = Field(None, description="配置ID可选用于指定遗忘引擎配置") # TODO 后续group_id更换成enduser_id自动与config_id关联 ,要删除此行
class ForgettingConfigResponse(BaseModel):
@@ -450,15 +449,36 @@ class ForgettingConfigUpdateRequest(BaseModel):
forgetting_interval_hours: Optional[int] = Field(None, ge=1, le=168, description="遗忘周期间隔(小时)")
class ForgettingCycleHistoryPoint(BaseModel):
"""遗忘周期历史数据点模型(用于趋势图)"""
model_config = ConfigDict(populate_by_name=True, extra="forbid")
date: str = Field(..., description="日期(格式: '1/1', '1/2'")
merged_count: int = Field(..., description="每日融合节点数")
average_activation: Optional[float] = Field(None, description="平均激活值")
total_nodes: int = Field(..., description="总节点数")
execution_time: int = Field(..., description="执行时间Unix时间戳")
class PendingForgettingNode(BaseModel):
"""待遗忘节点模型"""
model_config = ConfigDict(populate_by_name=True, extra="forbid")
node_id: str = Field(..., description="节点ID")
node_type: str = Field(..., description="节点类型statement/entity/summary")
content_summary: str = Field(..., description="内容摘要")
activation_value: float = Field(..., description="激活值")
last_access_time: int = Field(..., description="最后访问时间Unix时间戳")
class ForgettingStatsResponse(BaseModel):
"""遗忘引擎统计信息响应模型"""
model_config = ConfigDict(populate_by_name=True, extra="forbid")
activation_metrics: Dict[str, Any] = Field(..., description="激活值相关指标")
node_distribution: Dict[str, int] = Field(..., description="节点类型分布")
consistency_check: Optional[Dict[str, Any]] = Field(None, description="数据一致性检查结果")
nodes_merged_total: int = Field(..., description="累计融合节点对数")
recent_cycles: List[Dict[str, Any]] = Field(..., description="最近的遗忘周期记录")
timestamp: str = Field(..., description="统计时间ISO格式")
recent_trends: List[ForgettingCycleHistoryPoint] = Field(..., description="最近7个日期的遗忘趋势数据每天取最后一次执行")
pending_nodes: List[PendingForgettingNode] = Field(..., description="待遗忘节点列表前20个满足遗忘条件的节点")
timestamp: int = Field(..., description="统计时间(时间戳)")
class ForgettingReportResponse(BaseModel):

View File

@@ -66,9 +66,9 @@ class MultiAgentConfigCreate(BaseModel):
master_agent_id: uuid.UUID = Field(..., description="主 Agent ID")
master_agent_name: Optional[str] = Field(None, max_length=100, description="主 Agent 名称")
orchestration_mode: str = Field(
...,
pattern="^(sequential|parallel|conditional|loop)$",
description="编排模式:sequential|parallel|conditional|loop"
default="collaboration",
pattern="^(collaboration|supervisor)$",
description="协作模式:collaboration协作| supervisor监督"
)
sub_agents: List[SubAgentConfig] = Field(..., description="子 Agent 列表")
routing_rules: Optional[List[RoutingRule]] = Field(None, description="路由规则")
@@ -84,14 +84,15 @@ class MultiAgentConfigUpdate(BaseModel):
"""更新多 Agent 配置"""
master_agent_id: Optional[uuid.UUID] = None
master_agent_name: Optional[str] = Field(None, max_length=100, description="主 Agent 名称")
default_model_config_id : uuid.UUID = Field(description="默认模型配置ID")
model_parameters: ModelParameters | None = Field(
default_factory=ModelParameters,
default_model_config_id: Optional[uuid.UUID] = Field(None, description="默认模型配置ID")
model_parameters: Optional[ModelParameters] = Field(
None,
description="模型参数配置temperature、max_tokens 等)"
)
orchestration_mode: Optional[str] = Field(
None,
pattern="^(sequential|parallel|conditional|loop)$"
default="collaboration",
pattern="^(collaboration|supervisor)$",
description="协作模式collaboration协作| supervisor监督"
)
sub_agents: Optional[List[SubAgentConfig]] = None
routing_rules: Optional[List[RoutingRule]] = None

Some files were not shown because too many files have changed in this diff Show More