Merge branch 'develop' of codeup.aliyun.com:redbearai/python/redbear-mem-open into develop
# Conflicts: # api/app/services/workspace_service.py
This commit is contained in:
@@ -3,26 +3,26 @@
|
||||
|
||||
提供 Agent 试运行功能,允许用户在不发布应用的情况下测试配置。
|
||||
"""
|
||||
import time
|
||||
import uuid
|
||||
import json
|
||||
import asyncio
|
||||
import datetime
|
||||
from typing import Dict, Any, Optional, List, AsyncGenerator
|
||||
from langchain.tools import tool
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import select
|
||||
import json
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional
|
||||
|
||||
from app.models import AgentConfig, ModelConfig, ModelApiKey
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.schemas.prompt_schema import render_prompt_message, PromptMessageRole
|
||||
from app.core.rag.nlp.search import knowledge_retrieval
|
||||
from app.models import AgentConfig, ModelApiKey, ModelConfig
|
||||
from app.schemas.prompt_schema import PromptMessageRole, render_prompt_message
|
||||
from app.services.langchain_tool_server import Search
|
||||
from app.services.memory_agent_service import MemoryAgentService
|
||||
from app.services.model_parameter_merger import ModelParameterMerger
|
||||
from app.core.rag.nlp.search import knowledge_retrieval
|
||||
from app.services.langchain_tool_server import Search
|
||||
from langchain.tools import tool
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
logger = get_business_logger()
|
||||
class KnowledgeRetrievalInput(BaseModel):
|
||||
@@ -83,17 +83,23 @@ def create_long_term_memory_tool(memory_config: Dict[str, Any], end_user_id: str
|
||||
"""
|
||||
logger.info(f" 长期记忆工具被调用!question={question}, user={end_user_id}")
|
||||
try:
|
||||
memory_content = asyncio.run(
|
||||
MemoryAgentService().read_memory(
|
||||
group_id=end_user_id,
|
||||
message=question,
|
||||
history=[],
|
||||
search_switch="1",
|
||||
config_id=config_id,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id
|
||||
from app.db import get_db
|
||||
db = next(get_db())
|
||||
try:
|
||||
memory_content = asyncio.run(
|
||||
MemoryAgentService().read_memory(
|
||||
group_id=end_user_id,
|
||||
message=question,
|
||||
history=[],
|
||||
search_switch="1",
|
||||
config_id=config_id,
|
||||
db=db,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id
|
||||
)
|
||||
)
|
||||
)
|
||||
finally:
|
||||
db.close()
|
||||
logger.info(f'用户ID:Agent:{end_user_id}')
|
||||
logger.debug("调用长期记忆 API", extra={"question": question, "end_user_id": end_user_id})
|
||||
|
||||
@@ -713,9 +719,9 @@ class DraftRunService:
|
||||
Raises:
|
||||
BusinessException: 当指定的会话不存在时
|
||||
"""
|
||||
from app.services.conversation_service import ConversationService
|
||||
from app.schemas.conversation_schema import ConversationCreate
|
||||
from app.models import Conversation as ConversationModel
|
||||
from app.schemas.conversation_schema import ConversationCreate
|
||||
from app.services.conversation_service import ConversationService
|
||||
|
||||
conversation_service = ConversationService(self.db)
|
||||
|
||||
|
||||
@@ -7,14 +7,15 @@ Classes:
|
||||
EmotionAnalyticsService: 情绪分析服务,提供各种情绪分析功能
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, Optional, List
|
||||
import statistics
|
||||
import json
|
||||
from pydantic import BaseModel, Field
|
||||
import statistics
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.repositories.neo4j.emotion_repository import EmotionRepository
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.core.logging_config import get_business_logger
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
logger = get_business_logger()
|
||||
|
||||
@@ -64,19 +65,9 @@ class EmotionAnalyticsService:
|
||||
"""获取情绪标签统计
|
||||
|
||||
查询指定用户的情绪类型分布,包括计数、百分比和平均强度。
|
||||
|
||||
Args:
|
||||
end_user_id: 宿主ID(用户组ID)
|
||||
emotion_type: 可选的情绪类型过滤
|
||||
start_date: 可选的开始日期(ISO格式)
|
||||
end_date: 可选的结束日期(ISO格式)
|
||||
limit: 返回结果的最大数量
|
||||
|
||||
Returns:
|
||||
Dict: 包含情绪标签统计的响应数据:
|
||||
- tags: 情绪标签列表
|
||||
- total_count: 总情绪数量
|
||||
- time_range: 时间范围信息
|
||||
确保返回所有6个情绪维度(joy、sadness、anger、fear、surprise、neutral),
|
||||
即使某些维度没有数据也会返回count=0的记录。
|
||||
|
||||
"""
|
||||
try:
|
||||
logger.info(f"获取情绪标签统计: user={end_user_id}, type={emotion_type}, "
|
||||
@@ -91,8 +82,34 @@ class EmotionAnalyticsService:
|
||||
limit=limit
|
||||
)
|
||||
|
||||
# 定义所有6个情绪维度
|
||||
all_emotion_types = ['joy', 'sadness', 'anger', 'fear', 'surprise', 'neutral']
|
||||
|
||||
# 将查询结果转换为字典,方便查找
|
||||
tags_dict = {tag["emotion_type"]: tag for tag in tags}
|
||||
|
||||
# 补全缺失的情绪维度
|
||||
complete_tags = []
|
||||
for emotion in all_emotion_types:
|
||||
if emotion in tags_dict:
|
||||
complete_tags.append(tags_dict[emotion])
|
||||
else:
|
||||
# 如果该情绪类型不存在,添加默认值
|
||||
complete_tags.append({
|
||||
"emotion_type": emotion,
|
||||
"count": 0,
|
||||
"percentage": 0.0,
|
||||
"avg_intensity": 0.0
|
||||
})
|
||||
|
||||
# 计算总数
|
||||
total_count = sum(tag["count"] for tag in tags)
|
||||
total_count = sum(tag["count"] for tag in complete_tags)
|
||||
|
||||
# 如果有数据,重新计算百分比(因为补全了0值项)
|
||||
if total_count > 0:
|
||||
for tag in complete_tags:
|
||||
if tag["count"] > 0:
|
||||
tag["percentage"] = round((tag["count"] / total_count) * 100, 2)
|
||||
|
||||
# 构建时间范围信息
|
||||
time_range = {}
|
||||
@@ -103,12 +120,12 @@ class EmotionAnalyticsService:
|
||||
|
||||
# 格式化响应
|
||||
response = {
|
||||
"tags": tags,
|
||||
"tags": complete_tags,
|
||||
"total_count": total_count,
|
||||
"time_range": time_range if time_range else None
|
||||
}
|
||||
|
||||
logger.info(f"情绪标签统计完成: total_count={total_count}, tags_count={len(tags)}")
|
||||
logger.info(f"情绪标签统计完成: total_count={total_count}, tags_count={len(complete_tags)}")
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
@@ -454,7 +471,7 @@ class EmotionAnalyticsService:
|
||||
async def generate_emotion_suggestions(
|
||||
self,
|
||||
end_user_id: str,
|
||||
config_id: Optional[int] = None
|
||||
db: Session,
|
||||
) -> Dict[str, Any]:
|
||||
"""生成个性化情绪建议
|
||||
|
||||
@@ -462,7 +479,7 @@ class EmotionAnalyticsService:
|
||||
|
||||
Args:
|
||||
end_user_id: 宿主ID(用户组ID)
|
||||
config_id: 配置ID(可选,用于从数据库加载LLM配置)
|
||||
db: 数据库会话
|
||||
|
||||
Returns:
|
||||
Dict: 包含个性化建议的响应:
|
||||
@@ -470,14 +487,32 @@ class EmotionAnalyticsService:
|
||||
- suggestions: 建议列表(3-5条)
|
||||
"""
|
||||
try:
|
||||
logger.info(f"生成个性化情绪建议: user={end_user_id}, config_id={config_id}")
|
||||
logger.info(f"生成个性化情绪建议: user={end_user_id}")
|
||||
|
||||
# 1. 如果提供了 config_id,从数据库加载配置
|
||||
if config_id is not None:
|
||||
from app.core.memory.utils.config.definitions import reload_configuration_from_database
|
||||
config_loaded = reload_configuration_from_database(config_id)
|
||||
if not config_loaded:
|
||||
logger.warning(f"无法加载配置 config_id={config_id},将使用默认配置")
|
||||
# 1. 从 end_user_id 获取关联的 memory_config_id
|
||||
llm_client = None
|
||||
try:
|
||||
from app.services.memory_agent_service import (
|
||||
get_end_user_connected_config,
|
||||
)
|
||||
|
||||
connected_config = get_end_user_connected_config(end_user_id, db)
|
||||
config_id = connected_config.get("memory_config_id")
|
||||
|
||||
if config_id is not None:
|
||||
from app.services.memory_config_service import (
|
||||
MemoryConfigService,
|
||||
)
|
||||
config_service = MemoryConfigService(db)
|
||||
memory_config = config_service.load_memory_config(
|
||||
config_id=int(config_id),
|
||||
service_name="EmotionAnalyticsService.generate_emotion_suggestions"
|
||||
)
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
factory = MemoryClientFactory(db)
|
||||
llm_client = factory.get_llm_client(str(memory_config.llm_model_id))
|
||||
except Exception as e:
|
||||
logger.warning(f"无法获取 end_user {end_user_id} 的配置,将使用默认配置: {e}")
|
||||
|
||||
# 2. 获取情绪健康数据
|
||||
health_data = await self.calculate_emotion_health_index(end_user_id, time_range="30d")
|
||||
@@ -498,8 +533,9 @@ class EmotionAnalyticsService:
|
||||
prompt = await self._build_suggestion_prompt(health_data, patterns, user_profile)
|
||||
|
||||
# 7. 调用LLM生成建议(使用配置中的LLM)
|
||||
from app.core.memory.utils.llm.llm_utils import get_llm_client
|
||||
llm_client = get_llm_client()
|
||||
if llm_client is None:
|
||||
# 无法获取配置时,抛出错误而不是使用默认配置
|
||||
raise ValueError("无法获取LLM配置,请确保end_user关联了有效的memory_config")
|
||||
|
||||
# 将 prompt 转换为 messages 格式
|
||||
messages = [
|
||||
@@ -598,7 +634,9 @@ class EmotionAnalyticsService:
|
||||
Returns:
|
||||
str: LLM prompt
|
||||
"""
|
||||
from app.core.memory.utils.prompt.prompt_utils import render_emotion_suggestions_prompt
|
||||
from app.core.memory.utils.prompt.prompt_utils import (
|
||||
render_emotion_suggestions_prompt,
|
||||
)
|
||||
|
||||
prompt = await render_emotion_suggestions_prompt(
|
||||
health_data=health_data,
|
||||
|
||||
@@ -9,10 +9,12 @@ Classes:
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
from app.core.memory.models.emotion_models import EmotionExtraction
|
||||
from app.models.data_config_model import DataConfig
|
||||
from app.core.memory.utils.llm.llm_utils import get_llm_client
|
||||
|
||||
from app.core.memory.llm_tools.llm_client import LLMClientException
|
||||
from app.core.memory.models.emotion_models import EmotionExtraction
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.db import get_db_context
|
||||
from app.models.data_config_model import DataConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -50,7 +52,9 @@ class EmotionExtractionService:
|
||||
"""
|
||||
if self.llm_client is None or model_id:
|
||||
effective_model_id = model_id or self.llm_id
|
||||
self.llm_client = get_llm_client(effective_model_id)
|
||||
with get_db_context() as db:
|
||||
factory = MemoryClientFactory(db)
|
||||
self.llm_client = factory.get_llm_client(effective_model_id)
|
||||
return self.llm_client
|
||||
|
||||
async def extract_emotion(
|
||||
@@ -142,7 +146,9 @@ class EmotionExtractionService:
|
||||
Returns:
|
||||
Formatted prompt string for LLM
|
||||
"""
|
||||
from app.core.memory.utils.prompt.prompt_utils import render_emotion_extraction_prompt
|
||||
from app.core.memory.utils.prompt.prompt_utils import (
|
||||
render_emotion_extraction_prompt,
|
||||
)
|
||||
|
||||
prompt = await render_emotion_extraction_prompt(
|
||||
statement=statement,
|
||||
|
||||
@@ -4,50 +4,48 @@ Memory Agent Service
|
||||
Handles business logic for memory agent operations including read/write services,
|
||||
health checks, and message type classification.
|
||||
"""
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
import json
|
||||
import uuid
|
||||
from threading import Lock
|
||||
from typing import Dict, List, Optional, Any, AsyncGenerator
|
||||
from app.services.memory_konwledges_server import write_rag
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional
|
||||
|
||||
import redis
|
||||
from langchain_mcp_adapters.client import MultiServerMCPClient
|
||||
from langchain_mcp_adapters.tools import load_mcp_tools
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import func
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.logging_config import get_logger
|
||||
from app.core.logging_config import get_config_logger, get_logger
|
||||
from app.core.memory.agent.langgraph_graph.read_graph import make_read_graph
|
||||
from app.core.memory.agent.langgraph_graph.write_graph import make_write_graph
|
||||
from app.core.memory.agent.logger_file.log_streamer import LogStreamer
|
||||
from app.core.memory.agent.utils.llm_tools import PROJECT_ROOT_
|
||||
from app.core.memory.agent.utils.mcp_tools import get_mcp_server_config
|
||||
from app.core.memory.agent.utils.type_classifier import status_typle
|
||||
from app.db import get_db
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags
|
||||
from app.core.memory.utils.llm.llm_utils import get_llm_client
|
||||
from app.schemas.memory_storage_schema import ApiResponse, ok, fail
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.db import get_db_context
|
||||
from app.models.knowledge_model import Knowledge, KnowledgeType
|
||||
from app.repositories.data_config_repository import DataConfigRepository
|
||||
from app.core.memory.agent.logger_file.log_streamer import LogStreamer
|
||||
from app.services.memory_konwledges_server import memory_konwledges_up, SimpleUser, find_document_id_by_kb_and_filename
|
||||
from app.core.memory.utils.config.definitions import reload_configuration_from_database
|
||||
from app.schemas.file_schema import CustomTextFileCreate
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.schemas.memory_config_schema import ConfigurationError, MemoryConfig
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
from app.services.memory_konwledges_server import (
|
||||
write_rag,
|
||||
)
|
||||
from langchain_mcp_adapters.client import MultiServerMCPClient
|
||||
from langchain_mcp_adapters.tools import load_mcp_tools
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
try:
|
||||
from app.core.memory.utils.log.audit_logger import audit_logger
|
||||
except ImportError:
|
||||
audit_logger = None
|
||||
logger = get_logger(__name__)
|
||||
config_logger = get_config_logger()
|
||||
|
||||
# Initialize Neo4j connector for analytics functions
|
||||
_neo4j_connector = Neo4jConnector()
|
||||
db_gen = get_db()
|
||||
db = next(db_gen)
|
||||
|
||||
|
||||
class MemoryAgentService:
|
||||
"""Service for memory agent operations"""
|
||||
@@ -257,14 +255,17 @@ class MemoryAgentService:
|
||||
logger.info("Log streaming completed, cleaning up resources")
|
||||
# LogStreamer uses context manager for file handling, so cleanup is automatic
|
||||
|
||||
async def write_memory(self, group_id: str, message: str, config_id: str,storage_type:str,user_rag_memory_id:str) -> str:
|
||||
async def write_memory(self, group_id: str, message: str, config_id: Optional[str], db: Session, storage_type: str, user_rag_memory_id: str) -> str:
|
||||
"""
|
||||
Process write operation with config_id
|
||||
|
||||
Args:
|
||||
group_id: Group identifier
|
||||
group_id: Group identifier (also used as end_user_id)
|
||||
message: Message to write
|
||||
config_id: Configuration ID from database
|
||||
db: SQLAlchemy database session
|
||||
storage_type: Storage type (neo4j or rag)
|
||||
user_rag_memory_id: User RAG memory ID
|
||||
|
||||
Returns:
|
||||
Write operation result status
|
||||
@@ -272,24 +273,40 @@ class MemoryAgentService:
|
||||
Raises:
|
||||
ValueError: If config loading fails or write operation fails
|
||||
"""
|
||||
if config_id==None:
|
||||
config_id = os.getenv("config_id")
|
||||
# Resolve config_id if None using end_user's connected config
|
||||
if config_id is None:
|
||||
try:
|
||||
connected_config = get_end_user_connected_config(group_id, db)
|
||||
config_id = connected_config.get("memory_config_id")
|
||||
if config_id is None:
|
||||
raise ValueError(f"No memory configuration found for end_user {group_id}. Please ensure the user has a connected memory configuration.")
|
||||
except Exception as e:
|
||||
if "No memory configuration found" in str(e):
|
||||
raise # Re-raise our specific error
|
||||
logger.error(f"Failed to get connected config for end_user {group_id}: {e}")
|
||||
raise ValueError(f"Unable to determine memory configuration for end_user {group_id}: {e}")
|
||||
|
||||
import time
|
||||
start_time = time.time()
|
||||
|
||||
# 如果 config_id 为 None,使用默认值 "17"
|
||||
config_loaded = reload_configuration_from_database(config_id)
|
||||
if not config_loaded:
|
||||
error_msg = f"Failed to load configuration for config_id: {config_id}"
|
||||
# Load configuration from database only
|
||||
try:
|
||||
config_service = MemoryConfigService(db)
|
||||
memory_config = config_service.load_memory_config(
|
||||
config_id=config_id,
|
||||
service_name="MemoryAgentService"
|
||||
)
|
||||
logger.info(f"Configuration loaded successfully: {memory_config.config_name}")
|
||||
except ConfigurationError as e:
|
||||
error_msg = f"Failed to load configuration for config_id: {config_id}: {e}"
|
||||
logger.error(error_msg)
|
||||
|
||||
# 记录失败的操作
|
||||
# Log failed operation
|
||||
if audit_logger:
|
||||
duration = time.time() - start_time
|
||||
audit_logger.log_operation( operation="WRITE", config_id=config_id, group_id=group_id, success=False, duration=duration, error=error_msg )
|
||||
audit_logger.log_operation(operation="WRITE", config_id=config_id, group_id=group_id, success=False, duration=duration, error=error_msg)
|
||||
|
||||
raise ValueError(error_msg)
|
||||
logger.info(f"Configuration loaded successfully for config_id: {config_id}")
|
||||
mcp_config = get_mcp_server_config()
|
||||
client = MultiServerMCPClient(mcp_config)
|
||||
|
||||
@@ -300,20 +317,43 @@ class MemoryAgentService:
|
||||
async with client.session("data_flow") as session:
|
||||
logger.debug("Connected to MCP Server: data_flow")
|
||||
tools = await load_mcp_tools(session)
|
||||
workflow_errors = [] # Track errors from workflow
|
||||
|
||||
# Pass config_id to the graph workflow
|
||||
async with make_write_graph(group_id, tools, group_id, group_id, config_id=config_id) as graph:
|
||||
# Pass memory_config to the graph workflow
|
||||
async with make_write_graph(group_id, tools, group_id, group_id, memory_config=memory_config) as graph:
|
||||
logger.debug("Write graph created successfully")
|
||||
|
||||
config = {"configurable": {"thread_id": group_id}}
|
||||
|
||||
async for event in graph.astream(
|
||||
{"messages": message, "config_id": config_id},
|
||||
{"messages": message, "memory_config": memory_config, "errors": []},
|
||||
stream_mode="values",
|
||||
config=config
|
||||
):
|
||||
messages = event.get('messages')
|
||||
return self.writer_messages_deal(messages,start_time,group_id,config_id,message)
|
||||
# Capture any errors from the state
|
||||
if event.get('errors'):
|
||||
workflow_errors.extend(event.get('errors', []))
|
||||
|
||||
# Check for workflow errors
|
||||
if workflow_errors:
|
||||
error_details = "; ".join([f"{e['tool']}: {e['error']}" for e in workflow_errors])
|
||||
logger.error(f"Write workflow failed with errors: {error_details}")
|
||||
|
||||
if audit_logger:
|
||||
duration = time.time() - start_time
|
||||
audit_logger.log_operation(
|
||||
operation="WRITE",
|
||||
config_id=config_id,
|
||||
group_id=group_id,
|
||||
success=False,
|
||||
duration=duration,
|
||||
error=error_details
|
||||
)
|
||||
|
||||
raise ValueError(f"Write workflow failed: {error_details}")
|
||||
|
||||
return self.writer_messages_deal(messages, start_time, group_id, config_id, message)
|
||||
|
||||
async def read_memory(
|
||||
self,
|
||||
@@ -321,7 +361,8 @@ class MemoryAgentService:
|
||||
message: str,
|
||||
history: List[Dict],
|
||||
search_switch: str,
|
||||
config_id: str,
|
||||
config_id: Optional[str],
|
||||
db: Session,
|
||||
storage_type: str,
|
||||
user_rag_memory_id: str
|
||||
) -> Dict:
|
||||
@@ -334,11 +375,14 @@ class MemoryAgentService:
|
||||
- "2": Direct answer based on context
|
||||
|
||||
Args:
|
||||
group_id: Group identifier
|
||||
group_id: Group identifier (also used as end_user_id)
|
||||
message: User message
|
||||
history: Conversation history
|
||||
search_switch: Search mode switch
|
||||
config_id: Configuration ID from database
|
||||
db: SQLAlchemy database session
|
||||
storage_type: Storage type (neo4j or rag)
|
||||
user_rag_memory_id: User RAG memory ID
|
||||
|
||||
Returns:
|
||||
Dict with 'answer' and 'intermediate_outputs' keys
|
||||
@@ -350,8 +394,18 @@ class MemoryAgentService:
|
||||
import time
|
||||
start_time = time.time()
|
||||
|
||||
if config_id==None:
|
||||
config_id = os.getenv("config_id")
|
||||
# Resolve config_id if None using end_user's connected config
|
||||
if config_id is None:
|
||||
try:
|
||||
connected_config = get_end_user_connected_config(group_id, db)
|
||||
config_id = connected_config.get("memory_config_id")
|
||||
if config_id is None:
|
||||
raise ValueError(f"No memory configuration found for end_user {group_id}. Please ensure the user has a connected memory configuration.")
|
||||
except Exception as e:
|
||||
if "No memory configuration found" in str(e):
|
||||
raise # Re-raise our specific error
|
||||
logger.error(f"Failed to get connected config for end_user {group_id}: {e}")
|
||||
raise ValueError(f"Unable to determine memory configuration for end_user {group_id}: {e}")
|
||||
|
||||
logger.info(f"Read operation for group {group_id} with config_id {config_id}")
|
||||
|
||||
@@ -365,15 +419,19 @@ class MemoryAgentService:
|
||||
group_lock = self.get_group_lock(group_id)
|
||||
|
||||
with group_lock:
|
||||
# Step 1: Load configuration from database
|
||||
from app.core.memory.utils.config.definitions import reload_configuration_from_database
|
||||
|
||||
config_loaded = reload_configuration_from_database(config_id)
|
||||
if not config_loaded:
|
||||
error_msg = f"Failed to load configuration for config_id: {config_id}"
|
||||
# Step 1: Load configuration from database only
|
||||
try:
|
||||
config_service = MemoryConfigService(db)
|
||||
memory_config = config_service.load_memory_config(
|
||||
config_id=config_id,
|
||||
service_name="MemoryAgentService"
|
||||
)
|
||||
logger.info(f"Configuration loaded successfully: {memory_config.config_name}")
|
||||
except ConfigurationError as e:
|
||||
error_msg = f"Failed to load configuration for config_id: {config_id}: {e}"
|
||||
logger.error(error_msg)
|
||||
|
||||
# 记录失败的操作
|
||||
# Log failed operation
|
||||
if audit_logger:
|
||||
duration = time.time() - start_time
|
||||
audit_logger.log_operation(
|
||||
@@ -387,8 +445,6 @@ class MemoryAgentService:
|
||||
|
||||
raise ValueError(error_msg)
|
||||
|
||||
logger.info(f"Configuration loaded successfully for config_id: {config_id}")
|
||||
|
||||
# Step 2: Prepare history
|
||||
history.append({"role": "user", "content": message})
|
||||
logger.debug(f"Group ID:{group_id}, Message:{message}, History:{history}, Config ID:{config_id}")
|
||||
@@ -404,45 +460,52 @@ class MemoryAgentService:
|
||||
intermediate_outputs = []
|
||||
seen_intermediates = set() # Track seen intermediate outputs to avoid duplicates
|
||||
|
||||
# Pass config_id to the graph workflow
|
||||
async with make_read_graph(group_id, tools, search_switch, group_id, group_id, config_id=config_id,storage_type=storage_type,user_rag_memory_id=user_rag_memory_id) as graph:
|
||||
# Pass memory_config to the graph workflow
|
||||
async with make_read_graph(group_id, tools, search_switch, group_id, group_id, memory_config=memory_config, storage_type=storage_type, user_rag_memory_id=user_rag_memory_id) as graph:
|
||||
start = time.time()
|
||||
config = {"configurable": {"thread_id": group_id}}
|
||||
workflow_errors = [] # Track errors from workflow
|
||||
|
||||
async for event in graph.astream(
|
||||
{"messages": history, "config_id": config_id},
|
||||
{"messages": history, "memory_config": memory_config, "errors": []},
|
||||
stream_mode="values",
|
||||
config=config
|
||||
):
|
||||
messages = event.get('messages')
|
||||
# Capture any errors from the state
|
||||
if event.get('errors'):
|
||||
workflow_errors.extend(event.get('errors', []))
|
||||
|
||||
for msg in messages:
|
||||
msg_content = msg.content
|
||||
msg_role = msg.__class__.__name__.lower().replace("message", "")
|
||||
outputs.append({
|
||||
"role": msg.__class__.__name__.lower().replace("message", ""),
|
||||
"role": msg_role,
|
||||
"content": msg_content
|
||||
})
|
||||
|
||||
# Extract intermediate outputs
|
||||
if hasattr(msg, 'content'):
|
||||
try:
|
||||
# Debug: log message type and content preview
|
||||
msg_type = msg.__class__.__name__
|
||||
content_preview = str(msg_content)[:200] if msg_content else "empty"
|
||||
logger.debug(f"Processing message type={msg_type}, content preview={content_preview}")
|
||||
# Handle MCP content format: [{'type': 'text', 'text': '...'}]
|
||||
content_to_parse = msg_content
|
||||
if isinstance(msg_content, list):
|
||||
for block in msg_content:
|
||||
if isinstance(block, dict) and block.get('type') == 'text':
|
||||
content_to_parse = block.get('text', '')
|
||||
break
|
||||
else:
|
||||
continue # No text block found
|
||||
|
||||
# Try to parse content as JSON
|
||||
if isinstance(msg_content, str):
|
||||
if isinstance(content_to_parse, str):
|
||||
try:
|
||||
parsed = json.loads(msg_content)
|
||||
parsed = json.loads(content_to_parse)
|
||||
if isinstance(parsed, dict):
|
||||
# Debug: log what keys are in parsed
|
||||
logger.debug(f"Parsed dict keys: {list(parsed.keys())}")
|
||||
|
||||
# Check for single intermediate output
|
||||
if '_intermediate' in parsed:
|
||||
intermediate_data = parsed['_intermediate']
|
||||
output_key = self._create_intermediate_key(intermediate_data)
|
||||
logger.debug(f"Found _intermediate: {intermediate_data.get('type', 'unknown')}")
|
||||
|
||||
if output_key not in seen_intermediates:
|
||||
seen_intermediates.add(output_key)
|
||||
@@ -450,34 +513,14 @@ class MemoryAgentService:
|
||||
|
||||
# Check for multiple intermediate outputs (from Retrieve)
|
||||
if '_intermediates' in parsed:
|
||||
logger.debug(f"Found _intermediates list with {len(parsed['_intermediates'])} items")
|
||||
for intermediate_data in parsed['_intermediates']:
|
||||
output_key = self._create_intermediate_key(intermediate_data)
|
||||
logger.debug(f"Processing intermediate: {intermediate_data.get('type', 'unknown')}")
|
||||
|
||||
if output_key not in seen_intermediates:
|
||||
seen_intermediates.add(output_key)
|
||||
intermediate_outputs.append(self._format_intermediate_output(intermediate_data))
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
pass
|
||||
elif isinstance(msg_content, dict):
|
||||
# Check for single intermediate output
|
||||
if '_intermediate' in msg_content:
|
||||
intermediate_data = msg_content['_intermediate']
|
||||
output_key = self._create_intermediate_key(intermediate_data)
|
||||
|
||||
if output_key not in seen_intermediates:
|
||||
seen_intermediates.add(output_key)
|
||||
intermediate_outputs.append(self._format_intermediate_output(intermediate_data))
|
||||
|
||||
# Check for multiple intermediate outputs (from Retrieve)
|
||||
if '_intermediates' in msg_content:
|
||||
for intermediate_data in msg_content['_intermediates']:
|
||||
output_key = self._create_intermediate_key(intermediate_data)
|
||||
|
||||
if output_key not in seen_intermediates:
|
||||
seen_intermediates.add(output_key)
|
||||
intermediate_outputs.append(self._format_intermediate_output(intermediate_data))
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to extract intermediate output: {e}")
|
||||
|
||||
@@ -489,18 +532,57 @@ class MemoryAgentService:
|
||||
for messages in outputs:
|
||||
if messages['role'] == 'tool':
|
||||
message = messages['content']
|
||||
|
||||
# Handle MCP content format: [{'type': 'text', 'text': '...'}]
|
||||
if isinstance(message, list):
|
||||
# Extract text from MCP content blocks
|
||||
for block in message:
|
||||
if isinstance(block, dict) and block.get('type') == 'text':
|
||||
message = block.get('text', '')
|
||||
break
|
||||
else:
|
||||
continue # No text block found
|
||||
|
||||
try:
|
||||
message = json.loads(message) if isinstance(message, str) else message
|
||||
if isinstance(message, dict) and message.get('status') != '':
|
||||
summary_result = message.get('summary_result')
|
||||
if summary_result:
|
||||
final_answer = summary_result
|
||||
parsed = json.loads(message) if isinstance(message, str) else message
|
||||
if isinstance(parsed, dict):
|
||||
if parsed.get('status') == 'success':
|
||||
summary_result = parsed.get('summary_result')
|
||||
if summary_result:
|
||||
final_answer = summary_result
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
pass
|
||||
|
||||
# 记录成功的操作
|
||||
total_duration = time.time() - start_time
|
||||
if audit_logger:
|
||||
|
||||
# Check for workflow errors
|
||||
if workflow_errors:
|
||||
error_details = "; ".join([f"{e['tool']}: {e['error']}" for e in workflow_errors])
|
||||
logger.warning(f"Read workflow completed with errors: {error_details}")
|
||||
|
||||
if audit_logger:
|
||||
audit_logger.log_operation(
|
||||
operation="READ",
|
||||
config_id=config_id,
|
||||
group_id=group_id,
|
||||
success=False,
|
||||
duration=total_duration,
|
||||
error=error_details,
|
||||
details={
|
||||
"search_switch": search_switch,
|
||||
"history_length": len(history),
|
||||
"intermediate_outputs_count": len(intermediate_outputs),
|
||||
"has_answer": bool(final_answer),
|
||||
"errors": workflow_errors
|
||||
}
|
||||
)
|
||||
|
||||
# Raise error if no answer was produced
|
||||
if not final_answer:
|
||||
raise ValueError(f"Read workflow failed: {error_details}")
|
||||
|
||||
if audit_logger and not workflow_errors:
|
||||
audit_logger.log_operation(
|
||||
operation="READ",
|
||||
config_id=config_id,
|
||||
@@ -612,19 +694,29 @@ class MemoryAgentService:
|
||||
else:
|
||||
return output
|
||||
|
||||
async def classify_message_type(self, message: str) -> Dict:
|
||||
async def classify_message_type(self, message: str, config_id: int, db: Session) -> Dict:
|
||||
"""
|
||||
Determine the type of user message (read or write)
|
||||
Updated to eliminate global variables in favor of explicit parameters.
|
||||
|
||||
Args:
|
||||
message: User message to classify
|
||||
config_id: Configuration ID to load LLM model from database
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
Type classification result
|
||||
"""
|
||||
logger.info("Classifying message type")
|
||||
|
||||
status = await status_typle(message)
|
||||
# Load configuration to get LLM model ID
|
||||
config_service = MemoryConfigService(db)
|
||||
memory_config = config_service.load_memory_config(
|
||||
config_id=config_id,
|
||||
service_name="MemoryAgentService"
|
||||
)
|
||||
|
||||
status = await status_typle(message, memory_config.llm_model_id)
|
||||
logger.debug(f"Message type: {status}")
|
||||
return status
|
||||
|
||||
@@ -790,7 +882,9 @@ class MemoryAgentService:
|
||||
async def get_user_profile(
|
||||
self,
|
||||
end_user_id: Optional[str] = None,
|
||||
current_user_id: Optional[str] = None
|
||||
current_user_id: Optional[str] = None,
|
||||
llm_id: Optional[str] = None,
|
||||
db: Session = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
获取用户详情,包含:
|
||||
@@ -801,6 +895,8 @@ class MemoryAgentService:
|
||||
参数:
|
||||
- end_user_id: 用户ID(可选)
|
||||
- current_user_id: 当前登录用户的ID(保留参数)
|
||||
- llm_id: LLM模型ID(用于生成标签,可选,如果不提供则跳过标签生成)
|
||||
- db: 数据库会话(可选)
|
||||
|
||||
返回格式:
|
||||
{
|
||||
@@ -817,7 +913,7 @@ class MemoryAgentService:
|
||||
|
||||
# 1. 根据 end_user_id 获取 end_user_name
|
||||
try:
|
||||
if end_user_id:
|
||||
if end_user_id and db:
|
||||
from app.repositories import end_user_repository
|
||||
from app.schemas.end_user_schema import EndUser as EndUserSchema
|
||||
|
||||
@@ -862,15 +958,19 @@ class MemoryAgentService:
|
||||
|
||||
await connector.close()
|
||||
|
||||
if not statements:
|
||||
if not statements or not llm_id:
|
||||
result["tags"] = []
|
||||
if not llm_id and statements:
|
||||
logger.warning("llm_id not provided, skipping tag generation")
|
||||
else:
|
||||
# 构建摘要文本
|
||||
summary_text = f"用户语句样本:{' | '.join(statements[:20])}\n核心实体:{', '.join(entities)}"
|
||||
logger.debug(f"User data found: {len(statements)} statements, {len(entities)} entities")
|
||||
|
||||
# 使用LLM提取标签
|
||||
llm_client = get_llm_client()
|
||||
with get_db_context() as db:
|
||||
factory = MemoryClientFactory(db)
|
||||
llm_client = factory.get_llm_client(llm_id)
|
||||
|
||||
# 定义标签提取的结构
|
||||
class UserTags(BaseModel):
|
||||
@@ -1032,4 +1132,69 @@ class MemoryAgentService:
|
||||
# "msg": "解析失败",
|
||||
# "error_code": "DOC_PARSE_ERROR",
|
||||
# "data": {"error": str(e)}
|
||||
# }
|
||||
# }
|
||||
|
||||
|
||||
def get_end_user_connected_config(end_user_id: str, db: Session) -> Dict[str, Any]:
|
||||
"""
|
||||
获取终端用户关联的记忆配置
|
||||
|
||||
通过以下流程获取配置:
|
||||
1. 根据 end_user_id 获取用户的 app_id
|
||||
2. 获取该应用的最新发布版本
|
||||
3. 从发布版本的 config 字段中提取 memory_config_id
|
||||
|
||||
Args:
|
||||
end_user_id: 终端用户ID
|
||||
db: 数据库会话
|
||||
|
||||
Returns:
|
||||
包含 memory_config_id 和相关信息的字典
|
||||
|
||||
Raises:
|
||||
ValueError: 当终端用户不存在或应用未发布时
|
||||
"""
|
||||
from app.models.app_release_model import AppRelease
|
||||
from app.models.end_user_model import EndUser
|
||||
from sqlalchemy import select
|
||||
|
||||
logger.info(f"Getting connected config for end_user: {end_user_id}")
|
||||
|
||||
# 1. 获取 end_user 及其 app_id
|
||||
end_user = db.query(EndUser).filter(EndUser.id == end_user_id).first()
|
||||
if not end_user:
|
||||
logger.warning(f"End user not found: {end_user_id}")
|
||||
raise ValueError(f"终端用户不存在: {end_user_id}")
|
||||
|
||||
app_id = end_user.app_id
|
||||
logger.debug(f"Found end_user app_id: {app_id}")
|
||||
|
||||
# 2. 获取该应用的最新发布版本
|
||||
stmt = (
|
||||
select(AppRelease)
|
||||
.where(AppRelease.app_id == app_id, AppRelease.is_active.is_(True))
|
||||
.order_by(AppRelease.version.desc())
|
||||
)
|
||||
latest_release = db.scalars(stmt).first()
|
||||
|
||||
if not latest_release:
|
||||
logger.warning(f"No active release found for app: {app_id}")
|
||||
raise ValueError(f"应用未发布: {app_id}")
|
||||
|
||||
logger.debug(f"Found latest release: version={latest_release.version}, id={latest_release.id}")
|
||||
|
||||
# 3. 从 config 中提取 memory_config_id
|
||||
config = latest_release.config or {}
|
||||
memory_obj = config.get('memory', {})
|
||||
memory_config_id = memory_obj.get('memory_content') if isinstance(memory_obj, dict) else None
|
||||
|
||||
result = {
|
||||
"end_user_id": str(end_user_id),
|
||||
"app_id": str(app_id),
|
||||
"release_id": str(latest_release.id),
|
||||
"release_version": latest_release.version,
|
||||
"memory_config_id": memory_config_id
|
||||
}
|
||||
|
||||
logger.info(f"Successfully retrieved connected config: memory_config_id={memory_config_id}")
|
||||
return result
|
||||
399
api/app/services/memory_config_service.py
Normal file
399
api/app/services/memory_config_service.py
Normal file
@@ -0,0 +1,399 @@
|
||||
"""
|
||||
Memory Configuration Service
|
||||
|
||||
Centralized configuration loading and management for memory services.
|
||||
This service eliminates code duplication between MemoryAgentService and MemoryStorageService.
|
||||
"""
|
||||
|
||||
import time
|
||||
from datetime import datetime
|
||||
|
||||
from app.core.logging_config import get_config_logger, get_logger
|
||||
from app.core.validators.memory_config_validators import (
|
||||
validate_and_resolve_model_id,
|
||||
validate_embedding_model,
|
||||
validate_model_exists_and_active,
|
||||
)
|
||||
from app.repositories.data_config_repository import DataConfigRepository
|
||||
from app.schemas.memory_config_schema import (
|
||||
ConfigurationError,
|
||||
InvalidConfigError,
|
||||
MemoryConfig,
|
||||
ModelInactiveError,
|
||||
ModelNotFoundError,
|
||||
)
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
logger = get_logger(__name__)
|
||||
config_logger = get_config_logger()
|
||||
|
||||
|
||||
def _validate_config_id(config_id):
|
||||
"""Validate configuration ID format."""
|
||||
if config_id is None:
|
||||
raise InvalidConfigError(
|
||||
"Configuration ID cannot be None",
|
||||
field_name="config_id",
|
||||
invalid_value=config_id,
|
||||
)
|
||||
|
||||
if isinstance(config_id, int):
|
||||
if config_id <= 0:
|
||||
raise InvalidConfigError(
|
||||
f"Configuration ID must be positive: {config_id}",
|
||||
field_name="config_id",
|
||||
invalid_value=config_id,
|
||||
)
|
||||
return config_id
|
||||
|
||||
if isinstance(config_id, str):
|
||||
try:
|
||||
parsed_id = int(config_id.strip())
|
||||
if parsed_id <= 0:
|
||||
raise InvalidConfigError(
|
||||
f"Configuration ID must be positive: {parsed_id}",
|
||||
field_name="config_id",
|
||||
invalid_value=config_id,
|
||||
)
|
||||
return parsed_id
|
||||
except ValueError:
|
||||
raise InvalidConfigError(
|
||||
f"Invalid configuration ID format: '{config_id}'",
|
||||
field_name="config_id",
|
||||
invalid_value=config_id,
|
||||
)
|
||||
|
||||
raise InvalidConfigError(
|
||||
f"Invalid type for configuration ID: expected int or str, got {type(config_id).__name__}",
|
||||
field_name="config_id",
|
||||
invalid_value=config_id,
|
||||
)
|
||||
|
||||
|
||||
class MemoryConfigService:
|
||||
"""
|
||||
Centralized service for memory configuration loading and validation.
|
||||
|
||||
This class provides a single implementation of configuration loading logic
|
||||
that can be shared across multiple services, eliminating code duplication.
|
||||
|
||||
Usage:
|
||||
config_service = MemoryConfigService(db)
|
||||
memory_config = config_service.load_memory_config(config_id)
|
||||
model_config = config_service.get_model_config(model_id)
|
||||
"""
|
||||
|
||||
def __init__(self, db: Session):
|
||||
"""Initialize the service with a database session.
|
||||
|
||||
Args:
|
||||
db: SQLAlchemy database session
|
||||
"""
|
||||
self.db = db
|
||||
|
||||
def load_memory_config(
|
||||
self,
|
||||
config_id: int,
|
||||
service_name: str = "MemoryConfigService",
|
||||
) -> MemoryConfig:
|
||||
"""
|
||||
Load memory configuration from database by config_id.
|
||||
|
||||
Args:
|
||||
config_id: Configuration ID from database
|
||||
service_name: Name of the calling service (for logging purposes)
|
||||
|
||||
Returns:
|
||||
MemoryConfig: Immutable configuration object
|
||||
|
||||
Raises:
|
||||
ConfigurationError: If validation fails
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
config_logger.info(
|
||||
"Starting memory configuration loading",
|
||||
extra={
|
||||
"operation": "load_memory_config",
|
||||
"service": service_name,
|
||||
"config_id": config_id,
|
||||
},
|
||||
)
|
||||
|
||||
logger.info(f"Loading memory configuration from database: config_id={config_id}")
|
||||
|
||||
try:
|
||||
validated_config_id = _validate_config_id(config_id)
|
||||
|
||||
result = DataConfigRepository.get_config_with_workspace(self.db, validated_config_id)
|
||||
if not result:
|
||||
elapsed_ms = (time.time() - start_time) * 1000
|
||||
config_logger.error(
|
||||
"Configuration not found in database",
|
||||
extra={
|
||||
"operation": "load_memory_config",
|
||||
"config_id": validated_config_id,
|
||||
"load_result": "not_found",
|
||||
"elapsed_ms": elapsed_ms,
|
||||
"service": service_name,
|
||||
},
|
||||
)
|
||||
raise ConfigurationError(
|
||||
f"Configuration {validated_config_id} not found in database"
|
||||
)
|
||||
|
||||
memory_config, workspace = result
|
||||
|
||||
# Validate embedding model
|
||||
embedding_uuid = validate_embedding_model(
|
||||
validated_config_id,
|
||||
memory_config.embedding_id,
|
||||
self.db,
|
||||
workspace.tenant_id,
|
||||
workspace.id,
|
||||
)
|
||||
|
||||
# Resolve LLM model
|
||||
llm_uuid, llm_name = validate_and_resolve_model_id(
|
||||
memory_config.llm_id,
|
||||
"llm",
|
||||
self.db,
|
||||
workspace.tenant_id,
|
||||
required=True,
|
||||
config_id=validated_config_id,
|
||||
workspace_id=workspace.id,
|
||||
)
|
||||
|
||||
# Resolve optional rerank model
|
||||
rerank_uuid = None
|
||||
rerank_name = None
|
||||
if memory_config.rerank_id:
|
||||
rerank_uuid, rerank_name = validate_and_resolve_model_id(
|
||||
memory_config.rerank_id,
|
||||
"rerank",
|
||||
self.db,
|
||||
workspace.tenant_id,
|
||||
required=False,
|
||||
config_id=validated_config_id,
|
||||
workspace_id=workspace.id,
|
||||
)
|
||||
|
||||
# Get embedding model name
|
||||
embedding_name, _ = validate_model_exists_and_active(
|
||||
embedding_uuid,
|
||||
"embedding",
|
||||
self.db,
|
||||
workspace.tenant_id,
|
||||
config_id=validated_config_id,
|
||||
workspace_id=workspace.id,
|
||||
)
|
||||
|
||||
# Create immutable MemoryConfig object
|
||||
config = MemoryConfig(
|
||||
config_id=memory_config.config_id,
|
||||
config_name=memory_config.config_name,
|
||||
workspace_id=workspace.id,
|
||||
workspace_name=workspace.name,
|
||||
tenant_id=workspace.tenant_id,
|
||||
llm_model_id=llm_uuid,
|
||||
llm_model_name=llm_name,
|
||||
embedding_model_id=embedding_uuid,
|
||||
embedding_model_name=embedding_name,
|
||||
rerank_model_id=rerank_uuid,
|
||||
rerank_model_name=rerank_name,
|
||||
storage_type=workspace.storage_type or "neo4j",
|
||||
chunker_strategy=memory_config.chunker_strategy or "RecursiveChunker",
|
||||
reflexion_enabled=memory_config.enable_self_reflexion or False,
|
||||
reflexion_iteration_period=int(memory_config.iteration_period or "3"),
|
||||
reflexion_range=memory_config.reflexion_range or "retrieval",
|
||||
reflexion_baseline=memory_config.baseline or "time",
|
||||
loaded_at=datetime.now(),
|
||||
# Pipeline config: Deduplication
|
||||
enable_llm_dedup_blockwise=bool(memory_config.enable_llm_dedup_blockwise) if memory_config.enable_llm_dedup_blockwise is not None else False,
|
||||
enable_llm_disambiguation=bool(memory_config.enable_llm_disambiguation) if memory_config.enable_llm_disambiguation is not None else False,
|
||||
deep_retrieval=bool(memory_config.deep_retrieval) if memory_config.deep_retrieval is not None else True,
|
||||
t_type_strict=float(memory_config.t_type_strict) if memory_config.t_type_strict is not None else 0.8,
|
||||
t_name_strict=float(memory_config.t_name_strict) if memory_config.t_name_strict is not None else 0.8,
|
||||
t_overall=float(memory_config.t_overall) if memory_config.t_overall is not None else 0.8,
|
||||
# Pipeline config: Statement extraction
|
||||
statement_granularity=int(memory_config.statement_granularity) if memory_config.statement_granularity is not None else 2,
|
||||
include_dialogue_context=bool(memory_config.include_dialogue_context) if memory_config.include_dialogue_context is not None else False,
|
||||
max_dialogue_context_chars=int(memory_config.max_context) if memory_config.max_context is not None else 1000,
|
||||
# Pipeline config: Forgetting engine
|
||||
lambda_time=float(memory_config.lambda_time) if memory_config.lambda_time is not None else 0.5,
|
||||
lambda_mem=float(memory_config.lambda_mem) if memory_config.lambda_mem is not None else 0.5,
|
||||
offset=float(memory_config.offset) if memory_config.offset is not None else 0.0,
|
||||
# Pipeline config: Pruning
|
||||
pruning_enabled=bool(memory_config.pruning_enabled) if memory_config.pruning_enabled is not None else False,
|
||||
pruning_scene=memory_config.pruning_scene or "education",
|
||||
pruning_threshold=float(memory_config.pruning_threshold) if memory_config.pruning_threshold is not None else 0.5,
|
||||
)
|
||||
|
||||
elapsed_ms = (time.time() - start_time) * 1000
|
||||
|
||||
config_logger.info(
|
||||
"Memory configuration loaded successfully",
|
||||
extra={
|
||||
"operation": "load_memory_config",
|
||||
"service": service_name,
|
||||
"config_id": validated_config_id,
|
||||
"config_name": config.config_name,
|
||||
"workspace_id": str(config.workspace_id),
|
||||
"load_result": "success",
|
||||
"elapsed_ms": elapsed_ms,
|
||||
},
|
||||
)
|
||||
|
||||
logger.info(f"Memory configuration loaded successfully: {config.config_name}")
|
||||
return config
|
||||
|
||||
except Exception as e:
|
||||
elapsed_ms = (time.time() - start_time) * 1000
|
||||
|
||||
config_logger.error(
|
||||
"Failed to load memory configuration",
|
||||
extra={
|
||||
"operation": "load_memory_config",
|
||||
"service": service_name,
|
||||
"config_id": config_id,
|
||||
"load_result": "error",
|
||||
"error_type": type(e).__name__,
|
||||
"error_message": str(e),
|
||||
"elapsed_ms": elapsed_ms,
|
||||
},
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
logger.error(f"Failed to load memory configuration {config_id}: {e}")
|
||||
if isinstance(e, (ConfigurationError, ValueError)):
|
||||
raise
|
||||
else:
|
||||
raise ConfigurationError(f"Failed to load configuration {config_id}: {e}")
|
||||
|
||||
def get_model_config(self, model_id: str) -> dict:
|
||||
"""Get LLM model configuration by ID.
|
||||
|
||||
Args:
|
||||
model_id: Model ID to look up
|
||||
|
||||
Returns:
|
||||
Dict with model configuration including api_key, base_url, etc.
|
||||
"""
|
||||
from app.core.config import settings
|
||||
from app.models.models_model import ModelApiKey
|
||||
from app.services.model_service import ModelConfigService as ModelSvc
|
||||
from fastapi import status
|
||||
from fastapi.exceptions import HTTPException
|
||||
|
||||
config = ModelSvc.get_model_by_id(db=self.db, model_id=model_id)
|
||||
if not config:
|
||||
logger.warning(f"Model ID {model_id} not found")
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="模型ID不存在")
|
||||
|
||||
api_config: ModelApiKey = config.api_keys[0]
|
||||
|
||||
return {
|
||||
"model_name": api_config.model_name,
|
||||
"provider": api_config.provider,
|
||||
"api_key": api_config.api_key,
|
||||
"base_url": api_config.api_base,
|
||||
"model_config_id": api_config.model_config_id,
|
||||
"type": config.type,
|
||||
"timeout": settings.LLM_TIMEOUT,
|
||||
"max_retries": settings.LLM_MAX_RETRIES,
|
||||
}
|
||||
|
||||
def get_embedder_config(self, embedding_id: str) -> dict:
|
||||
"""Get embedding model configuration by ID.
|
||||
|
||||
Args:
|
||||
embedding_id: Embedding model ID to look up
|
||||
|
||||
Returns:
|
||||
Dict with embedder configuration including api_key, base_url, etc.
|
||||
"""
|
||||
from app.models.models_model import ModelApiKey
|
||||
from app.services.model_service import ModelConfigService as ModelSvc
|
||||
from fastapi import status
|
||||
from fastapi.exceptions import HTTPException
|
||||
|
||||
config = ModelSvc.get_model_by_id(db=self.db, model_id=embedding_id)
|
||||
if not config:
|
||||
logger.warning(f"Embedding model ID {embedding_id} not found")
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="嵌入模型ID不存在")
|
||||
|
||||
api_config: ModelApiKey = config.api_keys[0]
|
||||
|
||||
return {
|
||||
"model_name": api_config.model_name,
|
||||
"provider": api_config.provider,
|
||||
"api_key": api_config.api_key,
|
||||
"base_url": api_config.api_base,
|
||||
"model_config_id": api_config.model_config_id,
|
||||
"type": config.type,
|
||||
"timeout": 120.0,
|
||||
"max_retries": 5,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def get_pipeline_config(memory_config: MemoryConfig):
|
||||
"""Build ExtractionPipelineConfig from MemoryConfig.
|
||||
|
||||
Args:
|
||||
memory_config: MemoryConfig object containing all pipeline settings.
|
||||
|
||||
Returns:
|
||||
ExtractionPipelineConfig with deduplication, statement extraction,
|
||||
and forgetting engine settings.
|
||||
"""
|
||||
from app.core.memory.models.variate_config import (
|
||||
DedupConfig,
|
||||
ExtractionPipelineConfig,
|
||||
ForgettingEngineConfig,
|
||||
StatementExtractionConfig,
|
||||
)
|
||||
|
||||
dedup_config = DedupConfig(
|
||||
enable_llm_dedup_blockwise=memory_config.enable_llm_dedup_blockwise,
|
||||
enable_llm_disambiguation=memory_config.enable_llm_disambiguation,
|
||||
fuzzy_name_threshold_strict=memory_config.t_name_strict,
|
||||
fuzzy_type_threshold_strict=memory_config.t_type_strict,
|
||||
fuzzy_overall_threshold=memory_config.t_overall,
|
||||
)
|
||||
|
||||
stmt_config = StatementExtractionConfig(
|
||||
statement_granularity=memory_config.statement_granularity,
|
||||
include_dialogue_context=memory_config.include_dialogue_context,
|
||||
max_dialogue_context_chars=memory_config.max_dialogue_context_chars,
|
||||
)
|
||||
|
||||
forget_config = ForgettingEngineConfig(
|
||||
offset=memory_config.offset,
|
||||
lambda_time=memory_config.lambda_time,
|
||||
lambda_mem=memory_config.lambda_mem,
|
||||
)
|
||||
|
||||
return ExtractionPipelineConfig(
|
||||
statement_extraction=stmt_config,
|
||||
deduplication=dedup_config,
|
||||
forgetting_engine=forget_config,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_pruning_config(memory_config: MemoryConfig) -> dict:
|
||||
"""Retrieve semantic pruning config from MemoryConfig.
|
||||
|
||||
Args:
|
||||
memory_config: MemoryConfig object containing pruning settings.
|
||||
|
||||
Returns:
|
||||
Dict suitable for PruningConfig.model_validate with keys:
|
||||
- pruning_switch: bool
|
||||
- pruning_scene: str
|
||||
- pruning_threshold: float
|
||||
"""
|
||||
return {
|
||||
"pruning_switch": memory_config.pruning_enabled,
|
||||
"pruning_scene": memory_config.pruning_scene,
|
||||
"pruning_threshold": memory_config.pruning_threshold,
|
||||
}
|
||||
@@ -272,7 +272,7 @@ async def get_workspace_total_memory_count(
|
||||
from app.repositories.end_user_repository import EndUserRepository
|
||||
repo = EndUserRepository(db)
|
||||
end_user = repo.get_by_id(uuid.UUID(end_user_id))
|
||||
user_name = end_user.name if end_user else None
|
||||
user_name = end_user.other_name if end_user else None
|
||||
|
||||
return {
|
||||
"total_memory_count": search_result.get("total", 0),
|
||||
@@ -298,10 +298,10 @@ async def get_workspace_total_memory_count(
|
||||
details.append({
|
||||
"end_user_id": end_user_id_str,
|
||||
"count": host_total,
|
||||
"name": host.name # 添加 name 字段
|
||||
"name": host.other_name # 使用 other_name 字段
|
||||
})
|
||||
|
||||
business_logger.debug(f"EndUser {end_user_id_str} ({host.name}) 记忆数: {host_total}")
|
||||
business_logger.debug(f"EndUser {end_user_id_str} ({host.other_name}) 记忆数: {host_total}")
|
||||
|
||||
except Exception as e:
|
||||
business_logger.warning(f"获取 end_user {host.id} 记忆数失败: {str(e)}")
|
||||
@@ -309,7 +309,7 @@ async def get_workspace_total_memory_count(
|
||||
details.append({
|
||||
"end_user_id": str(host.id),
|
||||
"count": 0,
|
||||
"name": host.name # 添加 name 字段
|
||||
"name": host.other_name # 使用 other_name 字段
|
||||
})
|
||||
|
||||
result = {
|
||||
|
||||
@@ -4,38 +4,37 @@ Memory Storage Service
|
||||
Handles business logic for memory storage operations.
|
||||
"""
|
||||
|
||||
from typing import Dict, List, Optional, Any, AsyncGenerator
|
||||
import os
|
||||
import json
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional
|
||||
|
||||
from app.core.logging_config import get_config_logger, get_logger
|
||||
from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags
|
||||
from app.core.memory.analytics.recent_activity_stats import get_recent_activity_stats
|
||||
from app.models.user_model import User
|
||||
from app.core.logging_config import get_logger
|
||||
from app.utils.sse_utils import format_sse_message
|
||||
from app.repositories.data_config_repository import DataConfigRepository
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.schemas.memory_config_schema import ConfigurationError
|
||||
from app.schemas.memory_storage_schema import (
|
||||
ConfigPilotRun,
|
||||
ConfigKey,
|
||||
ConfigParamsCreate,
|
||||
ConfigParamsDelete,
|
||||
ConfigPilotRun,
|
||||
ConfigUpdate,
|
||||
ConfigUpdateExtracted,
|
||||
ConfigUpdateForget,
|
||||
ConfigKey,
|
||||
)
|
||||
from app.repositories.data_config_repository import DataConfigRepository
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags
|
||||
from app.core.memory.analytics.memory_insight import MemoryInsight
|
||||
from app.core.memory.analytics.recent_activity_stats import get_recent_activity_stats
|
||||
from app.core.memory.analytics.user_summary import generate_user_summary
|
||||
from app.repositories.end_user_repository import EndUserRepository
|
||||
import uuid
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
from app.utils.sse_utils import format_sse_message
|
||||
from dotenv import load_dotenv
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
logger = get_logger(__name__)
|
||||
config_logger = get_config_logger()
|
||||
|
||||
# Load environment variables for Neo4j connector
|
||||
load_dotenv()
|
||||
@@ -247,7 +246,6 @@ class DataConfigService: # 数据配置服务类(PostgreSQL)
|
||||
RuntimeError: 当管线执行失败时
|
||||
"""
|
||||
project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
dbrun_path = os.path.join(project_root, "app", "core", "memory", "dbrun.json")
|
||||
|
||||
try:
|
||||
# 发出初始进度事件
|
||||
@@ -256,24 +254,12 @@ class DataConfigService: # 数据配置服务类(PostgreSQL)
|
||||
"time": int(time.time() * 1000)
|
||||
})
|
||||
|
||||
# 步骤 1: 配置加载和验证(复用现有逻辑)
|
||||
# 步骤 1: 配置加载和验证(数据库优先)
|
||||
payload_cid = str(getattr(payload, "config_id", "") or "").strip()
|
||||
cid: Optional[str] = payload_cid if payload_cid else None
|
||||
|
||||
if not cid and os.path.isfile(dbrun_path):
|
||||
try:
|
||||
with open(dbrun_path, "r", encoding="utf-8") as f:
|
||||
dbrun = json.load(f)
|
||||
if isinstance(dbrun, dict):
|
||||
sel = dbrun.get("selections", {})
|
||||
if isinstance(sel, dict):
|
||||
fallback_cid = str(sel.get("config_id") or "").strip()
|
||||
cid = fallback_cid or None
|
||||
except Exception:
|
||||
cid = None
|
||||
|
||||
if not cid:
|
||||
raise ValueError("未提供 payload.config_id,且 dbrun.json 未设置 selections.config_id,禁止启动试运行")
|
||||
raise ValueError("未提供 payload.config_id,禁止启动试运行")
|
||||
|
||||
# 验证 dialogue_text 必须提供
|
||||
dialogue_text = payload.dialogue_text.strip() if payload.dialogue_text else ""
|
||||
@@ -281,12 +267,16 @@ class DataConfigService: # 数据配置服务类(PostgreSQL)
|
||||
if not dialogue_text:
|
||||
raise ValueError("试运行模式必须提供 dialogue_text 参数")
|
||||
|
||||
# 应用内存覆写并刷新常量
|
||||
from app.core.memory.utils.config.definitions import reload_configuration_from_database
|
||||
|
||||
ok_override = reload_configuration_from_database(cid)
|
||||
if not ok_override:
|
||||
raise RuntimeError("运行时覆写失败,config_id 无效或刷新常量失败")
|
||||
# Load configuration from database only using centralized manager
|
||||
try:
|
||||
config_service = MemoryConfigService(self.db)
|
||||
memory_config = config_service.load_memory_config(
|
||||
config_id=int(cid),
|
||||
service_name="MemoryStorageService.pilot_run_stream"
|
||||
)
|
||||
logger.info(f"Configuration loaded successfully: {memory_config.config_name}")
|
||||
except ConfigurationError as e:
|
||||
raise RuntimeError(f"Configuration loading failed: {e}")
|
||||
|
||||
# 步骤 2: 创建进度回调函数捕获管线进度
|
||||
# 使用队列在回调和生成器之间传递进度事件
|
||||
@@ -307,13 +297,14 @@ class DataConfigService: # 数据配置服务类(PostgreSQL)
|
||||
async def run_pipeline():
|
||||
"""在后台执行管线并捕获异常"""
|
||||
try:
|
||||
from app.core.memory.main import main as pipeline_main
|
||||
from app.services.pilot_run_service import run_pilot_extraction
|
||||
|
||||
logger.info(f"[PILOT_RUN_STREAM] Calling pipeline_main with dialogue_text length: {len(dialogue_text)}, is_pilot_run=True")
|
||||
await pipeline_main(
|
||||
dialogue_text=dialogue_text,
|
||||
is_pilot_run=True,
|
||||
progress_callback=progress_callback
|
||||
logger.info(f"[PILOT_RUN_STREAM] Calling run_pilot_extraction with dialogue_text length: {len(dialogue_text)}")
|
||||
await run_pilot_extraction(
|
||||
memory_config=memory_config,
|
||||
dialogue_text=dialogue_text,
|
||||
db=self.db,
|
||||
progress_callback=progress_callback,
|
||||
)
|
||||
logger.info("[PILOT_RUN_STREAM] pipeline_main completed")
|
||||
|
||||
|
||||
219
api/app/services/pilot_run_service.py
Normal file
219
api/app/services/pilot_run_service.py
Normal file
@@ -0,0 +1,219 @@
|
||||
"""
|
||||
Pilot Run Service - 试运行服务
|
||||
|
||||
用于执行记忆系统的试运行流程,不保存到 Neo4j。
|
||||
"""
|
||||
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import Awaitable, Callable, Optional
|
||||
|
||||
from app.core.logging_config import get_memory_logger, log_time
|
||||
from app.core.memory.models.message_models import (
|
||||
ConversationContext,
|
||||
ConversationMessage,
|
||||
DialogData,
|
||||
)
|
||||
from app.core.memory.storage_services.extraction_engine.extraction_orchestrator import (
|
||||
ExtractionOrchestrator,
|
||||
get_chunked_dialogs_from_preprocessed,
|
||||
)
|
||||
from app.core.memory.utils.config.config_utils import (
|
||||
get_pipeline_config,
|
||||
)
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.schemas.memory_config_schema import MemoryConfig
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
logger = get_memory_logger(__name__)
|
||||
|
||||
|
||||
async def run_pilot_extraction(
|
||||
memory_config: MemoryConfig,
|
||||
dialogue_text: str,
|
||||
db: Session,
|
||||
progress_callback: Optional[Callable[[str, str, Optional[dict]], Awaitable[None]]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
执行试运行模式的知识提取流水线。
|
||||
|
||||
Args:
|
||||
memory_config: 从数据库加载的内存配置对象
|
||||
dialogue_text: 输入的对话文本
|
||||
progress_callback: 可选的进度回调函数
|
||||
- 参数1 (stage): 当前处理阶段标识符
|
||||
- 参数2 (message): 人类可读的进度消息
|
||||
- 参数3 (data): 可选的附加数据字典
|
||||
"""
|
||||
log_file = "logs/time.log"
|
||||
os.makedirs(os.path.dirname(log_file), exist_ok=True)
|
||||
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
with open(log_file, "a", encoding="utf-8") as f:
|
||||
f.write(f"\n=== Pilot Run Started: {timestamp} ===\n")
|
||||
|
||||
pipeline_start = time.time()
|
||||
neo4j_connector = None
|
||||
|
||||
try:
|
||||
# 步骤 1: 初始化客户端
|
||||
logger.info("Initializing clients...")
|
||||
step_start = time.time()
|
||||
|
||||
client_factory = MemoryClientFactory(db)
|
||||
llm_client = client_factory.get_llm_client(str(memory_config.llm_model_id))
|
||||
embedder_client = client_factory.get_embedder_client(str(memory_config.embedding_model_id))
|
||||
|
||||
neo4j_connector = Neo4jConnector()
|
||||
|
||||
log_time("Client Initialization", time.time() - step_start, log_file)
|
||||
|
||||
# 步骤 2: 解析对话文本
|
||||
logger.info("Parsing dialogue text...")
|
||||
step_start = time.time()
|
||||
|
||||
# 解析对话文本,支持 "用户:" 和 "AI:" 格式
|
||||
pattern = r"(用户|AI)[::]\s*([^\n]+(?:\n(?!(?:用户|AI)[::])[^\n]*)*?)"
|
||||
matches = re.findall(pattern, dialogue_text, re.MULTILINE | re.DOTALL)
|
||||
messages = [
|
||||
ConversationMessage(role=r, msg=c.strip())
|
||||
for r, c in matches
|
||||
if c.strip()
|
||||
]
|
||||
|
||||
# 如果没有匹配到格式化的对话,将整个文本作为用户消息
|
||||
if not messages:
|
||||
messages = [ConversationMessage(role="用户", msg=dialogue_text.strip())]
|
||||
|
||||
context = ConversationContext(msgs=messages)
|
||||
dialog = DialogData(
|
||||
context=context,
|
||||
ref_id="pilot_dialog_1",
|
||||
group_id=str(memory_config.workspace_id),
|
||||
user_id=str(memory_config.tenant_id),
|
||||
apply_id=str(memory_config.config_id),
|
||||
metadata={"source": "pilot_run", "input_type": "frontend_text"},
|
||||
)
|
||||
|
||||
if progress_callback:
|
||||
await progress_callback("text_preprocessing", "开始预处理文本...")
|
||||
|
||||
chunked_dialogs = await get_chunked_dialogs_from_preprocessed(
|
||||
data=[dialog],
|
||||
chunker_strategy=memory_config.chunker_strategy,
|
||||
llm_client=llm_client,
|
||||
)
|
||||
logger.info(f"Processed dialogue text: {len(messages)} messages")
|
||||
|
||||
# 进度回调:输出每个分块的结果
|
||||
if progress_callback:
|
||||
for dlg in chunked_dialogs:
|
||||
for i, chunk in enumerate(dlg.chunks):
|
||||
chunk_result = {
|
||||
"chunk_index": i + 1,
|
||||
"content": chunk.content[:200] + "..." if len(chunk.content) > 200 else chunk.content,
|
||||
"full_length": len(chunk.content),
|
||||
"dialog_id": dlg.id,
|
||||
"chunker_strategy": memory_config.chunker_strategy,
|
||||
}
|
||||
await progress_callback("text_preprocessing_result", f"分块 {i + 1} 处理完成", chunk_result)
|
||||
|
||||
preprocessing_summary = {
|
||||
"total_chunks": sum(len(dlg.chunks) for dlg in chunked_dialogs),
|
||||
"total_dialogs": len(chunked_dialogs),
|
||||
"chunker_strategy": memory_config.chunker_strategy,
|
||||
}
|
||||
await progress_callback("text_preprocessing_complete", "预处理文本完成", preprocessing_summary)
|
||||
|
||||
log_time("Data Loading & Chunking", time.time() - step_start, log_file)
|
||||
|
||||
# 步骤 3: 初始化流水线编排器
|
||||
logger.info("Initializing extraction orchestrator...")
|
||||
step_start = time.time()
|
||||
|
||||
config = get_pipeline_config(memory_config)
|
||||
logger.info(
|
||||
f"Pipeline config loaded: enable_llm_dedup_blockwise={config.deduplication.enable_llm_dedup_blockwise}, "
|
||||
f"enable_llm_disambiguation={config.deduplication.enable_llm_disambiguation}"
|
||||
)
|
||||
|
||||
orchestrator = ExtractionOrchestrator(
|
||||
llm_client=llm_client,
|
||||
embedder_client=embedder_client,
|
||||
connector=neo4j_connector,
|
||||
config=config,
|
||||
progress_callback=progress_callback,
|
||||
embedding_id=str(memory_config.embedding_model_id),
|
||||
)
|
||||
|
||||
log_time("Orchestrator Initialization", time.time() - step_start, log_file)
|
||||
|
||||
# 步骤 4: 执行知识提取流水线
|
||||
logger.info("Running extraction pipeline...")
|
||||
step_start = time.time()
|
||||
|
||||
if progress_callback:
|
||||
await progress_callback("knowledge_extraction", "正在知识抽取...")
|
||||
|
||||
extraction_result = await orchestrator.run(
|
||||
dialog_data_list=chunked_dialogs,
|
||||
is_pilot_run=True,
|
||||
)
|
||||
|
||||
# 解包 extraction_result tuple (与 main.py 保持一致)
|
||||
(
|
||||
dialogue_nodes,
|
||||
chunk_nodes,
|
||||
statement_nodes,
|
||||
entity_nodes,
|
||||
statement_chunk_edges,
|
||||
statement_entity_edges,
|
||||
entity_edges,
|
||||
) = extraction_result
|
||||
|
||||
log_time("Extraction Pipeline", time.time() - step_start, log_file)
|
||||
|
||||
if progress_callback:
|
||||
await progress_callback("generating_results", "正在生成结果...")
|
||||
|
||||
# 步骤 5: 生成记忆摘要(与 main.py 保持一致)
|
||||
try:
|
||||
logger.info("Generating memory summaries...")
|
||||
step_start = time.time()
|
||||
|
||||
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.memory_summary import (
|
||||
memory_summary_generation,
|
||||
)
|
||||
|
||||
summaries = await memory_summary_generation(
|
||||
chunked_dialogs,
|
||||
llm_client=llm_client,
|
||||
embedder_client=embedder_client,
|
||||
)
|
||||
|
||||
log_time("Memory Summary Generation", time.time() - step_start, log_file)
|
||||
except Exception as e:
|
||||
logger.error(f"Memory summary step failed: {e}", exc_info=True)
|
||||
|
||||
logger.info("Pilot run completed: Skipping Neo4j save")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Pilot run failed: {e}", exc_info=True)
|
||||
raise
|
||||
finally:
|
||||
if neo4j_connector:
|
||||
try:
|
||||
await neo4j_connector.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
total_time = time.time() - pipeline_start
|
||||
log_time("TOTAL PILOT RUN TIME", total_time, log_file)
|
||||
|
||||
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
with open(log_file, "a", encoding="utf-8") as f:
|
||||
f.write(f"=== Pilot Run Completed: {timestamp} ===\n\n")
|
||||
|
||||
logger.info(f"Pilot run complete. Total time: {total_time:.2f}s")
|
||||
@@ -4,15 +4,15 @@ User Memory Service
|
||||
处理用户记忆相关的业务逻辑,包括记忆洞察、用户摘要、节点统计和图数据等。
|
||||
"""
|
||||
|
||||
from typing import Dict, List, Optional, Any
|
||||
import uuid
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from app.core.logging_config import get_logger
|
||||
from app.repositories.end_user_repository import EndUserRepository
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.core.memory.analytics.memory_insight import MemoryInsight
|
||||
from app.core.memory.analytics.user_summary import generate_user_summary
|
||||
from app.repositories.end_user_repository import EndUserRepository
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -284,8 +284,7 @@ class UserMemoryService:
|
||||
# 使用 end_user_id 调用分析函数
|
||||
try:
|
||||
logger.info(f"使用 end_user_id={end_user_id} 生成用户摘要")
|
||||
result = await analytics_user_summary(end_user_id)
|
||||
summary = result.get("summary", "")
|
||||
summary = await generate_user_summary(end_user_id)
|
||||
|
||||
if not summary:
|
||||
logger.warning(f"end_user_id {end_user_id} 的用户摘要生成结果为空")
|
||||
@@ -535,6 +534,112 @@ async def analytics_node_statistics(
|
||||
return data
|
||||
|
||||
|
||||
async def analytics_memory_types(
|
||||
db: Session,
|
||||
end_user_id: Optional[str] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
统计8种记忆类型的数量和百分比
|
||||
|
||||
计算规则:
|
||||
1. 感知记忆 (PERCEPTUAL_MEMORY) = statement + entity
|
||||
2. 工作记忆 (WORKING_MEMORY) = chunk + entity
|
||||
3. 短期记忆 (SHORT_TERM_MEMORY) = chunk
|
||||
4. 长期记忆 (LONG_TERM_MEMORY) = entity
|
||||
5. 显性记忆 (EXPLICIT_MEMORY) = 1/2 * entity
|
||||
6. 隐性记忆 (IMPLICIT_MEMORY) = 1/3 * entity
|
||||
7. 情绪记忆 (EMOTIONAL_MEMORY) = statement
|
||||
8. 情景记忆 (EPISODIC_MEMORY) = memory_summary
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
end_user_id: 可选的终端用户ID (UUID),用于过滤特定用户的节点
|
||||
|
||||
Returns:
|
||||
[
|
||||
{
|
||||
"type": str, # 记忆类型枚举值 (如 PERCEPTUAL_MEMORY, WORKING_MEMORY 等)
|
||||
"count": int, # 该类型的数量
|
||||
"percentage": float # 该类型在所有记忆中的占比
|
||||
},
|
||||
...
|
||||
]
|
||||
|
||||
记忆类型枚举值:
|
||||
- PERCEPTUAL_MEMORY: 感知记忆
|
||||
- WORKING_MEMORY: 工作记忆
|
||||
- SHORT_TERM_MEMORY: 短期记忆
|
||||
- LONG_TERM_MEMORY: 长期记忆
|
||||
- EXPLICIT_MEMORY: 显性记忆
|
||||
- IMPLICIT_MEMORY: 隐性记忆
|
||||
- EMOTIONAL_MEMORY: 情绪记忆
|
||||
- EPISODIC_MEMORY: 情景记忆
|
||||
"""
|
||||
# 定义需要查询的节点类型
|
||||
node_types = {
|
||||
"Statement": "Statement",
|
||||
"Entity": "ExtractedEntity",
|
||||
"Chunk": "Chunk",
|
||||
"MemorySummary": "MemorySummary"
|
||||
}
|
||||
|
||||
# 存储每种节点类型的计数
|
||||
node_counts = {}
|
||||
|
||||
# 查询每种节点类型的数量
|
||||
for key, node_type in node_types.items():
|
||||
if end_user_id:
|
||||
query = f"""
|
||||
MATCH (n:{node_type})
|
||||
WHERE n.group_id = $group_id
|
||||
RETURN count(n) as count
|
||||
"""
|
||||
result = await _neo4j_connector.execute_query(query, group_id=end_user_id)
|
||||
else:
|
||||
query = f"""
|
||||
MATCH (n:{node_type})
|
||||
RETURN count(n) as count
|
||||
"""
|
||||
result = await _neo4j_connector.execute_query(query)
|
||||
|
||||
# 提取计数结果
|
||||
count = result[0]["count"] if result and len(result) > 0 else 0
|
||||
node_counts[key] = count
|
||||
|
||||
# 获取各节点类型的数量
|
||||
statement_count = node_counts.get("Statement", 0)
|
||||
entity_count = node_counts.get("Entity", 0)
|
||||
chunk_count = node_counts.get("Chunk", 0)
|
||||
memory_summary_count = node_counts.get("MemorySummary", 0)
|
||||
|
||||
# 按规则计算8种记忆类型的数量(使用英文枚举作为key)
|
||||
memory_counts = {
|
||||
"PERCEPTUAL_MEMORY": statement_count + entity_count, # 感知记忆
|
||||
"WORKING_MEMORY": chunk_count + entity_count, # 工作记忆
|
||||
"SHORT_TERM_MEMORY": chunk_count, # 短期记忆
|
||||
"LONG_TERM_MEMORY": entity_count, # 长期记忆
|
||||
"EXPLICIT_MEMORY": entity_count // 2, # 显性记忆 (1/2 entity)
|
||||
"IMPLICIT_MEMORY": entity_count // 3, # 隐性记忆 (1/3 entity)
|
||||
"EMOTIONAL_MEMORY": statement_count, # 情绪记忆
|
||||
"EPISODIC_MEMORY": memory_summary_count # 情景记忆
|
||||
}
|
||||
|
||||
# 计算总数
|
||||
total = sum(memory_counts.values())
|
||||
|
||||
# 构建返回数据,包含 type、count 和 percentage
|
||||
memory_types = []
|
||||
for memory_type, count in memory_counts.items():
|
||||
percentage = round((count / total * 100), 2) if total > 0 else 0.0
|
||||
memory_types.append({
|
||||
"type": memory_type,
|
||||
"count": count,
|
||||
"percentage": percentage
|
||||
})
|
||||
|
||||
return memory_types
|
||||
|
||||
|
||||
async def analytics_graph_data(
|
||||
db: Session,
|
||||
end_user_id: str,
|
||||
|
||||
@@ -5,29 +5,37 @@ import uuid
|
||||
from os import getenv
|
||||
from typing import List, Optional
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.exceptions import BusinessException, PermissionDeniedException
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.models.user_model import User
|
||||
from app.models.workspace_model import Workspace, WorkspaceRole, InviteStatus, WorkspaceMember
|
||||
from app.repositories import workspace_repository
|
||||
from app.repositories.workspace_invite_repository import WorkspaceInviteRepository
|
||||
from app.schemas.workspace_schema import (
|
||||
WorkspaceCreate,
|
||||
WorkspaceUpdate,
|
||||
WorkspaceInviteCreate,
|
||||
WorkspaceInviteResponse,
|
||||
InviteValidateResponse,
|
||||
InviteAcceptRequest,
|
||||
WorkspaceMemberUpdate
|
||||
)
|
||||
|
||||
# 获取业务逻辑专用日志器
|
||||
business_logger = get_business_logger()
|
||||
from app.models.workspace_model import (
|
||||
InviteStatus,
|
||||
Workspace,
|
||||
WorkspaceMember,
|
||||
WorkspaceRole,
|
||||
)
|
||||
from app.repositories import workspace_repository
|
||||
from app.repositories.workspace_invite_repository import WorkspaceInviteRepository
|
||||
from app.schemas.workspace_schema import (
|
||||
InviteAcceptRequest,
|
||||
InviteValidateResponse,
|
||||
WorkspaceCreate,
|
||||
WorkspaceInviteCreate,
|
||||
WorkspaceInviteResponse,
|
||||
WorkspaceMemberUpdate,
|
||||
WorkspaceModelsUpdate,
|
||||
WorkspaceUpdate,
|
||||
)
|
||||
from dotenv import load_dotenv
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
# 获取业务逻辑专用日志器
|
||||
business_logger = get_business_logger()
|
||||
load_dotenv()
|
||||
def switch_workspace(
|
||||
db: Session,
|
||||
@@ -131,10 +139,9 @@ def create_workspace(
|
||||
f"{db_workspace.id} 创建知识库"
|
||||
)
|
||||
try:
|
||||
import os
|
||||
from app.schemas.knowledge_schema import KnowledgeCreate
|
||||
from app.models.knowledge_model import KnowledgeType, PermissionType
|
||||
from app.repositories import knowledge_repository
|
||||
from app.schemas.knowledge_schema import KnowledgeCreate
|
||||
|
||||
# 创建知识库数据
|
||||
knowledge_data = KnowledgeCreate(
|
||||
@@ -229,7 +236,7 @@ def get_workspace_members(
|
||||
)
|
||||
|
||||
# 权限检查:工作空间成员或超级管理员可以查看成员列表
|
||||
from app.core.permissions import permission_service, Subject, Resource, Action
|
||||
from app.core.permissions import Action, Resource, Subject, permission_service
|
||||
member = workspace_repository.get_member_in_workspace(
|
||||
db=db, user_id=user.id, workspace_id=workspace_id
|
||||
)
|
||||
@@ -322,7 +329,7 @@ def _check_workspace_admin_permission(db: Session, workspace_id: uuid.UUID, user
|
||||
)
|
||||
|
||||
# 使用统一权限服务检查管理权限
|
||||
from app.core.permissions import permission_service, Subject, Resource, Action
|
||||
from app.core.permissions import Action, Resource, Subject, permission_service
|
||||
|
||||
# 获取用户的工作空间成员关系
|
||||
member = workspace_repository.get_member_in_workspace(
|
||||
@@ -800,3 +807,53 @@ def get_workspace_models_configs(
|
||||
)
|
||||
return configs
|
||||
|
||||
|
||||
def update_workspace_models_configs(
|
||||
db: Session,
|
||||
workspace_id: uuid.UUID,
|
||||
models_update: WorkspaceModelsUpdate,
|
||||
user: User,
|
||||
) -> Workspace:
|
||||
"""更新工作空间的模型配置(llm, embedding, rerank)
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
workspace_id: 工作空间ID
|
||||
models_update: 模型配置更新对象
|
||||
user: 当前用户
|
||||
|
||||
Returns:
|
||||
Workspace: 更新后的工作空间对象
|
||||
"""
|
||||
business_logger.info(f"用户 {user.username} 请求更新工作空间 {workspace_id} 的模型配置")
|
||||
|
||||
# 检查用户是否有管理员权限
|
||||
db_workspace = _check_workspace_admin_permission(db, workspace_id, user)
|
||||
|
||||
try:
|
||||
if models_update.llm is not None:
|
||||
db_workspace.llm = str(models_update.llm) if models_update.llm else None
|
||||
business_logger.debug(f"更新LLM配置: {models_update.llm}")
|
||||
|
||||
if models_update.embedding is not None:
|
||||
db_workspace.embedding = str(models_update.embedding) if models_update.embedding else None
|
||||
business_logger.debug(f"更新嵌入模型配置: {models_update.embedding}")
|
||||
|
||||
if models_update.rerank is not None:
|
||||
db_workspace.rerank = str(models_update.rerank) if models_update.rerank else None
|
||||
business_logger.debug(f"更新重排序模型配置: {models_update.rerank}")
|
||||
|
||||
db.add(db_workspace)
|
||||
db.commit()
|
||||
db.refresh(db_workspace)
|
||||
|
||||
business_logger.info(
|
||||
f"工作空间模型配置更新成功: workspace_id={workspace_id}, "
|
||||
f"llm={db_workspace.llm}, embedding={db_workspace.embedding}, rerank={db_workspace.rerank}"
|
||||
)
|
||||
return db_workspace
|
||||
|
||||
except Exception as e:
|
||||
business_logger.error(f"工作空间模型配置更新失败: workspace_id={workspace_id} - {str(e)}")
|
||||
db.rollback()
|
||||
raise BusinessException(f"更新模型配置失败: {str(e)}", BizCode.INTERNAL_ERROR)
|
||||
|
||||
Reference in New Issue
Block a user