Merge branch 'develop' of github.com:SuanmoSuanyangTechnology/MemoryBear into develop

This commit is contained in:
Mark
2026-01-07 18:19:27 +08:00
45 changed files with 4488 additions and 259 deletions

View File

@@ -4,40 +4,44 @@
认证方式: JWT Token
"""
from fastapi import APIRouter
from . import (
model_controller,
task_controller,
test_controller,
user_controller,
auth_controller,
workspace_controller,
setup_controller,
file_controller,
document_controller,
knowledge_controller,
chunk_controller,
knowledgeshare_controller,
api_key_controller,
app_controller,
upload_controller,
auth_controller,
chunk_controller,
document_controller,
emotion_config_controller,
emotion_controller,
file_controller,
home_page_controller,
implicit_memory_controller,
knowledge_controller,
knowledgeshare_controller,
memory_agent_controller,
memory_dashboard_controller,
memory_storage_controller,
memory_dashboard_controller,
memory_forget_controller,
memory_reflection_controller,
memory_short_term_controller,
api_key_controller,
release_share_controller,
public_share_controller,
memory_storage_controller,
model_controller,
multi_agent_controller,
workflow_controller,
emotion_controller,
emotion_config_controller,
prompt_optimizer_controller,
public_share_controller,
release_share_controller,
setup_controller,
task_controller,
test_controller,
tool_controller,
upload_controller,
user_controller,
user_memory_controllers,
workflow_controller,
workspace_controller,
memory_forget_controller,
home_page_controller,
memory_perceptual_controller
)
from . import user_memory_controllers
# 创建管理端 API 路由器
manager_router = APIRouter()
@@ -76,5 +80,7 @@ manager_router.include_router(memory_short_term_controller.router)
manager_router.include_router(tool_controller.router)
manager_router.include_router(memory_forget_controller.router)
manager_router.include_router(home_page_controller.router)
manager_router.include_router(implicit_memory_controller.router)
manager_router.include_router(memory_perceptual_controller.router)
__all__ = ["manager_router"]

View File

@@ -0,0 +1,302 @@
from datetime import datetime
from typing import Optional
from app.core.error_codes import BizCode
from app.core.logging_config import get_api_logger
from app.core.response_utils import fail, success
from app.db import get_db
from app.dependencies import (
cur_workspace_access_guard,
get_current_user,
)
from app.models.user_model import User
from app.schemas.response_schema import ApiResponse
from app.services.implicit_memory_service import ImplicitMemoryService
from fastapi import APIRouter, Depends, Query
from sqlalchemy.orm import Session
api_logger = get_api_logger()
router = APIRouter(
prefix="/memory/implicit-memory",
tags=["Implicit Memory"],
)
def handle_implicit_memory_error(e: Exception, operation: str, user_id: str = None) -> dict:
"""
Centralized error handling for implicit memory operations.
Args:
e: The exception that occurred
operation: Description of the operation that failed
user_id: Optional user ID for logging context
Returns:
Standardized error response
"""
error_context = f"user_id={user_id}" if user_id else "unknown user"
if isinstance(e, ValueError):
if "user" in str(e).lower() and "not found" in str(e).lower():
api_logger.warning(f"Invalid user ID for {operation}: {error_context}")
return fail(BizCode.INVALID_USER_ID, "无效的用户ID", str(e))
elif "insufficient" in str(e).lower() or "no data" in str(e).lower():
api_logger.warning(f"Insufficient data for {operation}: {error_context}")
return fail(BizCode.INSUFFICIENT_DATA, "数据不足,无法进行分析", str(e))
else:
api_logger.warning(f"Invalid parameters for {operation}: {error_context}")
return fail(BizCode.INVALID_FILTER_PARAMS, "无效的参数", str(e))
elif isinstance(e, KeyError):
api_logger.warning(f"Missing required data for {operation}: {error_context}")
return fail(BizCode.INSUFFICIENT_DATA, "缺少必要的数据", str(e))
elif isinstance(e, (ConnectionError, TimeoutError)):
api_logger.error(f"Service unavailable for {operation}: {error_context}")
return fail(BizCode.SERVICE_UNAVAILABLE, "服务暂时不可用", str(e))
elif "analysis" in str(e).lower() or "llm" in str(e).lower():
api_logger.error(f"Analysis failed for {operation}: {error_context}", exc_info=True)
return fail(BizCode.ANALYSIS_FAILED, "分析处理失败", str(e))
elif "storage" in str(e).lower() or "database" in str(e).lower():
api_logger.error(f"Storage error for {operation}: {error_context}", exc_info=True)
return fail(BizCode.PROFILE_STORAGE_ERROR, "数据存储失败", str(e))
else:
api_logger.error(f"Unexpected error for {operation}: {error_context}", exc_info=True)
return fail(BizCode.INTERNAL_ERROR, f"{operation}失败", str(e))
def validate_user_id(user_id: str) -> None:
"""
Validate user ID format and constraints.
Args:
user_id: User ID to validate
Raises:
ValueError: If user ID is invalid
"""
if not user_id or not user_id.strip():
raise ValueError("User ID cannot be empty")
if len(user_id.strip()) < 1:
raise ValueError("User ID is too short")
def validate_date_range(start_date: Optional[datetime], end_date: Optional[datetime]) -> None:
"""
Validate date range parameters.
Args:
start_date: Start date
end_date: End date
Raises:
ValueError: If date range is invalid
"""
if (start_date and not end_date) or (end_date and not start_date):
raise ValueError("Both start_date and end_date must be provided together")
if start_date and end_date and start_date >= end_date:
raise ValueError("start_date must be before end_date")
if start_date and start_date > datetime.now():
raise ValueError("start_date cannot be in the future")
def validate_confidence_threshold(threshold: float) -> None:
"""
Validate confidence threshold parameter.
Args:
threshold: Confidence threshold to validate
Raises:
ValueError: If threshold is invalid
"""
if not 0.0 <= threshold <= 1.0:
raise ValueError("confidence_threshold must be between 0.0 and 1.0")
@router.get("/preferences/{user_id}", response_model=ApiResponse)
@cur_workspace_access_guard()
async def get_preference_tags(
user_id: str,
confidence_threshold: float = Query(0.5, ge=0.0, le=1.0, description="Minimum confidence threshold"),
tag_category: Optional[str] = Query(None, description="Filter by tag category"),
start_date: Optional[datetime] = Query(None, description="Filter start date"),
end_date: Optional[datetime] = Query(None, description="Filter end date"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
) -> ApiResponse:
"""
Get user preference tags with filtering options.
Args:
user_id: Target user ID
confidence_threshold: Minimum confidence score (0.0-1.0)
tag_category: Optional category filter
start_date: Optional start date filter
end_date: Optional end date filter
Returns:
List of preference tags matching the filters
"""
api_logger.info(f"Preference tags requested for user: {user_id}")
try:
# Validate inputs
validate_user_id(user_id)
validate_confidence_threshold(confidence_threshold)
validate_date_range(start_date, end_date)
# Create service with user-specific config
service = ImplicitMemoryService(db=db, end_user_id=user_id)
# Build date range
date_range = None
if start_date and end_date:
from app.schemas.implicit_memory_schema import DateRange
date_range = DateRange(start_date=start_date, end_date=end_date)
# Get preference tags
tags = await service.get_preference_tags(
user_id=user_id,
confidence_threshold=confidence_threshold,
tag_category=tag_category,
date_range=date_range
)
api_logger.info(f"Retrieved {len(tags)} preference tags for user: {user_id}")
return success(data=[tag.dict() for tag in tags], msg="偏好标签获取成功")
except Exception as e:
return handle_implicit_memory_error(e, "偏好标签获取", user_id)
@router.get("/portrait/{user_id}", response_model=ApiResponse)
@cur_workspace_access_guard()
async def get_dimension_portrait(
user_id: str,
include_history: bool = Query(False, description="Include historical trends"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
) -> ApiResponse:
"""
Get user's four-dimension personality portrait.
Args:
user_id: Target user ID
include_history: Whether to include historical trend data
Returns:
Four-dimension personality portrait with scores and evidence
"""
api_logger.info(f"Dimension portrait requested for user: {user_id}")
try:
# Validate inputs
validate_user_id(user_id)
# Create service with user-specific config
service = ImplicitMemoryService(db=db, end_user_id=user_id)
portrait = await service.get_dimension_portrait(
user_id=user_id,
include_history=include_history
)
api_logger.info(f"Dimension portrait retrieved for user: {user_id}")
return success(data=portrait.dict(), msg="四维画像获取成功")
except Exception as e:
return handle_implicit_memory_error(e, "四维画像获取", user_id)
@router.get("/interest-areas/{user_id}", response_model=ApiResponse)
@cur_workspace_access_guard()
async def get_interest_area_distribution(
user_id: str,
include_trends: bool = Query(False, description="Include trend analysis"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
) -> ApiResponse:
"""
Get user's interest area distribution across four areas.
Args:
user_id: Target user ID
include_trends: Whether to include trend analysis data
Returns:
Interest area distribution with percentages and evidence
"""
api_logger.info(f"Interest area distribution requested for user: {user_id}")
try:
# Validate inputs
validate_user_id(user_id)
# Create service with user-specific config
service = ImplicitMemoryService(db=db, end_user_id=user_id)
distribution = await service.get_interest_area_distribution(
user_id=user_id,
include_trends=include_trends
)
api_logger.info(f"Interest area distribution retrieved for user: {user_id}")
return success(data=distribution.dict(), msg="兴趣领域分布获取成功")
except Exception as e:
return handle_implicit_memory_error(e, "兴趣领域分布获取", user_id)
@router.get("/habits/{user_id}", response_model=ApiResponse)
@cur_workspace_access_guard()
async def get_behavior_habits(
user_id: str,
confidence_level: Optional[str] = Query(None, regex="^(high|medium|low)$", description="Filter by confidence level"),
frequency_pattern: Optional[str] = Query(None, regex="^(daily|weekly|monthly|seasonal|occasional|event_triggered)$", description="Filter by frequency pattern"),
time_period: Optional[str] = Query(None, regex="^(current|past)$", description="Filter by time period"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
) -> ApiResponse:
"""
Get user's behavioral habits with filtering options.
Args:
user_id: Target user ID
confidence_level: Filter by confidence level (high, medium, low)
frequency_pattern: Filter by frequency pattern (daily, weekly, monthly, seasonal, occasional, event_triggered)
time_period: Filter by time period (current, past)
Returns:
List of behavioral habits matching the filters
"""
api_logger.info(f"Behavior habits requested for user: {user_id}")
try:
# Validate inputs
validate_user_id(user_id)
# Create service with user-specific config
service = ImplicitMemoryService(db=db, end_user_id=user_id)
habits = await service.get_behavior_habits(
user_id=user_id,
confidence_level=confidence_level,
frequency_pattern=frequency_pattern,
time_period=time_period
)
api_logger.info(f"Retrieved {len(habits)} behavior habits for user: {user_id}")
return success(data=[habit.dict() for habit in habits], msg="行为习惯获取成功")
except Exception as e:
return handle_implicit_memory_error(e, "行为习惯获取", user_id)

View File

@@ -82,6 +82,13 @@ class BizCode(IntEnum):
MEMORY_WRITE_FAILED = 9501
MEMORY_READ_FAILED = 9502
MEMORY_CONFIG_NOT_FOUND = 9503
# Implicit Memory API96xx
INVALID_USER_ID = 9601
INSUFFICIENT_DATA = 9602
INVALID_FILTER_PARAMS = 9603
ANALYSIS_FAILED = 9604
PROFILE_STORAGE_ERROR = 9605
# 系统100xx
INTERNAL_ERROR = 10001
@@ -159,6 +166,13 @@ HTTP_MAPPING = {
BizCode.MEMORY_READ_FAILED: 500,
BizCode.MEMORY_CONFIG_NOT_FOUND: 400,
# Implicit Memory API 错误码映射
BizCode.INVALID_USER_ID: 400,
BizCode.INSUFFICIENT_DATA: 400,
BizCode.INVALID_FILTER_PARAMS: 400,
BizCode.ANALYSIS_FAILED: 500,
BizCode.PROFILE_STORAGE_ERROR: 500,
BizCode.INTERNAL_ERROR: 500,
BizCode.DB_ERROR: 500,
BizCode.SERVICE_UNAVAILABLE: 503,

View File

@@ -0,0 +1,6 @@
"""Implicit Memory Module
This module provides behavior analysis capabilities that build comprehensive user profiles
by analyzing memory summary nodes from Neo4j. It creates detailed user portraits across
multiple dimensions, tracks interest distributions, and identifies behavioral habits.
"""

View File

@@ -0,0 +1 @@
"""Analyzers package for implicit memory analysis components."""

View File

@@ -0,0 +1,264 @@
"""Dimension Analyzer for Implicit Memory System
This module implements LLM-based personality dimension analysis from user memory summaries.
It analyzes four key dimensions: creativity, aesthetic, technology, and literature,
providing percentage scores with evidence and reasoning.
"""
import logging
from datetime import datetime
from typing import Any, Dict, List, Optional
from app.core.memory.analytics.implicit_memory.llm_client import ImplicitMemoryLLMClient
from app.core.memory.llm_tools.llm_client import LLMClientException
from app.schemas.implicit_memory_schema import (
ConfidenceLevel,
DimensionPortrait,
DimensionScore,
UserMemorySummary,
)
from pydantic import BaseModel, Field
from sqlalchemy.orm import Session
logger = logging.getLogger(__name__)
class DimensionData(BaseModel):
"""Individual dimension analysis data."""
percentage: float = Field(ge=0.0, le=100.0)
evidence: List[str] = Field(default_factory=list)
reasoning: str = ""
confidence_level: str = "medium"
class DimensionAnalysisResponse(BaseModel):
"""Response model for dimension analysis."""
dimensions: Dict[str, DimensionData] = Field(default_factory=dict)
class DimensionAnalyzer:
"""Analyzes user memory summaries to extract personality dimensions."""
# Define the four dimensions we analyze
DIMENSIONS = ["creativity", "aesthetic", "technology", "literature"]
def __init__(self, db: Session, llm_model_id: Optional[str] = None):
"""Initialize the dimension analyzer.
Args:
db: Database session
llm_model_id: Optional LLM model ID to use for analysis
"""
self.db = db
self.llm_model_id = llm_model_id
self._llm_client = ImplicitMemoryLLMClient(db, llm_model_id)
async def analyze_dimensions(
self,
user_id: str,
user_summaries: List[UserMemorySummary],
existing_portrait: Optional[DimensionPortrait] = None
) -> DimensionPortrait:
"""Analyze user summaries to extract personality dimensions.
Args:
user_id: Target user ID
user_summaries: List of user-specific memory summaries
existing_portrait: Optional existing portrait for incremental updates
Returns:
Dimension portrait with four personality dimensions
Raises:
LLMClientException: If LLM analysis fails
"""
if not user_summaries:
logger.warning(f"No summaries provided for user {user_id}")
return self._create_empty_portrait(user_id)
try:
logger.info(f"Analyzing dimensions for user {user_id} with {len(user_summaries)} summaries")
# Use the LLM client wrapper for analysis
response = await self._llm_client.analyze_dimensions(
user_summaries=user_summaries,
user_id=user_id,
model_id=self.llm_model_id
)
# Create dimension scores
dimension_scores = {}
current_time = datetime.now()
for dimension_name in self.DIMENSIONS:
# Handle response as dictionary
dimensions_data = response.get("dimensions", {})
dimension_data = dimensions_data.get(dimension_name)
if dimension_data:
# Validate and create dimension score
score = self._create_dimension_score(
dimension_name=dimension_name,
dimension_data=dimension_data
)
dimension_scores[dimension_name] = score
else:
# Create default score if missing
logger.warning(f"Missing dimension data for {dimension_name}, using default")
dimension_scores[dimension_name] = self._create_default_dimension_score(dimension_name)
# Create dimension portrait
portrait = DimensionPortrait(
user_id=user_id,
creativity=dimension_scores["creativity"],
aesthetic=dimension_scores["aesthetic"],
technology=dimension_scores["technology"],
literature=dimension_scores["literature"],
analysis_timestamp=current_time,
total_summaries_analyzed=len(user_summaries),
historical_trends=self._calculate_historical_trends(existing_portrait) if existing_portrait else None
)
logger.info(f"Created dimension portrait for user {user_id}")
return portrait
except LLMClientException:
raise
except Exception as e:
logger.error(f"Dimension analysis failed for user {user_id}: {e}")
raise LLMClientException(f"Dimension analysis failed: {e}") from e
def _create_dimension_score(
self,
dimension_name: str,
dimension_data: dict
) -> DimensionScore:
"""Create a dimension score from analysis data.
Args:
dimension_name: Name of the dimension
dimension_data: Analysis data dictionary for the dimension
Returns:
DimensionScore object
"""
# Validate percentage - handle dict access
percentage = dimension_data.get("percentage", 0.0)
percentage = max(0.0, min(100.0, float(percentage)))
# Validate confidence level
confidence_level_str = dimension_data.get("confidence_level", "low")
confidence_level = self._validate_confidence_level(confidence_level_str)
# Ensure evidence is not empty
evidence = dimension_data.get("evidence", [])
if not evidence:
evidence = ["No specific evidence found"]
# Ensure reasoning is not empty
reasoning = dimension_data.get("reasoning", "")
if not reasoning:
reasoning = f"Analysis for {dimension_name} dimension"
return DimensionScore(
dimension_name=dimension_name,
percentage=percentage,
evidence=evidence,
reasoning=reasoning,
confidence_level=confidence_level
)
def _create_default_dimension_score(self, dimension_name: str) -> DimensionScore:
"""Create a default dimension score when analysis fails.
Args:
dimension_name: Name of the dimension
Returns:
Default DimensionScore object
"""
return DimensionScore(
dimension_name=dimension_name,
percentage=0.0,
evidence=["Insufficient data for analysis"],
reasoning=f"No clear evidence found for {dimension_name} dimension",
confidence_level=ConfidenceLevel.LOW
)
def _validate_confidence_level(self, confidence_str: str) -> ConfidenceLevel:
"""Validate and convert confidence level string.
Args:
confidence_str: Confidence level as string
Returns:
ConfidenceLevel enum value
"""
if not confidence_str:
return ConfidenceLevel.MEDIUM
confidence_str = str(confidence_str).lower().strip()
if confidence_str in ["high", "높음"]:
return ConfidenceLevel.HIGH
elif confidence_str in ["medium", "중간"]:
return ConfidenceLevel.MEDIUM
elif confidence_str in ["low", "낮음"]:
return ConfidenceLevel.LOW
else:
logger.warning(f"Unknown confidence level: {confidence_str}, defaulting to medium")
return ConfidenceLevel.MEDIUM
def _create_empty_portrait(self, user_id: str) -> DimensionPortrait:
"""Create an empty dimension portrait when no data is available.
Args:
user_id: Target user ID
Returns:
Empty DimensionPortrait
"""
current_time = datetime.now()
return DimensionPortrait(
user_id=user_id,
creativity=self._create_default_dimension_score("creativity"),
aesthetic=self._create_default_dimension_score("aesthetic"),
technology=self._create_default_dimension_score("technology"),
literature=self._create_default_dimension_score("literature"),
analysis_timestamp=current_time,
total_summaries_analyzed=0,
historical_trends=None
)
def _calculate_historical_trends(
self,
existing_portrait: DimensionPortrait
) -> List[Dict[str, Any]]:
"""Calculate historical trends from existing portrait.
Args:
existing_portrait: Previous dimension portrait
Returns:
List of historical trend data
"""
if not existing_portrait:
return []
# Create trend entry from existing portrait
trend_entry = {
"timestamp": existing_portrait.analysis_timestamp.isoformat(),
"creativity": existing_portrait.creativity.percentage,
"aesthetic": existing_portrait.aesthetic.percentage,
"technology": existing_portrait.technology.percentage,
"literature": existing_portrait.literature.percentage,
"total_summaries": existing_portrait.total_summaries_analyzed
}
# Combine with existing trends
existing_trends = existing_portrait.historical_trends or []
# Keep only recent trends (last 10 entries)
all_trends = existing_trends + [trend_entry]
return all_trends[-10:]

View File

@@ -0,0 +1,452 @@
"""Habit Analyzer for Implicit Memory System
This module implements LLM-based behavioral habit analysis from user memory summaries.
It identifies recurring behavioral patterns, temporal patterns, and consolidates
similar habits with confidence scoring.
"""
import logging
from datetime import datetime, timedelta
from typing import Any, Dict, List, Optional
from app.core.memory.analytics.implicit_memory.llm_client import ImplicitMemoryLLMClient
from app.core.memory.llm_tools.llm_client import LLMClientException
from app.schemas.implicit_memory_schema import (
BehaviorHabit,
ConfidenceLevel,
FrequencyPattern,
UserMemorySummary,
)
from pydantic import BaseModel, Field
from sqlalchemy.orm import Session
logger = logging.getLogger(__name__)
class HabitData(BaseModel):
"""Individual habit analysis data."""
habit_description: str
frequency_pattern: str
time_context: str
confidence_level: str
supporting_summaries: List[str] = Field(default_factory=list)
specific_examples: List[str] = Field(default_factory=list)
is_current: bool = True
class HabitAnalysisResponse(BaseModel):
"""Response model for habit analysis."""
habits: List[HabitData] = Field(default_factory=list)
class HabitAnalyzer:
"""Analyzes user memory summaries to extract behavioral habits."""
def __init__(self, db: Session, llm_model_id: Optional[str] = None):
"""Initialize the habit analyzer.
Args:
db: Database session
llm_model_id: Optional LLM model ID to use for analysis
"""
self.db = db
self.llm_model_id = llm_model_id
self._llm_client = ImplicitMemoryLLMClient(db, llm_model_id)
async def analyze_habits(
self,
user_id: str,
user_summaries: List[UserMemorySummary],
existing_habits: Optional[List[BehaviorHabit]] = None
) -> List[BehaviorHabit]:
"""Analyze user summaries to extract behavioral habits.
Args:
user_id: Target user ID
user_summaries: List of user-specific memory summaries
existing_habits: Optional existing habits for consolidation
Returns:
List of extracted behavioral habits
Raises:
LLMClientException: If LLM analysis fails
"""
if not user_summaries:
logger.warning(f"No summaries provided for user {user_id}")
return existing_habits or []
try:
logger.info(f"Analyzing habits for user {user_id} with {len(user_summaries)} summaries")
# Use the LLM client wrapper for analysis
response = await self._llm_client.analyze_habits(
user_summaries=user_summaries,
user_id=user_id,
model_id=self.llm_model_id
)
# Convert to BehaviorHabit objects
behavior_habits = []
current_time = datetime.now()
for habit_data in response.get("habits", []):
try:
# Handle habit_data as dictionary
supporting_summaries = habit_data.get("supporting_summaries", [])
specific_examples = habit_data.get("specific_examples", [])
# Determine observation dates from summaries
first_observed, last_observed = self._determine_observation_dates(
user_summaries, supporting_summaries
)
behavior_habit = BehaviorHabit(
habit_description=habit_data.get("habit_description", ""),
frequency_pattern=self._validate_frequency_pattern(habit_data.get("frequency_pattern", "occasional")),
time_context=habit_data.get("time_context", ""),
confidence_level=self._validate_confidence_level(habit_data.get("confidence_level", "medium")),
supporting_summaries=supporting_summaries,
specific_examples=specific_examples,
first_observed=first_observed,
last_observed=last_observed,
is_current=habit_data.get("is_current", True)
)
# Validate habit
if self._is_valid_habit(behavior_habit):
behavior_habits.append(behavior_habit)
else:
logger.warning(f"Invalid habit skipped: {behavior_habit.habit_description}")
except Exception as e:
logger.error(f"Error creating behavior habit: {e}")
continue
# Consolidate with existing habits if provided
if existing_habits:
behavior_habits = self._consolidate_habits(
new_habits=behavior_habits,
existing_habits=existing_habits
)
# Sort habits by confidence and recency
behavior_habits = self._sort_habits_by_priority(behavior_habits)
logger.info(f"Extracted {len(behavior_habits)} habits for user {user_id}")
return behavior_habits
except LLMClientException:
raise
except Exception as e:
logger.error(f"Habit analysis failed for user {user_id}: {e}")
raise LLMClientException(f"Habit analysis failed: {e}") from e
def _validate_frequency_pattern(self, frequency_str: str) -> FrequencyPattern:
"""Validate and convert frequency pattern string.
Args:
frequency_str: Frequency pattern as string
Returns:
FrequencyPattern enum value
"""
frequency_str = frequency_str.lower().strip()
frequency_mapping = {
"daily": FrequencyPattern.DAILY,
"weekly": FrequencyPattern.WEEKLY,
"monthly": FrequencyPattern.MONTHLY,
"seasonal": FrequencyPattern.SEASONAL,
"occasional": FrequencyPattern.OCCASIONAL,
"event_triggered": FrequencyPattern.EVENT_TRIGGERED,
"event-triggered": FrequencyPattern.EVENT_TRIGGERED,
}
return frequency_mapping.get(frequency_str, FrequencyPattern.OCCASIONAL)
def _validate_confidence_level(self, confidence_str: str) -> ConfidenceLevel:
"""Validate and convert confidence level string.
Args:
confidence_str: Confidence level as string
Returns:
ConfidenceLevel enum value
"""
confidence_str = confidence_str.lower().strip()
if confidence_str in ["high", "높음"]:
return ConfidenceLevel.HIGH
elif confidence_str in ["medium", "중간"]:
return ConfidenceLevel.MEDIUM
elif confidence_str in ["low", "낮음"]:
return ConfidenceLevel.LOW
else:
logger.warning(f"Unknown confidence level: {confidence_str}, defaulting to medium")
return ConfidenceLevel.MEDIUM
def _determine_observation_dates(
self,
user_summaries: List[UserMemorySummary],
supporting_summary_ids: List[str]
) -> tuple[datetime, datetime]:
"""Determine first and last observation dates for a habit.
Args:
user_summaries: List of user summaries
supporting_summary_ids: IDs of summaries supporting the habit
Returns:
Tuple of (first_observed, last_observed) dates
"""
from datetime import timezone
# Find summaries that support this habit
supporting_summaries = [
summary for summary in user_summaries
if summary.summary_id in supporting_summary_ids
]
if not supporting_summaries:
# Use all summaries if no specific supporting summaries found
supporting_summaries = user_summaries
if not supporting_summaries:
current_time = datetime.now(timezone.utc).replace(tzinfo=None)
return current_time, current_time
# Get date range from supporting summaries - normalize to naive datetimes
timestamps = []
for summary in supporting_summaries:
ts = summary.timestamp
# Convert to naive datetime if it's timezone-aware
if ts.tzinfo is not None:
ts = ts.replace(tzinfo=None)
timestamps.append(ts)
first_observed = min(timestamps)
last_observed = max(timestamps)
return first_observed, last_observed
def _is_valid_habit(self, habit: BehaviorHabit) -> bool:
"""Validate a behavioral habit.
Args:
habit: Behavioral habit to validate
Returns:
True if valid, False otherwise
"""
try:
# Check required fields
if not habit.habit_description or not habit.habit_description.strip():
return False
# Check time context
if not habit.time_context or not habit.time_context.strip():
return False
# Check supporting summaries
if not habit.supporting_summaries or len(habit.supporting_summaries) == 0:
return False
# Check specific examples
if not habit.specific_examples or len(habit.specific_examples) == 0:
return False
# Check observation dates
if habit.first_observed > habit.last_observed:
return False
return True
except Exception as e:
logger.error(f"Error validating habit: {e}")
return False
def _consolidate_habits(
self,
new_habits: List[BehaviorHabit],
existing_habits: List[BehaviorHabit],
similarity_threshold: float = 0.7
) -> List[BehaviorHabit]:
"""Consolidate new habits with existing ones.
Args:
new_habits: Newly extracted habits
existing_habits: Existing habits
similarity_threshold: Threshold for considering habits similar
Returns:
Consolidated list of habits
"""
consolidated = existing_habits.copy()
current_time = datetime.now()
for new_habit in new_habits:
# Find similar existing habit
similar_habit = self._find_similar_habit(
new_habit, existing_habits, similarity_threshold
)
if similar_habit:
# Update existing habit
updated_habit = self._merge_habits(similar_habit, new_habit, current_time)
# Replace in consolidated list
for i, habit in enumerate(consolidated):
if habit.habit_description == similar_habit.habit_description:
consolidated[i] = updated_habit
break
else:
# Add as new habit
consolidated.append(new_habit)
return consolidated
def _find_similar_habit(
self,
target_habit: BehaviorHabit,
existing_habits: List[BehaviorHabit],
threshold: float
) -> Optional[BehaviorHabit]:
"""Find similar habit in existing list.
Args:
target_habit: Habit to find similarity for
existing_habits: List of existing habits
threshold: Similarity threshold
Returns:
Similar habit if found, None otherwise
"""
target_desc = target_habit.habit_description.lower().strip()
for existing_habit in existing_habits:
existing_desc = existing_habit.habit_description.lower().strip()
# Check description similarity
desc_similarity = self._calculate_text_similarity(target_desc, existing_desc)
# Check frequency pattern match
frequency_match = (target_habit.frequency_pattern == existing_habit.frequency_pattern)
# Check time context similarity
time_similarity = self._calculate_text_similarity(
target_habit.time_context.lower(),
existing_habit.time_context.lower()
)
# Combined similarity score
combined_similarity = (desc_similarity * 0.6 + time_similarity * 0.4)
if frequency_match:
combined_similarity += 0.1 # Bonus for frequency match
if combined_similarity >= threshold:
return existing_habit
return None
def _calculate_text_similarity(self, text1: str, text2: str) -> float:
"""Calculate simple text similarity based on common words.
Args:
text1: First text
text2: Second text
Returns:
Similarity score between 0.0 and 1.0
"""
if not text1 or not text2:
return 0.0
# Simple word-based similarity
words1 = set(text1.lower().split())
words2 = set(text2.lower().split())
if not words1 or not words2:
return 0.0
intersection = words1.intersection(words2)
union = words1.union(words2)
return len(intersection) / len(union) if union else 0.0
def _merge_habits(
self,
existing_habit: BehaviorHabit,
new_habit: BehaviorHabit,
current_time: datetime
) -> BehaviorHabit:
"""Merge two similar habits.
Args:
existing_habit: Existing habit
new_habit: New habit to merge
current_time: Current timestamp
Returns:
Merged behavioral habit
"""
# Combine supporting summaries
combined_summaries = list(set(
existing_habit.supporting_summaries + new_habit.supporting_summaries
))
# Combine specific examples
combined_examples = list(set(
existing_habit.specific_examples + new_habit.specific_examples
))
# Update confidence level (take higher confidence)
confidence_levels = [existing_habit.confidence_level, new_habit.confidence_level]
new_confidence = max(confidence_levels, key=lambda x: ["low", "medium", "high"].index(x.value))
# Update observation dates
first_observed = min(existing_habit.first_observed, new_habit.first_observed)
last_observed = max(existing_habit.last_observed, new_habit.last_observed)
# Determine if habit is current (observed within last 30 days)
is_current = (current_time - last_observed).days <= 30
# Combine time context
combined_time_context = existing_habit.time_context
if new_habit.time_context and new_habit.time_context not in combined_time_context:
combined_time_context += f"; {new_habit.time_context}"
return BehaviorHabit(
habit_description=existing_habit.habit_description, # Keep original description
frequency_pattern=existing_habit.frequency_pattern, # Keep original frequency
time_context=combined_time_context,
confidence_level=new_confidence,
supporting_summaries=combined_summaries,
specific_examples=combined_examples,
first_observed=first_observed,
last_observed=last_observed,
is_current=is_current
)
def _sort_habits_by_priority(self, habits: List[BehaviorHabit]) -> List[BehaviorHabit]:
"""Sort habits by confidence level and recency.
Args:
habits: List of habits to sort
Returns:
Sorted list of habits
"""
def priority_score(habit: BehaviorHabit) -> tuple:
# Confidence level score (high=3, medium=2, low=1)
confidence_score = {"high": 3, "medium": 2, "low": 1}.get(habit.confidence_level.value, 1)
# Recency score (more recent = higher score)
days_since_last = (datetime.now() - habit.last_observed).days
recency_score = max(0, 365 - days_since_last) # Max 365 days
# Current habit bonus
current_bonus = 100 if habit.is_current else 0
return (confidence_score, recency_score + current_bonus, habit.last_observed)
return sorted(habits, key=priority_score, reverse=True)

View File

@@ -0,0 +1,277 @@
"""Interest Analyzer for Implicit Memory System
This module implements LLM-based interest area analysis from user memory summaries.
It categorizes user interests into four areas: tech, lifestyle, music, and art,
providing percentage distribution that totals 100%.
"""
import logging
from datetime import datetime
from typing import Any, Dict, List, Optional
from app.core.memory.analytics.implicit_memory.llm_client import ImplicitMemoryLLMClient
from app.core.memory.llm_tools.llm_client import LLMClientException
from app.schemas.implicit_memory_schema import (
InterestAreaDistribution,
InterestCategory,
UserMemorySummary,
)
from pydantic import BaseModel, Field
from sqlalchemy.orm import Session
logger = logging.getLogger(__name__)
class InterestData(BaseModel):
"""Individual interest category analysis data."""
percentage: float = Field(ge=0.0, le=100.0)
evidence: List[str] = Field(default_factory=list)
trending_direction: Optional[str] = None
class InterestAnalysisResponse(BaseModel):
"""Response model for interest analysis."""
interest_distribution: Dict[str, InterestData] = Field(default_factory=dict)
class InterestAnalyzer:
"""Analyzes user memory summaries to extract interest area distribution."""
# Define the four interest categories we analyze
INTEREST_CATEGORIES = ["tech", "lifestyle", "music", "art"]
def __init__(self, db: Session, llm_model_id: Optional[str] = None):
"""Initialize the interest analyzer.
Args:
db: Database session
llm_model_id: Optional LLM model ID to use for analysis
"""
self.db = db
self.llm_model_id = llm_model_id
self._llm_client = ImplicitMemoryLLMClient(db, llm_model_id)
async def analyze_interests(
self,
user_id: str,
user_summaries: List[UserMemorySummary],
existing_distribution: Optional[InterestAreaDistribution] = None
) -> InterestAreaDistribution:
"""Analyze user summaries to extract interest area distribution.
Args:
user_id: Target user ID
user_summaries: List of user-specific memory summaries
existing_distribution: Optional existing distribution for trend tracking
Returns:
Interest area distribution across four categories
Raises:
LLMClientException: If LLM analysis fails
"""
if not user_summaries:
logger.warning(f"No summaries provided for user {user_id}")
return self._create_empty_distribution(user_id)
try:
logger.info(f"Analyzing interests for user {user_id} with {len(user_summaries)} summaries")
# Use the LLM client wrapper for analysis
response = await self._llm_client.analyze_interests(
user_summaries=user_summaries,
user_id=user_id,
model_id=self.llm_model_id
)
# Create interest categories
interest_categories = {}
current_time = datetime.now()
# Extract interest_distribution from response dict
interest_distribution = response.get("interest_distribution", {})
# Extract and validate interest data
raw_interests = {}
for category_name in self.INTEREST_CATEGORIES:
interest_data_dict = interest_distribution.get(category_name)
if interest_data_dict:
raw_interests[category_name] = InterestData(
percentage=interest_data_dict.get("percentage", 0.0),
evidence=interest_data_dict.get("evidence", []),
trending_direction=interest_data_dict.get("trending_direction")
)
else:
# Create default if missing
logger.warning(f"Missing interest data for {category_name}, using default")
raw_interests[category_name] = InterestData(
percentage=0.0,
evidence=["No specific evidence found"],
trending_direction=None
)
# Normalize percentages to ensure they sum to 100%
normalized_interests = self._normalize_percentages(raw_interests)
# Create interest category objects
for category_name in self.INTEREST_CATEGORIES:
interest_data = normalized_interests[category_name]
# Calculate trending direction if we have existing data
trending_direction = self._calculate_trending_direction(
category_name=category_name,
current_percentage=interest_data.percentage,
existing_distribution=existing_distribution
) if existing_distribution else interest_data.trending_direction
interest_categories[category_name] = InterestCategory(
category_name=category_name,
percentage=interest_data.percentage,
evidence=interest_data.evidence if interest_data.evidence else ["No specific evidence found"],
trending_direction=trending_direction
)
# Create interest area distribution
distribution = InterestAreaDistribution(
user_id=user_id,
tech=interest_categories["tech"],
lifestyle=interest_categories["lifestyle"],
music=interest_categories["music"],
art=interest_categories["art"],
analysis_timestamp=current_time,
total_summaries_analyzed=len(user_summaries)
)
# Validate that percentages sum to 100%
total_percentage = distribution.total_percentage
if not (99.9 <= total_percentage <= 100.1):
logger.warning(f"Interest percentages sum to {total_percentage}, expected ~100%")
logger.info(f"Created interest distribution for user {user_id}")
return distribution
except LLMClientException:
raise
except Exception as e:
logger.error(f"Interest analysis failed for user {user_id}: {e}")
raise LLMClientException(f"Interest analysis failed: {e}") from e
def _normalize_percentages(self, raw_interests: Dict[str, InterestData]) -> Dict[str, InterestData]:
"""Normalize percentages to ensure they sum to 100%.
Args:
raw_interests: Raw interest data with potentially unnormalized percentages
Returns:
Normalized interest data
"""
# Calculate current total
total = sum(interest.percentage for interest in raw_interests.values())
if total == 0:
# If all percentages are 0, distribute equally
equal_percentage = 100.0 / len(self.INTEREST_CATEGORIES)
normalized = {}
for category_name, interest_data in raw_interests.items():
normalized[category_name] = InterestData(
percentage=equal_percentage,
evidence=interest_data.evidence,
trending_direction=interest_data.trending_direction
)
return normalized
# Normalize to sum to 100%
normalization_factor = 100.0 / total
normalized = {}
for category_name, interest_data in raw_interests.items():
normalized_percentage = interest_data.percentage * normalization_factor
normalized[category_name] = InterestData(
percentage=round(normalized_percentage, 1),
evidence=interest_data.evidence,
trending_direction=interest_data.trending_direction
)
# Handle rounding errors by adjusting the largest category
current_total = sum(interest.percentage for interest in normalized.values())
if abs(current_total - 100.0) > 0.1:
# Find category with largest percentage and adjust
largest_category = max(normalized.keys(), key=lambda k: normalized[k].percentage)
adjustment = 100.0 - current_total
adjusted_percentage = normalized[largest_category].percentage + adjustment
normalized[largest_category] = InterestData(
percentage=round(max(0.0, adjusted_percentage), 1),
evidence=normalized[largest_category].evidence,
trending_direction=normalized[largest_category].trending_direction
)
return normalized
def _calculate_trending_direction(
self,
category_name: str,
current_percentage: float,
existing_distribution: InterestAreaDistribution,
threshold: float = 5.0
) -> Optional[str]:
"""Calculate trending direction for an interest category.
Args:
category_name: Name of the interest category
current_percentage: Current percentage for the category
existing_distribution: Previous distribution for comparison
threshold: Minimum percentage change to consider a trend
Returns:
Trending direction: "increasing", "decreasing", "stable", or None
"""
try:
# Get previous percentage
previous_category = getattr(existing_distribution, category_name, None)
if not previous_category:
return None
previous_percentage = previous_category.percentage
change = current_percentage - previous_percentage
if abs(change) < threshold:
return "stable"
elif change > 0:
return "increasing"
else:
return "decreasing"
except Exception as e:
logger.error(f"Error calculating trending direction for {category_name}: {e}")
return None
def _create_empty_distribution(self, user_id: str) -> InterestAreaDistribution:
"""Create an empty interest distribution when no data is available.
Args:
user_id: Target user ID
Returns:
Empty InterestAreaDistribution with equal percentages
"""
current_time = datetime.now()
equal_percentage = 25.0 # 100% / 4 categories
default_category = lambda name: InterestCategory(
category_name=name,
percentage=equal_percentage,
evidence=["Insufficient data for analysis"],
trending_direction=None
)
return InterestAreaDistribution(
user_id=user_id,
tech=default_category("tech"),
lifestyle=default_category("lifestyle"),
music=default_category("music"),
art=default_category("art"),
analysis_timestamp=current_time,
total_summaries_analyzed=0
)

View File

@@ -0,0 +1,302 @@
"""Preference Analyzer for Implicit Memory System
This module implements LLM-based preference extraction from user memory summaries.
It identifies implicit preferences, consolidates similar preferences, and calculates
confidence scores based on evidence strength.
"""
import logging
from datetime import datetime
from typing import Any, Dict, List, Optional
from app.core.memory.analytics.implicit_memory.llm_client import ImplicitMemoryLLMClient
from app.core.memory.llm_tools.llm_client import LLMClientException
from app.schemas.implicit_memory_schema import (
PreferenceTag,
UserMemorySummary,
)
from pydantic import BaseModel, Field
from sqlalchemy.orm import Session
logger = logging.getLogger(__name__)
class PreferenceAnalysisResponse(BaseModel):
"""Response model for preference analysis."""
preferences: List[Dict[str, Any]] = Field(default_factory=list)
class PreferenceAnalyzer:
"""Analyzes user memory summaries to extract implicit preferences."""
def __init__(self, db: Session, llm_model_id: Optional[str] = None):
"""Initialize the preference analyzer.
Args:
db: Database session
llm_model_id: Optional LLM model ID to use for analysis
"""
self.db = db
self.llm_model_id = llm_model_id
self._llm_client = ImplicitMemoryLLMClient(db, llm_model_id)
async def analyze_preferences(
self,
user_id: str,
user_summaries: List[UserMemorySummary],
existing_preferences: Optional[List[PreferenceTag]] = None
) -> List[PreferenceTag]:
"""Analyze user summaries to extract preferences.
Args:
user_id: Target user ID
user_summaries: List of user-specific memory summaries
existing_preferences: Optional existing preferences for consolidation
Returns:
List of extracted preference tags
Raises:
LLMClientException: If LLM analysis fails
"""
if not user_summaries:
logger.warning(f"No summaries provided for user {user_id}")
return []
try:
logger.info(f"Analyzing preferences for user {user_id} with {len(user_summaries)} summaries")
# Use the LLM client wrapper for analysis
response = await self._llm_client.analyze_preferences(
user_summaries=user_summaries,
user_id=user_id,
model_id=self.llm_model_id
)
# Convert to PreferenceTag objects
preference_tags = []
current_time = datetime.now()
for pref_data in response.get("preferences", []):
try:
# Extract conversation references from summaries
conversation_refs = [s.summary_id for s in user_summaries]
preference_tag = PreferenceTag(
tag_name=pref_data.get("tag_name", ""),
confidence_score=float(pref_data.get("confidence_score", 0.0)),
supporting_evidence=pref_data.get("supporting_evidence", []),
context_details=pref_data.get("context_details", ""),
category=pref_data.get("category"),
conversation_references=conversation_refs,
created_at=current_time,
updated_at=current_time
)
# Validate preference tag
if self._is_valid_preference(preference_tag):
preference_tags.append(preference_tag)
else:
logger.warning(f"Invalid preference tag skipped: {preference_tag.tag_name}")
except Exception as e:
logger.error(f"Error creating preference tag: {e}")
continue
# Consolidate with existing preferences if provided
if existing_preferences:
preference_tags = self._consolidate_preferences(
new_preferences=preference_tags,
existing_preferences=existing_preferences
)
logger.info(f"Extracted {len(preference_tags)} preferences for user {user_id}")
return preference_tags
except LLMClientException:
raise
except Exception as e:
logger.error(f"Preference analysis failed for user {user_id}: {e}")
raise LLMClientException(f"Preference analysis failed: {e}") from e
def _is_valid_preference(self, preference: PreferenceTag) -> bool:
"""Validate a preference tag.
Args:
preference: Preference tag to validate
Returns:
True if valid, False otherwise
"""
try:
# Check required fields
if not preference.tag_name or not preference.tag_name.strip():
return False
# Check confidence score range
if not (0.0 <= preference.confidence_score <= 1.0):
return False
# Check supporting evidence
if not preference.supporting_evidence or len(preference.supporting_evidence) == 0:
return False
# Check context details
if not preference.context_details or not preference.context_details.strip():
return False
return True
except Exception as e:
logger.error(f"Error validating preference: {e}")
return False
def _consolidate_preferences(
self,
new_preferences: List[PreferenceTag],
existing_preferences: List[PreferenceTag],
similarity_threshold: float = 0.8
) -> List[PreferenceTag]:
"""Consolidate new preferences with existing ones.
Args:
new_preferences: Newly extracted preferences
existing_preferences: Existing preferences
similarity_threshold: Threshold for considering preferences similar
Returns:
Consolidated list of preferences
"""
consolidated = existing_preferences.copy()
current_time = datetime.now()
for new_pref in new_preferences:
# Find similar existing preference
similar_pref = self._find_similar_preference(
new_pref, existing_preferences, similarity_threshold
)
if similar_pref:
# Update existing preference
updated_pref = self._merge_preferences(similar_pref, new_pref, current_time)
# Replace in consolidated list
for i, pref in enumerate(consolidated):
if pref.tag_name == similar_pref.tag_name:
consolidated[i] = updated_pref
break
else:
# Add as new preference
consolidated.append(new_pref)
return consolidated
def _find_similar_preference(
self,
target_preference: PreferenceTag,
existing_preferences: List[PreferenceTag],
threshold: float
) -> Optional[PreferenceTag]:
"""Find similar preference in existing list.
Args:
target_preference: Preference to find similarity for
existing_preferences: List of existing preferences
threshold: Similarity threshold
Returns:
Similar preference if found, None otherwise
"""
target_name = target_preference.tag_name.lower().strip()
for existing_pref in existing_preferences:
existing_name = existing_pref.tag_name.lower().strip()
# Simple similarity check based on common words
similarity = self._calculate_text_similarity(target_name, existing_name)
if similarity >= threshold:
return existing_pref
return None
def _calculate_text_similarity(self, text1: str, text2: str) -> float:
"""Calculate simple text similarity based on common words.
Args:
text1: First text
text2: Second text
Returns:
Similarity score between 0.0 and 1.0
"""
if not text1 or not text2:
return 0.0
# Simple word-based similarity
words1 = set(text1.lower().split())
words2 = set(text2.lower().split())
if not words1 or not words2:
return 0.0
intersection = words1.intersection(words2)
union = words1.union(words2)
return len(intersection) / len(union) if union else 0.0
def _merge_preferences(
self,
existing_pref: PreferenceTag,
new_pref: PreferenceTag,
current_time: datetime
) -> PreferenceTag:
"""Merge two similar preferences.
Args:
existing_pref: Existing preference
new_pref: New preference to merge
current_time: Current timestamp
Returns:
Merged preference tag
"""
# Combine supporting evidence
combined_evidence = list(set(
existing_pref.supporting_evidence + new_pref.supporting_evidence
))
# Combine conversation references
combined_refs = list(set(
existing_pref.conversation_references + new_pref.conversation_references
))
# Calculate new confidence score (weighted average)
evidence_weight = len(new_pref.supporting_evidence)
total_weight = len(existing_pref.supporting_evidence) + evidence_weight
if total_weight > 0:
new_confidence = (
(existing_pref.confidence_score * len(existing_pref.supporting_evidence) +
new_pref.confidence_score * evidence_weight) / total_weight
)
else:
new_confidence = max(existing_pref.confidence_score, new_pref.confidence_score)
# Ensure confidence doesn't exceed 1.0
new_confidence = min(new_confidence, 1.0)
# Combine context details
combined_context = existing_pref.context_details
if new_pref.context_details and new_pref.context_details not in combined_context:
combined_context += f"; {new_pref.context_details}"
return PreferenceTag(
tag_name=existing_pref.tag_name, # Keep original name
confidence_score=new_confidence,
supporting_evidence=combined_evidence,
context_details=combined_context,
category=existing_pref.category or new_pref.category,
conversation_references=combined_refs,
created_at=existing_pref.created_at,
updated_at=current_time
)

View File

@@ -0,0 +1,97 @@
"""
Memory Data Source
Handles retrieval and processing of memory data from Neo4j using direct Cypher queries.
"""
import logging
from datetime import datetime
from typing import Any, Dict, List, Optional
from app.repositories.neo4j.memory_summary_repository import MemorySummaryRepository
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.schemas.implicit_memory_schema import TimeRange, UserMemorySummary
from sqlalchemy.orm import Session
logger = logging.getLogger(__name__)
class MemoryDataSource:
"""Retrieves processed memory data from Neo4j using direct Cypher queries."""
def __init__(
self,
db: Session,
neo4j_connector: Optional[Neo4jConnector] = None
):
self.db = db
self.neo4j_connector = neo4j_connector or Neo4jConnector()
self.memory_summary_repo = MemorySummaryRepository(self.neo4j_connector)
def _parse_timestamp(self, timestamp: Any) -> datetime:
"""Parse timestamp from various formats."""
if isinstance(timestamp, str):
return datetime.fromisoformat(timestamp.replace('Z', '+00:00'))
elif timestamp is None:
return datetime.now()
return timestamp
def _dict_to_user_summary(self, summary_dict: Dict, user_id: str) -> Optional[UserMemorySummary]:
"""Convert a Neo4j dict directly to UserMemorySummary."""
try:
content = summary_dict.get("content", summary_dict.get("summary", ""))
if not content or not content.strip():
return None
return UserMemorySummary(
summary_id=summary_dict.get("id", summary_dict.get("uuid", "")),
user_id=user_id,
user_content=content,
timestamp=self._parse_timestamp(summary_dict.get("created_at")),
confidence_score=1.0,
summary_type="memory_summary"
)
except Exception as e:
logger.warning(f"Failed to parse summary {summary_dict.get('id', 'unknown')}: {e}")
return None
async def get_user_summaries(
self,
user_id: str,
time_range: Optional[TimeRange] = None,
limit: int = 1000
) -> List[UserMemorySummary]:
"""Retrieve user memory summaries from Neo4j.
Args:
user_id: Target user ID
time_range: Optional time range filter
limit: Maximum number of summaries
Returns:
List of user memory summaries
"""
try:
start_date = time_range.start_date if time_range else None
end_date = time_range.end_date if time_range else None
summary_dicts = await self.memory_summary_repo.find_by_group_id(
group_id=user_id,
limit=limit,
start_date=start_date,
end_date=end_date
)
summaries = []
for summary_dict in summary_dicts:
summary = self._dict_to_user_summary(summary_dict, user_id)
if summary:
summaries.append(summary)
logger.info(f"Retrieved {len(summaries)} summaries for user {user_id}")
return summaries
except Exception as e:
logger.error(f"Failed to retrieve summaries for user {user_id}: {e}")
raise

View File

@@ -0,0 +1,234 @@
"""Habit Detector for Implicit Memory System
This module implements the HabitDetector class that specializes in identifying
and ranking behavioral habits from user memory summaries. It provides advanced
habit analysis with confidence scoring, recency weighting, and current vs past
habit distinction.
"""
import logging
from datetime import datetime, timedelta
from typing import List, Optional
from app.core.memory.analytics.implicit_memory.analyzers.habit_analyzer import (
HabitAnalyzer,
)
from app.core.memory.llm_tools.llm_client import LLMClientException
from app.schemas.implicit_memory_schema import (
BehaviorHabit,
ConfidenceLevel,
FrequencyPattern,
UserMemorySummary,
)
from sqlalchemy.orm import Session
logger = logging.getLogger(__name__)
class HabitDetector:
"""Detects and ranks behavioral habits from user memory summaries."""
def __init__(
self,
db: Session,
llm_model_id: Optional[str] = None
):
"""Initialize the habit detector.
Args:
db: Database session
llm_model_id: Optional LLM model ID to use for analysis
"""
self.db = db
self.llm_model_id = llm_model_id
self.habit_analyzer = HabitAnalyzer(db, llm_model_id)
async def detect_habits(
self,
user_id: str,
user_summaries: List[UserMemorySummary],
existing_habits: Optional[List[BehaviorHabit]] = None
) -> List[BehaviorHabit]:
"""Detect behavioral habits from user summaries.
Args:
user_id: Target user ID
user_summaries: List of user-specific memory summaries
existing_habits: Optional existing habits for consolidation
Returns:
List of detected and ranked behavioral habits
Raises:
LLMClientException: If habit analysis fails
"""
if not user_summaries:
logger.warning(f"No summaries provided for user {user_id}")
return existing_habits or []
logger.info(f"Detecting habits for user {user_id} with {len(user_summaries)} summaries")
try:
# Use the habit analyzer to extract habits
detected_habits = await self.habit_analyzer.analyze_habits(
user_id=user_id,
user_summaries=user_summaries,
existing_habits=existing_habits
)
# Apply advanced ranking and filtering
ranked_habits = self.rank_habits_by_confidence_and_recency(detected_habits)
# Distinguish current vs past habits
categorized_habits = self.distinguish_current_vs_past_habits(ranked_habits)
logger.info(f"Detected {len(categorized_habits)} habits for user {user_id}")
return categorized_habits
except LLMClientException:
logger.error(f"Habit detection failed for user {user_id}")
raise
except Exception as e:
logger.error(f"Habit detection failed for user {user_id}: {e}")
raise LLMClientException(f"Habit detection failed: {e}") from e
def rank_habits_by_confidence_and_recency(
self,
habits: List[BehaviorHabit],
confidence_weight: float = 0.6,
recency_weight: float = 0.4
) -> List[BehaviorHabit]:
"""Rank habits by confidence level and recency.
Args:
habits: List of habits to rank
confidence_weight: Weight for confidence score (0.0-1.0)
recency_weight: Weight for recency score (0.0-1.0)
Returns:
List of habits ranked by combined score
"""
if not habits:
return []
logger.info(f"Ranking {len(habits)} habits by confidence and recency")
def calculate_ranking_score(habit: BehaviorHabit) -> float:
"""Calculate combined ranking score for a habit."""
# Confidence score (0.0-1.0)
confidence_scores = {
ConfidenceLevel.HIGH: 1.0,
ConfidenceLevel.MEDIUM: 0.6,
ConfidenceLevel.LOW: 0.3
}
confidence_score = confidence_scores.get(habit.confidence_level, 0.3)
# Recency score (0.0-1.0)
current_time = datetime.now()
days_since_last = (current_time - habit.last_observed).days
# Exponential decay for recency (habits lose relevance over time)
if days_since_last <= 7:
recency_score = 1.0 # Very recent
elif days_since_last <= 30:
recency_score = 0.8 # Recent
elif days_since_last <= 90:
recency_score = 0.5 # Somewhat recent
elif days_since_last <= 180:
recency_score = 0.3 # Old
else:
recency_score = 0.1 # Very old
# Frequency pattern bonus
frequency_bonuses = {
FrequencyPattern.DAILY: 0.2,
FrequencyPattern.WEEKLY: 0.15,
FrequencyPattern.MONTHLY: 0.1,
FrequencyPattern.SEASONAL: 0.05,
FrequencyPattern.OCCASIONAL: 0.0,
FrequencyPattern.EVENT_TRIGGERED: 0.05
}
frequency_bonus = frequency_bonuses.get(habit.frequency_pattern, 0.0)
# Evidence quality bonus
evidence_bonus = min(len(habit.supporting_summaries) / 10.0, 0.1) # Max 0.1 bonus
# Current habit bonus
current_bonus = 0.1 if habit.is_current else 0.0
# Calculate final score
base_score = (confidence_score * confidence_weight +
recency_score * recency_weight)
final_score = base_score + frequency_bonus + evidence_bonus + current_bonus
return min(final_score, 1.0) # Cap at 1.0
# Sort habits by ranking score (descending)
ranked_habits = sorted(habits, key=calculate_ranking_score, reverse=True)
logger.info(f"Ranked habits with scores: {[calculate_ranking_score(h) for h in ranked_habits[:5]]}")
return ranked_habits
def distinguish_current_vs_past_habits(
self,
habits: List[BehaviorHabit],
current_threshold_days: int = 30
) -> List[BehaviorHabit]:
"""Distinguish between current and past habits based on recency.
Args:
habits: List of habits to categorize
current_threshold_days: Days threshold for considering a habit current
Returns:
List of habits with updated is_current status
"""
if not habits:
return []
current_time = datetime.now()
cutoff_date = current_time - timedelta(days=current_threshold_days)
current_habits = []
past_habits = []
for habit in habits:
# Update is_current status based on last observation
if habit.last_observed >= cutoff_date:
# Create updated habit with is_current = True
updated_habit = BehaviorHabit(
habit_description=habit.habit_description,
frequency_pattern=habit.frequency_pattern,
time_context=habit.time_context,
confidence_level=habit.confidence_level,
supporting_summaries=habit.supporting_summaries,
specific_examples=habit.specific_examples,
first_observed=habit.first_observed,
last_observed=habit.last_observed,
is_current=True
)
current_habits.append(updated_habit)
else:
# Create updated habit with is_current = False
updated_habit = BehaviorHabit(
habit_description=habit.habit_description,
frequency_pattern=habit.frequency_pattern,
time_context=habit.time_context,
confidence_level=habit.confidence_level,
supporting_summaries=habit.supporting_summaries,
specific_examples=habit.specific_examples,
first_observed=habit.first_observed,
last_observed=habit.last_observed,
is_current=False
)
past_habits.append(updated_habit)
# Return current habits first, then past habits
categorized_habits = current_habits + past_habits
logger.info(f"Categorized habits: {len(current_habits)} current, {len(past_habits)} past")
return categorized_habits

View File

@@ -0,0 +1,321 @@
"""LLM Client Wrapper for Implicit Memory Analysis
This module provides a specialized LLM client wrapper that integrates with the
MemoryClientFactory to perform implicit memory analysis tasks including preference
extraction, personality dimension analysis, interest categorization, and habit detection.
"""
import logging
from typing import Any, Dict, List, Optional
from app.core.memory.analytics.implicit_memory.prompts import (
get_dimension_analysis_prompt,
get_habit_analysis_prompt,
get_interest_analysis_prompt,
get_preference_analysis_prompt,
)
from app.core.memory.llm_tools.llm_client import LLMClientException
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.schemas.implicit_memory_schema import UserMemorySummary
from pydantic import BaseModel, Field
from sqlalchemy.orm import Session
logger = logging.getLogger(__name__)
# Response Models for LLM Analysis
class PreferenceAnalysisResponse(BaseModel):
"""Response model for preference analysis."""
preferences: List[Dict[str, Any]] = Field(default_factory=list)
class DimensionAnalysisResponse(BaseModel):
"""Response model for dimension analysis."""
dimensions: Dict[str, Dict[str, Any]] = Field(default_factory=dict)
class InterestAnalysisResponse(BaseModel):
"""Response model for interest analysis."""
interest_distribution: Dict[str, Dict[str, Any]] = Field(default_factory=dict)
class HabitAnalysisResponse(BaseModel):
"""Response model for habit analysis."""
habits: List[Dict[str, Any]] = Field(default_factory=list)
class ImplicitMemoryLLMClient:
"""LLM client wrapper for implicit memory analysis.
This class provides a high-level interface for performing LLM-based analysis
of user memory summaries to extract preferences, personality dimensions,
interests, and behavioral habits.
"""
def __init__(self, db: Session, default_model_id: Optional[str] = None):
"""Initialize the LLM client wrapper.
Args:
db: Database session for accessing model configurations
default_model_id: Default LLM model ID to use if none specified
"""
self.db = db
self.default_model_id = default_model_id
self._client_factory = MemoryClientFactory(db)
logger.info("ImplicitMemoryLLMClient initialized")
def _get_llm_client(self, model_id: Optional[str] = None):
"""Get LLM client instance.
Args:
model_id: LLM model ID to use, defaults to default_model_id
Returns:
LLM client instance
Raises:
ValueError: If no model ID is provided and no default is set
LLMClientException: If client creation fails
"""
effective_model_id = model_id or self.default_model_id
if not effective_model_id:
raise ValueError("No LLM model ID provided and no default model ID set")
try:
client = self._client_factory.get_llm_client(effective_model_id)
logger.debug(f"Created LLM client for model: {effective_model_id}")
return client
except Exception as e:
logger.error(f"Failed to create LLM client for model {effective_model_id}: {e}")
raise LLMClientException(f"Failed to create LLM client: {e}") from e
def _prepare_summaries_for_analysis(self, user_summaries: List[UserMemorySummary]) -> List[Dict[str, Any]]:
"""Prepare user memory summaries for LLM analysis.
Args:
user_summaries: List of user memory summaries
Returns:
List of formatted summary dictionaries
"""
formatted_summaries = []
for summary in user_summaries:
formatted_summary = {
'summary_id': summary.summary_id,
'user_content': summary.user_content,
'timestamp': summary.timestamp.isoformat(),
'summary_type': summary.summary_type,
'confidence_score': summary.confidence_score
}
formatted_summaries.append(formatted_summary)
logger.debug(f"Prepared {len(formatted_summaries)} summaries for analysis")
return formatted_summaries
async def analyze_preferences(
self,
user_summaries: List[UserMemorySummary],
user_id: str,
model_id: Optional[str] = None
) -> Dict[str, Any]:
"""Analyze user preferences from memory summaries.
Args:
user_summaries: List of user memory summaries to analyze
user_id: Target user ID for analysis
model_id: Optional LLM model ID to use
Returns:
Dictionary containing extracted preferences
Raises:
LLMClientException: If LLM analysis fails
ValueError: If input validation fails
"""
if not user_summaries:
logger.warning(f"No summaries provided for preference analysis of user {user_id}")
return {"preferences": []}
if not user_id:
raise ValueError("User ID is required for preference analysis")
try:
# Prepare summaries and get prompt
formatted_summaries = self._prepare_summaries_for_analysis(user_summaries)
prompt = get_preference_analysis_prompt(formatted_summaries, user_id)
# Get LLM client and perform analysis
llm_client = self._get_llm_client(model_id)
messages = [{"role": "user", "content": prompt}]
# Use structured output for reliable parsing
response = await llm_client.response_structured(
messages=messages,
response_model=PreferenceAnalysisResponse
)
result = response.model_dump()
logger.info(f"Analyzed preferences for user {user_id}: found {len(result.get('preferences', []))} preferences")
return result
except Exception as e:
logger.error(f"Preference analysis failed for user {user_id}: {e}")
raise LLMClientException(f"Preference analysis failed: {e}") from e
async def analyze_dimensions(
self,
user_summaries: List[UserMemorySummary],
user_id: str,
model_id: Optional[str] = None
) -> Dict[str, Any]:
"""Analyze user personality dimensions from memory summaries.
Args:
user_summaries: List of user memory summaries to analyze
user_id: Target user ID for analysis
model_id: Optional LLM model ID to use
Returns:
Dictionary containing dimension scores and analysis
Raises:
LLMClientException: If LLM analysis fails
ValueError: If input validation fails
"""
if not user_summaries:
logger.warning(f"No summaries provided for dimension analysis of user {user_id}")
return {"dimensions": {}}
if not user_id:
raise ValueError("User ID is required for dimension analysis")
try:
# Prepare summaries and get prompt
formatted_summaries = self._prepare_summaries_for_analysis(user_summaries)
prompt = get_dimension_analysis_prompt(formatted_summaries, user_id)
# Get LLM client and perform analysis
llm_client = self._get_llm_client(model_id)
messages = [{"role": "user", "content": prompt}]
# Use structured output for reliable parsing
response = await llm_client.response_structured(
messages=messages,
response_model=DimensionAnalysisResponse
)
result = response.model_dump()
dimensions = result.get('dimensions', {})
logger.info(f"Analyzed dimensions for user {user_id}: {list(dimensions.keys())}")
return result
except Exception as e:
logger.error(f"Dimension analysis failed for user {user_id}: {e}")
raise LLMClientException(f"Dimension analysis failed: {e}") from e
async def analyze_interests(
self,
user_summaries: List[UserMemorySummary],
user_id: str,
model_id: Optional[str] = None
) -> Dict[str, Any]:
"""Analyze user interest distribution from memory summaries.
Args:
user_summaries: List of user memory summaries to analyze
user_id: Target user ID for analysis
model_id: Optional LLM model ID to use
Returns:
Dictionary containing interest area distribution
Raises:
LLMClientException: If LLM analysis fails
ValueError: If input validation fails
"""
if not user_summaries:
logger.warning(f"No summaries provided for interest analysis of user {user_id}")
return {"interest_distribution": {}}
if not user_id:
raise ValueError("User ID is required for interest analysis")
try:
# Prepare summaries and get prompt
formatted_summaries = self._prepare_summaries_for_analysis(user_summaries)
prompt = get_interest_analysis_prompt(formatted_summaries, user_id)
# Get LLM client and perform analysis
llm_client = self._get_llm_client(model_id)
messages = [{"role": "user", "content": prompt}]
# Use structured output for reliable parsing
response = await llm_client.response_structured(
messages=messages,
response_model=InterestAnalysisResponse
)
result = response.model_dump()
interest_dist = result.get('interest_distribution', {})
logger.info(f"Analyzed interests for user {user_id}: {list(interest_dist.keys())}")
return result
except Exception as e:
logger.error(f"Interest analysis failed for user {user_id}: {e}")
raise LLMClientException(f"Interest analysis failed: {e}") from e
async def analyze_habits(
self,
user_summaries: List[UserMemorySummary],
user_id: str,
model_id: Optional[str] = None
) -> Dict[str, Any]:
"""Analyze user behavioral habits from memory summaries.
Args:
user_summaries: List of user memory summaries to analyze
user_id: Target user ID for analysis
model_id: Optional LLM model ID to use
Returns:
Dictionary containing identified behavioral habits
Raises:
LLMClientException: If LLM analysis fails
ValueError: If input validation fails
"""
if not user_summaries:
logger.warning(f"No summaries provided for habit analysis of user {user_id}")
return {"habits": []}
if not user_id:
raise ValueError("User ID is required for habit analysis")
try:
# Prepare summaries and get prompt
formatted_summaries = self._prepare_summaries_for_analysis(user_summaries)
prompt = get_habit_analysis_prompt(formatted_summaries, user_id)
# Get LLM client and perform analysis
llm_client = self._get_llm_client(model_id)
messages = [{"role": "user", "content": prompt}]
# Use structured output for reliable parsing
response = await llm_client.response_structured(
messages=messages,
response_model=HabitAnalysisResponse
)
result = response.model_dump()
logger.info(f"Analyzed habits for user {user_id}: found {len(result.get('habits', []))} habits")
return result
except Exception as e:
logger.error(f"Habit analysis failed for user {user_id}: {e}")
raise LLMClientException(f"Habit analysis failed: {e}") from e

View File

@@ -0,0 +1,69 @@
"""LLM Prompt Templates for Implicit Memory Analysis
This module contains prompt rendering functions for analyzing user memory summaries
to extract preferences, personality dimensions, interests, and behavioral habits.
"""
import os
from typing import Any, Dict, List
from jinja2 import Environment, FileSystemLoader
# Setup Jinja2 environment
current_dir = os.path.dirname(os.path.abspath(__file__))
prompt_dir = os.path.join(current_dir, "prompts")
prompt_env = Environment(loader=FileSystemLoader(prompt_dir))
def _render_template(template_name: str, **kwargs) -> str:
"""Helper function to render Jinja2 templates."""
template = prompt_env.get_template(template_name)
return template.render(**kwargs)
def get_preference_analysis_prompt(
memory_summaries: List[Dict[str, Any]],
user_id: str
) -> str:
"""Get formatted preference analysis prompt using Jinja2 template."""
return _render_template(
"preference_analysis.jinja2",
memory_summaries=memory_summaries,
user_id=user_id
)
def get_dimension_analysis_prompt(
memory_summaries: List[Dict[str, Any]],
user_id: str
) -> str:
"""Get formatted dimension analysis prompt using Jinja2 template."""
return _render_template(
"dimension_analysis.jinja2",
memory_summaries=memory_summaries,
user_id=user_id
)
def get_interest_analysis_prompt(
memory_summaries: List[Dict[str, Any]],
user_id: str
) -> str:
"""Get formatted interest analysis prompt using Jinja2 template."""
return _render_template(
"interest_analysis.jinja2",
memory_summaries=memory_summaries,
user_id=user_id
)
def get_habit_analysis_prompt(
memory_summaries: List[Dict[str, Any]],
user_id: str
) -> str:
"""Get formatted habit analysis prompt using Jinja2 template."""
return _render_template(
"habit_analysis.jinja2",
memory_summaries=memory_summaries,
user_id=user_id
)

View File

@@ -0,0 +1,41 @@
You are an expert personality analyst. Analyze memory summaries to assess the user's personality across four dimensions.
## Memory Summaries
{% for summary in memory_summaries %}
Summary {{ loop.index }}:
{{ summary.content or summary.user_content or '' }}
---
{% endfor %}
## Target User ID
{{ user_id }}
## Dimensions to Analyze
1. **Creativity** (0-100%): Creative thinking, artistic interests, innovative ideas
2. **Aesthetic** (0-100%): Design preferences, visual interests, artistic appreciation
3. **Technology** (0-100%): Technical discussions, tool usage, programming interests
4. **Literature** (0-100%): Reading habits, writing style, literary references
## Instructions
1. Analyze the user's content for each dimension
2. Calculate percentage scores (0-100%)
## Output Format
{
"dimensions": {
"creativity": {"percentage": 0-100},
"aesthetic": {"percentage": 0-100},
"technology": {"percentage": 0-100},
"literature": {"percentage": 0-100}
}
}
## Example
{
"dimensions": {
"creativity": {"percentage": 75},
"aesthetic": {"percentage": 45},
"technology": {"percentage": 60},
"literature": {"percentage": 30}
}
}

View File

@@ -0,0 +1,70 @@
You are an expert at identifying behavioral patterns and habits from memory summaries.
## Memory Summaries
{% for summary in memory_summaries %}
Summary {{ loop.index }}:
{{ summary.content or summary.user_content or '' }}
---
{% endfor %}
## Target User ID
{{ user_id }}
## Instructions
1. Identify recurring behavioral patterns mentioned by the SPECIFIED USER
2. Focus on specific, concrete habits with temporal patterns
3. For each habit, provide:
- habit_description: Clear, specific description
- frequency_pattern: "daily", "weekly", "monthly", "seasonal", "occasional", "event_triggered"
- time_context: When it typically happens
- confidence_level: "high", "medium", "low"
- supporting_summaries: References to evidence
- specific_examples: Concrete examples from summaries
- is_current: true if current habit, false if past habit
4. Only include habits with medium or high confidence
5. **IMPORTANT: Output language MUST match the input language. If summaries are in Chinese, output in Chinese. If in English, output in English.**
## Output Format
{
"habits": [
{
"habit_description": "string",
"frequency_pattern": "daily|weekly|monthly|seasonal|occasional|event_triggered",
"time_context": "string",
"confidence_level": "high|medium|low",
"supporting_summaries": ["id1", "id2"],
"specific_examples": ["example1", "example2"],
"is_current": true|false
}
]
}
## Example (English input → English output)
{
"habits": [
{
"habit_description": "drinks coffee every morning",
"frequency_pattern": "daily",
"time_context": "morning routine",
"confidence_level": "high",
"supporting_summaries": ["s1", "s2"],
"specific_examples": ["needs coffee to start the day"],
"is_current": true
}
]
}
## Example (Chinese input → Chinese output)
{
"habits": [
{
"habit_description": "每天早上喝咖啡",
"frequency_pattern": "daily",
"time_context": "早晨日常",
"confidence_level": "high",
"supporting_summaries": ["s1", "s2"],
"specific_examples": ["需要咖啡来开始一天"],
"is_current": true
}
]
}

View File

@@ -0,0 +1,54 @@
You are an expert at analyzing user interests from memory summaries.
## Memory Summaries
{% for summary in memory_summaries %}
Summary {{ loop.index }}:
{{ summary.content or summary.user_content or '' }}
---
{% endfor %}
## Target User ID
{{ user_id }}
## Interest Categories
1. **Tech**: Programming, technology, software tools, hardware
2. **Lifestyle**: Daily routines, health, hobbies, social activities
3. **Music**: Music preferences, instruments, concerts
4. **Art**: Visual arts, creative projects, design, aesthetics
## Instructions
1. Categorize the user's interests into the four areas
2. Calculate percentage distribution (must total 100%)
3. Provide specific evidence for each interest area
4. Use "increasing", "decreasing", or "stable" for trending direction
5. **IMPORTANT: Output language MUST match the input language. If summaries are in Chinese, output in Chinese. If in English, output in English.**
## Output Format
{
"interest_distribution": {
"tech": {"percentage": 0-100, "evidence": [], "trending_direction": "increasing|decreasing|stable|null"},
"lifestyle": {"percentage": 0-100, "evidence": [], "trending_direction": "increasing|decreasing|stable|null"},
"music": {"percentage": 0-100, "evidence": [], "trending_direction": "increasing|decreasing|stable|null"},
"art": {"percentage": 0-100, "evidence": [], "trending_direction": "increasing|decreasing|stable|null"}
}
}
## Example (English input → English output)
{
"interest_distribution": {
"tech": {"percentage": 40, "evidence": ["discusses programming frequently"], "trending_direction": "increasing"},
"lifestyle": {"percentage": 35, "evidence": ["talks about fitness routine"], "trending_direction": "stable"},
"music": {"percentage": 15, "evidence": ["mentioned favorite bands"], "trending_direction": "stable"},
"art": {"percentage": 10, "evidence": ["visited art museum"], "trending_direction": "stable"}
}
}
## Example (Chinese input → Chinese output)
{
"interest_distribution": {
"tech": {"percentage": 40, "evidence": ["经常讨论编程"], "trending_direction": "increasing"},
"lifestyle": {"percentage": 35, "evidence": ["谈论健身日常"], "trending_direction": "stable"},
"music": {"percentage": 15, "evidence": ["提到喜欢的乐队"], "trending_direction": "stable"},
"art": {"percentage": 10, "evidence": ["参观了艺术博物馆"], "trending_direction": "stable"}
}
}

View File

@@ -0,0 +1,47 @@
You are an expert at analyzing user memory summaries to identify implicit preferences.
## Memory Summaries
{% for summary in memory_summaries %}
Summary {{ loop.index }}:
{{ summary.content or summary.user_content or '' }}
---
{% endfor %}
## Target User ID
{{ user_id }}
## Instructions
1. Focus ONLY on the specified user's preferences
2. Extract SHORT preference tags (1-3 words max), like: "音乐", "咖啡", "科幻", "设计", "古典", "吉他"
3. DO NOT use long phrases - use short nouns or noun phrases
4. Only include preferences with confidence_score >= 0.3
5. **IMPORTANT: Output language MUST match the input language. If summaries are in Chinese, output in Chinese. If in English, output in English.**
## Output Format
{
"preferences": [
{
"tag_name": "short tag",
"confidence_score": 0.0-1.0,
"supporting_evidence": ["evidence1", "evidence2"],
"context_details": "brief context",
"category": "category or null"
}
]
}
## Example (Chinese input → Chinese output)
{
"preferences": [
{"tag_name": "咖啡", "confidence_score": 0.8, "supporting_evidence": ["每天早上喝咖啡"], "context_details": "日常习惯", "category": "lifestyle"},
{"tag_name": "古典音乐", "confidence_score": 0.7, "supporting_evidence": ["喜欢听古典"], "context_details": "音乐偏好", "category": "music"}
]
}
## Example (English input → English output)
{
"preferences": [
{"tag_name": "coffee", "confidence_score": 0.8, "supporting_evidence": ["drinks coffee every morning"], "context_details": "daily routine", "category": "lifestyle"},
{"tag_name": "classical music", "confidence_score": 0.7, "supporting_evidence": ["enjoys classical"], "context_details": "music preference", "category": "music"}
]
}

View File

@@ -76,7 +76,7 @@ class HttpContentTypeConfig(BaseModel):
elif content_type in [HttpContentType.JSON] and not isinstance(v, str):
raise ValueError("When content_type is JSON, data must be of type str")
elif content_type in [HttpContentType.WWW_FORM] and not isinstance(v, dict):
raise ValueError("When content_type is x-www-form-urlencoded, data must be a object")
raise ValueError("When content_type is x-www-form-urlencoded, data must be an object(dict)")
elif content_type in [HttpContentType.RAW, HttpContentType.BINARY] and not isinstance(v, str):
raise ValueError("When content_type is raw/binary, data must be a string (File descriptor)")
return v

View File

@@ -26,6 +26,7 @@ from .tool_model import (
ToolConfig, BuiltinToolConfig, CustomToolConfig, MCPToolConfig,
ToolExecution, ToolType, ToolStatus, AuthType, ExecutionStatus
)
from .memory_perceptual_model import MemoryPerceptualModel
__all__ = [
"Tenants",
@@ -74,5 +75,6 @@ __all__ = [
"ToolType",
"ToolStatus",
"AuthType",
"ExecutionStatus"
"ExecutionStatus",
"MemoryPerceptualModel"
]

View File

@@ -0,0 +1,273 @@
# -*- coding: utf-8 -*-
"""Memory Summary Repository Module
This module provides data access functionality for MemorySummary nodes.
Classes:
MemorySummaryRepository: Repository for managing MemorySummary CRUD operations
"""
from datetime import datetime
from typing import Any, Dict, List, Optional
from app.repositories.neo4j.base_neo4j_repository import BaseNeo4jRepository
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
class MemorySummaryRepository(BaseNeo4jRepository):
"""Memory Summary Repository
Manages CRUD operations for MemorySummary nodes.
Provides methods to query summaries by group_id, user_id, and time ranges.
Attributes:
connector: Neo4j connector instance
node_label: Node label, fixed as "MemorySummary"
"""
def __init__(self, connector: Neo4jConnector):
"""Initialize memory summary repository
Args:
connector: Neo4j connector instance
"""
super().__init__(connector, "MemorySummary")
def _map_to_dict(self, node_data: Dict) -> Dict[str, Any]:
"""Map node data to dictionary format
Args:
node_data: Node data returned from Neo4j query
Returns:
Dict[str, Any]: Memory summary data dictionary
"""
# Extract node data from query result
n = node_data.get('n', node_data)
# Handle datetime fields
if isinstance(n.get('created_at'), str):
n['created_at'] = datetime.fromisoformat(n['created_at'])
return dict(n)
async def find_by_group_id(
self,
group_id: str,
limit: int = 1000,
start_date: Optional[datetime] = None,
end_date: Optional[datetime] = None
) -> List[Dict[str, Any]]:
"""Query memory summaries by group_id
Args:
group_id: Group ID to filter by
limit: Maximum number of results to return
start_date: Optional start date filter
end_date: Optional end date filter
Returns:
List[Dict[str, Any]]: List of memory summary dictionaries
"""
query = f"""
MATCH (n:{self.node_label})
WHERE n.group_id = $group_id
"""
params = {"group_id": group_id, "limit": limit}
# Add date range filters if provided
if start_date:
query += " AND n.created_at >= $start_date"
params["start_date"] = start_date
if end_date:
query += " AND n.created_at <= $end_date"
params["end_date"] = end_date
query += """
RETURN n
ORDER BY n.created_at DESC
LIMIT $limit
"""
results = await self.connector.execute_query(query, **params)
return [self._map_to_dict(r) for r in results]
async def find_by_user_id(
self,
user_id: str,
limit: int = 1000,
start_date: Optional[datetime] = None,
end_date: Optional[datetime] = None
) -> List[Dict[str, Any]]:
"""Query memory summaries by user_id
Args:
user_id: User ID to filter by
limit: Maximum number of results to return
start_date: Optional start date filter
end_date: Optional end date filter
Returns:
List[Dict[str, Any]]: List of memory summary dictionaries
"""
query = f"""
MATCH (n:{self.node_label})
WHERE n.user_id = $user_id
"""
params = {"user_id": user_id, "limit": limit}
# Add date range filters if provided
if start_date:
query += " AND n.created_at >= $start_date"
params["start_date"] = start_date
if end_date:
query += " AND n.created_at <= $end_date"
params["end_date"] = end_date
query += """
RETURN n
ORDER BY n.created_at DESC
LIMIT $limit
"""
results = await self.connector.execute_query(query, **params)
return [self._map_to_dict(r) for r in results]
async def find_by_group_and_user(
self,
group_id: str,
user_id: str,
limit: int = 1000,
start_date: Optional[datetime] = None,
end_date: Optional[datetime] = None
) -> List[Dict[str, Any]]:
"""Query memory summaries by both group_id and user_id
Args:
group_id: Group ID to filter by
user_id: User ID to filter by
limit: Maximum number of results to return
start_date: Optional start date filter
end_date: Optional end date filter
Returns:
List[Dict[str, Any]]: List of memory summary dictionaries
"""
query = f"""
MATCH (n:{self.node_label})
WHERE n.group_id = $group_id AND n.user_id = $user_id
"""
params = {"group_id": group_id, "user_id": user_id, "limit": limit}
# Add date range filters if provided
if start_date:
query += " AND n.created_at >= $start_date"
params["start_date"] = start_date
if end_date:
query += " AND n.created_at <= $end_date"
params["end_date"] = end_date
query += """
RETURN n
ORDER BY n.created_at DESC
LIMIT $limit
"""
results = await self.connector.execute_query(query, **params)
return [self._map_to_dict(r) for r in results]
async def find_recent_summaries(
self,
group_id: str,
days: int = 7,
limit: int = 1000
) -> List[Dict[str, Any]]:
"""Query recent memory summaries
Args:
group_id: Group ID to filter by
days: Number of recent days to query
limit: Maximum number of results to return
Returns:
List[Dict[str, Any]]: List of memory summary dictionaries
"""
query = f"""
MATCH (n:{self.node_label})
WHERE n.group_id = $group_id
AND n.created_at >= datetime() - duration({{days: $days}})
RETURN n
ORDER BY n.created_at DESC
LIMIT $limit
"""
results = await self.connector.execute_query(
query,
group_id=group_id,
days=days,
limit=limit
)
return [self._map_to_dict(r) for r in results]
async def find_by_content_keywords(
self,
group_id: str,
keywords: List[str],
limit: int = 100
) -> List[Dict[str, Any]]:
"""Query memory summaries by content keywords
Args:
group_id: Group ID to filter by
keywords: List of keywords to search for in content
limit: Maximum number of results to return
Returns:
List[Dict[str, Any]]: List of memory summary dictionaries
"""
# Build keyword search conditions
keyword_conditions = []
params = {"group_id": group_id, "limit": limit}
for i, keyword in enumerate(keywords):
keyword_conditions.append(f"toLower(n.content) CONTAINS toLower($keyword_{i})")
params[f"keyword_{i}"] = keyword
keyword_filter = " OR ".join(keyword_conditions)
query = f"""
MATCH (n:{self.node_label})
WHERE n.group_id = $group_id
AND ({keyword_filter})
RETURN n
ORDER BY n.created_at DESC
LIMIT $limit
"""
results = await self.connector.execute_query(query, **params)
return [self._map_to_dict(r) for r in results]
async def get_summary_count_by_group(self, group_id: str) -> int:
"""Get count of memory summaries for a group
Args:
group_id: Group ID to count summaries for
Returns:
int: Number of memory summaries
"""
query = f"""
MATCH (n:{self.node_label})
WHERE n.group_id = $group_id
RETURN count(n) as count
"""
results = await self.connector.execute_query(query, group_id=group_id)
return results[0]['count'] if results else 0

View File

@@ -0,0 +1,212 @@
"""Implicit Memory Schemas
This module defines the Pydantic schemas for the implicit memory system API.
"""
import datetime
from enum import Enum
from typing import Any, Dict, List, Optional
from pydantic import BaseModel, ConfigDict, Field, field_validator
class ConfidenceLevel(str, Enum):
"""Confidence levels for analysis results."""
HIGH = "high"
MEDIUM = "medium"
LOW = "low"
class FrequencyPattern(str, Enum):
"""Frequency patterns for habits."""
DAILY = "daily"
WEEKLY = "weekly"
MONTHLY = "monthly"
SEASONAL = "seasonal"
OCCASIONAL = "occasional"
EVENT_TRIGGERED = "event_triggered"
# Request Schemas
class TimeRange(BaseModel):
"""Time range for analysis."""
start_date: datetime.datetime
end_date: datetime.datetime
@field_validator('end_date')
@classmethod
def end_date_must_be_after_start_date(cls, v, info):
if 'start_date' in info.data and v <= info.data['start_date']:
raise ValueError('end_date must be after start_date')
return v
class DateRange(BaseModel):
"""Date range for filtering."""
start_date: Optional[datetime.datetime] = None
end_date: Optional[datetime.datetime] = None
@field_validator('end_date')
@classmethod
def end_date_must_be_after_start_date(cls, v, info):
if v and 'start_date' in info.data and info.data['start_date'] and v <= info.data['start_date']:
raise ValueError('end_date must be after start_date')
return v
class AnalysisConfig(BaseModel):
"""Configuration for analysis operations."""
llm_model_id: Optional[str] = None
batch_size: int = 100
confidence_threshold: float = 0.5
include_historical_trends: bool = False
max_conversations: Optional[int] = None
# Response Schemas
class PreferenceTagResponse(BaseModel):
"""A user preference tag with detailed context."""
model_config = ConfigDict(from_attributes=True)
tag_name: str
confidence_score: float = Field(ge=0.0, le=1.0)
supporting_evidence: List[str]
context_details: str
created_at: datetime.datetime
updated_at: datetime.datetime
conversation_references: List[str]
category: Optional[str] = None
class DimensionScoreResponse(BaseModel):
"""Score for a personality dimension."""
model_config = ConfigDict(from_attributes=True)
dimension_name: str
percentage: float = Field(ge=0.0, le=100.0)
evidence: List[str]
reasoning: str
confidence_level: ConfidenceLevel
class DimensionPortraitResponse(BaseModel):
"""Four-dimension personality portrait."""
model_config = ConfigDict(from_attributes=True)
user_id: str
creativity: DimensionScoreResponse
aesthetic: DimensionScoreResponse
technology: DimensionScoreResponse
literature: DimensionScoreResponse
analysis_timestamp: datetime.datetime
total_summaries_analyzed: int
historical_trends: Optional[List[Dict[str, Any]]] = None
class InterestCategoryResponse(BaseModel):
"""Interest category with percentage and evidence."""
model_config = ConfigDict(from_attributes=True)
category_name: str
percentage: float = Field(ge=0.0, le=100.0)
evidence: List[str]
trending_direction: Optional[str] = None
class InterestAreaDistributionResponse(BaseModel):
"""Distribution of user interests across four areas."""
model_config = ConfigDict(from_attributes=True)
user_id: str
tech: InterestCategoryResponse
lifestyle: InterestCategoryResponse
music: InterestCategoryResponse
art: InterestCategoryResponse
analysis_timestamp: datetime.datetime
total_summaries_analyzed: int
@property
def total_percentage(self) -> float:
"""Calculate total percentage across all interest areas."""
return self.tech.percentage + self.lifestyle.percentage + self.music.percentage + self.art.percentage
class BehaviorHabitResponse(BaseModel):
"""A behavioral habit identified from conversations."""
model_config = ConfigDict(from_attributes=True)
habit_description: str
frequency_pattern: FrequencyPattern
time_context: str
confidence_level: ConfidenceLevel
supporting_summaries: List[str]
first_observed: datetime.datetime
last_observed: datetime.datetime
is_current: bool = True
specific_examples: List[str]
class UserProfileResponse(BaseModel):
"""Comprehensive user profile."""
model_config = ConfigDict(from_attributes=True)
user_id: str
preference_tags: List[PreferenceTagResponse]
dimension_portrait: DimensionPortraitResponse
interest_area_distribution: InterestAreaDistributionResponse
behavior_habits: List[BehaviorHabitResponse]
profile_version: int
created_at: datetime.datetime
updated_at: datetime.datetime
total_summaries_analyzed: int
analysis_completeness_score: float = Field(ge=0.0, le=1.0)
# Internal/Business Logic Schemas
class MemorySummary(BaseModel):
"""Memory summary from existing storage system."""
model_config = ConfigDict(from_attributes=True)
summary_id: str
content: str
timestamp: datetime.datetime
participants: List[str]
summary_type: str
class UserMemorySummary(BaseModel):
"""Memory summary filtered for specific user content."""
model_config = ConfigDict(from_attributes=True)
summary_id: str
user_id: str
user_content: str
timestamp: datetime.datetime
confidence_score: float = Field(ge=0.0, le=1.0)
summary_type: str
class SummaryAnalysisResult(BaseModel):
"""Result of analyzing memory summaries."""
model_config = ConfigDict(from_attributes=True)
user_id: str
preferences: List[PreferenceTagResponse]
dimension_evidence: Dict[str, List[str]]
interest_evidence: Dict[str, List[str]]
habit_evidence: List[Dict[str, Any]]
analysis_timestamp: datetime.datetime
summaries_analyzed: List[str]
# Aliases for backward compatibility with existing code
PreferenceTag = PreferenceTagResponse
DimensionScore = DimensionScoreResponse
DimensionPortrait = DimensionPortraitResponse
InterestCategory = InterestCategoryResponse
InterestAreaDistribution = InterestAreaDistributionResponse
BehaviorHabit = BehaviorHabitResponse
UserProfile = UserProfileResponse

View File

@@ -0,0 +1,385 @@
"""
Implicit Memory Service
Main service orchestrating all implicit memory operations. This service coordinates
profile building, data extraction, and provides high-level methods for analyzing
user profiles from memory summaries.
"""
import logging
from datetime import datetime
from typing import List, Optional
from app.core.memory.analytics.implicit_memory.analyzers.dimension_analyzer import (
DimensionAnalyzer,
)
from app.core.memory.analytics.implicit_memory.analyzers.interest_analyzer import (
InterestAnalyzer,
)
from app.core.memory.analytics.implicit_memory.analyzers.preference_analyzer import (
PreferenceAnalyzer,
)
from app.core.memory.analytics.implicit_memory.data_source import MemoryDataSource
from app.core.memory.analytics.implicit_memory.habit_detector import HabitDetector
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.schemas.implicit_memory_schema import (
BehaviorHabit,
ConfidenceLevel,
DateRange,
DimensionPortrait,
FrequencyPattern,
InterestAreaDistribution,
PreferenceTag,
TimeRange,
UserMemorySummary,
)
from app.schemas.memory_config_schema import MemoryConfig
from sqlalchemy.orm import Session
logger = logging.getLogger(__name__)
class ImplicitMemoryService:
"""Main service for implicit memory operations."""
def __init__(
self,
db: Session,
end_user_id: str
):
"""Initialize the implicit memory service.
Args:
db: Database session
end_user_id: End user ID to get connected memory configuration
"""
self.db = db
self.end_user_id = end_user_id
# Get connected memory configuration for the user
self.memory_config = self._get_user_memory_config()
# Extract LLM model ID from memory config
llm_model_id = str(self.memory_config.llm_model_id) if self.memory_config.llm_model_id else None
# Initialize Neo4j connector
self.neo4j_connector = Neo4jConnector()
# Initialize core components with LLM model ID
self.data_source = MemoryDataSource(db, self.neo4j_connector)
self.preference_analyzer = PreferenceAnalyzer(db, llm_model_id)
self.dimension_analyzer = DimensionAnalyzer(db, llm_model_id)
self.interest_analyzer = InterestAnalyzer(db, llm_model_id)
self.habit_detector = HabitDetector(db, llm_model_id)
logger.info(f"ImplicitMemoryService initialized for end_user: {end_user_id}")
def _get_user_memory_config(self) -> MemoryConfig:
"""Get memory configuration for the connected end user.
Returns:
MemoryConfig: User's connected memory configuration
Raises:
ValueError: If no memory configuration found for user
"""
try:
from app.services.memory_agent_service import get_end_user_connected_config
from app.services.memory_config_service import MemoryConfigService
# Get user's connected config
connected_config = get_end_user_connected_config(self.end_user_id, self.db)
config_id = connected_config.get("memory_config_id")
if config_id is None:
raise ValueError(f"No memory configuration found for end_user: {self.end_user_id}")
# Load the memory configuration
config_service = MemoryConfigService(self.db)
memory_config = config_service.load_memory_config(config_id)
logger.info(f"Loaded memory config {config_id} for end_user: {self.end_user_id}")
return memory_config
except Exception as e:
logger.error(f"Failed to get memory config for end_user {self.end_user_id}: {e}")
raise ValueError(f"Unable to get memory configuration for end_user {self.end_user_id}: {e}")
async def extract_user_summaries(
self,
user_id: str,
time_range: Optional[TimeRange] = None,
limit: Optional[int] = None
) -> List[UserMemorySummary]:
"""Extract user-specific memory summaries.
Args:
user_id: Target user ID
time_range: Optional time range to filter summaries
limit: Optional limit on number of summaries
Returns:
List of user-specific memory summaries
"""
logger.info(f"Extracting user summaries for user {user_id}")
try:
summaries = await self.data_source.get_user_summaries(
user_id=user_id,
time_range=time_range,
limit=limit or 1000
)
logger.info(f"Extracted {len(summaries)} summaries for user {user_id}")
return summaries
except Exception as e:
logger.error(f"Failed to extract user summaries for user {user_id}: {e}")
raise
async def get_preference_tags(
self,
user_id: str,
confidence_threshold: float = 0.5,
tag_category: Optional[str] = None,
date_range: Optional[DateRange] = None
) -> List[PreferenceTag]:
"""Retrieve user preference tags with filtering.
Args:
user_id: Target user ID
confidence_threshold: Minimum confidence score for tags
tag_category: Optional category filter
date_range: Optional date range filter
Returns:
List of filtered preference tags
"""
logger.info(f"Getting preference tags for user {user_id}")
try:
# Get user summaries for analysis
time_range = None
if date_range:
time_range = TimeRange(
start_date=date_range.start_date or datetime.min,
end_date=date_range.end_date or datetime.now()
)
user_summaries = await self.extract_user_summaries(
user_id=user_id,
time_range=time_range
)
if not user_summaries:
logger.warning(f"No summaries found for user {user_id}")
return []
# Analyze preferences
preference_tags = await self.preference_analyzer.analyze_preferences(
user_id=user_id,
user_summaries=user_summaries
)
# Apply filters
filtered_tags = []
for tag in preference_tags:
# Filter by confidence threshold
if tag.confidence_score < confidence_threshold:
continue
# Filter by category if specified
if tag_category and tag.category != tag_category:
continue
# Filter by date range if specified
if date_range:
if date_range.start_date and tag.created_at < date_range.start_date:
continue
if date_range.end_date and tag.created_at > date_range.end_date:
continue
filtered_tags.append(tag)
# Sort by confidence score and recency
filtered_tags.sort(
key=lambda x: (x.confidence_score, x.updated_at),
reverse=True
)
logger.info(f"Retrieved {len(filtered_tags)} preference tags for user {user_id}")
return filtered_tags
except Exception as e:
logger.error(f"Failed to get preference tags for user {user_id}: {e}")
raise
async def get_dimension_portrait(
self,
user_id: str,
include_history: bool = False
) -> DimensionPortrait:
"""Get user's four-dimension personality portrait.
Args:
user_id: Target user ID
include_history: Whether to include historical trends
Returns:
User's dimension portrait
"""
logger.info(f"Getting dimension portrait for user {user_id}")
try:
# Get user summaries
user_summaries = await self.extract_user_summaries(user_id=user_id)
if not user_summaries:
logger.warning(f"No summaries found for user {user_id}")
return self.dimension_analyzer._create_empty_portrait(user_id)
# Analyze dimensions
dimension_portrait = await self.dimension_analyzer.analyze_dimensions(
user_id=user_id,
user_summaries=user_summaries
)
# Include historical trends if requested
if include_history:
# In a full implementation, this would retrieve historical data
# For now, we'll leave historical_trends as None
pass
logger.info(f"Retrieved dimension portrait for user {user_id}")
return dimension_portrait
except Exception as e:
logger.error(f"Failed to get dimension portrait for user {user_id}: {e}")
raise
async def get_interest_area_distribution(
self,
user_id: str,
include_trends: bool = False
) -> InterestAreaDistribution:
"""Get user's interest area distribution across four areas.
Args:
user_id: Target user ID
include_trends: Whether to include trending information
Returns:
User's interest area distribution
"""
logger.info(f"Getting interest area distribution for user {user_id}")
try:
# Get user summaries
user_summaries = await self.extract_user_summaries(user_id=user_id)
if not user_summaries:
logger.warning(f"No summaries found for user {user_id}")
return self.interest_analyzer._create_empty_distribution(user_id)
# Analyze interests
interest_distribution = await self.interest_analyzer.analyze_interests(
user_id=user_id,
user_summaries=user_summaries
)
# Include trends if requested
if include_trends:
# In a full implementation, this would calculate trending directions
# For now, we'll leave trending_direction as None for each category
pass
logger.info(f"Retrieved interest area distribution for user {user_id}")
return interest_distribution
except Exception as e:
logger.error(f"Failed to get interest area distribution for user {user_id}: {e}")
raise
async def get_behavior_habits(
self,
user_id: str,
confidence_level: Optional[str] = None,
frequency_pattern: Optional[str] = None,
time_period: Optional[str] = None
) -> List[BehaviorHabit]:
"""Get user's behavioral habits with filtering.
Args:
user_id: Target user ID
confidence_level: Optional confidence level filter ("high", "medium", "low")
frequency_pattern: Optional frequency pattern filter
time_period: Optional time period filter ("current", "past")
Returns:
List of filtered behavioral habits
"""
logger.info(f"Getting behavior habits for user {user_id}")
try:
# Get user summaries
user_summaries = await self.extract_user_summaries(user_id=user_id)
if not user_summaries:
logger.warning(f"No summaries found for user {user_id}")
return []
# Detect habits
behavior_habits = await self.habit_detector.detect_habits(
user_id=user_id,
user_summaries=user_summaries
)
# Apply filters
filtered_habits = []
for habit in behavior_habits:
# Filter by confidence level
if confidence_level:
try:
target_confidence = ConfidenceLevel(confidence_level.lower())
if habit.confidence_level != target_confidence:
continue
except ValueError:
logger.warning(f"Invalid confidence level: {confidence_level}")
continue
# Filter by frequency pattern
if frequency_pattern:
try:
target_frequency = FrequencyPattern(frequency_pattern.lower())
if habit.frequency_pattern != target_frequency:
continue
except ValueError:
logger.warning(f"Invalid frequency pattern: {frequency_pattern}")
continue
# Filter by time period
if time_period:
if time_period.lower() == "current" and not habit.is_current:
continue
elif time_period.lower() == "past" and habit.is_current:
continue
filtered_habits.append(habit)
# Sort by confidence level and recency
confidence_order = {"high": 3, "medium": 2, "low": 1}
filtered_habits.sort(
key=lambda x: (
confidence_order.get(x.confidence_level.value, 0),
x.last_observed
),
reverse=True
)
logger.info(f"Retrieved {len(filtered_habits)} behavior habits for user {user_id}")
return filtered_habits
except Exception as e:
logger.error(f"Failed to get behavior habits for user {user_id}: {e}")
raise

View File

@@ -68,12 +68,7 @@ edges:
label: 完成
# 变量定义
variables:
- name: user_question
type: string
required: true
description: 用户的问题
default: ""
variables: []
# 执行配置
execution_config:

View File

@@ -0,0 +1,88 @@
"""202601071800
Revision ID: c6d4afa27bf0
Revises: 8372101eda28
Create Date: 2026-01-07 17:59:23.032323
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision: str = 'c6d4afa27bf0'
down_revision: Union[str, None] = '8372101eda28'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('memory_long_term',
sa.Column('id', sa.UUID(), nullable=False, comment='记忆ID'),
sa.Column('end_user_id', sa.String(length=255), nullable=False, comment='终端用户ID'),
sa.Column('retrieved_content', sa.JSON(), nullable=True, comment='检索到的相关内容,格式为[{}, {}]'),
sa.Column('created_at', sa.DateTime(), nullable=False, comment='创建时间'),
sa.PrimaryKeyConstraint('id')
)
op.create_index(op.f('ix_memory_long_term_created_at'), 'memory_long_term', ['created_at'], unique=False)
op.create_index(op.f('ix_memory_long_term_end_user_id'), 'memory_long_term', ['end_user_id'], unique=False)
op.create_index(op.f('ix_memory_long_term_id'), 'memory_long_term', ['id'], unique=False)
op.create_table('memory_short_term',
sa.Column('id', sa.UUID(), nullable=False, comment='记忆ID'),
sa.Column('end_user_id', sa.String(length=255), nullable=False, comment='终端用户ID'),
sa.Column('messages', sa.Text(), nullable=False, comment='用户消息内容'),
sa.Column('aimessages', sa.Text(), nullable=True, comment='AI回复消息内容'),
sa.Column('search_switch', sa.String(length=50), nullable=True, comment='搜索开关状态'),
sa.Column('retrieved_content', sa.JSON(), nullable=True, comment='检索到的相关内容,格式为[{}, {}]'),
sa.Column('created_at', sa.DateTime(), nullable=False, comment='创建时间'),
sa.PrimaryKeyConstraint('id')
)
op.create_index(op.f('ix_memory_short_term_created_at'), 'memory_short_term', ['created_at'], unique=False)
op.create_index(op.f('ix_memory_short_term_end_user_id'), 'memory_short_term', ['end_user_id'], unique=False)
op.create_index(op.f('ix_memory_short_term_id'), 'memory_short_term', ['id'], unique=False)
op.create_table('memory_perceptual',
sa.Column('id', sa.UUID(), nullable=False),
sa.Column('end_user_id', sa.UUID(), nullable=True),
sa.Column('perceptual_type', sa.Integer(), nullable=False, comment='感知类型'),
sa.Column('storage_service', sa.Integer(), nullable=True, comment='存储服务类型'),
sa.Column('file_path', sa.String(), nullable=False, comment='文件路径'),
sa.Column('file_name', sa.String(), nullable=False, comment='文件名称'),
sa.Column('file_ext', sa.String(), nullable=False, comment='文件后缀名'),
sa.Column('summary', sa.String(), nullable=True, comment='摘要'),
sa.Column('meta_data', postgresql.JSONB(astext_type=sa.Text()), nullable=True, comment='元信息'),
sa.Column('created_time', sa.DateTime(), nullable=True, comment='创建时间'),
sa.ForeignKeyConstraint(['end_user_id'], ['end_users.id'], ),
sa.PrimaryKeyConstraint('id')
)
op.create_index(op.f('ix_memory_perceptual_end_user_id'), 'memory_perceptual', ['end_user_id'], unique=False)
op.create_index(op.f('ix_memory_perceptual_perceptual_type'), 'memory_perceptual', ['perceptual_type'], unique=False)
op.alter_column('multi_agent_configs', 'orchestration_mode',
existing_type=sa.VARCHAR(length=20),
comment='协作模式: collaboration协作| supervisor监督',
existing_comment='协作模式: sequential|parallel|conditional|loop',
existing_nullable=False)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.alter_column('multi_agent_configs', 'orchestration_mode',
existing_type=sa.VARCHAR(length=20),
comment='协作模式: sequential|parallel|conditional|loop',
existing_comment='协作模式: collaboration协作| supervisor监督',
existing_nullable=False)
op.drop_index(op.f('ix_memory_perceptual_perceptual_type'), table_name='memory_perceptual')
op.drop_index(op.f('ix_memory_perceptual_end_user_id'), table_name='memory_perceptual')
op.drop_table('memory_perceptual')
op.drop_index(op.f('ix_memory_short_term_id'), table_name='memory_short_term')
op.drop_index(op.f('ix_memory_short_term_end_user_id'), table_name='memory_short_term')
op.drop_index(op.f('ix_memory_short_term_created_at'), table_name='memory_short_term')
op.drop_table('memory_short_term')
op.drop_index(op.f('ix_memory_long_term_id'), table_name='memory_long_term')
op.drop_index(op.f('ix_memory_long_term_end_user_id'), table_name='memory_long_term')
op.drop_index(op.f('ix_memory_long_term_created_at'), table_name='memory_long_term')
op.drop_table('memory_long_term')
# ### end Alembic commands ###

View File

@@ -1137,10 +1137,10 @@ export const en = {
promptEmpty: 'Describe your use case on the left, and the orchestration preview will be displayed here.',
master: 'Supervisor Mode',
master_agent: 'Supervisor Mode',
master_agentDesc: 'Unified scheduling and management by the main Agent, with sub-Agents executing tasks assigned by the supervisor, suitable for scenarios requiring centralized control.',
handoffs: 'Collaboration Mode',
handoffsDesc: 'Multiple Agents collaborate equally, autonomously coordinating according to task requirements, suitable for complex scenarios requiring flexible interaction.',
supervisor: 'Supervisor Mode',
supervisorDesc: 'Unified scheduling and management by the main Agent, with sub-Agents executing tasks assigned by the supervisor, suitable for scenarios requiring centralized control.',
collaboration: 'Collaboration Mode',
collaborationDesc: 'Multiple Agents collaborate equally, autonomously coordinating according to task requirements, suitable for complex scenarios requiring flexible interaction.',
masterConfig: 'Supervisor Configuration',
orchestrationMode: 'Task Assignment Strategy',
conditional: 'Intelligent Assignment',
@@ -1150,6 +1150,8 @@ export const en = {
merge: 'Complete Aggregation',
vote: 'Key Information Extraction',
priority: 'Structured Integration',
addTool: 'Add Tool',
tool: 'Tool',
},
userMemory: {
userMemory: 'User Memory',
@@ -1207,6 +1209,7 @@ export const en = {
IMPLICIT_MEMORY: 'Implicit Memory',
EMOTIONAL_MEMORY: 'Emotional Memory',
EPISODIC_MEMORY: 'Episodic Memory',
FORGETTING_MANAGEMENT: 'Forgetting Management',
endUserProfile: 'Core Profile',
editEndUserProfile: 'Edit',
@@ -1839,6 +1842,7 @@ Memory Bear: After the rebellion, regional warlordism intensified for several re
status_code: 'Status Code',
max_attempts: 'Max Retry Attempts',
retry_interval: 'Retry Interval',
errorBranch: 'Error Branch',
},
'jinja-render': {
template: 'Code',

View File

@@ -626,10 +626,10 @@ export const zh = {
promptEmpty: '在左侧描述您的用例,编排预览将在此处显示。',
master: '主管模式',
master_agent: '主管模式',
master_agentDesc: '由主 Agent 统一调度和管理,子 Agent 按照主管分配的任务执行,适合需要集中控制的场景。',
handoffs: '协作模式',
handoffsDesc: '多个 Agent 平等协作,根据任务需求自主协调配合,适合需要灵活互动的复杂场景。',
supervisor: '主管模式',
supervisorDesc: '由主 Agent 统一调度和管理,子 Agent 按照主管分配的任务执行,适合需要集中控制的场景。',
collaboration: '协作模式',
collaborationDesc: '多个 Agent 平等协作,根据任务需求自主协调配合,适合需要灵活互动的复杂场景。',
masterConfig: '主管配置',
orchestrationMode: '任务分配策略',
conditional: '智能分配',
@@ -639,6 +639,8 @@ export const zh = {
merge: '完整汇总',
vote: '关键信息提取',
priority: '结构化整合',
addTool: '添加工具',
tool: '工具',
},
// 角色管理相关翻译
role: {
@@ -1286,6 +1288,7 @@ export const zh = {
IMPLICIT_MEMORY: '隐性记忆',
EMOTIONAL_MEMORY: '情绪记忆',
EPISODIC_MEMORY: '情景记忆',
FORGETTING_MANAGEMENT: '遗忘',
endUserProfile: '核心档案',
editEndUserProfile: '编辑',
@@ -1939,6 +1942,7 @@ export const zh = {
status_code: '状态码',
max_attempts: '最大重试次数',
retry_interval: '重试间隔',
errorBranch: '异常分支',
},
'jinja-render': {
template: '代码',
@@ -2252,5 +2256,12 @@ export const zh = {
orderPayInfo: '支付信息',
create_time: '创建时间',
},
forgetDetail: {
title: '遗忘管理系统帮助AI智能管理记忆生命周期通过自动识别低价值记忆、设置遗忘策略和执行定期清理优化记忆库存储空间提升检索效率。',
overviewTitle: '核心指标概览',
totalMemory: '记忆总量',
MemoryHealth: '记忆健康度',
riskOfForgetting: '遗忘风险',
}
},
}

View File

@@ -19,7 +19,6 @@ import type {
MemoryConfig,
AiPromptModalRef,
Source,
ToolModalRef,
ToolOption
} from './types'
import type { Model } from '@/views/ModelManagement/types'
@@ -33,7 +32,6 @@ import { memoryConfigListUrl } from '@/api/memory'
import CustomSelect from '@/components/CustomSelect'
import aiPrompt from '@/assets/images/application/aiPrompt.png'
import AiPromptModal from './components/AiPromptModal'
import ToolModal from './components/ToolModal'
import ToolList from './components/ToolList'
const DescWrapper: FC<{desc: string, className?: string}> = ({desc, className}) => {
@@ -115,6 +113,7 @@ const Agent = forwardRef<AgentRef>((_props, ref) => {
const [variableList, setVariableList] = useState<Variable[]>([])
const [isSave, setIsSave] = useState(false)
const initialized = useRef(false)
const [toolList, setToolList] = useState<ToolOption[]>([])
// 初始化完成标记
useEffect(() => {
@@ -143,6 +142,11 @@ const Agent = forwardRef<AgentRef>((_props, ref) => {
if (isSave) return
setIsSave(true)
}, [values])
useEffect(() => {
if (!initialized.current) return
if (isSave) return
setIsSave(true)
}, [toolList])
useEffect(() => {
getModels()
@@ -294,7 +298,11 @@ const Agent = forwardRef<AgentRef>((_props, ref) => {
...(item.config || {})
}))
} as KnowledgeConfig : null,
tools: toolList
tools: toolList.map(vo => ({
tool_id: vo.tool_id,
operation: vo.operation,
enabled: vo.enabled
}))
}
console.log('params', rest, params)
@@ -347,18 +355,6 @@ const Agent = forwardRef<AgentRef>((_props, ref) => {
form.setFieldValue('system_prompt', value)
}
const toolModalRef = useRef<ToolModalRef>(null)
const [toolList, setToolList] = useState<ToolOption[]>([])
const handleAddTool = () => {
toolModalRef.current?.handleOpen()
}
const updateTools = (tool: ToolOption) => {
const tools = [...toolList, tool]
setToolList(tools)
form.setFieldValue('tools', tools)
}
console.log('toolList', toolList)
return (
<>
{loading && <Spin fullscreen></Spin>}
@@ -469,10 +465,6 @@ const Agent = forwardRef<AgentRef>((_props, ref) => {
defaultModel={defaultModel}
refresh={updatePrompt}
/>
<ToolModal
ref={toolModalRef}
refresh={updateTools}
/>
</>
);
});

View File

@@ -42,7 +42,7 @@ const Cluster = forwardRef<ClusterRef>((_props, ref) => {
const handleSave = (flag = true) => {
if (!data) return Promise.resolve()
if (!values.default_model_config_id) {
if (!values.default_model_config_id && values.orchestration_mode === 'supervisor') {
message.warning(t('common.selectPlaceholder', { title: t('application.model') }))
return Promise.resolve()
}
@@ -140,15 +140,14 @@ const Cluster = forwardRef<ClusterRef>((_props, ref) => {
<Space size={20} direction="vertical" style={{width: '100%'}}>
<Card title={t('application.handoffs')}>
<Form.Item
name={['execution_config', 'routing_mode']}
name="orchestration_mode"
noStyle
>
<RadioGroupCard
options={['master_agent', 'handoffs'].map((type) => ({
options={['supervisor', 'collaboration'].map((type) => ({
value: type,
label: t(`application.${type}`),
labelDesc: t(`application.${type}Desc`),
disabled: type === 'handoffs'
}))}
allowClear={false}
/>
@@ -192,7 +191,7 @@ const Cluster = forwardRef<ClusterRef>((_props, ref) => {
))}
</Card>
<Card title={t('application.masterConfig')}>
{values?.orchestration_mode !== 'collaboration' && <Card title={t('application.masterConfig')}>
<Form.Item
label={t('application.model')}
required={true}
@@ -218,11 +217,11 @@ const Cluster = forwardRef<ClusterRef>((_props, ref) => {
</Row>
</Form.Item>
<Form.Item
name="orchestration_mode"
name={['execution_config',"sub_agent_execution_mode"]}
label={t('application.orchestrationMode')}
>
<Select
options={['conditional', 'sequential', 'parallel'].map((type) => ({
options={['sequential', 'parallel'].map((type) => ({
value: type,
label: t(`application.${type}`),
}))}
@@ -239,7 +238,7 @@ const Cluster = forwardRef<ClusterRef>((_props, ref) => {
}))}
/>
</Form.Item>
</Card>
</Card>}
</Space>
</Form>
</Col>

View File

@@ -90,11 +90,19 @@ const AiPromptModal = forwardRef<AiPromptModalRef, AiPromptModalProps>(({
switch (item.event) {
case 'start':
currentPromptValueRef.current = ''
if (editorRef.current?.clear) {
editorRef.current.clear();
}
break;
case 'message':
if (content) {
currentPromptValueRef.current += content;
form.setFieldsValue({ current_prompt: currentPromptValueRef.current })
if (editorRef.current?.appendText) {
editorRef.current.appendText(content);
editorRef.current.scrollToBottom();
} else {
form.setFieldsValue({ current_prompt: currentPromptValueRef.current })
}
}
if (desc) {
setChatList(prev => {
@@ -107,6 +115,8 @@ const AiPromptModal = forwardRef<AiPromptModalRef, AiPromptModalProps>(({
break;
case 'end':
setLoading(false)
// 流结束时同步表单值
form.setFieldsValue({ current_prompt: currentPromptValueRef.current })
break
}
})

View File

@@ -4,7 +4,7 @@ import { LexicalComposer } from '@lexical/react/LexicalComposer';
import { RichTextPlugin } from '@lexical/react/LexicalRichTextPlugin';
import { ContentEditable } from '@lexical/react/LexicalContentEditable';
import { LexicalErrorBoundary } from '@lexical/react/LexicalErrorBoundary';
import { $getSelection } from 'lexical';
import { $getSelection, $getRoot, $createParagraphNode, $createTextNode, $isParagraphNode, $isTextNode } from 'lexical';
import { useLexicalComposerContext } from '@lexical/react/LexicalComposerContext';
import InitialValuePlugin from './plugin/InitialValuePlugin'
import LineBreakPlugin from './plugin/LineBreakPlugin';
@@ -12,6 +12,9 @@ import InsertTextPlugin from './plugin/InsertTextPlugin';
export interface EditorRef {
insertText: (text: string) => void;
appendText: (text: string) => void;
clear: () => void;
scrollToBottom: () => void;
}
interface LexicalEditorProps {
@@ -46,6 +49,41 @@ const EditorContent = forwardRef<EditorRef, LexicalEditorProps>(({
selection.insertText(text);
}
});
},
appendText: (text: string) => {
editor.update(() => {
const root = $getRoot();
const lastChild = root.getLastChild();
if (lastChild && $isParagraphNode(lastChild)) {
const lastTextNode = lastChild.getLastChild();
if (lastTextNode && $isTextNode(lastTextNode)) {
const currentText = lastTextNode.getTextContent();
lastTextNode.setTextContent(currentText + text);
} else {
const textNode = $createTextNode(text);
lastChild.append(textNode);
}
} else {
const paragraph = $createParagraphNode();
const textNode = $createTextNode(text);
paragraph.append(textNode);
root.append(paragraph);
}
});
},
clear: () => {
editor.update(() => {
const root = $getRoot();
root.clear();
const paragraph = $createParagraphNode();
root.append(paragraph);
});
},
scrollToBottom: () => {
const editorElement = editor.getRootElement();
if (editorElement) {
editorElement.scrollTop = editorElement.scrollHeight;
}
}
}), [editor]);

View File

@@ -1,21 +1,37 @@
import { type FC, useEffect } from 'react';
import { type FC, useEffect, useRef } from 'react';
import { $getRoot, $createParagraphNode, $createTextNode } from 'lexical';
import { useLexicalComposerContext } from '@lexical/react/LexicalComposerContext';
// 设置初始值的插件
const InitialValuePlugin: FC<{ value?: string }> = ({ value }) => {
const [editor] = useLexicalComposerContext();
const lastValueRef = useRef<string | undefined>(undefined);
useEffect(() => {
if (value) {
// 只有当value真正发生变化时才更新
if (lastValueRef.current !== value) {
editor.update(() => {
const root = $getRoot();
const currentText = root.getTextContent();
// 如果当前内容和新值相同,则不更新
if (currentText === (value || '')) {
return;
}
root.clear();
const paragraph = $createParagraphNode();
const textNode = $createTextNode(value);
paragraph.append(textNode);
root.append(paragraph);
if (value) {
const paragraph = $createParagraphNode();
const textNode = $createTextNode(value);
paragraph.append(textNode);
root.append(paragraph);
} else {
// 当value为undefined或空时创建一个空段落
const paragraph = $createParagraphNode();
root.append(paragraph);
}
});
lastValueRef.current = value;
}
}, [editor, value]);

View File

@@ -1,4 +1,4 @@
import { type FC, useState } from 'react';
import { type FC, useState, useEffect } from 'react';
import { LexicalComposer } from '@lexical/react/LexicalComposer';
import { RichTextPlugin } from '@lexical/react/LexicalRichTextPlugin';
import { ContentEditable } from '@lexical/react/LexicalContentEditable';
@@ -23,6 +23,7 @@ interface LexicalEditorProps {
options: Suggestion[];
variant?: 'outlined' | 'borderless';
height?: number;
enableJinja2?: boolean;
}
const theme = {
@@ -33,6 +34,15 @@ const theme = {
},
};
const jinja2Theme = {
...theme,
code: 'jinja2-expression',
text: {
...theme.text,
code: 'jinja2-inline',
},
};
const Editor: FC<LexicalEditorProps> =({
placeholder = "请输入内容...",
value = "",
@@ -40,19 +50,62 @@ const Editor: FC<LexicalEditorProps> =({
options,
variant = 'borderless',
height = 60,
enableJinja2 = false,
}) => {
const [_count, setCount] = useState(0);
useEffect(() => {
if (enableJinja2) {
const styleId = 'jinja2-styles';
let existingStyle = document.getElementById(styleId);
if (!existingStyle) {
const style = document.createElement('style');
style.id = styleId;
style.textContent = `
.jinja2-expression {
background-color: #f6f8fa !important;
border: 1px solid #d1d9e0 !important;
border-radius: 3px !important;
padding: 2px 4px !important;
font-family: 'SFMono-Regular', Consolas, 'Liberation Mono', Menlo, monospace !important;
font-size: 13px !important;
color: #0969da !important;
}
.jinja2-inline {
background-color: #f6f8fa !important;
padding: 1px 3px !important;
border-radius: 2px !important;
font-family: 'SFMono-Regular', Consolas, 'Liberation Mono', Menlo, monospace !important;
font-size: 13px !important;
color: #0969da !important;
}
.editor-paragraph {
margin: 0;
}
.editor-paragraph:has-text('{') .editor-text,
.editor-paragraph:has-text('[') .editor-text {
font-family: 'SFMono-Regular', Consolas, 'Liberation Mono', Menlo, monospace !important;
}
`;
document.head.appendChild(style);
}
}
}, [enableJinja2]);
const initialConfig = {
namespace: 'AutocompleteEditor',
theme,
nodes: [
theme: enableJinja2 ? jinja2Theme : theme,
nodes: enableJinja2 ? [
// 当启用jinja2时不使用VariableNode使用普通文本
] : [
// HeadingNode,
// QuoteNode,
// ListItemNode,
// ListNode,
// LinkNode,
// CodeNode,
VariableNode
VariableNode,
],
onError: (error: Error) => {
console.error(error);
@@ -96,9 +149,9 @@ const Editor: FC<LexicalEditorProps> =({
/>
<HistoryPlugin />
<CommandPlugin />
<AutocompletePlugin options={options} />
<AutocompletePlugin options={options} enableJinja2={enableJinja2} />
<CharacterCountPlugin setCount={(count) => { setCount(count) }} onChange={onChange} />
<InitialValuePlugin value={value} options={options} />
<InitialValuePlugin value={value} options={options} enableJinja2={enableJinja2} />
</div>
</LexicalComposer>
);

View File

@@ -17,7 +17,7 @@ export interface Suggestion {
disabled?: boolean; // 标记是否禁用
}
const AutocompletePlugin: FC<{ options: Suggestion[] }> = ({ options }) => {
const AutocompletePlugin: FC<{ options: Suggestion[], enableJinja2?: boolean }> = ({ options, enableJinja2 = false }) => {
const [editor] = useLexicalComposerContext();
const [showSuggestions, setShowSuggestions] = useState(false);
const [selectedIndex, setSelectedIndex] = useState(0);
@@ -82,7 +82,32 @@ const AutocompletePlugin: FC<{ options: Suggestion[] }> = ({ options }) => {
}, [editor]);
const insertMention = (suggestion: Suggestion) => {
editor.dispatchCommand(INSERT_VARIABLE_COMMAND, { data: suggestion });
if (enableJinja2) {
// 在jinja2模式下插入{{variable}}格式的文本
editor.update(() => {
const selection = $getSelection();
if ($isRangeSelection(selection)) {
const anchorNode = selection.anchor.getNode();
const anchorOffset = selection.anchor.offset;
const nodeText = anchorNode.getTextContent();
// 移除触发字符'/'
const textBefore = nodeText.substring(0, anchorOffset - 1);
const textAfter = nodeText.substring(anchorOffset);
const newText = textBefore + `{{${suggestion.value}}}` + textAfter;
anchorNode.setTextContent(newText);
// 设置光标位置到插入文本之后
const newOffset = textBefore.length + `{{${suggestion.value}}}`.length;
selection.anchor.offset = newOffset;
selection.focus.offset = newOffset;
}
});
} else {
// 普通模式下使用VariableNode
editor.dispatchCommand(INSERT_VARIABLE_COMMAND, { data: suggestion });
}
setShowSuggestions(false);
};

View File

@@ -8,14 +8,31 @@ import { type Suggestion } from '../plugin/AutocompletePlugin'
interface InitialValuePluginProps {
value: string;
options?: Suggestion[];
enableJinja2?: boolean;
}
const InitialValuePlugin: React.FC<InitialValuePluginProps> = ({ value, options = [] }) => {
const InitialValuePlugin: React.FC<InitialValuePluginProps> = ({ value, options = [], enableJinja2 = false }) => {
const [editor] = useLexicalComposerContext();
const initializedRef = useRef(false);
const prevValueRef = useRef<string>('');
const isUserInputRef = useRef(false);
useEffect(() => {
if (!initializedRef.current && value) {
// 监听编辑器变化,标记是否为用户输入
const removeListener = editor.registerUpdateListener(({ editorState }) => {
editorState.read(() => {
const root = $getRoot();
const textContent = root.getTextContent();
if (textContent !== prevValueRef.current) {
isUserInputRef.current = true;
}
});
});
return removeListener;
}, [editor]);
useEffect(() => {
if (value !== prevValueRef.current && !isUserInputRef.current) {
editor.update(() => {
const root = $getRoot();
root.clear();
@@ -28,7 +45,11 @@ const InitialValuePlugin: React.FC<InitialValuePluginProps> = ({ value, options
const contextMatch = part.match(/^\{\{context\}\}$/);
const conversationMatch = part.match(/^\{\{conv\.([^}]+)\}\}$/);
// 匹配{{context}}格式
if (enableJinja2) {
paragraph.append($createTextNode(part));
return;
}
if (contextMatch) {
const contextSuggestion = options.find(s => s.isContext && s.label === 'context');
if (contextSuggestion) {
@@ -39,7 +60,6 @@ const InitialValuePlugin: React.FC<InitialValuePluginProps> = ({ value, options
return
}
// 匹配{{conv.xx}}格式
if (conversationMatch) {
const [_, variableName] = conversationMatch;
const conversationSuggestion = options.find(s =>
@@ -53,7 +73,6 @@ const InitialValuePlugin: React.FC<InitialValuePluginProps> = ({ value, options
return
}
// 匹配普通变量{{nodeId.label}}格式
if (match) {
const [_, nodeId, label] = match;
@@ -75,11 +94,12 @@ const InitialValuePlugin: React.FC<InitialValuePluginProps> = ({ value, options
});
root.append(paragraph);
});
initializedRef.current = true;
}, { discrete: true });
}
}, [options]);
prevValueRef.current = value;
isUserInputRef.current = false;
}, [value, options, editor, enableJinja2]);
return null;
};

View File

@@ -0,0 +1,109 @@
import { useEffect } from 'react';
import { useLexicalComposerContext } from '@lexical/react/LexicalComposerContext';
import { $getRoot, $getSelection, $isRangeSelection, TextNode, $createTextNode } from 'lexical';
const JsonHighlightPlugin = () => {
const [editor] = useLexicalComposerContext();
useEffect(() => {
return editor.registerNodeTransform(TextNode, (textNode: TextNode) => {
const text = textNode.getTextContent();
// Check if text contains JSON-like patterns
if (containsJsonPatterns(text)) {
const parent = textNode.getParent();
if (!parent) return;
// Split text into tokens and create new nodes with appropriate classes
const tokens = tokenizeJson(text);
const newNodes = tokens.map(token => {
const newNode = $createTextNode(token.text);
// Set format based on token type
switch (token.type) {
case 'string':
newNode.setFormat('code');
newNode.setStyle('color: #032f62');
break;
case 'number':
newNode.setFormat('code');
newNode.setStyle('color: #005cc5');
break;
case 'boolean':
newNode.setFormat('code');
newNode.setStyle('color: #d73a49');
break;
case 'null':
newNode.setFormat('code');
newNode.setStyle('color: #6f42c1');
break;
case 'key':
newNode.setFormat('code');
newNode.setStyle('color: #22863a; font-weight: bold');
break;
case 'punctuation':
newNode.setFormat('code');
newNode.setStyle('color: #24292e');
break;
}
return newNode;
});
// Replace the original text node with the new highlighted nodes
if (newNodes.length > 1) {
textNode.replace(newNodes[0]);
for (let i = 1; i < newNodes.length; i++) {
newNodes[i - 1].insertAfter(newNodes[i]);
}
}
}
});
}, [editor]);
return null;
};
function containsJsonPatterns(text: string): boolean {
// Check for JSON-like patterns
return /[{}\[\]:,]/.test(text) ||
/"[^"]*"/.test(text) ||
/\b\d+(\.\d+)?\b/.test(text) ||
/\b(true|false|null)\b/.test(text);
}
function tokenizeJson(text: string): Array<{text: string, type: string}> {
const tokens: Array<{text: string, type: string}> = [];
const regex = /("[^"]*")|([{}\[\]:,])|(\b\d+(?:\.\d+)?\b)|(\b(?:true|false|null)\b)|(\s+)|([^\s{}\[\]:,"]+)/g;
let match;
while ((match = regex.exec(text)) !== null) {
const [fullMatch, string, punctuation, number, boolean, whitespace, other] = match;
if (string) {
// Check if it's a key (followed by colon)
const afterMatch = text.slice(match.index + fullMatch.length).trim();
if (afterMatch.startsWith(':')) {
tokens.push({ text: fullMatch, type: 'key' });
} else {
tokens.push({ text: fullMatch, type: 'string' });
}
} else if (punctuation) {
tokens.push({ text: fullMatch, type: 'punctuation' });
} else if (number) {
tokens.push({ text: fullMatch, type: 'number' });
} else if (boolean) {
if (fullMatch === 'null') {
tokens.push({ text: fullMatch, type: 'null' });
} else {
tokens.push({ text: fullMatch, type: 'boolean' });
}
} else if (whitespace || other) {
tokens.push({ text: fullMatch, type: 'text' });
}
}
return tokens;
}
export default JsonHighlightPlugin;

View File

@@ -28,12 +28,18 @@ const AuthConfigModal = forwardRef<AuthConfigModalRef, AuthConfigModalProps>(({
const handleOpen = (data?: HttpRequestConfigForm['auth']) => {
if (data) {
form.setFieldsValue({
const initialValues = {
auth: !data.auth_type || data.auth_type === 'none' ? 'none' : 'api_key',
auth_type: !data.auth_type || data.auth_type === 'none' ? undefined : data.auth_type,
header: data.header,
api_key: data.api_key
})
}
form.setFieldValue('auth', initialValues.auth)
if (initialValues.auth !== 'none') {
setTimeout(() => {
form.setFieldsValue(initialValues)
}, 1)
}
}
setVisible(true);
};
@@ -91,6 +97,9 @@ const AuthConfigModal = forwardRef<AuthConfigModalRef, AuthConfigModalProps>(({
<FormItem
name="auth"
label={t('workflow.config.http-request.authType')}
rules={[
{ required: true, message: t('common.pleaseSelect') }
]}
>
<Select
options={[
@@ -103,6 +112,9 @@ const AuthConfigModal = forwardRef<AuthConfigModalRef, AuthConfigModalProps>(({
<FormItem
name="auth_type"
label={t('workflow.config.http-request.authType')}
rules={[
{ required: true, message: t('common.pleaseSelect') }
]}
>
<Select
options={[

View File

@@ -1,4 +1,4 @@
import { useMemo, useCallback } from 'react';
import { useMemo } from 'react';
import { useTranslation } from 'react-i18next'
import { Button, Select, Table, Form, type TableProps } from 'antd';
import { PlusOutlined, DeleteOutlined } from '@ant-design/icons';
@@ -6,46 +6,8 @@ import type { Suggestion } from '../../Editor/plugin/AutocompletePlugin';
import Empty from '@/components/Empty';
import VariableSelect from '../VariableSelect';
interface EditableCellProps extends React.HTMLAttributes<HTMLElement> {
name?: string | string[];
inputType?: 'select' | 'variableSelect';
options?: { value: string, label: string }[] | Suggestion[];
}
const EditableCell: React.FC<React.PropsWithChildren<EditableCellProps>> = ({
name,
inputType,
options,
children,
...restProps
}) => {
const { t } = useTranslation();
if (!inputType) return <td {...restProps}>{children}</td>;
return (
<td {...restProps}>
<Form.Item name={name} style={{ margin: 0 }}>
{inputType === 'select' ? (
<Select
placeholder={t('common.pleaseSelect')}
size="small"
options={options as { value: string, label: string }[]}
/>
) : (
<VariableSelect
placeholder={t('common.pleaseSelect')}
size="small"
options={(options as Suggestion[]) || []}
/>
)}
</Form.Item>
</td>
);
};
export interface TableRow {
key: string;
key?: string;
name?: string;
value?: string;
type?: string;
@@ -56,100 +18,158 @@ interface EditableTableProps {
title?: string;
options?: Suggestion[];
typeOptions?: { value: string, label: string }[]
filterBooleanType?: boolean;
}
const EditableTable: React.FC<EditableTableProps> = ({
parentName,
title,
options = [],
typeOptions = []
typeOptions = [],
filterBooleanType = false
}) => {
const { t } = useTranslation();
const form = Form.useFormInstance();
const values = Form.useWatch(typeof parentName === 'string' ? [parentName] : parentName, form);
const createNewRow = (): TableRow => ({
key: Date.now().toString(),
name: undefined,
value: undefined,
...(typeOptions.length > 0 && { type: typeOptions[0].value })
});
const handleAdd = useCallback(() => {
form.setFieldValue(parentName, [...(values ?? []), createNewRow()]);
}, [form, parentName, values, typeOptions]);
const handleDelete = useCallback((index: number) => {
const currentValues = form.getFieldValue(parentName) || [];
form.setFieldValue(parentName, currentValues.filter((_: TableRow, i: number) => i !== index));
}, [form, parentName]);
const createColumn = (dataIndex: string, inputType: 'select' | 'variableSelect', width: string, columnOptions: any[]) => ({
title: t(`workflow.config.${dataIndex}`),
dataIndex,
width,
onCell: (_: TableRow, index?: number) => ({
name: typeof parentName === 'string' ? [parentName, index ?? 0, dataIndex] : [...parentName, index ?? 0, dataIndex],
inputType,
options: columnOptions
} as any)
});
const columns: TableProps<TableRow>['columns'] = useMemo(() => {
const getColumns = (remove: (index: number) => void): TableProps<TableRow>['columns'] => {
const hasType = typeOptions.length > 0;
const baseWidth = hasType ? '35%' : '45%';
return [
createColumn('name', 'variableSelect', baseWidth, options),
...(hasType ? [createColumn('type', 'select', '20%', typeOptions)] : []),
createColumn('value', 'variableSelect', baseWidth, options),
{
title: t('workflow.config.name'),
dataIndex: 'name',
width: baseWidth,
render: (_: any, __: TableRow, index: number) => (
<Form.Item name={[index, 'name']} noStyle>
<VariableSelect
placeholder={t('common.pleaseSelect')}
size="small"
options={options}
filterBooleanType={filterBooleanType}
popupMatchSelectWidth={false}
/>
</Form.Item>
)
},
...(hasType ? [{
title: t('workflow.config.type'),
dataIndex: 'type',
width: '20%',
render: (_: any, __: TableRow, index: number) => (
<Form.Item shouldUpdate noStyle>
{(form) => (
<Form.Item name={[index, 'type']} noStyle>
<Select
placeholder={t('common.pleaseSelect')}
size="small"
options={typeOptions}
popupMatchSelectWidth={false}
onChange={() => {
form.setFieldValue([...Array.isArray(parentName) ? parentName : [parentName], index, 'value'], undefined);
}}
/>
</Form.Item>
)}
</Form.Item>
)
}] : []),
{
title: t('workflow.config.value'),
dataIndex: 'value',
width: baseWidth,
render: (_: any, __: TableRow, index: number) => (
<Form.Item
shouldUpdate={(prevValues, currentValues) => {
const prevType = prevValues?.[Array.isArray(parentName) ? parentName.join('.') : parentName]?.[index]?.type;
const currentType = currentValues?.[Array.isArray(parentName) ? parentName.join('.') : parentName]?.[index]?.type;
return prevType !== currentType;
}}
noStyle
>
{(form) => {
const currentType = form.getFieldValue([...Array.isArray(parentName) ? parentName : [parentName], index, 'type']);
const filteredOptions = currentType === 'file'
? options.filter(option => option.dataType === 'file')
: options;
return (
<Form.Item name={[index, 'value']} noStyle>
<VariableSelect
placeholder={t('common.pleaseSelect')}
size="small"
options={filteredOptions}
filterBooleanType={filterBooleanType}
popupMatchSelectWidth={false}
/>
</Form.Item>
);
}}
</Form.Item>
)
},
{
title: '',
dataIndex: 'actions',
width: '10%',
render: (_: any, __: TableRow, index: number) => (
<Button type="text" icon={<DeleteOutlined />} onClick={() => handleDelete(index)} />
<Button type="text" icon={<DeleteOutlined />} onClick={() => remove(index)} />
)
}
];
}, [typeOptions, options, t, parentName, handleDelete]);
const AddButton = ({ block = false }: { block?: boolean }) => (
<Button
type={block ? "dashed" : "text"}
icon={<PlusOutlined />}
onClick={handleAdd}
size="small"
block={block}
className={block ? "rb:mt-1" : ""}
>
{block && `+${t('common.add')}`}
</Button>
);
};
return (
<div className="rb:mb-4">
{title && (
<div className="rb:flex rb:items-center rb:mb-2 rb:justify-between">
<div className="rb:font-medium">{title}</div>
<AddButton />
</div>
)}
<Form.Item name={parentName}>
<Table<TableRow>
components={{ body: { cell: EditableCell } }}
bordered
dataSource={values}
columns={columns}
pagination={false}
size="small"
locale={{ emptyText: <Empty size={88} /> }}
scroll={{ x: 'max-content' }}
/>
</Form.Item>
{!title && <AddButton block />}
<Form.List name={parentName}>
{(fields, { add, remove }) => {
const AddButton = ({ block = false }: { block?: boolean }) => (
<Button
type={block ? "dashed" : "text"}
icon={<PlusOutlined />}
onClick={() => add(createNewRow())}
size="small"
block={block}
className={block ? "rb:mt-1" : ""}
>
{block && `+${t('common.add')}`}
</Button>
);
return (
<>
{title && (
<div className="rb:flex rb:items-center rb:mb-2 rb:justify-between">
<div className="rb:font-medium">{title}</div>
<AddButton />
</div>
)}
<Table<TableRow>
bordered
dataSource={fields.map((field) => ({
key: String(field.key),
name: undefined,
value: undefined,
type: undefined
}))}
columns={getColumns(remove)}
pagination={false}
size="small"
locale={{ emptyText: <Empty size={88} /> }}
scroll={{ x: 'max-content' }}
/>
{!title && <AddButton block />}
</>
);
}}
</Form.List>
</div>
);
};

View File

@@ -9,8 +9,10 @@ import VariableSelect from "../VariableSelect";
import MessageEditor from '../MessageEditor'
import EditableTable from './EditableTable'
const HttpRequest: FC<{ options: Suggestion[]; }> = ({
const HttpRequest: FC<{ options: Suggestion[]; selectedNode?: any; graphRef?: any; }> = ({
options,
selectedNode,
graphRef
}) => {
const { t } = useTranslation()
const form = Form.useFormInstance();
@@ -22,18 +24,45 @@ const HttpRequest: FC<{ options: Suggestion[]; }> = ({
}
const handleRefresh = (auth: HttpRequestConfigForm['auth']) => {
console.log('handleRefresh', auth)
form.setFieldsValue({ auth: {...auth} })
form.setFieldsValue({ auth })
}
const handleChangeBodyContentType = (contentType: string) => {
const currentValues = form.getFieldsValue()
const handleChangeBodyContentType = () => {
form.setFieldValue(['body', 'data'], undefined)
}
const handleChangeErrorHandleMethod = (method: string) => {
form.setFieldsValue({
body: {
...currentValues?.body,
content_type: contentType,
data: undefined
error_handle: {
method,
body: undefined,
status_code: undefined,
headers: undefined
}
})
// 更新节点连接桩
console.log('handleChangeErrorHandleMethod', selectedNode, graphRef?.current)
if (selectedNode && graphRef?.current) {
const existingPorts = selectedNode.getPorts();
const errorPort = existingPorts.find((port: any) => port.id === 'ERROR');
if (method === 'branch' && !errorPort) {
// 添加异常节点连接桩
selectedNode.addPort({
id: 'ERROR',
group: 'right',
attrs: { text: { text: t('workflow.config.http-request.errorBranch'), fontSize: 12, fill: '#5B6167' }}
});
} else if (method !== 'branch' && errorPort) {
// 移除异常节点连接桩和相关连线
const edges = graphRef.current.getEdges().filter((edge: any) =>
edge.getSourceCellId() === selectedNode.id && edge.getSourcePortId() === 'ERROR'
);
edges.forEach((edge: any) => graphRef.current.removeCell(edge));
selectedNode.removePort('ERROR');
}
}
}
console.log('HttpRequest', values)
@@ -73,6 +102,7 @@ const HttpRequest: FC<{ options: Suggestion[]; }> = ({
parentName="headers"
title="HEADERS"
options={options}
filterBooleanType={true}
/>
</Form.Item>
@@ -81,6 +111,7 @@ const HttpRequest: FC<{ options: Suggestion[]; }> = ({
parentName="params"
title="PARAMS"
options={options}
filterBooleanType={true}
/>
</Form.Item>
@@ -104,6 +135,7 @@ const HttpRequest: FC<{ options: Suggestion[]; }> = ({
<EditableTable
parentName={['body', 'data']}
options={options}
filterBooleanType={true}
typeOptions={[
{ label: 'text', value: 'text' },
{ label: 'file', value: 'file' }
@@ -116,12 +148,15 @@ const HttpRequest: FC<{ options: Suggestion[]; }> = ({
<EditableTable
parentName={['body', 'data']}
options={options}
filterBooleanType={true}
/>
</Form.Item>
}
{values?.body?.content_type === 'json' &&
<Form.Item name={['body', 'data']}>
<MessageEditor
key="json"
parentName={['body', 'data']}
options={options}
isArray={false}
title="JSON"
@@ -131,6 +166,8 @@ const HttpRequest: FC<{ options: Suggestion[]; }> = ({
{values?.body?.content_type === 'raw' &&
<Form.Item name={['body', 'data']}>
<MessageEditor
key="raw"
parentName={['body', 'data']}
options={options}
isArray={false}
title="RAW TEXT"
@@ -141,6 +178,7 @@ const HttpRequest: FC<{ options: Suggestion[]; }> = ({
<Form.Item name={['body', 'data']}>
<VariableSelect
options={options}
filterBooleanType={true}
/>
</Form.Item>
}
@@ -185,7 +223,7 @@ const HttpRequest: FC<{ options: Suggestion[]; }> = ({
</Form.Item>
<Form.Item
name={['retry', 'retry_interval']}
label={t('workflow.config.http-request.retry_interval')}
label={<>{t('workflow.config.http-request.retry_interval')} <span className="rb:text-[#5B6167]">(ms)</span></>}
>
<InputNumber placeholder={t('common.pleaseEnter')} className="rb:w-full!" />
</Form.Item>
@@ -196,6 +234,7 @@ const HttpRequest: FC<{ options: Suggestion[]; }> = ({
<Form.Item layout="horizontal" name={['error_handle', 'method']} label={t('workflow.config.http-request.error_handle')}>
<Select
placeholder={t('common.pleaseSelect')}
onChange={handleChangeErrorHandleMethod}
options={[
{ value: 'none', label: t('workflow.config.http-request.none') },
{ value: 'default', label: t('workflow.config.http-request.default') },
@@ -207,32 +246,19 @@ const HttpRequest: FC<{ options: Suggestion[]; }> = ({
<>
<Form.Item
name={['error_handle', 'body']}
label="body"
label={<>body <span className="rb:text-[#5B6167] rb:ml-1">string</span></>}
>
<Input placeholder={t('common.pleaseEnter')} />
</Form.Item>
<Form.Item
name={['error_handle', 'status_code']}
label="status_code"
label={<>status_code <span className="rb:text-[#5B6167] rb:ml-1">number</span></>}
>
<InputNumber placeholder={t('common.pleaseEnter')} className="rb:w-full!" />
</Form.Item>
<Form.Item
name={['error_handle', 'headers']}
label="headers"
rules={[
{
validator: (_, value) => {
if (!value) return Promise.resolve();
try {
JSON.parse(value);
return Promise.resolve();
} catch {
return Promise.reject(new Error('Please enter valid JSON format'));
}
}
}
]}
label={<>headers <span className="rb:text-[#5B6167] rb:ml-1">object</span></>}
>
<Input.TextArea placeholder={t('common.pleaseEnter')} />
</Form.Item>

View File

@@ -17,14 +17,14 @@ const MappingList: React.FC<MappingListProps> = ({ name, options }) => {
{(fields, { add, remove }) => (
<>
{fields.map(({ key, name, ...restField }) => (
<Row gutter={12} className="rb:mb-2">
<Row key={key} gutter={12} className="rb:mb-2">
<Col span={10}>
<Form.Item
{...restField}
name={[name, 'name']}
noStyle
>
<Input placeholder={t('common.pleaseEnter')} />
<Input placeholder={t('common.pleaseEnter')} data-field-type="mapping-name" />
</Form.Item>
</Col>
<Col span={12}>

View File

@@ -5,14 +5,15 @@ import { MinusCircleOutlined } from '@ant-design/icons';
import Editor from '../Editor'
import type { Suggestion } from '../Editor/plugin/AutocompletePlugin'
interface TextareaProps {
interface MessageEditor {
options: Suggestion[];
title?: string
isArray?: boolean;
parentName?: string;
parentName?: string | string[];
label?: string;
placeholder?: string;
value?: string;
enableJinja2?: boolean;
onChange?: (value?: string) => void;
}
const roleOptions = [
@@ -20,12 +21,13 @@ const roleOptions = [
{ label: 'USER', value: 'USER' },
{ label: 'ASSISTANT', value: 'ASSISTANT' },
]
const MessageEditor: FC<TextareaProps> = ({
const MessageEditor: FC<MessageEditor> = ({
title,
isArray = true,
parentName = 'messages',
placeholder,
options,
enableJinja2 = false,
}) => {
const { t } = useTranslation()
const form = Form.useFormInstance();
@@ -33,10 +35,17 @@ const MessageEditor: FC<TextareaProps> = ({
// 检查是否已经使用了context变量将已使用的context设置为disabled
const processedOptions = useMemo(() => {
if (!isArray || !values?.[parentName]) return options;
if (!isArray) return options;
// 获取表单中对应字段的值
const fieldValue = Array.isArray(parentName)
? parentName.reduce((obj, key) => obj?.[key], values)
: values?.[parentName];
if (!fieldValue) return options;
// 获取所有消息内容
const allContents = values[parentName]
const allContents = fieldValue
.map((msg: any) => msg?.content || '')
.join(' ');
@@ -50,7 +59,11 @@ const MessageEditor: FC<TextareaProps> = ({
}, [options, values, parentName, isArray]);
const handleAdd = (add: FormListOperation['add']) => {
const list = values?.[parentName] || [];
const fieldValue = Array.isArray(parentName)
? parentName.reduce((obj, key) => obj?.[key], values)
: values?.[parentName];
const list = fieldValue || [];
const lastRole = list.length > 0 ? list[list.length - 1]?.role : 'ASSISTANT';
add({
@@ -61,14 +74,14 @@ const MessageEditor: FC<TextareaProps> = ({
if (!isArray) {
return (
<Space size={12} direction="vertical" className="rb:w-full rb:border rb:border-[#DFE4ED] rb:rounded-md rb:px-2 rb:py-1.5 rb:bg-white">
<Space size={12} direction="vertical" className="rb:w-full rb:border rb:border-[#DFE4ED] rb:rounded-md rb:px-2 rb:py-1.5 rb:bg-white" data-editor-type={parentName === 'template' ? 'template' : undefined}>
<Row>
<Col span={12}>
{title ?? t('workflow.answerDesc')}
</Col>
</Row>
<Form.Item name={parentName} noStyle>
<Editor placeholder={placeholder} options={processedOptions} />
<Editor enableJinja2={enableJinja2} placeholder={placeholder} options={processedOptions} />
</Form.Item>
</Space>
);
@@ -79,7 +92,11 @@ const MessageEditor: FC<TextareaProps> = ({
{(fields, { add, remove }) => (
<Space size={12} direction="vertical" className="rb:w-full">
{fields.map(({ key, name, ...restField }) => {
const currentRole = (values?.[parentName]?.[name]?.role || 'USER').toUpperCase();
const fieldValue = Array.isArray(parentName)
? parentName.reduce((obj, key) => obj?.[key], values)
: values?.[parentName];
const currentRole = (fieldValue?.[name]?.role || 'USER').toUpperCase();
return (
<Space key={key} size={12} direction="vertical" className="rb:w-full rb:border rb:border-[#DFE4ED] rb:rounded-md rb:px-2 rb:py-1.5 rb:bg-white">
@@ -105,7 +122,7 @@ const MessageEditor: FC<TextareaProps> = ({
)}
</Row>
<Form.Item {...restField} name={[name, 'content']} noStyle>
<Editor placeholder={placeholder} options={processedOptions} />
<Editor enableJinja2={enableJinja2} placeholder={placeholder} options={processedOptions} />
</Form.Item>
</Space>
);

View File

@@ -9,6 +9,7 @@ interface VariableSelectProps extends SelectProps {
value?: string;
onChange?: (value: string) => void;
allowClear?: boolean;
filterBooleanType?: boolean;
}
const VariableSelect: FC<VariableSelectProps> = ({
@@ -18,6 +19,7 @@ const VariableSelect: FC<VariableSelectProps> = ({
allowClear = true,
onChange,
size,
filterBooleanType = false,
...resetPorps
}) => {
@@ -26,7 +28,7 @@ const VariableSelect: FC<VariableSelectProps> = ({
}
const labelRender: LabelRender = (props) => {
const { value } = props
const filterOption = options.find(vo => `{{${vo.value}}}` === value)
const filterOption = filteredOptions.find(vo => `{{${vo.value}}}` === value)
if (filterOption) {
return (
@@ -54,7 +56,11 @@ const VariableSelect: FC<VariableSelectProps> = ({
}
return null
}
const groupedSuggestions = options.reduce((groups: Record<string, any[]>, suggestion) => {
const filteredOptions = filterBooleanType
? options.filter(option => option.dataType !== 'boolean')
: options;
const groupedSuggestions = filteredOptions.reduce((groups: Record<string, any[]>, suggestion) => {
const { nodeData } = suggestion
const nodeId = nodeData.id as string;
if (!groups[nodeId]) {
@@ -64,12 +70,10 @@ const VariableSelect: FC<VariableSelectProps> = ({
return groups;
}, {});
const groupedOptions = Object.entries(groupedSuggestions).map(([nodeId, suggestions]) => ({
const groupedOptions = Object.entries(groupedSuggestions).map(([_nodeId, suggestions]) => ({
label: suggestions[0].nodeData.name,
options: suggestions.map(s => ({ label: s.label, value: `{{${s.value}}}` }))
}));
console.log('groupedOptions', groupedOptions)
return (
<Select

View File

@@ -45,13 +45,134 @@ const Properties: FC<PropertiesProps> = ({
const values = Form.useWatch([], form);
const variableModalRef = useRef<VariableEditModalRef>(null)
const [editIndex, setEditIndex] = useState<number | null>(null)
const prevMappingNamesRef = useRef<string[]>([])
const prevTemplateVarsRef = useRef<string[]>([])
const syncTimeoutRef = useRef<NodeJS.Timeout | null>(null)
const isSyncingRef = useRef(false)
const lastSyncSourceRef = useRef<'mapping' | 'template' | null>(null)
useEffect(() => {
if (selectedNode?.getData()?.id) {
form.resetFields()
prevMappingNamesRef.current = []
prevTemplateVarsRef.current = []
lastSyncSourceRef.current = null
}
}, [selectedNode?.getData()?.id])
// Sync template when mapping names change
useEffect(() => {
if (isSyncingRef.current || lastSyncSourceRef.current === 'mapping' || selectedNode?.data?.type !== 'jinja-render' || !values?.mapping || !values?.template) return
const currentMappingNames = values.mapping.map((item: any) => item.name).filter(Boolean)
const prevNames = prevMappingNamesRef.current
if (prevNames.length === 0) {
prevMappingNamesRef.current = currentMappingNames
return
}
if (JSON.stringify(prevNames) === JSON.stringify(currentMappingNames)) return
if (syncTimeoutRef.current) clearTimeout(syncTimeoutRef.current)
const activeElement = document.activeElement as HTMLElement
syncTimeoutRef.current = setTimeout(() => {
let updatedTemplate = String(form.getFieldValue('template') || '')
prevNames.forEach((oldName, index) => {
const newName = currentMappingNames[index]
if (newName && oldName !== newName) {
updatedTemplate = updatedTemplate.replace(new RegExp(`{{\\s*${oldName}\\s*}}`, 'g'), `{{${newName}}}`)
}
})
if (updatedTemplate !== form.getFieldValue('template')) {
isSyncingRef.current = true
lastSyncSourceRef.current = 'mapping'
const newTemplateVars = (updatedTemplate.match(/{{\s*([\w.]+)\s*}}/g) || []).map(m => m.replace(/{{\s*|\s*}}/g, ''))
prevTemplateVarsRef.current = newTemplateVars
prevMappingNamesRef.current = currentMappingNames
form.setFieldValue('template', updatedTemplate)
requestAnimationFrame(() => {
activeElement?.focus?.()
setTimeout(() => {
isSyncingRef.current = false
lastSyncSourceRef.current = null
}, 50)
})
} else {
prevMappingNamesRef.current = currentMappingNames
}
}, 0)
}, [values?.mapping, selectedNode?.data?.type, form])
// Sync mapping when template variables change
useEffect(() => {
if (isSyncingRef.current || lastSyncSourceRef.current === 'template' || selectedNode?.data?.type !== 'jinja-render' || !values?.template || !values?.mapping) return
const templateVars = (String(values.template).match(/{{\s*([\w.]+)\s*}}/g) || []).map(m => m.replace(/{{\s*|\s*}}/g, ''))
if (JSON.stringify(prevTemplateVarsRef.current) === JSON.stringify(templateVars)) return
const isTemplateEditor = document.activeElement?.closest('[data-editor-type="template"]')
if (!isTemplateEditor) {
prevTemplateVarsRef.current = templateVars
return
}
const updatedMapping = [...values.mapping]
const existingNames = updatedMapping.map(item => item.name)
let updatedTemplate = String(values.template)
if (prevTemplateVarsRef.current.length > 0) {
prevTemplateVarsRef.current.forEach((oldVar, index) => {
const newVar = templateVars[index]
if (newVar && oldVar !== newVar && updatedMapping[index]) {
updatedMapping[index] = { ...updatedMapping[index], name: newVar }
}
})
}
templateVars.forEach(varName => {
const existingMapping = updatedMapping.find(item => item.value === `{{${varName}}}`)
const regex = new RegExp(`{{\\s*${varName.replace(/\./g, '\\.')}\\s*}}`, 'g')
if (existingMapping) {
updatedTemplate = updatedTemplate.replace(regex, `{{${existingMapping.name}}}`)
} else if (!existingNames.includes(varName)) {
const mappingName = varName.includes('.') ? varName.split('.').pop() || varName : varName
updatedMapping.push({ name: mappingName, value: `{{${varName}}}` })
updatedTemplate = updatedTemplate.replace(regex, `{{${mappingName}}}`)
}
})
const seenNames = new Set<string>()
const finalMapping = updatedMapping.filter(item => {
const isUsed = templateVars.some(v => item.name === v || item.value === `{{${v}}}`)
if (!isUsed || seenNames.has(item.name)) return false
seenNames.add(item.name)
return true
})
isSyncingRef.current = true
lastSyncSourceRef.current = 'template'
prevMappingNamesRef.current = finalMapping.map((item: any) => item.name).filter(Boolean)
prevTemplateVarsRef.current = templateVars
if (JSON.stringify(finalMapping) !== JSON.stringify(values.mapping)) {
form.setFieldValue('mapping', finalMapping)
}
if (updatedTemplate !== String(values.template)) {
form.setFieldValue('template', updatedTemplate)
}
setTimeout(() => {
isSyncingRef.current = false
lastSyncSourceRef.current = null
}, 50)
}, [values?.template, selectedNode?.data?.type, form])
useEffect(() => {
if (selectedNode && form) {
const { type = 'default', name = '', config } = selectedNode.getData() || {}
@@ -96,6 +217,8 @@ const Properties: FC<PropertiesProps> = ({
}))
}
Object.keys(values).forEach(key => {
if (selectedNode.data?.config?.[key]) {
// Create a deep copy to avoid reference sharing between nodes
@@ -114,7 +237,7 @@ const Properties: FC<PropertiesProps> = ({
...allRest,
})
}
}, [values, selectedNode])
}, [values, selectedNode, form])
const handleAddVariable = () => {
variableModalRef.current?.handleOpen()
@@ -190,11 +313,88 @@ const Properties: FC<PropertiesProps> = ({
.map(node => node.id);
};
// Find parent loop/iteration node if current node is a child
const getParentLoopNode = (nodeId: string): Node | null => {
const node = nodes.find(n => n.id === nodeId);
if (!node) return null;
const nodeData = node.getData();
const cycle = nodeData?.cycle;
if (cycle) {
const parentNode = nodes.find(n => n.getData().id === cycle);
if (parentNode) {
const parentData = parentNode.getData();
if (parentData?.type === 'loop' || parentData?.type === 'iteration') {
return parentNode;
}
}
}
return null;
};
const allPreviousNodeIds = getAllPreviousNodes(selectedNode.id);
const childNodeIds = getChildNodes(selectedNode.id);
const parentLoopNode = getParentLoopNode(selectedNode.id);
console.log('childNodeIds', selectedNode, childNodeIds)
const allRelevantNodeIds = [...allPreviousNodeIds, ...childNodeIds];
// Add parent loop/iteration node variables if current node is a child
if (parentLoopNode) {
const parentData = parentLoopNode.getData();
if (parentData.type === 'loop') {
const cycleVars = parentData.cycle_vars || [];
cycleVars.forEach((cycleVar: any) => {
const key = `${parentLoopNode.getData().id}_cycle_${cycleVar.name}`;
if (!addedKeys.has(key)) {
addedKeys.add(key);
variableList.push({
key,
label: cycleVar.name,
type: 'variable',
dataType: cycleVar.type || 'String',
value: `${parentLoopNode.getData().id}.${cycleVar.name}`,
nodeData: parentData,
});
}
});
} else if (parentData.type === 'iteration') {
// Add item and index variables for iteration parent
const itemKey = `${parentLoopNode.getData().id}_item`;
const indexKey = `${parentLoopNode.getData().id}_index`;
if (!addedKeys.has(itemKey)) {
addedKeys.add(itemKey);
variableList.push({
key: itemKey,
label: 'item',
type: 'variable',
dataType: 'Object',
value: `${parentLoopNode.getData().id}.item`,
nodeData: parentData,
});
}
if (!addedKeys.has(indexKey)) {
addedKeys.add(indexKey);
variableList.push({
key: indexKey,
label: 'index',
type: 'variable',
dataType: 'Number',
value: `${parentLoopNode.getData().id}.index`,
nodeData: parentData,
});
}
}
// Add variables from nodes preceding the parent loop/iteration node
const parentPreviousNodeIds = getAllPreviousNodes(parentLoopNode.id);
allRelevantNodeIds.push(...parentPreviousNodeIds);
}
allRelevantNodeIds.forEach(nodeId => {
const node = nodes.find(n => n.id === nodeId);
if (!node) return;
@@ -363,6 +563,87 @@ const Properties: FC<PropertiesProps> = ({
});
}
break
case 'question-classifier':
const classNameKey = `${nodeId}_class_name`;
const outputKey = `${nodeId}_output`;
if (!addedKeys.has(classNameKey)) {
addedKeys.add(classNameKey);
variableList.push({
key: classNameKey,
label: 'class_name',
type: 'variable',
dataType: 'string',
value: `${node.getData().id}.class_name`,
nodeData: nodeData,
});
}
if (!addedKeys.has(outputKey)) {
addedKeys.add(outputKey);
variableList.push({
key: outputKey,
label: 'output',
type: 'variable',
dataType: 'string',
value: `${node.getData().id}.output`,
nodeData: nodeData,
});
}
break
case 'iteration':
const iterationOutputKey = `${nodeId}_output`;
if (!addedKeys.has(iterationOutputKey)) {
addedKeys.add(iterationOutputKey);
// Get the data type from the output configuration, default to string
const outputConfig = nodeData.output;
let outputDataType = 'string';
if (outputConfig) {
// Find the selected variable from variableList to get its type
const selectedVariable = variableList.find(v => v.value === outputConfig);
if (selectedVariable) {
outputDataType = selectedVariable.dataType;
}
}
variableList.push({
key: iterationOutputKey,
label: 'output',
type: 'variable',
dataType: outputDataType,
value: `${node.getData().id}.output`,
nodeData: nodeData,
});
}
break
case 'loop':
const cycleVars = nodeData.cycle_vars || [];
cycleVars.forEach((cycleVar: any) => {
const cycleVarKey = `${nodeId}_cycle_${cycleVar.name}`;
if (!addedKeys.has(cycleVarKey)) {
addedKeys.add(cycleVarKey);
variableList.push({
key: cycleVarKey,
label: cycleVar.name,
type: 'variable',
dataType: cycleVar.type || 'string',
value: `${node.getData().id}.${cycleVar.name}`,
nodeData: nodeData,
});
}
});
break
case 'tool':
const toolDataKey = `${nodeId}_data`;
if (!addedKeys.has(toolDataKey)) {
addedKeys.add(toolDataKey);
variableList.push({
key: toolDataKey,
label: 'data',
type: 'variable',
dataType: 'object',
value: `${node.getData().id}.data`,
nodeData: nodeData,
});
}
break
}
});
@@ -388,6 +669,14 @@ const Properties: FC<PropertiesProps> = ({
return variableList;
}, [selectedNode, graphRef]);
// Filter out boolean type variables for loop and llm nodes
const getFilteredVariableList = (nodeType?: string) => {
if (nodeType === 'loop' || nodeType === 'llm') {
return variableList.filter(variable => variable.dataType !== 'boolean');
}
return variableList;
};
console.log('values', values)
console.log('variableList', variableList, selectedNode?.data)
@@ -412,6 +701,8 @@ const Properties: FC<PropertiesProps> = ({
{selectedNode?.data?.type === 'http-request'
? <HttpRequest
options={variableList}
selectedNode={selectedNode}
graphRef={graphRef}
/>
: selectedNode?.data?.type === 'tool'
? <ToolConfig options={variableList} />
@@ -469,7 +760,7 @@ const Properties: FC<PropertiesProps> = ({
if (selectedNode?.data?.type === 'llm' && key === 'messages' && config.type === 'define') {
// 为llm节点且isArray=true时添加context变量支持
let contextVariableList = [...variableList];
let contextVariableList = [...getFilteredVariableList('llm')];
const isArrayMode = config.isArray !== false; // 默认为true
if (isArrayMode) {
@@ -491,14 +782,14 @@ const Properties: FC<PropertiesProps> = ({
return (
<Form.Item key={key} name={key}>
<MessageEditor options={contextVariableList} parentName={key} />
<MessageEditor key={key} options={contextVariableList} parentName={key} />
</Form.Item>
)
}
if (selectedNode?.data?.type === 'end' && key === 'output') {
return (
<Form.Item key={key} name={key}>
<MessageEditor isArray={false} parentName={key} options={variableList} />
<MessageEditor key={key} isArray={false} parentName={key} options={variableList} />
</Form.Item>
)
}
@@ -525,7 +816,8 @@ const Properties: FC<PropertiesProps> = ({
title={t(`workflow.config.${selectedNode?.data?.type}.${key}`)}
isArray={!!config.isArray}
parentName={key}
options={variableList}
enableJinja2={config.enableJinja2 as boolean}
options={getFilteredVariableList(selectedNode?.data?.type)}
/>
</Form.Item>
)
@@ -546,7 +838,7 @@ const Properties: FC<PropertiesProps> = ({
<Form.Item key={key} name={key}>
<GroupVariableList
name={key}
options={variableList}
options={getFilteredVariableList(selectedNode?.data?.type)}
isCanAdd={!!(values as any)?.group}
/>
</Form.Item>
@@ -558,7 +850,7 @@ const Properties: FC<PropertiesProps> = ({
<Form.Item key={key} name={key}>
<CaseList
name={key}
options={variableList}
options={getFilteredVariableList(selectedNode?.data?.type)}
selectedNode={selectedNode}
graphRef={graphRef}
/>
@@ -571,7 +863,7 @@ const Properties: FC<PropertiesProps> = ({
<Form.Item key={key} name={key}
label={t(`workflow.config.${selectedNode?.data?.type}.${key}`)}
>
<MappingList name={key} options={variableList} />
<MappingList name={key} options={getFilteredVariableList(selectedNode?.data?.type)} />
</Form.Item>
)
@@ -581,7 +873,7 @@ const Properties: FC<PropertiesProps> = ({
<Form.Item key={key} name={key}>
<CycleVarsList
parentName={key}
options={variableList}
options={getFilteredVariableList(selectedNode?.data?.type)}
/>
</Form.Item>
)
@@ -655,9 +947,9 @@ const Properties: FC<PropertiesProps> = ({
findParentLoopIteration(selectedNode.id);
}
return [...variableList, ...loopIterationVars];
return [...getFilteredVariableList(selectedNode?.data?.type), ...loopIterationVars];
}
return variableList;
return getFilteredVariableList(selectedNode?.data?.type);
})()
}
/>
@@ -678,7 +970,7 @@ const Properties: FC<PropertiesProps> = ({
? <Input.TextArea placeholder={t('common.pleaseEnter')} />
: config.type === 'select'
? <Select
options={config.needTranslation ? config.options?.map(vo => ({ ...vo, label: t(vo.label) })) : config.options}
options={config.needTranslation ? (config.options || []).map(vo => ({ ...vo, label: t(vo.label) })) : config.options}
placeholder={t('common.pleaseSelect')}
/>
: config.type === 'inputNumber'
@@ -698,9 +990,10 @@ const Properties: FC<PropertiesProps> = ({
? <VariableSelect
placeholder={t('common.pleaseSelect')}
options={(() => {
const baseVariableList = getFilteredVariableList(selectedNode?.data?.type);
// Apply filtering if specified in config
if (config.filterNodeTypes || config.filterVariableNames) {
return variableList.filter(variable => {
return baseVariableList.filter(variable => {
const nodeTypeMatch = !config.filterNodeTypes ||
(Array.isArray(config.filterNodeTypes) && config.filterNodeTypes.includes(variable.nodeData?.type));
const variableNameMatch = !config.filterVariableNames ||
@@ -721,22 +1014,38 @@ const Properties: FC<PropertiesProps> = ({
return nodeData?.cycle === selectedNode.id;
});
return variableList.filter(variable =>
return baseVariableList.filter(variable =>
childNodes.some(node => node.id === variable.nodeData?.id)
);
}
return variableList;
return baseVariableList;
})()
}
/>
: config.type === 'switch'
? <Switch onChange={key === 'group' ? () => { form.setFieldValue('group_variables', []) } : undefined} />
? <Switch onChange={key === 'group' ? () => { form.setFieldValue('group_variables', []) } : undefined} />
: config.type === 'categoryList'
? <CategoryList parentName={key} selectedNode={selectedNode} graphRef={graphRef} />
: config.type === 'conditionList'
? <ConditionList
parentName={key}
options={variableList}
options={(() => {
// For loop nodes, add cycle_vars to condition options
if (selectedNode?.data?.type === 'loop') {
const cycleVars = values?.cycle_vars || [];
const cycleVarSuggestions: Suggestion[] = cycleVars.map((cycleVar: any) => ({
key: `${selectedNode.id}_cycle_${cycleVar.name}`,
label: cycleVar.name,
type: 'variable',
dataType: cycleVar.type || 'String',
value: `${selectedNode.getData().id}.${cycleVar.name}`,
nodeData: selectedNode.getData(),
}));
return [...getFilteredVariableList(selectedNode?.data?.type), ...cycleVarSuggestions];
}
return getFilteredVariableList(selectedNode?.data?.type);
})()
}
selectedNode={selectedNode}
graphRef={graphRef}
addBtnText={t('workflow.config.addCase')}

View File

@@ -300,6 +300,7 @@ export const nodeLibrary: NodeLibrary[] = [
config: {
cycle_vars: {
type: 'cycleVarsList',
defaultValue: []
},
condition: {
type: 'conditionList',
@@ -395,12 +396,14 @@ export const nodeLibrary: NodeLibrary[] = [
},
retry: {
type: 'switch',
defaultValue: false
defaultValue: {
enable: false
}
},
error_handle: {
type: 'define',
defaultValue: {
method: 'default'
method: 'none'
}
}
}
@@ -420,11 +423,13 @@ export const nodeLibrary: NodeLibrary[] = [
config: {
mapping: {
type: 'mappingList',
defaultValue: []
defaultValue: [{name: 'arg1'}]
},
template: {
type: 'messageEditor',
isArray: false,
enableJinja2: true,
defaultValue: "{{arg1}}"
},
}
}

View File

@@ -193,6 +193,27 @@ export const useWorkflowGraph = ({
nodeConfig.height = newHeight;
}
// 如果是http-request节点检查error_handle.method配置
if (type === 'http-request' && (config as any).error_handle?.method === 'branch') {
const portAttrs = {
circle: {
r: 4, magnet: true, stroke: '#155EEF', strokeWidth: 2, fill: '#155EEF', position: { top: 22 }
},
};
nodeConfig.ports = {
groups: {
right: { position: 'right', attrs: portAttrs },
left: { position: 'left', attrs: portAttrs },
},
items: [
{ group: 'left' },
{ group: 'right', id: 'right' },
{ group: 'right', id: 'ERROR', attrs: { text: { text: t('workflow.config.http-request.errorBranch'), fontSize: 12, fill: '#5B6167' }}}
]
};
}
return nodeConfig
})
@@ -284,6 +305,14 @@ export const useWorkflowGraph = ({
}
}
// 如果是http-request节点且有label根据label匹配对应的端口
if (sourceCell.getData()?.type === 'http-request' && label) {
const matchingPort = sourcePorts.find((port: any) => port.id === label);
if (matchingPort) {
sourcePort = label;
}
}
const edgeConfig = {
source: {
cell: sourceCell.id,
@@ -954,6 +983,23 @@ export const useWorkflowGraph = ({
};
}
// 如果是http-request节点的右侧端口连线添加label
if (sourceCell?.getData()?.type === 'http-request') {
if (sourcePortId === 'ERROR') {
return {
source: sourceCell.getData().id,
target: targetCell?.getData().id,
label: 'ERROR',
};
} else {
return {
source: sourceCell.getData().id,
target: targetCell?.getData().id,
label: 'SUCCESS',
};
}
}
return {
source: sourceCell?.getData().id,
target: targetCell?.getData().id,

View File

@@ -26,6 +26,7 @@ export interface NodeConfig {
group_variables?: Array<{ key: string, value: string[] }>
cycle?: string;
cycle_vars?: Array<{ name: string; type: string; value: string; input_type: string; }>
[key: string]: unknown;
}