Fix/memory bug fix (#171)
This commit is contained in:
@@ -92,7 +92,7 @@ def create_long_term_memory_tool(memory_config: Dict[str, Any], end_user_id: str
|
||||
try:
|
||||
memory_content = asyncio.run(
|
||||
MemoryAgentService().read_memory(
|
||||
group_id=end_user_id,
|
||||
end_user_id=end_user_id,
|
||||
message=question,
|
||||
history=[],
|
||||
search_switch="2",
|
||||
|
||||
@@ -75,7 +75,7 @@ class EmotionAnalyticsService:
|
||||
|
||||
# 调用仓储层查询
|
||||
tags = await self.emotion_repo.get_emotion_tags(
|
||||
group_id=end_user_id,
|
||||
end_user_id=end_user_id,
|
||||
emotion_type=emotion_type,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
@@ -157,7 +157,7 @@ class EmotionAnalyticsService:
|
||||
|
||||
# 调用仓储层查询
|
||||
keywords = await self.emotion_repo.get_emotion_wordcloud(
|
||||
group_id=end_user_id,
|
||||
end_user_id=end_user_id,
|
||||
emotion_type=emotion_type,
|
||||
limit=limit
|
||||
)
|
||||
@@ -339,7 +339,7 @@ class EmotionAnalyticsService:
|
||||
|
||||
# 获取时间范围内的情绪数据
|
||||
emotions = await self.emotion_repo.get_emotions_in_range(
|
||||
group_id=end_user_id,
|
||||
end_user_id=end_user_id,
|
||||
time_range=time_range
|
||||
)
|
||||
|
||||
@@ -505,7 +505,7 @@ class EmotionAnalyticsService:
|
||||
)
|
||||
config_service = MemoryConfigService(db)
|
||||
memory_config = config_service.load_memory_config(
|
||||
config_id=int(config_id),
|
||||
config_id=(config_id),
|
||||
service_name="EmotionAnalyticsService.generate_emotion_suggestions"
|
||||
)
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
@@ -519,7 +519,7 @@ class EmotionAnalyticsService:
|
||||
|
||||
# 3. 获取情绪数据用于模式分析
|
||||
emotions = await self.emotion_repo.get_emotions_in_range(
|
||||
group_id=end_user_id,
|
||||
end_user_id=end_user_id,
|
||||
time_range="30d"
|
||||
)
|
||||
|
||||
@@ -598,13 +598,13 @@ class EmotionAnalyticsService:
|
||||
# 查询用户的实体和标签
|
||||
query = """
|
||||
MATCH (e:Entity)
|
||||
WHERE e.group_id = $group_id
|
||||
WHERE e.end_user_id = $end_user_id
|
||||
RETURN e.name as name, e.type as type
|
||||
ORDER BY e.created_at DESC
|
||||
LIMIT 20
|
||||
"""
|
||||
|
||||
entities = await connector.execute_query(query, group_id=end_user_id)
|
||||
entities = await connector.execute_query(query, end_user_id=end_user_id)
|
||||
|
||||
# 提取兴趣标签
|
||||
interests = [e["name"] for e in entities if e.get("type") in ["INTEREST", "HOBBY"]][:5]
|
||||
|
||||
@@ -8,9 +8,11 @@ Classes:
|
||||
"""
|
||||
|
||||
from typing import Dict, Any
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.models.data_config_model import DataConfig
|
||||
from app.models.memory_config_model import MemoryConfig
|
||||
from app.core.logging_config import get_business_logger
|
||||
|
||||
logger = get_business_logger()
|
||||
@@ -37,7 +39,7 @@ class EmotionConfigService:
|
||||
self.db = db
|
||||
logger.info("情绪配置服务初始化完成")
|
||||
|
||||
def get_emotion_config(self, config_id: int) -> Dict[str, Any]:
|
||||
def get_emotion_config(self, config_id: UUID) -> Dict[str, Any]:
|
||||
"""获取情绪引擎配置
|
||||
|
||||
查询指定配置ID的情绪相关配置字段。
|
||||
@@ -61,8 +63,8 @@ class EmotionConfigService:
|
||||
logger.info(f"获取情绪配置: config_id={config_id}")
|
||||
|
||||
# 查询配置
|
||||
config = self.db.query(DataConfig).filter(
|
||||
DataConfig.config_id == config_id
|
||||
config = self.db.query(MemoryConfig).filter(
|
||||
MemoryConfig.config_id == config_id
|
||||
).first()
|
||||
|
||||
if not config:
|
||||
@@ -144,7 +146,7 @@ class EmotionConfigService:
|
||||
|
||||
def update_emotion_config(
|
||||
self,
|
||||
config_id: int,
|
||||
config_id: UUID,
|
||||
config_data: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""更新情绪引擎配置
|
||||
@@ -173,8 +175,8 @@ class EmotionConfigService:
|
||||
self.validate_emotion_config(config_data)
|
||||
|
||||
# 查询配置
|
||||
config = self.db.query(DataConfig).filter(
|
||||
DataConfig.config_id == config_id
|
||||
config = self.db.query(MemoryConfig).filter(
|
||||
MemoryConfig.config_id == config_id
|
||||
).first()
|
||||
|
||||
if not config:
|
||||
|
||||
@@ -14,7 +14,7 @@ from app.core.memory.llm_tools.llm_client import LLMClientException
|
||||
from app.core.memory.models.emotion_models import EmotionExtraction
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.db import get_db_context
|
||||
from app.models.data_config_model import DataConfig
|
||||
from app.models.memory_config_model import MemoryConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -60,7 +60,7 @@ class EmotionExtractionService:
|
||||
async def extract_emotion(
|
||||
self,
|
||||
statement: str,
|
||||
config: DataConfig
|
||||
config: MemoryConfig
|
||||
) -> Optional[EmotionExtraction]:
|
||||
"""Extract emotion information from a statement.
|
||||
|
||||
|
||||
@@ -9,6 +9,7 @@ import os
|
||||
import re
|
||||
import time
|
||||
import uuid
|
||||
from uuid import UUID
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional
|
||||
|
||||
import redis
|
||||
@@ -27,6 +28,7 @@ from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.db import get_db_context
|
||||
from app.models.knowledge_model import Knowledge, KnowledgeType
|
||||
from app.repositories.memory_short_repository import ShortTermMemoryRepository
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.schemas.memory_agent_schema import Write_UserInput
|
||||
from app.schemas.memory_config_schema import ConfigurationError
|
||||
@@ -35,6 +37,7 @@ from app.services.memory_config_service import MemoryConfigService
|
||||
from app.services.memory_konwledges_server import (
|
||||
write_rag,
|
||||
)
|
||||
from langchain_core.messages import AIMessage
|
||||
from langchain_core.messages import HumanMessage
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import func
|
||||
@@ -54,25 +57,24 @@ _neo4j_connector = Neo4jConnector()
|
||||
class MemoryAgentService:
|
||||
"""Service for memory agent operations"""
|
||||
|
||||
def writer_messages_deal(self, messages, start_time, group_id, config_id, message, context):
|
||||
def writer_messages_deal(self, messages, start_time, end_user_id, config_id, message, context):
|
||||
duration = time.time() - start_time
|
||||
|
||||
if str(messages) == 'success':
|
||||
logger.info(f"Write operation successful for group {group_id} with config_id {config_id}")
|
||||
logger.info(f"Write operation successful for group {end_user_id} with config_id {config_id}")
|
||||
# 记录成功的操作
|
||||
if audit_logger:
|
||||
audit_logger.log_operation(operation="WRITE", config_id=config_id, group_id=group_id, success=True,
|
||||
audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id, success=True,
|
||||
duration=duration, details={"message_length": len(message)})
|
||||
return context
|
||||
else:
|
||||
logger.warning(f"Write operation failed for group {group_id}")
|
||||
logger.warning(f"Write operation failed for group {end_user_id}")
|
||||
|
||||
# 记录失败的操作
|
||||
if audit_logger:
|
||||
audit_logger.log_operation(
|
||||
operation="WRITE",
|
||||
config_id=config_id,
|
||||
group_id=group_id,
|
||||
end_user_id=end_user_id,
|
||||
success=False,
|
||||
duration=duration,
|
||||
error=f"写入失败: {messages[:100]}"
|
||||
@@ -263,13 +265,13 @@ class MemoryAgentService:
|
||||
logger.info("Log streaming completed, cleaning up resources")
|
||||
# LogStreamer uses context manager for file handling, so cleanup is automatic
|
||||
|
||||
async def write_memory(self, group_id: str, messages: list[dict], config_id: Optional[str], db: Session, storage_type: str, user_rag_memory_id: str) -> str:
|
||||
async def write_memory(self, end_user_id: str, messages: list[dict], config_id: Optional[uuid.UUID], db: Session, storage_type: str, user_rag_memory_id: str) -> str:
|
||||
"""
|
||||
Process write operation with config_id
|
||||
|
||||
Args:
|
||||
group_id: Group identifier (also used as end_user_id)
|
||||
messages: Structured message list [{"role": "user", "content": "..."}, ...]
|
||||
end_user_id: Group identifier (also used as end_user_id)
|
||||
message: Message to write
|
||||
config_id: Configuration ID from database
|
||||
db: SQLAlchemy database session
|
||||
storage_type: Storage type (neo4j or rag)
|
||||
@@ -284,15 +286,15 @@ class MemoryAgentService:
|
||||
# Resolve config_id if None using end_user's connected config
|
||||
if config_id is None:
|
||||
try:
|
||||
connected_config = get_end_user_connected_config(group_id, db)
|
||||
connected_config = get_end_user_connected_config(end_user_id, db)
|
||||
config_id = connected_config.get("memory_config_id")
|
||||
if config_id is None:
|
||||
raise ValueError(f"No memory configuration found for end_user {group_id}. Please ensure the user has a connected memory configuration.")
|
||||
raise ValueError(f"No memory configuration found for end_user {end_user_id}. Please ensure the user has a connected memory configuration.")
|
||||
except Exception as e:
|
||||
if "No memory configuration found" in str(e):
|
||||
raise
|
||||
logger.error(f"Failed to get connected config for end_user {group_id}: {e}")
|
||||
raise ValueError(f"Unable to determine memory configuration for end_user {group_id}: {e}")
|
||||
raise # Re-raise our specific error
|
||||
logger.error(f"Failed to get connected config for end_user {end_user_id}: {e}")
|
||||
raise ValueError(f"Unable to determine memory configuration for end_user {end_user_id}: {e}")
|
||||
|
||||
import time
|
||||
start_time = time.time()
|
||||
@@ -312,7 +314,7 @@ class MemoryAgentService:
|
||||
# Log failed operation
|
||||
if audit_logger:
|
||||
duration = time.time() - start_time
|
||||
audit_logger.log_operation(operation="WRITE", config_id=config_id, group_id=group_id, success=False, duration=duration, error=error_msg)
|
||||
audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id, success=False, duration=duration, error=error_msg)
|
||||
|
||||
raise ValueError(error_msg)
|
||||
|
||||
@@ -320,24 +322,23 @@ class MemoryAgentService:
|
||||
if storage_type == "rag":
|
||||
# For RAG storage, convert messages to single string
|
||||
message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages])
|
||||
result = await write_rag(group_id, message_text, user_rag_memory_id)
|
||||
result = await write_rag(end_user_id, message_text, user_rag_memory_id)
|
||||
return result
|
||||
else:
|
||||
async with make_write_graph() as graph:
|
||||
config = {"configurable": {"thread_id": group_id}}
|
||||
config = {"configurable": {"thread_id": end_user_id}}
|
||||
# Convert structured messages to LangChain messages
|
||||
langchain_messages = []
|
||||
for msg in messages:
|
||||
if msg['role'] == 'user':
|
||||
langchain_messages.append(HumanMessage(content=msg['content']))
|
||||
elif msg['role'] == 'assistant':
|
||||
from langchain_core.messages import AIMessage
|
||||
langchain_messages.append(AIMessage(content=msg['content']))
|
||||
|
||||
|
||||
# 初始状态 - 包含所有必要字段
|
||||
initial_state = {
|
||||
"messages": langchain_messages,
|
||||
"group_id": group_id,
|
||||
"end_user_id": end_user_id,
|
||||
"memory_config": memory_config
|
||||
}
|
||||
|
||||
@@ -354,14 +355,14 @@ class MemoryAgentService:
|
||||
contents = massages.get('write_result')
|
||||
# Convert messages back to string for logging
|
||||
message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages])
|
||||
return self.writer_messages_deal(massagesstatus, start_time, group_id, config_id, message_text, contents)
|
||||
return self.writer_messages_deal(massagesstatus, start_time, end_user_id, config_id, message_text, contents)
|
||||
except Exception as e:
|
||||
# Ensure proper error handling and logging
|
||||
error_msg = f"Write operation failed: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
if audit_logger:
|
||||
duration = time.time() - start_time
|
||||
audit_logger.log_operation(operation="WRITE", config_id=config_id, group_id=group_id, success=False, duration=duration, error=error_msg)
|
||||
audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id, success=False, duration=duration, error=error_msg)
|
||||
raise ValueError(error_msg)
|
||||
|
||||
|
||||
@@ -369,15 +370,14 @@ class MemoryAgentService:
|
||||
|
||||
async def read_memory(
|
||||
self,
|
||||
group_id: str,
|
||||
end_user_id: str,
|
||||
message: str,
|
||||
history: List[Dict],
|
||||
search_switch: str,
|
||||
config_id: Optional[str],
|
||||
config_id: Optional[UUID],
|
||||
db: Session,
|
||||
storage_type: str,
|
||||
user_rag_memory_id: str
|
||||
) -> Dict:
|
||||
user_rag_memory_id: str) -> Dict:
|
||||
"""
|
||||
Process read operation with config_id
|
||||
|
||||
@@ -387,7 +387,7 @@ class MemoryAgentService:
|
||||
- "2": Direct answer based on context
|
||||
|
||||
Args:
|
||||
group_id: Group identifier (also used as end_user_id)
|
||||
end_user_id: Group identifier (also used as end_user_id)
|
||||
message: User message
|
||||
history: Conversation history
|
||||
search_switch: Search mode switch
|
||||
@@ -405,22 +405,22 @@ class MemoryAgentService:
|
||||
|
||||
import time
|
||||
start_time = time.time()
|
||||
logger.info(f"[PERF] read_memory started for group_id={group_id}, search_switch={search_switch}")
|
||||
ori_message= message
|
||||
|
||||
# Resolve config_id if None using end_user's connected config
|
||||
if config_id is None:
|
||||
try:
|
||||
config_id = get_end_user_connected_config(group_id, db)
|
||||
config_id=config_id.get('memory_config_id')
|
||||
connected_config = get_end_user_connected_config(end_user_id, db)
|
||||
config_id = connected_config.get("memory_config_id")
|
||||
if config_id is None:
|
||||
raise ValueError(f"No memory configuration found for end_user {group_id}. Please ensure the user has a connected memory configuration.")
|
||||
raise ValueError(f"No memory configuration found for end_user {end_user_id}. Please ensure the user has a connected memory configuration.")
|
||||
except Exception as e:
|
||||
if "No memory configuration found" in str(e):
|
||||
raise # Re-raise our specific error
|
||||
logger.error(f"Failed to get connected config for end_user {group_id}: {e}")
|
||||
raise ValueError(f"Unable to determine memory configuration for end_user {group_id}: {e}")
|
||||
logger.error(f"Failed to get connected config for end_user {end_user_id}: {e}")
|
||||
raise ValueError(f"Unable to determine memory configuration for end_user {end_user_id}: {e}")
|
||||
|
||||
logger.info(f"Read operation for group {group_id} with config_id {config_id}")
|
||||
logger.info(f"Read operation for group {end_user_id} with config_id {config_id}")
|
||||
|
||||
# 导入审计日志记录器
|
||||
try:
|
||||
@@ -448,7 +448,7 @@ class MemoryAgentService:
|
||||
audit_logger.log_operation(
|
||||
operation="READ",
|
||||
config_id=config_id,
|
||||
group_id=group_id,
|
||||
end_user_id=end_user_id,
|
||||
success=False,
|
||||
duration=duration,
|
||||
error=error_msg
|
||||
@@ -458,16 +458,16 @@ class MemoryAgentService:
|
||||
|
||||
# Step 2: Prepare history
|
||||
history.append({"role": "user", "content": message})
|
||||
logger.debug(f"Group ID:{group_id}, Message:{message}, History:{history}, Config ID:{config_id}")
|
||||
logger.debug(f"Group ID:{end_user_id}, Message:{message}, History:{history}, Config ID:{config_id}")
|
||||
|
||||
# Step 3: Initialize MCP client and execute read workflow
|
||||
graph_exec_start = time.time()
|
||||
try:
|
||||
async with make_read_graph() as graph:
|
||||
config = {"configurable": {"thread_id": group_id}}
|
||||
config = {"configurable": {"thread_id": end_user_id}}
|
||||
# 初始状态 - 包含所有必要字段
|
||||
initial_state = {"messages": [HumanMessage(content=message)], "search_switch": search_switch,
|
||||
"group_id": group_id
|
||||
"end_user_id": end_user_id
|
||||
, "storage_type": storage_type, "user_rag_memory_id": user_rag_memory_id,
|
||||
"memory_config": memory_config}
|
||||
# 获取节点更新信息
|
||||
@@ -562,13 +562,13 @@ class MemoryAgentService:
|
||||
if '信息不足,无法回答。' != str(summary) and str(search_switch).strip() != "2":
|
||||
# 使用 upsert 方法
|
||||
repo.upsert(
|
||||
end_user_id=group_id,
|
||||
messages=message,
|
||||
end_user_id=end_user_id,
|
||||
messages=ori_message,
|
||||
aimessages=summary,
|
||||
retrieved_content=retrieved_content,
|
||||
search_switch=str(search_switch)
|
||||
)
|
||||
logger.info(f"成功保存短期记忆: group_id={group_id}, search_switch={search_switch}")
|
||||
logger.info(f"成功保存短期记忆: end_user_id={end_user_id}, search_switch={search_switch}")
|
||||
else:
|
||||
logger.debug(f"跳过保存短期记忆: summary={summary[:50] if summary else 'None'}, search_switch={search_switch}")
|
||||
|
||||
@@ -584,7 +584,7 @@ class MemoryAgentService:
|
||||
audit_logger.log_operation(
|
||||
operation="READ",
|
||||
config_id=config_id,
|
||||
group_id=group_id,
|
||||
end_user_id=end_user_id,
|
||||
success=True,
|
||||
duration=duration
|
||||
)
|
||||
@@ -596,20 +596,20 @@ class MemoryAgentService:
|
||||
except Exception as e:
|
||||
# Ensure proper error handling and logging
|
||||
error_msg = f"Read operation failed: {str(e)}"
|
||||
total_time = time.time() - start_time
|
||||
logger.error(f"[PERF] read_memory failed after {total_time:.4f}s: {error_msg}")
|
||||
logger.error(error_msg)
|
||||
if audit_logger:
|
||||
duration = time.time() - start_time
|
||||
audit_logger.log_operation(
|
||||
operation="READ",
|
||||
config_id=config_id,
|
||||
group_id=group_id,
|
||||
end_user_id=end_user_id,
|
||||
success=False,
|
||||
duration=duration,
|
||||
error=error_msg
|
||||
)
|
||||
raise ValueError(error_msg)
|
||||
|
||||
|
||||
def get_messages_list(self, user_input: Write_UserInput) -> list[dict]:
|
||||
"""
|
||||
Get standardized message list from user input.
|
||||
@@ -654,7 +654,7 @@ class MemoryAgentService:
|
||||
logger.info(f"Validation successful: Structured message list, count: {len(user_input.messages)}")
|
||||
return user_input.messages
|
||||
|
||||
async def classify_message_type(self, message: str, config_id: int, db: Session) -> Dict:
|
||||
async def classify_message_type(self, message: str, config_id: UUID, db: Session) -> Dict:
|
||||
"""
|
||||
Determine the type of user message (read or write)
|
||||
Updated to eliminate global variables in favor of explicit parameters.
|
||||
@@ -681,10 +681,9 @@ class MemoryAgentService:
|
||||
status = await status_typle(message, memory_config.llm_model_id)
|
||||
logger.debug(f"Message type: {status}")
|
||||
return status
|
||||
|
||||
async def generate_summary_from_retrieve(
|
||||
self,
|
||||
group_id: str,
|
||||
end_user_id: str,
|
||||
retrieve_info: str,
|
||||
history: List[Dict],
|
||||
query: str,
|
||||
@@ -708,16 +707,16 @@ class MemoryAgentService:
|
||||
"""
|
||||
if config_id is None:
|
||||
try:
|
||||
config_id = get_end_user_connected_config(group_id, db)
|
||||
config_id = get_end_user_connected_config(end_user_id, db)
|
||||
config_id = config_id.get('memory_config_id')
|
||||
if config_id is None:
|
||||
raise ValueError(
|
||||
f"No memory configuration found for end_user {group_id}. Please ensure the user has a connected memory configuration.")
|
||||
f"No memory configuration found for end_user {end_user_id}. Please ensure the user has a connected memory configuration.")
|
||||
except Exception as e:
|
||||
if "No memory configuration found" in str(e):
|
||||
raise # Re-raise our specific error
|
||||
logger.error(f"Failed to get connected config for end_user {group_id}: {e}")
|
||||
raise ValueError(f"Unable to determine memory configuration for end_user {group_id}: {e}")
|
||||
logger.error(f"Failed to get connected config for end_user {end_user_id}: {e}")
|
||||
raise ValueError(f"Unable to determine memory configuration for end_user {end_user_id}: {e}")
|
||||
logger.info(f"Generating summary from retrieve info for query: {query[:50]}...")
|
||||
|
||||
try:
|
||||
@@ -727,6 +726,7 @@ class MemoryAgentService:
|
||||
config_id=config_id,
|
||||
service_name="MemoryAgentService"
|
||||
)
|
||||
|
||||
# 导入必要的模块
|
||||
from app.core.memory.agent.langgraph_graph.nodes.summary_nodes import summary_llm
|
||||
from app.core.memory.agent.models.summary_models import RetrieveSummaryResponse
|
||||
@@ -766,7 +766,7 @@ class MemoryAgentService:
|
||||
"""
|
||||
统计知识库类型分布,包含:
|
||||
1. PostgreSQL 中的知识库类型:General, Web, Third-party, Folder(根据 workspace_id 过滤)
|
||||
2. Neo4j 中的 memory 类型(仅统计 Chunk 数量,根据 end_user_id/group_id 过滤)
|
||||
2. Neo4j 中的 memory 类型(仅统计 Chunk 数量,根据 end_user_id/end_user_id 过滤)
|
||||
3. total: 所有类型的总和
|
||||
|
||||
参数:
|
||||
@@ -852,11 +852,11 @@ class MemoryAgentService:
|
||||
for end_user in end_users:
|
||||
end_user_id_str = str(end_user.id)
|
||||
memory_query = """
|
||||
MATCH (n:Chunk) WHERE n.group_id = $group_id RETURN count(n) AS Count
|
||||
MATCH (n:Chunk) WHERE n.end_user_id = $end_user_id RETURN count(n) AS Count
|
||||
"""
|
||||
neo4j_result = await _neo4j_connector.execute_query(
|
||||
memory_query,
|
||||
group_id=end_user_id_str,
|
||||
end_user_id=end_user_id_str,
|
||||
)
|
||||
chunk_count = neo4j_result[0]["Count"] if neo4j_result else 0
|
||||
total_chunks += chunk_count
|
||||
@@ -896,7 +896,7 @@ class MemoryAgentService:
|
||||
获取指定用户的热门记忆标签
|
||||
|
||||
参数:
|
||||
- end_user_id: 用户ID(可选),对应Neo4j中的group_id字段
|
||||
- end_user_id: 用户ID(可选),对应Neo4j中的end_user_id字段
|
||||
- limit: 返回标签数量限制
|
||||
|
||||
返回格式:
|
||||
@@ -906,7 +906,7 @@ class MemoryAgentService:
|
||||
]
|
||||
"""
|
||||
try:
|
||||
# by_user=False 表示按 group_id 查询(在Neo4j中,group_id就是用户维度)
|
||||
# by_user=False 表示按 end_user_id 查询(在Neo4j中,end_user_id就是用户维度)
|
||||
tags = await get_hot_memory_tags(end_user_id, limit=limit, by_user=False)
|
||||
payload=[]
|
||||
for tag, freq in tags:
|
||||
@@ -981,21 +981,21 @@ class MemoryAgentService:
|
||||
# 查询该用户的语句
|
||||
query = (
|
||||
"MATCH (s:Statement) "
|
||||
"WHERE ($group_id IS NULL OR s.group_id = $group_id) AND s.statement IS NOT NULL "
|
||||
"WHERE ($end_user_id IS NULL OR s.end_user_id = $end_user_id) AND s.statement IS NOT NULL "
|
||||
"RETURN s.statement AS statement "
|
||||
"ORDER BY s.created_at DESC LIMIT 100"
|
||||
)
|
||||
rows = await connector.execute_query(query, group_id=end_user_id)
|
||||
rows = await connector.execute_query(query, end_user_id=end_user_id)
|
||||
statements = [r.get("statement", "") for r in rows if r.get("statement")]
|
||||
|
||||
# 查询该用户的热门实体
|
||||
entity_query = (
|
||||
"MATCH (e:ExtractedEntity) "
|
||||
"WHERE ($group_id IS NULL OR e.group_id = $group_id) AND e.entity_type <> '人物' AND e.name IS NOT NULL "
|
||||
"WHERE ($end_user_id IS NULL OR e.end_user_id = $end_user_id) AND e.entity_type <> '人物' AND e.name IS NOT NULL "
|
||||
"RETURN e.name AS name, count(e) AS frequency "
|
||||
"ORDER BY frequency DESC LIMIT 20"
|
||||
)
|
||||
entity_rows = await connector.execute_query(entity_query, group_id=end_user_id)
|
||||
entity_rows = await connector.execute_query(entity_query, end_user_id=end_user_id)
|
||||
entities = [f"{r['name']} ({r['frequency']})" for r in entity_rows]
|
||||
|
||||
await connector.close()
|
||||
@@ -1048,14 +1048,14 @@ class MemoryAgentService:
|
||||
names_to_exclude = ['AI', 'Caroline', 'Melanie', 'Jon', 'Gina', '用户', 'AI助手', 'John', 'Maria']
|
||||
hot_tag_query = (
|
||||
"MATCH (e:ExtractedEntity) "
|
||||
"WHERE ($group_id IS NULL OR e.group_id = $group_id) AND e.entity_type <> '人物' "
|
||||
"WHERE ($end_user_id IS NULL OR e.end_user_id = $end_user_id) AND e.entity_type <> '人物' "
|
||||
"AND e.name IS NOT NULL AND NOT e.name IN $names_to_exclude "
|
||||
"RETURN e.name AS name, count(e) AS frequency "
|
||||
"ORDER BY frequency DESC LIMIT 4"
|
||||
)
|
||||
hot_tag_rows = await connector.execute_query(
|
||||
hot_tag_query,
|
||||
group_id=end_user_id,
|
||||
end_user_id=end_user_id,
|
||||
names_to_exclude=names_to_exclude
|
||||
)
|
||||
await connector.close()
|
||||
@@ -1189,6 +1189,16 @@ def get_end_user_connected_config(end_user_id: str, db: Session) -> Dict[str, An
|
||||
|
||||
# 3. 从 config 中提取 memory_config_id
|
||||
config = latest_release.config or {}
|
||||
|
||||
# 如果 config 是字符串,解析为字典
|
||||
if isinstance(config, str):
|
||||
import json
|
||||
try:
|
||||
config = json.loads(config)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"Failed to parse config JSON for release {latest_release.id}")
|
||||
config = {}
|
||||
|
||||
memory_obj = config.get('memory', {})
|
||||
memory_config_id = memory_obj.get('memory_content') if isinstance(memory_obj, dict) else None
|
||||
|
||||
@@ -1227,7 +1237,7 @@ def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session)
|
||||
"""
|
||||
from app.models.app_release_model import AppRelease
|
||||
from app.models.end_user_model import EndUser
|
||||
from app.models.data_config_model import DataConfig
|
||||
from app.models.memory_config_model import MemoryConfig
|
||||
from sqlalchemy import select
|
||||
|
||||
logger.info(f"Batch getting connected configs for {len(end_user_ids)} end_users")
|
||||
@@ -1240,10 +1250,10 @@ def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session)
|
||||
|
||||
# 1. 批量查询所有 end_user 及其 app_id
|
||||
end_users = db.query(EndUser).filter(EndUser.id.in_(end_user_ids)).all()
|
||||
|
||||
|
||||
# 创建 end_user_id -> app_id 的映射
|
||||
user_to_app = {str(eu.id): eu.app_id for eu in end_users}
|
||||
|
||||
|
||||
# 记录未找到的用户
|
||||
found_user_ids = set(user_to_app.keys())
|
||||
missing_user_ids = set(end_user_ids) - found_user_ids
|
||||
@@ -1285,13 +1295,13 @@ def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session)
|
||||
# 批量查询 memory_config_name
|
||||
config_id_to_name = {}
|
||||
if memory_config_ids:
|
||||
memory_configs = db.query(DataConfig).filter(DataConfig.config_id.in_(memory_config_ids)).all()
|
||||
config_id_to_name = {str(mc.config_id): mc.config_name for mc in memory_configs}
|
||||
memory_configs = db.query(MemoryConfig).filter(MemoryConfig.id.in_(memory_config_ids)).all()
|
||||
config_id_to_name = {str(mc.id): mc.config_name for mc in memory_configs}
|
||||
|
||||
# 4. 构建最终结果
|
||||
for end_user_id, app_id in user_to_app.items():
|
||||
release = app_to_release.get(app_id)
|
||||
|
||||
|
||||
if not release:
|
||||
logger.warning(f"No active release found for app: {app_id} (end_user: {end_user_id})")
|
||||
result[end_user_id] = {"memory_config_id": None, "memory_config_name": None}
|
||||
@@ -1303,7 +1313,7 @@ def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session)
|
||||
memory_config_id = memory_obj.get('memory_content') if isinstance(memory_obj, dict) else None
|
||||
|
||||
# 获取配置名称
|
||||
memory_config_name = config_id_to_name.get(str(memory_config_id)) if memory_config_id else None
|
||||
memory_config_name = config_id_to_name.get(memory_config_id) if memory_config_id else None
|
||||
|
||||
result[end_user_id] = {
|
||||
"memory_config_id": memory_config_id,
|
||||
|
||||
@@ -25,7 +25,7 @@ class MemoryAPIService:
|
||||
|
||||
This service provides a thin layer that:
|
||||
1. Validates end_user exists and belongs to the authorized workspace
|
||||
2. Maps end_user_id to group_id for memory operations
|
||||
2. Maps end_user_id to end_user_id for memory operations
|
||||
3. Delegates to MemoryAgentService for actual memory read/write operations
|
||||
"""
|
||||
|
||||
@@ -68,7 +68,7 @@ class MemoryAPIService:
|
||||
)
|
||||
|
||||
end_user = self.db.query(EndUser).filter(EndUser.id == end_user_uuid).first()
|
||||
|
||||
|
||||
if not end_user:
|
||||
logger.warning(f"End user not found: {end_user_id}")
|
||||
raise ResourceNotFoundException(
|
||||
@@ -118,7 +118,7 @@ class MemoryAPIService:
|
||||
|
||||
Args:
|
||||
workspace_id: Workspace ID for resource validation
|
||||
end_user_id: End user identifier (used as group_id)
|
||||
end_user_id: End user identifier (used as end_user_id)
|
||||
message: Message content to store
|
||||
config_id: Optional memory configuration ID
|
||||
storage_type: Storage backend (neo4j or rag)
|
||||
@@ -136,14 +136,13 @@ class MemoryAPIService:
|
||||
# Validate end_user exists and belongs to workspace
|
||||
self.validate_end_user(end_user_id, workspace_id)
|
||||
|
||||
# Use end_user_id as group_id for memory operations
|
||||
group_id = end_user_id
|
||||
# Use end_user_id as end_user_id for memory operations
|
||||
|
||||
try:
|
||||
# Delegate to MemoryAgentService
|
||||
result = await MemoryAgentService().write_memory(
|
||||
group_id=group_id,
|
||||
message=message,
|
||||
end_user_id=end_user_id,
|
||||
messages=message,
|
||||
config_id=config_id,
|
||||
db=self.db,
|
||||
storage_type=storage_type,
|
||||
@@ -189,7 +188,7 @@ class MemoryAPIService:
|
||||
|
||||
Args:
|
||||
workspace_id: Workspace ID for resource validation
|
||||
end_user_id: End user identifier (used as group_id)
|
||||
end_user_id: End user identifier (used as end_user_id)
|
||||
message: Query message
|
||||
search_switch: Search mode (0=deep search with verification, 1=deep search, 2=fast search)
|
||||
config_id: Optional memory configuration ID
|
||||
@@ -208,13 +207,13 @@ class MemoryAPIService:
|
||||
# Validate end_user exists and belongs to workspace
|
||||
self.validate_end_user(end_user_id, workspace_id)
|
||||
|
||||
# Use end_user_id as group_id for memory operations
|
||||
group_id = end_user_id
|
||||
# Use end_user_id as end_user_id for memory operations
|
||||
|
||||
|
||||
try:
|
||||
# Delegate to MemoryAgentService
|
||||
result = await MemoryAgentService().read_memory(
|
||||
group_id=group_id,
|
||||
end_user_id=end_user_id,
|
||||
message=message,
|
||||
history=[],
|
||||
search_switch=search_switch,
|
||||
|
||||
@@ -326,7 +326,7 @@ class MemoryBaseService:
|
||||
|
||||
Args:
|
||||
summary_id: Summary节点的ID
|
||||
end_user_id: 终端用户ID (group_id)
|
||||
end_user_id: 终端用户ID (end_user_id)
|
||||
|
||||
Returns:
|
||||
最大emotion_intensity对应的emotion_type,如果没有则返回None
|
||||
@@ -334,7 +334,7 @@ class MemoryBaseService:
|
||||
try:
|
||||
query = """
|
||||
MATCH (s:MemorySummary)
|
||||
WHERE elementId(s) = $summary_id AND s.group_id = $group_id
|
||||
WHERE elementId(s) = $summary_id AND s.end_user_id = $end_user_id
|
||||
MATCH (s)-[:DERIVED_FROM_STATEMENT]->(stmt:Statement)
|
||||
WHERE stmt.emotion_type IS NOT NULL
|
||||
AND stmt.emotion_intensity IS NOT NULL
|
||||
@@ -347,7 +347,7 @@ class MemoryBaseService:
|
||||
result = await self.neo4j_connector.execute_query(
|
||||
query,
|
||||
summary_id=summary_id,
|
||||
group_id=end_user_id
|
||||
end_user_id=end_user_id
|
||||
)
|
||||
|
||||
if result and len(result) > 0:
|
||||
@@ -381,10 +381,10 @@ class MemoryBaseService:
|
||||
if end_user_id:
|
||||
query = """
|
||||
MATCH (n:MemorySummary)
|
||||
WHERE n.group_id = $group_id
|
||||
WHERE n.end_user_id = $end_user_id
|
||||
RETURN count(n) as count
|
||||
"""
|
||||
result = await self.neo4j_connector.execute_query(query, group_id=end_user_id)
|
||||
result = await self.neo4j_connector.execute_query(query, end_user_id=end_user_id)
|
||||
else:
|
||||
query = """
|
||||
MATCH (n:MemorySummary)
|
||||
@@ -423,12 +423,12 @@ class MemoryBaseService:
|
||||
if end_user_id:
|
||||
semantic_query = """
|
||||
MATCH (e:ExtractedEntity)
|
||||
WHERE e.group_id = $group_id AND e.is_explicit_memory = true
|
||||
WHERE e.end_user_id = $end_user_id AND e.is_explicit_memory = true
|
||||
RETURN count(e) as count
|
||||
"""
|
||||
semantic_result = await self.neo4j_connector.execute_query(
|
||||
semantic_query,
|
||||
group_id=end_user_id
|
||||
end_user_id=end_user_id
|
||||
)
|
||||
else:
|
||||
semantic_query = """
|
||||
@@ -519,7 +519,7 @@ class MemoryBaseService:
|
||||
"""
|
||||
|
||||
if end_user_id:
|
||||
query += " AND n.group_id = $group_id"
|
||||
query += " AND n.end_user_id = $end_user_id"
|
||||
|
||||
query += """
|
||||
RETURN sum(CASE WHEN n.activation_value IS NOT NULL AND n.activation_value < $threshold THEN 1 ELSE 0 END) as low_activation_nodes
|
||||
@@ -528,7 +528,7 @@ class MemoryBaseService:
|
||||
# 设置查询参数
|
||||
params = {'threshold': forgetting_threshold}
|
||||
if end_user_id:
|
||||
params['group_id'] = end_user_id
|
||||
params['end_user_id'] = end_user_id
|
||||
|
||||
# 执行查询
|
||||
result = await self.neo4j_connector.execute_query(query, **params)
|
||||
|
||||
@@ -14,7 +14,7 @@ from app.core.validators.memory_config_validators import (
|
||||
validate_embedding_model,
|
||||
validate_model_exists_and_active,
|
||||
)
|
||||
from app.repositories.data_config_repository import DataConfigRepository
|
||||
from app.repositories.memory_config_repository import MemoryConfigRepository
|
||||
from app.schemas.memory_config_schema import (
|
||||
ConfigurationError,
|
||||
InvalidConfigError,
|
||||
@@ -23,20 +23,24 @@ from app.schemas.memory_config_schema import (
|
||||
ModelNotFoundError,
|
||||
)
|
||||
from sqlalchemy.orm import Session
|
||||
from uuid import UUID
|
||||
|
||||
logger = get_logger(__name__)
|
||||
config_logger = get_config_logger()
|
||||
|
||||
import uuid
|
||||
|
||||
def _validate_config_id(config_id):
|
||||
"""Validate configuration ID format."""
|
||||
"""Validate configuration ID format (supports both UUID and integer)."""
|
||||
if isinstance(config_id, uuid.UUID):
|
||||
return config_id
|
||||
|
||||
if config_id is None:
|
||||
raise InvalidConfigError(
|
||||
"Configuration ID cannot be None",
|
||||
field_name="config_id",
|
||||
invalid_value=config_id,
|
||||
)
|
||||
|
||||
|
||||
if isinstance(config_id, int):
|
||||
if config_id <= 0:
|
||||
raise InvalidConfigError(
|
||||
@@ -45,10 +49,19 @@ def _validate_config_id(config_id):
|
||||
invalid_value=config_id,
|
||||
)
|
||||
return config_id
|
||||
|
||||
|
||||
if isinstance(config_id, str):
|
||||
config_id_stripped = config_id.strip()
|
||||
|
||||
# Try parsing as UUID first
|
||||
try:
|
||||
parsed_id = int(config_id.strip())
|
||||
return uuid.UUID(config_id_stripped)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
# Fall back to integer parsing
|
||||
try:
|
||||
parsed_id = config_id_stripped
|
||||
if parsed_id <= 0:
|
||||
raise InvalidConfigError(
|
||||
f"Configuration ID must be positive: {parsed_id}",
|
||||
@@ -58,13 +71,13 @@ def _validate_config_id(config_id):
|
||||
return parsed_id
|
||||
except ValueError:
|
||||
raise InvalidConfigError(
|
||||
f"Invalid configuration ID format: '{config_id}'",
|
||||
f"Invalid configuration ID format: '{config_id}' (must be UUID or positive integer)",
|
||||
field_name="config_id",
|
||||
invalid_value=config_id,
|
||||
)
|
||||
|
||||
|
||||
raise InvalidConfigError(
|
||||
f"Invalid type for configuration ID: expected int or str, got {type(config_id).__name__}",
|
||||
f"Invalid type for configuration ID: expected UUID, int or str, got {type(config_id).__name__}",
|
||||
field_name="config_id",
|
||||
invalid_value=config_id,
|
||||
)
|
||||
@@ -73,61 +86,61 @@ def _validate_config_id(config_id):
|
||||
class MemoryConfigService:
|
||||
"""
|
||||
Centralized service for memory configuration loading and validation.
|
||||
|
||||
|
||||
This class provides a single implementation of configuration loading logic
|
||||
that can be shared across multiple services, eliminating code duplication.
|
||||
|
||||
|
||||
Usage:
|
||||
config_service = MemoryConfigService(db)
|
||||
memory_config = config_service.load_memory_config(config_id)
|
||||
model_config = config_service.get_model_config(model_id)
|
||||
"""
|
||||
|
||||
|
||||
def __init__(self, db: Session):
|
||||
"""Initialize the service with a database session.
|
||||
|
||||
|
||||
Args:
|
||||
db: SQLAlchemy database session
|
||||
"""
|
||||
self.db = db
|
||||
|
||||
|
||||
def load_memory_config(
|
||||
self,
|
||||
config_id: int,
|
||||
config_id: UUID,
|
||||
service_name: str = "MemoryConfigService",
|
||||
) -> MemoryConfig:
|
||||
"""
|
||||
Load memory configuration from database by config_id.
|
||||
|
||||
|
||||
Args:
|
||||
config_id: Configuration ID from database
|
||||
config_id: Configuration ID (UUID) from database
|
||||
service_name: Name of the calling service (for logging purposes)
|
||||
|
||||
|
||||
Returns:
|
||||
MemoryConfig: Immutable configuration object
|
||||
|
||||
|
||||
Raises:
|
||||
ConfigurationError: If validation fails
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
|
||||
config_logger.info(
|
||||
"Starting memory configuration loading",
|
||||
extra={
|
||||
"operation": "load_memory_config",
|
||||
"service": service_name,
|
||||
"config_id": config_id,
|
||||
"config_id": str(config_id),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
logger.info(f"Loading memory configuration from database: config_id={config_id}")
|
||||
|
||||
|
||||
try:
|
||||
validated_config_id = _validate_config_id(config_id)
|
||||
|
||||
|
||||
# Step 1: Get config and workspace
|
||||
db_query_start = time.time()
|
||||
result = DataConfigRepository.get_config_with_workspace(self.db, validated_config_id)
|
||||
result = MemoryConfigRepository.get_config_with_workspace(self.db, validated_config_id)
|
||||
db_query_time = time.time() - db_query_start
|
||||
logger.info(f"[PERF] Config+Workspace query: {db_query_time:.4f}s")
|
||||
if not result:
|
||||
@@ -136,18 +149,18 @@ class MemoryConfigService:
|
||||
"Configuration not found in database",
|
||||
extra={
|
||||
"operation": "load_memory_config",
|
||||
"config_id": validated_config_id,
|
||||
"config_id": str(config_id),
|
||||
"load_result": "not_found",
|
||||
"elapsed_ms": elapsed_ms,
|
||||
"service": service_name,
|
||||
},
|
||||
)
|
||||
raise ConfigurationError(
|
||||
f"Configuration {validated_config_id} not found in database"
|
||||
f"Configuration {config_id} not found in database"
|
||||
)
|
||||
|
||||
|
||||
memory_config, workspace = result
|
||||
|
||||
|
||||
# Step 2: Validate embedding model (returns both UUID and name)
|
||||
embed_start = time.time()
|
||||
embedding_uuid, embedding_name = validate_embedding_model(
|
||||
@@ -159,7 +172,7 @@ class MemoryConfigService:
|
||||
)
|
||||
embed_time = time.time() - embed_start
|
||||
logger.info(f"[PERF] Embedding validation: {embed_time:.4f}s")
|
||||
|
||||
|
||||
# Step 3: Resolve LLM model
|
||||
llm_start = time.time()
|
||||
llm_uuid, llm_name = validate_and_resolve_model_id(
|
||||
@@ -173,7 +186,7 @@ class MemoryConfigService:
|
||||
)
|
||||
llm_time = time.time() - llm_start
|
||||
logger.info(f"[PERF] LLM validation: {llm_time:.4f}s")
|
||||
|
||||
|
||||
# Step 4: Resolve optional rerank model
|
||||
rerank_start = time.time()
|
||||
rerank_uuid = None
|
||||
@@ -191,10 +204,10 @@ class MemoryConfigService:
|
||||
rerank_time = time.time() - rerank_start
|
||||
if memory_config.rerank_id:
|
||||
logger.info(f"[PERF] Rerank validation: {rerank_time:.4f}s")
|
||||
|
||||
|
||||
# Note: embedding_name is now returned from validate_embedding_model above
|
||||
# No need for redundant query!
|
||||
|
||||
|
||||
# Create immutable MemoryConfig object
|
||||
config = MemoryConfig(
|
||||
config_id=memory_config.config_id,
|
||||
@@ -235,9 +248,9 @@ class MemoryConfigService:
|
||||
pruning_scene=memory_config.pruning_scene or "education",
|
||||
pruning_threshold=float(memory_config.pruning_threshold) if memory_config.pruning_threshold is not None else 0.5,
|
||||
)
|
||||
|
||||
|
||||
elapsed_ms = (time.time() - start_time) * 1000
|
||||
|
||||
|
||||
config_logger.info(
|
||||
"Memory configuration loaded successfully",
|
||||
extra={
|
||||
@@ -250,13 +263,13 @@ class MemoryConfigService:
|
||||
"elapsed_ms": elapsed_ms,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
logger.info(f"Memory configuration loaded successfully: {config.config_name}")
|
||||
return config
|
||||
|
||||
|
||||
except Exception as e:
|
||||
elapsed_ms = (time.time() - start_time) * 1000
|
||||
|
||||
|
||||
config_logger.error(
|
||||
"Failed to load memory configuration",
|
||||
extra={
|
||||
@@ -270,7 +283,7 @@ class MemoryConfigService:
|
||||
},
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
|
||||
logger.error(f"Failed to load memory configuration {config_id}: {e}")
|
||||
if isinstance(e, (ConfigurationError, ValueError)):
|
||||
raise
|
||||
|
||||
@@ -717,8 +717,8 @@ class MemoryInteraction:
|
||||
ori_data= await self.connector.execute_query(Memory_Space_Entity, id=self.id)
|
||||
if ori_data!=[]:
|
||||
# name = ori_data[0]['name']
|
||||
group_id = [i['group_id'] for i in ori_data][0]
|
||||
Space_User = await self.connector.execute_query(Memory_Space_User, group_id=group_id)
|
||||
end_user_id = [i['end_user_id'] for i in ori_data][0]
|
||||
Space_User = await self.connector.execute_query(Memory_Space_User, end_user_id=end_user_id)
|
||||
if not Space_User:
|
||||
return []
|
||||
user_id=Space_User[0]['id']
|
||||
|
||||
@@ -34,7 +34,7 @@ class MemoryEpisodicService(MemoryBaseService):
|
||||
|
||||
Args:
|
||||
summary_id: Summary节点的ID
|
||||
end_user_id: 终端用户ID (group_id)
|
||||
end_user_id: 终端用户ID (end_user_id)
|
||||
|
||||
Returns:
|
||||
(标题, 类型)元组,如果不存在则返回默认值
|
||||
@@ -43,14 +43,14 @@ class MemoryEpisodicService(MemoryBaseService):
|
||||
# 查询Summary节点的name(作为title)和memory_type(作为type)
|
||||
query = """
|
||||
MATCH (s:MemorySummary)
|
||||
WHERE elementId(s) = $summary_id AND s.group_id = $group_id
|
||||
WHERE elementId(s) = $summary_id AND s.end_user_id = $end_user_id
|
||||
RETURN s.name AS title, s.memory_type AS type
|
||||
"""
|
||||
|
||||
result = await self.neo4j_connector.execute_query(
|
||||
query,
|
||||
summary_id=summary_id,
|
||||
group_id=end_user_id
|
||||
end_user_id=end_user_id
|
||||
)
|
||||
|
||||
if not result or len(result) == 0:
|
||||
@@ -77,7 +77,7 @@ class MemoryEpisodicService(MemoryBaseService):
|
||||
|
||||
Args:
|
||||
summary_id: Summary节点的ID
|
||||
end_user_id: 终端用户ID (group_id)
|
||||
end_user_id: 终端用户ID (end_user_id)
|
||||
|
||||
Returns:
|
||||
前3个实体的name属性列表
|
||||
@@ -87,7 +87,7 @@ class MemoryEpisodicService(MemoryBaseService):
|
||||
# 按activation_value降序排序,返回前3个
|
||||
query = """
|
||||
MATCH (s:MemorySummary)
|
||||
WHERE elementId(s) = $summary_id AND s.group_id = $group_id
|
||||
WHERE elementId(s) = $summary_id AND s.end_user_id = $end_user_id
|
||||
MATCH (s)-[:DERIVED_FROM_STATEMENT]->(stmt:Statement)
|
||||
MATCH (stmt)-[:REFERENCES_ENTITY]->(entity:ExtractedEntity)
|
||||
WHERE entity.activation_value IS NOT NULL
|
||||
@@ -99,7 +99,7 @@ class MemoryEpisodicService(MemoryBaseService):
|
||||
result = await self.neo4j_connector.execute_query(
|
||||
query,
|
||||
summary_id=summary_id,
|
||||
group_id=end_user_id
|
||||
end_user_id=end_user_id
|
||||
)
|
||||
|
||||
# 提取实体名称
|
||||
@@ -123,7 +123,7 @@ class MemoryEpisodicService(MemoryBaseService):
|
||||
|
||||
Args:
|
||||
summary_id: Summary节点的ID
|
||||
end_user_id: 终端用户ID (group_id)
|
||||
end_user_id: 终端用户ID (end_user_id)
|
||||
|
||||
Returns:
|
||||
所有Statement节点的statement属性内容列表
|
||||
@@ -132,7 +132,7 @@ class MemoryEpisodicService(MemoryBaseService):
|
||||
# 查询Summary节点指向的所有Statement节点
|
||||
query = """
|
||||
MATCH (s:MemorySummary)
|
||||
WHERE elementId(s) = $summary_id AND s.group_id = $group_id
|
||||
WHERE elementId(s) = $summary_id AND s.end_user_id = $end_user_id
|
||||
MATCH (s)-[:DERIVED_FROM_STATEMENT]->(stmt:Statement)
|
||||
WHERE stmt.statement IS NOT NULL AND stmt.statement <> ''
|
||||
RETURN stmt.statement AS statement
|
||||
@@ -141,7 +141,7 @@ class MemoryEpisodicService(MemoryBaseService):
|
||||
result = await self.neo4j_connector.execute_query(
|
||||
query,
|
||||
summary_id=summary_id,
|
||||
group_id=end_user_id
|
||||
end_user_id=end_user_id
|
||||
)
|
||||
|
||||
# 提取statement内容
|
||||
@@ -214,12 +214,12 @@ class MemoryEpisodicService(MemoryBaseService):
|
||||
# 1. 先查询所有情景记忆的总数(不受筛选条件限制)
|
||||
total_all_query = """
|
||||
MATCH (s:MemorySummary)
|
||||
WHERE s.group_id = $group_id
|
||||
WHERE s.end_user_id = $end_user_id
|
||||
RETURN count(s) AS total_all
|
||||
"""
|
||||
total_all_result = await self.neo4j_connector.execute_query(
|
||||
total_all_query,
|
||||
group_id=end_user_id
|
||||
end_user_id=end_user_id
|
||||
)
|
||||
total_all = total_all_result[0]["total_all"] if total_all_result else 0
|
||||
|
||||
@@ -229,7 +229,7 @@ class MemoryEpisodicService(MemoryBaseService):
|
||||
# 3. 构建Cypher查询
|
||||
query = """
|
||||
MATCH (s:MemorySummary)
|
||||
WHERE s.group_id = $group_id
|
||||
WHERE s.end_user_id = $end_user_id
|
||||
"""
|
||||
|
||||
# 添加时间范围过滤
|
||||
@@ -248,7 +248,7 @@ class MemoryEpisodicService(MemoryBaseService):
|
||||
ORDER BY s.created_at DESC
|
||||
"""
|
||||
|
||||
params = {"group_id": end_user_id}
|
||||
params = {"end_user_id": end_user_id}
|
||||
if time_filter:
|
||||
params["time_filter"] = time_filter
|
||||
if title_keyword:
|
||||
@@ -333,14 +333,14 @@ class MemoryEpisodicService(MemoryBaseService):
|
||||
# 1. 查询指定的MemorySummary节点
|
||||
query = """
|
||||
MATCH (s:MemorySummary)
|
||||
WHERE elementId(s) = $summary_id AND s.group_id = $group_id
|
||||
WHERE elementId(s) = $summary_id AND s.end_user_id = $end_user_id
|
||||
RETURN elementId(s) AS id, s.created_at AS created_at
|
||||
"""
|
||||
|
||||
result = await self.neo4j_connector.execute_query(
|
||||
query,
|
||||
summary_id=summary_id,
|
||||
group_id=end_user_id
|
||||
end_user_id=end_user_id
|
||||
)
|
||||
|
||||
# 2. 如果节点不存在,返回错误
|
||||
|
||||
@@ -60,7 +60,7 @@ class MemoryExplicitService(MemoryBaseService):
|
||||
# ========== 1. 查询情景记忆(MemorySummary节点) ==========
|
||||
episodic_query = """
|
||||
MATCH (s:MemorySummary)
|
||||
WHERE s.group_id = $group_id
|
||||
WHERE s.end_user_id = $end_user_id
|
||||
RETURN elementId(s) AS id,
|
||||
s.name AS title,
|
||||
s.content AS content,
|
||||
@@ -70,7 +70,7 @@ class MemoryExplicitService(MemoryBaseService):
|
||||
|
||||
episodic_result = await self.neo4j_connector.execute_query(
|
||||
episodic_query,
|
||||
group_id=end_user_id
|
||||
end_user_id=end_user_id
|
||||
)
|
||||
|
||||
# 处理情景记忆数据
|
||||
@@ -96,7 +96,7 @@ class MemoryExplicitService(MemoryBaseService):
|
||||
# ========== 2. 查询语义记忆(ExtractedEntity节点) ==========
|
||||
semantic_query = """
|
||||
MATCH (e:ExtractedEntity)
|
||||
WHERE e.group_id = $group_id
|
||||
WHERE e.end_user_id = $end_user_id
|
||||
AND e.is_explicit_memory = true
|
||||
RETURN elementId(e) AS id,
|
||||
e.name AS name,
|
||||
@@ -107,7 +107,7 @@ class MemoryExplicitService(MemoryBaseService):
|
||||
|
||||
semantic_result = await self.neo4j_connector.execute_query(
|
||||
semantic_query,
|
||||
group_id=end_user_id
|
||||
end_user_id=end_user_id
|
||||
)
|
||||
|
||||
# 处理语义记忆数据
|
||||
@@ -189,7 +189,7 @@ class MemoryExplicitService(MemoryBaseService):
|
||||
# ========== 1. 先尝试查询情景记忆 ==========
|
||||
episodic_query = """
|
||||
MATCH (s:MemorySummary)
|
||||
WHERE elementId(s) = $memory_id AND s.group_id = $group_id
|
||||
WHERE elementId(s) = $memory_id AND s.end_user_id = $end_user_id
|
||||
RETURN s.name AS title,
|
||||
s.content AS content,
|
||||
s.created_at AS created_at
|
||||
@@ -198,7 +198,7 @@ class MemoryExplicitService(MemoryBaseService):
|
||||
episodic_result = await self.neo4j_connector.execute_query(
|
||||
episodic_query,
|
||||
memory_id=memory_id,
|
||||
group_id=end_user_id
|
||||
end_user_id=end_user_id
|
||||
)
|
||||
|
||||
if episodic_result and len(episodic_result) > 0:
|
||||
@@ -229,7 +229,7 @@ class MemoryExplicitService(MemoryBaseService):
|
||||
semantic_query = """
|
||||
MATCH (e:ExtractedEntity)
|
||||
WHERE elementId(e) = $memory_id
|
||||
AND e.group_id = $group_id
|
||||
AND e.end_user_id = $end_user_id
|
||||
AND e.is_explicit_memory = true
|
||||
RETURN e.name AS name,
|
||||
e.description AS core_definition,
|
||||
@@ -240,7 +240,7 @@ class MemoryExplicitService(MemoryBaseService):
|
||||
semantic_result = await self.neo4j_connector.execute_query(
|
||||
semantic_query,
|
||||
memory_id=memory_id,
|
||||
group_id=end_user_id
|
||||
end_user_id=end_user_id
|
||||
)
|
||||
|
||||
if semantic_result and len(semantic_result) > 0:
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
|
||||
from typing import Optional, Dict, Any, Tuple
|
||||
from datetime import datetime, timezone
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@@ -23,7 +24,7 @@ from app.core.memory.storage_services.forgetting_engine.config_utils import (
|
||||
load_actr_config_from_db,
|
||||
)
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.repositories.data_config_repository import DataConfigRepository
|
||||
from app.repositories.memory_config_repository import MemoryConfigRepository
|
||||
from app.repositories.forgetting_cycle_history_repository import ForgettingCycleHistoryRepository
|
||||
|
||||
|
||||
@@ -70,7 +71,7 @@ class MemoryForgetService:
|
||||
|
||||
def __init__(self):
|
||||
"""初始化服务"""
|
||||
self.config_repository = DataConfigRepository()
|
||||
self.config_repository = MemoryConfigRepository()
|
||||
self.history_repository = ForgettingCycleHistoryRepository()
|
||||
|
||||
def _get_neo4j_connector(self) -> Neo4jConnector:
|
||||
@@ -87,7 +88,7 @@ class MemoryForgetService:
|
||||
async def _get_forgetting_components(
|
||||
self,
|
||||
db: Session,
|
||||
config_id: Optional[int] = None
|
||||
config_id: Optional[UUID] = None
|
||||
) -> Tuple[ACTRCalculator, ForgettingStrategy, ForgettingScheduler, Dict[str, Any]]:
|
||||
"""
|
||||
获取遗忘引擎组件(计算器、策略、调度器)
|
||||
@@ -132,7 +133,7 @@ class MemoryForgetService:
|
||||
async def _get_knowledge_stats(
|
||||
self,
|
||||
connector: Neo4jConnector,
|
||||
group_id: Optional[str] = None,
|
||||
end_user_id: Optional[str] = None,
|
||||
forgetting_threshold: float = 0.3
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
@@ -140,7 +141,7 @@ class MemoryForgetService:
|
||||
|
||||
Args:
|
||||
connector: Neo4j 连接器
|
||||
group_id: 组ID(可选)
|
||||
end_user_id: 组ID(可选)
|
||||
forgetting_threshold: 遗忘阈值
|
||||
|
||||
Returns:
|
||||
@@ -152,8 +153,8 @@ class MemoryForgetService:
|
||||
WHERE (n:Statement OR n:ExtractedEntity OR n:MemorySummary)
|
||||
"""
|
||||
|
||||
if group_id:
|
||||
query += " AND n.group_id = $group_id"
|
||||
if end_user_id:
|
||||
query += " AND n.end_user_id = $end_user_id"
|
||||
|
||||
query += """
|
||||
WITH n,
|
||||
@@ -172,8 +173,8 @@ class MemoryForgetService:
|
||||
"""
|
||||
|
||||
params = {'threshold': forgetting_threshold}
|
||||
if group_id:
|
||||
params['group_id'] = group_id
|
||||
if end_user_id:
|
||||
params['end_user_id'] = end_user_id
|
||||
|
||||
results = await connector.execute_query(query, **params)
|
||||
|
||||
@@ -200,7 +201,7 @@ class MemoryForgetService:
|
||||
async def _get_pending_forgetting_nodes(
|
||||
self,
|
||||
connector: Neo4jConnector,
|
||||
group_id: str,
|
||||
end_user_id: str,
|
||||
forgetting_threshold: float,
|
||||
min_days_since_access: int,
|
||||
limit: int = 20
|
||||
@@ -212,7 +213,7 @@ class MemoryForgetService:
|
||||
|
||||
Args:
|
||||
connector: Neo4j 连接器
|
||||
group_id: 组ID
|
||||
end_user_id: 组ID
|
||||
forgetting_threshold: 遗忘阈值
|
||||
min_days_since_access: 最小未访问天数
|
||||
limit: 返回节点数量限制
|
||||
@@ -229,7 +230,7 @@ class MemoryForgetService:
|
||||
query = """
|
||||
MATCH (n)
|
||||
WHERE (n:Statement OR n:ExtractedEntity OR n:MemorySummary)
|
||||
AND n.group_id = $group_id
|
||||
AND n.end_user_id = $end_user_id
|
||||
AND n.activation_value IS NOT NULL
|
||||
AND n.activation_value < $threshold
|
||||
AND n.last_access_time IS NOT NULL
|
||||
@@ -250,7 +251,7 @@ class MemoryForgetService:
|
||||
"""
|
||||
|
||||
params = {
|
||||
'group_id': group_id,
|
||||
'end_user_id': end_user_id,
|
||||
'threshold': forgetting_threshold,
|
||||
'min_access_time_str': min_access_time_str,
|
||||
'limit': limit
|
||||
@@ -291,10 +292,10 @@ class MemoryForgetService:
|
||||
async def trigger_forgetting_cycle(
|
||||
self,
|
||||
db: Session,
|
||||
group_id: str,
|
||||
end_user_id: str,
|
||||
max_merge_batch_size: Optional[int] = None,
|
||||
min_days_since_access: Optional[int] = None,
|
||||
config_id: Optional[int] = None
|
||||
config_id: Optional[UUID] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
手动触发遗忘周期
|
||||
@@ -303,10 +304,10 @@ class MemoryForgetService:
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
group_id: 组ID(即终端用户ID,必填)
|
||||
end_user_id: 组ID(即终端用户ID,必填)
|
||||
max_merge_batch_size: 最大融合批次大小(可选)
|
||||
min_days_since_access: 最小未访问天数(可选)
|
||||
config_id: 配置ID(必填,由控制器层通过 group_id 获取)
|
||||
config_id: 配置ID(必填,由控制器层通过 end_user_id 获取)
|
||||
|
||||
Returns:
|
||||
dict: 遗忘报告
|
||||
@@ -319,7 +320,7 @@ class MemoryForgetService:
|
||||
|
||||
# 运行遗忘周期(LLM 客户端将在需要时由 forgetting_strategy 内部获取)
|
||||
report = await forgetting_scheduler.run_forgetting_cycle(
|
||||
group_id=group_id,
|
||||
end_user_id=end_user_id,
|
||||
max_merge_batch_size=max_merge_batch_size,
|
||||
min_days_since_access=min_days_since_access,
|
||||
config_id=config_id,
|
||||
@@ -338,7 +339,7 @@ class MemoryForgetService:
|
||||
stats_query = """
|
||||
MATCH (n)
|
||||
WHERE (n:Statement OR n:ExtractedEntity OR n:MemorySummary OR n:Chunk)
|
||||
AND n.group_id = $group_id
|
||||
AND n.end_user_id = $end_user_id
|
||||
RETURN
|
||||
count(n) as total_nodes,
|
||||
avg(n.activation_value) as average_activation,
|
||||
@@ -347,7 +348,7 @@ class MemoryForgetService:
|
||||
|
||||
stats_results = await connector.execute_query(
|
||||
stats_query,
|
||||
group_id=group_id,
|
||||
end_user_id=end_user_id,
|
||||
threshold=config['forgetting_threshold']
|
||||
)
|
||||
|
||||
@@ -364,7 +365,7 @@ class MemoryForgetService:
|
||||
# 保存历史记录到数据库
|
||||
self.history_repository.create(
|
||||
db=db,
|
||||
end_user_id=group_id,
|
||||
end_user_id=end_user_id,
|
||||
execution_time=execution_time,
|
||||
merged_count=report['merged_count'],
|
||||
failed_count=report['failed_count'],
|
||||
@@ -376,7 +377,7 @@ class MemoryForgetService:
|
||||
)
|
||||
|
||||
api_logger.info(
|
||||
f"已保存遗忘周期历史记录: end_user_id={group_id}, "
|
||||
f"已保存遗忘周期历史记录: end_user_id={end_user_id}, "
|
||||
f"merged_count={report['merged_count']}"
|
||||
)
|
||||
|
||||
@@ -389,7 +390,7 @@ class MemoryForgetService:
|
||||
def read_forgetting_config(
|
||||
self,
|
||||
db: Session,
|
||||
config_id: int
|
||||
config_id: UUID
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
获取遗忘引擎配置
|
||||
@@ -416,7 +417,7 @@ class MemoryForgetService:
|
||||
def update_forgetting_config(
|
||||
self,
|
||||
db: Session,
|
||||
config_id: int,
|
||||
config_id: UUID,
|
||||
update_fields: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
@@ -465,8 +466,8 @@ class MemoryForgetService:
|
||||
async def get_forgetting_stats(
|
||||
self,
|
||||
db: Session,
|
||||
group_id: Optional[str] = None,
|
||||
config_id: Optional[int] = None
|
||||
end_user_id: Optional[str] = None,
|
||||
config_id: Optional[UUID] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
获取遗忘引擎统计信息
|
||||
@@ -475,7 +476,7 @@ class MemoryForgetService:
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
group_id: 组ID(可选)
|
||||
end_user_id: 组ID(可选)
|
||||
config_id: 配置ID(可选,用于获取遗忘阈值)
|
||||
|
||||
Returns:
|
||||
@@ -493,8 +494,8 @@ class MemoryForgetService:
|
||||
WHERE (n:Statement OR n:ExtractedEntity OR n:MemorySummary OR n:Chunk)
|
||||
"""
|
||||
|
||||
if group_id:
|
||||
activation_query += " AND n.group_id = $group_id"
|
||||
if end_user_id:
|
||||
activation_query += " AND n.end_user_id = $end_user_id"
|
||||
|
||||
activation_query += """
|
||||
RETURN
|
||||
@@ -506,8 +507,8 @@ class MemoryForgetService:
|
||||
"""
|
||||
|
||||
params = {'threshold': forgetting_threshold}
|
||||
if group_id:
|
||||
params['group_id'] = group_id
|
||||
if end_user_id:
|
||||
params['end_user_id'] = end_user_id
|
||||
|
||||
activation_results = await connector.execute_query(activation_query, **params)
|
||||
|
||||
@@ -539,8 +540,8 @@ class MemoryForgetService:
|
||||
WHERE (n:Statement OR n:ExtractedEntity OR n:MemorySummary OR n:Chunk)
|
||||
"""
|
||||
|
||||
if group_id:
|
||||
distribution_query += " AND n.group_id = $group_id"
|
||||
if end_user_id:
|
||||
distribution_query += " AND n.end_user_id = $end_user_id"
|
||||
|
||||
distribution_query += """
|
||||
WITH n,
|
||||
@@ -558,8 +559,8 @@ class MemoryForgetService:
|
||||
"""
|
||||
|
||||
dist_params = {}
|
||||
if group_id:
|
||||
dist_params['group_id'] = group_id
|
||||
if end_user_id:
|
||||
dist_params['end_user_id'] = end_user_id
|
||||
|
||||
distribution_results = await connector.execute_query(distribution_query, **dist_params)
|
||||
|
||||
@@ -582,11 +583,11 @@ class MemoryForgetService:
|
||||
# 获取最近7个日期的历史趋势数据(每天取最后一次执行)
|
||||
recent_trends = []
|
||||
try:
|
||||
if group_id:
|
||||
if end_user_id:
|
||||
# 查询所有历史记录
|
||||
history_records = self.history_repository.get_recent_by_end_user(
|
||||
db=db,
|
||||
end_user_id=group_id
|
||||
end_user_id=end_user_id
|
||||
)
|
||||
|
||||
# 按日期分组(一天可能有多次执行,取最后一次)
|
||||
@@ -632,7 +633,7 @@ class MemoryForgetService:
|
||||
# 获取待遗忘节点列表(前20个满足遗忘条件的节点)
|
||||
pending_nodes = []
|
||||
try:
|
||||
if group_id:
|
||||
if end_user_id:
|
||||
# 验证 min_days_since_access 配置值
|
||||
min_days = config.get('min_days_since_access')
|
||||
if min_days is None or not isinstance(min_days, (int, float)) or min_days < 0:
|
||||
@@ -643,7 +644,7 @@ class MemoryForgetService:
|
||||
|
||||
pending_nodes = await self._get_pending_forgetting_nodes(
|
||||
connector=connector,
|
||||
group_id=group_id,
|
||||
end_user_id=end_user_id,
|
||||
forgetting_threshold=forgetting_threshold,
|
||||
min_days_since_access=int(min_days),
|
||||
limit=20
|
||||
@@ -677,7 +678,7 @@ class MemoryForgetService:
|
||||
db: Session,
|
||||
importance_score: float,
|
||||
days: int,
|
||||
config_id: Optional[int] = None
|
||||
config_id: Optional[UUID] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
获取遗忘曲线数据
|
||||
|
||||
@@ -450,12 +450,12 @@ async def create_document_chunk(
|
||||
|
||||
return success(data=chunk, msg="文档块创建成功")
|
||||
|
||||
async def write_rag(group_id, message, user_rag_memory_id):
|
||||
async def write_rag(end_user_id, message, user_rag_memory_id):
|
||||
"""
|
||||
将消息写入 RAG 知识库
|
||||
|
||||
Args:
|
||||
group_id: 组ID,用作文件标题
|
||||
end_user_id: 组ID,用作文件标题
|
||||
message: 消息内容
|
||||
user_rag_memory_id: 知识库ID(必须是有效的UUID)
|
||||
|
||||
@@ -487,10 +487,10 @@ async def write_rag(group_id, message, user_rag_memory_id):
|
||||
db = next(db_gen)
|
||||
|
||||
try:
|
||||
create_data = CustomTextFileCreate(title=group_id, content=message)
|
||||
create_data = CustomTextFileCreate(title=end_user_id, content=message)
|
||||
current_user = SimpleUser(user_rag_memory_id)
|
||||
# 检查文档是否已存在
|
||||
document = find_document_id_by_kb_and_filename(db=db, kb_id=user_rag_memory_id, file_name=f"{group_id}.txt")
|
||||
document = find_document_id_by_kb_and_filename(db=db, kb_id=user_rag_memory_id, file_name=f"{end_user_id}.txt")
|
||||
print('======',document)
|
||||
api_logger.info(f"查找文档结果: document_id={document}")
|
||||
if document is not None:
|
||||
@@ -508,7 +508,7 @@ async def write_rag(group_id, message, user_rag_memory_id):
|
||||
return result
|
||||
else:
|
||||
# 文档不存在,创建新文档
|
||||
api_logger.info(f"文档不存在,创建新文档: group_id={group_id}")
|
||||
api_logger.info(f"文档不存在,创建新文档: end_user_id={end_user_id}")
|
||||
result = await memory_konwledges_up(
|
||||
kb_id=user_rag_memory_id,
|
||||
parent_id=user_rag_memory_id,
|
||||
@@ -520,13 +520,13 @@ async def write_rag(group_id, message, user_rag_memory_id):
|
||||
new_document_id = find_document_id_by_kb_and_filename(
|
||||
db=db,
|
||||
kb_id=user_rag_memory_id,
|
||||
file_name=f"{group_id}.txt"
|
||||
file_name=f"{end_user_id}.txt"
|
||||
)
|
||||
|
||||
if new_document_id:
|
||||
await parse_document_by_id(new_document_id, db=db, current_user=current_user)
|
||||
else:
|
||||
api_logger.error(f"创建文档后无法找到文档ID: group_id={group_id}")
|
||||
api_logger.error(f"创建文档后无法找到文档ID: end_user_id={end_user_id}")
|
||||
return result
|
||||
finally:
|
||||
# 确保数据库会话被关闭
|
||||
|
||||
@@ -6,7 +6,7 @@ from sqlalchemy.orm import Session
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.models.memory_perceptual_model import PerceptualType, FileStorageType
|
||||
from app.models.memory_perceptual_model import PerceptualType, FileStorageService
|
||||
from app.repositories.memory_perceptual_repository import MemoryPerceptualRepository
|
||||
from app.schemas.memory_perceptual_schema import (
|
||||
PerceptualQuerySchema,
|
||||
@@ -137,8 +137,19 @@ class MemoryPerceptualService:
|
||||
memory_items = []
|
||||
for memory in memories:
|
||||
meta_data = memory.meta_data or {}
|
||||
content = meta_data.get("content")
|
||||
content = Content(**content)
|
||||
content = meta_data.get("content", {})
|
||||
|
||||
# 安全地提取 content 字段,提供默认值
|
||||
if content:
|
||||
content_obj = Content(**content)
|
||||
topic = content_obj.topic
|
||||
domain = content_obj.domain
|
||||
keywords = content_obj.keywords
|
||||
else:
|
||||
topic = "Unknown"
|
||||
domain = "Unknown"
|
||||
keywords = []
|
||||
|
||||
memory_item = PerceptualMemoryItem(
|
||||
id=memory.id,
|
||||
perceptual_type=PerceptualType(memory.perceptual_type),
|
||||
@@ -146,11 +157,12 @@ class MemoryPerceptualService:
|
||||
file_name=memory.file_name,
|
||||
file_ext=memory.file_ext,
|
||||
summary=memory.summary,
|
||||
topic=content.topic,
|
||||
domain=content.domain,
|
||||
keywords=content.keywords,
|
||||
meta_data=meta_data,
|
||||
topic=topic,
|
||||
domain=domain,
|
||||
keywords=keywords,
|
||||
created_time=int(memory.created_time.timestamp()*1000),
|
||||
storage_type=FileStorageType(memory.storage_service),
|
||||
storage_service=FileStorageService(memory.storage_service),
|
||||
)
|
||||
memory_items.append(memory_item)
|
||||
|
||||
|
||||
@@ -13,7 +13,7 @@ from app.db import get_db
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.core.memory.storage_services.reflection_engine import ReflectionConfig, ReflectionEngine
|
||||
from app.core.memory.storage_services.reflection_engine.self_reflexion import ReflectionRange, ReflectionBaseline
|
||||
from app.repositories.data_config_repository import DataConfigRepository
|
||||
from app.repositories.memory_config_repository import MemoryConfigRepository
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.models.app_model import App
|
||||
from app.models.app_release_model import AppRelease
|
||||
@@ -73,7 +73,7 @@ class WorkspaceAppService:
|
||||
"created_at": app.created_at.isoformat() if app.created_at else None,
|
||||
"updated_at": app.updated_at.isoformat() if app.updated_at else None,
|
||||
"releases": [],
|
||||
"data_configs": [],
|
||||
"memory_configs": [],
|
||||
"end_users": []
|
||||
}
|
||||
|
||||
@@ -101,11 +101,11 @@ class WorkspaceAppService:
|
||||
|
||||
if memory_content:
|
||||
processed_configs.add(memory_content)
|
||||
data_config_info = self._get_data_config(memory_content)
|
||||
memory_config_info = self._get_memory_config(memory_content)
|
||||
|
||||
if data_config_info:
|
||||
if not any(dc["config_id"] == data_config_info["config_id"] for dc in app_info["data_configs"]):
|
||||
app_info["data_configs"].append(data_config_info)
|
||||
if memory_config_info:
|
||||
if not any(dc["config_id"] == memory_config_info["config_id"] for dc in app_info["memory_configs"]):
|
||||
app_info["memory_configs"].append(memory_config_info)
|
||||
|
||||
app_info["releases"].append(release_info)
|
||||
|
||||
@@ -120,30 +120,30 @@ class WorkspaceAppService:
|
||||
|
||||
return None
|
||||
|
||||
def _get_data_config(self, memory_content: str) -> Dict[str, Any]:
|
||||
"""Retrieve data_comfig information based on memory_comtent"""
|
||||
def _get_memory_config(self, memory_content: str) -> Dict[str, Any]:
|
||||
"""Retrieve memory_config information based on memory_content"""
|
||||
try:
|
||||
data_config_result = DataConfigRepository.query_reflection_config_by_id(self.db, int(memory_content))
|
||||
memory_config_result = MemoryConfigRepository.query_reflection_config_by_id(self.db, int(memory_content))
|
||||
|
||||
# data_config_query, data_config_params = DataConfigRepository.build_select_reflection(memory_content)
|
||||
# data_config_result = self.db.execute(text(data_config_query), data_config_params).fetchone()
|
||||
# if data_config_result is None:
|
||||
# memory_config_query, memory_config_params = MemoryConfigRepository.build_select_reflection(memory_content)
|
||||
# memory_config_result = self.db.execute(text(memory_config_query), memory_config_params).fetchone()
|
||||
# if memory_config_result is None:
|
||||
# return None
|
||||
|
||||
if data_config_result:
|
||||
if memory_config_result:
|
||||
return {
|
||||
"config_id": data_config_result.config_id,
|
||||
"enable_self_reflexion": data_config_result.enable_self_reflexion,
|
||||
"iteration_period": data_config_result.iteration_period,
|
||||
"reflexion_range": data_config_result.reflexion_range,
|
||||
"baseline": data_config_result.baseline,
|
||||
"reflection_model_id": data_config_result.reflection_model_id,
|
||||
"memory_verify": data_config_result.memory_verify,
|
||||
"quality_assessment": data_config_result.quality_assessment,
|
||||
"user_id": data_config_result.user_id
|
||||
"config_id": memory_config_result.config_id,
|
||||
"enable_self_reflexion": memory_config_result.enable_self_reflexion,
|
||||
"iteration_period": memory_config_result.iteration_period,
|
||||
"reflexion_range": memory_config_result.reflexion_range,
|
||||
"baseline": memory_config_result.baseline,
|
||||
"reflection_model_id": memory_config_result.reflection_model_id,
|
||||
"memory_verify": memory_config_result.memory_verify,
|
||||
"quality_assessment": memory_config_result.quality_assessment,
|
||||
"user_id": memory_config_result.user_id
|
||||
}
|
||||
except Exception as e:
|
||||
api_logger.warning(f"查询data_config失败,memory_content: {memory_content}, 错误: {str(e)}")
|
||||
api_logger.warning(f"查询memory_config失败,memory_content: {memory_content}, 错误: {str(e)}")
|
||||
|
||||
return None
|
||||
|
||||
@@ -226,7 +226,7 @@ class MemoryReflectionService:
|
||||
}
|
||||
|
||||
config_data_id = config_data['config_id']
|
||||
reflection_config = WorkspaceAppService(self.db)._get_data_config(config_data_id)
|
||||
reflection_config = WorkspaceAppService(self.db)._get_memory_config(config_data_id)
|
||||
if reflection_config is not None and reflection_config['enable_self_reflexion']:
|
||||
reflection_config = self._create_reflection_config_from_data(reflection_config)
|
||||
# 3. 执行反思引擎
|
||||
@@ -280,7 +280,7 @@ class MemoryReflectionService:
|
||||
|
||||
|
||||
config_data_id=config_data['config_id']
|
||||
reflection_config=WorkspaceAppService(self.db)._get_data_config(config_data_id)
|
||||
reflection_config=WorkspaceAppService(self.db)._get_memory_config(config_data_id)
|
||||
if reflection_config is not None and reflection_config['enable_self_reflexion']:
|
||||
reflection_config= self._create_reflection_config_from_data(reflection_config)
|
||||
iteration_period = int(reflection_config.iteration_period)
|
||||
|
||||
@@ -19,7 +19,7 @@ from app.core.memory.analytics.hot_memory_tags import (
|
||||
)
|
||||
from app.core.memory.analytics.recent_activity_stats import get_recent_activity_stats
|
||||
from app.models.user_model import User
|
||||
from app.repositories.data_config_repository import DataConfigRepository
|
||||
from app.repositories.memory_config_repository import MemoryConfigRepository
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.schemas.memory_config_schema import ConfigurationError
|
||||
from app.schemas.memory_storage_schema import (
|
||||
@@ -129,7 +129,7 @@ class DataConfigService: # 数据配置服务类(PostgreSQL)
|
||||
if not params.rerank_id:
|
||||
params.rerank_id = configs.get('rerank')
|
||||
|
||||
config = DataConfigRepository.create(self.db, params)
|
||||
config = MemoryConfigRepository.create(self.db, params)
|
||||
self.db.commit()
|
||||
return {"affected": 1, "config_id": config.config_id}
|
||||
|
||||
@@ -146,20 +146,20 @@ class DataConfigService: # 数据配置服务类(PostgreSQL)
|
||||
|
||||
# --- Delete ---
|
||||
def delete(self, key: ConfigParamsDelete) -> Dict[str, Any]: # 删除配置参数(按配置ID)
|
||||
success = DataConfigRepository.delete(self.db, key.config_id)
|
||||
success = MemoryConfigRepository.delete(self.db, key.config_id)
|
||||
if not success:
|
||||
raise ValueError("未找到配置")
|
||||
return {"affected": 1}
|
||||
|
||||
# --- Update ---
|
||||
def update(self, update: ConfigUpdate) -> Dict[str, Any]: # 部分更新配置参数
|
||||
config = DataConfigRepository.update(self.db, update)
|
||||
config = MemoryConfigRepository.update(self.db, update)
|
||||
if not config:
|
||||
raise ValueError("未找到配置")
|
||||
return {"affected": 1}
|
||||
|
||||
def update_extracted(self, update: ConfigUpdateExtracted) -> Dict[str, Any]: # 更新记忆萃取引擎配置参数
|
||||
config = DataConfigRepository.update_extracted(self.db, update)
|
||||
config = MemoryConfigRepository.update_extracted(self.db, update)
|
||||
if not config:
|
||||
raise ValueError("未找到配置")
|
||||
return {"affected": 1}
|
||||
@@ -170,14 +170,14 @@ class DataConfigService: # 数据配置服务类(PostgreSQL)
|
||||
|
||||
# --- Read ---
|
||||
def get_extracted(self, key: ConfigKey) -> Dict[str, Any]: # 获取萃取配置参数
|
||||
result = DataConfigRepository.get_extracted_config(self.db, key.config_id)
|
||||
result = MemoryConfigRepository.get_extracted_config(self.db, key.config_id)
|
||||
if not result:
|
||||
raise ValueError("未找到配置")
|
||||
return result
|
||||
|
||||
# --- Read All ---
|
||||
def get_all(self, workspace_id = None) -> List[Dict[str, Any]]: # 获取所有配置参数
|
||||
configs = DataConfigRepository.get_all(self.db, workspace_id)
|
||||
configs = MemoryConfigRepository.get_all(self.db, workspace_id)
|
||||
|
||||
# 将 ORM 对象转换为字典列表
|
||||
data_list = []
|
||||
@@ -187,7 +187,7 @@ class DataConfigService: # 数据配置服务类(PostgreSQL)
|
||||
"config_name": config.config_name,
|
||||
"config_desc": config.config_desc,
|
||||
"workspace_id": str(config.workspace_id) if config.workspace_id else None,
|
||||
"group_id": config.group_id,
|
||||
"end_user_id": config.end_user_id,
|
||||
"user_id": config.user_id,
|
||||
"apply_id": config.apply_id,
|
||||
"llm_id": config.llm_id,
|
||||
@@ -395,8 +395,8 @@ _neo4j_connector = Neo4jConnector()
|
||||
|
||||
async def search_dialogue(end_user_id: Optional[str] = None) -> Dict[str, Any]:
|
||||
result = await _neo4j_connector.execute_query(
|
||||
DataConfigRepository.SEARCH_FOR_DIALOGUE,
|
||||
group_id=end_user_id,
|
||||
MemoryConfigRepository.SEARCH_FOR_DIALOGUE,
|
||||
end_user_id=end_user_id,
|
||||
)
|
||||
data = {"search_for": "dialogue", "num": result[0]["num"]}
|
||||
return data
|
||||
@@ -404,8 +404,8 @@ async def search_dialogue(end_user_id: Optional[str] = None) -> Dict[str, Any]:
|
||||
|
||||
async def search_chunk(end_user_id: Optional[str] = None) -> Dict[str, Any]:
|
||||
result = await _neo4j_connector.execute_query(
|
||||
DataConfigRepository.SEARCH_FOR_CHUNK,
|
||||
group_id=end_user_id,
|
||||
MemoryConfigRepository.SEARCH_FOR_CHUNK,
|
||||
end_user_id=end_user_id,
|
||||
)
|
||||
data = {"search_for": "chunk", "num": result[0]["num"]}
|
||||
return data
|
||||
@@ -413,8 +413,8 @@ async def search_chunk(end_user_id: Optional[str] = None) -> Dict[str, Any]:
|
||||
|
||||
async def search_statement(end_user_id: Optional[str] = None) -> Dict[str, Any]:
|
||||
result = await _neo4j_connector.execute_query(
|
||||
DataConfigRepository.SEARCH_FOR_STATEMENT,
|
||||
group_id=end_user_id,
|
||||
MemoryConfigRepository.SEARCH_FOR_STATEMENT,
|
||||
end_user_id=end_user_id,
|
||||
)
|
||||
data = {"search_for": "statement", "num": result[0]["num"]}
|
||||
return data
|
||||
@@ -422,8 +422,8 @@ async def search_statement(end_user_id: Optional[str] = None) -> Dict[str, Any]:
|
||||
|
||||
async def search_entity(end_user_id: Optional[str] = None) -> Dict[str, Any]:
|
||||
result = await _neo4j_connector.execute_query(
|
||||
DataConfigRepository.SEARCH_FOR_ENTITY,
|
||||
group_id=end_user_id,
|
||||
MemoryConfigRepository.SEARCH_FOR_ENTITY,
|
||||
end_user_id=end_user_id,
|
||||
)
|
||||
data = {"search_for": "entity", "num": result[0]["num"]}
|
||||
return data
|
||||
@@ -431,8 +431,8 @@ async def search_entity(end_user_id: Optional[str] = None) -> Dict[str, Any]:
|
||||
|
||||
async def search_all(end_user_id: Optional[str] = None) -> Dict[str, Any]:
|
||||
result = await _neo4j_connector.execute_query(
|
||||
DataConfigRepository.SEARCH_FOR_ALL,
|
||||
group_id=end_user_id,
|
||||
MemoryConfigRepository.SEARCH_FOR_ALL,
|
||||
end_user_id=end_user_id,
|
||||
)
|
||||
|
||||
# 检查结果是否为空或长度不足
|
||||
@@ -466,8 +466,8 @@ async def kb_type_distribution(end_user_id: Optional[str] = None) -> Dict[str, A
|
||||
聚合 dialogue/chunk/statement/entity 四类计数,返回统一的分布结构,便于前端一次性消费。
|
||||
"""
|
||||
result = await _neo4j_connector.execute_query(
|
||||
DataConfigRepository.SEARCH_FOR_ALL,
|
||||
group_id=end_user_id,
|
||||
MemoryConfigRepository.SEARCH_FOR_ALL,
|
||||
end_user_id=end_user_id,
|
||||
)
|
||||
|
||||
# 检查结果是否为空或长度不足
|
||||
@@ -497,21 +497,19 @@ async def kb_type_distribution(end_user_id: Optional[str] = None) -> Dict[str, A
|
||||
|
||||
async def search_detials(end_user_id: Optional[str] = None) -> List[Dict[str, Any]]:
|
||||
result = await _neo4j_connector.execute_query(
|
||||
DataConfigRepository.SEARCH_FOR_DETIALS,
|
||||
group_id=end_user_id,
|
||||
MemoryConfigRepository.SEARCH_FOR_DETIALS,
|
||||
end_user_id=end_user_id,
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
async def search_edges(end_user_id: Optional[str] = None) -> List[Dict[str, Any]]:
|
||||
result = await _neo4j_connector.execute_query(
|
||||
DataConfigRepository.SEARCH_FOR_EDGES,
|
||||
group_id=end_user_id,
|
||||
MemoryConfigRepository.SEARCH_FOR_EDGES,
|
||||
end_user_id=end_user_id,
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
|
||||
async def analytics_hot_memory_tags(
|
||||
db: Session,
|
||||
current_user: User,
|
||||
@@ -574,7 +572,7 @@ async def analytics_hot_memory_tags(
|
||||
# 步骤4: 只调用一次LLM进行筛选
|
||||
tag_names = [tag for tag, _ in sorted_tags]
|
||||
|
||||
# 使用第一个用户的group_id来获取LLM配置
|
||||
# 使用第一个用户的end_user_id来获取LLM配置
|
||||
# 因为同一工作空间下的用户应该使用相同的配置
|
||||
first_end_user_id = str(end_users[0].id)
|
||||
filtered_tag_names = await filter_tags_with_llm(tag_names, first_end_user_id)
|
||||
|
||||
@@ -91,7 +91,7 @@ async def run_pilot_extraction(
|
||||
dialog = DialogData(
|
||||
context=context,
|
||||
ref_id="pilot_dialog_1",
|
||||
group_id=str(memory_config.workspace_id),
|
||||
end_user_id=str(memory_config.workspace_id),
|
||||
user_id=str(memory_config.tenant_id),
|
||||
apply_id=str(memory_config.config_id),
|
||||
metadata={"source": "pilot_run", "input_type": "frontend_text"},
|
||||
|
||||
@@ -155,10 +155,10 @@ class MemoryInsightHelper:
|
||||
"""
|
||||
query = """
|
||||
MATCH (d:Dialogue)
|
||||
WHERE d.group_id = $group_id AND d.created_at IS NOT NULL AND d.created_at <> ''
|
||||
WHERE d.end_user_id = $end_user_id AND d.created_at IS NOT NULL AND d.created_at <> ''
|
||||
RETURN d.created_at AS creation_time
|
||||
"""
|
||||
records = await self.neo4j_connector.execute_query(query, group_id=self.user_id)
|
||||
records = await self.neo4j_connector.execute_query(query, end_user_id=self.user_id)
|
||||
|
||||
if not records:
|
||||
return []
|
||||
@@ -211,17 +211,17 @@ class MemoryInsightHelper:
|
||||
async def get_social_connections(self) -> dict | None:
|
||||
"""Find the user with whom the most memories are shared."""
|
||||
query = """
|
||||
MATCH (c1:Chunk {group_id: $group_id})
|
||||
MATCH (c1:Chunk {end_user_id: $end_user_id})
|
||||
OPTIONAL MATCH (c1)-[:CONTAINS]->(s:Statement)
|
||||
OPTIONAL MATCH (s)<-[:CONTAINS]-(c2:Chunk)
|
||||
WHERE c1.group_id <> c2.group_id AND s IS NOT NULL AND c2 IS NOT NULL
|
||||
WITH c2.group_id AS other_user_id, COUNT(DISTINCT s) AS common_statements
|
||||
WHERE c1.end_user_id <> c2.end_user_id AND s IS NOT NULL AND c2 IS NOT NULL
|
||||
WITH c2.end_user_id AS other_user_id, COUNT(DISTINCT s) AS common_statements
|
||||
WHERE common_statements > 0
|
||||
RETURN other_user_id, common_statements
|
||||
ORDER BY common_statements DESC
|
||||
LIMIT 1
|
||||
"""
|
||||
records = await self.neo4j_connector.execute_query(query, group_id=self.user_id)
|
||||
records = await self.neo4j_connector.execute_query(query, end_user_id=self.user_id)
|
||||
if not records or not records[0].get("other_user_id"):
|
||||
return None
|
||||
|
||||
@@ -230,7 +230,7 @@ class MemoryInsightHelper:
|
||||
|
||||
time_range_query = """
|
||||
MATCH (c:Chunk)
|
||||
WHERE c.group_id IN [$user_id, $other_user_id]
|
||||
WHERE c.end_user_id IN [$user_id, $other_user_id]
|
||||
RETURN min(c.created_at) AS start_time, max(c.created_at) AS end_time
|
||||
"""
|
||||
time_records = await self.neo4j_connector.execute_query(
|
||||
@@ -294,11 +294,11 @@ class UserSummaryHelper:
|
||||
"""Fetch recent statements authored by the user/group for context."""
|
||||
query = (
|
||||
"MATCH (s:Statement) "
|
||||
"WHERE s.group_id = $group_id AND s.statement IS NOT NULL "
|
||||
"WHERE s.end_user_id = $end_user_id AND s.statement IS NOT NULL "
|
||||
"RETURN s.statement AS statement, s.created_at AS created_at "
|
||||
"ORDER BY created_at DESC LIMIT $limit"
|
||||
)
|
||||
rows = await self.connector.execute_query(query, group_id=self.user_id, limit=limit)
|
||||
rows = await self.connector.execute_query(query, end_user_id=self.user_id, limit=limit)
|
||||
records = []
|
||||
for r in rows:
|
||||
try:
|
||||
@@ -1152,7 +1152,7 @@ async def analytics_user_summary(end_user_id: Optional[str] = None) -> Dict[str,
|
||||
import re
|
||||
|
||||
# 创建 UserSummaryHelper 实例
|
||||
user_summary_tool = UserSummaryHelper(end_user_id or os.getenv("SELECTED_GROUP_ID", "group_123"))
|
||||
user_summary_tool = UserSummaryHelper(end_user_id or os.getenv("SELECTED_end_user_id", "group_123"))
|
||||
|
||||
try:
|
||||
# 1) 收集上下文数据
|
||||
@@ -1273,10 +1273,10 @@ async def analytics_node_statistics(
|
||||
if end_user_id:
|
||||
query = f"""
|
||||
MATCH (n:{node_type})
|
||||
WHERE n.group_id = $group_id
|
||||
WHERE n.end_user_id = $end_user_id
|
||||
RETURN count(n) as count
|
||||
"""
|
||||
result = await _neo4j_connector.execute_query(query, group_id=end_user_id)
|
||||
result = await _neo4j_connector.execute_query(query, end_user_id=end_user_id)
|
||||
else:
|
||||
query = f"""
|
||||
MATCH (n:{node_type})
|
||||
@@ -1387,10 +1387,10 @@ async def analytics_memory_types(
|
||||
# 查询 Statement 节点数量
|
||||
query = """
|
||||
MATCH (n:Statement)
|
||||
WHERE n.group_id = $group_id
|
||||
WHERE n.end_user_id = $end_user_id
|
||||
RETURN count(n) as count
|
||||
"""
|
||||
result = await _neo4j_connector.execute_query(query, group_id=end_user_id)
|
||||
result = await _neo4j_connector.execute_query(query, end_user_id=end_user_id)
|
||||
statement_count = result[0]["count"] if result and len(result) > 0 else 0
|
||||
# 取三分之一作为隐性记忆数量
|
||||
implicit_count = round(statement_count / 3)
|
||||
@@ -1504,7 +1504,7 @@ async def analytics_graph_data(
|
||||
包含节点、边和统计信息的字典
|
||||
"""
|
||||
try:
|
||||
# 1. 获取 group_id
|
||||
# 1. 获取 end_user_id
|
||||
user_uuid = uuid.UUID(end_user_id)
|
||||
repo = EndUserRepository(db)
|
||||
end_user = repo.get_by_id(user_uuid)
|
||||
@@ -1528,7 +1528,7 @@ async def analytics_graph_data(
|
||||
# 基于中心节点的扩展查询
|
||||
node_query = f"""
|
||||
MATCH path = (center)-[*1..{depth}]-(connected)
|
||||
WHERE center.group_id = $group_id
|
||||
WHERE center.end_user_id = $end_user_id
|
||||
AND elementId(center) = $center_node_id
|
||||
WITH collect(DISTINCT center) + collect(DISTINCT connected) as all_nodes
|
||||
UNWIND all_nodes as n
|
||||
@@ -1539,7 +1539,7 @@ async def analytics_graph_data(
|
||||
LIMIT $limit
|
||||
"""
|
||||
node_params = {
|
||||
"group_id": end_user_id,
|
||||
"end_user_id": end_user_id,
|
||||
"center_node_id": center_node_id,
|
||||
"limit": limit
|
||||
}
|
||||
@@ -1547,7 +1547,7 @@ async def analytics_graph_data(
|
||||
# 按节点类型过滤查询
|
||||
node_query = """
|
||||
MATCH (n)
|
||||
WHERE n.group_id = $group_id
|
||||
WHERE n.end_user_id = $end_user_id
|
||||
AND labels(n)[0] IN $node_types
|
||||
RETURN
|
||||
elementId(n) as id,
|
||||
@@ -1556,7 +1556,7 @@ async def analytics_graph_data(
|
||||
LIMIT $limit
|
||||
"""
|
||||
node_params = {
|
||||
"group_id": end_user_id,
|
||||
"end_user_id": end_user_id,
|
||||
"node_types": node_types,
|
||||
"limit": limit
|
||||
}
|
||||
@@ -1564,7 +1564,7 @@ async def analytics_graph_data(
|
||||
# 查询所有节点
|
||||
node_query = """
|
||||
MATCH (n)
|
||||
WHERE n.group_id = $group_id
|
||||
WHERE n.end_user_id = $end_user_id
|
||||
RETURN
|
||||
elementId(n) as id,
|
||||
labels(n)[0] as label,
|
||||
@@ -1572,7 +1572,7 @@ async def analytics_graph_data(
|
||||
LIMIT $limit
|
||||
"""
|
||||
node_params = {
|
||||
"group_id": end_user_id,
|
||||
"end_user_id": end_user_id,
|
||||
"limit": limit
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user