Merge branch 'develop' of codeup.aliyun.com:redbearai/python/redbear-mem-open into develop

# Conflicts:
#	api/app/services/workspace_service.py
This commit is contained in:
Mark
2025-12-24 20:59:51 +08:00
107 changed files with 5443 additions and 5005 deletions

View File

@@ -3,26 +3,26 @@
提供 Agent 试运行功能,允许用户在不发布应用的情况下测试配置。
"""
import time
import uuid
import json
import asyncio
import datetime
from typing import Dict, Any, Optional, List, AsyncGenerator
from langchain.tools import tool
from pydantic import BaseModel, Field
from sqlalchemy.orm import Session
from sqlalchemy import select
import json
import time
import uuid
from typing import Any, AsyncGenerator, Dict, List, Optional
from app.models import AgentConfig, ModelConfig, ModelApiKey
from app.core.exceptions import BusinessException
from app.core.error_codes import BizCode
from app.core.exceptions import BusinessException
from app.core.logging_config import get_business_logger
from app.schemas.prompt_schema import render_prompt_message, PromptMessageRole
from app.core.rag.nlp.search import knowledge_retrieval
from app.models import AgentConfig, ModelApiKey, ModelConfig
from app.schemas.prompt_schema import PromptMessageRole, render_prompt_message
from app.services.langchain_tool_server import Search
from app.services.memory_agent_service import MemoryAgentService
from app.services.model_parameter_merger import ModelParameterMerger
from app.core.rag.nlp.search import knowledge_retrieval
from app.services.langchain_tool_server import Search
from langchain.tools import tool
from pydantic import BaseModel, Field
from sqlalchemy import select
from sqlalchemy.orm import Session
logger = get_business_logger()
class KnowledgeRetrievalInput(BaseModel):
@@ -83,17 +83,23 @@ def create_long_term_memory_tool(memory_config: Dict[str, Any], end_user_id: str
"""
logger.info(f" 长期记忆工具被调用question={question}, user={end_user_id}")
try:
memory_content = asyncio.run(
MemoryAgentService().read_memory(
group_id=end_user_id,
message=question,
history=[],
search_switch="1",
config_id=config_id,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id
from app.db import get_db
db = next(get_db())
try:
memory_content = asyncio.run(
MemoryAgentService().read_memory(
group_id=end_user_id,
message=question,
history=[],
search_switch="1",
config_id=config_id,
db=db,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id
)
)
)
finally:
db.close()
logger.info(f'用户IDAgent:{end_user_id}')
logger.debug("调用长期记忆 API", extra={"question": question, "end_user_id": end_user_id})
@@ -713,9 +719,9 @@ class DraftRunService:
Raises:
BusinessException: 当指定的会话不存在时
"""
from app.services.conversation_service import ConversationService
from app.schemas.conversation_schema import ConversationCreate
from app.models import Conversation as ConversationModel
from app.schemas.conversation_schema import ConversationCreate
from app.services.conversation_service import ConversationService
conversation_service = ConversationService(self.db)

View File

@@ -7,14 +7,15 @@ Classes:
EmotionAnalyticsService: 情绪分析服务,提供各种情绪分析功能
"""
from typing import Dict, Any, Optional, List
import statistics
import json
from pydantic import BaseModel, Field
import statistics
from typing import Any, Dict, List, Optional
from app.core.logging_config import get_business_logger
from app.repositories.neo4j.emotion_repository import EmotionRepository
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.core.logging_config import get_business_logger
from pydantic import BaseModel, Field
from sqlalchemy.orm import Session
logger = get_business_logger()
@@ -64,19 +65,9 @@ class EmotionAnalyticsService:
"""获取情绪标签统计
查询指定用户的情绪类型分布,包括计数、百分比和平均强度。
Args:
end_user_id: 宿主ID用户组ID
emotion_type: 可选的情绪类型过滤
start_date: 可选的开始日期ISO格式
end_date: 可选的结束日期ISO格式
limit: 返回结果的最大数量
Returns:
Dict: 包含情绪标签统计的响应数据:
- tags: 情绪标签列表
- total_count: 总情绪数量
- time_range: 时间范围信息
确保返回所有6个情绪维度joy、sadness、anger、fear、surprise、neutral
即使某些维度没有数据也会返回count=0的记录。
"""
try:
logger.info(f"获取情绪标签统计: user={end_user_id}, type={emotion_type}, "
@@ -91,8 +82,34 @@ class EmotionAnalyticsService:
limit=limit
)
# 定义所有6个情绪维度
all_emotion_types = ['joy', 'sadness', 'anger', 'fear', 'surprise', 'neutral']
# 将查询结果转换为字典,方便查找
tags_dict = {tag["emotion_type"]: tag for tag in tags}
# 补全缺失的情绪维度
complete_tags = []
for emotion in all_emotion_types:
if emotion in tags_dict:
complete_tags.append(tags_dict[emotion])
else:
# 如果该情绪类型不存在,添加默认值
complete_tags.append({
"emotion_type": emotion,
"count": 0,
"percentage": 0.0,
"avg_intensity": 0.0
})
# 计算总数
total_count = sum(tag["count"] for tag in tags)
total_count = sum(tag["count"] for tag in complete_tags)
# 如果有数据重新计算百分比因为补全了0值项
if total_count > 0:
for tag in complete_tags:
if tag["count"] > 0:
tag["percentage"] = round((tag["count"] / total_count) * 100, 2)
# 构建时间范围信息
time_range = {}
@@ -103,12 +120,12 @@ class EmotionAnalyticsService:
# 格式化响应
response = {
"tags": tags,
"tags": complete_tags,
"total_count": total_count,
"time_range": time_range if time_range else None
}
logger.info(f"情绪标签统计完成: total_count={total_count}, tags_count={len(tags)}")
logger.info(f"情绪标签统计完成: total_count={total_count}, tags_count={len(complete_tags)}")
return response
except Exception as e:
@@ -454,7 +471,7 @@ class EmotionAnalyticsService:
async def generate_emotion_suggestions(
self,
end_user_id: str,
config_id: Optional[int] = None
db: Session,
) -> Dict[str, Any]:
"""生成个性化情绪建议
@@ -462,7 +479,7 @@ class EmotionAnalyticsService:
Args:
end_user_id: 宿主ID用户组ID
config_id: 配置ID可选用于从数据库加载LLM配置
db: 数据库会话
Returns:
Dict: 包含个性化建议的响应:
@@ -470,14 +487,32 @@ class EmotionAnalyticsService:
- suggestions: 建议列表3-5条
"""
try:
logger.info(f"生成个性化情绪建议: user={end_user_id}, config_id={config_id}")
logger.info(f"生成个性化情绪建议: user={end_user_id}")
# 1. 如果提供了 config_id从数据库加载配置
if config_id is not None:
from app.core.memory.utils.config.definitions import reload_configuration_from_database
config_loaded = reload_configuration_from_database(config_id)
if not config_loaded:
logger.warning(f"无法加载配置 config_id={config_id},将使用默认配置")
# 1. 从 end_user_id 获取关联的 memory_config_id
llm_client = None
try:
from app.services.memory_agent_service import (
get_end_user_connected_config,
)
connected_config = get_end_user_connected_config(end_user_id, db)
config_id = connected_config.get("memory_config_id")
if config_id is not None:
from app.services.memory_config_service import (
MemoryConfigService,
)
config_service = MemoryConfigService(db)
memory_config = config_service.load_memory_config(
config_id=int(config_id),
service_name="EmotionAnalyticsService.generate_emotion_suggestions"
)
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
factory = MemoryClientFactory(db)
llm_client = factory.get_llm_client(str(memory_config.llm_model_id))
except Exception as e:
logger.warning(f"无法获取 end_user {end_user_id} 的配置,将使用默认配置: {e}")
# 2. 获取情绪健康数据
health_data = await self.calculate_emotion_health_index(end_user_id, time_range="30d")
@@ -498,8 +533,9 @@ class EmotionAnalyticsService:
prompt = await self._build_suggestion_prompt(health_data, patterns, user_profile)
# 7. 调用LLM生成建议使用配置中的LLM
from app.core.memory.utils.llm.llm_utils import get_llm_client
llm_client = get_llm_client()
if llm_client is None:
# 无法获取配置时,抛出错误而不是使用默认配置
raise ValueError("无法获取LLM配置请确保end_user关联了有效的memory_config")
# 将 prompt 转换为 messages 格式
messages = [
@@ -598,7 +634,9 @@ class EmotionAnalyticsService:
Returns:
str: LLM prompt
"""
from app.core.memory.utils.prompt.prompt_utils import render_emotion_suggestions_prompt
from app.core.memory.utils.prompt.prompt_utils import (
render_emotion_suggestions_prompt,
)
prompt = await render_emotion_suggestions_prompt(
health_data=health_data,

View File

@@ -9,10 +9,12 @@ Classes:
import logging
from typing import Optional
from app.core.memory.models.emotion_models import EmotionExtraction
from app.models.data_config_model import DataConfig
from app.core.memory.utils.llm.llm_utils import get_llm_client
from app.core.memory.llm_tools.llm_client import LLMClientException
from app.core.memory.models.emotion_models import EmotionExtraction
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.db import get_db_context
from app.models.data_config_model import DataConfig
logger = logging.getLogger(__name__)
@@ -50,7 +52,9 @@ class EmotionExtractionService:
"""
if self.llm_client is None or model_id:
effective_model_id = model_id or self.llm_id
self.llm_client = get_llm_client(effective_model_id)
with get_db_context() as db:
factory = MemoryClientFactory(db)
self.llm_client = factory.get_llm_client(effective_model_id)
return self.llm_client
async def extract_emotion(
@@ -142,7 +146,9 @@ class EmotionExtractionService:
Returns:
Formatted prompt string for LLM
"""
from app.core.memory.utils.prompt.prompt_utils import render_emotion_extraction_prompt
from app.core.memory.utils.prompt.prompt_utils import (
render_emotion_extraction_prompt,
)
prompt = await render_emotion_extraction_prompt(
statement=statement,

View File

@@ -4,50 +4,48 @@ Memory Agent Service
Handles business logic for memory agent operations including read/write services,
health checks, and message type classification.
"""
import json
import os
import re
import time
import json
import uuid
from threading import Lock
from typing import Dict, List, Optional, Any, AsyncGenerator
from app.services.memory_konwledges_server import write_rag
from typing import Any, AsyncGenerator, Dict, List, Optional
import redis
from langchain_mcp_adapters.client import MultiServerMCPClient
from langchain_mcp_adapters.tools import load_mcp_tools
from sqlalchemy.orm import Session
from sqlalchemy import func
from pydantic import BaseModel, Field
from app.core.config import settings
from app.core.logging_config import get_logger
from app.core.logging_config import get_config_logger, get_logger
from app.core.memory.agent.langgraph_graph.read_graph import make_read_graph
from app.core.memory.agent.langgraph_graph.write_graph import make_write_graph
from app.core.memory.agent.logger_file.log_streamer import LogStreamer
from app.core.memory.agent.utils.llm_tools import PROJECT_ROOT_
from app.core.memory.agent.utils.mcp_tools import get_mcp_server_config
from app.core.memory.agent.utils.type_classifier import status_typle
from app.db import get_db
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags
from app.core.memory.utils.llm.llm_utils import get_llm_client
from app.schemas.memory_storage_schema import ApiResponse, ok, fail
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.db import get_db_context
from app.models.knowledge_model import Knowledge, KnowledgeType
from app.repositories.data_config_repository import DataConfigRepository
from app.core.memory.agent.logger_file.log_streamer import LogStreamer
from app.services.memory_konwledges_server import memory_konwledges_up, SimpleUser, find_document_id_by_kb_and_filename
from app.core.memory.utils.config.definitions import reload_configuration_from_database
from app.schemas.file_schema import CustomTextFileCreate
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.schemas.memory_config_schema import ConfigurationError, MemoryConfig
from app.services.memory_config_service import MemoryConfigService
from app.services.memory_konwledges_server import (
write_rag,
)
from langchain_mcp_adapters.client import MultiServerMCPClient
from langchain_mcp_adapters.tools import load_mcp_tools
from pydantic import BaseModel, Field
from sqlalchemy import func
from sqlalchemy.orm import Session
try:
from app.core.memory.utils.log.audit_logger import audit_logger
except ImportError:
audit_logger = None
logger = get_logger(__name__)
config_logger = get_config_logger()
# Initialize Neo4j connector for analytics functions
_neo4j_connector = Neo4jConnector()
db_gen = get_db()
db = next(db_gen)
class MemoryAgentService:
"""Service for memory agent operations"""
@@ -257,14 +255,17 @@ class MemoryAgentService:
logger.info("Log streaming completed, cleaning up resources")
# LogStreamer uses context manager for file handling, so cleanup is automatic
async def write_memory(self, group_id: str, message: str, config_id: str,storage_type:str,user_rag_memory_id:str) -> str:
async def write_memory(self, group_id: str, message: str, config_id: Optional[str], db: Session, storage_type: str, user_rag_memory_id: str) -> str:
"""
Process write operation with config_id
Args:
group_id: Group identifier
group_id: Group identifier (also used as end_user_id)
message: Message to write
config_id: Configuration ID from database
db: SQLAlchemy database session
storage_type: Storage type (neo4j or rag)
user_rag_memory_id: User RAG memory ID
Returns:
Write operation result status
@@ -272,24 +273,40 @@ class MemoryAgentService:
Raises:
ValueError: If config loading fails or write operation fails
"""
if config_id==None:
config_id = os.getenv("config_id")
# Resolve config_id if None using end_user's connected config
if config_id is None:
try:
connected_config = get_end_user_connected_config(group_id, db)
config_id = connected_config.get("memory_config_id")
if config_id is None:
raise ValueError(f"No memory configuration found for end_user {group_id}. Please ensure the user has a connected memory configuration.")
except Exception as e:
if "No memory configuration found" in str(e):
raise # Re-raise our specific error
logger.error(f"Failed to get connected config for end_user {group_id}: {e}")
raise ValueError(f"Unable to determine memory configuration for end_user {group_id}: {e}")
import time
start_time = time.time()
# 如果 config_id 为 None使用默认值 "17"
config_loaded = reload_configuration_from_database(config_id)
if not config_loaded:
error_msg = f"Failed to load configuration for config_id: {config_id}"
# Load configuration from database only
try:
config_service = MemoryConfigService(db)
memory_config = config_service.load_memory_config(
config_id=config_id,
service_name="MemoryAgentService"
)
logger.info(f"Configuration loaded successfully: {memory_config.config_name}")
except ConfigurationError as e:
error_msg = f"Failed to load configuration for config_id: {config_id}: {e}"
logger.error(error_msg)
# 记录失败的操作
# Log failed operation
if audit_logger:
duration = time.time() - start_time
audit_logger.log_operation( operation="WRITE", config_id=config_id, group_id=group_id, success=False, duration=duration, error=error_msg )
audit_logger.log_operation(operation="WRITE", config_id=config_id, group_id=group_id, success=False, duration=duration, error=error_msg)
raise ValueError(error_msg)
logger.info(f"Configuration loaded successfully for config_id: {config_id}")
mcp_config = get_mcp_server_config()
client = MultiServerMCPClient(mcp_config)
@@ -300,20 +317,43 @@ class MemoryAgentService:
async with client.session("data_flow") as session:
logger.debug("Connected to MCP Server: data_flow")
tools = await load_mcp_tools(session)
workflow_errors = [] # Track errors from workflow
# Pass config_id to the graph workflow
async with make_write_graph(group_id, tools, group_id, group_id, config_id=config_id) as graph:
# Pass memory_config to the graph workflow
async with make_write_graph(group_id, tools, group_id, group_id, memory_config=memory_config) as graph:
logger.debug("Write graph created successfully")
config = {"configurable": {"thread_id": group_id}}
async for event in graph.astream(
{"messages": message, "config_id": config_id},
{"messages": message, "memory_config": memory_config, "errors": []},
stream_mode="values",
config=config
):
messages = event.get('messages')
return self.writer_messages_deal(messages,start_time,group_id,config_id,message)
# Capture any errors from the state
if event.get('errors'):
workflow_errors.extend(event.get('errors', []))
# Check for workflow errors
if workflow_errors:
error_details = "; ".join([f"{e['tool']}: {e['error']}" for e in workflow_errors])
logger.error(f"Write workflow failed with errors: {error_details}")
if audit_logger:
duration = time.time() - start_time
audit_logger.log_operation(
operation="WRITE",
config_id=config_id,
group_id=group_id,
success=False,
duration=duration,
error=error_details
)
raise ValueError(f"Write workflow failed: {error_details}")
return self.writer_messages_deal(messages, start_time, group_id, config_id, message)
async def read_memory(
self,
@@ -321,7 +361,8 @@ class MemoryAgentService:
message: str,
history: List[Dict],
search_switch: str,
config_id: str,
config_id: Optional[str],
db: Session,
storage_type: str,
user_rag_memory_id: str
) -> Dict:
@@ -334,11 +375,14 @@ class MemoryAgentService:
- "2": Direct answer based on context
Args:
group_id: Group identifier
group_id: Group identifier (also used as end_user_id)
message: User message
history: Conversation history
search_switch: Search mode switch
config_id: Configuration ID from database
db: SQLAlchemy database session
storage_type: Storage type (neo4j or rag)
user_rag_memory_id: User RAG memory ID
Returns:
Dict with 'answer' and 'intermediate_outputs' keys
@@ -350,8 +394,18 @@ class MemoryAgentService:
import time
start_time = time.time()
if config_id==None:
config_id = os.getenv("config_id")
# Resolve config_id if None using end_user's connected config
if config_id is None:
try:
connected_config = get_end_user_connected_config(group_id, db)
config_id = connected_config.get("memory_config_id")
if config_id is None:
raise ValueError(f"No memory configuration found for end_user {group_id}. Please ensure the user has a connected memory configuration.")
except Exception as e:
if "No memory configuration found" in str(e):
raise # Re-raise our specific error
logger.error(f"Failed to get connected config for end_user {group_id}: {e}")
raise ValueError(f"Unable to determine memory configuration for end_user {group_id}: {e}")
logger.info(f"Read operation for group {group_id} with config_id {config_id}")
@@ -365,15 +419,19 @@ class MemoryAgentService:
group_lock = self.get_group_lock(group_id)
with group_lock:
# Step 1: Load configuration from database
from app.core.memory.utils.config.definitions import reload_configuration_from_database
config_loaded = reload_configuration_from_database(config_id)
if not config_loaded:
error_msg = f"Failed to load configuration for config_id: {config_id}"
# Step 1: Load configuration from database only
try:
config_service = MemoryConfigService(db)
memory_config = config_service.load_memory_config(
config_id=config_id,
service_name="MemoryAgentService"
)
logger.info(f"Configuration loaded successfully: {memory_config.config_name}")
except ConfigurationError as e:
error_msg = f"Failed to load configuration for config_id: {config_id}: {e}"
logger.error(error_msg)
# 记录失败的操作
# Log failed operation
if audit_logger:
duration = time.time() - start_time
audit_logger.log_operation(
@@ -387,8 +445,6 @@ class MemoryAgentService:
raise ValueError(error_msg)
logger.info(f"Configuration loaded successfully for config_id: {config_id}")
# Step 2: Prepare history
history.append({"role": "user", "content": message})
logger.debug(f"Group ID:{group_id}, Message:{message}, History:{history}, Config ID:{config_id}")
@@ -404,45 +460,52 @@ class MemoryAgentService:
intermediate_outputs = []
seen_intermediates = set() # Track seen intermediate outputs to avoid duplicates
# Pass config_id to the graph workflow
async with make_read_graph(group_id, tools, search_switch, group_id, group_id, config_id=config_id,storage_type=storage_type,user_rag_memory_id=user_rag_memory_id) as graph:
# Pass memory_config to the graph workflow
async with make_read_graph(group_id, tools, search_switch, group_id, group_id, memory_config=memory_config, storage_type=storage_type, user_rag_memory_id=user_rag_memory_id) as graph:
start = time.time()
config = {"configurable": {"thread_id": group_id}}
workflow_errors = [] # Track errors from workflow
async for event in graph.astream(
{"messages": history, "config_id": config_id},
{"messages": history, "memory_config": memory_config, "errors": []},
stream_mode="values",
config=config
):
messages = event.get('messages')
# Capture any errors from the state
if event.get('errors'):
workflow_errors.extend(event.get('errors', []))
for msg in messages:
msg_content = msg.content
msg_role = msg.__class__.__name__.lower().replace("message", "")
outputs.append({
"role": msg.__class__.__name__.lower().replace("message", ""),
"role": msg_role,
"content": msg_content
})
# Extract intermediate outputs
if hasattr(msg, 'content'):
try:
# Debug: log message type and content preview
msg_type = msg.__class__.__name__
content_preview = str(msg_content)[:200] if msg_content else "empty"
logger.debug(f"Processing message type={msg_type}, content preview={content_preview}")
# Handle MCP content format: [{'type': 'text', 'text': '...'}]
content_to_parse = msg_content
if isinstance(msg_content, list):
for block in msg_content:
if isinstance(block, dict) and block.get('type') == 'text':
content_to_parse = block.get('text', '')
break
else:
continue # No text block found
# Try to parse content as JSON
if isinstance(msg_content, str):
if isinstance(content_to_parse, str):
try:
parsed = json.loads(msg_content)
parsed = json.loads(content_to_parse)
if isinstance(parsed, dict):
# Debug: log what keys are in parsed
logger.debug(f"Parsed dict keys: {list(parsed.keys())}")
# Check for single intermediate output
if '_intermediate' in parsed:
intermediate_data = parsed['_intermediate']
output_key = self._create_intermediate_key(intermediate_data)
logger.debug(f"Found _intermediate: {intermediate_data.get('type', 'unknown')}")
if output_key not in seen_intermediates:
seen_intermediates.add(output_key)
@@ -450,34 +513,14 @@ class MemoryAgentService:
# Check for multiple intermediate outputs (from Retrieve)
if '_intermediates' in parsed:
logger.debug(f"Found _intermediates list with {len(parsed['_intermediates'])} items")
for intermediate_data in parsed['_intermediates']:
output_key = self._create_intermediate_key(intermediate_data)
logger.debug(f"Processing intermediate: {intermediate_data.get('type', 'unknown')}")
if output_key not in seen_intermediates:
seen_intermediates.add(output_key)
intermediate_outputs.append(self._format_intermediate_output(intermediate_data))
except (json.JSONDecodeError, ValueError):
pass
elif isinstance(msg_content, dict):
# Check for single intermediate output
if '_intermediate' in msg_content:
intermediate_data = msg_content['_intermediate']
output_key = self._create_intermediate_key(intermediate_data)
if output_key not in seen_intermediates:
seen_intermediates.add(output_key)
intermediate_outputs.append(self._format_intermediate_output(intermediate_data))
# Check for multiple intermediate outputs (from Retrieve)
if '_intermediates' in msg_content:
for intermediate_data in msg_content['_intermediates']:
output_key = self._create_intermediate_key(intermediate_data)
if output_key not in seen_intermediates:
seen_intermediates.add(output_key)
intermediate_outputs.append(self._format_intermediate_output(intermediate_data))
except Exception as e:
logger.debug(f"Failed to extract intermediate output: {e}")
@@ -489,18 +532,57 @@ class MemoryAgentService:
for messages in outputs:
if messages['role'] == 'tool':
message = messages['content']
# Handle MCP content format: [{'type': 'text', 'text': '...'}]
if isinstance(message, list):
# Extract text from MCP content blocks
for block in message:
if isinstance(block, dict) and block.get('type') == 'text':
message = block.get('text', '')
break
else:
continue # No text block found
try:
message = json.loads(message) if isinstance(message, str) else message
if isinstance(message, dict) and message.get('status') != '':
summary_result = message.get('summary_result')
if summary_result:
final_answer = summary_result
parsed = json.loads(message) if isinstance(message, str) else message
if isinstance(parsed, dict):
if parsed.get('status') == 'success':
summary_result = parsed.get('summary_result')
if summary_result:
final_answer = summary_result
except (json.JSONDecodeError, ValueError):
pass
# 记录成功的操作
total_duration = time.time() - start_time
if audit_logger:
# Check for workflow errors
if workflow_errors:
error_details = "; ".join([f"{e['tool']}: {e['error']}" for e in workflow_errors])
logger.warning(f"Read workflow completed with errors: {error_details}")
if audit_logger:
audit_logger.log_operation(
operation="READ",
config_id=config_id,
group_id=group_id,
success=False,
duration=total_duration,
error=error_details,
details={
"search_switch": search_switch,
"history_length": len(history),
"intermediate_outputs_count": len(intermediate_outputs),
"has_answer": bool(final_answer),
"errors": workflow_errors
}
)
# Raise error if no answer was produced
if not final_answer:
raise ValueError(f"Read workflow failed: {error_details}")
if audit_logger and not workflow_errors:
audit_logger.log_operation(
operation="READ",
config_id=config_id,
@@ -612,19 +694,29 @@ class MemoryAgentService:
else:
return output
async def classify_message_type(self, message: str) -> Dict:
async def classify_message_type(self, message: str, config_id: int, db: Session) -> Dict:
"""
Determine the type of user message (read or write)
Updated to eliminate global variables in favor of explicit parameters.
Args:
message: User message to classify
config_id: Configuration ID to load LLM model from database
db: Database session
Returns:
Type classification result
"""
logger.info("Classifying message type")
status = await status_typle(message)
# Load configuration to get LLM model ID
config_service = MemoryConfigService(db)
memory_config = config_service.load_memory_config(
config_id=config_id,
service_name="MemoryAgentService"
)
status = await status_typle(message, memory_config.llm_model_id)
logger.debug(f"Message type: {status}")
return status
@@ -790,7 +882,9 @@ class MemoryAgentService:
async def get_user_profile(
self,
end_user_id: Optional[str] = None,
current_user_id: Optional[str] = None
current_user_id: Optional[str] = None,
llm_id: Optional[str] = None,
db: Session = None
) -> Dict[str, Any]:
"""
获取用户详情,包含:
@@ -801,6 +895,8 @@ class MemoryAgentService:
参数:
- end_user_id: 用户ID可选
- current_user_id: 当前登录用户的ID保留参数
- llm_id: LLM模型ID用于生成标签可选如果不提供则跳过标签生成
- db: 数据库会话(可选)
返回格式:
{
@@ -817,7 +913,7 @@ class MemoryAgentService:
# 1. 根据 end_user_id 获取 end_user_name
try:
if end_user_id:
if end_user_id and db:
from app.repositories import end_user_repository
from app.schemas.end_user_schema import EndUser as EndUserSchema
@@ -862,15 +958,19 @@ class MemoryAgentService:
await connector.close()
if not statements:
if not statements or not llm_id:
result["tags"] = []
if not llm_id and statements:
logger.warning("llm_id not provided, skipping tag generation")
else:
# 构建摘要文本
summary_text = f"用户语句样本:{' | '.join(statements[:20])}\n核心实体:{', '.join(entities)}"
logger.debug(f"User data found: {len(statements)} statements, {len(entities)} entities")
# 使用LLM提取标签
llm_client = get_llm_client()
with get_db_context() as db:
factory = MemoryClientFactory(db)
llm_client = factory.get_llm_client(llm_id)
# 定义标签提取的结构
class UserTags(BaseModel):
@@ -1032,4 +1132,69 @@ class MemoryAgentService:
# "msg": "解析失败",
# "error_code": "DOC_PARSE_ERROR",
# "data": {"error": str(e)}
# }
# }
def get_end_user_connected_config(end_user_id: str, db: Session) -> Dict[str, Any]:
"""
获取终端用户关联的记忆配置
通过以下流程获取配置:
1. 根据 end_user_id 获取用户的 app_id
2. 获取该应用的最新发布版本
3. 从发布版本的 config 字段中提取 memory_config_id
Args:
end_user_id: 终端用户ID
db: 数据库会话
Returns:
包含 memory_config_id 和相关信息的字典
Raises:
ValueError: 当终端用户不存在或应用未发布时
"""
from app.models.app_release_model import AppRelease
from app.models.end_user_model import EndUser
from sqlalchemy import select
logger.info(f"Getting connected config for end_user: {end_user_id}")
# 1. 获取 end_user 及其 app_id
end_user = db.query(EndUser).filter(EndUser.id == end_user_id).first()
if not end_user:
logger.warning(f"End user not found: {end_user_id}")
raise ValueError(f"终端用户不存在: {end_user_id}")
app_id = end_user.app_id
logger.debug(f"Found end_user app_id: {app_id}")
# 2. 获取该应用的最新发布版本
stmt = (
select(AppRelease)
.where(AppRelease.app_id == app_id, AppRelease.is_active.is_(True))
.order_by(AppRelease.version.desc())
)
latest_release = db.scalars(stmt).first()
if not latest_release:
logger.warning(f"No active release found for app: {app_id}")
raise ValueError(f"应用未发布: {app_id}")
logger.debug(f"Found latest release: version={latest_release.version}, id={latest_release.id}")
# 3. 从 config 中提取 memory_config_id
config = latest_release.config or {}
memory_obj = config.get('memory', {})
memory_config_id = memory_obj.get('memory_content') if isinstance(memory_obj, dict) else None
result = {
"end_user_id": str(end_user_id),
"app_id": str(app_id),
"release_id": str(latest_release.id),
"release_version": latest_release.version,
"memory_config_id": memory_config_id
}
logger.info(f"Successfully retrieved connected config: memory_config_id={memory_config_id}")
return result

View File

@@ -0,0 +1,399 @@
"""
Memory Configuration Service
Centralized configuration loading and management for memory services.
This service eliminates code duplication between MemoryAgentService and MemoryStorageService.
"""
import time
from datetime import datetime
from app.core.logging_config import get_config_logger, get_logger
from app.core.validators.memory_config_validators import (
validate_and_resolve_model_id,
validate_embedding_model,
validate_model_exists_and_active,
)
from app.repositories.data_config_repository import DataConfigRepository
from app.schemas.memory_config_schema import (
ConfigurationError,
InvalidConfigError,
MemoryConfig,
ModelInactiveError,
ModelNotFoundError,
)
from sqlalchemy.orm import Session
logger = get_logger(__name__)
config_logger = get_config_logger()
def _validate_config_id(config_id):
"""Validate configuration ID format."""
if config_id is None:
raise InvalidConfigError(
"Configuration ID cannot be None",
field_name="config_id",
invalid_value=config_id,
)
if isinstance(config_id, int):
if config_id <= 0:
raise InvalidConfigError(
f"Configuration ID must be positive: {config_id}",
field_name="config_id",
invalid_value=config_id,
)
return config_id
if isinstance(config_id, str):
try:
parsed_id = int(config_id.strip())
if parsed_id <= 0:
raise InvalidConfigError(
f"Configuration ID must be positive: {parsed_id}",
field_name="config_id",
invalid_value=config_id,
)
return parsed_id
except ValueError:
raise InvalidConfigError(
f"Invalid configuration ID format: '{config_id}'",
field_name="config_id",
invalid_value=config_id,
)
raise InvalidConfigError(
f"Invalid type for configuration ID: expected int or str, got {type(config_id).__name__}",
field_name="config_id",
invalid_value=config_id,
)
class MemoryConfigService:
"""
Centralized service for memory configuration loading and validation.
This class provides a single implementation of configuration loading logic
that can be shared across multiple services, eliminating code duplication.
Usage:
config_service = MemoryConfigService(db)
memory_config = config_service.load_memory_config(config_id)
model_config = config_service.get_model_config(model_id)
"""
def __init__(self, db: Session):
"""Initialize the service with a database session.
Args:
db: SQLAlchemy database session
"""
self.db = db
def load_memory_config(
self,
config_id: int,
service_name: str = "MemoryConfigService",
) -> MemoryConfig:
"""
Load memory configuration from database by config_id.
Args:
config_id: Configuration ID from database
service_name: Name of the calling service (for logging purposes)
Returns:
MemoryConfig: Immutable configuration object
Raises:
ConfigurationError: If validation fails
"""
start_time = time.time()
config_logger.info(
"Starting memory configuration loading",
extra={
"operation": "load_memory_config",
"service": service_name,
"config_id": config_id,
},
)
logger.info(f"Loading memory configuration from database: config_id={config_id}")
try:
validated_config_id = _validate_config_id(config_id)
result = DataConfigRepository.get_config_with_workspace(self.db, validated_config_id)
if not result:
elapsed_ms = (time.time() - start_time) * 1000
config_logger.error(
"Configuration not found in database",
extra={
"operation": "load_memory_config",
"config_id": validated_config_id,
"load_result": "not_found",
"elapsed_ms": elapsed_ms,
"service": service_name,
},
)
raise ConfigurationError(
f"Configuration {validated_config_id} not found in database"
)
memory_config, workspace = result
# Validate embedding model
embedding_uuid = validate_embedding_model(
validated_config_id,
memory_config.embedding_id,
self.db,
workspace.tenant_id,
workspace.id,
)
# Resolve LLM model
llm_uuid, llm_name = validate_and_resolve_model_id(
memory_config.llm_id,
"llm",
self.db,
workspace.tenant_id,
required=True,
config_id=validated_config_id,
workspace_id=workspace.id,
)
# Resolve optional rerank model
rerank_uuid = None
rerank_name = None
if memory_config.rerank_id:
rerank_uuid, rerank_name = validate_and_resolve_model_id(
memory_config.rerank_id,
"rerank",
self.db,
workspace.tenant_id,
required=False,
config_id=validated_config_id,
workspace_id=workspace.id,
)
# Get embedding model name
embedding_name, _ = validate_model_exists_and_active(
embedding_uuid,
"embedding",
self.db,
workspace.tenant_id,
config_id=validated_config_id,
workspace_id=workspace.id,
)
# Create immutable MemoryConfig object
config = MemoryConfig(
config_id=memory_config.config_id,
config_name=memory_config.config_name,
workspace_id=workspace.id,
workspace_name=workspace.name,
tenant_id=workspace.tenant_id,
llm_model_id=llm_uuid,
llm_model_name=llm_name,
embedding_model_id=embedding_uuid,
embedding_model_name=embedding_name,
rerank_model_id=rerank_uuid,
rerank_model_name=rerank_name,
storage_type=workspace.storage_type or "neo4j",
chunker_strategy=memory_config.chunker_strategy or "RecursiveChunker",
reflexion_enabled=memory_config.enable_self_reflexion or False,
reflexion_iteration_period=int(memory_config.iteration_period or "3"),
reflexion_range=memory_config.reflexion_range or "retrieval",
reflexion_baseline=memory_config.baseline or "time",
loaded_at=datetime.now(),
# Pipeline config: Deduplication
enable_llm_dedup_blockwise=bool(memory_config.enable_llm_dedup_blockwise) if memory_config.enable_llm_dedup_blockwise is not None else False,
enable_llm_disambiguation=bool(memory_config.enable_llm_disambiguation) if memory_config.enable_llm_disambiguation is not None else False,
deep_retrieval=bool(memory_config.deep_retrieval) if memory_config.deep_retrieval is not None else True,
t_type_strict=float(memory_config.t_type_strict) if memory_config.t_type_strict is not None else 0.8,
t_name_strict=float(memory_config.t_name_strict) if memory_config.t_name_strict is not None else 0.8,
t_overall=float(memory_config.t_overall) if memory_config.t_overall is not None else 0.8,
# Pipeline config: Statement extraction
statement_granularity=int(memory_config.statement_granularity) if memory_config.statement_granularity is not None else 2,
include_dialogue_context=bool(memory_config.include_dialogue_context) if memory_config.include_dialogue_context is not None else False,
max_dialogue_context_chars=int(memory_config.max_context) if memory_config.max_context is not None else 1000,
# Pipeline config: Forgetting engine
lambda_time=float(memory_config.lambda_time) if memory_config.lambda_time is not None else 0.5,
lambda_mem=float(memory_config.lambda_mem) if memory_config.lambda_mem is not None else 0.5,
offset=float(memory_config.offset) if memory_config.offset is not None else 0.0,
# Pipeline config: Pruning
pruning_enabled=bool(memory_config.pruning_enabled) if memory_config.pruning_enabled is not None else False,
pruning_scene=memory_config.pruning_scene or "education",
pruning_threshold=float(memory_config.pruning_threshold) if memory_config.pruning_threshold is not None else 0.5,
)
elapsed_ms = (time.time() - start_time) * 1000
config_logger.info(
"Memory configuration loaded successfully",
extra={
"operation": "load_memory_config",
"service": service_name,
"config_id": validated_config_id,
"config_name": config.config_name,
"workspace_id": str(config.workspace_id),
"load_result": "success",
"elapsed_ms": elapsed_ms,
},
)
logger.info(f"Memory configuration loaded successfully: {config.config_name}")
return config
except Exception as e:
elapsed_ms = (time.time() - start_time) * 1000
config_logger.error(
"Failed to load memory configuration",
extra={
"operation": "load_memory_config",
"service": service_name,
"config_id": config_id,
"load_result": "error",
"error_type": type(e).__name__,
"error_message": str(e),
"elapsed_ms": elapsed_ms,
},
exc_info=True,
)
logger.error(f"Failed to load memory configuration {config_id}: {e}")
if isinstance(e, (ConfigurationError, ValueError)):
raise
else:
raise ConfigurationError(f"Failed to load configuration {config_id}: {e}")
def get_model_config(self, model_id: str) -> dict:
"""Get LLM model configuration by ID.
Args:
model_id: Model ID to look up
Returns:
Dict with model configuration including api_key, base_url, etc.
"""
from app.core.config import settings
from app.models.models_model import ModelApiKey
from app.services.model_service import ModelConfigService as ModelSvc
from fastapi import status
from fastapi.exceptions import HTTPException
config = ModelSvc.get_model_by_id(db=self.db, model_id=model_id)
if not config:
logger.warning(f"Model ID {model_id} not found")
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="模型ID不存在")
api_config: ModelApiKey = config.api_keys[0]
return {
"model_name": api_config.model_name,
"provider": api_config.provider,
"api_key": api_config.api_key,
"base_url": api_config.api_base,
"model_config_id": api_config.model_config_id,
"type": config.type,
"timeout": settings.LLM_TIMEOUT,
"max_retries": settings.LLM_MAX_RETRIES,
}
def get_embedder_config(self, embedding_id: str) -> dict:
"""Get embedding model configuration by ID.
Args:
embedding_id: Embedding model ID to look up
Returns:
Dict with embedder configuration including api_key, base_url, etc.
"""
from app.models.models_model import ModelApiKey
from app.services.model_service import ModelConfigService as ModelSvc
from fastapi import status
from fastapi.exceptions import HTTPException
config = ModelSvc.get_model_by_id(db=self.db, model_id=embedding_id)
if not config:
logger.warning(f"Embedding model ID {embedding_id} not found")
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="嵌入模型ID不存在")
api_config: ModelApiKey = config.api_keys[0]
return {
"model_name": api_config.model_name,
"provider": api_config.provider,
"api_key": api_config.api_key,
"base_url": api_config.api_base,
"model_config_id": api_config.model_config_id,
"type": config.type,
"timeout": 120.0,
"max_retries": 5,
}
@staticmethod
def get_pipeline_config(memory_config: MemoryConfig):
"""Build ExtractionPipelineConfig from MemoryConfig.
Args:
memory_config: MemoryConfig object containing all pipeline settings.
Returns:
ExtractionPipelineConfig with deduplication, statement extraction,
and forgetting engine settings.
"""
from app.core.memory.models.variate_config import (
DedupConfig,
ExtractionPipelineConfig,
ForgettingEngineConfig,
StatementExtractionConfig,
)
dedup_config = DedupConfig(
enable_llm_dedup_blockwise=memory_config.enable_llm_dedup_blockwise,
enable_llm_disambiguation=memory_config.enable_llm_disambiguation,
fuzzy_name_threshold_strict=memory_config.t_name_strict,
fuzzy_type_threshold_strict=memory_config.t_type_strict,
fuzzy_overall_threshold=memory_config.t_overall,
)
stmt_config = StatementExtractionConfig(
statement_granularity=memory_config.statement_granularity,
include_dialogue_context=memory_config.include_dialogue_context,
max_dialogue_context_chars=memory_config.max_dialogue_context_chars,
)
forget_config = ForgettingEngineConfig(
offset=memory_config.offset,
lambda_time=memory_config.lambda_time,
lambda_mem=memory_config.lambda_mem,
)
return ExtractionPipelineConfig(
statement_extraction=stmt_config,
deduplication=dedup_config,
forgetting_engine=forget_config,
)
@staticmethod
def get_pruning_config(memory_config: MemoryConfig) -> dict:
"""Retrieve semantic pruning config from MemoryConfig.
Args:
memory_config: MemoryConfig object containing pruning settings.
Returns:
Dict suitable for PruningConfig.model_validate with keys:
- pruning_switch: bool
- pruning_scene: str
- pruning_threshold: float
"""
return {
"pruning_switch": memory_config.pruning_enabled,
"pruning_scene": memory_config.pruning_scene,
"pruning_threshold": memory_config.pruning_threshold,
}

View File

@@ -272,7 +272,7 @@ async def get_workspace_total_memory_count(
from app.repositories.end_user_repository import EndUserRepository
repo = EndUserRepository(db)
end_user = repo.get_by_id(uuid.UUID(end_user_id))
user_name = end_user.name if end_user else None
user_name = end_user.other_name if end_user else None
return {
"total_memory_count": search_result.get("total", 0),
@@ -298,10 +298,10 @@ async def get_workspace_total_memory_count(
details.append({
"end_user_id": end_user_id_str,
"count": host_total,
"name": host.name # 添加 name 字段
"name": host.other_name # 使用 other_name 字段
})
business_logger.debug(f"EndUser {end_user_id_str} ({host.name}) 记忆数: {host_total}")
business_logger.debug(f"EndUser {end_user_id_str} ({host.other_name}) 记忆数: {host_total}")
except Exception as e:
business_logger.warning(f"获取 end_user {host.id} 记忆数失败: {str(e)}")
@@ -309,7 +309,7 @@ async def get_workspace_total_memory_count(
details.append({
"end_user_id": str(host.id),
"count": 0,
"name": host.name # 添加 name 字段
"name": host.other_name # 使用 other_name 字段
})
result = {

View File

@@ -4,38 +4,37 @@ Memory Storage Service
Handles business logic for memory storage operations.
"""
from typing import Dict, List, Optional, Any, AsyncGenerator
import os
import json
import asyncio
import json
import os
import time
import uuid
from datetime import datetime
from sqlalchemy.orm import Session
from dotenv import load_dotenv
from typing import Any, AsyncGenerator, Dict, List, Optional
from app.core.logging_config import get_config_logger, get_logger
from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags
from app.core.memory.analytics.recent_activity_stats import get_recent_activity_stats
from app.models.user_model import User
from app.core.logging_config import get_logger
from app.utils.sse_utils import format_sse_message
from app.repositories.data_config_repository import DataConfigRepository
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.schemas.memory_config_schema import ConfigurationError
from app.schemas.memory_storage_schema import (
ConfigPilotRun,
ConfigKey,
ConfigParamsCreate,
ConfigParamsDelete,
ConfigPilotRun,
ConfigUpdate,
ConfigUpdateExtracted,
ConfigUpdateForget,
ConfigKey,
)
from app.repositories.data_config_repository import DataConfigRepository
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags
from app.core.memory.analytics.memory_insight import MemoryInsight
from app.core.memory.analytics.recent_activity_stats import get_recent_activity_stats
from app.core.memory.analytics.user_summary import generate_user_summary
from app.repositories.end_user_repository import EndUserRepository
import uuid
from app.services.memory_config_service import MemoryConfigService
from app.utils.sse_utils import format_sse_message
from dotenv import load_dotenv
from sqlalchemy.orm import Session
logger = get_logger(__name__)
config_logger = get_config_logger()
# Load environment variables for Neo4j connector
load_dotenv()
@@ -247,7 +246,6 @@ class DataConfigService: # 数据配置服务类PostgreSQL
RuntimeError: 当管线执行失败时
"""
project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
dbrun_path = os.path.join(project_root, "app", "core", "memory", "dbrun.json")
try:
# 发出初始进度事件
@@ -256,24 +254,12 @@ class DataConfigService: # 数据配置服务类PostgreSQL
"time": int(time.time() * 1000)
})
# 步骤 1: 配置加载和验证(复用现有逻辑
# 步骤 1: 配置加载和验证(数据库优先
payload_cid = str(getattr(payload, "config_id", "") or "").strip()
cid: Optional[str] = payload_cid if payload_cid else None
if not cid and os.path.isfile(dbrun_path):
try:
with open(dbrun_path, "r", encoding="utf-8") as f:
dbrun = json.load(f)
if isinstance(dbrun, dict):
sel = dbrun.get("selections", {})
if isinstance(sel, dict):
fallback_cid = str(sel.get("config_id") or "").strip()
cid = fallback_cid or None
except Exception:
cid = None
if not cid:
raise ValueError("未提供 payload.config_id且 dbrun.json 未设置 selections.config_id禁止启动试运行")
raise ValueError("未提供 payload.config_id禁止启动试运行")
# 验证 dialogue_text 必须提供
dialogue_text = payload.dialogue_text.strip() if payload.dialogue_text else ""
@@ -281,12 +267,16 @@ class DataConfigService: # 数据配置服务类PostgreSQL
if not dialogue_text:
raise ValueError("试运行模式必须提供 dialogue_text 参数")
# 应用内存覆写并刷新常量
from app.core.memory.utils.config.definitions import reload_configuration_from_database
ok_override = reload_configuration_from_database(cid)
if not ok_override:
raise RuntimeError("运行时覆写失败config_id 无效或刷新常量失败")
# Load configuration from database only using centralized manager
try:
config_service = MemoryConfigService(self.db)
memory_config = config_service.load_memory_config(
config_id=int(cid),
service_name="MemoryStorageService.pilot_run_stream"
)
logger.info(f"Configuration loaded successfully: {memory_config.config_name}")
except ConfigurationError as e:
raise RuntimeError(f"Configuration loading failed: {e}")
# 步骤 2: 创建进度回调函数捕获管线进度
# 使用队列在回调和生成器之间传递进度事件
@@ -307,13 +297,14 @@ class DataConfigService: # 数据配置服务类PostgreSQL
async def run_pipeline():
"""在后台执行管线并捕获异常"""
try:
from app.core.memory.main import main as pipeline_main
from app.services.pilot_run_service import run_pilot_extraction
logger.info(f"[PILOT_RUN_STREAM] Calling pipeline_main with dialogue_text length: {len(dialogue_text)}, is_pilot_run=True")
await pipeline_main(
dialogue_text=dialogue_text,
is_pilot_run=True,
progress_callback=progress_callback
logger.info(f"[PILOT_RUN_STREAM] Calling run_pilot_extraction with dialogue_text length: {len(dialogue_text)}")
await run_pilot_extraction(
memory_config=memory_config,
dialogue_text=dialogue_text,
db=self.db,
progress_callback=progress_callback,
)
logger.info("[PILOT_RUN_STREAM] pipeline_main completed")

View File

@@ -0,0 +1,219 @@
"""
Pilot Run Service - 试运行服务
用于执行记忆系统的试运行流程,不保存到 Neo4j。
"""
import os
import re
import time
from datetime import datetime
from typing import Awaitable, Callable, Optional
from app.core.logging_config import get_memory_logger, log_time
from app.core.memory.models.message_models import (
ConversationContext,
ConversationMessage,
DialogData,
)
from app.core.memory.storage_services.extraction_engine.extraction_orchestrator import (
ExtractionOrchestrator,
get_chunked_dialogs_from_preprocessed,
)
from app.core.memory.utils.config.config_utils import (
get_pipeline_config,
)
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.schemas.memory_config_schema import MemoryConfig
from sqlalchemy.orm import Session
logger = get_memory_logger(__name__)
async def run_pilot_extraction(
memory_config: MemoryConfig,
dialogue_text: str,
db: Session,
progress_callback: Optional[Callable[[str, str, Optional[dict]], Awaitable[None]]] = None,
) -> None:
"""
执行试运行模式的知识提取流水线。
Args:
memory_config: 从数据库加载的内存配置对象
dialogue_text: 输入的对话文本
progress_callback: 可选的进度回调函数
- 参数1 (stage): 当前处理阶段标识符
- 参数2 (message): 人类可读的进度消息
- 参数3 (data): 可选的附加数据字典
"""
log_file = "logs/time.log"
os.makedirs(os.path.dirname(log_file), exist_ok=True)
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
with open(log_file, "a", encoding="utf-8") as f:
f.write(f"\n=== Pilot Run Started: {timestamp} ===\n")
pipeline_start = time.time()
neo4j_connector = None
try:
# 步骤 1: 初始化客户端
logger.info("Initializing clients...")
step_start = time.time()
client_factory = MemoryClientFactory(db)
llm_client = client_factory.get_llm_client(str(memory_config.llm_model_id))
embedder_client = client_factory.get_embedder_client(str(memory_config.embedding_model_id))
neo4j_connector = Neo4jConnector()
log_time("Client Initialization", time.time() - step_start, log_file)
# 步骤 2: 解析对话文本
logger.info("Parsing dialogue text...")
step_start = time.time()
# 解析对话文本,支持 "用户:" 和 "AI:" 格式
pattern = r"(用户|AI)[:]\s*([^\n]+(?:\n(?!(?:用户|AI)[:])[^\n]*)*?)"
matches = re.findall(pattern, dialogue_text, re.MULTILINE | re.DOTALL)
messages = [
ConversationMessage(role=r, msg=c.strip())
for r, c in matches
if c.strip()
]
# 如果没有匹配到格式化的对话,将整个文本作为用户消息
if not messages:
messages = [ConversationMessage(role="用户", msg=dialogue_text.strip())]
context = ConversationContext(msgs=messages)
dialog = DialogData(
context=context,
ref_id="pilot_dialog_1",
group_id=str(memory_config.workspace_id),
user_id=str(memory_config.tenant_id),
apply_id=str(memory_config.config_id),
metadata={"source": "pilot_run", "input_type": "frontend_text"},
)
if progress_callback:
await progress_callback("text_preprocessing", "开始预处理文本...")
chunked_dialogs = await get_chunked_dialogs_from_preprocessed(
data=[dialog],
chunker_strategy=memory_config.chunker_strategy,
llm_client=llm_client,
)
logger.info(f"Processed dialogue text: {len(messages)} messages")
# 进度回调:输出每个分块的结果
if progress_callback:
for dlg in chunked_dialogs:
for i, chunk in enumerate(dlg.chunks):
chunk_result = {
"chunk_index": i + 1,
"content": chunk.content[:200] + "..." if len(chunk.content) > 200 else chunk.content,
"full_length": len(chunk.content),
"dialog_id": dlg.id,
"chunker_strategy": memory_config.chunker_strategy,
}
await progress_callback("text_preprocessing_result", f"分块 {i + 1} 处理完成", chunk_result)
preprocessing_summary = {
"total_chunks": sum(len(dlg.chunks) for dlg in chunked_dialogs),
"total_dialogs": len(chunked_dialogs),
"chunker_strategy": memory_config.chunker_strategy,
}
await progress_callback("text_preprocessing_complete", "预处理文本完成", preprocessing_summary)
log_time("Data Loading & Chunking", time.time() - step_start, log_file)
# 步骤 3: 初始化流水线编排器
logger.info("Initializing extraction orchestrator...")
step_start = time.time()
config = get_pipeline_config(memory_config)
logger.info(
f"Pipeline config loaded: enable_llm_dedup_blockwise={config.deduplication.enable_llm_dedup_blockwise}, "
f"enable_llm_disambiguation={config.deduplication.enable_llm_disambiguation}"
)
orchestrator = ExtractionOrchestrator(
llm_client=llm_client,
embedder_client=embedder_client,
connector=neo4j_connector,
config=config,
progress_callback=progress_callback,
embedding_id=str(memory_config.embedding_model_id),
)
log_time("Orchestrator Initialization", time.time() - step_start, log_file)
# 步骤 4: 执行知识提取流水线
logger.info("Running extraction pipeline...")
step_start = time.time()
if progress_callback:
await progress_callback("knowledge_extraction", "正在知识抽取...")
extraction_result = await orchestrator.run(
dialog_data_list=chunked_dialogs,
is_pilot_run=True,
)
# 解包 extraction_result tuple (与 main.py 保持一致)
(
dialogue_nodes,
chunk_nodes,
statement_nodes,
entity_nodes,
statement_chunk_edges,
statement_entity_edges,
entity_edges,
) = extraction_result
log_time("Extraction Pipeline", time.time() - step_start, log_file)
if progress_callback:
await progress_callback("generating_results", "正在生成结果...")
# 步骤 5: 生成记忆摘要(与 main.py 保持一致)
try:
logger.info("Generating memory summaries...")
step_start = time.time()
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.memory_summary import (
memory_summary_generation,
)
summaries = await memory_summary_generation(
chunked_dialogs,
llm_client=llm_client,
embedder_client=embedder_client,
)
log_time("Memory Summary Generation", time.time() - step_start, log_file)
except Exception as e:
logger.error(f"Memory summary step failed: {e}", exc_info=True)
logger.info("Pilot run completed: Skipping Neo4j save")
except Exception as e:
logger.error(f"Pilot run failed: {e}", exc_info=True)
raise
finally:
if neo4j_connector:
try:
await neo4j_connector.close()
except Exception:
pass
total_time = time.time() - pipeline_start
log_time("TOTAL PILOT RUN TIME", total_time, log_file)
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
with open(log_file, "a", encoding="utf-8") as f:
f.write(f"=== Pilot Run Completed: {timestamp} ===\n\n")
logger.info(f"Pilot run complete. Total time: {total_time:.2f}s")

View File

@@ -4,15 +4,15 @@ User Memory Service
处理用户记忆相关的业务逻辑,包括记忆洞察、用户摘要、节点统计和图数据等。
"""
from typing import Dict, List, Optional, Any
import uuid
from sqlalchemy.orm import Session
from typing import Any, Dict, List, Optional
from app.core.logging_config import get_logger
from app.repositories.end_user_repository import EndUserRepository
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.core.memory.analytics.memory_insight import MemoryInsight
from app.core.memory.analytics.user_summary import generate_user_summary
from app.repositories.end_user_repository import EndUserRepository
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from sqlalchemy.orm import Session
logger = get_logger(__name__)
@@ -284,8 +284,7 @@ class UserMemoryService:
# 使用 end_user_id 调用分析函数
try:
logger.info(f"使用 end_user_id={end_user_id} 生成用户摘要")
result = await analytics_user_summary(end_user_id)
summary = result.get("summary", "")
summary = await generate_user_summary(end_user_id)
if not summary:
logger.warning(f"end_user_id {end_user_id} 的用户摘要生成结果为空")
@@ -535,6 +534,112 @@ async def analytics_node_statistics(
return data
async def analytics_memory_types(
db: Session,
end_user_id: Optional[str] = None
) -> List[Dict[str, Any]]:
"""
统计8种记忆类型的数量和百分比
计算规则:
1. 感知记忆 (PERCEPTUAL_MEMORY) = statement + entity
2. 工作记忆 (WORKING_MEMORY) = chunk + entity
3. 短期记忆 (SHORT_TERM_MEMORY) = chunk
4. 长期记忆 (LONG_TERM_MEMORY) = entity
5. 显性记忆 (EXPLICIT_MEMORY) = 1/2 * entity
6. 隐性记忆 (IMPLICIT_MEMORY) = 1/3 * entity
7. 情绪记忆 (EMOTIONAL_MEMORY) = statement
8. 情景记忆 (EPISODIC_MEMORY) = memory_summary
Args:
db: 数据库会话
end_user_id: 可选的终端用户ID (UUID),用于过滤特定用户的节点
Returns:
[
{
"type": str, # 记忆类型枚举值 (如 PERCEPTUAL_MEMORY, WORKING_MEMORY 等)
"count": int, # 该类型的数量
"percentage": float # 该类型在所有记忆中的占比
},
...
]
记忆类型枚举值:
- PERCEPTUAL_MEMORY: 感知记忆
- WORKING_MEMORY: 工作记忆
- SHORT_TERM_MEMORY: 短期记忆
- LONG_TERM_MEMORY: 长期记忆
- EXPLICIT_MEMORY: 显性记忆
- IMPLICIT_MEMORY: 隐性记忆
- EMOTIONAL_MEMORY: 情绪记忆
- EPISODIC_MEMORY: 情景记忆
"""
# 定义需要查询的节点类型
node_types = {
"Statement": "Statement",
"Entity": "ExtractedEntity",
"Chunk": "Chunk",
"MemorySummary": "MemorySummary"
}
# 存储每种节点类型的计数
node_counts = {}
# 查询每种节点类型的数量
for key, node_type in node_types.items():
if end_user_id:
query = f"""
MATCH (n:{node_type})
WHERE n.group_id = $group_id
RETURN count(n) as count
"""
result = await _neo4j_connector.execute_query(query, group_id=end_user_id)
else:
query = f"""
MATCH (n:{node_type})
RETURN count(n) as count
"""
result = await _neo4j_connector.execute_query(query)
# 提取计数结果
count = result[0]["count"] if result and len(result) > 0 else 0
node_counts[key] = count
# 获取各节点类型的数量
statement_count = node_counts.get("Statement", 0)
entity_count = node_counts.get("Entity", 0)
chunk_count = node_counts.get("Chunk", 0)
memory_summary_count = node_counts.get("MemorySummary", 0)
# 按规则计算8种记忆类型的数量使用英文枚举作为key
memory_counts = {
"PERCEPTUAL_MEMORY": statement_count + entity_count, # 感知记忆
"WORKING_MEMORY": chunk_count + entity_count, # 工作记忆
"SHORT_TERM_MEMORY": chunk_count, # 短期记忆
"LONG_TERM_MEMORY": entity_count, # 长期记忆
"EXPLICIT_MEMORY": entity_count // 2, # 显性记忆 (1/2 entity)
"IMPLICIT_MEMORY": entity_count // 3, # 隐性记忆 (1/3 entity)
"EMOTIONAL_MEMORY": statement_count, # 情绪记忆
"EPISODIC_MEMORY": memory_summary_count # 情景记忆
}
# 计算总数
total = sum(memory_counts.values())
# 构建返回数据,包含 type、count 和 percentage
memory_types = []
for memory_type, count in memory_counts.items():
percentage = round((count / total * 100), 2) if total > 0 else 0.0
memory_types.append({
"type": memory_type,
"count": count,
"percentage": percentage
})
return memory_types
async def analytics_graph_data(
db: Session,
end_user_id: str,

View File

@@ -5,29 +5,37 @@ import uuid
from os import getenv
from typing import List, Optional
from sqlalchemy.orm import Session
from app.core.config import settings
from app.core.error_codes import BizCode
from app.core.exceptions import BusinessException, PermissionDeniedException
from app.core.logging_config import get_business_logger
from app.models.user_model import User
from app.models.workspace_model import Workspace, WorkspaceRole, InviteStatus, WorkspaceMember
from app.repositories import workspace_repository
from app.repositories.workspace_invite_repository import WorkspaceInviteRepository
from app.schemas.workspace_schema import (
WorkspaceCreate,
WorkspaceUpdate,
WorkspaceInviteCreate,
WorkspaceInviteResponse,
InviteValidateResponse,
InviteAcceptRequest,
WorkspaceMemberUpdate
)
# 获取业务逻辑专用日志器
business_logger = get_business_logger()
from app.models.workspace_model import (
InviteStatus,
Workspace,
WorkspaceMember,
WorkspaceRole,
)
from app.repositories import workspace_repository
from app.repositories.workspace_invite_repository import WorkspaceInviteRepository
from app.schemas.workspace_schema import (
InviteAcceptRequest,
InviteValidateResponse,
WorkspaceCreate,
WorkspaceInviteCreate,
WorkspaceInviteResponse,
WorkspaceMemberUpdate,
WorkspaceModelsUpdate,
WorkspaceUpdate,
)
from dotenv import load_dotenv
from sqlalchemy.orm import Session
# 获取业务逻辑专用日志器
business_logger = get_business_logger()
load_dotenv()
def switch_workspace(
db: Session,
@@ -131,10 +139,9 @@ def create_workspace(
f"{db_workspace.id} 创建知识库"
)
try:
import os
from app.schemas.knowledge_schema import KnowledgeCreate
from app.models.knowledge_model import KnowledgeType, PermissionType
from app.repositories import knowledge_repository
from app.schemas.knowledge_schema import KnowledgeCreate
# 创建知识库数据
knowledge_data = KnowledgeCreate(
@@ -229,7 +236,7 @@ def get_workspace_members(
)
# 权限检查:工作空间成员或超级管理员可以查看成员列表
from app.core.permissions import permission_service, Subject, Resource, Action
from app.core.permissions import Action, Resource, Subject, permission_service
member = workspace_repository.get_member_in_workspace(
db=db, user_id=user.id, workspace_id=workspace_id
)
@@ -322,7 +329,7 @@ def _check_workspace_admin_permission(db: Session, workspace_id: uuid.UUID, user
)
# 使用统一权限服务检查管理权限
from app.core.permissions import permission_service, Subject, Resource, Action
from app.core.permissions import Action, Resource, Subject, permission_service
# 获取用户的工作空间成员关系
member = workspace_repository.get_member_in_workspace(
@@ -800,3 +807,53 @@ def get_workspace_models_configs(
)
return configs
def update_workspace_models_configs(
db: Session,
workspace_id: uuid.UUID,
models_update: WorkspaceModelsUpdate,
user: User,
) -> Workspace:
"""更新工作空间的模型配置llm, embedding, rerank
Args:
db: 数据库会话
workspace_id: 工作空间ID
models_update: 模型配置更新对象
user: 当前用户
Returns:
Workspace: 更新后的工作空间对象
"""
business_logger.info(f"用户 {user.username} 请求更新工作空间 {workspace_id} 的模型配置")
# 检查用户是否有管理员权限
db_workspace = _check_workspace_admin_permission(db, workspace_id, user)
try:
if models_update.llm is not None:
db_workspace.llm = str(models_update.llm) if models_update.llm else None
business_logger.debug(f"更新LLM配置: {models_update.llm}")
if models_update.embedding is not None:
db_workspace.embedding = str(models_update.embedding) if models_update.embedding else None
business_logger.debug(f"更新嵌入模型配置: {models_update.embedding}")
if models_update.rerank is not None:
db_workspace.rerank = str(models_update.rerank) if models_update.rerank else None
business_logger.debug(f"更新重排序模型配置: {models_update.rerank}")
db.add(db_workspace)
db.commit()
db.refresh(db_workspace)
business_logger.info(
f"工作空间模型配置更新成功: workspace_id={workspace_id}, "
f"llm={db_workspace.llm}, embedding={db_workspace.embedding}, rerank={db_workspace.rerank}"
)
return db_workspace
except Exception as e:
business_logger.error(f"工作空间模型配置更新失败: workspace_id={workspace_id} - {str(e)}")
db.rollback()
raise BusinessException(f"更新模型配置失败: {str(e)}", BizCode.INTERNAL_ERROR)