Merge branch 'develop' of github.com:SuanmoSuanyangTechnology/MemoryBear into develop
This commit is contained in:
@@ -4,40 +4,44 @@
|
||||
认证方式: JWT Token
|
||||
"""
|
||||
from fastapi import APIRouter
|
||||
|
||||
from . import (
|
||||
model_controller,
|
||||
task_controller,
|
||||
test_controller,
|
||||
user_controller,
|
||||
auth_controller,
|
||||
workspace_controller,
|
||||
setup_controller,
|
||||
file_controller,
|
||||
document_controller,
|
||||
knowledge_controller,
|
||||
chunk_controller,
|
||||
knowledgeshare_controller,
|
||||
api_key_controller,
|
||||
app_controller,
|
||||
upload_controller,
|
||||
auth_controller,
|
||||
chunk_controller,
|
||||
document_controller,
|
||||
emotion_config_controller,
|
||||
emotion_controller,
|
||||
file_controller,
|
||||
home_page_controller,
|
||||
implicit_memory_controller,
|
||||
knowledge_controller,
|
||||
knowledgeshare_controller,
|
||||
memory_agent_controller,
|
||||
memory_dashboard_controller,
|
||||
memory_storage_controller,
|
||||
memory_dashboard_controller,
|
||||
memory_forget_controller,
|
||||
memory_reflection_controller,
|
||||
memory_short_term_controller,
|
||||
api_key_controller,
|
||||
release_share_controller,
|
||||
public_share_controller,
|
||||
memory_storage_controller,
|
||||
model_controller,
|
||||
multi_agent_controller,
|
||||
workflow_controller,
|
||||
emotion_controller,
|
||||
emotion_config_controller,
|
||||
prompt_optimizer_controller,
|
||||
public_share_controller,
|
||||
release_share_controller,
|
||||
setup_controller,
|
||||
task_controller,
|
||||
test_controller,
|
||||
tool_controller,
|
||||
upload_controller,
|
||||
user_controller,
|
||||
user_memory_controllers,
|
||||
workflow_controller,
|
||||
workspace_controller,
|
||||
memory_forget_controller,
|
||||
home_page_controller,
|
||||
memory_perceptual_controller
|
||||
)
|
||||
from . import user_memory_controllers
|
||||
|
||||
# 创建管理端 API 路由器
|
||||
manager_router = APIRouter()
|
||||
@@ -76,5 +80,7 @@ manager_router.include_router(memory_short_term_controller.router)
|
||||
manager_router.include_router(tool_controller.router)
|
||||
manager_router.include_router(memory_forget_controller.router)
|
||||
manager_router.include_router(home_page_controller.router)
|
||||
manager_router.include_router(implicit_memory_controller.router)
|
||||
manager_router.include_router(memory_perceptual_controller.router)
|
||||
|
||||
__all__ = ["manager_router"]
|
||||
|
||||
302
api/app/controllers/implicit_memory_controller.py
Normal file
302
api/app/controllers/implicit_memory_controller.py
Normal file
@@ -0,0 +1,302 @@
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.core.response_utils import fail, success
|
||||
from app.db import get_db
|
||||
from app.dependencies import (
|
||||
cur_workspace_access_guard,
|
||||
get_current_user,
|
||||
)
|
||||
from app.models.user_model import User
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.services.implicit_memory_service import ImplicitMemoryService
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
api_logger = get_api_logger()
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/memory/implicit-memory",
|
||||
tags=["Implicit Memory"],
|
||||
)
|
||||
|
||||
|
||||
def handle_implicit_memory_error(e: Exception, operation: str, user_id: str = None) -> dict:
|
||||
"""
|
||||
Centralized error handling for implicit memory operations.
|
||||
|
||||
Args:
|
||||
e: The exception that occurred
|
||||
operation: Description of the operation that failed
|
||||
user_id: Optional user ID for logging context
|
||||
|
||||
Returns:
|
||||
Standardized error response
|
||||
"""
|
||||
error_context = f"user_id={user_id}" if user_id else "unknown user"
|
||||
|
||||
if isinstance(e, ValueError):
|
||||
if "user" in str(e).lower() and "not found" in str(e).lower():
|
||||
api_logger.warning(f"Invalid user ID for {operation}: {error_context}")
|
||||
return fail(BizCode.INVALID_USER_ID, "无效的用户ID", str(e))
|
||||
elif "insufficient" in str(e).lower() or "no data" in str(e).lower():
|
||||
api_logger.warning(f"Insufficient data for {operation}: {error_context}")
|
||||
return fail(BizCode.INSUFFICIENT_DATA, "数据不足,无法进行分析", str(e))
|
||||
else:
|
||||
api_logger.warning(f"Invalid parameters for {operation}: {error_context}")
|
||||
return fail(BizCode.INVALID_FILTER_PARAMS, "无效的参数", str(e))
|
||||
|
||||
elif isinstance(e, KeyError):
|
||||
api_logger.warning(f"Missing required data for {operation}: {error_context}")
|
||||
return fail(BizCode.INSUFFICIENT_DATA, "缺少必要的数据", str(e))
|
||||
|
||||
elif isinstance(e, (ConnectionError, TimeoutError)):
|
||||
api_logger.error(f"Service unavailable for {operation}: {error_context}")
|
||||
return fail(BizCode.SERVICE_UNAVAILABLE, "服务暂时不可用", str(e))
|
||||
|
||||
elif "analysis" in str(e).lower() or "llm" in str(e).lower():
|
||||
api_logger.error(f"Analysis failed for {operation}: {error_context}", exc_info=True)
|
||||
return fail(BizCode.ANALYSIS_FAILED, "分析处理失败", str(e))
|
||||
|
||||
elif "storage" in str(e).lower() or "database" in str(e).lower():
|
||||
api_logger.error(f"Storage error for {operation}: {error_context}", exc_info=True)
|
||||
return fail(BizCode.PROFILE_STORAGE_ERROR, "数据存储失败", str(e))
|
||||
|
||||
else:
|
||||
api_logger.error(f"Unexpected error for {operation}: {error_context}", exc_info=True)
|
||||
return fail(BizCode.INTERNAL_ERROR, f"{operation}失败", str(e))
|
||||
|
||||
|
||||
def validate_user_id(user_id: str) -> None:
|
||||
"""
|
||||
Validate user ID format and constraints.
|
||||
|
||||
Args:
|
||||
user_id: User ID to validate
|
||||
|
||||
Raises:
|
||||
ValueError: If user ID is invalid
|
||||
"""
|
||||
if not user_id or not user_id.strip():
|
||||
raise ValueError("User ID cannot be empty")
|
||||
|
||||
if len(user_id.strip()) < 1:
|
||||
raise ValueError("User ID is too short")
|
||||
|
||||
|
||||
def validate_date_range(start_date: Optional[datetime], end_date: Optional[datetime]) -> None:
|
||||
"""
|
||||
Validate date range parameters.
|
||||
|
||||
Args:
|
||||
start_date: Start date
|
||||
end_date: End date
|
||||
|
||||
Raises:
|
||||
ValueError: If date range is invalid
|
||||
"""
|
||||
if (start_date and not end_date) or (end_date and not start_date):
|
||||
raise ValueError("Both start_date and end_date must be provided together")
|
||||
|
||||
if start_date and end_date and start_date >= end_date:
|
||||
raise ValueError("start_date must be before end_date")
|
||||
|
||||
if start_date and start_date > datetime.now():
|
||||
raise ValueError("start_date cannot be in the future")
|
||||
|
||||
|
||||
def validate_confidence_threshold(threshold: float) -> None:
|
||||
"""
|
||||
Validate confidence threshold parameter.
|
||||
|
||||
Args:
|
||||
threshold: Confidence threshold to validate
|
||||
|
||||
Raises:
|
||||
ValueError: If threshold is invalid
|
||||
"""
|
||||
if not 0.0 <= threshold <= 1.0:
|
||||
raise ValueError("confidence_threshold must be between 0.0 and 1.0")
|
||||
|
||||
|
||||
@router.get("/preferences/{user_id}", response_model=ApiResponse)
|
||||
@cur_workspace_access_guard()
|
||||
async def get_preference_tags(
|
||||
user_id: str,
|
||||
confidence_threshold: float = Query(0.5, ge=0.0, le=1.0, description="Minimum confidence threshold"),
|
||||
tag_category: Optional[str] = Query(None, description="Filter by tag category"),
|
||||
start_date: Optional[datetime] = Query(None, description="Filter start date"),
|
||||
end_date: Optional[datetime] = Query(None, description="Filter end date"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
) -> ApiResponse:
|
||||
"""
|
||||
Get user preference tags with filtering options.
|
||||
|
||||
Args:
|
||||
user_id: Target user ID
|
||||
confidence_threshold: Minimum confidence score (0.0-1.0)
|
||||
tag_category: Optional category filter
|
||||
start_date: Optional start date filter
|
||||
end_date: Optional end date filter
|
||||
|
||||
Returns:
|
||||
List of preference tags matching the filters
|
||||
"""
|
||||
api_logger.info(f"Preference tags requested for user: {user_id}")
|
||||
|
||||
try:
|
||||
# Validate inputs
|
||||
validate_user_id(user_id)
|
||||
validate_confidence_threshold(confidence_threshold)
|
||||
validate_date_range(start_date, end_date)
|
||||
|
||||
# Create service with user-specific config
|
||||
service = ImplicitMemoryService(db=db, end_user_id=user_id)
|
||||
|
||||
# Build date range
|
||||
date_range = None
|
||||
if start_date and end_date:
|
||||
from app.schemas.implicit_memory_schema import DateRange
|
||||
date_range = DateRange(start_date=start_date, end_date=end_date)
|
||||
|
||||
# Get preference tags
|
||||
tags = await service.get_preference_tags(
|
||||
user_id=user_id,
|
||||
confidence_threshold=confidence_threshold,
|
||||
tag_category=tag_category,
|
||||
date_range=date_range
|
||||
)
|
||||
|
||||
api_logger.info(f"Retrieved {len(tags)} preference tags for user: {user_id}")
|
||||
return success(data=[tag.dict() for tag in tags], msg="偏好标签获取成功")
|
||||
|
||||
except Exception as e:
|
||||
return handle_implicit_memory_error(e, "偏好标签获取", user_id)
|
||||
|
||||
|
||||
@router.get("/portrait/{user_id}", response_model=ApiResponse)
|
||||
@cur_workspace_access_guard()
|
||||
async def get_dimension_portrait(
|
||||
user_id: str,
|
||||
include_history: bool = Query(False, description="Include historical trends"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
) -> ApiResponse:
|
||||
"""
|
||||
Get user's four-dimension personality portrait.
|
||||
|
||||
Args:
|
||||
user_id: Target user ID
|
||||
include_history: Whether to include historical trend data
|
||||
|
||||
Returns:
|
||||
Four-dimension personality portrait with scores and evidence
|
||||
"""
|
||||
api_logger.info(f"Dimension portrait requested for user: {user_id}")
|
||||
|
||||
try:
|
||||
# Validate inputs
|
||||
validate_user_id(user_id)
|
||||
|
||||
# Create service with user-specific config
|
||||
service = ImplicitMemoryService(db=db, end_user_id=user_id)
|
||||
|
||||
portrait = await service.get_dimension_portrait(
|
||||
user_id=user_id,
|
||||
include_history=include_history
|
||||
)
|
||||
|
||||
api_logger.info(f"Dimension portrait retrieved for user: {user_id}")
|
||||
return success(data=portrait.dict(), msg="四维画像获取成功")
|
||||
|
||||
except Exception as e:
|
||||
return handle_implicit_memory_error(e, "四维画像获取", user_id)
|
||||
|
||||
|
||||
@router.get("/interest-areas/{user_id}", response_model=ApiResponse)
|
||||
@cur_workspace_access_guard()
|
||||
async def get_interest_area_distribution(
|
||||
user_id: str,
|
||||
include_trends: bool = Query(False, description="Include trend analysis"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
) -> ApiResponse:
|
||||
"""
|
||||
Get user's interest area distribution across four areas.
|
||||
|
||||
Args:
|
||||
user_id: Target user ID
|
||||
include_trends: Whether to include trend analysis data
|
||||
|
||||
Returns:
|
||||
Interest area distribution with percentages and evidence
|
||||
"""
|
||||
api_logger.info(f"Interest area distribution requested for user: {user_id}")
|
||||
|
||||
try:
|
||||
# Validate inputs
|
||||
validate_user_id(user_id)
|
||||
|
||||
# Create service with user-specific config
|
||||
service = ImplicitMemoryService(db=db, end_user_id=user_id)
|
||||
|
||||
distribution = await service.get_interest_area_distribution(
|
||||
user_id=user_id,
|
||||
include_trends=include_trends
|
||||
)
|
||||
|
||||
api_logger.info(f"Interest area distribution retrieved for user: {user_id}")
|
||||
return success(data=distribution.dict(), msg="兴趣领域分布获取成功")
|
||||
|
||||
except Exception as e:
|
||||
return handle_implicit_memory_error(e, "兴趣领域分布获取", user_id)
|
||||
|
||||
|
||||
@router.get("/habits/{user_id}", response_model=ApiResponse)
|
||||
@cur_workspace_access_guard()
|
||||
async def get_behavior_habits(
|
||||
user_id: str,
|
||||
confidence_level: Optional[str] = Query(None, regex="^(high|medium|low)$", description="Filter by confidence level"),
|
||||
frequency_pattern: Optional[str] = Query(None, regex="^(daily|weekly|monthly|seasonal|occasional|event_triggered)$", description="Filter by frequency pattern"),
|
||||
time_period: Optional[str] = Query(None, regex="^(current|past)$", description="Filter by time period"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
) -> ApiResponse:
|
||||
"""
|
||||
Get user's behavioral habits with filtering options.
|
||||
|
||||
Args:
|
||||
user_id: Target user ID
|
||||
confidence_level: Filter by confidence level (high, medium, low)
|
||||
frequency_pattern: Filter by frequency pattern (daily, weekly, monthly, seasonal, occasional, event_triggered)
|
||||
time_period: Filter by time period (current, past)
|
||||
|
||||
Returns:
|
||||
List of behavioral habits matching the filters
|
||||
"""
|
||||
api_logger.info(f"Behavior habits requested for user: {user_id}")
|
||||
|
||||
try:
|
||||
# Validate inputs
|
||||
validate_user_id(user_id)
|
||||
|
||||
# Create service with user-specific config
|
||||
service = ImplicitMemoryService(db=db, end_user_id=user_id)
|
||||
|
||||
habits = await service.get_behavior_habits(
|
||||
user_id=user_id,
|
||||
confidence_level=confidence_level,
|
||||
frequency_pattern=frequency_pattern,
|
||||
time_period=time_period
|
||||
)
|
||||
|
||||
api_logger.info(f"Retrieved {len(habits)} behavior habits for user: {user_id}")
|
||||
return success(data=[habit.dict() for habit in habits], msg="行为习惯获取成功")
|
||||
|
||||
except Exception as e:
|
||||
return handle_implicit_memory_error(e, "行为习惯获取", user_id)
|
||||
|
||||
|
||||
@@ -82,6 +82,13 @@ class BizCode(IntEnum):
|
||||
MEMORY_WRITE_FAILED = 9501
|
||||
MEMORY_READ_FAILED = 9502
|
||||
MEMORY_CONFIG_NOT_FOUND = 9503
|
||||
|
||||
# Implicit Memory API(96xx)
|
||||
INVALID_USER_ID = 9601
|
||||
INSUFFICIENT_DATA = 9602
|
||||
INVALID_FILTER_PARAMS = 9603
|
||||
ANALYSIS_FAILED = 9604
|
||||
PROFILE_STORAGE_ERROR = 9605
|
||||
|
||||
# 系统(100xx)
|
||||
INTERNAL_ERROR = 10001
|
||||
@@ -159,6 +166,13 @@ HTTP_MAPPING = {
|
||||
BizCode.MEMORY_READ_FAILED: 500,
|
||||
BizCode.MEMORY_CONFIG_NOT_FOUND: 400,
|
||||
|
||||
# Implicit Memory API 错误码映射
|
||||
BizCode.INVALID_USER_ID: 400,
|
||||
BizCode.INSUFFICIENT_DATA: 400,
|
||||
BizCode.INVALID_FILTER_PARAMS: 400,
|
||||
BizCode.ANALYSIS_FAILED: 500,
|
||||
BizCode.PROFILE_STORAGE_ERROR: 500,
|
||||
|
||||
BizCode.INTERNAL_ERROR: 500,
|
||||
BizCode.DB_ERROR: 500,
|
||||
BizCode.SERVICE_UNAVAILABLE: 503,
|
||||
|
||||
@@ -0,0 +1,6 @@
|
||||
"""Implicit Memory Module
|
||||
|
||||
This module provides behavior analysis capabilities that build comprehensive user profiles
|
||||
by analyzing memory summary nodes from Neo4j. It creates detailed user portraits across
|
||||
multiple dimensions, tracks interest distributions, and identifies behavioral habits.
|
||||
"""
|
||||
@@ -0,0 +1 @@
|
||||
"""Analyzers package for implicit memory analysis components."""
|
||||
@@ -0,0 +1,264 @@
|
||||
"""Dimension Analyzer for Implicit Memory System
|
||||
|
||||
This module implements LLM-based personality dimension analysis from user memory summaries.
|
||||
It analyzes four key dimensions: creativity, aesthetic, technology, and literature,
|
||||
providing percentage scores with evidence and reasoning.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from app.core.memory.analytics.implicit_memory.llm_client import ImplicitMemoryLLMClient
|
||||
from app.core.memory.llm_tools.llm_client import LLMClientException
|
||||
from app.schemas.implicit_memory_schema import (
|
||||
ConfidenceLevel,
|
||||
DimensionPortrait,
|
||||
DimensionScore,
|
||||
UserMemorySummary,
|
||||
)
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DimensionData(BaseModel):
|
||||
"""Individual dimension analysis data."""
|
||||
percentage: float = Field(ge=0.0, le=100.0)
|
||||
evidence: List[str] = Field(default_factory=list)
|
||||
reasoning: str = ""
|
||||
confidence_level: str = "medium"
|
||||
|
||||
|
||||
class DimensionAnalysisResponse(BaseModel):
|
||||
"""Response model for dimension analysis."""
|
||||
dimensions: Dict[str, DimensionData] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class DimensionAnalyzer:
|
||||
"""Analyzes user memory summaries to extract personality dimensions."""
|
||||
|
||||
# Define the four dimensions we analyze
|
||||
DIMENSIONS = ["creativity", "aesthetic", "technology", "literature"]
|
||||
|
||||
def __init__(self, db: Session, llm_model_id: Optional[str] = None):
|
||||
"""Initialize the dimension analyzer.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
llm_model_id: Optional LLM model ID to use for analysis
|
||||
"""
|
||||
self.db = db
|
||||
self.llm_model_id = llm_model_id
|
||||
self._llm_client = ImplicitMemoryLLMClient(db, llm_model_id)
|
||||
|
||||
async def analyze_dimensions(
|
||||
self,
|
||||
user_id: str,
|
||||
user_summaries: List[UserMemorySummary],
|
||||
existing_portrait: Optional[DimensionPortrait] = None
|
||||
) -> DimensionPortrait:
|
||||
"""Analyze user summaries to extract personality dimensions.
|
||||
|
||||
Args:
|
||||
user_id: Target user ID
|
||||
user_summaries: List of user-specific memory summaries
|
||||
existing_portrait: Optional existing portrait for incremental updates
|
||||
|
||||
Returns:
|
||||
Dimension portrait with four personality dimensions
|
||||
|
||||
Raises:
|
||||
LLMClientException: If LLM analysis fails
|
||||
"""
|
||||
if not user_summaries:
|
||||
logger.warning(f"No summaries provided for user {user_id}")
|
||||
return self._create_empty_portrait(user_id)
|
||||
|
||||
try:
|
||||
logger.info(f"Analyzing dimensions for user {user_id} with {len(user_summaries)} summaries")
|
||||
|
||||
# Use the LLM client wrapper for analysis
|
||||
response = await self._llm_client.analyze_dimensions(
|
||||
user_summaries=user_summaries,
|
||||
user_id=user_id,
|
||||
model_id=self.llm_model_id
|
||||
)
|
||||
|
||||
# Create dimension scores
|
||||
dimension_scores = {}
|
||||
current_time = datetime.now()
|
||||
|
||||
for dimension_name in self.DIMENSIONS:
|
||||
# Handle response as dictionary
|
||||
dimensions_data = response.get("dimensions", {})
|
||||
dimension_data = dimensions_data.get(dimension_name)
|
||||
|
||||
if dimension_data:
|
||||
# Validate and create dimension score
|
||||
score = self._create_dimension_score(
|
||||
dimension_name=dimension_name,
|
||||
dimension_data=dimension_data
|
||||
)
|
||||
dimension_scores[dimension_name] = score
|
||||
else:
|
||||
# Create default score if missing
|
||||
logger.warning(f"Missing dimension data for {dimension_name}, using default")
|
||||
dimension_scores[dimension_name] = self._create_default_dimension_score(dimension_name)
|
||||
|
||||
# Create dimension portrait
|
||||
portrait = DimensionPortrait(
|
||||
user_id=user_id,
|
||||
creativity=dimension_scores["creativity"],
|
||||
aesthetic=dimension_scores["aesthetic"],
|
||||
technology=dimension_scores["technology"],
|
||||
literature=dimension_scores["literature"],
|
||||
analysis_timestamp=current_time,
|
||||
total_summaries_analyzed=len(user_summaries),
|
||||
historical_trends=self._calculate_historical_trends(existing_portrait) if existing_portrait else None
|
||||
)
|
||||
|
||||
logger.info(f"Created dimension portrait for user {user_id}")
|
||||
return portrait
|
||||
|
||||
except LLMClientException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Dimension analysis failed for user {user_id}: {e}")
|
||||
raise LLMClientException(f"Dimension analysis failed: {e}") from e
|
||||
|
||||
def _create_dimension_score(
|
||||
self,
|
||||
dimension_name: str,
|
||||
dimension_data: dict
|
||||
) -> DimensionScore:
|
||||
"""Create a dimension score from analysis data.
|
||||
|
||||
Args:
|
||||
dimension_name: Name of the dimension
|
||||
dimension_data: Analysis data dictionary for the dimension
|
||||
|
||||
Returns:
|
||||
DimensionScore object
|
||||
"""
|
||||
# Validate percentage - handle dict access
|
||||
percentage = dimension_data.get("percentage", 0.0)
|
||||
percentage = max(0.0, min(100.0, float(percentage)))
|
||||
|
||||
# Validate confidence level
|
||||
confidence_level_str = dimension_data.get("confidence_level", "low")
|
||||
confidence_level = self._validate_confidence_level(confidence_level_str)
|
||||
|
||||
# Ensure evidence is not empty
|
||||
evidence = dimension_data.get("evidence", [])
|
||||
if not evidence:
|
||||
evidence = ["No specific evidence found"]
|
||||
|
||||
# Ensure reasoning is not empty
|
||||
reasoning = dimension_data.get("reasoning", "")
|
||||
if not reasoning:
|
||||
reasoning = f"Analysis for {dimension_name} dimension"
|
||||
|
||||
return DimensionScore(
|
||||
dimension_name=dimension_name,
|
||||
percentage=percentage,
|
||||
evidence=evidence,
|
||||
reasoning=reasoning,
|
||||
confidence_level=confidence_level
|
||||
)
|
||||
|
||||
def _create_default_dimension_score(self, dimension_name: str) -> DimensionScore:
|
||||
"""Create a default dimension score when analysis fails.
|
||||
|
||||
Args:
|
||||
dimension_name: Name of the dimension
|
||||
|
||||
Returns:
|
||||
Default DimensionScore object
|
||||
"""
|
||||
return DimensionScore(
|
||||
dimension_name=dimension_name,
|
||||
percentage=0.0,
|
||||
evidence=["Insufficient data for analysis"],
|
||||
reasoning=f"No clear evidence found for {dimension_name} dimension",
|
||||
confidence_level=ConfidenceLevel.LOW
|
||||
)
|
||||
|
||||
def _validate_confidence_level(self, confidence_str: str) -> ConfidenceLevel:
|
||||
"""Validate and convert confidence level string.
|
||||
|
||||
Args:
|
||||
confidence_str: Confidence level as string
|
||||
|
||||
Returns:
|
||||
ConfidenceLevel enum value
|
||||
"""
|
||||
if not confidence_str:
|
||||
return ConfidenceLevel.MEDIUM
|
||||
|
||||
confidence_str = str(confidence_str).lower().strip()
|
||||
|
||||
if confidence_str in ["high", "높음"]:
|
||||
return ConfidenceLevel.HIGH
|
||||
elif confidence_str in ["medium", "중간"]:
|
||||
return ConfidenceLevel.MEDIUM
|
||||
elif confidence_str in ["low", "낮음"]:
|
||||
return ConfidenceLevel.LOW
|
||||
else:
|
||||
logger.warning(f"Unknown confidence level: {confidence_str}, defaulting to medium")
|
||||
return ConfidenceLevel.MEDIUM
|
||||
|
||||
def _create_empty_portrait(self, user_id: str) -> DimensionPortrait:
|
||||
"""Create an empty dimension portrait when no data is available.
|
||||
|
||||
Args:
|
||||
user_id: Target user ID
|
||||
|
||||
Returns:
|
||||
Empty DimensionPortrait
|
||||
"""
|
||||
current_time = datetime.now()
|
||||
|
||||
return DimensionPortrait(
|
||||
user_id=user_id,
|
||||
creativity=self._create_default_dimension_score("creativity"),
|
||||
aesthetic=self._create_default_dimension_score("aesthetic"),
|
||||
technology=self._create_default_dimension_score("technology"),
|
||||
literature=self._create_default_dimension_score("literature"),
|
||||
analysis_timestamp=current_time,
|
||||
total_summaries_analyzed=0,
|
||||
historical_trends=None
|
||||
)
|
||||
|
||||
def _calculate_historical_trends(
|
||||
self,
|
||||
existing_portrait: DimensionPortrait
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Calculate historical trends from existing portrait.
|
||||
|
||||
Args:
|
||||
existing_portrait: Previous dimension portrait
|
||||
|
||||
Returns:
|
||||
List of historical trend data
|
||||
"""
|
||||
if not existing_portrait:
|
||||
return []
|
||||
|
||||
# Create trend entry from existing portrait
|
||||
trend_entry = {
|
||||
"timestamp": existing_portrait.analysis_timestamp.isoformat(),
|
||||
"creativity": existing_portrait.creativity.percentage,
|
||||
"aesthetic": existing_portrait.aesthetic.percentage,
|
||||
"technology": existing_portrait.technology.percentage,
|
||||
"literature": existing_portrait.literature.percentage,
|
||||
"total_summaries": existing_portrait.total_summaries_analyzed
|
||||
}
|
||||
|
||||
# Combine with existing trends
|
||||
existing_trends = existing_portrait.historical_trends or []
|
||||
|
||||
# Keep only recent trends (last 10 entries)
|
||||
all_trends = existing_trends + [trend_entry]
|
||||
return all_trends[-10:]
|
||||
@@ -0,0 +1,452 @@
|
||||
"""Habit Analyzer for Implicit Memory System
|
||||
|
||||
This module implements LLM-based behavioral habit analysis from user memory summaries.
|
||||
It identifies recurring behavioral patterns, temporal patterns, and consolidates
|
||||
similar habits with confidence scoring.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from app.core.memory.analytics.implicit_memory.llm_client import ImplicitMemoryLLMClient
|
||||
from app.core.memory.llm_tools.llm_client import LLMClientException
|
||||
from app.schemas.implicit_memory_schema import (
|
||||
BehaviorHabit,
|
||||
ConfidenceLevel,
|
||||
FrequencyPattern,
|
||||
UserMemorySummary,
|
||||
)
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HabitData(BaseModel):
|
||||
"""Individual habit analysis data."""
|
||||
habit_description: str
|
||||
frequency_pattern: str
|
||||
time_context: str
|
||||
confidence_level: str
|
||||
supporting_summaries: List[str] = Field(default_factory=list)
|
||||
specific_examples: List[str] = Field(default_factory=list)
|
||||
is_current: bool = True
|
||||
|
||||
|
||||
class HabitAnalysisResponse(BaseModel):
|
||||
"""Response model for habit analysis."""
|
||||
habits: List[HabitData] = Field(default_factory=list)
|
||||
|
||||
|
||||
class HabitAnalyzer:
|
||||
"""Analyzes user memory summaries to extract behavioral habits."""
|
||||
|
||||
def __init__(self, db: Session, llm_model_id: Optional[str] = None):
|
||||
"""Initialize the habit analyzer.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
llm_model_id: Optional LLM model ID to use for analysis
|
||||
"""
|
||||
self.db = db
|
||||
self.llm_model_id = llm_model_id
|
||||
self._llm_client = ImplicitMemoryLLMClient(db, llm_model_id)
|
||||
|
||||
async def analyze_habits(
|
||||
self,
|
||||
user_id: str,
|
||||
user_summaries: List[UserMemorySummary],
|
||||
existing_habits: Optional[List[BehaviorHabit]] = None
|
||||
) -> List[BehaviorHabit]:
|
||||
"""Analyze user summaries to extract behavioral habits.
|
||||
|
||||
Args:
|
||||
user_id: Target user ID
|
||||
user_summaries: List of user-specific memory summaries
|
||||
existing_habits: Optional existing habits for consolidation
|
||||
|
||||
Returns:
|
||||
List of extracted behavioral habits
|
||||
|
||||
Raises:
|
||||
LLMClientException: If LLM analysis fails
|
||||
"""
|
||||
if not user_summaries:
|
||||
logger.warning(f"No summaries provided for user {user_id}")
|
||||
return existing_habits or []
|
||||
|
||||
try:
|
||||
logger.info(f"Analyzing habits for user {user_id} with {len(user_summaries)} summaries")
|
||||
|
||||
# Use the LLM client wrapper for analysis
|
||||
response = await self._llm_client.analyze_habits(
|
||||
user_summaries=user_summaries,
|
||||
user_id=user_id,
|
||||
model_id=self.llm_model_id
|
||||
)
|
||||
|
||||
# Convert to BehaviorHabit objects
|
||||
behavior_habits = []
|
||||
current_time = datetime.now()
|
||||
|
||||
for habit_data in response.get("habits", []):
|
||||
try:
|
||||
# Handle habit_data as dictionary
|
||||
supporting_summaries = habit_data.get("supporting_summaries", [])
|
||||
specific_examples = habit_data.get("specific_examples", [])
|
||||
|
||||
# Determine observation dates from summaries
|
||||
first_observed, last_observed = self._determine_observation_dates(
|
||||
user_summaries, supporting_summaries
|
||||
)
|
||||
|
||||
behavior_habit = BehaviorHabit(
|
||||
habit_description=habit_data.get("habit_description", ""),
|
||||
frequency_pattern=self._validate_frequency_pattern(habit_data.get("frequency_pattern", "occasional")),
|
||||
time_context=habit_data.get("time_context", ""),
|
||||
confidence_level=self._validate_confidence_level(habit_data.get("confidence_level", "medium")),
|
||||
supporting_summaries=supporting_summaries,
|
||||
specific_examples=specific_examples,
|
||||
first_observed=first_observed,
|
||||
last_observed=last_observed,
|
||||
is_current=habit_data.get("is_current", True)
|
||||
)
|
||||
|
||||
# Validate habit
|
||||
if self._is_valid_habit(behavior_habit):
|
||||
behavior_habits.append(behavior_habit)
|
||||
else:
|
||||
logger.warning(f"Invalid habit skipped: {behavior_habit.habit_description}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating behavior habit: {e}")
|
||||
continue
|
||||
|
||||
# Consolidate with existing habits if provided
|
||||
if existing_habits:
|
||||
behavior_habits = self._consolidate_habits(
|
||||
new_habits=behavior_habits,
|
||||
existing_habits=existing_habits
|
||||
)
|
||||
|
||||
# Sort habits by confidence and recency
|
||||
behavior_habits = self._sort_habits_by_priority(behavior_habits)
|
||||
|
||||
logger.info(f"Extracted {len(behavior_habits)} habits for user {user_id}")
|
||||
return behavior_habits
|
||||
|
||||
except LLMClientException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Habit analysis failed for user {user_id}: {e}")
|
||||
raise LLMClientException(f"Habit analysis failed: {e}") from e
|
||||
|
||||
def _validate_frequency_pattern(self, frequency_str: str) -> FrequencyPattern:
|
||||
"""Validate and convert frequency pattern string.
|
||||
|
||||
Args:
|
||||
frequency_str: Frequency pattern as string
|
||||
|
||||
Returns:
|
||||
FrequencyPattern enum value
|
||||
"""
|
||||
frequency_str = frequency_str.lower().strip()
|
||||
|
||||
frequency_mapping = {
|
||||
"daily": FrequencyPattern.DAILY,
|
||||
"weekly": FrequencyPattern.WEEKLY,
|
||||
"monthly": FrequencyPattern.MONTHLY,
|
||||
"seasonal": FrequencyPattern.SEASONAL,
|
||||
"occasional": FrequencyPattern.OCCASIONAL,
|
||||
"event_triggered": FrequencyPattern.EVENT_TRIGGERED,
|
||||
"event-triggered": FrequencyPattern.EVENT_TRIGGERED,
|
||||
}
|
||||
|
||||
return frequency_mapping.get(frequency_str, FrequencyPattern.OCCASIONAL)
|
||||
|
||||
def _validate_confidence_level(self, confidence_str: str) -> ConfidenceLevel:
|
||||
"""Validate and convert confidence level string.
|
||||
|
||||
Args:
|
||||
confidence_str: Confidence level as string
|
||||
|
||||
Returns:
|
||||
ConfidenceLevel enum value
|
||||
"""
|
||||
confidence_str = confidence_str.lower().strip()
|
||||
|
||||
if confidence_str in ["high", "높음"]:
|
||||
return ConfidenceLevel.HIGH
|
||||
elif confidence_str in ["medium", "중간"]:
|
||||
return ConfidenceLevel.MEDIUM
|
||||
elif confidence_str in ["low", "낮음"]:
|
||||
return ConfidenceLevel.LOW
|
||||
else:
|
||||
logger.warning(f"Unknown confidence level: {confidence_str}, defaulting to medium")
|
||||
return ConfidenceLevel.MEDIUM
|
||||
|
||||
def _determine_observation_dates(
|
||||
self,
|
||||
user_summaries: List[UserMemorySummary],
|
||||
supporting_summary_ids: List[str]
|
||||
) -> tuple[datetime, datetime]:
|
||||
"""Determine first and last observation dates for a habit.
|
||||
|
||||
Args:
|
||||
user_summaries: List of user summaries
|
||||
supporting_summary_ids: IDs of summaries supporting the habit
|
||||
|
||||
Returns:
|
||||
Tuple of (first_observed, last_observed) dates
|
||||
"""
|
||||
from datetime import timezone
|
||||
|
||||
# Find summaries that support this habit
|
||||
supporting_summaries = [
|
||||
summary for summary in user_summaries
|
||||
if summary.summary_id in supporting_summary_ids
|
||||
]
|
||||
|
||||
if not supporting_summaries:
|
||||
# Use all summaries if no specific supporting summaries found
|
||||
supporting_summaries = user_summaries
|
||||
|
||||
if not supporting_summaries:
|
||||
current_time = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
return current_time, current_time
|
||||
|
||||
# Get date range from supporting summaries - normalize to naive datetimes
|
||||
timestamps = []
|
||||
for summary in supporting_summaries:
|
||||
ts = summary.timestamp
|
||||
# Convert to naive datetime if it's timezone-aware
|
||||
if ts.tzinfo is not None:
|
||||
ts = ts.replace(tzinfo=None)
|
||||
timestamps.append(ts)
|
||||
|
||||
first_observed = min(timestamps)
|
||||
last_observed = max(timestamps)
|
||||
|
||||
return first_observed, last_observed
|
||||
|
||||
def _is_valid_habit(self, habit: BehaviorHabit) -> bool:
|
||||
"""Validate a behavioral habit.
|
||||
|
||||
Args:
|
||||
habit: Behavioral habit to validate
|
||||
|
||||
Returns:
|
||||
True if valid, False otherwise
|
||||
"""
|
||||
try:
|
||||
# Check required fields
|
||||
if not habit.habit_description or not habit.habit_description.strip():
|
||||
return False
|
||||
|
||||
# Check time context
|
||||
if not habit.time_context or not habit.time_context.strip():
|
||||
return False
|
||||
|
||||
# Check supporting summaries
|
||||
if not habit.supporting_summaries or len(habit.supporting_summaries) == 0:
|
||||
return False
|
||||
|
||||
# Check specific examples
|
||||
if not habit.specific_examples or len(habit.specific_examples) == 0:
|
||||
return False
|
||||
|
||||
# Check observation dates
|
||||
if habit.first_observed > habit.last_observed:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error validating habit: {e}")
|
||||
return False
|
||||
|
||||
def _consolidate_habits(
|
||||
self,
|
||||
new_habits: List[BehaviorHabit],
|
||||
existing_habits: List[BehaviorHabit],
|
||||
similarity_threshold: float = 0.7
|
||||
) -> List[BehaviorHabit]:
|
||||
"""Consolidate new habits with existing ones.
|
||||
|
||||
Args:
|
||||
new_habits: Newly extracted habits
|
||||
existing_habits: Existing habits
|
||||
similarity_threshold: Threshold for considering habits similar
|
||||
|
||||
Returns:
|
||||
Consolidated list of habits
|
||||
"""
|
||||
consolidated = existing_habits.copy()
|
||||
current_time = datetime.now()
|
||||
|
||||
for new_habit in new_habits:
|
||||
# Find similar existing habit
|
||||
similar_habit = self._find_similar_habit(
|
||||
new_habit, existing_habits, similarity_threshold
|
||||
)
|
||||
|
||||
if similar_habit:
|
||||
# Update existing habit
|
||||
updated_habit = self._merge_habits(similar_habit, new_habit, current_time)
|
||||
# Replace in consolidated list
|
||||
for i, habit in enumerate(consolidated):
|
||||
if habit.habit_description == similar_habit.habit_description:
|
||||
consolidated[i] = updated_habit
|
||||
break
|
||||
else:
|
||||
# Add as new habit
|
||||
consolidated.append(new_habit)
|
||||
|
||||
return consolidated
|
||||
|
||||
def _find_similar_habit(
|
||||
self,
|
||||
target_habit: BehaviorHabit,
|
||||
existing_habits: List[BehaviorHabit],
|
||||
threshold: float
|
||||
) -> Optional[BehaviorHabit]:
|
||||
"""Find similar habit in existing list.
|
||||
|
||||
Args:
|
||||
target_habit: Habit to find similarity for
|
||||
existing_habits: List of existing habits
|
||||
threshold: Similarity threshold
|
||||
|
||||
Returns:
|
||||
Similar habit if found, None otherwise
|
||||
"""
|
||||
target_desc = target_habit.habit_description.lower().strip()
|
||||
|
||||
for existing_habit in existing_habits:
|
||||
existing_desc = existing_habit.habit_description.lower().strip()
|
||||
|
||||
# Check description similarity
|
||||
desc_similarity = self._calculate_text_similarity(target_desc, existing_desc)
|
||||
|
||||
# Check frequency pattern match
|
||||
frequency_match = (target_habit.frequency_pattern == existing_habit.frequency_pattern)
|
||||
|
||||
# Check time context similarity
|
||||
time_similarity = self._calculate_text_similarity(
|
||||
target_habit.time_context.lower(),
|
||||
existing_habit.time_context.lower()
|
||||
)
|
||||
|
||||
# Combined similarity score
|
||||
combined_similarity = (desc_similarity * 0.6 + time_similarity * 0.4)
|
||||
if frequency_match:
|
||||
combined_similarity += 0.1 # Bonus for frequency match
|
||||
|
||||
if combined_similarity >= threshold:
|
||||
return existing_habit
|
||||
|
||||
return None
|
||||
|
||||
def _calculate_text_similarity(self, text1: str, text2: str) -> float:
|
||||
"""Calculate simple text similarity based on common words.
|
||||
|
||||
Args:
|
||||
text1: First text
|
||||
text2: Second text
|
||||
|
||||
Returns:
|
||||
Similarity score between 0.0 and 1.0
|
||||
"""
|
||||
if not text1 or not text2:
|
||||
return 0.0
|
||||
|
||||
# Simple word-based similarity
|
||||
words1 = set(text1.lower().split())
|
||||
words2 = set(text2.lower().split())
|
||||
|
||||
if not words1 or not words2:
|
||||
return 0.0
|
||||
|
||||
intersection = words1.intersection(words2)
|
||||
union = words1.union(words2)
|
||||
|
||||
return len(intersection) / len(union) if union else 0.0
|
||||
|
||||
def _merge_habits(
|
||||
self,
|
||||
existing_habit: BehaviorHabit,
|
||||
new_habit: BehaviorHabit,
|
||||
current_time: datetime
|
||||
) -> BehaviorHabit:
|
||||
"""Merge two similar habits.
|
||||
|
||||
Args:
|
||||
existing_habit: Existing habit
|
||||
new_habit: New habit to merge
|
||||
current_time: Current timestamp
|
||||
|
||||
Returns:
|
||||
Merged behavioral habit
|
||||
"""
|
||||
# Combine supporting summaries
|
||||
combined_summaries = list(set(
|
||||
existing_habit.supporting_summaries + new_habit.supporting_summaries
|
||||
))
|
||||
|
||||
# Combine specific examples
|
||||
combined_examples = list(set(
|
||||
existing_habit.specific_examples + new_habit.specific_examples
|
||||
))
|
||||
|
||||
# Update confidence level (take higher confidence)
|
||||
confidence_levels = [existing_habit.confidence_level, new_habit.confidence_level]
|
||||
new_confidence = max(confidence_levels, key=lambda x: ["low", "medium", "high"].index(x.value))
|
||||
|
||||
# Update observation dates
|
||||
first_observed = min(existing_habit.first_observed, new_habit.first_observed)
|
||||
last_observed = max(existing_habit.last_observed, new_habit.last_observed)
|
||||
|
||||
# Determine if habit is current (observed within last 30 days)
|
||||
is_current = (current_time - last_observed).days <= 30
|
||||
|
||||
# Combine time context
|
||||
combined_time_context = existing_habit.time_context
|
||||
if new_habit.time_context and new_habit.time_context not in combined_time_context:
|
||||
combined_time_context += f"; {new_habit.time_context}"
|
||||
|
||||
return BehaviorHabit(
|
||||
habit_description=existing_habit.habit_description, # Keep original description
|
||||
frequency_pattern=existing_habit.frequency_pattern, # Keep original frequency
|
||||
time_context=combined_time_context,
|
||||
confidence_level=new_confidence,
|
||||
supporting_summaries=combined_summaries,
|
||||
specific_examples=combined_examples,
|
||||
first_observed=first_observed,
|
||||
last_observed=last_observed,
|
||||
is_current=is_current
|
||||
)
|
||||
|
||||
def _sort_habits_by_priority(self, habits: List[BehaviorHabit]) -> List[BehaviorHabit]:
|
||||
"""Sort habits by confidence level and recency.
|
||||
|
||||
Args:
|
||||
habits: List of habits to sort
|
||||
|
||||
Returns:
|
||||
Sorted list of habits
|
||||
"""
|
||||
def priority_score(habit: BehaviorHabit) -> tuple:
|
||||
# Confidence level score (high=3, medium=2, low=1)
|
||||
confidence_score = {"high": 3, "medium": 2, "low": 1}.get(habit.confidence_level.value, 1)
|
||||
|
||||
# Recency score (more recent = higher score)
|
||||
days_since_last = (datetime.now() - habit.last_observed).days
|
||||
recency_score = max(0, 365 - days_since_last) # Max 365 days
|
||||
|
||||
# Current habit bonus
|
||||
current_bonus = 100 if habit.is_current else 0
|
||||
|
||||
return (confidence_score, recency_score + current_bonus, habit.last_observed)
|
||||
|
||||
return sorted(habits, key=priority_score, reverse=True)
|
||||
@@ -0,0 +1,277 @@
|
||||
"""Interest Analyzer for Implicit Memory System
|
||||
|
||||
This module implements LLM-based interest area analysis from user memory summaries.
|
||||
It categorizes user interests into four areas: tech, lifestyle, music, and art,
|
||||
providing percentage distribution that totals 100%.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from app.core.memory.analytics.implicit_memory.llm_client import ImplicitMemoryLLMClient
|
||||
from app.core.memory.llm_tools.llm_client import LLMClientException
|
||||
from app.schemas.implicit_memory_schema import (
|
||||
InterestAreaDistribution,
|
||||
InterestCategory,
|
||||
UserMemorySummary,
|
||||
)
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class InterestData(BaseModel):
|
||||
"""Individual interest category analysis data."""
|
||||
percentage: float = Field(ge=0.0, le=100.0)
|
||||
evidence: List[str] = Field(default_factory=list)
|
||||
trending_direction: Optional[str] = None
|
||||
|
||||
|
||||
class InterestAnalysisResponse(BaseModel):
|
||||
"""Response model for interest analysis."""
|
||||
interest_distribution: Dict[str, InterestData] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class InterestAnalyzer:
|
||||
"""Analyzes user memory summaries to extract interest area distribution."""
|
||||
|
||||
# Define the four interest categories we analyze
|
||||
INTEREST_CATEGORIES = ["tech", "lifestyle", "music", "art"]
|
||||
|
||||
def __init__(self, db: Session, llm_model_id: Optional[str] = None):
|
||||
"""Initialize the interest analyzer.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
llm_model_id: Optional LLM model ID to use for analysis
|
||||
"""
|
||||
self.db = db
|
||||
self.llm_model_id = llm_model_id
|
||||
self._llm_client = ImplicitMemoryLLMClient(db, llm_model_id)
|
||||
|
||||
async def analyze_interests(
|
||||
self,
|
||||
user_id: str,
|
||||
user_summaries: List[UserMemorySummary],
|
||||
existing_distribution: Optional[InterestAreaDistribution] = None
|
||||
) -> InterestAreaDistribution:
|
||||
"""Analyze user summaries to extract interest area distribution.
|
||||
|
||||
Args:
|
||||
user_id: Target user ID
|
||||
user_summaries: List of user-specific memory summaries
|
||||
existing_distribution: Optional existing distribution for trend tracking
|
||||
|
||||
Returns:
|
||||
Interest area distribution across four categories
|
||||
|
||||
Raises:
|
||||
LLMClientException: If LLM analysis fails
|
||||
"""
|
||||
if not user_summaries:
|
||||
logger.warning(f"No summaries provided for user {user_id}")
|
||||
return self._create_empty_distribution(user_id)
|
||||
|
||||
try:
|
||||
logger.info(f"Analyzing interests for user {user_id} with {len(user_summaries)} summaries")
|
||||
|
||||
# Use the LLM client wrapper for analysis
|
||||
response = await self._llm_client.analyze_interests(
|
||||
user_summaries=user_summaries,
|
||||
user_id=user_id,
|
||||
model_id=self.llm_model_id
|
||||
)
|
||||
|
||||
# Create interest categories
|
||||
interest_categories = {}
|
||||
current_time = datetime.now()
|
||||
|
||||
# Extract interest_distribution from response dict
|
||||
interest_distribution = response.get("interest_distribution", {})
|
||||
|
||||
# Extract and validate interest data
|
||||
raw_interests = {}
|
||||
for category_name in self.INTEREST_CATEGORIES:
|
||||
interest_data_dict = interest_distribution.get(category_name)
|
||||
if interest_data_dict:
|
||||
raw_interests[category_name] = InterestData(
|
||||
percentage=interest_data_dict.get("percentage", 0.0),
|
||||
evidence=interest_data_dict.get("evidence", []),
|
||||
trending_direction=interest_data_dict.get("trending_direction")
|
||||
)
|
||||
else:
|
||||
# Create default if missing
|
||||
logger.warning(f"Missing interest data for {category_name}, using default")
|
||||
raw_interests[category_name] = InterestData(
|
||||
percentage=0.0,
|
||||
evidence=["No specific evidence found"],
|
||||
trending_direction=None
|
||||
)
|
||||
|
||||
# Normalize percentages to ensure they sum to 100%
|
||||
normalized_interests = self._normalize_percentages(raw_interests)
|
||||
|
||||
# Create interest category objects
|
||||
for category_name in self.INTEREST_CATEGORIES:
|
||||
interest_data = normalized_interests[category_name]
|
||||
|
||||
# Calculate trending direction if we have existing data
|
||||
trending_direction = self._calculate_trending_direction(
|
||||
category_name=category_name,
|
||||
current_percentage=interest_data.percentage,
|
||||
existing_distribution=existing_distribution
|
||||
) if existing_distribution else interest_data.trending_direction
|
||||
|
||||
interest_categories[category_name] = InterestCategory(
|
||||
category_name=category_name,
|
||||
percentage=interest_data.percentage,
|
||||
evidence=interest_data.evidence if interest_data.evidence else ["No specific evidence found"],
|
||||
trending_direction=trending_direction
|
||||
)
|
||||
|
||||
# Create interest area distribution
|
||||
distribution = InterestAreaDistribution(
|
||||
user_id=user_id,
|
||||
tech=interest_categories["tech"],
|
||||
lifestyle=interest_categories["lifestyle"],
|
||||
music=interest_categories["music"],
|
||||
art=interest_categories["art"],
|
||||
analysis_timestamp=current_time,
|
||||
total_summaries_analyzed=len(user_summaries)
|
||||
)
|
||||
|
||||
# Validate that percentages sum to 100%
|
||||
total_percentage = distribution.total_percentage
|
||||
if not (99.9 <= total_percentage <= 100.1):
|
||||
logger.warning(f"Interest percentages sum to {total_percentage}, expected ~100%")
|
||||
|
||||
logger.info(f"Created interest distribution for user {user_id}")
|
||||
return distribution
|
||||
|
||||
except LLMClientException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Interest analysis failed for user {user_id}: {e}")
|
||||
raise LLMClientException(f"Interest analysis failed: {e}") from e
|
||||
|
||||
def _normalize_percentages(self, raw_interests: Dict[str, InterestData]) -> Dict[str, InterestData]:
|
||||
"""Normalize percentages to ensure they sum to 100%.
|
||||
|
||||
Args:
|
||||
raw_interests: Raw interest data with potentially unnormalized percentages
|
||||
|
||||
Returns:
|
||||
Normalized interest data
|
||||
"""
|
||||
# Calculate current total
|
||||
total = sum(interest.percentage for interest in raw_interests.values())
|
||||
|
||||
if total == 0:
|
||||
# If all percentages are 0, distribute equally
|
||||
equal_percentage = 100.0 / len(self.INTEREST_CATEGORIES)
|
||||
normalized = {}
|
||||
for category_name, interest_data in raw_interests.items():
|
||||
normalized[category_name] = InterestData(
|
||||
percentage=equal_percentage,
|
||||
evidence=interest_data.evidence,
|
||||
trending_direction=interest_data.trending_direction
|
||||
)
|
||||
return normalized
|
||||
|
||||
# Normalize to sum to 100%
|
||||
normalization_factor = 100.0 / total
|
||||
normalized = {}
|
||||
|
||||
for category_name, interest_data in raw_interests.items():
|
||||
normalized_percentage = interest_data.percentage * normalization_factor
|
||||
|
||||
normalized[category_name] = InterestData(
|
||||
percentage=round(normalized_percentage, 1),
|
||||
evidence=interest_data.evidence,
|
||||
trending_direction=interest_data.trending_direction
|
||||
)
|
||||
|
||||
# Handle rounding errors by adjusting the largest category
|
||||
current_total = sum(interest.percentage for interest in normalized.values())
|
||||
if abs(current_total - 100.0) > 0.1:
|
||||
# Find category with largest percentage and adjust
|
||||
largest_category = max(normalized.keys(), key=lambda k: normalized[k].percentage)
|
||||
adjustment = 100.0 - current_total
|
||||
|
||||
adjusted_percentage = normalized[largest_category].percentage + adjustment
|
||||
normalized[largest_category] = InterestData(
|
||||
percentage=round(max(0.0, adjusted_percentage), 1),
|
||||
evidence=normalized[largest_category].evidence,
|
||||
trending_direction=normalized[largest_category].trending_direction
|
||||
)
|
||||
|
||||
return normalized
|
||||
|
||||
def _calculate_trending_direction(
|
||||
self,
|
||||
category_name: str,
|
||||
current_percentage: float,
|
||||
existing_distribution: InterestAreaDistribution,
|
||||
threshold: float = 5.0
|
||||
) -> Optional[str]:
|
||||
"""Calculate trending direction for an interest category.
|
||||
|
||||
Args:
|
||||
category_name: Name of the interest category
|
||||
current_percentage: Current percentage for the category
|
||||
existing_distribution: Previous distribution for comparison
|
||||
threshold: Minimum percentage change to consider a trend
|
||||
|
||||
Returns:
|
||||
Trending direction: "increasing", "decreasing", "stable", or None
|
||||
"""
|
||||
try:
|
||||
# Get previous percentage
|
||||
previous_category = getattr(existing_distribution, category_name, None)
|
||||
if not previous_category:
|
||||
return None
|
||||
|
||||
previous_percentage = previous_category.percentage
|
||||
change = current_percentage - previous_percentage
|
||||
|
||||
if abs(change) < threshold:
|
||||
return "stable"
|
||||
elif change > 0:
|
||||
return "increasing"
|
||||
else:
|
||||
return "decreasing"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error calculating trending direction for {category_name}: {e}")
|
||||
return None
|
||||
|
||||
def _create_empty_distribution(self, user_id: str) -> InterestAreaDistribution:
|
||||
"""Create an empty interest distribution when no data is available.
|
||||
|
||||
Args:
|
||||
user_id: Target user ID
|
||||
|
||||
Returns:
|
||||
Empty InterestAreaDistribution with equal percentages
|
||||
"""
|
||||
current_time = datetime.now()
|
||||
equal_percentage = 25.0 # 100% / 4 categories
|
||||
|
||||
default_category = lambda name: InterestCategory(
|
||||
category_name=name,
|
||||
percentage=equal_percentage,
|
||||
evidence=["Insufficient data for analysis"],
|
||||
trending_direction=None
|
||||
)
|
||||
|
||||
return InterestAreaDistribution(
|
||||
user_id=user_id,
|
||||
tech=default_category("tech"),
|
||||
lifestyle=default_category("lifestyle"),
|
||||
music=default_category("music"),
|
||||
art=default_category("art"),
|
||||
analysis_timestamp=current_time,
|
||||
total_summaries_analyzed=0
|
||||
)
|
||||
@@ -0,0 +1,302 @@
|
||||
"""Preference Analyzer for Implicit Memory System
|
||||
|
||||
This module implements LLM-based preference extraction from user memory summaries.
|
||||
It identifies implicit preferences, consolidates similar preferences, and calculates
|
||||
confidence scores based on evidence strength.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from app.core.memory.analytics.implicit_memory.llm_client import ImplicitMemoryLLMClient
|
||||
from app.core.memory.llm_tools.llm_client import LLMClientException
|
||||
from app.schemas.implicit_memory_schema import (
|
||||
PreferenceTag,
|
||||
UserMemorySummary,
|
||||
)
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PreferenceAnalysisResponse(BaseModel):
|
||||
"""Response model for preference analysis."""
|
||||
preferences: List[Dict[str, Any]] = Field(default_factory=list)
|
||||
|
||||
|
||||
class PreferenceAnalyzer:
|
||||
"""Analyzes user memory summaries to extract implicit preferences."""
|
||||
|
||||
def __init__(self, db: Session, llm_model_id: Optional[str] = None):
|
||||
"""Initialize the preference analyzer.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
llm_model_id: Optional LLM model ID to use for analysis
|
||||
"""
|
||||
self.db = db
|
||||
self.llm_model_id = llm_model_id
|
||||
self._llm_client = ImplicitMemoryLLMClient(db, llm_model_id)
|
||||
|
||||
async def analyze_preferences(
|
||||
self,
|
||||
user_id: str,
|
||||
user_summaries: List[UserMemorySummary],
|
||||
existing_preferences: Optional[List[PreferenceTag]] = None
|
||||
) -> List[PreferenceTag]:
|
||||
"""Analyze user summaries to extract preferences.
|
||||
|
||||
Args:
|
||||
user_id: Target user ID
|
||||
user_summaries: List of user-specific memory summaries
|
||||
existing_preferences: Optional existing preferences for consolidation
|
||||
|
||||
Returns:
|
||||
List of extracted preference tags
|
||||
|
||||
Raises:
|
||||
LLMClientException: If LLM analysis fails
|
||||
"""
|
||||
if not user_summaries:
|
||||
logger.warning(f"No summaries provided for user {user_id}")
|
||||
return []
|
||||
|
||||
try:
|
||||
logger.info(f"Analyzing preferences for user {user_id} with {len(user_summaries)} summaries")
|
||||
|
||||
# Use the LLM client wrapper for analysis
|
||||
response = await self._llm_client.analyze_preferences(
|
||||
user_summaries=user_summaries,
|
||||
user_id=user_id,
|
||||
model_id=self.llm_model_id
|
||||
)
|
||||
|
||||
# Convert to PreferenceTag objects
|
||||
preference_tags = []
|
||||
current_time = datetime.now()
|
||||
|
||||
for pref_data in response.get("preferences", []):
|
||||
try:
|
||||
# Extract conversation references from summaries
|
||||
conversation_refs = [s.summary_id for s in user_summaries]
|
||||
|
||||
preference_tag = PreferenceTag(
|
||||
tag_name=pref_data.get("tag_name", ""),
|
||||
confidence_score=float(pref_data.get("confidence_score", 0.0)),
|
||||
supporting_evidence=pref_data.get("supporting_evidence", []),
|
||||
context_details=pref_data.get("context_details", ""),
|
||||
category=pref_data.get("category"),
|
||||
conversation_references=conversation_refs,
|
||||
created_at=current_time,
|
||||
updated_at=current_time
|
||||
)
|
||||
|
||||
# Validate preference tag
|
||||
if self._is_valid_preference(preference_tag):
|
||||
preference_tags.append(preference_tag)
|
||||
else:
|
||||
logger.warning(f"Invalid preference tag skipped: {preference_tag.tag_name}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating preference tag: {e}")
|
||||
continue
|
||||
|
||||
# Consolidate with existing preferences if provided
|
||||
if existing_preferences:
|
||||
preference_tags = self._consolidate_preferences(
|
||||
new_preferences=preference_tags,
|
||||
existing_preferences=existing_preferences
|
||||
)
|
||||
|
||||
logger.info(f"Extracted {len(preference_tags)} preferences for user {user_id}")
|
||||
return preference_tags
|
||||
|
||||
except LLMClientException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Preference analysis failed for user {user_id}: {e}")
|
||||
raise LLMClientException(f"Preference analysis failed: {e}") from e
|
||||
|
||||
def _is_valid_preference(self, preference: PreferenceTag) -> bool:
|
||||
"""Validate a preference tag.
|
||||
|
||||
Args:
|
||||
preference: Preference tag to validate
|
||||
|
||||
Returns:
|
||||
True if valid, False otherwise
|
||||
"""
|
||||
try:
|
||||
# Check required fields
|
||||
if not preference.tag_name or not preference.tag_name.strip():
|
||||
return False
|
||||
|
||||
# Check confidence score range
|
||||
if not (0.0 <= preference.confidence_score <= 1.0):
|
||||
return False
|
||||
|
||||
# Check supporting evidence
|
||||
if not preference.supporting_evidence or len(preference.supporting_evidence) == 0:
|
||||
return False
|
||||
|
||||
# Check context details
|
||||
if not preference.context_details or not preference.context_details.strip():
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error validating preference: {e}")
|
||||
return False
|
||||
|
||||
def _consolidate_preferences(
|
||||
self,
|
||||
new_preferences: List[PreferenceTag],
|
||||
existing_preferences: List[PreferenceTag],
|
||||
similarity_threshold: float = 0.8
|
||||
) -> List[PreferenceTag]:
|
||||
"""Consolidate new preferences with existing ones.
|
||||
|
||||
Args:
|
||||
new_preferences: Newly extracted preferences
|
||||
existing_preferences: Existing preferences
|
||||
similarity_threshold: Threshold for considering preferences similar
|
||||
|
||||
Returns:
|
||||
Consolidated list of preferences
|
||||
"""
|
||||
consolidated = existing_preferences.copy()
|
||||
current_time = datetime.now()
|
||||
|
||||
for new_pref in new_preferences:
|
||||
# Find similar existing preference
|
||||
similar_pref = self._find_similar_preference(
|
||||
new_pref, existing_preferences, similarity_threshold
|
||||
)
|
||||
|
||||
if similar_pref:
|
||||
# Update existing preference
|
||||
updated_pref = self._merge_preferences(similar_pref, new_pref, current_time)
|
||||
# Replace in consolidated list
|
||||
for i, pref in enumerate(consolidated):
|
||||
if pref.tag_name == similar_pref.tag_name:
|
||||
consolidated[i] = updated_pref
|
||||
break
|
||||
else:
|
||||
# Add as new preference
|
||||
consolidated.append(new_pref)
|
||||
|
||||
return consolidated
|
||||
|
||||
def _find_similar_preference(
|
||||
self,
|
||||
target_preference: PreferenceTag,
|
||||
existing_preferences: List[PreferenceTag],
|
||||
threshold: float
|
||||
) -> Optional[PreferenceTag]:
|
||||
"""Find similar preference in existing list.
|
||||
|
||||
Args:
|
||||
target_preference: Preference to find similarity for
|
||||
existing_preferences: List of existing preferences
|
||||
threshold: Similarity threshold
|
||||
|
||||
Returns:
|
||||
Similar preference if found, None otherwise
|
||||
"""
|
||||
target_name = target_preference.tag_name.lower().strip()
|
||||
|
||||
for existing_pref in existing_preferences:
|
||||
existing_name = existing_pref.tag_name.lower().strip()
|
||||
|
||||
# Simple similarity check based on common words
|
||||
similarity = self._calculate_text_similarity(target_name, existing_name)
|
||||
|
||||
if similarity >= threshold:
|
||||
return existing_pref
|
||||
|
||||
return None
|
||||
|
||||
def _calculate_text_similarity(self, text1: str, text2: str) -> float:
|
||||
"""Calculate simple text similarity based on common words.
|
||||
|
||||
Args:
|
||||
text1: First text
|
||||
text2: Second text
|
||||
|
||||
Returns:
|
||||
Similarity score between 0.0 and 1.0
|
||||
"""
|
||||
if not text1 or not text2:
|
||||
return 0.0
|
||||
|
||||
# Simple word-based similarity
|
||||
words1 = set(text1.lower().split())
|
||||
words2 = set(text2.lower().split())
|
||||
|
||||
if not words1 or not words2:
|
||||
return 0.0
|
||||
|
||||
intersection = words1.intersection(words2)
|
||||
union = words1.union(words2)
|
||||
|
||||
return len(intersection) / len(union) if union else 0.0
|
||||
|
||||
def _merge_preferences(
|
||||
self,
|
||||
existing_pref: PreferenceTag,
|
||||
new_pref: PreferenceTag,
|
||||
current_time: datetime
|
||||
) -> PreferenceTag:
|
||||
"""Merge two similar preferences.
|
||||
|
||||
Args:
|
||||
existing_pref: Existing preference
|
||||
new_pref: New preference to merge
|
||||
current_time: Current timestamp
|
||||
|
||||
Returns:
|
||||
Merged preference tag
|
||||
"""
|
||||
# Combine supporting evidence
|
||||
combined_evidence = list(set(
|
||||
existing_pref.supporting_evidence + new_pref.supporting_evidence
|
||||
))
|
||||
|
||||
# Combine conversation references
|
||||
combined_refs = list(set(
|
||||
existing_pref.conversation_references + new_pref.conversation_references
|
||||
))
|
||||
|
||||
# Calculate new confidence score (weighted average)
|
||||
evidence_weight = len(new_pref.supporting_evidence)
|
||||
total_weight = len(existing_pref.supporting_evidence) + evidence_weight
|
||||
|
||||
if total_weight > 0:
|
||||
new_confidence = (
|
||||
(existing_pref.confidence_score * len(existing_pref.supporting_evidence) +
|
||||
new_pref.confidence_score * evidence_weight) / total_weight
|
||||
)
|
||||
else:
|
||||
new_confidence = max(existing_pref.confidence_score, new_pref.confidence_score)
|
||||
|
||||
# Ensure confidence doesn't exceed 1.0
|
||||
new_confidence = min(new_confidence, 1.0)
|
||||
|
||||
# Combine context details
|
||||
combined_context = existing_pref.context_details
|
||||
if new_pref.context_details and new_pref.context_details not in combined_context:
|
||||
combined_context += f"; {new_pref.context_details}"
|
||||
|
||||
return PreferenceTag(
|
||||
tag_name=existing_pref.tag_name, # Keep original name
|
||||
confidence_score=new_confidence,
|
||||
supporting_evidence=combined_evidence,
|
||||
context_details=combined_context,
|
||||
category=existing_pref.category or new_pref.category,
|
||||
conversation_references=combined_refs,
|
||||
created_at=existing_pref.created_at,
|
||||
updated_at=current_time
|
||||
)
|
||||
97
api/app/core/memory/analytics/implicit_memory/data_source.py
Normal file
97
api/app/core/memory/analytics/implicit_memory/data_source.py
Normal file
@@ -0,0 +1,97 @@
|
||||
"""
|
||||
Memory Data Source
|
||||
|
||||
Handles retrieval and processing of memory data from Neo4j using direct Cypher queries.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from app.repositories.neo4j.memory_summary_repository import MemorySummaryRepository
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.schemas.implicit_memory_schema import TimeRange, UserMemorySummary
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MemoryDataSource:
|
||||
"""Retrieves processed memory data from Neo4j using direct Cypher queries."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
db: Session,
|
||||
neo4j_connector: Optional[Neo4jConnector] = None
|
||||
):
|
||||
self.db = db
|
||||
self.neo4j_connector = neo4j_connector or Neo4jConnector()
|
||||
self.memory_summary_repo = MemorySummaryRepository(self.neo4j_connector)
|
||||
|
||||
def _parse_timestamp(self, timestamp: Any) -> datetime:
|
||||
"""Parse timestamp from various formats."""
|
||||
if isinstance(timestamp, str):
|
||||
return datetime.fromisoformat(timestamp.replace('Z', '+00:00'))
|
||||
elif timestamp is None:
|
||||
return datetime.now()
|
||||
return timestamp
|
||||
|
||||
def _dict_to_user_summary(self, summary_dict: Dict, user_id: str) -> Optional[UserMemorySummary]:
|
||||
"""Convert a Neo4j dict directly to UserMemorySummary."""
|
||||
try:
|
||||
content = summary_dict.get("content", summary_dict.get("summary", ""))
|
||||
if not content or not content.strip():
|
||||
return None
|
||||
|
||||
return UserMemorySummary(
|
||||
summary_id=summary_dict.get("id", summary_dict.get("uuid", "")),
|
||||
user_id=user_id,
|
||||
user_content=content,
|
||||
timestamp=self._parse_timestamp(summary_dict.get("created_at")),
|
||||
confidence_score=1.0,
|
||||
summary_type="memory_summary"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to parse summary {summary_dict.get('id', 'unknown')}: {e}")
|
||||
return None
|
||||
|
||||
async def get_user_summaries(
|
||||
self,
|
||||
user_id: str,
|
||||
time_range: Optional[TimeRange] = None,
|
||||
limit: int = 1000
|
||||
) -> List[UserMemorySummary]:
|
||||
"""Retrieve user memory summaries from Neo4j.
|
||||
|
||||
Args:
|
||||
user_id: Target user ID
|
||||
time_range: Optional time range filter
|
||||
limit: Maximum number of summaries
|
||||
|
||||
Returns:
|
||||
List of user memory summaries
|
||||
"""
|
||||
try:
|
||||
start_date = time_range.start_date if time_range else None
|
||||
end_date = time_range.end_date if time_range else None
|
||||
|
||||
summary_dicts = await self.memory_summary_repo.find_by_group_id(
|
||||
group_id=user_id,
|
||||
limit=limit,
|
||||
start_date=start_date,
|
||||
end_date=end_date
|
||||
)
|
||||
|
||||
summaries = []
|
||||
for summary_dict in summary_dicts:
|
||||
summary = self._dict_to_user_summary(summary_dict, user_id)
|
||||
if summary:
|
||||
summaries.append(summary)
|
||||
|
||||
logger.info(f"Retrieved {len(summaries)} summaries for user {user_id}")
|
||||
return summaries
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to retrieve summaries for user {user_id}: {e}")
|
||||
raise
|
||||
|
||||
234
api/app/core/memory/analytics/implicit_memory/habit_detector.py
Normal file
234
api/app/core/memory/analytics/implicit_memory/habit_detector.py
Normal file
@@ -0,0 +1,234 @@
|
||||
"""Habit Detector for Implicit Memory System
|
||||
|
||||
This module implements the HabitDetector class that specializes in identifying
|
||||
and ranking behavioral habits from user memory summaries. It provides advanced
|
||||
habit analysis with confidence scoring, recency weighting, and current vs past
|
||||
habit distinction.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import datetime, timedelta
|
||||
from typing import List, Optional
|
||||
|
||||
from app.core.memory.analytics.implicit_memory.analyzers.habit_analyzer import (
|
||||
HabitAnalyzer,
|
||||
)
|
||||
from app.core.memory.llm_tools.llm_client import LLMClientException
|
||||
from app.schemas.implicit_memory_schema import (
|
||||
BehaviorHabit,
|
||||
ConfidenceLevel,
|
||||
FrequencyPattern,
|
||||
UserMemorySummary,
|
||||
)
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HabitDetector:
|
||||
"""Detects and ranks behavioral habits from user memory summaries."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
db: Session,
|
||||
llm_model_id: Optional[str] = None
|
||||
):
|
||||
"""Initialize the habit detector.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
llm_model_id: Optional LLM model ID to use for analysis
|
||||
"""
|
||||
self.db = db
|
||||
self.llm_model_id = llm_model_id
|
||||
self.habit_analyzer = HabitAnalyzer(db, llm_model_id)
|
||||
|
||||
async def detect_habits(
|
||||
self,
|
||||
user_id: str,
|
||||
user_summaries: List[UserMemorySummary],
|
||||
existing_habits: Optional[List[BehaviorHabit]] = None
|
||||
) -> List[BehaviorHabit]:
|
||||
"""Detect behavioral habits from user summaries.
|
||||
|
||||
Args:
|
||||
user_id: Target user ID
|
||||
user_summaries: List of user-specific memory summaries
|
||||
existing_habits: Optional existing habits for consolidation
|
||||
|
||||
Returns:
|
||||
List of detected and ranked behavioral habits
|
||||
|
||||
Raises:
|
||||
LLMClientException: If habit analysis fails
|
||||
"""
|
||||
if not user_summaries:
|
||||
logger.warning(f"No summaries provided for user {user_id}")
|
||||
return existing_habits or []
|
||||
|
||||
logger.info(f"Detecting habits for user {user_id} with {len(user_summaries)} summaries")
|
||||
|
||||
try:
|
||||
# Use the habit analyzer to extract habits
|
||||
detected_habits = await self.habit_analyzer.analyze_habits(
|
||||
user_id=user_id,
|
||||
user_summaries=user_summaries,
|
||||
existing_habits=existing_habits
|
||||
)
|
||||
|
||||
# Apply advanced ranking and filtering
|
||||
ranked_habits = self.rank_habits_by_confidence_and_recency(detected_habits)
|
||||
|
||||
# Distinguish current vs past habits
|
||||
categorized_habits = self.distinguish_current_vs_past_habits(ranked_habits)
|
||||
|
||||
logger.info(f"Detected {len(categorized_habits)} habits for user {user_id}")
|
||||
return categorized_habits
|
||||
|
||||
except LLMClientException:
|
||||
logger.error(f"Habit detection failed for user {user_id}")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Habit detection failed for user {user_id}: {e}")
|
||||
raise LLMClientException(f"Habit detection failed: {e}") from e
|
||||
|
||||
def rank_habits_by_confidence_and_recency(
|
||||
self,
|
||||
habits: List[BehaviorHabit],
|
||||
confidence_weight: float = 0.6,
|
||||
recency_weight: float = 0.4
|
||||
) -> List[BehaviorHabit]:
|
||||
"""Rank habits by confidence level and recency.
|
||||
|
||||
Args:
|
||||
habits: List of habits to rank
|
||||
confidence_weight: Weight for confidence score (0.0-1.0)
|
||||
recency_weight: Weight for recency score (0.0-1.0)
|
||||
|
||||
Returns:
|
||||
List of habits ranked by combined score
|
||||
"""
|
||||
if not habits:
|
||||
return []
|
||||
|
||||
logger.info(f"Ranking {len(habits)} habits by confidence and recency")
|
||||
|
||||
def calculate_ranking_score(habit: BehaviorHabit) -> float:
|
||||
"""Calculate combined ranking score for a habit."""
|
||||
|
||||
# Confidence score (0.0-1.0)
|
||||
confidence_scores = {
|
||||
ConfidenceLevel.HIGH: 1.0,
|
||||
ConfidenceLevel.MEDIUM: 0.6,
|
||||
ConfidenceLevel.LOW: 0.3
|
||||
}
|
||||
confidence_score = confidence_scores.get(habit.confidence_level, 0.3)
|
||||
|
||||
# Recency score (0.0-1.0)
|
||||
current_time = datetime.now()
|
||||
days_since_last = (current_time - habit.last_observed).days
|
||||
|
||||
# Exponential decay for recency (habits lose relevance over time)
|
||||
if days_since_last <= 7:
|
||||
recency_score = 1.0 # Very recent
|
||||
elif days_since_last <= 30:
|
||||
recency_score = 0.8 # Recent
|
||||
elif days_since_last <= 90:
|
||||
recency_score = 0.5 # Somewhat recent
|
||||
elif days_since_last <= 180:
|
||||
recency_score = 0.3 # Old
|
||||
else:
|
||||
recency_score = 0.1 # Very old
|
||||
|
||||
# Frequency pattern bonus
|
||||
frequency_bonuses = {
|
||||
FrequencyPattern.DAILY: 0.2,
|
||||
FrequencyPattern.WEEKLY: 0.15,
|
||||
FrequencyPattern.MONTHLY: 0.1,
|
||||
FrequencyPattern.SEASONAL: 0.05,
|
||||
FrequencyPattern.OCCASIONAL: 0.0,
|
||||
FrequencyPattern.EVENT_TRIGGERED: 0.05
|
||||
}
|
||||
frequency_bonus = frequency_bonuses.get(habit.frequency_pattern, 0.0)
|
||||
|
||||
# Evidence quality bonus
|
||||
evidence_bonus = min(len(habit.supporting_summaries) / 10.0, 0.1) # Max 0.1 bonus
|
||||
|
||||
# Current habit bonus
|
||||
current_bonus = 0.1 if habit.is_current else 0.0
|
||||
|
||||
# Calculate final score
|
||||
base_score = (confidence_score * confidence_weight +
|
||||
recency_score * recency_weight)
|
||||
|
||||
final_score = base_score + frequency_bonus + evidence_bonus + current_bonus
|
||||
|
||||
return min(final_score, 1.0) # Cap at 1.0
|
||||
|
||||
# Sort habits by ranking score (descending)
|
||||
ranked_habits = sorted(habits, key=calculate_ranking_score, reverse=True)
|
||||
|
||||
logger.info(f"Ranked habits with scores: {[calculate_ranking_score(h) for h in ranked_habits[:5]]}")
|
||||
|
||||
return ranked_habits
|
||||
|
||||
def distinguish_current_vs_past_habits(
|
||||
self,
|
||||
habits: List[BehaviorHabit],
|
||||
current_threshold_days: int = 30
|
||||
) -> List[BehaviorHabit]:
|
||||
"""Distinguish between current and past habits based on recency.
|
||||
|
||||
Args:
|
||||
habits: List of habits to categorize
|
||||
current_threshold_days: Days threshold for considering a habit current
|
||||
|
||||
Returns:
|
||||
List of habits with updated is_current status
|
||||
"""
|
||||
if not habits:
|
||||
return []
|
||||
|
||||
current_time = datetime.now()
|
||||
cutoff_date = current_time - timedelta(days=current_threshold_days)
|
||||
|
||||
current_habits = []
|
||||
past_habits = []
|
||||
|
||||
for habit in habits:
|
||||
# Update is_current status based on last observation
|
||||
if habit.last_observed >= cutoff_date:
|
||||
# Create updated habit with is_current = True
|
||||
updated_habit = BehaviorHabit(
|
||||
habit_description=habit.habit_description,
|
||||
frequency_pattern=habit.frequency_pattern,
|
||||
time_context=habit.time_context,
|
||||
confidence_level=habit.confidence_level,
|
||||
supporting_summaries=habit.supporting_summaries,
|
||||
specific_examples=habit.specific_examples,
|
||||
first_observed=habit.first_observed,
|
||||
last_observed=habit.last_observed,
|
||||
is_current=True
|
||||
)
|
||||
current_habits.append(updated_habit)
|
||||
else:
|
||||
# Create updated habit with is_current = False
|
||||
updated_habit = BehaviorHabit(
|
||||
habit_description=habit.habit_description,
|
||||
frequency_pattern=habit.frequency_pattern,
|
||||
time_context=habit.time_context,
|
||||
confidence_level=habit.confidence_level,
|
||||
supporting_summaries=habit.supporting_summaries,
|
||||
specific_examples=habit.specific_examples,
|
||||
first_observed=habit.first_observed,
|
||||
last_observed=habit.last_observed,
|
||||
is_current=False
|
||||
)
|
||||
past_habits.append(updated_habit)
|
||||
|
||||
# Return current habits first, then past habits
|
||||
categorized_habits = current_habits + past_habits
|
||||
|
||||
logger.info(f"Categorized habits: {len(current_habits)} current, {len(past_habits)} past")
|
||||
|
||||
return categorized_habits
|
||||
321
api/app/core/memory/analytics/implicit_memory/llm_client.py
Normal file
321
api/app/core/memory/analytics/implicit_memory/llm_client.py
Normal file
@@ -0,0 +1,321 @@
|
||||
"""LLM Client Wrapper for Implicit Memory Analysis
|
||||
|
||||
This module provides a specialized LLM client wrapper that integrates with the
|
||||
MemoryClientFactory to perform implicit memory analysis tasks including preference
|
||||
extraction, personality dimension analysis, interest categorization, and habit detection.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from app.core.memory.analytics.implicit_memory.prompts import (
|
||||
get_dimension_analysis_prompt,
|
||||
get_habit_analysis_prompt,
|
||||
get_interest_analysis_prompt,
|
||||
get_preference_analysis_prompt,
|
||||
)
|
||||
from app.core.memory.llm_tools.llm_client import LLMClientException
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.schemas.implicit_memory_schema import UserMemorySummary
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Response Models for LLM Analysis
|
||||
|
||||
class PreferenceAnalysisResponse(BaseModel):
|
||||
"""Response model for preference analysis."""
|
||||
preferences: List[Dict[str, Any]] = Field(default_factory=list)
|
||||
|
||||
|
||||
class DimensionAnalysisResponse(BaseModel):
|
||||
"""Response model for dimension analysis."""
|
||||
dimensions: Dict[str, Dict[str, Any]] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class InterestAnalysisResponse(BaseModel):
|
||||
"""Response model for interest analysis."""
|
||||
interest_distribution: Dict[str, Dict[str, Any]] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class HabitAnalysisResponse(BaseModel):
|
||||
"""Response model for habit analysis."""
|
||||
habits: List[Dict[str, Any]] = Field(default_factory=list)
|
||||
|
||||
|
||||
class ImplicitMemoryLLMClient:
|
||||
"""LLM client wrapper for implicit memory analysis.
|
||||
|
||||
This class provides a high-level interface for performing LLM-based analysis
|
||||
of user memory summaries to extract preferences, personality dimensions,
|
||||
interests, and behavioral habits.
|
||||
"""
|
||||
|
||||
def __init__(self, db: Session, default_model_id: Optional[str] = None):
|
||||
"""Initialize the LLM client wrapper.
|
||||
|
||||
Args:
|
||||
db: Database session for accessing model configurations
|
||||
default_model_id: Default LLM model ID to use if none specified
|
||||
"""
|
||||
self.db = db
|
||||
self.default_model_id = default_model_id
|
||||
self._client_factory = MemoryClientFactory(db)
|
||||
|
||||
logger.info("ImplicitMemoryLLMClient initialized")
|
||||
|
||||
def _get_llm_client(self, model_id: Optional[str] = None):
|
||||
"""Get LLM client instance.
|
||||
|
||||
Args:
|
||||
model_id: LLM model ID to use, defaults to default_model_id
|
||||
|
||||
Returns:
|
||||
LLM client instance
|
||||
|
||||
Raises:
|
||||
ValueError: If no model ID is provided and no default is set
|
||||
LLMClientException: If client creation fails
|
||||
"""
|
||||
effective_model_id = model_id or self.default_model_id
|
||||
if not effective_model_id:
|
||||
raise ValueError("No LLM model ID provided and no default model ID set")
|
||||
|
||||
try:
|
||||
client = self._client_factory.get_llm_client(effective_model_id)
|
||||
logger.debug(f"Created LLM client for model: {effective_model_id}")
|
||||
return client
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create LLM client for model {effective_model_id}: {e}")
|
||||
raise LLMClientException(f"Failed to create LLM client: {e}") from e
|
||||
|
||||
def _prepare_summaries_for_analysis(self, user_summaries: List[UserMemorySummary]) -> List[Dict[str, Any]]:
|
||||
"""Prepare user memory summaries for LLM analysis.
|
||||
|
||||
Args:
|
||||
user_summaries: List of user memory summaries
|
||||
|
||||
Returns:
|
||||
List of formatted summary dictionaries
|
||||
"""
|
||||
formatted_summaries = []
|
||||
for summary in user_summaries:
|
||||
formatted_summary = {
|
||||
'summary_id': summary.summary_id,
|
||||
'user_content': summary.user_content,
|
||||
'timestamp': summary.timestamp.isoformat(),
|
||||
'summary_type': summary.summary_type,
|
||||
'confidence_score': summary.confidence_score
|
||||
}
|
||||
formatted_summaries.append(formatted_summary)
|
||||
|
||||
logger.debug(f"Prepared {len(formatted_summaries)} summaries for analysis")
|
||||
return formatted_summaries
|
||||
|
||||
async def analyze_preferences(
|
||||
self,
|
||||
user_summaries: List[UserMemorySummary],
|
||||
user_id: str,
|
||||
model_id: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Analyze user preferences from memory summaries.
|
||||
|
||||
Args:
|
||||
user_summaries: List of user memory summaries to analyze
|
||||
user_id: Target user ID for analysis
|
||||
model_id: Optional LLM model ID to use
|
||||
|
||||
Returns:
|
||||
Dictionary containing extracted preferences
|
||||
|
||||
Raises:
|
||||
LLMClientException: If LLM analysis fails
|
||||
ValueError: If input validation fails
|
||||
"""
|
||||
if not user_summaries:
|
||||
logger.warning(f"No summaries provided for preference analysis of user {user_id}")
|
||||
return {"preferences": []}
|
||||
|
||||
if not user_id:
|
||||
raise ValueError("User ID is required for preference analysis")
|
||||
|
||||
try:
|
||||
# Prepare summaries and get prompt
|
||||
formatted_summaries = self._prepare_summaries_for_analysis(user_summaries)
|
||||
prompt = get_preference_analysis_prompt(formatted_summaries, user_id)
|
||||
|
||||
# Get LLM client and perform analysis
|
||||
llm_client = self._get_llm_client(model_id)
|
||||
|
||||
messages = [{"role": "user", "content": prompt}]
|
||||
|
||||
# Use structured output for reliable parsing
|
||||
response = await llm_client.response_structured(
|
||||
messages=messages,
|
||||
response_model=PreferenceAnalysisResponse
|
||||
)
|
||||
|
||||
result = response.model_dump()
|
||||
logger.info(f"Analyzed preferences for user {user_id}: found {len(result.get('preferences', []))} preferences")
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Preference analysis failed for user {user_id}: {e}")
|
||||
raise LLMClientException(f"Preference analysis failed: {e}") from e
|
||||
|
||||
async def analyze_dimensions(
|
||||
self,
|
||||
user_summaries: List[UserMemorySummary],
|
||||
user_id: str,
|
||||
model_id: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Analyze user personality dimensions from memory summaries.
|
||||
|
||||
Args:
|
||||
user_summaries: List of user memory summaries to analyze
|
||||
user_id: Target user ID for analysis
|
||||
model_id: Optional LLM model ID to use
|
||||
|
||||
Returns:
|
||||
Dictionary containing dimension scores and analysis
|
||||
|
||||
Raises:
|
||||
LLMClientException: If LLM analysis fails
|
||||
ValueError: If input validation fails
|
||||
"""
|
||||
if not user_summaries:
|
||||
logger.warning(f"No summaries provided for dimension analysis of user {user_id}")
|
||||
return {"dimensions": {}}
|
||||
|
||||
if not user_id:
|
||||
raise ValueError("User ID is required for dimension analysis")
|
||||
|
||||
try:
|
||||
# Prepare summaries and get prompt
|
||||
formatted_summaries = self._prepare_summaries_for_analysis(user_summaries)
|
||||
prompt = get_dimension_analysis_prompt(formatted_summaries, user_id)
|
||||
|
||||
# Get LLM client and perform analysis
|
||||
llm_client = self._get_llm_client(model_id)
|
||||
|
||||
messages = [{"role": "user", "content": prompt}]
|
||||
|
||||
# Use structured output for reliable parsing
|
||||
response = await llm_client.response_structured(
|
||||
messages=messages,
|
||||
response_model=DimensionAnalysisResponse
|
||||
)
|
||||
|
||||
result = response.model_dump()
|
||||
dimensions = result.get('dimensions', {})
|
||||
logger.info(f"Analyzed dimensions for user {user_id}: {list(dimensions.keys())}")
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Dimension analysis failed for user {user_id}: {e}")
|
||||
raise LLMClientException(f"Dimension analysis failed: {e}") from e
|
||||
|
||||
async def analyze_interests(
|
||||
self,
|
||||
user_summaries: List[UserMemorySummary],
|
||||
user_id: str,
|
||||
model_id: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Analyze user interest distribution from memory summaries.
|
||||
|
||||
Args:
|
||||
user_summaries: List of user memory summaries to analyze
|
||||
user_id: Target user ID for analysis
|
||||
model_id: Optional LLM model ID to use
|
||||
|
||||
Returns:
|
||||
Dictionary containing interest area distribution
|
||||
|
||||
Raises:
|
||||
LLMClientException: If LLM analysis fails
|
||||
ValueError: If input validation fails
|
||||
"""
|
||||
if not user_summaries:
|
||||
logger.warning(f"No summaries provided for interest analysis of user {user_id}")
|
||||
return {"interest_distribution": {}}
|
||||
|
||||
if not user_id:
|
||||
raise ValueError("User ID is required for interest analysis")
|
||||
|
||||
try:
|
||||
# Prepare summaries and get prompt
|
||||
formatted_summaries = self._prepare_summaries_for_analysis(user_summaries)
|
||||
prompt = get_interest_analysis_prompt(formatted_summaries, user_id)
|
||||
|
||||
# Get LLM client and perform analysis
|
||||
llm_client = self._get_llm_client(model_id)
|
||||
|
||||
messages = [{"role": "user", "content": prompt}]
|
||||
|
||||
# Use structured output for reliable parsing
|
||||
response = await llm_client.response_structured(
|
||||
messages=messages,
|
||||
response_model=InterestAnalysisResponse
|
||||
)
|
||||
|
||||
result = response.model_dump()
|
||||
interest_dist = result.get('interest_distribution', {})
|
||||
logger.info(f"Analyzed interests for user {user_id}: {list(interest_dist.keys())}")
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Interest analysis failed for user {user_id}: {e}")
|
||||
raise LLMClientException(f"Interest analysis failed: {e}") from e
|
||||
|
||||
async def analyze_habits(
|
||||
self,
|
||||
user_summaries: List[UserMemorySummary],
|
||||
user_id: str,
|
||||
model_id: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Analyze user behavioral habits from memory summaries.
|
||||
|
||||
Args:
|
||||
user_summaries: List of user memory summaries to analyze
|
||||
user_id: Target user ID for analysis
|
||||
model_id: Optional LLM model ID to use
|
||||
|
||||
Returns:
|
||||
Dictionary containing identified behavioral habits
|
||||
|
||||
Raises:
|
||||
LLMClientException: If LLM analysis fails
|
||||
ValueError: If input validation fails
|
||||
"""
|
||||
if not user_summaries:
|
||||
logger.warning(f"No summaries provided for habit analysis of user {user_id}")
|
||||
return {"habits": []}
|
||||
|
||||
if not user_id:
|
||||
raise ValueError("User ID is required for habit analysis")
|
||||
|
||||
try:
|
||||
# Prepare summaries and get prompt
|
||||
formatted_summaries = self._prepare_summaries_for_analysis(user_summaries)
|
||||
prompt = get_habit_analysis_prompt(formatted_summaries, user_id)
|
||||
|
||||
# Get LLM client and perform analysis
|
||||
llm_client = self._get_llm_client(model_id)
|
||||
|
||||
messages = [{"role": "user", "content": prompt}]
|
||||
|
||||
# Use structured output for reliable parsing
|
||||
response = await llm_client.response_structured(
|
||||
messages=messages,
|
||||
response_model=HabitAnalysisResponse
|
||||
)
|
||||
|
||||
result = response.model_dump()
|
||||
logger.info(f"Analyzed habits for user {user_id}: found {len(result.get('habits', []))} habits")
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Habit analysis failed for user {user_id}: {e}")
|
||||
raise LLMClientException(f"Habit analysis failed: {e}") from e
|
||||
69
api/app/core/memory/analytics/implicit_memory/prompts.py
Normal file
69
api/app/core/memory/analytics/implicit_memory/prompts.py
Normal file
@@ -0,0 +1,69 @@
|
||||
"""LLM Prompt Templates for Implicit Memory Analysis
|
||||
|
||||
This module contains prompt rendering functions for analyzing user memory summaries
|
||||
to extract preferences, personality dimensions, interests, and behavioral habits.
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from jinja2 import Environment, FileSystemLoader
|
||||
|
||||
# Setup Jinja2 environment
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
prompt_dir = os.path.join(current_dir, "prompts")
|
||||
prompt_env = Environment(loader=FileSystemLoader(prompt_dir))
|
||||
|
||||
|
||||
def _render_template(template_name: str, **kwargs) -> str:
|
||||
"""Helper function to render Jinja2 templates."""
|
||||
template = prompt_env.get_template(template_name)
|
||||
return template.render(**kwargs)
|
||||
|
||||
|
||||
def get_preference_analysis_prompt(
|
||||
memory_summaries: List[Dict[str, Any]],
|
||||
user_id: str
|
||||
) -> str:
|
||||
"""Get formatted preference analysis prompt using Jinja2 template."""
|
||||
return _render_template(
|
||||
"preference_analysis.jinja2",
|
||||
memory_summaries=memory_summaries,
|
||||
user_id=user_id
|
||||
)
|
||||
|
||||
|
||||
def get_dimension_analysis_prompt(
|
||||
memory_summaries: List[Dict[str, Any]],
|
||||
user_id: str
|
||||
) -> str:
|
||||
"""Get formatted dimension analysis prompt using Jinja2 template."""
|
||||
return _render_template(
|
||||
"dimension_analysis.jinja2",
|
||||
memory_summaries=memory_summaries,
|
||||
user_id=user_id
|
||||
)
|
||||
|
||||
|
||||
def get_interest_analysis_prompt(
|
||||
memory_summaries: List[Dict[str, Any]],
|
||||
user_id: str
|
||||
) -> str:
|
||||
"""Get formatted interest analysis prompt using Jinja2 template."""
|
||||
return _render_template(
|
||||
"interest_analysis.jinja2",
|
||||
memory_summaries=memory_summaries,
|
||||
user_id=user_id
|
||||
)
|
||||
|
||||
|
||||
def get_habit_analysis_prompt(
|
||||
memory_summaries: List[Dict[str, Any]],
|
||||
user_id: str
|
||||
) -> str:
|
||||
"""Get formatted habit analysis prompt using Jinja2 template."""
|
||||
return _render_template(
|
||||
"habit_analysis.jinja2",
|
||||
memory_summaries=memory_summaries,
|
||||
user_id=user_id
|
||||
)
|
||||
@@ -0,0 +1,41 @@
|
||||
You are an expert personality analyst. Analyze memory summaries to assess the user's personality across four dimensions.
|
||||
|
||||
## Memory Summaries
|
||||
{% for summary in memory_summaries %}
|
||||
Summary {{ loop.index }}:
|
||||
{{ summary.content or summary.user_content or '' }}
|
||||
---
|
||||
{% endfor %}
|
||||
|
||||
## Target User ID
|
||||
{{ user_id }}
|
||||
|
||||
## Dimensions to Analyze
|
||||
1. **Creativity** (0-100%): Creative thinking, artistic interests, innovative ideas
|
||||
2. **Aesthetic** (0-100%): Design preferences, visual interests, artistic appreciation
|
||||
3. **Technology** (0-100%): Technical discussions, tool usage, programming interests
|
||||
4. **Literature** (0-100%): Reading habits, writing style, literary references
|
||||
|
||||
## Instructions
|
||||
1. Analyze the user's content for each dimension
|
||||
2. Calculate percentage scores (0-100%)
|
||||
|
||||
## Output Format
|
||||
{
|
||||
"dimensions": {
|
||||
"creativity": {"percentage": 0-100},
|
||||
"aesthetic": {"percentage": 0-100},
|
||||
"technology": {"percentage": 0-100},
|
||||
"literature": {"percentage": 0-100}
|
||||
}
|
||||
}
|
||||
|
||||
## Example
|
||||
{
|
||||
"dimensions": {
|
||||
"creativity": {"percentage": 75},
|
||||
"aesthetic": {"percentage": 45},
|
||||
"technology": {"percentage": 60},
|
||||
"literature": {"percentage": 30}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,70 @@
|
||||
You are an expert at identifying behavioral patterns and habits from memory summaries.
|
||||
|
||||
## Memory Summaries
|
||||
{% for summary in memory_summaries %}
|
||||
Summary {{ loop.index }}:
|
||||
{{ summary.content or summary.user_content or '' }}
|
||||
---
|
||||
{% endfor %}
|
||||
|
||||
## Target User ID
|
||||
{{ user_id }}
|
||||
|
||||
## Instructions
|
||||
1. Identify recurring behavioral patterns mentioned by the SPECIFIED USER
|
||||
2. Focus on specific, concrete habits with temporal patterns
|
||||
3. For each habit, provide:
|
||||
- habit_description: Clear, specific description
|
||||
- frequency_pattern: "daily", "weekly", "monthly", "seasonal", "occasional", "event_triggered"
|
||||
- time_context: When it typically happens
|
||||
- confidence_level: "high", "medium", "low"
|
||||
- supporting_summaries: References to evidence
|
||||
- specific_examples: Concrete examples from summaries
|
||||
- is_current: true if current habit, false if past habit
|
||||
4. Only include habits with medium or high confidence
|
||||
5. **IMPORTANT: Output language MUST match the input language. If summaries are in Chinese, output in Chinese. If in English, output in English.**
|
||||
|
||||
## Output Format
|
||||
{
|
||||
"habits": [
|
||||
{
|
||||
"habit_description": "string",
|
||||
"frequency_pattern": "daily|weekly|monthly|seasonal|occasional|event_triggered",
|
||||
"time_context": "string",
|
||||
"confidence_level": "high|medium|low",
|
||||
"supporting_summaries": ["id1", "id2"],
|
||||
"specific_examples": ["example1", "example2"],
|
||||
"is_current": true|false
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
## Example (English input → English output)
|
||||
{
|
||||
"habits": [
|
||||
{
|
||||
"habit_description": "drinks coffee every morning",
|
||||
"frequency_pattern": "daily",
|
||||
"time_context": "morning routine",
|
||||
"confidence_level": "high",
|
||||
"supporting_summaries": ["s1", "s2"],
|
||||
"specific_examples": ["needs coffee to start the day"],
|
||||
"is_current": true
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
## Example (Chinese input → Chinese output)
|
||||
{
|
||||
"habits": [
|
||||
{
|
||||
"habit_description": "每天早上喝咖啡",
|
||||
"frequency_pattern": "daily",
|
||||
"time_context": "早晨日常",
|
||||
"confidence_level": "high",
|
||||
"supporting_summaries": ["s1", "s2"],
|
||||
"specific_examples": ["需要咖啡来开始一天"],
|
||||
"is_current": true
|
||||
}
|
||||
]
|
||||
}
|
||||
@@ -0,0 +1,54 @@
|
||||
You are an expert at analyzing user interests from memory summaries.
|
||||
|
||||
## Memory Summaries
|
||||
{% for summary in memory_summaries %}
|
||||
Summary {{ loop.index }}:
|
||||
{{ summary.content or summary.user_content or '' }}
|
||||
---
|
||||
{% endfor %}
|
||||
|
||||
## Target User ID
|
||||
{{ user_id }}
|
||||
|
||||
## Interest Categories
|
||||
1. **Tech**: Programming, technology, software tools, hardware
|
||||
2. **Lifestyle**: Daily routines, health, hobbies, social activities
|
||||
3. **Music**: Music preferences, instruments, concerts
|
||||
4. **Art**: Visual arts, creative projects, design, aesthetics
|
||||
|
||||
## Instructions
|
||||
1. Categorize the user's interests into the four areas
|
||||
2. Calculate percentage distribution (must total 100%)
|
||||
3. Provide specific evidence for each interest area
|
||||
4. Use "increasing", "decreasing", or "stable" for trending direction
|
||||
5. **IMPORTANT: Output language MUST match the input language. If summaries are in Chinese, output in Chinese. If in English, output in English.**
|
||||
|
||||
## Output Format
|
||||
{
|
||||
"interest_distribution": {
|
||||
"tech": {"percentage": 0-100, "evidence": [], "trending_direction": "increasing|decreasing|stable|null"},
|
||||
"lifestyle": {"percentage": 0-100, "evidence": [], "trending_direction": "increasing|decreasing|stable|null"},
|
||||
"music": {"percentage": 0-100, "evidence": [], "trending_direction": "increasing|decreasing|stable|null"},
|
||||
"art": {"percentage": 0-100, "evidence": [], "trending_direction": "increasing|decreasing|stable|null"}
|
||||
}
|
||||
}
|
||||
|
||||
## Example (English input → English output)
|
||||
{
|
||||
"interest_distribution": {
|
||||
"tech": {"percentage": 40, "evidence": ["discusses programming frequently"], "trending_direction": "increasing"},
|
||||
"lifestyle": {"percentage": 35, "evidence": ["talks about fitness routine"], "trending_direction": "stable"},
|
||||
"music": {"percentage": 15, "evidence": ["mentioned favorite bands"], "trending_direction": "stable"},
|
||||
"art": {"percentage": 10, "evidence": ["visited art museum"], "trending_direction": "stable"}
|
||||
}
|
||||
}
|
||||
|
||||
## Example (Chinese input → Chinese output)
|
||||
{
|
||||
"interest_distribution": {
|
||||
"tech": {"percentage": 40, "evidence": ["经常讨论编程"], "trending_direction": "increasing"},
|
||||
"lifestyle": {"percentage": 35, "evidence": ["谈论健身日常"], "trending_direction": "stable"},
|
||||
"music": {"percentage": 15, "evidence": ["提到喜欢的乐队"], "trending_direction": "stable"},
|
||||
"art": {"percentage": 10, "evidence": ["参观了艺术博物馆"], "trending_direction": "stable"}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,47 @@
|
||||
You are an expert at analyzing user memory summaries to identify implicit preferences.
|
||||
|
||||
## Memory Summaries
|
||||
{% for summary in memory_summaries %}
|
||||
Summary {{ loop.index }}:
|
||||
{{ summary.content or summary.user_content or '' }}
|
||||
---
|
||||
{% endfor %}
|
||||
|
||||
## Target User ID
|
||||
{{ user_id }}
|
||||
|
||||
## Instructions
|
||||
1. Focus ONLY on the specified user's preferences
|
||||
2. Extract SHORT preference tags (1-3 words max), like: "音乐", "咖啡", "科幻", "设计", "古典", "吉他"
|
||||
3. DO NOT use long phrases - use short nouns or noun phrases
|
||||
4. Only include preferences with confidence_score >= 0.3
|
||||
5. **IMPORTANT: Output language MUST match the input language. If summaries are in Chinese, output in Chinese. If in English, output in English.**
|
||||
|
||||
## Output Format
|
||||
{
|
||||
"preferences": [
|
||||
{
|
||||
"tag_name": "short tag",
|
||||
"confidence_score": 0.0-1.0,
|
||||
"supporting_evidence": ["evidence1", "evidence2"],
|
||||
"context_details": "brief context",
|
||||
"category": "category or null"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
## Example (Chinese input → Chinese output)
|
||||
{
|
||||
"preferences": [
|
||||
{"tag_name": "咖啡", "confidence_score": 0.8, "supporting_evidence": ["每天早上喝咖啡"], "context_details": "日常习惯", "category": "lifestyle"},
|
||||
{"tag_name": "古典音乐", "confidence_score": 0.7, "supporting_evidence": ["喜欢听古典"], "context_details": "音乐偏好", "category": "music"}
|
||||
]
|
||||
}
|
||||
|
||||
## Example (English input → English output)
|
||||
{
|
||||
"preferences": [
|
||||
{"tag_name": "coffee", "confidence_score": 0.8, "supporting_evidence": ["drinks coffee every morning"], "context_details": "daily routine", "category": "lifestyle"},
|
||||
{"tag_name": "classical music", "confidence_score": 0.7, "supporting_evidence": ["enjoys classical"], "context_details": "music preference", "category": "music"}
|
||||
]
|
||||
}
|
||||
@@ -76,7 +76,7 @@ class HttpContentTypeConfig(BaseModel):
|
||||
elif content_type in [HttpContentType.JSON] and not isinstance(v, str):
|
||||
raise ValueError("When content_type is JSON, data must be of type str")
|
||||
elif content_type in [HttpContentType.WWW_FORM] and not isinstance(v, dict):
|
||||
raise ValueError("When content_type is x-www-form-urlencoded, data must be a object")
|
||||
raise ValueError("When content_type is x-www-form-urlencoded, data must be an object(dict)")
|
||||
elif content_type in [HttpContentType.RAW, HttpContentType.BINARY] and not isinstance(v, str):
|
||||
raise ValueError("When content_type is raw/binary, data must be a string (File descriptor)")
|
||||
return v
|
||||
|
||||
@@ -26,6 +26,7 @@ from .tool_model import (
|
||||
ToolConfig, BuiltinToolConfig, CustomToolConfig, MCPToolConfig,
|
||||
ToolExecution, ToolType, ToolStatus, AuthType, ExecutionStatus
|
||||
)
|
||||
from .memory_perceptual_model import MemoryPerceptualModel
|
||||
|
||||
__all__ = [
|
||||
"Tenants",
|
||||
@@ -74,5 +75,6 @@ __all__ = [
|
||||
"ToolType",
|
||||
"ToolStatus",
|
||||
"AuthType",
|
||||
"ExecutionStatus"
|
||||
"ExecutionStatus",
|
||||
"MemoryPerceptualModel"
|
||||
]
|
||||
|
||||
273
api/app/repositories/neo4j/memory_summary_repository.py
Normal file
273
api/app/repositories/neo4j/memory_summary_repository.py
Normal file
@@ -0,0 +1,273 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Memory Summary Repository Module
|
||||
|
||||
This module provides data access functionality for MemorySummary nodes.
|
||||
|
||||
Classes:
|
||||
MemorySummaryRepository: Repository for managing MemorySummary CRUD operations
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from app.repositories.neo4j.base_neo4j_repository import BaseNeo4jRepository
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
|
||||
|
||||
class MemorySummaryRepository(BaseNeo4jRepository):
|
||||
"""Memory Summary Repository
|
||||
|
||||
Manages CRUD operations for MemorySummary nodes.
|
||||
Provides methods to query summaries by group_id, user_id, and time ranges.
|
||||
|
||||
Attributes:
|
||||
connector: Neo4j connector instance
|
||||
node_label: Node label, fixed as "MemorySummary"
|
||||
"""
|
||||
|
||||
def __init__(self, connector: Neo4jConnector):
|
||||
"""Initialize memory summary repository
|
||||
|
||||
Args:
|
||||
connector: Neo4j connector instance
|
||||
"""
|
||||
super().__init__(connector, "MemorySummary")
|
||||
|
||||
def _map_to_dict(self, node_data: Dict) -> Dict[str, Any]:
|
||||
"""Map node data to dictionary format
|
||||
|
||||
Args:
|
||||
node_data: Node data returned from Neo4j query
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: Memory summary data dictionary
|
||||
"""
|
||||
# Extract node data from query result
|
||||
n = node_data.get('n', node_data)
|
||||
|
||||
# Handle datetime fields
|
||||
if isinstance(n.get('created_at'), str):
|
||||
n['created_at'] = datetime.fromisoformat(n['created_at'])
|
||||
|
||||
return dict(n)
|
||||
|
||||
async def find_by_group_id(
|
||||
self,
|
||||
group_id: str,
|
||||
limit: int = 1000,
|
||||
start_date: Optional[datetime] = None,
|
||||
end_date: Optional[datetime] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Query memory summaries by group_id
|
||||
|
||||
Args:
|
||||
group_id: Group ID to filter by
|
||||
limit: Maximum number of results to return
|
||||
start_date: Optional start date filter
|
||||
end_date: Optional end date filter
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: List of memory summary dictionaries
|
||||
"""
|
||||
query = f"""
|
||||
MATCH (n:{self.node_label})
|
||||
WHERE n.group_id = $group_id
|
||||
"""
|
||||
|
||||
params = {"group_id": group_id, "limit": limit}
|
||||
|
||||
# Add date range filters if provided
|
||||
if start_date:
|
||||
query += " AND n.created_at >= $start_date"
|
||||
params["start_date"] = start_date
|
||||
|
||||
if end_date:
|
||||
query += " AND n.created_at <= $end_date"
|
||||
params["end_date"] = end_date
|
||||
|
||||
query += """
|
||||
RETURN n
|
||||
ORDER BY n.created_at DESC
|
||||
LIMIT $limit
|
||||
"""
|
||||
|
||||
results = await self.connector.execute_query(query, **params)
|
||||
return [self._map_to_dict(r) for r in results]
|
||||
|
||||
async def find_by_user_id(
|
||||
self,
|
||||
user_id: str,
|
||||
limit: int = 1000,
|
||||
start_date: Optional[datetime] = None,
|
||||
end_date: Optional[datetime] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Query memory summaries by user_id
|
||||
|
||||
Args:
|
||||
user_id: User ID to filter by
|
||||
limit: Maximum number of results to return
|
||||
start_date: Optional start date filter
|
||||
end_date: Optional end date filter
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: List of memory summary dictionaries
|
||||
"""
|
||||
query = f"""
|
||||
MATCH (n:{self.node_label})
|
||||
WHERE n.user_id = $user_id
|
||||
"""
|
||||
|
||||
params = {"user_id": user_id, "limit": limit}
|
||||
|
||||
# Add date range filters if provided
|
||||
if start_date:
|
||||
query += " AND n.created_at >= $start_date"
|
||||
params["start_date"] = start_date
|
||||
|
||||
if end_date:
|
||||
query += " AND n.created_at <= $end_date"
|
||||
params["end_date"] = end_date
|
||||
|
||||
query += """
|
||||
RETURN n
|
||||
ORDER BY n.created_at DESC
|
||||
LIMIT $limit
|
||||
"""
|
||||
|
||||
results = await self.connector.execute_query(query, **params)
|
||||
return [self._map_to_dict(r) for r in results]
|
||||
|
||||
async def find_by_group_and_user(
|
||||
self,
|
||||
group_id: str,
|
||||
user_id: str,
|
||||
limit: int = 1000,
|
||||
start_date: Optional[datetime] = None,
|
||||
end_date: Optional[datetime] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Query memory summaries by both group_id and user_id
|
||||
|
||||
Args:
|
||||
group_id: Group ID to filter by
|
||||
user_id: User ID to filter by
|
||||
limit: Maximum number of results to return
|
||||
start_date: Optional start date filter
|
||||
end_date: Optional end date filter
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: List of memory summary dictionaries
|
||||
"""
|
||||
query = f"""
|
||||
MATCH (n:{self.node_label})
|
||||
WHERE n.group_id = $group_id AND n.user_id = $user_id
|
||||
"""
|
||||
|
||||
params = {"group_id": group_id, "user_id": user_id, "limit": limit}
|
||||
|
||||
# Add date range filters if provided
|
||||
if start_date:
|
||||
query += " AND n.created_at >= $start_date"
|
||||
params["start_date"] = start_date
|
||||
|
||||
if end_date:
|
||||
query += " AND n.created_at <= $end_date"
|
||||
params["end_date"] = end_date
|
||||
|
||||
query += """
|
||||
RETURN n
|
||||
ORDER BY n.created_at DESC
|
||||
LIMIT $limit
|
||||
"""
|
||||
|
||||
results = await self.connector.execute_query(query, **params)
|
||||
return [self._map_to_dict(r) for r in results]
|
||||
|
||||
async def find_recent_summaries(
|
||||
self,
|
||||
group_id: str,
|
||||
days: int = 7,
|
||||
limit: int = 1000
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Query recent memory summaries
|
||||
|
||||
Args:
|
||||
group_id: Group ID to filter by
|
||||
days: Number of recent days to query
|
||||
limit: Maximum number of results to return
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: List of memory summary dictionaries
|
||||
"""
|
||||
query = f"""
|
||||
MATCH (n:{self.node_label})
|
||||
WHERE n.group_id = $group_id
|
||||
AND n.created_at >= datetime() - duration({{days: $days}})
|
||||
RETURN n
|
||||
ORDER BY n.created_at DESC
|
||||
LIMIT $limit
|
||||
"""
|
||||
|
||||
results = await self.connector.execute_query(
|
||||
query,
|
||||
group_id=group_id,
|
||||
days=days,
|
||||
limit=limit
|
||||
)
|
||||
return [self._map_to_dict(r) for r in results]
|
||||
|
||||
async def find_by_content_keywords(
|
||||
self,
|
||||
group_id: str,
|
||||
keywords: List[str],
|
||||
limit: int = 100
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Query memory summaries by content keywords
|
||||
|
||||
Args:
|
||||
group_id: Group ID to filter by
|
||||
keywords: List of keywords to search for in content
|
||||
limit: Maximum number of results to return
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: List of memory summary dictionaries
|
||||
"""
|
||||
# Build keyword search conditions
|
||||
keyword_conditions = []
|
||||
params = {"group_id": group_id, "limit": limit}
|
||||
|
||||
for i, keyword in enumerate(keywords):
|
||||
keyword_conditions.append(f"toLower(n.content) CONTAINS toLower($keyword_{i})")
|
||||
params[f"keyword_{i}"] = keyword
|
||||
|
||||
keyword_filter = " OR ".join(keyword_conditions)
|
||||
|
||||
query = f"""
|
||||
MATCH (n:{self.node_label})
|
||||
WHERE n.group_id = $group_id
|
||||
AND ({keyword_filter})
|
||||
RETURN n
|
||||
ORDER BY n.created_at DESC
|
||||
LIMIT $limit
|
||||
"""
|
||||
|
||||
results = await self.connector.execute_query(query, **params)
|
||||
return [self._map_to_dict(r) for r in results]
|
||||
|
||||
async def get_summary_count_by_group(self, group_id: str) -> int:
|
||||
"""Get count of memory summaries for a group
|
||||
|
||||
Args:
|
||||
group_id: Group ID to count summaries for
|
||||
|
||||
Returns:
|
||||
int: Number of memory summaries
|
||||
"""
|
||||
query = f"""
|
||||
MATCH (n:{self.node_label})
|
||||
WHERE n.group_id = $group_id
|
||||
RETURN count(n) as count
|
||||
"""
|
||||
|
||||
results = await self.connector.execute_query(query, group_id=group_id)
|
||||
return results[0]['count'] if results else 0
|
||||
|
||||
212
api/app/schemas/implicit_memory_schema.py
Normal file
212
api/app/schemas/implicit_memory_schema.py
Normal file
@@ -0,0 +1,212 @@
|
||||
"""Implicit Memory Schemas
|
||||
|
||||
This module defines the Pydantic schemas for the implicit memory system API.
|
||||
"""
|
||||
|
||||
import datetime
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
|
||||
|
||||
class ConfidenceLevel(str, Enum):
|
||||
"""Confidence levels for analysis results."""
|
||||
HIGH = "high"
|
||||
MEDIUM = "medium"
|
||||
LOW = "low"
|
||||
|
||||
|
||||
class FrequencyPattern(str, Enum):
|
||||
"""Frequency patterns for habits."""
|
||||
DAILY = "daily"
|
||||
WEEKLY = "weekly"
|
||||
MONTHLY = "monthly"
|
||||
SEASONAL = "seasonal"
|
||||
OCCASIONAL = "occasional"
|
||||
EVENT_TRIGGERED = "event_triggered"
|
||||
|
||||
|
||||
# Request Schemas
|
||||
|
||||
class TimeRange(BaseModel):
|
||||
"""Time range for analysis."""
|
||||
start_date: datetime.datetime
|
||||
end_date: datetime.datetime
|
||||
|
||||
@field_validator('end_date')
|
||||
@classmethod
|
||||
def end_date_must_be_after_start_date(cls, v, info):
|
||||
if 'start_date' in info.data and v <= info.data['start_date']:
|
||||
raise ValueError('end_date must be after start_date')
|
||||
return v
|
||||
|
||||
|
||||
class DateRange(BaseModel):
|
||||
"""Date range for filtering."""
|
||||
start_date: Optional[datetime.datetime] = None
|
||||
end_date: Optional[datetime.datetime] = None
|
||||
|
||||
@field_validator('end_date')
|
||||
@classmethod
|
||||
def end_date_must_be_after_start_date(cls, v, info):
|
||||
if v and 'start_date' in info.data and info.data['start_date'] and v <= info.data['start_date']:
|
||||
raise ValueError('end_date must be after start_date')
|
||||
return v
|
||||
|
||||
|
||||
class AnalysisConfig(BaseModel):
|
||||
"""Configuration for analysis operations."""
|
||||
llm_model_id: Optional[str] = None
|
||||
batch_size: int = 100
|
||||
confidence_threshold: float = 0.5
|
||||
include_historical_trends: bool = False
|
||||
max_conversations: Optional[int] = None
|
||||
|
||||
|
||||
# Response Schemas
|
||||
|
||||
class PreferenceTagResponse(BaseModel):
|
||||
"""A user preference tag with detailed context."""
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
tag_name: str
|
||||
confidence_score: float = Field(ge=0.0, le=1.0)
|
||||
supporting_evidence: List[str]
|
||||
context_details: str
|
||||
created_at: datetime.datetime
|
||||
updated_at: datetime.datetime
|
||||
conversation_references: List[str]
|
||||
category: Optional[str] = None
|
||||
|
||||
|
||||
class DimensionScoreResponse(BaseModel):
|
||||
"""Score for a personality dimension."""
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
dimension_name: str
|
||||
percentage: float = Field(ge=0.0, le=100.0)
|
||||
evidence: List[str]
|
||||
reasoning: str
|
||||
confidence_level: ConfidenceLevel
|
||||
|
||||
|
||||
class DimensionPortraitResponse(BaseModel):
|
||||
"""Four-dimension personality portrait."""
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
user_id: str
|
||||
creativity: DimensionScoreResponse
|
||||
aesthetic: DimensionScoreResponse
|
||||
technology: DimensionScoreResponse
|
||||
literature: DimensionScoreResponse
|
||||
analysis_timestamp: datetime.datetime
|
||||
total_summaries_analyzed: int
|
||||
historical_trends: Optional[List[Dict[str, Any]]] = None
|
||||
|
||||
|
||||
class InterestCategoryResponse(BaseModel):
|
||||
"""Interest category with percentage and evidence."""
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
category_name: str
|
||||
percentage: float = Field(ge=0.0, le=100.0)
|
||||
evidence: List[str]
|
||||
trending_direction: Optional[str] = None
|
||||
|
||||
|
||||
class InterestAreaDistributionResponse(BaseModel):
|
||||
"""Distribution of user interests across four areas."""
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
user_id: str
|
||||
tech: InterestCategoryResponse
|
||||
lifestyle: InterestCategoryResponse
|
||||
music: InterestCategoryResponse
|
||||
art: InterestCategoryResponse
|
||||
analysis_timestamp: datetime.datetime
|
||||
total_summaries_analyzed: int
|
||||
|
||||
@property
|
||||
def total_percentage(self) -> float:
|
||||
"""Calculate total percentage across all interest areas."""
|
||||
return self.tech.percentage + self.lifestyle.percentage + self.music.percentage + self.art.percentage
|
||||
|
||||
|
||||
class BehaviorHabitResponse(BaseModel):
|
||||
"""A behavioral habit identified from conversations."""
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
habit_description: str
|
||||
frequency_pattern: FrequencyPattern
|
||||
time_context: str
|
||||
confidence_level: ConfidenceLevel
|
||||
supporting_summaries: List[str]
|
||||
first_observed: datetime.datetime
|
||||
last_observed: datetime.datetime
|
||||
is_current: bool = True
|
||||
specific_examples: List[str]
|
||||
|
||||
|
||||
class UserProfileResponse(BaseModel):
|
||||
"""Comprehensive user profile."""
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
user_id: str
|
||||
preference_tags: List[PreferenceTagResponse]
|
||||
dimension_portrait: DimensionPortraitResponse
|
||||
interest_area_distribution: InterestAreaDistributionResponse
|
||||
behavior_habits: List[BehaviorHabitResponse]
|
||||
profile_version: int
|
||||
created_at: datetime.datetime
|
||||
updated_at: datetime.datetime
|
||||
total_summaries_analyzed: int
|
||||
analysis_completeness_score: float = Field(ge=0.0, le=1.0)
|
||||
|
||||
|
||||
# Internal/Business Logic Schemas
|
||||
|
||||
class MemorySummary(BaseModel):
|
||||
"""Memory summary from existing storage system."""
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
summary_id: str
|
||||
content: str
|
||||
timestamp: datetime.datetime
|
||||
participants: List[str]
|
||||
summary_type: str
|
||||
|
||||
|
||||
class UserMemorySummary(BaseModel):
|
||||
"""Memory summary filtered for specific user content."""
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
summary_id: str
|
||||
user_id: str
|
||||
user_content: str
|
||||
timestamp: datetime.datetime
|
||||
confidence_score: float = Field(ge=0.0, le=1.0)
|
||||
summary_type: str
|
||||
|
||||
|
||||
class SummaryAnalysisResult(BaseModel):
|
||||
"""Result of analyzing memory summaries."""
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
user_id: str
|
||||
preferences: List[PreferenceTagResponse]
|
||||
dimension_evidence: Dict[str, List[str]]
|
||||
interest_evidence: Dict[str, List[str]]
|
||||
habit_evidence: List[Dict[str, Any]]
|
||||
analysis_timestamp: datetime.datetime
|
||||
summaries_analyzed: List[str]
|
||||
|
||||
|
||||
# Aliases for backward compatibility with existing code
|
||||
PreferenceTag = PreferenceTagResponse
|
||||
DimensionScore = DimensionScoreResponse
|
||||
DimensionPortrait = DimensionPortraitResponse
|
||||
InterestCategory = InterestCategoryResponse
|
||||
InterestAreaDistribution = InterestAreaDistributionResponse
|
||||
BehaviorHabit = BehaviorHabitResponse
|
||||
UserProfile = UserProfileResponse
|
||||
385
api/app/services/implicit_memory_service.py
Normal file
385
api/app/services/implicit_memory_service.py
Normal file
@@ -0,0 +1,385 @@
|
||||
"""
|
||||
Implicit Memory Service
|
||||
|
||||
Main service orchestrating all implicit memory operations. This service coordinates
|
||||
profile building, data extraction, and provides high-level methods for analyzing
|
||||
user profiles from memory summaries.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import List, Optional
|
||||
|
||||
from app.core.memory.analytics.implicit_memory.analyzers.dimension_analyzer import (
|
||||
DimensionAnalyzer,
|
||||
)
|
||||
from app.core.memory.analytics.implicit_memory.analyzers.interest_analyzer import (
|
||||
InterestAnalyzer,
|
||||
)
|
||||
from app.core.memory.analytics.implicit_memory.analyzers.preference_analyzer import (
|
||||
PreferenceAnalyzer,
|
||||
)
|
||||
from app.core.memory.analytics.implicit_memory.data_source import MemoryDataSource
|
||||
from app.core.memory.analytics.implicit_memory.habit_detector import HabitDetector
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.schemas.implicit_memory_schema import (
|
||||
BehaviorHabit,
|
||||
ConfidenceLevel,
|
||||
DateRange,
|
||||
DimensionPortrait,
|
||||
FrequencyPattern,
|
||||
InterestAreaDistribution,
|
||||
PreferenceTag,
|
||||
TimeRange,
|
||||
UserMemorySummary,
|
||||
)
|
||||
from app.schemas.memory_config_schema import MemoryConfig
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ImplicitMemoryService:
|
||||
"""Main service for implicit memory operations."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
db: Session,
|
||||
end_user_id: str
|
||||
):
|
||||
"""Initialize the implicit memory service.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
end_user_id: End user ID to get connected memory configuration
|
||||
"""
|
||||
self.db = db
|
||||
self.end_user_id = end_user_id
|
||||
|
||||
# Get connected memory configuration for the user
|
||||
self.memory_config = self._get_user_memory_config()
|
||||
|
||||
# Extract LLM model ID from memory config
|
||||
llm_model_id = str(self.memory_config.llm_model_id) if self.memory_config.llm_model_id else None
|
||||
|
||||
# Initialize Neo4j connector
|
||||
self.neo4j_connector = Neo4jConnector()
|
||||
|
||||
# Initialize core components with LLM model ID
|
||||
self.data_source = MemoryDataSource(db, self.neo4j_connector)
|
||||
self.preference_analyzer = PreferenceAnalyzer(db, llm_model_id)
|
||||
self.dimension_analyzer = DimensionAnalyzer(db, llm_model_id)
|
||||
self.interest_analyzer = InterestAnalyzer(db, llm_model_id)
|
||||
self.habit_detector = HabitDetector(db, llm_model_id)
|
||||
|
||||
logger.info(f"ImplicitMemoryService initialized for end_user: {end_user_id}")
|
||||
|
||||
def _get_user_memory_config(self) -> MemoryConfig:
|
||||
"""Get memory configuration for the connected end user.
|
||||
|
||||
Returns:
|
||||
MemoryConfig: User's connected memory configuration
|
||||
|
||||
Raises:
|
||||
ValueError: If no memory configuration found for user
|
||||
"""
|
||||
try:
|
||||
from app.services.memory_agent_service import get_end_user_connected_config
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
|
||||
# Get user's connected config
|
||||
connected_config = get_end_user_connected_config(self.end_user_id, self.db)
|
||||
config_id = connected_config.get("memory_config_id")
|
||||
|
||||
if config_id is None:
|
||||
raise ValueError(f"No memory configuration found for end_user: {self.end_user_id}")
|
||||
|
||||
# Load the memory configuration
|
||||
config_service = MemoryConfigService(self.db)
|
||||
memory_config = config_service.load_memory_config(config_id)
|
||||
|
||||
logger.info(f"Loaded memory config {config_id} for end_user: {self.end_user_id}")
|
||||
return memory_config
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get memory config for end_user {self.end_user_id}: {e}")
|
||||
raise ValueError(f"Unable to get memory configuration for end_user {self.end_user_id}: {e}")
|
||||
|
||||
async def extract_user_summaries(
|
||||
self,
|
||||
user_id: str,
|
||||
time_range: Optional[TimeRange] = None,
|
||||
limit: Optional[int] = None
|
||||
) -> List[UserMemorySummary]:
|
||||
"""Extract user-specific memory summaries.
|
||||
|
||||
Args:
|
||||
user_id: Target user ID
|
||||
time_range: Optional time range to filter summaries
|
||||
limit: Optional limit on number of summaries
|
||||
|
||||
Returns:
|
||||
List of user-specific memory summaries
|
||||
"""
|
||||
logger.info(f"Extracting user summaries for user {user_id}")
|
||||
|
||||
try:
|
||||
summaries = await self.data_source.get_user_summaries(
|
||||
user_id=user_id,
|
||||
time_range=time_range,
|
||||
limit=limit or 1000
|
||||
)
|
||||
|
||||
logger.info(f"Extracted {len(summaries)} summaries for user {user_id}")
|
||||
return summaries
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to extract user summaries for user {user_id}: {e}")
|
||||
raise
|
||||
|
||||
async def get_preference_tags(
|
||||
self,
|
||||
user_id: str,
|
||||
confidence_threshold: float = 0.5,
|
||||
tag_category: Optional[str] = None,
|
||||
date_range: Optional[DateRange] = None
|
||||
) -> List[PreferenceTag]:
|
||||
"""Retrieve user preference tags with filtering.
|
||||
|
||||
Args:
|
||||
user_id: Target user ID
|
||||
confidence_threshold: Minimum confidence score for tags
|
||||
tag_category: Optional category filter
|
||||
date_range: Optional date range filter
|
||||
|
||||
Returns:
|
||||
List of filtered preference tags
|
||||
"""
|
||||
logger.info(f"Getting preference tags for user {user_id}")
|
||||
|
||||
try:
|
||||
# Get user summaries for analysis
|
||||
time_range = None
|
||||
if date_range:
|
||||
time_range = TimeRange(
|
||||
start_date=date_range.start_date or datetime.min,
|
||||
end_date=date_range.end_date or datetime.now()
|
||||
)
|
||||
|
||||
user_summaries = await self.extract_user_summaries(
|
||||
user_id=user_id,
|
||||
time_range=time_range
|
||||
)
|
||||
|
||||
if not user_summaries:
|
||||
logger.warning(f"No summaries found for user {user_id}")
|
||||
return []
|
||||
|
||||
# Analyze preferences
|
||||
preference_tags = await self.preference_analyzer.analyze_preferences(
|
||||
user_id=user_id,
|
||||
user_summaries=user_summaries
|
||||
)
|
||||
|
||||
# Apply filters
|
||||
filtered_tags = []
|
||||
for tag in preference_tags:
|
||||
# Filter by confidence threshold
|
||||
if tag.confidence_score < confidence_threshold:
|
||||
continue
|
||||
|
||||
# Filter by category if specified
|
||||
if tag_category and tag.category != tag_category:
|
||||
continue
|
||||
|
||||
# Filter by date range if specified
|
||||
if date_range:
|
||||
if date_range.start_date and tag.created_at < date_range.start_date:
|
||||
continue
|
||||
if date_range.end_date and tag.created_at > date_range.end_date:
|
||||
continue
|
||||
|
||||
filtered_tags.append(tag)
|
||||
|
||||
# Sort by confidence score and recency
|
||||
filtered_tags.sort(
|
||||
key=lambda x: (x.confidence_score, x.updated_at),
|
||||
reverse=True
|
||||
)
|
||||
|
||||
logger.info(f"Retrieved {len(filtered_tags)} preference tags for user {user_id}")
|
||||
return filtered_tags
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get preference tags for user {user_id}: {e}")
|
||||
raise
|
||||
|
||||
async def get_dimension_portrait(
|
||||
self,
|
||||
user_id: str,
|
||||
include_history: bool = False
|
||||
) -> DimensionPortrait:
|
||||
"""Get user's four-dimension personality portrait.
|
||||
|
||||
Args:
|
||||
user_id: Target user ID
|
||||
include_history: Whether to include historical trends
|
||||
|
||||
Returns:
|
||||
User's dimension portrait
|
||||
"""
|
||||
logger.info(f"Getting dimension portrait for user {user_id}")
|
||||
|
||||
try:
|
||||
# Get user summaries
|
||||
user_summaries = await self.extract_user_summaries(user_id=user_id)
|
||||
|
||||
if not user_summaries:
|
||||
logger.warning(f"No summaries found for user {user_id}")
|
||||
return self.dimension_analyzer._create_empty_portrait(user_id)
|
||||
|
||||
# Analyze dimensions
|
||||
dimension_portrait = await self.dimension_analyzer.analyze_dimensions(
|
||||
user_id=user_id,
|
||||
user_summaries=user_summaries
|
||||
)
|
||||
|
||||
# Include historical trends if requested
|
||||
if include_history:
|
||||
# In a full implementation, this would retrieve historical data
|
||||
# For now, we'll leave historical_trends as None
|
||||
pass
|
||||
|
||||
logger.info(f"Retrieved dimension portrait for user {user_id}")
|
||||
return dimension_portrait
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get dimension portrait for user {user_id}: {e}")
|
||||
raise
|
||||
|
||||
async def get_interest_area_distribution(
|
||||
self,
|
||||
user_id: str,
|
||||
include_trends: bool = False
|
||||
) -> InterestAreaDistribution:
|
||||
"""Get user's interest area distribution across four areas.
|
||||
|
||||
Args:
|
||||
user_id: Target user ID
|
||||
include_trends: Whether to include trending information
|
||||
|
||||
Returns:
|
||||
User's interest area distribution
|
||||
"""
|
||||
logger.info(f"Getting interest area distribution for user {user_id}")
|
||||
|
||||
try:
|
||||
# Get user summaries
|
||||
user_summaries = await self.extract_user_summaries(user_id=user_id)
|
||||
|
||||
if not user_summaries:
|
||||
logger.warning(f"No summaries found for user {user_id}")
|
||||
return self.interest_analyzer._create_empty_distribution(user_id)
|
||||
|
||||
# Analyze interests
|
||||
interest_distribution = await self.interest_analyzer.analyze_interests(
|
||||
user_id=user_id,
|
||||
user_summaries=user_summaries
|
||||
)
|
||||
|
||||
# Include trends if requested
|
||||
if include_trends:
|
||||
# In a full implementation, this would calculate trending directions
|
||||
# For now, we'll leave trending_direction as None for each category
|
||||
pass
|
||||
|
||||
logger.info(f"Retrieved interest area distribution for user {user_id}")
|
||||
return interest_distribution
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get interest area distribution for user {user_id}: {e}")
|
||||
raise
|
||||
|
||||
async def get_behavior_habits(
|
||||
self,
|
||||
user_id: str,
|
||||
confidence_level: Optional[str] = None,
|
||||
frequency_pattern: Optional[str] = None,
|
||||
time_period: Optional[str] = None
|
||||
) -> List[BehaviorHabit]:
|
||||
"""Get user's behavioral habits with filtering.
|
||||
|
||||
Args:
|
||||
user_id: Target user ID
|
||||
confidence_level: Optional confidence level filter ("high", "medium", "low")
|
||||
frequency_pattern: Optional frequency pattern filter
|
||||
time_period: Optional time period filter ("current", "past")
|
||||
|
||||
Returns:
|
||||
List of filtered behavioral habits
|
||||
"""
|
||||
logger.info(f"Getting behavior habits for user {user_id}")
|
||||
|
||||
try:
|
||||
# Get user summaries
|
||||
user_summaries = await self.extract_user_summaries(user_id=user_id)
|
||||
|
||||
if not user_summaries:
|
||||
logger.warning(f"No summaries found for user {user_id}")
|
||||
return []
|
||||
|
||||
# Detect habits
|
||||
behavior_habits = await self.habit_detector.detect_habits(
|
||||
user_id=user_id,
|
||||
user_summaries=user_summaries
|
||||
)
|
||||
|
||||
# Apply filters
|
||||
filtered_habits = []
|
||||
for habit in behavior_habits:
|
||||
# Filter by confidence level
|
||||
if confidence_level:
|
||||
try:
|
||||
target_confidence = ConfidenceLevel(confidence_level.lower())
|
||||
if habit.confidence_level != target_confidence:
|
||||
continue
|
||||
except ValueError:
|
||||
logger.warning(f"Invalid confidence level: {confidence_level}")
|
||||
continue
|
||||
|
||||
# Filter by frequency pattern
|
||||
if frequency_pattern:
|
||||
try:
|
||||
target_frequency = FrequencyPattern(frequency_pattern.lower())
|
||||
if habit.frequency_pattern != target_frequency:
|
||||
continue
|
||||
except ValueError:
|
||||
logger.warning(f"Invalid frequency pattern: {frequency_pattern}")
|
||||
continue
|
||||
|
||||
# Filter by time period
|
||||
if time_period:
|
||||
if time_period.lower() == "current" and not habit.is_current:
|
||||
continue
|
||||
elif time_period.lower() == "past" and habit.is_current:
|
||||
continue
|
||||
|
||||
filtered_habits.append(habit)
|
||||
|
||||
# Sort by confidence level and recency
|
||||
confidence_order = {"high": 3, "medium": 2, "low": 1}
|
||||
filtered_habits.sort(
|
||||
key=lambda x: (
|
||||
confidence_order.get(x.confidence_level.value, 0),
|
||||
x.last_observed
|
||||
),
|
||||
reverse=True
|
||||
)
|
||||
|
||||
logger.info(f"Retrieved {len(filtered_habits)} behavior habits for user {user_id}")
|
||||
return filtered_habits
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get behavior habits for user {user_id}: {e}")
|
||||
raise
|
||||
|
||||
@@ -68,12 +68,7 @@ edges:
|
||||
label: 完成
|
||||
|
||||
# 变量定义
|
||||
variables:
|
||||
- name: user_question
|
||||
type: string
|
||||
required: true
|
||||
description: 用户的问题
|
||||
default: ""
|
||||
variables: []
|
||||
|
||||
# 执行配置
|
||||
execution_config:
|
||||
|
||||
88
api/migrations/versions/c6d4afa27bf0_202601071800.py
Normal file
88
api/migrations/versions/c6d4afa27bf0_202601071800.py
Normal file
@@ -0,0 +1,88 @@
|
||||
"""202601071800
|
||||
|
||||
Revision ID: c6d4afa27bf0
|
||||
Revises: 8372101eda28
|
||||
Create Date: 2026-01-07 17:59:23.032323
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = 'c6d4afa27bf0'
|
||||
down_revision: Union[str, None] = '8372101eda28'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table('memory_long_term',
|
||||
sa.Column('id', sa.UUID(), nullable=False, comment='记忆ID'),
|
||||
sa.Column('end_user_id', sa.String(length=255), nullable=False, comment='终端用户ID'),
|
||||
sa.Column('retrieved_content', sa.JSON(), nullable=True, comment='检索到的相关内容,格式为[{}, {}]'),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=False, comment='创建时间'),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_index(op.f('ix_memory_long_term_created_at'), 'memory_long_term', ['created_at'], unique=False)
|
||||
op.create_index(op.f('ix_memory_long_term_end_user_id'), 'memory_long_term', ['end_user_id'], unique=False)
|
||||
op.create_index(op.f('ix_memory_long_term_id'), 'memory_long_term', ['id'], unique=False)
|
||||
op.create_table('memory_short_term',
|
||||
sa.Column('id', sa.UUID(), nullable=False, comment='记忆ID'),
|
||||
sa.Column('end_user_id', sa.String(length=255), nullable=False, comment='终端用户ID'),
|
||||
sa.Column('messages', sa.Text(), nullable=False, comment='用户消息内容'),
|
||||
sa.Column('aimessages', sa.Text(), nullable=True, comment='AI回复消息内容'),
|
||||
sa.Column('search_switch', sa.String(length=50), nullable=True, comment='搜索开关状态'),
|
||||
sa.Column('retrieved_content', sa.JSON(), nullable=True, comment='检索到的相关内容,格式为[{}, {}]'),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=False, comment='创建时间'),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_index(op.f('ix_memory_short_term_created_at'), 'memory_short_term', ['created_at'], unique=False)
|
||||
op.create_index(op.f('ix_memory_short_term_end_user_id'), 'memory_short_term', ['end_user_id'], unique=False)
|
||||
op.create_index(op.f('ix_memory_short_term_id'), 'memory_short_term', ['id'], unique=False)
|
||||
op.create_table('memory_perceptual',
|
||||
sa.Column('id', sa.UUID(), nullable=False),
|
||||
sa.Column('end_user_id', sa.UUID(), nullable=True),
|
||||
sa.Column('perceptual_type', sa.Integer(), nullable=False, comment='感知类型'),
|
||||
sa.Column('storage_service', sa.Integer(), nullable=True, comment='存储服务类型'),
|
||||
sa.Column('file_path', sa.String(), nullable=False, comment='文件路径'),
|
||||
sa.Column('file_name', sa.String(), nullable=False, comment='文件名称'),
|
||||
sa.Column('file_ext', sa.String(), nullable=False, comment='文件后缀名'),
|
||||
sa.Column('summary', sa.String(), nullable=True, comment='摘要'),
|
||||
sa.Column('meta_data', postgresql.JSONB(astext_type=sa.Text()), nullable=True, comment='元信息'),
|
||||
sa.Column('created_time', sa.DateTime(), nullable=True, comment='创建时间'),
|
||||
sa.ForeignKeyConstraint(['end_user_id'], ['end_users.id'], ),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_index(op.f('ix_memory_perceptual_end_user_id'), 'memory_perceptual', ['end_user_id'], unique=False)
|
||||
op.create_index(op.f('ix_memory_perceptual_perceptual_type'), 'memory_perceptual', ['perceptual_type'], unique=False)
|
||||
op.alter_column('multi_agent_configs', 'orchestration_mode',
|
||||
existing_type=sa.VARCHAR(length=20),
|
||||
comment='协作模式: collaboration(协作)| supervisor(监督)',
|
||||
existing_comment='协作模式: sequential|parallel|conditional|loop',
|
||||
existing_nullable=False)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.alter_column('multi_agent_configs', 'orchestration_mode',
|
||||
existing_type=sa.VARCHAR(length=20),
|
||||
comment='协作模式: sequential|parallel|conditional|loop',
|
||||
existing_comment='协作模式: collaboration(协作)| supervisor(监督)',
|
||||
existing_nullable=False)
|
||||
op.drop_index(op.f('ix_memory_perceptual_perceptual_type'), table_name='memory_perceptual')
|
||||
op.drop_index(op.f('ix_memory_perceptual_end_user_id'), table_name='memory_perceptual')
|
||||
op.drop_table('memory_perceptual')
|
||||
op.drop_index(op.f('ix_memory_short_term_id'), table_name='memory_short_term')
|
||||
op.drop_index(op.f('ix_memory_short_term_end_user_id'), table_name='memory_short_term')
|
||||
op.drop_index(op.f('ix_memory_short_term_created_at'), table_name='memory_short_term')
|
||||
op.drop_table('memory_short_term')
|
||||
op.drop_index(op.f('ix_memory_long_term_id'), table_name='memory_long_term')
|
||||
op.drop_index(op.f('ix_memory_long_term_end_user_id'), table_name='memory_long_term')
|
||||
op.drop_index(op.f('ix_memory_long_term_created_at'), table_name='memory_long_term')
|
||||
op.drop_table('memory_long_term')
|
||||
# ### end Alembic commands ###
|
||||
@@ -1137,10 +1137,10 @@ export const en = {
|
||||
promptEmpty: 'Describe your use case on the left, and the orchestration preview will be displayed here.',
|
||||
|
||||
master: 'Supervisor Mode',
|
||||
master_agent: 'Supervisor Mode',
|
||||
master_agentDesc: 'Unified scheduling and management by the main Agent, with sub-Agents executing tasks assigned by the supervisor, suitable for scenarios requiring centralized control.',
|
||||
handoffs: 'Collaboration Mode',
|
||||
handoffsDesc: 'Multiple Agents collaborate equally, autonomously coordinating according to task requirements, suitable for complex scenarios requiring flexible interaction.',
|
||||
supervisor: 'Supervisor Mode',
|
||||
supervisorDesc: 'Unified scheduling and management by the main Agent, with sub-Agents executing tasks assigned by the supervisor, suitable for scenarios requiring centralized control.',
|
||||
collaboration: 'Collaboration Mode',
|
||||
collaborationDesc: 'Multiple Agents collaborate equally, autonomously coordinating according to task requirements, suitable for complex scenarios requiring flexible interaction.',
|
||||
masterConfig: 'Supervisor Configuration',
|
||||
orchestrationMode: 'Task Assignment Strategy',
|
||||
conditional: 'Intelligent Assignment',
|
||||
@@ -1150,6 +1150,8 @@ export const en = {
|
||||
merge: 'Complete Aggregation',
|
||||
vote: 'Key Information Extraction',
|
||||
priority: 'Structured Integration',
|
||||
addTool: 'Add Tool',
|
||||
tool: 'Tool',
|
||||
},
|
||||
userMemory: {
|
||||
userMemory: 'User Memory',
|
||||
@@ -1207,6 +1209,7 @@ export const en = {
|
||||
IMPLICIT_MEMORY: 'Implicit Memory',
|
||||
EMOTIONAL_MEMORY: 'Emotional Memory',
|
||||
EPISODIC_MEMORY: 'Episodic Memory',
|
||||
FORGETTING_MANAGEMENT: 'Forgetting Management',
|
||||
|
||||
endUserProfile: 'Core Profile',
|
||||
editEndUserProfile: 'Edit',
|
||||
@@ -1839,6 +1842,7 @@ Memory Bear: After the rebellion, regional warlordism intensified for several re
|
||||
status_code: 'Status Code',
|
||||
max_attempts: 'Max Retry Attempts',
|
||||
retry_interval: 'Retry Interval',
|
||||
errorBranch: 'Error Branch',
|
||||
},
|
||||
'jinja-render': {
|
||||
template: 'Code',
|
||||
|
||||
@@ -626,10 +626,10 @@ export const zh = {
|
||||
promptEmpty: '在左侧描述您的用例,编排预览将在此处显示。',
|
||||
|
||||
master: '主管模式',
|
||||
master_agent: '主管模式',
|
||||
master_agentDesc: '由主 Agent 统一调度和管理,子 Agent 按照主管分配的任务执行,适合需要集中控制的场景。',
|
||||
handoffs: '协作模式',
|
||||
handoffsDesc: '多个 Agent 平等协作,根据任务需求自主协调配合,适合需要灵活互动的复杂场景。',
|
||||
supervisor: '主管模式',
|
||||
supervisorDesc: '由主 Agent 统一调度和管理,子 Agent 按照主管分配的任务执行,适合需要集中控制的场景。',
|
||||
collaboration: '协作模式',
|
||||
collaborationDesc: '多个 Agent 平等协作,根据任务需求自主协调配合,适合需要灵活互动的复杂场景。',
|
||||
masterConfig: '主管配置',
|
||||
orchestrationMode: '任务分配策略',
|
||||
conditional: '智能分配',
|
||||
@@ -639,6 +639,8 @@ export const zh = {
|
||||
merge: '完整汇总',
|
||||
vote: '关键信息提取',
|
||||
priority: '结构化整合',
|
||||
addTool: '添加工具',
|
||||
tool: '工具',
|
||||
},
|
||||
// 角色管理相关翻译
|
||||
role: {
|
||||
@@ -1286,6 +1288,7 @@ export const zh = {
|
||||
IMPLICIT_MEMORY: '隐性记忆',
|
||||
EMOTIONAL_MEMORY: '情绪记忆',
|
||||
EPISODIC_MEMORY: '情景记忆',
|
||||
FORGETTING_MANAGEMENT: '遗忘',
|
||||
|
||||
endUserProfile: '核心档案',
|
||||
editEndUserProfile: '编辑',
|
||||
@@ -1939,6 +1942,7 @@ export const zh = {
|
||||
status_code: '状态码',
|
||||
max_attempts: '最大重试次数',
|
||||
retry_interval: '重试间隔',
|
||||
errorBranch: '异常分支',
|
||||
},
|
||||
'jinja-render': {
|
||||
template: '代码',
|
||||
@@ -2252,5 +2256,12 @@ export const zh = {
|
||||
orderPayInfo: '支付信息',
|
||||
create_time: '创建时间',
|
||||
},
|
||||
forgetDetail: {
|
||||
title: '遗忘管理系统帮助AI智能管理记忆生命周期,通过自动识别低价值记忆、设置遗忘策略和执行定期清理,优化记忆库存储空间,提升检索效率。',
|
||||
overviewTitle: '核心指标概览',
|
||||
totalMemory: '记忆总量',
|
||||
MemoryHealth: '记忆健康度',
|
||||
riskOfForgetting: '遗忘风险',
|
||||
}
|
||||
},
|
||||
}
|
||||
@@ -19,7 +19,6 @@ import type {
|
||||
MemoryConfig,
|
||||
AiPromptModalRef,
|
||||
Source,
|
||||
ToolModalRef,
|
||||
ToolOption
|
||||
} from './types'
|
||||
import type { Model } from '@/views/ModelManagement/types'
|
||||
@@ -33,7 +32,6 @@ import { memoryConfigListUrl } from '@/api/memory'
|
||||
import CustomSelect from '@/components/CustomSelect'
|
||||
import aiPrompt from '@/assets/images/application/aiPrompt.png'
|
||||
import AiPromptModal from './components/AiPromptModal'
|
||||
import ToolModal from './components/ToolModal'
|
||||
import ToolList from './components/ToolList'
|
||||
|
||||
const DescWrapper: FC<{desc: string, className?: string}> = ({desc, className}) => {
|
||||
@@ -115,6 +113,7 @@ const Agent = forwardRef<AgentRef>((_props, ref) => {
|
||||
const [variableList, setVariableList] = useState<Variable[]>([])
|
||||
const [isSave, setIsSave] = useState(false)
|
||||
const initialized = useRef(false)
|
||||
const [toolList, setToolList] = useState<ToolOption[]>([])
|
||||
|
||||
// 初始化完成标记
|
||||
useEffect(() => {
|
||||
@@ -143,6 +142,11 @@ const Agent = forwardRef<AgentRef>((_props, ref) => {
|
||||
if (isSave) return
|
||||
setIsSave(true)
|
||||
}, [values])
|
||||
useEffect(() => {
|
||||
if (!initialized.current) return
|
||||
if (isSave) return
|
||||
setIsSave(true)
|
||||
}, [toolList])
|
||||
|
||||
useEffect(() => {
|
||||
getModels()
|
||||
@@ -294,7 +298,11 @@ const Agent = forwardRef<AgentRef>((_props, ref) => {
|
||||
...(item.config || {})
|
||||
}))
|
||||
} as KnowledgeConfig : null,
|
||||
tools: toolList
|
||||
tools: toolList.map(vo => ({
|
||||
tool_id: vo.tool_id,
|
||||
operation: vo.operation,
|
||||
enabled: vo.enabled
|
||||
}))
|
||||
}
|
||||
|
||||
console.log('params', rest, params)
|
||||
@@ -347,18 +355,6 @@ const Agent = forwardRef<AgentRef>((_props, ref) => {
|
||||
form.setFieldValue('system_prompt', value)
|
||||
}
|
||||
|
||||
const toolModalRef = useRef<ToolModalRef>(null)
|
||||
const [toolList, setToolList] = useState<ToolOption[]>([])
|
||||
const handleAddTool = () => {
|
||||
toolModalRef.current?.handleOpen()
|
||||
}
|
||||
const updateTools = (tool: ToolOption) => {
|
||||
const tools = [...toolList, tool]
|
||||
setToolList(tools)
|
||||
form.setFieldValue('tools', tools)
|
||||
}
|
||||
|
||||
console.log('toolList', toolList)
|
||||
return (
|
||||
<>
|
||||
{loading && <Spin fullscreen></Spin>}
|
||||
@@ -469,10 +465,6 @@ const Agent = forwardRef<AgentRef>((_props, ref) => {
|
||||
defaultModel={defaultModel}
|
||||
refresh={updatePrompt}
|
||||
/>
|
||||
<ToolModal
|
||||
ref={toolModalRef}
|
||||
refresh={updateTools}
|
||||
/>
|
||||
</>
|
||||
);
|
||||
});
|
||||
|
||||
@@ -42,7 +42,7 @@ const Cluster = forwardRef<ClusterRef>((_props, ref) => {
|
||||
|
||||
const handleSave = (flag = true) => {
|
||||
if (!data) return Promise.resolve()
|
||||
if (!values.default_model_config_id) {
|
||||
if (!values.default_model_config_id && values.orchestration_mode === 'supervisor') {
|
||||
message.warning(t('common.selectPlaceholder', { title: t('application.model') }))
|
||||
return Promise.resolve()
|
||||
}
|
||||
@@ -140,15 +140,14 @@ const Cluster = forwardRef<ClusterRef>((_props, ref) => {
|
||||
<Space size={20} direction="vertical" style={{width: '100%'}}>
|
||||
<Card title={t('application.handoffs')}>
|
||||
<Form.Item
|
||||
name={['execution_config', 'routing_mode']}
|
||||
name="orchestration_mode"
|
||||
noStyle
|
||||
>
|
||||
<RadioGroupCard
|
||||
options={['master_agent', 'handoffs'].map((type) => ({
|
||||
options={['supervisor', 'collaboration'].map((type) => ({
|
||||
value: type,
|
||||
label: t(`application.${type}`),
|
||||
labelDesc: t(`application.${type}Desc`),
|
||||
disabled: type === 'handoffs'
|
||||
}))}
|
||||
allowClear={false}
|
||||
/>
|
||||
@@ -192,7 +191,7 @@ const Cluster = forwardRef<ClusterRef>((_props, ref) => {
|
||||
))}
|
||||
</Card>
|
||||
|
||||
<Card title={t('application.masterConfig')}>
|
||||
{values?.orchestration_mode !== 'collaboration' && <Card title={t('application.masterConfig')}>
|
||||
<Form.Item
|
||||
label={t('application.model')}
|
||||
required={true}
|
||||
@@ -218,11 +217,11 @@ const Cluster = forwardRef<ClusterRef>((_props, ref) => {
|
||||
</Row>
|
||||
</Form.Item>
|
||||
<Form.Item
|
||||
name="orchestration_mode"
|
||||
name={['execution_config',"sub_agent_execution_mode"]}
|
||||
label={t('application.orchestrationMode')}
|
||||
>
|
||||
<Select
|
||||
options={['conditional', 'sequential', 'parallel'].map((type) => ({
|
||||
options={['sequential', 'parallel'].map((type) => ({
|
||||
value: type,
|
||||
label: t(`application.${type}`),
|
||||
}))}
|
||||
@@ -239,7 +238,7 @@ const Cluster = forwardRef<ClusterRef>((_props, ref) => {
|
||||
}))}
|
||||
/>
|
||||
</Form.Item>
|
||||
</Card>
|
||||
</Card>}
|
||||
</Space>
|
||||
</Form>
|
||||
</Col>
|
||||
|
||||
@@ -90,11 +90,19 @@ const AiPromptModal = forwardRef<AiPromptModalRef, AiPromptModalProps>(({
|
||||
switch (item.event) {
|
||||
case 'start':
|
||||
currentPromptValueRef.current = ''
|
||||
if (editorRef.current?.clear) {
|
||||
editorRef.current.clear();
|
||||
}
|
||||
break;
|
||||
case 'message':
|
||||
if (content) {
|
||||
currentPromptValueRef.current += content;
|
||||
form.setFieldsValue({ current_prompt: currentPromptValueRef.current })
|
||||
if (editorRef.current?.appendText) {
|
||||
editorRef.current.appendText(content);
|
||||
editorRef.current.scrollToBottom();
|
||||
} else {
|
||||
form.setFieldsValue({ current_prompt: currentPromptValueRef.current })
|
||||
}
|
||||
}
|
||||
if (desc) {
|
||||
setChatList(prev => {
|
||||
@@ -107,6 +115,8 @@ const AiPromptModal = forwardRef<AiPromptModalRef, AiPromptModalProps>(({
|
||||
break;
|
||||
case 'end':
|
||||
setLoading(false)
|
||||
// 流结束时同步表单值
|
||||
form.setFieldsValue({ current_prompt: currentPromptValueRef.current })
|
||||
break
|
||||
}
|
||||
})
|
||||
|
||||
@@ -4,7 +4,7 @@ import { LexicalComposer } from '@lexical/react/LexicalComposer';
|
||||
import { RichTextPlugin } from '@lexical/react/LexicalRichTextPlugin';
|
||||
import { ContentEditable } from '@lexical/react/LexicalContentEditable';
|
||||
import { LexicalErrorBoundary } from '@lexical/react/LexicalErrorBoundary';
|
||||
import { $getSelection } from 'lexical';
|
||||
import { $getSelection, $getRoot, $createParagraphNode, $createTextNode, $isParagraphNode, $isTextNode } from 'lexical';
|
||||
import { useLexicalComposerContext } from '@lexical/react/LexicalComposerContext';
|
||||
import InitialValuePlugin from './plugin/InitialValuePlugin'
|
||||
import LineBreakPlugin from './plugin/LineBreakPlugin';
|
||||
@@ -12,6 +12,9 @@ import InsertTextPlugin from './plugin/InsertTextPlugin';
|
||||
|
||||
export interface EditorRef {
|
||||
insertText: (text: string) => void;
|
||||
appendText: (text: string) => void;
|
||||
clear: () => void;
|
||||
scrollToBottom: () => void;
|
||||
}
|
||||
|
||||
interface LexicalEditorProps {
|
||||
@@ -46,6 +49,41 @@ const EditorContent = forwardRef<EditorRef, LexicalEditorProps>(({
|
||||
selection.insertText(text);
|
||||
}
|
||||
});
|
||||
},
|
||||
appendText: (text: string) => {
|
||||
editor.update(() => {
|
||||
const root = $getRoot();
|
||||
const lastChild = root.getLastChild();
|
||||
if (lastChild && $isParagraphNode(lastChild)) {
|
||||
const lastTextNode = lastChild.getLastChild();
|
||||
if (lastTextNode && $isTextNode(lastTextNode)) {
|
||||
const currentText = lastTextNode.getTextContent();
|
||||
lastTextNode.setTextContent(currentText + text);
|
||||
} else {
|
||||
const textNode = $createTextNode(text);
|
||||
lastChild.append(textNode);
|
||||
}
|
||||
} else {
|
||||
const paragraph = $createParagraphNode();
|
||||
const textNode = $createTextNode(text);
|
||||
paragraph.append(textNode);
|
||||
root.append(paragraph);
|
||||
}
|
||||
});
|
||||
},
|
||||
clear: () => {
|
||||
editor.update(() => {
|
||||
const root = $getRoot();
|
||||
root.clear();
|
||||
const paragraph = $createParagraphNode();
|
||||
root.append(paragraph);
|
||||
});
|
||||
},
|
||||
scrollToBottom: () => {
|
||||
const editorElement = editor.getRootElement();
|
||||
if (editorElement) {
|
||||
editorElement.scrollTop = editorElement.scrollHeight;
|
||||
}
|
||||
}
|
||||
}), [editor]);
|
||||
|
||||
|
||||
@@ -1,21 +1,37 @@
|
||||
import { type FC, useEffect } from 'react';
|
||||
import { type FC, useEffect, useRef } from 'react';
|
||||
import { $getRoot, $createParagraphNode, $createTextNode } from 'lexical';
|
||||
import { useLexicalComposerContext } from '@lexical/react/LexicalComposerContext';
|
||||
|
||||
// 设置初始值的插件
|
||||
const InitialValuePlugin: FC<{ value?: string }> = ({ value }) => {
|
||||
const [editor] = useLexicalComposerContext();
|
||||
const lastValueRef = useRef<string | undefined>(undefined);
|
||||
|
||||
useEffect(() => {
|
||||
if (value) {
|
||||
// 只有当value真正发生变化时才更新
|
||||
if (lastValueRef.current !== value) {
|
||||
editor.update(() => {
|
||||
const root = $getRoot();
|
||||
const currentText = root.getTextContent();
|
||||
|
||||
// 如果当前内容和新值相同,则不更新
|
||||
if (currentText === (value || '')) {
|
||||
return;
|
||||
}
|
||||
|
||||
root.clear();
|
||||
const paragraph = $createParagraphNode();
|
||||
const textNode = $createTextNode(value);
|
||||
paragraph.append(textNode);
|
||||
root.append(paragraph);
|
||||
if (value) {
|
||||
const paragraph = $createParagraphNode();
|
||||
const textNode = $createTextNode(value);
|
||||
paragraph.append(textNode);
|
||||
root.append(paragraph);
|
||||
} else {
|
||||
// 当value为undefined或空时,创建一个空段落
|
||||
const paragraph = $createParagraphNode();
|
||||
root.append(paragraph);
|
||||
}
|
||||
});
|
||||
lastValueRef.current = value;
|
||||
}
|
||||
}, [editor, value]);
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import { type FC, useState } from 'react';
|
||||
import { type FC, useState, useEffect } from 'react';
|
||||
import { LexicalComposer } from '@lexical/react/LexicalComposer';
|
||||
import { RichTextPlugin } from '@lexical/react/LexicalRichTextPlugin';
|
||||
import { ContentEditable } from '@lexical/react/LexicalContentEditable';
|
||||
@@ -23,6 +23,7 @@ interface LexicalEditorProps {
|
||||
options: Suggestion[];
|
||||
variant?: 'outlined' | 'borderless';
|
||||
height?: number;
|
||||
enableJinja2?: boolean;
|
||||
}
|
||||
|
||||
const theme = {
|
||||
@@ -33,6 +34,15 @@ const theme = {
|
||||
},
|
||||
};
|
||||
|
||||
const jinja2Theme = {
|
||||
...theme,
|
||||
code: 'jinja2-expression',
|
||||
text: {
|
||||
...theme.text,
|
||||
code: 'jinja2-inline',
|
||||
},
|
||||
};
|
||||
|
||||
const Editor: FC<LexicalEditorProps> =({
|
||||
placeholder = "请输入内容...",
|
||||
value = "",
|
||||
@@ -40,19 +50,62 @@ const Editor: FC<LexicalEditorProps> =({
|
||||
options,
|
||||
variant = 'borderless',
|
||||
height = 60,
|
||||
enableJinja2 = false,
|
||||
}) => {
|
||||
|
||||
const [_count, setCount] = useState(0);
|
||||
|
||||
useEffect(() => {
|
||||
if (enableJinja2) {
|
||||
const styleId = 'jinja2-styles';
|
||||
let existingStyle = document.getElementById(styleId);
|
||||
|
||||
if (!existingStyle) {
|
||||
const style = document.createElement('style');
|
||||
style.id = styleId;
|
||||
style.textContent = `
|
||||
.jinja2-expression {
|
||||
background-color: #f6f8fa !important;
|
||||
border: 1px solid #d1d9e0 !important;
|
||||
border-radius: 3px !important;
|
||||
padding: 2px 4px !important;
|
||||
font-family: 'SFMono-Regular', Consolas, 'Liberation Mono', Menlo, monospace !important;
|
||||
font-size: 13px !important;
|
||||
color: #0969da !important;
|
||||
}
|
||||
.jinja2-inline {
|
||||
background-color: #f6f8fa !important;
|
||||
padding: 1px 3px !important;
|
||||
border-radius: 2px !important;
|
||||
font-family: 'SFMono-Regular', Consolas, 'Liberation Mono', Menlo, monospace !important;
|
||||
font-size: 13px !important;
|
||||
color: #0969da !important;
|
||||
}
|
||||
.editor-paragraph {
|
||||
margin: 0;
|
||||
}
|
||||
.editor-paragraph:has-text('{') .editor-text,
|
||||
.editor-paragraph:has-text('[') .editor-text {
|
||||
font-family: 'SFMono-Regular', Consolas, 'Liberation Mono', Menlo, monospace !important;
|
||||
}
|
||||
`;
|
||||
document.head.appendChild(style);
|
||||
}
|
||||
}
|
||||
}, [enableJinja2]);
|
||||
const initialConfig = {
|
||||
namespace: 'AutocompleteEditor',
|
||||
theme,
|
||||
nodes: [
|
||||
theme: enableJinja2 ? jinja2Theme : theme,
|
||||
nodes: enableJinja2 ? [
|
||||
// 当启用jinja2时,不使用VariableNode,使用普通文本
|
||||
] : [
|
||||
// HeadingNode,
|
||||
// QuoteNode,
|
||||
// ListItemNode,
|
||||
// ListNode,
|
||||
// LinkNode,
|
||||
// CodeNode,
|
||||
VariableNode
|
||||
VariableNode,
|
||||
],
|
||||
onError: (error: Error) => {
|
||||
console.error(error);
|
||||
@@ -96,9 +149,9 @@ const Editor: FC<LexicalEditorProps> =({
|
||||
/>
|
||||
<HistoryPlugin />
|
||||
<CommandPlugin />
|
||||
<AutocompletePlugin options={options} />
|
||||
<AutocompletePlugin options={options} enableJinja2={enableJinja2} />
|
||||
<CharacterCountPlugin setCount={(count) => { setCount(count) }} onChange={onChange} />
|
||||
<InitialValuePlugin value={value} options={options} />
|
||||
<InitialValuePlugin value={value} options={options} enableJinja2={enableJinja2} />
|
||||
</div>
|
||||
</LexicalComposer>
|
||||
);
|
||||
|
||||
@@ -17,7 +17,7 @@ export interface Suggestion {
|
||||
disabled?: boolean; // 标记是否禁用
|
||||
}
|
||||
|
||||
const AutocompletePlugin: FC<{ options: Suggestion[] }> = ({ options }) => {
|
||||
const AutocompletePlugin: FC<{ options: Suggestion[], enableJinja2?: boolean }> = ({ options, enableJinja2 = false }) => {
|
||||
const [editor] = useLexicalComposerContext();
|
||||
const [showSuggestions, setShowSuggestions] = useState(false);
|
||||
const [selectedIndex, setSelectedIndex] = useState(0);
|
||||
@@ -82,7 +82,32 @@ const AutocompletePlugin: FC<{ options: Suggestion[] }> = ({ options }) => {
|
||||
}, [editor]);
|
||||
|
||||
const insertMention = (suggestion: Suggestion) => {
|
||||
editor.dispatchCommand(INSERT_VARIABLE_COMMAND, { data: suggestion });
|
||||
if (enableJinja2) {
|
||||
// 在jinja2模式下,插入{{variable}}格式的文本
|
||||
editor.update(() => {
|
||||
const selection = $getSelection();
|
||||
if ($isRangeSelection(selection)) {
|
||||
const anchorNode = selection.anchor.getNode();
|
||||
const anchorOffset = selection.anchor.offset;
|
||||
const nodeText = anchorNode.getTextContent();
|
||||
|
||||
// 移除触发字符'/'
|
||||
const textBefore = nodeText.substring(0, anchorOffset - 1);
|
||||
const textAfter = nodeText.substring(anchorOffset);
|
||||
const newText = textBefore + `{{${suggestion.value}}}` + textAfter;
|
||||
|
||||
anchorNode.setTextContent(newText);
|
||||
|
||||
// 设置光标位置到插入文本之后
|
||||
const newOffset = textBefore.length + `{{${suggestion.value}}}`.length;
|
||||
selection.anchor.offset = newOffset;
|
||||
selection.focus.offset = newOffset;
|
||||
}
|
||||
});
|
||||
} else {
|
||||
// 普通模式下使用VariableNode
|
||||
editor.dispatchCommand(INSERT_VARIABLE_COMMAND, { data: suggestion });
|
||||
}
|
||||
setShowSuggestions(false);
|
||||
};
|
||||
|
||||
|
||||
@@ -8,14 +8,31 @@ import { type Suggestion } from '../plugin/AutocompletePlugin'
|
||||
interface InitialValuePluginProps {
|
||||
value: string;
|
||||
options?: Suggestion[];
|
||||
enableJinja2?: boolean;
|
||||
}
|
||||
|
||||
const InitialValuePlugin: React.FC<InitialValuePluginProps> = ({ value, options = [] }) => {
|
||||
const InitialValuePlugin: React.FC<InitialValuePluginProps> = ({ value, options = [], enableJinja2 = false }) => {
|
||||
const [editor] = useLexicalComposerContext();
|
||||
const initializedRef = useRef(false);
|
||||
const prevValueRef = useRef<string>('');
|
||||
const isUserInputRef = useRef(false);
|
||||
|
||||
useEffect(() => {
|
||||
if (!initializedRef.current && value) {
|
||||
// 监听编辑器变化,标记是否为用户输入
|
||||
const removeListener = editor.registerUpdateListener(({ editorState }) => {
|
||||
editorState.read(() => {
|
||||
const root = $getRoot();
|
||||
const textContent = root.getTextContent();
|
||||
if (textContent !== prevValueRef.current) {
|
||||
isUserInputRef.current = true;
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
return removeListener;
|
||||
}, [editor]);
|
||||
|
||||
useEffect(() => {
|
||||
if (value !== prevValueRef.current && !isUserInputRef.current) {
|
||||
editor.update(() => {
|
||||
const root = $getRoot();
|
||||
root.clear();
|
||||
@@ -28,7 +45,11 @@ const InitialValuePlugin: React.FC<InitialValuePluginProps> = ({ value, options
|
||||
const contextMatch = part.match(/^\{\{context\}\}$/);
|
||||
const conversationMatch = part.match(/^\{\{conv\.([^}]+)\}\}$/);
|
||||
|
||||
// 匹配{{context}}格式
|
||||
if (enableJinja2) {
|
||||
paragraph.append($createTextNode(part));
|
||||
return;
|
||||
}
|
||||
|
||||
if (contextMatch) {
|
||||
const contextSuggestion = options.find(s => s.isContext && s.label === 'context');
|
||||
if (contextSuggestion) {
|
||||
@@ -39,7 +60,6 @@ const InitialValuePlugin: React.FC<InitialValuePluginProps> = ({ value, options
|
||||
return
|
||||
}
|
||||
|
||||
// 匹配{{conv.xx}}格式
|
||||
if (conversationMatch) {
|
||||
const [_, variableName] = conversationMatch;
|
||||
const conversationSuggestion = options.find(s =>
|
||||
@@ -53,7 +73,6 @@ const InitialValuePlugin: React.FC<InitialValuePluginProps> = ({ value, options
|
||||
return
|
||||
}
|
||||
|
||||
// 匹配普通变量{{nodeId.label}}格式
|
||||
if (match) {
|
||||
const [_, nodeId, label] = match;
|
||||
|
||||
@@ -75,11 +94,12 @@ const InitialValuePlugin: React.FC<InitialValuePluginProps> = ({ value, options
|
||||
});
|
||||
|
||||
root.append(paragraph);
|
||||
});
|
||||
|
||||
initializedRef.current = true;
|
||||
}, { discrete: true });
|
||||
}
|
||||
}, [options]);
|
||||
|
||||
prevValueRef.current = value;
|
||||
isUserInputRef.current = false;
|
||||
}, [value, options, editor, enableJinja2]);
|
||||
|
||||
return null;
|
||||
};
|
||||
|
||||
@@ -0,0 +1,109 @@
|
||||
import { useEffect } from 'react';
|
||||
import { useLexicalComposerContext } from '@lexical/react/LexicalComposerContext';
|
||||
import { $getRoot, $getSelection, $isRangeSelection, TextNode, $createTextNode } from 'lexical';
|
||||
|
||||
const JsonHighlightPlugin = () => {
|
||||
const [editor] = useLexicalComposerContext();
|
||||
|
||||
useEffect(() => {
|
||||
return editor.registerNodeTransform(TextNode, (textNode: TextNode) => {
|
||||
const text = textNode.getTextContent();
|
||||
|
||||
// Check if text contains JSON-like patterns
|
||||
if (containsJsonPatterns(text)) {
|
||||
const parent = textNode.getParent();
|
||||
if (!parent) return;
|
||||
|
||||
// Split text into tokens and create new nodes with appropriate classes
|
||||
const tokens = tokenizeJson(text);
|
||||
const newNodes = tokens.map(token => {
|
||||
const newNode = $createTextNode(token.text);
|
||||
|
||||
// Set format based on token type
|
||||
switch (token.type) {
|
||||
case 'string':
|
||||
newNode.setFormat('code');
|
||||
newNode.setStyle('color: #032f62');
|
||||
break;
|
||||
case 'number':
|
||||
newNode.setFormat('code');
|
||||
newNode.setStyle('color: #005cc5');
|
||||
break;
|
||||
case 'boolean':
|
||||
newNode.setFormat('code');
|
||||
newNode.setStyle('color: #d73a49');
|
||||
break;
|
||||
case 'null':
|
||||
newNode.setFormat('code');
|
||||
newNode.setStyle('color: #6f42c1');
|
||||
break;
|
||||
case 'key':
|
||||
newNode.setFormat('code');
|
||||
newNode.setStyle('color: #22863a; font-weight: bold');
|
||||
break;
|
||||
case 'punctuation':
|
||||
newNode.setFormat('code');
|
||||
newNode.setStyle('color: #24292e');
|
||||
break;
|
||||
}
|
||||
|
||||
return newNode;
|
||||
});
|
||||
|
||||
// Replace the original text node with the new highlighted nodes
|
||||
if (newNodes.length > 1) {
|
||||
textNode.replace(newNodes[0]);
|
||||
for (let i = 1; i < newNodes.length; i++) {
|
||||
newNodes[i - 1].insertAfter(newNodes[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
}, [editor]);
|
||||
|
||||
return null;
|
||||
};
|
||||
|
||||
function containsJsonPatterns(text: string): boolean {
|
||||
// Check for JSON-like patterns
|
||||
return /[{}\[\]:,]/.test(text) ||
|
||||
/"[^"]*"/.test(text) ||
|
||||
/\b\d+(\.\d+)?\b/.test(text) ||
|
||||
/\b(true|false|null)\b/.test(text);
|
||||
}
|
||||
|
||||
function tokenizeJson(text: string): Array<{text: string, type: string}> {
|
||||
const tokens: Array<{text: string, type: string}> = [];
|
||||
const regex = /("[^"]*")|([{}\[\]:,])|(\b\d+(?:\.\d+)?\b)|(\b(?:true|false|null)\b)|(\s+)|([^\s{}\[\]:,"]+)/g;
|
||||
|
||||
let match;
|
||||
while ((match = regex.exec(text)) !== null) {
|
||||
const [fullMatch, string, punctuation, number, boolean, whitespace, other] = match;
|
||||
|
||||
if (string) {
|
||||
// Check if it's a key (followed by colon)
|
||||
const afterMatch = text.slice(match.index + fullMatch.length).trim();
|
||||
if (afterMatch.startsWith(':')) {
|
||||
tokens.push({ text: fullMatch, type: 'key' });
|
||||
} else {
|
||||
tokens.push({ text: fullMatch, type: 'string' });
|
||||
}
|
||||
} else if (punctuation) {
|
||||
tokens.push({ text: fullMatch, type: 'punctuation' });
|
||||
} else if (number) {
|
||||
tokens.push({ text: fullMatch, type: 'number' });
|
||||
} else if (boolean) {
|
||||
if (fullMatch === 'null') {
|
||||
tokens.push({ text: fullMatch, type: 'null' });
|
||||
} else {
|
||||
tokens.push({ text: fullMatch, type: 'boolean' });
|
||||
}
|
||||
} else if (whitespace || other) {
|
||||
tokens.push({ text: fullMatch, type: 'text' });
|
||||
}
|
||||
}
|
||||
|
||||
return tokens;
|
||||
}
|
||||
|
||||
export default JsonHighlightPlugin;
|
||||
@@ -28,12 +28,18 @@ const AuthConfigModal = forwardRef<AuthConfigModalRef, AuthConfigModalProps>(({
|
||||
|
||||
const handleOpen = (data?: HttpRequestConfigForm['auth']) => {
|
||||
if (data) {
|
||||
form.setFieldsValue({
|
||||
const initialValues = {
|
||||
auth: !data.auth_type || data.auth_type === 'none' ? 'none' : 'api_key',
|
||||
auth_type: !data.auth_type || data.auth_type === 'none' ? undefined : data.auth_type,
|
||||
header: data.header,
|
||||
api_key: data.api_key
|
||||
})
|
||||
}
|
||||
form.setFieldValue('auth', initialValues.auth)
|
||||
if (initialValues.auth !== 'none') {
|
||||
setTimeout(() => {
|
||||
form.setFieldsValue(initialValues)
|
||||
}, 1)
|
||||
}
|
||||
}
|
||||
setVisible(true);
|
||||
};
|
||||
@@ -91,6 +97,9 @@ const AuthConfigModal = forwardRef<AuthConfigModalRef, AuthConfigModalProps>(({
|
||||
<FormItem
|
||||
name="auth"
|
||||
label={t('workflow.config.http-request.authType')}
|
||||
rules={[
|
||||
{ required: true, message: t('common.pleaseSelect') }
|
||||
]}
|
||||
>
|
||||
<Select
|
||||
options={[
|
||||
@@ -103,6 +112,9 @@ const AuthConfigModal = forwardRef<AuthConfigModalRef, AuthConfigModalProps>(({
|
||||
<FormItem
|
||||
name="auth_type"
|
||||
label={t('workflow.config.http-request.authType')}
|
||||
rules={[
|
||||
{ required: true, message: t('common.pleaseSelect') }
|
||||
]}
|
||||
>
|
||||
<Select
|
||||
options={[
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import { useMemo, useCallback } from 'react';
|
||||
import { useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { Button, Select, Table, Form, type TableProps } from 'antd';
|
||||
import { PlusOutlined, DeleteOutlined } from '@ant-design/icons';
|
||||
@@ -6,46 +6,8 @@ import type { Suggestion } from '../../Editor/plugin/AutocompletePlugin';
|
||||
import Empty from '@/components/Empty';
|
||||
import VariableSelect from '../VariableSelect';
|
||||
|
||||
interface EditableCellProps extends React.HTMLAttributes<HTMLElement> {
|
||||
name?: string | string[];
|
||||
inputType?: 'select' | 'variableSelect';
|
||||
options?: { value: string, label: string }[] | Suggestion[];
|
||||
}
|
||||
|
||||
const EditableCell: React.FC<React.PropsWithChildren<EditableCellProps>> = ({
|
||||
name,
|
||||
inputType,
|
||||
options,
|
||||
children,
|
||||
...restProps
|
||||
}) => {
|
||||
const { t } = useTranslation();
|
||||
|
||||
if (!inputType) return <td {...restProps}>{children}</td>;
|
||||
|
||||
return (
|
||||
<td {...restProps}>
|
||||
<Form.Item name={name} style={{ margin: 0 }}>
|
||||
{inputType === 'select' ? (
|
||||
<Select
|
||||
placeholder={t('common.pleaseSelect')}
|
||||
size="small"
|
||||
options={options as { value: string, label: string }[]}
|
||||
/>
|
||||
) : (
|
||||
<VariableSelect
|
||||
placeholder={t('common.pleaseSelect')}
|
||||
size="small"
|
||||
options={(options as Suggestion[]) || []}
|
||||
/>
|
||||
)}
|
||||
</Form.Item>
|
||||
</td>
|
||||
);
|
||||
};
|
||||
|
||||
export interface TableRow {
|
||||
key: string;
|
||||
key?: string;
|
||||
name?: string;
|
||||
value?: string;
|
||||
type?: string;
|
||||
@@ -56,100 +18,158 @@ interface EditableTableProps {
|
||||
title?: string;
|
||||
options?: Suggestion[];
|
||||
typeOptions?: { value: string, label: string }[]
|
||||
filterBooleanType?: boolean;
|
||||
}
|
||||
|
||||
const EditableTable: React.FC<EditableTableProps> = ({
|
||||
parentName,
|
||||
title,
|
||||
options = [],
|
||||
typeOptions = []
|
||||
typeOptions = [],
|
||||
filterBooleanType = false
|
||||
}) => {
|
||||
const { t } = useTranslation();
|
||||
const form = Form.useFormInstance();
|
||||
const values = Form.useWatch(typeof parentName === 'string' ? [parentName] : parentName, form);
|
||||
|
||||
const createNewRow = (): TableRow => ({
|
||||
key: Date.now().toString(),
|
||||
name: undefined,
|
||||
value: undefined,
|
||||
...(typeOptions.length > 0 && { type: typeOptions[0].value })
|
||||
});
|
||||
|
||||
const handleAdd = useCallback(() => {
|
||||
form.setFieldValue(parentName, [...(values ?? []), createNewRow()]);
|
||||
}, [form, parentName, values, typeOptions]);
|
||||
|
||||
const handleDelete = useCallback((index: number) => {
|
||||
const currentValues = form.getFieldValue(parentName) || [];
|
||||
form.setFieldValue(parentName, currentValues.filter((_: TableRow, i: number) => i !== index));
|
||||
}, [form, parentName]);
|
||||
|
||||
const createColumn = (dataIndex: string, inputType: 'select' | 'variableSelect', width: string, columnOptions: any[]) => ({
|
||||
title: t(`workflow.config.${dataIndex}`),
|
||||
dataIndex,
|
||||
width,
|
||||
onCell: (_: TableRow, index?: number) => ({
|
||||
name: typeof parentName === 'string' ? [parentName, index ?? 0, dataIndex] : [...parentName, index ?? 0, dataIndex],
|
||||
inputType,
|
||||
options: columnOptions
|
||||
} as any)
|
||||
});
|
||||
|
||||
const columns: TableProps<TableRow>['columns'] = useMemo(() => {
|
||||
const getColumns = (remove: (index: number) => void): TableProps<TableRow>['columns'] => {
|
||||
const hasType = typeOptions.length > 0;
|
||||
const baseWidth = hasType ? '35%' : '45%';
|
||||
|
||||
return [
|
||||
createColumn('name', 'variableSelect', baseWidth, options),
|
||||
...(hasType ? [createColumn('type', 'select', '20%', typeOptions)] : []),
|
||||
createColumn('value', 'variableSelect', baseWidth, options),
|
||||
{
|
||||
title: t('workflow.config.name'),
|
||||
dataIndex: 'name',
|
||||
width: baseWidth,
|
||||
render: (_: any, __: TableRow, index: number) => (
|
||||
<Form.Item name={[index, 'name']} noStyle>
|
||||
<VariableSelect
|
||||
placeholder={t('common.pleaseSelect')}
|
||||
size="small"
|
||||
options={options}
|
||||
filterBooleanType={filterBooleanType}
|
||||
popupMatchSelectWidth={false}
|
||||
/>
|
||||
</Form.Item>
|
||||
)
|
||||
},
|
||||
...(hasType ? [{
|
||||
title: t('workflow.config.type'),
|
||||
dataIndex: 'type',
|
||||
width: '20%',
|
||||
render: (_: any, __: TableRow, index: number) => (
|
||||
<Form.Item shouldUpdate noStyle>
|
||||
{(form) => (
|
||||
<Form.Item name={[index, 'type']} noStyle>
|
||||
<Select
|
||||
placeholder={t('common.pleaseSelect')}
|
||||
size="small"
|
||||
options={typeOptions}
|
||||
popupMatchSelectWidth={false}
|
||||
onChange={() => {
|
||||
form.setFieldValue([...Array.isArray(parentName) ? parentName : [parentName], index, 'value'], undefined);
|
||||
}}
|
||||
/>
|
||||
</Form.Item>
|
||||
)}
|
||||
</Form.Item>
|
||||
)
|
||||
}] : []),
|
||||
{
|
||||
title: t('workflow.config.value'),
|
||||
dataIndex: 'value',
|
||||
width: baseWidth,
|
||||
render: (_: any, __: TableRow, index: number) => (
|
||||
<Form.Item
|
||||
shouldUpdate={(prevValues, currentValues) => {
|
||||
const prevType = prevValues?.[Array.isArray(parentName) ? parentName.join('.') : parentName]?.[index]?.type;
|
||||
const currentType = currentValues?.[Array.isArray(parentName) ? parentName.join('.') : parentName]?.[index]?.type;
|
||||
return prevType !== currentType;
|
||||
}}
|
||||
noStyle
|
||||
>
|
||||
{(form) => {
|
||||
const currentType = form.getFieldValue([...Array.isArray(parentName) ? parentName : [parentName], index, 'type']);
|
||||
const filteredOptions = currentType === 'file'
|
||||
? options.filter(option => option.dataType === 'file')
|
||||
: options;
|
||||
|
||||
return (
|
||||
<Form.Item name={[index, 'value']} noStyle>
|
||||
<VariableSelect
|
||||
placeholder={t('common.pleaseSelect')}
|
||||
size="small"
|
||||
options={filteredOptions}
|
||||
filterBooleanType={filterBooleanType}
|
||||
popupMatchSelectWidth={false}
|
||||
/>
|
||||
</Form.Item>
|
||||
);
|
||||
}}
|
||||
</Form.Item>
|
||||
)
|
||||
},
|
||||
{
|
||||
title: '',
|
||||
dataIndex: 'actions',
|
||||
width: '10%',
|
||||
render: (_: any, __: TableRow, index: number) => (
|
||||
<Button type="text" icon={<DeleteOutlined />} onClick={() => handleDelete(index)} />
|
||||
<Button type="text" icon={<DeleteOutlined />} onClick={() => remove(index)} />
|
||||
)
|
||||
}
|
||||
];
|
||||
}, [typeOptions, options, t, parentName, handleDelete]);
|
||||
|
||||
const AddButton = ({ block = false }: { block?: boolean }) => (
|
||||
<Button
|
||||
type={block ? "dashed" : "text"}
|
||||
icon={<PlusOutlined />}
|
||||
onClick={handleAdd}
|
||||
size="small"
|
||||
block={block}
|
||||
className={block ? "rb:mt-1" : ""}
|
||||
>
|
||||
{block && `+${t('common.add')}`}
|
||||
</Button>
|
||||
);
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="rb:mb-4">
|
||||
{title && (
|
||||
<div className="rb:flex rb:items-center rb:mb-2 rb:justify-between">
|
||||
<div className="rb:font-medium">{title}</div>
|
||||
<AddButton />
|
||||
</div>
|
||||
)}
|
||||
|
||||
<Form.Item name={parentName}>
|
||||
<Table<TableRow>
|
||||
components={{ body: { cell: EditableCell } }}
|
||||
bordered
|
||||
dataSource={values}
|
||||
columns={columns}
|
||||
pagination={false}
|
||||
size="small"
|
||||
locale={{ emptyText: <Empty size={88} /> }}
|
||||
scroll={{ x: 'max-content' }}
|
||||
/>
|
||||
</Form.Item>
|
||||
|
||||
{!title && <AddButton block />}
|
||||
<Form.List name={parentName}>
|
||||
{(fields, { add, remove }) => {
|
||||
const AddButton = ({ block = false }: { block?: boolean }) => (
|
||||
<Button
|
||||
type={block ? "dashed" : "text"}
|
||||
icon={<PlusOutlined />}
|
||||
onClick={() => add(createNewRow())}
|
||||
size="small"
|
||||
block={block}
|
||||
className={block ? "rb:mt-1" : ""}
|
||||
>
|
||||
{block && `+${t('common.add')}`}
|
||||
</Button>
|
||||
);
|
||||
|
||||
return (
|
||||
<>
|
||||
{title && (
|
||||
<div className="rb:flex rb:items-center rb:mb-2 rb:justify-between">
|
||||
<div className="rb:font-medium">{title}</div>
|
||||
<AddButton />
|
||||
</div>
|
||||
)}
|
||||
|
||||
<Table<TableRow>
|
||||
bordered
|
||||
dataSource={fields.map((field) => ({
|
||||
key: String(field.key),
|
||||
name: undefined,
|
||||
value: undefined,
|
||||
type: undefined
|
||||
}))}
|
||||
columns={getColumns(remove)}
|
||||
pagination={false}
|
||||
size="small"
|
||||
locale={{ emptyText: <Empty size={88} /> }}
|
||||
scroll={{ x: 'max-content' }}
|
||||
/>
|
||||
|
||||
{!title && <AddButton block />}
|
||||
</>
|
||||
);
|
||||
}}
|
||||
</Form.List>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -9,8 +9,10 @@ import VariableSelect from "../VariableSelect";
|
||||
import MessageEditor from '../MessageEditor'
|
||||
import EditableTable from './EditableTable'
|
||||
|
||||
const HttpRequest: FC<{ options: Suggestion[]; }> = ({
|
||||
const HttpRequest: FC<{ options: Suggestion[]; selectedNode?: any; graphRef?: any; }> = ({
|
||||
options,
|
||||
selectedNode,
|
||||
graphRef
|
||||
}) => {
|
||||
const { t } = useTranslation()
|
||||
const form = Form.useFormInstance();
|
||||
@@ -22,18 +24,45 @@ const HttpRequest: FC<{ options: Suggestion[]; }> = ({
|
||||
}
|
||||
const handleRefresh = (auth: HttpRequestConfigForm['auth']) => {
|
||||
console.log('handleRefresh', auth)
|
||||
form.setFieldsValue({ auth: {...auth} })
|
||||
form.setFieldsValue({ auth })
|
||||
}
|
||||
|
||||
const handleChangeBodyContentType = (contentType: string) => {
|
||||
const currentValues = form.getFieldsValue()
|
||||
const handleChangeBodyContentType = () => {
|
||||
form.setFieldValue(['body', 'data'], undefined)
|
||||
}
|
||||
|
||||
const handleChangeErrorHandleMethod = (method: string) => {
|
||||
form.setFieldsValue({
|
||||
body: {
|
||||
...currentValues?.body,
|
||||
content_type: contentType,
|
||||
data: undefined
|
||||
error_handle: {
|
||||
method,
|
||||
body: undefined,
|
||||
status_code: undefined,
|
||||
headers: undefined
|
||||
}
|
||||
})
|
||||
|
||||
// 更新节点连接桩
|
||||
console.log('handleChangeErrorHandleMethod', selectedNode, graphRef?.current)
|
||||
if (selectedNode && graphRef?.current) {
|
||||
const existingPorts = selectedNode.getPorts();
|
||||
const errorPort = existingPorts.find((port: any) => port.id === 'ERROR');
|
||||
|
||||
if (method === 'branch' && !errorPort) {
|
||||
// 添加异常节点连接桩
|
||||
selectedNode.addPort({
|
||||
id: 'ERROR',
|
||||
group: 'right',
|
||||
attrs: { text: { text: t('workflow.config.http-request.errorBranch'), fontSize: 12, fill: '#5B6167' }}
|
||||
});
|
||||
} else if (method !== 'branch' && errorPort) {
|
||||
// 移除异常节点连接桩和相关连线
|
||||
const edges = graphRef.current.getEdges().filter((edge: any) =>
|
||||
edge.getSourceCellId() === selectedNode.id && edge.getSourcePortId() === 'ERROR'
|
||||
);
|
||||
edges.forEach((edge: any) => graphRef.current.removeCell(edge));
|
||||
selectedNode.removePort('ERROR');
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
console.log('HttpRequest', values)
|
||||
@@ -73,6 +102,7 @@ const HttpRequest: FC<{ options: Suggestion[]; }> = ({
|
||||
parentName="headers"
|
||||
title="HEADERS"
|
||||
options={options}
|
||||
filterBooleanType={true}
|
||||
/>
|
||||
</Form.Item>
|
||||
|
||||
@@ -81,6 +111,7 @@ const HttpRequest: FC<{ options: Suggestion[]; }> = ({
|
||||
parentName="params"
|
||||
title="PARAMS"
|
||||
options={options}
|
||||
filterBooleanType={true}
|
||||
/>
|
||||
</Form.Item>
|
||||
|
||||
@@ -104,6 +135,7 @@ const HttpRequest: FC<{ options: Suggestion[]; }> = ({
|
||||
<EditableTable
|
||||
parentName={['body', 'data']}
|
||||
options={options}
|
||||
filterBooleanType={true}
|
||||
typeOptions={[
|
||||
{ label: 'text', value: 'text' },
|
||||
{ label: 'file', value: 'file' }
|
||||
@@ -116,12 +148,15 @@ const HttpRequest: FC<{ options: Suggestion[]; }> = ({
|
||||
<EditableTable
|
||||
parentName={['body', 'data']}
|
||||
options={options}
|
||||
filterBooleanType={true}
|
||||
/>
|
||||
</Form.Item>
|
||||
}
|
||||
{values?.body?.content_type === 'json' &&
|
||||
<Form.Item name={['body', 'data']}>
|
||||
<MessageEditor
|
||||
key="json"
|
||||
parentName={['body', 'data']}
|
||||
options={options}
|
||||
isArray={false}
|
||||
title="JSON"
|
||||
@@ -131,6 +166,8 @@ const HttpRequest: FC<{ options: Suggestion[]; }> = ({
|
||||
{values?.body?.content_type === 'raw' &&
|
||||
<Form.Item name={['body', 'data']}>
|
||||
<MessageEditor
|
||||
key="raw"
|
||||
parentName={['body', 'data']}
|
||||
options={options}
|
||||
isArray={false}
|
||||
title="RAW TEXT"
|
||||
@@ -141,6 +178,7 @@ const HttpRequest: FC<{ options: Suggestion[]; }> = ({
|
||||
<Form.Item name={['body', 'data']}>
|
||||
<VariableSelect
|
||||
options={options}
|
||||
filterBooleanType={true}
|
||||
/>
|
||||
</Form.Item>
|
||||
}
|
||||
@@ -185,7 +223,7 @@ const HttpRequest: FC<{ options: Suggestion[]; }> = ({
|
||||
</Form.Item>
|
||||
<Form.Item
|
||||
name={['retry', 'retry_interval']}
|
||||
label={t('workflow.config.http-request.retry_interval')}
|
||||
label={<>{t('workflow.config.http-request.retry_interval')} <span className="rb:text-[#5B6167]">(ms)</span></>}
|
||||
>
|
||||
<InputNumber placeholder={t('common.pleaseEnter')} className="rb:w-full!" />
|
||||
</Form.Item>
|
||||
@@ -196,6 +234,7 @@ const HttpRequest: FC<{ options: Suggestion[]; }> = ({
|
||||
<Form.Item layout="horizontal" name={['error_handle', 'method']} label={t('workflow.config.http-request.error_handle')}>
|
||||
<Select
|
||||
placeholder={t('common.pleaseSelect')}
|
||||
onChange={handleChangeErrorHandleMethod}
|
||||
options={[
|
||||
{ value: 'none', label: t('workflow.config.http-request.none') },
|
||||
{ value: 'default', label: t('workflow.config.http-request.default') },
|
||||
@@ -207,32 +246,19 @@ const HttpRequest: FC<{ options: Suggestion[]; }> = ({
|
||||
<>
|
||||
<Form.Item
|
||||
name={['error_handle', 'body']}
|
||||
label="body"
|
||||
label={<>body <span className="rb:text-[#5B6167] rb:ml-1">string</span></>}
|
||||
>
|
||||
<Input placeholder={t('common.pleaseEnter')} />
|
||||
</Form.Item>
|
||||
<Form.Item
|
||||
name={['error_handle', 'status_code']}
|
||||
label="status_code"
|
||||
label={<>status_code <span className="rb:text-[#5B6167] rb:ml-1">number</span></>}
|
||||
>
|
||||
<InputNumber placeholder={t('common.pleaseEnter')} className="rb:w-full!" />
|
||||
</Form.Item>
|
||||
<Form.Item
|
||||
name={['error_handle', 'headers']}
|
||||
label="headers"
|
||||
rules={[
|
||||
{
|
||||
validator: (_, value) => {
|
||||
if (!value) return Promise.resolve();
|
||||
try {
|
||||
JSON.parse(value);
|
||||
return Promise.resolve();
|
||||
} catch {
|
||||
return Promise.reject(new Error('Please enter valid JSON format'));
|
||||
}
|
||||
}
|
||||
}
|
||||
]}
|
||||
label={<>headers <span className="rb:text-[#5B6167] rb:ml-1">object</span></>}
|
||||
>
|
||||
<Input.TextArea placeholder={t('common.pleaseEnter')} />
|
||||
</Form.Item>
|
||||
|
||||
@@ -17,14 +17,14 @@ const MappingList: React.FC<MappingListProps> = ({ name, options }) => {
|
||||
{(fields, { add, remove }) => (
|
||||
<>
|
||||
{fields.map(({ key, name, ...restField }) => (
|
||||
<Row gutter={12} className="rb:mb-2">
|
||||
<Row key={key} gutter={12} className="rb:mb-2">
|
||||
<Col span={10}>
|
||||
<Form.Item
|
||||
{...restField}
|
||||
name={[name, 'name']}
|
||||
noStyle
|
||||
>
|
||||
<Input placeholder={t('common.pleaseEnter')} />
|
||||
<Input placeholder={t('common.pleaseEnter')} data-field-type="mapping-name" />
|
||||
</Form.Item>
|
||||
</Col>
|
||||
<Col span={12}>
|
||||
|
||||
@@ -5,14 +5,15 @@ import { MinusCircleOutlined } from '@ant-design/icons';
|
||||
import Editor from '../Editor'
|
||||
import type { Suggestion } from '../Editor/plugin/AutocompletePlugin'
|
||||
|
||||
interface TextareaProps {
|
||||
interface MessageEditor {
|
||||
options: Suggestion[];
|
||||
title?: string
|
||||
isArray?: boolean;
|
||||
parentName?: string;
|
||||
parentName?: string | string[];
|
||||
label?: string;
|
||||
placeholder?: string;
|
||||
value?: string;
|
||||
enableJinja2?: boolean;
|
||||
onChange?: (value?: string) => void;
|
||||
}
|
||||
const roleOptions = [
|
||||
@@ -20,12 +21,13 @@ const roleOptions = [
|
||||
{ label: 'USER', value: 'USER' },
|
||||
{ label: 'ASSISTANT', value: 'ASSISTANT' },
|
||||
]
|
||||
const MessageEditor: FC<TextareaProps> = ({
|
||||
const MessageEditor: FC<MessageEditor> = ({
|
||||
title,
|
||||
isArray = true,
|
||||
parentName = 'messages',
|
||||
placeholder,
|
||||
options,
|
||||
enableJinja2 = false,
|
||||
}) => {
|
||||
const { t } = useTranslation()
|
||||
const form = Form.useFormInstance();
|
||||
@@ -33,10 +35,17 @@ const MessageEditor: FC<TextareaProps> = ({
|
||||
|
||||
// 检查是否已经使用了context变量,将已使用的context设置为disabled
|
||||
const processedOptions = useMemo(() => {
|
||||
if (!isArray || !values?.[parentName]) return options;
|
||||
if (!isArray) return options;
|
||||
|
||||
// 获取表单中对应字段的值
|
||||
const fieldValue = Array.isArray(parentName)
|
||||
? parentName.reduce((obj, key) => obj?.[key], values)
|
||||
: values?.[parentName];
|
||||
|
||||
if (!fieldValue) return options;
|
||||
|
||||
// 获取所有消息内容
|
||||
const allContents = values[parentName]
|
||||
const allContents = fieldValue
|
||||
.map((msg: any) => msg?.content || '')
|
||||
.join(' ');
|
||||
|
||||
@@ -50,7 +59,11 @@ const MessageEditor: FC<TextareaProps> = ({
|
||||
}, [options, values, parentName, isArray]);
|
||||
|
||||
const handleAdd = (add: FormListOperation['add']) => {
|
||||
const list = values?.[parentName] || [];
|
||||
const fieldValue = Array.isArray(parentName)
|
||||
? parentName.reduce((obj, key) => obj?.[key], values)
|
||||
: values?.[parentName];
|
||||
|
||||
const list = fieldValue || [];
|
||||
const lastRole = list.length > 0 ? list[list.length - 1]?.role : 'ASSISTANT';
|
||||
|
||||
add({
|
||||
@@ -61,14 +74,14 @@ const MessageEditor: FC<TextareaProps> = ({
|
||||
|
||||
if (!isArray) {
|
||||
return (
|
||||
<Space size={12} direction="vertical" className="rb:w-full rb:border rb:border-[#DFE4ED] rb:rounded-md rb:px-2 rb:py-1.5 rb:bg-white">
|
||||
<Space size={12} direction="vertical" className="rb:w-full rb:border rb:border-[#DFE4ED] rb:rounded-md rb:px-2 rb:py-1.5 rb:bg-white" data-editor-type={parentName === 'template' ? 'template' : undefined}>
|
||||
<Row>
|
||||
<Col span={12}>
|
||||
{title ?? t('workflow.answerDesc')}
|
||||
</Col>
|
||||
</Row>
|
||||
<Form.Item name={parentName} noStyle>
|
||||
<Editor placeholder={placeholder} options={processedOptions} />
|
||||
<Editor enableJinja2={enableJinja2} placeholder={placeholder} options={processedOptions} />
|
||||
</Form.Item>
|
||||
</Space>
|
||||
);
|
||||
@@ -79,7 +92,11 @@ const MessageEditor: FC<TextareaProps> = ({
|
||||
{(fields, { add, remove }) => (
|
||||
<Space size={12} direction="vertical" className="rb:w-full">
|
||||
{fields.map(({ key, name, ...restField }) => {
|
||||
const currentRole = (values?.[parentName]?.[name]?.role || 'USER').toUpperCase();
|
||||
const fieldValue = Array.isArray(parentName)
|
||||
? parentName.reduce((obj, key) => obj?.[key], values)
|
||||
: values?.[parentName];
|
||||
|
||||
const currentRole = (fieldValue?.[name]?.role || 'USER').toUpperCase();
|
||||
|
||||
return (
|
||||
<Space key={key} size={12} direction="vertical" className="rb:w-full rb:border rb:border-[#DFE4ED] rb:rounded-md rb:px-2 rb:py-1.5 rb:bg-white">
|
||||
@@ -105,7 +122,7 @@ const MessageEditor: FC<TextareaProps> = ({
|
||||
)}
|
||||
</Row>
|
||||
<Form.Item {...restField} name={[name, 'content']} noStyle>
|
||||
<Editor placeholder={placeholder} options={processedOptions} />
|
||||
<Editor enableJinja2={enableJinja2} placeholder={placeholder} options={processedOptions} />
|
||||
</Form.Item>
|
||||
</Space>
|
||||
);
|
||||
|
||||
@@ -9,6 +9,7 @@ interface VariableSelectProps extends SelectProps {
|
||||
value?: string;
|
||||
onChange?: (value: string) => void;
|
||||
allowClear?: boolean;
|
||||
filterBooleanType?: boolean;
|
||||
}
|
||||
|
||||
const VariableSelect: FC<VariableSelectProps> = ({
|
||||
@@ -18,6 +19,7 @@ const VariableSelect: FC<VariableSelectProps> = ({
|
||||
allowClear = true,
|
||||
onChange,
|
||||
size,
|
||||
filterBooleanType = false,
|
||||
...resetPorps
|
||||
}) => {
|
||||
|
||||
@@ -26,7 +28,7 @@ const VariableSelect: FC<VariableSelectProps> = ({
|
||||
}
|
||||
const labelRender: LabelRender = (props) => {
|
||||
const { value } = props
|
||||
const filterOption = options.find(vo => `{{${vo.value}}}` === value)
|
||||
const filterOption = filteredOptions.find(vo => `{{${vo.value}}}` === value)
|
||||
|
||||
if (filterOption) {
|
||||
return (
|
||||
@@ -54,7 +56,11 @@ const VariableSelect: FC<VariableSelectProps> = ({
|
||||
}
|
||||
return null
|
||||
}
|
||||
const groupedSuggestions = options.reduce((groups: Record<string, any[]>, suggestion) => {
|
||||
const filteredOptions = filterBooleanType
|
||||
? options.filter(option => option.dataType !== 'boolean')
|
||||
: options;
|
||||
|
||||
const groupedSuggestions = filteredOptions.reduce((groups: Record<string, any[]>, suggestion) => {
|
||||
const { nodeData } = suggestion
|
||||
const nodeId = nodeData.id as string;
|
||||
if (!groups[nodeId]) {
|
||||
@@ -64,12 +70,10 @@ const VariableSelect: FC<VariableSelectProps> = ({
|
||||
return groups;
|
||||
}, {});
|
||||
|
||||
const groupedOptions = Object.entries(groupedSuggestions).map(([nodeId, suggestions]) => ({
|
||||
const groupedOptions = Object.entries(groupedSuggestions).map(([_nodeId, suggestions]) => ({
|
||||
label: suggestions[0].nodeData.name,
|
||||
options: suggestions.map(s => ({ label: s.label, value: `{{${s.value}}}` }))
|
||||
}));
|
||||
|
||||
console.log('groupedOptions', groupedOptions)
|
||||
|
||||
return (
|
||||
<Select
|
||||
|
||||
@@ -45,13 +45,134 @@ const Properties: FC<PropertiesProps> = ({
|
||||
const values = Form.useWatch([], form);
|
||||
const variableModalRef = useRef<VariableEditModalRef>(null)
|
||||
const [editIndex, setEditIndex] = useState<number | null>(null)
|
||||
const prevMappingNamesRef = useRef<string[]>([])
|
||||
const prevTemplateVarsRef = useRef<string[]>([])
|
||||
const syncTimeoutRef = useRef<NodeJS.Timeout | null>(null)
|
||||
const isSyncingRef = useRef(false)
|
||||
const lastSyncSourceRef = useRef<'mapping' | 'template' | null>(null)
|
||||
|
||||
useEffect(() => {
|
||||
if (selectedNode?.getData()?.id) {
|
||||
form.resetFields()
|
||||
prevMappingNamesRef.current = []
|
||||
prevTemplateVarsRef.current = []
|
||||
lastSyncSourceRef.current = null
|
||||
}
|
||||
}, [selectedNode?.getData()?.id])
|
||||
|
||||
// Sync template when mapping names change
|
||||
useEffect(() => {
|
||||
if (isSyncingRef.current || lastSyncSourceRef.current === 'mapping' || selectedNode?.data?.type !== 'jinja-render' || !values?.mapping || !values?.template) return
|
||||
|
||||
const currentMappingNames = values.mapping.map((item: any) => item.name).filter(Boolean)
|
||||
const prevNames = prevMappingNamesRef.current
|
||||
|
||||
if (prevNames.length === 0) {
|
||||
prevMappingNamesRef.current = currentMappingNames
|
||||
return
|
||||
}
|
||||
|
||||
if (JSON.stringify(prevNames) === JSON.stringify(currentMappingNames)) return
|
||||
|
||||
if (syncTimeoutRef.current) clearTimeout(syncTimeoutRef.current)
|
||||
const activeElement = document.activeElement as HTMLElement
|
||||
|
||||
syncTimeoutRef.current = setTimeout(() => {
|
||||
let updatedTemplate = String(form.getFieldValue('template') || '')
|
||||
|
||||
prevNames.forEach((oldName, index) => {
|
||||
const newName = currentMappingNames[index]
|
||||
if (newName && oldName !== newName) {
|
||||
updatedTemplate = updatedTemplate.replace(new RegExp(`{{\\s*${oldName}\\s*}}`, 'g'), `{{${newName}}}`)
|
||||
}
|
||||
})
|
||||
|
||||
if (updatedTemplate !== form.getFieldValue('template')) {
|
||||
isSyncingRef.current = true
|
||||
lastSyncSourceRef.current = 'mapping'
|
||||
const newTemplateVars = (updatedTemplate.match(/{{\s*([\w.]+)\s*}}/g) || []).map(m => m.replace(/{{\s*|\s*}}/g, ''))
|
||||
prevTemplateVarsRef.current = newTemplateVars
|
||||
prevMappingNamesRef.current = currentMappingNames
|
||||
form.setFieldValue('template', updatedTemplate)
|
||||
|
||||
requestAnimationFrame(() => {
|
||||
activeElement?.focus?.()
|
||||
setTimeout(() => {
|
||||
isSyncingRef.current = false
|
||||
lastSyncSourceRef.current = null
|
||||
}, 50)
|
||||
})
|
||||
} else {
|
||||
prevMappingNamesRef.current = currentMappingNames
|
||||
}
|
||||
}, 0)
|
||||
}, [values?.mapping, selectedNode?.data?.type, form])
|
||||
|
||||
// Sync mapping when template variables change
|
||||
useEffect(() => {
|
||||
if (isSyncingRef.current || lastSyncSourceRef.current === 'template' || selectedNode?.data?.type !== 'jinja-render' || !values?.template || !values?.mapping) return
|
||||
|
||||
const templateVars = (String(values.template).match(/{{\s*([\w.]+)\s*}}/g) || []).map(m => m.replace(/{{\s*|\s*}}/g, ''))
|
||||
if (JSON.stringify(prevTemplateVarsRef.current) === JSON.stringify(templateVars)) return
|
||||
|
||||
const isTemplateEditor = document.activeElement?.closest('[data-editor-type="template"]')
|
||||
if (!isTemplateEditor) {
|
||||
prevTemplateVarsRef.current = templateVars
|
||||
return
|
||||
}
|
||||
|
||||
const updatedMapping = [...values.mapping]
|
||||
const existingNames = updatedMapping.map(item => item.name)
|
||||
let updatedTemplate = String(values.template)
|
||||
|
||||
if (prevTemplateVarsRef.current.length > 0) {
|
||||
prevTemplateVarsRef.current.forEach((oldVar, index) => {
|
||||
const newVar = templateVars[index]
|
||||
if (newVar && oldVar !== newVar && updatedMapping[index]) {
|
||||
updatedMapping[index] = { ...updatedMapping[index], name: newVar }
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
templateVars.forEach(varName => {
|
||||
const existingMapping = updatedMapping.find(item => item.value === `{{${varName}}}`)
|
||||
const regex = new RegExp(`{{\\s*${varName.replace(/\./g, '\\.')}\\s*}}`, 'g')
|
||||
|
||||
if (existingMapping) {
|
||||
updatedTemplate = updatedTemplate.replace(regex, `{{${existingMapping.name}}}`)
|
||||
} else if (!existingNames.includes(varName)) {
|
||||
const mappingName = varName.includes('.') ? varName.split('.').pop() || varName : varName
|
||||
updatedMapping.push({ name: mappingName, value: `{{${varName}}}` })
|
||||
updatedTemplate = updatedTemplate.replace(regex, `{{${mappingName}}}`)
|
||||
}
|
||||
})
|
||||
|
||||
const seenNames = new Set<string>()
|
||||
const finalMapping = updatedMapping.filter(item => {
|
||||
const isUsed = templateVars.some(v => item.name === v || item.value === `{{${v}}}`)
|
||||
if (!isUsed || seenNames.has(item.name)) return false
|
||||
seenNames.add(item.name)
|
||||
return true
|
||||
})
|
||||
|
||||
isSyncingRef.current = true
|
||||
lastSyncSourceRef.current = 'template'
|
||||
prevMappingNamesRef.current = finalMapping.map((item: any) => item.name).filter(Boolean)
|
||||
prevTemplateVarsRef.current = templateVars
|
||||
|
||||
if (JSON.stringify(finalMapping) !== JSON.stringify(values.mapping)) {
|
||||
form.setFieldValue('mapping', finalMapping)
|
||||
}
|
||||
if (updatedTemplate !== String(values.template)) {
|
||||
form.setFieldValue('template', updatedTemplate)
|
||||
}
|
||||
|
||||
setTimeout(() => {
|
||||
isSyncingRef.current = false
|
||||
lastSyncSourceRef.current = null
|
||||
}, 50)
|
||||
}, [values?.template, selectedNode?.data?.type, form])
|
||||
|
||||
useEffect(() => {
|
||||
if (selectedNode && form) {
|
||||
const { type = 'default', name = '', config } = selectedNode.getData() || {}
|
||||
@@ -96,6 +217,8 @@ const Properties: FC<PropertiesProps> = ({
|
||||
}))
|
||||
}
|
||||
|
||||
|
||||
|
||||
Object.keys(values).forEach(key => {
|
||||
if (selectedNode.data?.config?.[key]) {
|
||||
// Create a deep copy to avoid reference sharing between nodes
|
||||
@@ -114,7 +237,7 @@ const Properties: FC<PropertiesProps> = ({
|
||||
...allRest,
|
||||
})
|
||||
}
|
||||
}, [values, selectedNode])
|
||||
}, [values, selectedNode, form])
|
||||
|
||||
const handleAddVariable = () => {
|
||||
variableModalRef.current?.handleOpen()
|
||||
@@ -190,11 +313,88 @@ const Properties: FC<PropertiesProps> = ({
|
||||
.map(node => node.id);
|
||||
};
|
||||
|
||||
// Find parent loop/iteration node if current node is a child
|
||||
const getParentLoopNode = (nodeId: string): Node | null => {
|
||||
const node = nodes.find(n => n.id === nodeId);
|
||||
if (!node) return null;
|
||||
|
||||
const nodeData = node.getData();
|
||||
const cycle = nodeData?.cycle;
|
||||
|
||||
if (cycle) {
|
||||
const parentNode = nodes.find(n => n.getData().id === cycle);
|
||||
if (parentNode) {
|
||||
const parentData = parentNode.getData();
|
||||
if (parentData?.type === 'loop' || parentData?.type === 'iteration') {
|
||||
return parentNode;
|
||||
}
|
||||
}
|
||||
}
|
||||
return null;
|
||||
};
|
||||
|
||||
const allPreviousNodeIds = getAllPreviousNodes(selectedNode.id);
|
||||
const childNodeIds = getChildNodes(selectedNode.id);
|
||||
const parentLoopNode = getParentLoopNode(selectedNode.id);
|
||||
|
||||
console.log('childNodeIds', selectedNode, childNodeIds)
|
||||
const allRelevantNodeIds = [...allPreviousNodeIds, ...childNodeIds];
|
||||
|
||||
// Add parent loop/iteration node variables if current node is a child
|
||||
if (parentLoopNode) {
|
||||
const parentData = parentLoopNode.getData();
|
||||
|
||||
if (parentData.type === 'loop') {
|
||||
const cycleVars = parentData.cycle_vars || [];
|
||||
cycleVars.forEach((cycleVar: any) => {
|
||||
const key = `${parentLoopNode.getData().id}_cycle_${cycleVar.name}`;
|
||||
if (!addedKeys.has(key)) {
|
||||
addedKeys.add(key);
|
||||
variableList.push({
|
||||
key,
|
||||
label: cycleVar.name,
|
||||
type: 'variable',
|
||||
dataType: cycleVar.type || 'String',
|
||||
value: `${parentLoopNode.getData().id}.${cycleVar.name}`,
|
||||
nodeData: parentData,
|
||||
});
|
||||
}
|
||||
});
|
||||
} else if (parentData.type === 'iteration') {
|
||||
// Add item and index variables for iteration parent
|
||||
const itemKey = `${parentLoopNode.getData().id}_item`;
|
||||
const indexKey = `${parentLoopNode.getData().id}_index`;
|
||||
|
||||
if (!addedKeys.has(itemKey)) {
|
||||
addedKeys.add(itemKey);
|
||||
variableList.push({
|
||||
key: itemKey,
|
||||
label: 'item',
|
||||
type: 'variable',
|
||||
dataType: 'Object',
|
||||
value: `${parentLoopNode.getData().id}.item`,
|
||||
nodeData: parentData,
|
||||
});
|
||||
}
|
||||
|
||||
if (!addedKeys.has(indexKey)) {
|
||||
addedKeys.add(indexKey);
|
||||
variableList.push({
|
||||
key: indexKey,
|
||||
label: 'index',
|
||||
type: 'variable',
|
||||
dataType: 'Number',
|
||||
value: `${parentLoopNode.getData().id}.index`,
|
||||
nodeData: parentData,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Add variables from nodes preceding the parent loop/iteration node
|
||||
const parentPreviousNodeIds = getAllPreviousNodes(parentLoopNode.id);
|
||||
allRelevantNodeIds.push(...parentPreviousNodeIds);
|
||||
}
|
||||
|
||||
allRelevantNodeIds.forEach(nodeId => {
|
||||
const node = nodes.find(n => n.id === nodeId);
|
||||
if (!node) return;
|
||||
@@ -363,6 +563,87 @@ const Properties: FC<PropertiesProps> = ({
|
||||
});
|
||||
}
|
||||
break
|
||||
case 'question-classifier':
|
||||
const classNameKey = `${nodeId}_class_name`;
|
||||
const outputKey = `${nodeId}_output`;
|
||||
if (!addedKeys.has(classNameKey)) {
|
||||
addedKeys.add(classNameKey);
|
||||
variableList.push({
|
||||
key: classNameKey,
|
||||
label: 'class_name',
|
||||
type: 'variable',
|
||||
dataType: 'string',
|
||||
value: `${node.getData().id}.class_name`,
|
||||
nodeData: nodeData,
|
||||
});
|
||||
}
|
||||
if (!addedKeys.has(outputKey)) {
|
||||
addedKeys.add(outputKey);
|
||||
variableList.push({
|
||||
key: outputKey,
|
||||
label: 'output',
|
||||
type: 'variable',
|
||||
dataType: 'string',
|
||||
value: `${node.getData().id}.output`,
|
||||
nodeData: nodeData,
|
||||
});
|
||||
}
|
||||
break
|
||||
case 'iteration':
|
||||
const iterationOutputKey = `${nodeId}_output`;
|
||||
if (!addedKeys.has(iterationOutputKey)) {
|
||||
addedKeys.add(iterationOutputKey);
|
||||
// Get the data type from the output configuration, default to string
|
||||
const outputConfig = nodeData.output;
|
||||
let outputDataType = 'string';
|
||||
if (outputConfig) {
|
||||
// Find the selected variable from variableList to get its type
|
||||
const selectedVariable = variableList.find(v => v.value === outputConfig);
|
||||
if (selectedVariable) {
|
||||
outputDataType = selectedVariable.dataType;
|
||||
}
|
||||
}
|
||||
variableList.push({
|
||||
key: iterationOutputKey,
|
||||
label: 'output',
|
||||
type: 'variable',
|
||||
dataType: outputDataType,
|
||||
value: `${node.getData().id}.output`,
|
||||
nodeData: nodeData,
|
||||
});
|
||||
}
|
||||
break
|
||||
case 'loop':
|
||||
const cycleVars = nodeData.cycle_vars || [];
|
||||
cycleVars.forEach((cycleVar: any) => {
|
||||
const cycleVarKey = `${nodeId}_cycle_${cycleVar.name}`;
|
||||
if (!addedKeys.has(cycleVarKey)) {
|
||||
addedKeys.add(cycleVarKey);
|
||||
variableList.push({
|
||||
key: cycleVarKey,
|
||||
label: cycleVar.name,
|
||||
type: 'variable',
|
||||
dataType: cycleVar.type || 'string',
|
||||
value: `${node.getData().id}.${cycleVar.name}`,
|
||||
nodeData: nodeData,
|
||||
});
|
||||
}
|
||||
});
|
||||
break
|
||||
case 'tool':
|
||||
const toolDataKey = `${nodeId}_data`;
|
||||
if (!addedKeys.has(toolDataKey)) {
|
||||
addedKeys.add(toolDataKey);
|
||||
variableList.push({
|
||||
key: toolDataKey,
|
||||
label: 'data',
|
||||
type: 'variable',
|
||||
dataType: 'object',
|
||||
value: `${node.getData().id}.data`,
|
||||
nodeData: nodeData,
|
||||
});
|
||||
}
|
||||
break
|
||||
}
|
||||
});
|
||||
|
||||
@@ -388,6 +669,14 @@ const Properties: FC<PropertiesProps> = ({
|
||||
return variableList;
|
||||
}, [selectedNode, graphRef]);
|
||||
|
||||
// Filter out boolean type variables for loop and llm nodes
|
||||
const getFilteredVariableList = (nodeType?: string) => {
|
||||
if (nodeType === 'loop' || nodeType === 'llm') {
|
||||
return variableList.filter(variable => variable.dataType !== 'boolean');
|
||||
}
|
||||
return variableList;
|
||||
};
|
||||
|
||||
console.log('values', values)
|
||||
console.log('variableList', variableList, selectedNode?.data)
|
||||
|
||||
@@ -412,6 +701,8 @@ const Properties: FC<PropertiesProps> = ({
|
||||
{selectedNode?.data?.type === 'http-request'
|
||||
? <HttpRequest
|
||||
options={variableList}
|
||||
selectedNode={selectedNode}
|
||||
graphRef={graphRef}
|
||||
/>
|
||||
: selectedNode?.data?.type === 'tool'
|
||||
? <ToolConfig options={variableList} />
|
||||
@@ -469,7 +760,7 @@ const Properties: FC<PropertiesProps> = ({
|
||||
|
||||
if (selectedNode?.data?.type === 'llm' && key === 'messages' && config.type === 'define') {
|
||||
// 为llm节点且isArray=true时添加context变量支持
|
||||
let contextVariableList = [...variableList];
|
||||
let contextVariableList = [...getFilteredVariableList('llm')];
|
||||
const isArrayMode = config.isArray !== false; // 默认为true
|
||||
|
||||
if (isArrayMode) {
|
||||
@@ -491,14 +782,14 @@ const Properties: FC<PropertiesProps> = ({
|
||||
|
||||
return (
|
||||
<Form.Item key={key} name={key}>
|
||||
<MessageEditor options={contextVariableList} parentName={key} />
|
||||
<MessageEditor key={key} options={contextVariableList} parentName={key} />
|
||||
</Form.Item>
|
||||
)
|
||||
}
|
||||
if (selectedNode?.data?.type === 'end' && key === 'output') {
|
||||
return (
|
||||
<Form.Item key={key} name={key}>
|
||||
<MessageEditor isArray={false} parentName={key} options={variableList} />
|
||||
<MessageEditor key={key} isArray={false} parentName={key} options={variableList} />
|
||||
</Form.Item>
|
||||
)
|
||||
}
|
||||
@@ -525,7 +816,8 @@ const Properties: FC<PropertiesProps> = ({
|
||||
title={t(`workflow.config.${selectedNode?.data?.type}.${key}`)}
|
||||
isArray={!!config.isArray}
|
||||
parentName={key}
|
||||
options={variableList}
|
||||
enableJinja2={config.enableJinja2 as boolean}
|
||||
options={getFilteredVariableList(selectedNode?.data?.type)}
|
||||
/>
|
||||
</Form.Item>
|
||||
)
|
||||
@@ -546,7 +838,7 @@ const Properties: FC<PropertiesProps> = ({
|
||||
<Form.Item key={key} name={key}>
|
||||
<GroupVariableList
|
||||
name={key}
|
||||
options={variableList}
|
||||
options={getFilteredVariableList(selectedNode?.data?.type)}
|
||||
isCanAdd={!!(values as any)?.group}
|
||||
/>
|
||||
</Form.Item>
|
||||
@@ -558,7 +850,7 @@ const Properties: FC<PropertiesProps> = ({
|
||||
<Form.Item key={key} name={key}>
|
||||
<CaseList
|
||||
name={key}
|
||||
options={variableList}
|
||||
options={getFilteredVariableList(selectedNode?.data?.type)}
|
||||
selectedNode={selectedNode}
|
||||
graphRef={graphRef}
|
||||
/>
|
||||
@@ -571,7 +863,7 @@ const Properties: FC<PropertiesProps> = ({
|
||||
<Form.Item key={key} name={key}
|
||||
label={t(`workflow.config.${selectedNode?.data?.type}.${key}`)}
|
||||
>
|
||||
<MappingList name={key} options={variableList} />
|
||||
<MappingList name={key} options={getFilteredVariableList(selectedNode?.data?.type)} />
|
||||
</Form.Item>
|
||||
|
||||
)
|
||||
@@ -581,7 +873,7 @@ const Properties: FC<PropertiesProps> = ({
|
||||
<Form.Item key={key} name={key}>
|
||||
<CycleVarsList
|
||||
parentName={key}
|
||||
options={variableList}
|
||||
options={getFilteredVariableList(selectedNode?.data?.type)}
|
||||
/>
|
||||
</Form.Item>
|
||||
)
|
||||
@@ -655,9 +947,9 @@ const Properties: FC<PropertiesProps> = ({
|
||||
findParentLoopIteration(selectedNode.id);
|
||||
}
|
||||
|
||||
return [...variableList, ...loopIterationVars];
|
||||
return [...getFilteredVariableList(selectedNode?.data?.type), ...loopIterationVars];
|
||||
}
|
||||
return variableList;
|
||||
return getFilteredVariableList(selectedNode?.data?.type);
|
||||
})()
|
||||
}
|
||||
/>
|
||||
@@ -678,7 +970,7 @@ const Properties: FC<PropertiesProps> = ({
|
||||
? <Input.TextArea placeholder={t('common.pleaseEnter')} />
|
||||
: config.type === 'select'
|
||||
? <Select
|
||||
options={config.needTranslation ? config.options?.map(vo => ({ ...vo, label: t(vo.label) })) : config.options}
|
||||
options={config.needTranslation ? (config.options || []).map(vo => ({ ...vo, label: t(vo.label) })) : config.options}
|
||||
placeholder={t('common.pleaseSelect')}
|
||||
/>
|
||||
: config.type === 'inputNumber'
|
||||
@@ -698,9 +990,10 @@ const Properties: FC<PropertiesProps> = ({
|
||||
? <VariableSelect
|
||||
placeholder={t('common.pleaseSelect')}
|
||||
options={(() => {
|
||||
const baseVariableList = getFilteredVariableList(selectedNode?.data?.type);
|
||||
// Apply filtering if specified in config
|
||||
if (config.filterNodeTypes || config.filterVariableNames) {
|
||||
return variableList.filter(variable => {
|
||||
return baseVariableList.filter(variable => {
|
||||
const nodeTypeMatch = !config.filterNodeTypes ||
|
||||
(Array.isArray(config.filterNodeTypes) && config.filterNodeTypes.includes(variable.nodeData?.type));
|
||||
const variableNameMatch = !config.filterVariableNames ||
|
||||
@@ -721,22 +1014,38 @@ const Properties: FC<PropertiesProps> = ({
|
||||
return nodeData?.cycle === selectedNode.id;
|
||||
});
|
||||
|
||||
return variableList.filter(variable =>
|
||||
return baseVariableList.filter(variable =>
|
||||
childNodes.some(node => node.id === variable.nodeData?.id)
|
||||
);
|
||||
}
|
||||
return variableList;
|
||||
return baseVariableList;
|
||||
})()
|
||||
}
|
||||
/>
|
||||
: config.type === 'switch'
|
||||
? <Switch onChange={key === 'group' ? () => { form.setFieldValue('group_variables', []) } : undefined} />
|
||||
? <Switch onChange={key === 'group' ? () => { form.setFieldValue('group_variables', []) } : undefined} />
|
||||
: config.type === 'categoryList'
|
||||
? <CategoryList parentName={key} selectedNode={selectedNode} graphRef={graphRef} />
|
||||
: config.type === 'conditionList'
|
||||
? <ConditionList
|
||||
parentName={key}
|
||||
options={variableList}
|
||||
options={(() => {
|
||||
// For loop nodes, add cycle_vars to condition options
|
||||
if (selectedNode?.data?.type === 'loop') {
|
||||
const cycleVars = values?.cycle_vars || [];
|
||||
const cycleVarSuggestions: Suggestion[] = cycleVars.map((cycleVar: any) => ({
|
||||
key: `${selectedNode.id}_cycle_${cycleVar.name}`,
|
||||
label: cycleVar.name,
|
||||
type: 'variable',
|
||||
dataType: cycleVar.type || 'String',
|
||||
value: `${selectedNode.getData().id}.${cycleVar.name}`,
|
||||
nodeData: selectedNode.getData(),
|
||||
}));
|
||||
return [...getFilteredVariableList(selectedNode?.data?.type), ...cycleVarSuggestions];
|
||||
}
|
||||
return getFilteredVariableList(selectedNode?.data?.type);
|
||||
})()
|
||||
}
|
||||
selectedNode={selectedNode}
|
||||
graphRef={graphRef}
|
||||
addBtnText={t('workflow.config.addCase')}
|
||||
|
||||
@@ -300,6 +300,7 @@ export const nodeLibrary: NodeLibrary[] = [
|
||||
config: {
|
||||
cycle_vars: {
|
||||
type: 'cycleVarsList',
|
||||
defaultValue: []
|
||||
},
|
||||
condition: {
|
||||
type: 'conditionList',
|
||||
@@ -395,12 +396,14 @@ export const nodeLibrary: NodeLibrary[] = [
|
||||
},
|
||||
retry: {
|
||||
type: 'switch',
|
||||
defaultValue: false
|
||||
defaultValue: {
|
||||
enable: false
|
||||
}
|
||||
},
|
||||
error_handle: {
|
||||
type: 'define',
|
||||
defaultValue: {
|
||||
method: 'default'
|
||||
method: 'none'
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -420,11 +423,13 @@ export const nodeLibrary: NodeLibrary[] = [
|
||||
config: {
|
||||
mapping: {
|
||||
type: 'mappingList',
|
||||
defaultValue: []
|
||||
defaultValue: [{name: 'arg1'}]
|
||||
},
|
||||
template: {
|
||||
type: 'messageEditor',
|
||||
isArray: false,
|
||||
enableJinja2: true,
|
||||
defaultValue: "{{arg1}}"
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -193,6 +193,27 @@ export const useWorkflowGraph = ({
|
||||
nodeConfig.height = newHeight;
|
||||
}
|
||||
|
||||
// 如果是http-request节点,检查error_handle.method配置
|
||||
if (type === 'http-request' && (config as any).error_handle?.method === 'branch') {
|
||||
const portAttrs = {
|
||||
circle: {
|
||||
r: 4, magnet: true, stroke: '#155EEF', strokeWidth: 2, fill: '#155EEF', position: { top: 22 }
|
||||
},
|
||||
};
|
||||
|
||||
nodeConfig.ports = {
|
||||
groups: {
|
||||
right: { position: 'right', attrs: portAttrs },
|
||||
left: { position: 'left', attrs: portAttrs },
|
||||
},
|
||||
items: [
|
||||
{ group: 'left' },
|
||||
{ group: 'right', id: 'right' },
|
||||
{ group: 'right', id: 'ERROR', attrs: { text: { text: t('workflow.config.http-request.errorBranch'), fontSize: 12, fill: '#5B6167' }}}
|
||||
]
|
||||
};
|
||||
}
|
||||
|
||||
return nodeConfig
|
||||
})
|
||||
|
||||
@@ -284,6 +305,14 @@ export const useWorkflowGraph = ({
|
||||
}
|
||||
}
|
||||
|
||||
// 如果是http-request节点且有label,根据label匹配对应的端口
|
||||
if (sourceCell.getData()?.type === 'http-request' && label) {
|
||||
const matchingPort = sourcePorts.find((port: any) => port.id === label);
|
||||
if (matchingPort) {
|
||||
sourcePort = label;
|
||||
}
|
||||
}
|
||||
|
||||
const edgeConfig = {
|
||||
source: {
|
||||
cell: sourceCell.id,
|
||||
@@ -954,6 +983,23 @@ export const useWorkflowGraph = ({
|
||||
};
|
||||
}
|
||||
|
||||
// 如果是http-request节点的右侧端口连线,添加label
|
||||
if (sourceCell?.getData()?.type === 'http-request') {
|
||||
if (sourcePortId === 'ERROR') {
|
||||
return {
|
||||
source: sourceCell.getData().id,
|
||||
target: targetCell?.getData().id,
|
||||
label: 'ERROR',
|
||||
};
|
||||
} else {
|
||||
return {
|
||||
source: sourceCell.getData().id,
|
||||
target: targetCell?.getData().id,
|
||||
label: 'SUCCESS',
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
source: sourceCell?.getData().id,
|
||||
target: targetCell?.getData().id,
|
||||
|
||||
@@ -26,6 +26,7 @@ export interface NodeConfig {
|
||||
|
||||
group_variables?: Array<{ key: string, value: string[] }>
|
||||
cycle?: string;
|
||||
cycle_vars?: Array<{ name: string; type: string; value: string; input_type: string; }>
|
||||
[key: string]: unknown;
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user