Merge branch 'develop' of github.com:SuanmoSuanyangTechnology/MemoryBear into develop
# Conflicts: # api/pyproject.toml
This commit is contained in:
@@ -10,11 +10,6 @@ import time
|
||||
import uuid
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional
|
||||
|
||||
from langchain.tools import tool
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.celery_app import celery_app
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.exceptions import BusinessException
|
||||
@@ -28,6 +23,10 @@ from app.services.langchain_tool_server import Search
|
||||
from app.services.memory_agent_service import MemoryAgentService
|
||||
from app.services.model_parameter_merger import ModelParameterMerger
|
||||
from app.services.tool_service import ToolService
|
||||
from langchain.tools import tool
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
logger = get_business_logger()
|
||||
class KnowledgeRetrievalInput(BaseModel):
|
||||
@@ -107,9 +106,9 @@ def create_long_term_memory_tool(memory_config: Dict[str, Any], end_user_id: str
|
||||
"app.core.memory.agent.read_message",
|
||||
args=[end_user_id, question, [], "1", config_id, storage_type, user_rag_memory_id]
|
||||
)
|
||||
result = task_service.get_task_memory_read_result(task.id)
|
||||
status = result.get("status")
|
||||
logger.info(f"读取任务状态:{status}")
|
||||
# result = task_service.get_task_memory_read_result(task.id)
|
||||
# status = result.get("status")
|
||||
# logger.info(f"读取任务状态:{status}")
|
||||
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
@@ -10,26 +10,32 @@ import re
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional
|
||||
import redis
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
import redis
|
||||
from app.core.config import settings
|
||||
from app.core.logging_config import get_config_logger, get_logger
|
||||
from app.core.memory.agent.langgraph_graph.read_graph import make_read_graph
|
||||
from app.core.memory.agent.langgraph_graph.write_graph import make_write_graph
|
||||
from app.core.memory.agent.logger_file.log_streamer import LogStreamer
|
||||
from app.core.memory.agent.utils.messages_tools import merge_multiple_search_results, reorder_output_results
|
||||
from app.core.memory.agent.utils.messages_tools import (
|
||||
merge_multiple_search_results,
|
||||
reorder_output_results,
|
||||
)
|
||||
from app.core.memory.agent.utils.type_classifier import status_typle
|
||||
from app.core.memory.agent.utils.write_tools import write # 新增:直接导入 write 函数
|
||||
from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.db import get_db_context
|
||||
from app.models.knowledge_model import Knowledge, KnowledgeType
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.schemas.memory_agent_schema import Write_UserInput
|
||||
from app.schemas.memory_config_schema import ConfigurationError
|
||||
from app.services.memory_base_service import Translation_English
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
from app.services.memory_konwledges_server import (
|
||||
write_rag,
|
||||
)
|
||||
from langchain_core.messages import HumanMessage
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy.orm import Session
|
||||
@@ -259,13 +265,13 @@ class MemoryAgentService:
|
||||
logger.info("Log streaming completed, cleaning up resources")
|
||||
# LogStreamer uses context manager for file handling, so cleanup is automatic
|
||||
|
||||
async def write_memory(self, group_id: str, message: str, config_id: Optional[str], db: Session, storage_type: str, user_rag_memory_id: str) -> str:
|
||||
async def write_memory(self, group_id: str, messages: list[dict], config_id: Optional[str], db: Session, storage_type: str, user_rag_memory_id: str) -> str:
|
||||
"""
|
||||
Process write operation with config_id
|
||||
|
||||
Args:
|
||||
group_id: Group identifier (also used as end_user_id)
|
||||
message: Message to write
|
||||
messages: Structured message list [{"role": "user", "content": "..."}, ...]
|
||||
config_id: Configuration ID from database
|
||||
db: SQLAlchemy database session
|
||||
storage_type: Storage type (neo4j or rag)
|
||||
@@ -286,7 +292,7 @@ class MemoryAgentService:
|
||||
raise ValueError(f"No memory configuration found for end_user {group_id}. Please ensure the user has a connected memory configuration.")
|
||||
except Exception as e:
|
||||
if "No memory configuration found" in str(e):
|
||||
raise # Re-raise our specific error
|
||||
raise
|
||||
logger.error(f"Failed to get connected config for end_user {group_id}: {e}")
|
||||
raise ValueError(f"Unable to determine memory configuration for end_user {group_id}: {e}")
|
||||
|
||||
@@ -314,14 +320,28 @@ class MemoryAgentService:
|
||||
|
||||
try:
|
||||
if storage_type == "rag":
|
||||
result = await write_rag(group_id, message, user_rag_memory_id)
|
||||
# For RAG storage, convert messages to single string
|
||||
message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages])
|
||||
result = await write_rag(group_id, message_text, user_rag_memory_id)
|
||||
return result
|
||||
else:
|
||||
async with make_write_graph() as graph:
|
||||
config = {"configurable": {"thread_id": group_id}}
|
||||
# Convert structured messages to LangChain messages
|
||||
langchain_messages = []
|
||||
for msg in messages:
|
||||
if msg['role'] == 'user':
|
||||
langchain_messages.append(HumanMessage(content=msg['content']))
|
||||
elif msg['role'] == 'assistant':
|
||||
from langchain_core.messages import AIMessage
|
||||
langchain_messages.append(AIMessage(content=msg['content']))
|
||||
|
||||
# 初始状态 - 包含所有必要字段
|
||||
initial_state = {"messages": [HumanMessage(content=message)], "group_id": group_id,
|
||||
"memory_config": memory_config}
|
||||
initial_state = {
|
||||
"messages": langchain_messages,
|
||||
"group_id": group_id,
|
||||
"memory_config": memory_config
|
||||
}
|
||||
|
||||
# 获取节点更新信息
|
||||
async for update_event in graph.astream(
|
||||
@@ -334,7 +354,9 @@ class MemoryAgentService:
|
||||
massages = node_data
|
||||
massagesstatus = massages.get('write_result')['status']
|
||||
contents = massages.get('write_result')
|
||||
return self.writer_messages_deal(massagesstatus, start_time, group_id, config_id, message, contents)
|
||||
# Convert messages back to string for logging
|
||||
message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages])
|
||||
return self.writer_messages_deal(massagesstatus, start_time, group_id, config_id, message_text, contents)
|
||||
except Exception as e:
|
||||
# Ensure proper error handling and logging
|
||||
error_msg = f"Write operation failed: {str(e)}"
|
||||
@@ -385,6 +407,7 @@ class MemoryAgentService:
|
||||
|
||||
import time
|
||||
start_time = time.time()
|
||||
logger.info(f"[PERF] read_memory started for group_id={group_id}, search_switch={search_switch}")
|
||||
|
||||
# Resolve config_id if None using end_user's connected config
|
||||
if config_id is None:
|
||||
@@ -408,13 +431,15 @@ class MemoryAgentService:
|
||||
audit_logger = None
|
||||
|
||||
|
||||
config_load_start = time.time()
|
||||
try:
|
||||
config_service = MemoryConfigService(db)
|
||||
memory_config = config_service.load_memory_config(
|
||||
config_id=config_id,
|
||||
service_name="MemoryAgentService"
|
||||
)
|
||||
logger.info(f"Configuration loaded successfully: {memory_config.config_name}")
|
||||
config_load_time = time.time() - config_load_start
|
||||
logger.info(f"[PERF] Configuration loaded in {config_load_time:.4f}s: {memory_config.config_name}")
|
||||
except ConfigurationError as e:
|
||||
error_msg = f"Failed to load configuration for config_id: {config_id}: {e}"
|
||||
logger.error(error_msg)
|
||||
@@ -438,6 +463,7 @@ class MemoryAgentService:
|
||||
logger.debug(f"Group ID:{group_id}, Message:{message}, History:{history}, Config ID:{config_id}")
|
||||
|
||||
# Step 3: Initialize MCP client and execute read workflow
|
||||
graph_exec_start = time.time()
|
||||
try:
|
||||
async with make_read_graph() as graph:
|
||||
config = {"configurable": {"thread_id": group_id}}
|
||||
@@ -494,12 +520,68 @@ class MemoryAgentService:
|
||||
if summary_n and summary_n != [] and summary_n != {}:
|
||||
_intermediate_outputs.append(summary_n)
|
||||
|
||||
graph_exec_time = time.time() - graph_exec_start
|
||||
logger.info(f"[PERF] Graph execution completed in {graph_exec_time:.4f}s")
|
||||
|
||||
_intermediate_outputs = [item for item in _intermediate_outputs if item and item != [] and item != {}]
|
||||
|
||||
optimized_outputs = merge_multiple_search_results(_intermediate_outputs)
|
||||
result = reorder_output_results(optimized_outputs)
|
||||
|
||||
# 保存短期记忆到数据库
|
||||
# 只有 search_switch 不为 "2"(快速检索)时才保存
|
||||
try:
|
||||
from app.repositories.memory_short_repository import ShortTermMemoryRepository
|
||||
|
||||
retrieved_content = []
|
||||
repo = ShortTermMemoryRepository(db)
|
||||
|
||||
if str(search_switch) != "2":
|
||||
for intermediate in _intermediate_outputs:
|
||||
logger.debug(f"处理中间结果: {intermediate}")
|
||||
intermediate_type = intermediate.get('type', '')
|
||||
|
||||
if intermediate_type == "search_result":
|
||||
query = intermediate.get('query', '')
|
||||
raw_results = intermediate.get('raw_results', {})
|
||||
reranked_results = raw_results.get('reranked_results', [])
|
||||
|
||||
try:
|
||||
statements = [statement['statement'] for statement in reranked_results.get('statements', [])]
|
||||
except Exception:
|
||||
statements = []
|
||||
|
||||
# 去重
|
||||
statements = list(set(statements))
|
||||
|
||||
if query and statements:
|
||||
retrieved_content.append({query: statements})
|
||||
|
||||
# 如果 retrieved_content 为空,设置为空字符串
|
||||
if retrieved_content == []:
|
||||
retrieved_content = ''
|
||||
|
||||
# 只有当回答不是"信息不足"且不是快速检索时才保存
|
||||
if '信息不足,无法回答。' != str(summary) and str(search_switch).strip() != "2":
|
||||
# 使用 upsert 方法
|
||||
repo.upsert(
|
||||
end_user_id=group_id,
|
||||
messages=message,
|
||||
aimessages=summary,
|
||||
retrieved_content=retrieved_content,
|
||||
search_switch=str(search_switch)
|
||||
)
|
||||
logger.info(f"成功保存短期记忆: group_id={group_id}, search_switch={search_switch}")
|
||||
else:
|
||||
logger.debug(f"跳过保存短期记忆: summary={summary[:50] if summary else 'None'}, search_switch={search_switch}")
|
||||
|
||||
except Exception as save_error:
|
||||
# 保存失败不应该影响主流程,只记录错误
|
||||
logger.error(f"保存短期记忆失败: {str(save_error)}", exc_info=True)
|
||||
|
||||
# Log successful operation
|
||||
total_time = time.time() - start_time
|
||||
logger.info(f"[PERF] read_memory completed successfully in {total_time:.4f}s (config: {config_load_time:.4f}s, graph: {graph_exec_time:.4f}s)")
|
||||
if audit_logger:
|
||||
duration = time.time() - start_time
|
||||
audit_logger.log_operation(
|
||||
@@ -517,7 +599,8 @@ class MemoryAgentService:
|
||||
except Exception as e:
|
||||
# Ensure proper error handling and logging
|
||||
error_msg = f"Read operation failed: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
total_time = time.time() - start_time
|
||||
logger.error(f"[PERF] read_memory failed after {total_time:.4f}s: {error_msg}")
|
||||
if audit_logger:
|
||||
duration = time.time() - start_time
|
||||
audit_logger.log_operation(
|
||||
@@ -530,7 +613,49 @@ class MemoryAgentService:
|
||||
)
|
||||
raise ValueError(error_msg)
|
||||
|
||||
|
||||
def get_messages_list(self, user_input: Write_UserInput) -> list[dict]:
|
||||
"""
|
||||
Get standardized message list from user input.
|
||||
|
||||
Args:
|
||||
user_input: Write_UserInput object
|
||||
|
||||
Returns:
|
||||
list[dict]: Message list, each message contains role and content
|
||||
|
||||
Raises:
|
||||
ValueError: If messages is empty or format is incorrect
|
||||
"""
|
||||
from app.core.logging_config import get_api_logger
|
||||
logger = get_api_logger()
|
||||
|
||||
if len(user_input.messages) == 0:
|
||||
logger.error("Validation failed: Message list cannot be empty")
|
||||
raise ValueError("Message list cannot be empty")
|
||||
|
||||
for idx, msg in enumerate(user_input.messages):
|
||||
if not isinstance(msg, dict):
|
||||
logger.error(f"Validation failed: Message {idx} is not a dict: {type(msg)}")
|
||||
raise ValueError(f"Message format error: Message must be a dictionary. Error message index: {idx}, type: {type(msg)}")
|
||||
|
||||
if 'role' not in msg:
|
||||
logger.error(f"Validation failed: Message {idx} missing 'role' field: {msg}")
|
||||
raise ValueError(f"Message format error: Message must contain 'role' field. Error message index: {idx}")
|
||||
|
||||
if 'content' not in msg:
|
||||
logger.error(f"Validation failed: Message {idx} missing 'content' field: {msg}")
|
||||
raise ValueError(f"Message format error: Message must contain 'content' field. Error message index: {idx}")
|
||||
|
||||
if msg['role'] not in ['user', 'assistant']:
|
||||
logger.error(f"Validation failed: Message {idx} invalid role: {msg['role']}")
|
||||
raise ValueError(f"Role must be 'user' or 'assistant', got: {msg['role']}. Message index: {idx}")
|
||||
|
||||
if not msg['content'] or not msg['content'].strip():
|
||||
logger.error(f"Validation failed: Message {idx} content is empty")
|
||||
raise ValueError(f"Message content cannot be empty. Message index: {idx}, role: {msg['role']}")
|
||||
|
||||
logger.info(f"Validation successful: Structured message list, count: {len(user_input.messages)}")
|
||||
return user_input.messages
|
||||
|
||||
async def classify_message_type(self, message: str, config_id: int, db: Session) -> Dict:
|
||||
"""
|
||||
@@ -558,7 +683,67 @@ class MemoryAgentService:
|
||||
logger.debug(f"Message type: {status}")
|
||||
return status
|
||||
|
||||
# ==================== 新增的三个接口方法 ====================
|
||||
async def generate_summary_from_retrieve(
|
||||
self,
|
||||
retrieve_info: str,
|
||||
history: List[Dict],
|
||||
query: str,
|
||||
config_id: str,
|
||||
db: Session
|
||||
) -> str:
|
||||
"""
|
||||
基于检索信息、历史对话和查询生成最终答案
|
||||
|
||||
使用 Retrieve_Summary_prompt.jinja2 模板调用大模型生成答案
|
||||
|
||||
Args:
|
||||
retrieve_info: 检索到的信息
|
||||
history: 历史对话记录
|
||||
query: 用户查询
|
||||
config_id: 配置ID
|
||||
db: 数据库会话
|
||||
|
||||
Returns:
|
||||
生成的答案文本
|
||||
"""
|
||||
logger.info(f"Generating summary from retrieve info for query: {query[:50]}...")
|
||||
|
||||
try:
|
||||
# 加载配置
|
||||
config_service = MemoryConfigService(db)
|
||||
memory_config = config_service.load_memory_config(
|
||||
config_id=config_id,
|
||||
service_name="MemoryAgentService"
|
||||
)
|
||||
|
||||
# 导入必要的模块
|
||||
from app.core.memory.agent.langgraph_graph.nodes.summary_nodes import summary_llm
|
||||
from app.core.memory.agent.models.summary_models import RetrieveSummaryResponse
|
||||
|
||||
# 构建状态对象
|
||||
state = {
|
||||
"data": query,
|
||||
"memory_config": memory_config
|
||||
}
|
||||
|
||||
# 直接调用 summary_llm 函数
|
||||
answer = await summary_llm(
|
||||
state=state,
|
||||
history=history,
|
||||
retrieve_info=retrieve_info,
|
||||
template_name='Retrieve_Summary_prompt.jinja2',
|
||||
operation_name='retrieve_summary',
|
||||
response_model=RetrieveSummaryResponse,
|
||||
search_mode="1"
|
||||
)
|
||||
|
||||
logger.info(f"Successfully generated summary: {answer[:100] if answer else 'None'}...")
|
||||
return answer if answer else "信息不足,无法回答。"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"生成摘要失败: {str(e)}", exc_info=True)
|
||||
return "信息不足,无法回答。"
|
||||
|
||||
|
||||
async def get_knowledge_type_stats(
|
||||
self,
|
||||
@@ -692,7 +877,9 @@ class MemoryAgentService:
|
||||
async def get_hot_memory_tags_by_user(
|
||||
self,
|
||||
end_user_id: Optional[str] = None,
|
||||
limit: int = 20
|
||||
limit: int = 20,
|
||||
model_id: Optional[str] = None,
|
||||
language_type: Optional[str] = "zh"
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
获取指定用户的热门记忆标签
|
||||
@@ -710,7 +897,13 @@ class MemoryAgentService:
|
||||
try:
|
||||
# by_user=False 表示按 group_id 查询(在Neo4j中,group_id就是用户维度)
|
||||
tags = await get_hot_memory_tags(end_user_id, limit=limit, by_user=False)
|
||||
payload = [{"name": t, "frequency": f} for t, f in tags]
|
||||
payload=[]
|
||||
for tag, freq in tags:
|
||||
if language_type!="zh":
|
||||
tag=await Translation_English(model_id, tag)
|
||||
payload.append({"name": tag, "frequency": freq})
|
||||
else:
|
||||
payload.append({"name": tag, "frequency": freq})
|
||||
return payload
|
||||
except Exception as e:
|
||||
logger.error(f"热门记忆标签查询失败: {e}")
|
||||
@@ -1024,7 +1217,7 @@ def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session)
|
||||
"""
|
||||
from app.models.app_release_model import AppRelease
|
||||
from app.models.end_user_model import EndUser
|
||||
from app.models.memory_config_model import MemoryConfig
|
||||
from app.models.data_config_model import DataConfig
|
||||
from sqlalchemy import select
|
||||
|
||||
logger.info(f"Batch getting connected configs for {len(end_user_ids)} end_users")
|
||||
@@ -1082,8 +1275,8 @@ def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session)
|
||||
# 批量查询 memory_config_name
|
||||
config_id_to_name = {}
|
||||
if memory_config_ids:
|
||||
memory_configs = db.query(MemoryConfig).filter(MemoryConfig.id.in_(memory_config_ids)).all()
|
||||
config_id_to_name = {str(mc.id): mc.config_name for mc in memory_configs}
|
||||
memory_configs = db.query(DataConfig).filter(DataConfig.config_id.in_(memory_config_ids)).all()
|
||||
config_id_to_name = {str(mc.config_id): mc.config_name for mc in memory_configs}
|
||||
|
||||
# 4. 构建最终结果
|
||||
for end_user_id, app_id in user_to_app.items():
|
||||
@@ -1100,7 +1293,7 @@ def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session)
|
||||
memory_config_id = memory_obj.get('memory_content') if isinstance(memory_obj, dict) else None
|
||||
|
||||
# 获取配置名称
|
||||
memory_config_name = config_id_to_name.get(memory_config_id) if memory_config_id else None
|
||||
memory_config_name = config_id_to_name.get(str(memory_config_id)) if memory_config_id else None
|
||||
|
||||
result[end_user_id] = {
|
||||
"memory_config_id": memory_config_id,
|
||||
|
||||
@@ -3,17 +3,268 @@ Memory Base Service
|
||||
|
||||
提供记忆服务的基础功能和共享辅助方法。
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
from app.core.logging_config import get_logger
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.services.emotion_analytics_service import EmotionAnalyticsService
|
||||
|
||||
from app.core.memory.llm_tools.openai_client import OpenAIClient
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
from app.db import get_db_context
|
||||
logger = get_logger(__name__)
|
||||
class TranslationResponse(BaseModel):
|
||||
"""翻译响应模型"""
|
||||
data: str
|
||||
|
||||
class MemoryTransService:
|
||||
"""记忆翻译服务,提供中英文翻译功能"""
|
||||
|
||||
def __init__(self, llm_client=None, model_id: Optional[str] = None):
|
||||
"""
|
||||
初始化翻译服务
|
||||
|
||||
Args:
|
||||
llm_client: LLM客户端实例或模型ID字符串(可选)
|
||||
model_id: 模型ID,用于初始化LLM客户端(可选)
|
||||
|
||||
Note:
|
||||
- 如果llm_client是字符串,会被当作model_id使用
|
||||
- 如果同时提供llm_client和model_id,优先使用llm_client
|
||||
"""
|
||||
# 处理llm_client参数:如果是字符串,当作model_id
|
||||
if isinstance(llm_client, str):
|
||||
self.model_id = llm_client
|
||||
self.llm_client = None
|
||||
else:
|
||||
self.llm_client = llm_client
|
||||
self.model_id = model_id
|
||||
|
||||
self._initialized = False
|
||||
|
||||
def _ensure_llm_client(self):
|
||||
"""确保LLM客户端已初始化"""
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
if self.llm_client is None:
|
||||
if self.model_id:
|
||||
with get_db_context() as db:
|
||||
config_service = MemoryConfigService(db)
|
||||
model_config = config_service.get_model_config(self.model_id)
|
||||
|
||||
extra_params = {
|
||||
"temperature": 0.2,
|
||||
"max_tokens": 400,
|
||||
"top_p": 0.8,
|
||||
"stream": False,
|
||||
}
|
||||
|
||||
self.llm_client = OpenAIClient(
|
||||
RedBearModelConfig(
|
||||
model_name=model_config.get("model_name"),
|
||||
provider=model_config.get("provider"),
|
||||
api_key=model_config.get("api_key"),
|
||||
base_url=model_config.get("base_url"),
|
||||
timeout=model_config.get("timeout", 30),
|
||||
max_retries=model_config.get("max_retries", 3),
|
||||
extra_params=extra_params
|
||||
),
|
||||
type_=model_config.get("type")
|
||||
)
|
||||
else:
|
||||
raise ValueError("必须提供 llm_client 或 model_id 之一")
|
||||
|
||||
self._initialized = True
|
||||
|
||||
async def translate_to_english(self, text: str) -> str:
|
||||
"""
|
||||
将中文翻译为英文
|
||||
|
||||
Args:
|
||||
text: 要翻译的中文文本
|
||||
|
||||
Returns:
|
||||
翻译后的英文文本
|
||||
"""
|
||||
self._ensure_llm_client()
|
||||
|
||||
translation_messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"{text}\n\n中文翻译为英文,输出格式为{{\"data\":\"翻译后的内容\"}}"
|
||||
}
|
||||
]
|
||||
|
||||
try:
|
||||
response = await self.llm_client.response_structured(
|
||||
messages=translation_messages,
|
||||
response_model=TranslationResponse
|
||||
)
|
||||
return response.data
|
||||
except Exception as e:
|
||||
logger.error(f"翻译失败: {str(e)}")
|
||||
return text # 翻译失败时返回原文
|
||||
|
||||
async def is_english(self, text: str) -> bool:
|
||||
"""
|
||||
检查文本是否为英文
|
||||
|
||||
Args:
|
||||
text: 要检查的文本(必须是字符串)
|
||||
|
||||
Returns:
|
||||
True 如果文本主要是英文,False 否则
|
||||
|
||||
Note:
|
||||
- 只接受字符串类型
|
||||
- 检查是否主要由英文字母和常见标点组成
|
||||
- 允许数字、空格和常见标点符号
|
||||
"""
|
||||
if not isinstance(text, str):
|
||||
raise TypeError(f"is_english 只接受字符串类型,收到: {type(text).__name__}")
|
||||
|
||||
if not text.strip():
|
||||
return True # 空字符串视为英文
|
||||
|
||||
# 更宽松的英文检查:允许字母、数字、空格和常见标点
|
||||
# 如果文本中英文字符占比超过 80%,认为是英文
|
||||
english_chars = sum(1 for c in text if c.isascii() and (c.isalnum() or c.isspace() or c in '.,!?;:\'"()-'))
|
||||
total_chars = len(text)
|
||||
|
||||
if total_chars == 0:
|
||||
return True
|
||||
|
||||
return (english_chars / total_chars) >= 0.8
|
||||
async def Translate(self, text: str, target_language: str = "en") -> str:
|
||||
"""
|
||||
通用翻译方法(保持向后兼容)
|
||||
|
||||
Args:
|
||||
text: 要翻译的文本
|
||||
target_language: 目标语言,"en"表示英文,"zh"表示中文
|
||||
|
||||
Returns:
|
||||
翻译后的文本
|
||||
"""
|
||||
if target_language == "en":
|
||||
return await self.translate_to_english(text)
|
||||
else:
|
||||
logger.warning(f"不支持的目标语言: {target_language},返回原文")
|
||||
return text
|
||||
|
||||
|
||||
|
||||
# 测试翻译服务
|
||||
async def Translation_English(modid, text, fields=None):
|
||||
"""
|
||||
将数据翻译为英文(支持字段级翻译)
|
||||
|
||||
Args:
|
||||
modid: 模型ID
|
||||
text: 要翻译的数据(可以是字符串、字典或列表)
|
||||
fields: 需要翻译的字段列表(可选)
|
||||
如果为None,默认翻译: ['content', 'summary', 'statement', 'description',
|
||||
'name', 'aliases', 'caption', 'emotion_keywords']
|
||||
|
||||
Returns:
|
||||
翻译后的数据,保持原有结构
|
||||
|
||||
Note:
|
||||
- 对于字符串:直接翻译
|
||||
- 对于列表:递归处理每个元素,保持列表长度和索引不变
|
||||
- 对于字典:只翻译指定字段(fields参数)
|
||||
- 对于其他类型:原样返回
|
||||
"""
|
||||
trans_service = MemoryTransService(modid)
|
||||
|
||||
# 处理字符串类型
|
||||
if isinstance(text, str):
|
||||
# 空字符串直接返回
|
||||
if not text.strip():
|
||||
return text
|
||||
|
||||
try:
|
||||
is_eng = await trans_service.is_english(text)
|
||||
if not is_eng:
|
||||
english_result = await trans_service.Translate(text)
|
||||
return english_result
|
||||
return text
|
||||
except Exception as e:
|
||||
logger.warning(f"翻译字符串失败: {e}")
|
||||
return text
|
||||
|
||||
# 处理列表类型
|
||||
elif isinstance(text, list):
|
||||
english_result = []
|
||||
for item in text:
|
||||
# 递归处理列表中的每个元素
|
||||
if isinstance(item, str):
|
||||
# 字符串元素:检查是否需要翻译
|
||||
if not item.strip():
|
||||
english_result.append(item)
|
||||
continue
|
||||
|
||||
try:
|
||||
is_eng = await trans_service.is_english(item)
|
||||
if not is_eng:
|
||||
translated = await trans_service.Translate(item)
|
||||
english_result.append(translated)
|
||||
else:
|
||||
# 保留英文项,不改变列表长度
|
||||
english_result.append(item)
|
||||
except Exception as e:
|
||||
logger.warning(f"翻译列表项失败: {e}")
|
||||
english_result.append(item)
|
||||
|
||||
elif isinstance(item, dict):
|
||||
# 字典元素:递归调用自己处理字典
|
||||
translated_dict = await Translation_English(modid, item, fields)
|
||||
english_result.append(translated_dict)
|
||||
|
||||
elif isinstance(item, list):
|
||||
# 嵌套列表:递归处理
|
||||
translated_list = await Translation_English(modid, item, fields)
|
||||
english_result.append(translated_list)
|
||||
|
||||
else:
|
||||
# 其他类型(数字、布尔值等):原样保留
|
||||
english_result.append(item)
|
||||
|
||||
return english_result
|
||||
|
||||
# 处理字典类型
|
||||
elif isinstance(text, dict):
|
||||
# 确定要翻译的字段
|
||||
if fields is None:
|
||||
# 默认翻译字段
|
||||
fields = [
|
||||
'content', 'summary', 'statement', 'description',
|
||||
'name', 'aliases', 'caption', 'emotion_keywords',
|
||||
'text', 'title', 'label', 'type' # 添加常用字段
|
||||
]
|
||||
|
||||
# 创建副本,避免修改原始数据
|
||||
result = text.copy()
|
||||
|
||||
for field in fields:
|
||||
if field in result and result[field] is not None:
|
||||
# 递归翻译字段值(可能是字符串、列表或嵌套字典)
|
||||
try:
|
||||
result[field] = await Translation_English(modid, result[field], fields)
|
||||
except Exception as e:
|
||||
logger.warning(f"翻译字段 {field} 失败: {e}")
|
||||
# 翻译失败时保留原值
|
||||
continue
|
||||
|
||||
return result
|
||||
|
||||
# 其他类型(数字、布尔值、None等):原样返回
|
||||
else:
|
||||
return text
|
||||
class MemoryBaseService:
|
||||
"""记忆服务基类,提供共享的辅助方法"""
|
||||
|
||||
@@ -294,4 +545,4 @@ class MemoryBaseService:
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取遗忘记忆数量时出错: {str(e)}", exc_info=True)
|
||||
return 0
|
||||
return 0
|
||||
@@ -125,7 +125,11 @@ class MemoryConfigService:
|
||||
try:
|
||||
validated_config_id = _validate_config_id(config_id)
|
||||
|
||||
# Step 1: Get config and workspace
|
||||
db_query_start = time.time()
|
||||
result = DataConfigRepository.get_config_with_workspace(self.db, validated_config_id)
|
||||
db_query_time = time.time() - db_query_start
|
||||
logger.info(f"[PERF] Config+Workspace query: {db_query_time:.4f}s")
|
||||
if not result:
|
||||
elapsed_ms = (time.time() - start_time) * 1000
|
||||
config_logger.error(
|
||||
@@ -144,16 +148,20 @@ class MemoryConfigService:
|
||||
|
||||
memory_config, workspace = result
|
||||
|
||||
# Validate embedding model
|
||||
embedding_uuid = validate_embedding_model(
|
||||
# Step 2: Validate embedding model (returns both UUID and name)
|
||||
embed_start = time.time()
|
||||
embedding_uuid, embedding_name = validate_embedding_model(
|
||||
validated_config_id,
|
||||
memory_config.embedding_id,
|
||||
self.db,
|
||||
workspace.tenant_id,
|
||||
workspace.id,
|
||||
)
|
||||
embed_time = time.time() - embed_start
|
||||
logger.info(f"[PERF] Embedding validation: {embed_time:.4f}s")
|
||||
|
||||
# Resolve LLM model
|
||||
# Step 3: Resolve LLM model
|
||||
llm_start = time.time()
|
||||
llm_uuid, llm_name = validate_and_resolve_model_id(
|
||||
memory_config.llm_id,
|
||||
"llm",
|
||||
@@ -163,8 +171,11 @@ class MemoryConfigService:
|
||||
config_id=validated_config_id,
|
||||
workspace_id=workspace.id,
|
||||
)
|
||||
llm_time = time.time() - llm_start
|
||||
logger.info(f"[PERF] LLM validation: {llm_time:.4f}s")
|
||||
|
||||
# Resolve optional rerank model
|
||||
# Step 4: Resolve optional rerank model
|
||||
rerank_start = time.time()
|
||||
rerank_uuid = None
|
||||
rerank_name = None
|
||||
if memory_config.rerank_id:
|
||||
@@ -177,16 +188,12 @@ class MemoryConfigService:
|
||||
config_id=validated_config_id,
|
||||
workspace_id=workspace.id,
|
||||
)
|
||||
rerank_time = time.time() - rerank_start
|
||||
if memory_config.rerank_id:
|
||||
logger.info(f"[PERF] Rerank validation: {rerank_time:.4f}s")
|
||||
|
||||
# Get embedding model name
|
||||
embedding_name, _ = validate_model_exists_and_active(
|
||||
embedding_uuid,
|
||||
"embedding",
|
||||
self.db,
|
||||
workspace.tenant_id,
|
||||
config_id=validated_config_id,
|
||||
workspace_id=workspace.id,
|
||||
)
|
||||
# Note: embedding_name is now returned from validate_embedding_model above
|
||||
# No need for redundant query!
|
||||
|
||||
# Create immutable MemoryConfig object
|
||||
config = MemoryConfig(
|
||||
|
||||
@@ -16,6 +16,7 @@ import json
|
||||
from datetime import datetime
|
||||
|
||||
from app.schemas.memory_episodic_schema import EmotionType
|
||||
from app.services.memory_base_service import Translation_English
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -24,7 +25,7 @@ class MemoryEntityService:
|
||||
self.id = id
|
||||
self.table = table
|
||||
self.connector = Neo4jConnector()
|
||||
async def get_timeline_memories_server(self):
|
||||
async def get_timeline_memories_server(self,model_id, language_type):
|
||||
"""
|
||||
获取时间线记忆数据
|
||||
|
||||
@@ -48,10 +49,10 @@ class MemoryEntityService:
|
||||
logger.info(f"获取时间线记忆数据 - ID: {self.id}, Table: {self.table}")
|
||||
|
||||
# 根据表类型选择查询
|
||||
if self.table == 'Statement':
|
||||
if self.table == 'Statement':
|
||||
# Statement只需要输入ID,使用简化查询
|
||||
results = await self.connector.execute_query(Memory_Timeline_Statement, id=self.id)
|
||||
elif self.table == 'ExtractedEntity':
|
||||
elif self.table == 'ExtractedEntity':
|
||||
# ExtractedEntity类型查询
|
||||
results = await self.connector.execute_query(Memory_Timeline_ExtractedEntity, id=self.id)
|
||||
else:
|
||||
@@ -62,7 +63,7 @@ class MemoryEntityService:
|
||||
logger.info(f"时间线查询结果类型: {type(results)}, 长度: {len(results) if isinstance(results, list) else 'N/A'}")
|
||||
|
||||
# 处理查询结果
|
||||
timeline_data = self._process_timeline_results(results)
|
||||
timeline_data =await self._process_timeline_results(results, model_id, language_type)
|
||||
|
||||
logger.info(f"成功获取时间线记忆数据: 总计 {len(timeline_data.get('timelines_memory', []))} 条")
|
||||
|
||||
@@ -71,12 +72,14 @@ class MemoryEntityService:
|
||||
except Exception as e:
|
||||
logger.error(f"获取时间线记忆数据失败: {str(e)}", exc_info=True)
|
||||
return str(e)
|
||||
def _process_timeline_results(self, results: List[Dict[str, Any]]) -> Dict[str, Any]:
|
||||
async def _process_timeline_results(self, results: List[Dict[str, Any]], model_id: str, language_type: str) -> Dict[str, Any]:
|
||||
"""
|
||||
处理时间线查询结果
|
||||
|
||||
Args:
|
||||
results: Neo4j查询结果
|
||||
model_id: 模型ID用于翻译
|
||||
language_type: 语言类型 ('zh' 或其他)
|
||||
|
||||
Returns:
|
||||
处理后的时间线数据字典
|
||||
@@ -104,19 +107,19 @@ class MemoryEntityService:
|
||||
# 处理MemorySummary
|
||||
summary = data.get('MemorySummary')
|
||||
if summary is not None:
|
||||
processed_summary = self._process_field_value(summary, "MemorySummary")
|
||||
processed_summary = await self._process_field_value(summary, "MemorySummary")
|
||||
memory_summary_list.extend(processed_summary)
|
||||
|
||||
# 处理Statement
|
||||
statement = data.get('statement')
|
||||
if statement is not None:
|
||||
processed_statement = self._process_field_value(statement, "Statement")
|
||||
processed_statement = await self._process_field_value(statement, "Statement")
|
||||
statement_list.extend(processed_statement)
|
||||
|
||||
# 处理ExtractedEntity
|
||||
extracted_entity = data.get('ExtractedEntity')
|
||||
if extracted_entity is not None:
|
||||
processed_entity = self._process_field_value(extracted_entity, "ExtractedEntity")
|
||||
processed_entity = await self._process_field_value(extracted_entity, "ExtractedEntity")
|
||||
extracted_entity_list.extend(processed_entity)
|
||||
|
||||
# 去重 - 现在处理的是字典列表,需要更智能的去重
|
||||
@@ -128,6 +131,8 @@ class MemoryEntityService:
|
||||
all_timeline_data = memory_summary_list + statement_list
|
||||
all_timeline_data = self._merge_same_text_items(all_timeline_data)
|
||||
|
||||
# 如果需要翻译(非中文),对整个结果进行翻译
|
||||
|
||||
result = {
|
||||
"MemorySummary": memory_summary_list,
|
||||
"Statement": statement_list,
|
||||
@@ -233,7 +238,7 @@ class MemoryEntityService:
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def _process_field_value(self, value: Any, field_name: str) -> List[Dict[str, Any]]:
|
||||
async def _process_field_value(self, value: Any, field_name: str) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
处理字段值,支持字符串、列表等类型
|
||||
|
||||
@@ -251,13 +256,13 @@ class MemoryEntityService:
|
||||
# 如果是列表,处理每个元素
|
||||
for item in value:
|
||||
if self._is_valid_item(item):
|
||||
processed_item = self._process_single_item(item)
|
||||
processed_item = await self._process_single_item(item)
|
||||
if processed_item:
|
||||
processed_values.append(processed_item)
|
||||
elif isinstance(value, dict):
|
||||
# 如果是字典,直接处理
|
||||
if self._is_valid_item(value):
|
||||
processed_item = self._process_single_item(value)
|
||||
processed_item = await self._process_single_item(value)
|
||||
if processed_item:
|
||||
processed_values.append(processed_item)
|
||||
elif isinstance(value, str):
|
||||
@@ -304,7 +309,7 @@ class MemoryEntityService:
|
||||
return (str(item).strip() != '' and
|
||||
"MemorySummaryChunk" not in str(item))
|
||||
|
||||
def _process_single_item(self, item: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
||||
async def _process_single_item(self, item: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
处理单个项目
|
||||
|
||||
@@ -369,6 +374,117 @@ class MemoryEntityService:
|
||||
logger.warning(f"转换时间格式失败: {e}, 原始值: {dt}")
|
||||
return str(dt) if dt is not None else None
|
||||
|
||||
async def _translate_list(
|
||||
self,
|
||||
data_list: List[Dict[str, Any]],
|
||||
model_id: str,
|
||||
fields: List[str]
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
翻译列表中每个字典的指定字段(并发有限度以降低整体延迟)
|
||||
|
||||
Args:
|
||||
data_list: 要翻译的字典列表
|
||||
model_id: 模型ID
|
||||
fields: 需要翻译的字段列表
|
||||
|
||||
Returns:
|
||||
翻译后的字典列表
|
||||
"""
|
||||
# 空列表或无字段时直接返回
|
||||
if not data_list or not fields:
|
||||
return data_list
|
||||
|
||||
import asyncio
|
||||
|
||||
# 并发限制,避免一次性发起过多请求
|
||||
# 可根据实际情况调整(建议 5-10)
|
||||
concurrency_limit = 5
|
||||
semaphore = asyncio.Semaphore(concurrency_limit)
|
||||
|
||||
async def translate_single_field(
|
||||
index: int,
|
||||
field: str,
|
||||
value: Any,
|
||||
) -> Optional[tuple]:
|
||||
"""
|
||||
翻译单个字段并返回 (索引, 字段名, 翻译结果)
|
||||
|
||||
Returns:
|
||||
(index, field, translated_value) 或 None(如果跳过)
|
||||
"""
|
||||
# 跳过空值
|
||||
if value is None or value == "":
|
||||
return None
|
||||
|
||||
# 统一转成字符串再翻译,防止非字符串类型导致错误
|
||||
text = str(value)
|
||||
|
||||
try:
|
||||
async with semaphore:
|
||||
# 调用 Translation_English 进行翻译
|
||||
# 注意:Translation_English 的参数顺序是 (model_id, text)
|
||||
translated = await Translation_English(model_id, text)
|
||||
|
||||
# 如果翻译结果为空,保留原值
|
||||
if translated is None or translated == "":
|
||||
return None
|
||||
|
||||
return index, field, translated
|
||||
except Exception as e:
|
||||
logger.warning(f"翻译字段 {field} (索引 {index}) 失败: {e}")
|
||||
return None
|
||||
|
||||
# 构造所有需要翻译的任务
|
||||
tasks = []
|
||||
for idx, item in enumerate(data_list):
|
||||
# 防御性检查:确保 item 是字典
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
|
||||
for field in fields:
|
||||
if field not in item:
|
||||
continue
|
||||
|
||||
value = item.get(field)
|
||||
|
||||
# 对于 None 或空字符串的值,直接跳过,不创建任务
|
||||
if value is None or value == "":
|
||||
continue
|
||||
|
||||
tasks.append(
|
||||
asyncio.create_task(
|
||||
translate_single_field(idx, field, value)
|
||||
)
|
||||
)
|
||||
|
||||
# 如果没有需要翻译的任务,直接返回原列表
|
||||
if not tasks:
|
||||
return data_list
|
||||
|
||||
# 使用 gather 并发执行翻译任务(受 semaphore 限制)
|
||||
# return_exceptions=True 可以防止单个任务失败导致整体失败
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
# 创建深拷贝以避免修改原始数据
|
||||
translated_list = [item.copy() if isinstance(item, dict) else item for item in data_list]
|
||||
|
||||
# 将翻译结果回填到列表
|
||||
for result in results:
|
||||
# 跳过 None 结果和异常
|
||||
if result is None or isinstance(result, Exception):
|
||||
if isinstance(result, Exception):
|
||||
logger.warning(f"翻译任务异常: {result}")
|
||||
continue
|
||||
|
||||
idx, field, translated = result
|
||||
|
||||
# 防御性检查索引范围
|
||||
if 0 <= idx < len(translated_list) and isinstance(translated_list[idx], dict):
|
||||
translated_list[idx][field] = translated
|
||||
|
||||
return translated_list
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -426,15 +542,19 @@ class MemoryEmotion:
|
||||
# 如果解析失败,返回原始字符串
|
||||
return iso_string
|
||||
|
||||
async def get_emotion(self) -> Dict[str, Any]:
|
||||
async def get_emotion(self, model_id: str = None, language_type: str = 'zh') -> Dict[str, Any]:
|
||||
"""
|
||||
获取情绪随时间变化数据
|
||||
|
||||
Args:
|
||||
model_id: 模型ID用于翻译
|
||||
language_type: 语言类型 ('zh' 或其他)
|
||||
|
||||
Returns:
|
||||
包含情绪数据的字典
|
||||
"""
|
||||
try:
|
||||
logger.info(f"获取情绪数据 - ID: {self.id}, Table: {self.table}")
|
||||
logger.info(f"获取情绪数据 - ID: {self.id}, Table: {self.table}, language_type={language_type}")
|
||||
|
||||
if self.table == 'Statement':
|
||||
results = await self.connector.execute_query(Memory_Space_Emotion_Statement, id=self.id)
|
||||
@@ -450,6 +570,10 @@ class MemoryEmotion:
|
||||
# 转换Neo4j类型
|
||||
final_data = self._convert_neo4j_types(emotion_data)
|
||||
|
||||
# 如果需要翻译(非中文)
|
||||
if language_type != 'zh' and model_id and final_data:
|
||||
final_data = await self._translate_emotion_data(final_data, model_id)
|
||||
|
||||
logger.info(f"成功获取 {len(final_data)} 条情绪数据")
|
||||
|
||||
return final_data
|
||||
@@ -590,16 +714,14 @@ class MemoryInteraction:
|
||||
"""
|
||||
try:
|
||||
logger.info(f"获取交互数据 - ID: {self.id}, Table: {self.table}")
|
||||
|
||||
ori_data= await self.connector.execute_query(Memory_Space_Entity, id=self.id)
|
||||
if ori_data!=[]:
|
||||
# name = ori_data[0]['name']
|
||||
group_id = ori_data[0]['group_id']
|
||||
group_id = [i['group_id'] for i in ori_data][0]
|
||||
Space_User = await self.connector.execute_query(Memory_Space_User, group_id=group_id)
|
||||
if not Space_User:
|
||||
return []
|
||||
user_id=Space_User[0]['id']
|
||||
|
||||
results = await self.connector.execute_query(Memory_Space_Associative, id=self.id,user_id=user_id)
|
||||
|
||||
|
||||
|
||||
@@ -506,27 +506,6 @@ async def search_edges(end_user_id: Optional[str] = None) -> List[Dict[str, Any]
|
||||
return result
|
||||
|
||||
|
||||
async def search_entity_graph(end_user_id: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""搜索所有实体之间的关系网络(group 维度)。"""
|
||||
result = await _neo4j_connector.execute_query(
|
||||
DataConfigRepository.SEARCH_FOR_ENTITY_GRAPH,
|
||||
group_id=end_user_id,
|
||||
)
|
||||
# 对source_node 和 target_node 的 fact_summary进行截取,只截取前三条的内容(需要提取前三条“来源”)
|
||||
for item in result:
|
||||
source_fact = item["sourceNode"]["fact_summary"]
|
||||
target_fact = item["targetNode"]["fact_summary"]
|
||||
# 截取前三条“来源”
|
||||
item["sourceNode"]["fact_summary"] = source_fact.split("\n")[:4] if source_fact else []
|
||||
item["targetNode"]["fact_summary"] = target_fact.split("\n")[:4] if target_fact else []
|
||||
# 与现有返回风格保持一致,携带搜索类型、数量与详情
|
||||
data = {
|
||||
"search_for": "entity_graph",
|
||||
"num": len(result),
|
||||
"detials": result,
|
||||
}
|
||||
return data
|
||||
|
||||
|
||||
async def analytics_hot_memory_tags(
|
||||
db: Session,
|
||||
|
||||
@@ -18,7 +18,7 @@ from app.repositories.end_user_repository import EndUserRepository
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.schemas.memory_episodic_schema import EmotionSubject, EmotionType, type_mapping
|
||||
from app.services.implicit_memory_service import ImplicitMemoryService
|
||||
from app.services.memory_base_service import MemoryBaseService
|
||||
from app.services.memory_base_service import MemoryBaseService, MemoryTransService, Translation_English
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
from app.services.memory_perceptual_service import MemoryPerceptualService
|
||||
from app.services.memory_short_service import ShortService
|
||||
@@ -357,10 +357,107 @@ class UserMemoryService:
|
||||
data[key] = UserMemoryService._datetime_to_timestamp(original_value)
|
||||
return data
|
||||
|
||||
def update_end_user_profile(
|
||||
self,
|
||||
db: Session,
|
||||
end_user_id: str,
|
||||
profile_update: Any
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
更新终端用户的基本信息
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
end_user_id: 终端用户ID (UUID)
|
||||
profile_update: 包含更新字段的 Pydantic 模型
|
||||
|
||||
Returns:
|
||||
{
|
||||
"success": bool,
|
||||
"data": dict, # 更新后的用户档案数据
|
||||
"error": Optional[str]
|
||||
}
|
||||
"""
|
||||
try:
|
||||
# 转换为UUID并查询用户
|
||||
user_uuid = uuid.UUID(end_user_id)
|
||||
repo = EndUserRepository(db)
|
||||
end_user = repo.get_by_id(user_uuid)
|
||||
|
||||
if not end_user:
|
||||
logger.warning(f"终端用户不存在: end_user_id={end_user_id}")
|
||||
return {
|
||||
"success": False,
|
||||
"data": None,
|
||||
"error": "终端用户不存在"
|
||||
}
|
||||
|
||||
# 获取更新数据(排除 end_user_id 字段)
|
||||
update_data = profile_update.model_dump(exclude_unset=True, exclude={'end_user_id'})
|
||||
|
||||
# 特殊处理 hire_date:如果提供了时间戳,转换为 DateTime
|
||||
if 'hire_date' in update_data:
|
||||
hire_date_timestamp = update_data['hire_date']
|
||||
if hire_date_timestamp is not None:
|
||||
from app.core.api_key_utils import timestamp_to_datetime
|
||||
update_data['hire_date'] = timestamp_to_datetime(hire_date_timestamp)
|
||||
# 如果是 None,保持 None(允许清空)
|
||||
|
||||
# 更新字段
|
||||
for field, value in update_data.items():
|
||||
setattr(end_user, field, value)
|
||||
|
||||
# 更新时间戳
|
||||
end_user.updated_at = datetime.now()
|
||||
end_user.updatetime_profile = datetime.now()
|
||||
|
||||
# 提交更改
|
||||
db.commit()
|
||||
db.refresh(end_user)
|
||||
|
||||
# 构建响应数据
|
||||
from app.schemas.end_user_schema import EndUserProfileResponse
|
||||
profile_data = EndUserProfileResponse(
|
||||
id=end_user.id,
|
||||
other_name=end_user.other_name,
|
||||
position=end_user.position,
|
||||
department=end_user.department,
|
||||
contact=end_user.contact,
|
||||
phone=end_user.phone,
|
||||
hire_date=end_user.hire_date,
|
||||
updatetime_profile=end_user.updatetime_profile
|
||||
)
|
||||
|
||||
logger.info(f"成功更新用户信息: end_user_id={end_user_id}, updated_fields={list(update_data.keys())}")
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"data": self.convert_profile_to_dict_with_timestamp(profile_data),
|
||||
"error": None
|
||||
}
|
||||
|
||||
except ValueError:
|
||||
logger.error(f"无效的 end_user_id 格式: {end_user_id}")
|
||||
return {
|
||||
"success": False,
|
||||
"data": None,
|
||||
"error": "无效的用户ID格式"
|
||||
}
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
logger.error(f"用户信息更新失败: end_user_id={end_user_id}, error={str(e)}")
|
||||
return {
|
||||
"success": False,
|
||||
"data": None,
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
async def get_cached_memory_insight(
|
||||
self,
|
||||
db: Session,
|
||||
end_user_id: str
|
||||
end_user_id: str,
|
||||
model_id: str,
|
||||
language_type: str
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
从数据库获取缓存的记忆洞察(四个维度)
|
||||
@@ -419,11 +516,18 @@ class UserMemoryService:
|
||||
key_findings_array = []
|
||||
|
||||
logger.info(f"成功获取 end_user_id {end_user_id} 的缓存记忆洞察(四维度)")
|
||||
memory_insight=end_user.memory_insight
|
||||
behavior_pattern=end_user.behavior_pattern
|
||||
growth_trajectory=end_user.growth_trajectory
|
||||
if language_type!='zh':
|
||||
memory_insight=await Translation_English(model_id,memory_insight)
|
||||
behavior_pattern=await Translation_English(model_id,behavior_pattern)
|
||||
growth_trajectory=await Translation_English(model_id,growth_trajectory)
|
||||
return {
|
||||
"memory_insight": end_user.memory_insight, # 总体概述存储在 memory_insight
|
||||
"behavior_pattern": end_user.behavior_pattern,
|
||||
"memory_insight":memory_insight, # 总体概述存储在 memory_insight
|
||||
"behavior_pattern":behavior_pattern,
|
||||
"key_findings": key_findings_array, # 返回数组
|
||||
"growth_trajectory": end_user.growth_trajectory,
|
||||
"growth_trajectory": growth_trajectory,
|
||||
"updated_at": self._datetime_to_timestamp(end_user.memory_insight_updated_at),
|
||||
"is_cached": True
|
||||
}
|
||||
@@ -457,7 +561,9 @@ class UserMemoryService:
|
||||
async def get_cached_user_summary(
|
||||
self,
|
||||
db: Session,
|
||||
end_user_id: str
|
||||
end_user_id: str,
|
||||
model_id:str,
|
||||
language_type:str="zh"
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
从数据库获取缓存的用户摘要(四个部分)
|
||||
@@ -481,7 +587,6 @@ class UserMemoryService:
|
||||
user_uuid = uuid.UUID(end_user_id)
|
||||
repo = EndUserRepository(db)
|
||||
end_user = repo.get_by_id(user_uuid)
|
||||
|
||||
if not end_user:
|
||||
logger.warning(f"未找到 end_user_id 为 {end_user_id} 的用户")
|
||||
return {
|
||||
@@ -495,20 +600,29 @@ class UserMemoryService:
|
||||
}
|
||||
|
||||
# 检查是否有缓存数据(至少有一个字段不为空)
|
||||
user_summary=end_user.user_summary
|
||||
personality_traits=end_user.personality_traits
|
||||
core_values=end_user.core_values
|
||||
one_sentence_summary=end_user.one_sentence_summary
|
||||
if language_type!='zh':
|
||||
user_summary=await Translation_English(model_id, user_summary)
|
||||
personality_traits = await Translation_English(model_id, personality_traits)
|
||||
core_values = await Translation_English(model_id, core_values)
|
||||
one_sentence_summary = await Translation_English(model_id, one_sentence_summary)
|
||||
has_cache = any([
|
||||
end_user.user_summary,
|
||||
end_user.personality_traits,
|
||||
end_user.core_values,
|
||||
end_user.one_sentence_summary
|
||||
user_summary,
|
||||
personality_traits,
|
||||
core_values,
|
||||
one_sentence_summary
|
||||
])
|
||||
|
||||
if has_cache:
|
||||
logger.info(f"成功获取 end_user_id {end_user_id} 的缓存用户摘要")
|
||||
return {
|
||||
"user_summary": end_user.user_summary,
|
||||
"personality": end_user.personality_traits,
|
||||
"core_values": end_user.core_values,
|
||||
"one_sentence": end_user.one_sentence_summary,
|
||||
"user_summary": user_summary,
|
||||
"personality": personality_traits,
|
||||
"core_values":core_values,
|
||||
"one_sentence": one_sentence_summary,
|
||||
"updated_at": self._datetime_to_timestamp(end_user.user_summary_updated_at),
|
||||
"is_cached": True
|
||||
}
|
||||
@@ -1367,7 +1481,6 @@ async def analytics_memory_types(
|
||||
|
||||
return memory_types
|
||||
|
||||
|
||||
async def analytics_graph_data(
|
||||
db: Session,
|
||||
end_user_id: str,
|
||||
@@ -1557,7 +1670,7 @@ async def analytics_graph_data(
|
||||
f"成功获取图数据: end_user_id={end_user_id}, "
|
||||
f"nodes={len(nodes)}, edges={len(edges)}"
|
||||
)
|
||||
|
||||
|
||||
return {
|
||||
"nodes": nodes,
|
||||
"edges": edges,
|
||||
@@ -1606,11 +1719,7 @@ async def _extract_node_properties(label: str, properties: Dict[str, Any],node_
|
||||
|
||||
# 获取该节点类型的白名单字段
|
||||
allowed_fields = field_whitelist.get(label, [])
|
||||
|
||||
# 如果没有定义白名单,返回空字典(或者可以返回所有字段)
|
||||
# if not allowed_fields:
|
||||
# # 对于未定义的节点类型,只返回基本字段
|
||||
# allowed_fields = ["name", "created_at", "caption"]
|
||||
|
||||
count_neo4j=f"""MATCH (n)-[r]-(m) WHERE elementId(n) ="{node_id}" RETURN count(r) AS rel_count;"""
|
||||
node_results = await (_neo4j_connector.execute_query(count_neo4j))
|
||||
# 提取白名单中的字段
|
||||
@@ -1618,13 +1727,12 @@ async def _extract_node_properties(label: str, properties: Dict[str, Any],node_
|
||||
for field in allowed_fields:
|
||||
if field in properties:
|
||||
value = properties[field]
|
||||
if str(field) == 'entity_type':
|
||||
if str(field) == 'entity_type':
|
||||
value=type_mapping.get(value,'')
|
||||
if str(field)=="emotion_type":
|
||||
value=EmotionType.EMOTION_MAPPING.get(value)
|
||||
if str(field)=="emotion_subject":
|
||||
if str(field)=="emotion_subject":
|
||||
value=EmotionSubject.SUBJECT_MAPPING.get(value)
|
||||
# 清理 Neo4j 特殊类型
|
||||
filtered_props[field] = _clean_neo4j_value(value)
|
||||
filtered_props['associative_memory']=[i['rel_count'] for i in node_results][0]
|
||||
return filtered_props
|
||||
|
||||
Reference in New Issue
Block a user