config_id字段改成UUID,与develop校对恢复
This commit is contained in:
@@ -223,9 +223,12 @@ async def write_server_async(
|
|||||||
if knowledge: user_rag_memory_id = str(knowledge.id)
|
if knowledge: user_rag_memory_id = str(knowledge.id)
|
||||||
api_logger.info(f"Async write: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}")
|
api_logger.info(f"Async write: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}")
|
||||||
try:
|
try:
|
||||||
|
# 获取标准化的消息列表
|
||||||
|
messages_list = memory_agent_service.get_messages_list(user_input)
|
||||||
|
|
||||||
task = celery_app.send_task(
|
task = celery_app.send_task(
|
||||||
"app.core.memory.agent.write_message",
|
"app.core.memory.agent.write_message",
|
||||||
args=[user_input.end_user_id, user_input.message, config_id, storage_type, user_rag_memory_id]
|
args=[user_input.end_user_id, messages_list, config_id, storage_type, user_rag_memory_id]
|
||||||
)
|
)
|
||||||
api_logger.info(f"Write task queued: {task.id}")
|
api_logger.info(f"Write task queued: {task.id}")
|
||||||
|
|
||||||
@@ -598,7 +601,7 @@ async def status_type(
|
|||||||
last_user_message = " ".join([msg.get('content', '') for msg in messages_list])
|
last_user_message = " ".join([msg.get('content', '') for msg in messages_list])
|
||||||
|
|
||||||
result = await memory_agent_service.classify_message_type(
|
result = await memory_agent_service.classify_message_type(
|
||||||
user_input.messages,
|
last_user_message,
|
||||||
user_input.config_id,
|
user_input.config_id,
|
||||||
db
|
db
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -237,7 +237,7 @@ async def update_forgetting_config(
|
|||||||
|
|
||||||
@router.get("/stats", response_model=ApiResponse)
|
@router.get("/stats", response_model=ApiResponse)
|
||||||
async def get_forgetting_stats(
|
async def get_forgetting_stats(
|
||||||
group_id: Optional[str] = None,
|
end_user_id: Optional[str] = None,
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: Session = Depends(get_db)
|
db: Session = Depends(get_db)
|
||||||
):
|
):
|
||||||
@@ -263,18 +263,18 @@ async def get_forgetting_stats(
|
|||||||
|
|
||||||
# 如果提供了 group_id,通过它获取 config_id
|
# 如果提供了 group_id,通过它获取 config_id
|
||||||
config_id = None
|
config_id = None
|
||||||
if group_id:
|
if end_user_id:
|
||||||
try:
|
try:
|
||||||
from app.services.memory_agent_service import get_end_user_connected_config
|
from app.services.memory_agent_service import get_end_user_connected_config
|
||||||
|
|
||||||
connected_config = get_end_user_connected_config(group_id, db)
|
connected_config = get_end_user_connected_config(end_user_id, db)
|
||||||
config_id = connected_config.get("memory_config_id")
|
config_id = connected_config.get("memory_config_id")
|
||||||
|
|
||||||
if config_id is None:
|
if config_id is None:
|
||||||
api_logger.warning(f"终端用户 {group_id} 未关联记忆配置")
|
api_logger.warning(f"终端用户 {end_user_id} 未关联记忆配置")
|
||||||
return fail(BizCode.INVALID_PARAMETER, f"终端用户 {group_id} 未关联记忆配置", "memory_config_id is None")
|
return fail(BizCode.INVALID_PARAMETER, f"终端用户 {end_user_id} 未关联记忆配置", "memory_config_id is None")
|
||||||
|
|
||||||
api_logger.debug(f"通过 group_id={group_id} 获取到 config_id={config_id}")
|
api_logger.debug(f"通过 group_id={end_user_id} 获取到 config_id={config_id}")
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
api_logger.warning(f"获取终端用户配置失败: {str(e)}")
|
api_logger.warning(f"获取终端用户配置失败: {str(e)}")
|
||||||
return fail(BizCode.INVALID_PARAMETER, str(e), "ValueError")
|
return fail(BizCode.INVALID_PARAMETER, str(e), "ValueError")
|
||||||
@@ -284,14 +284,14 @@ async def get_forgetting_stats(
|
|||||||
|
|
||||||
api_logger.info(
|
api_logger.info(
|
||||||
f"用户 {current_user.username} 在工作空间 {workspace_id} 请求获取遗忘引擎统计: "
|
f"用户 {current_user.username} 在工作空间 {workspace_id} 请求获取遗忘引擎统计: "
|
||||||
f"group_id={group_id}, config_id={config_id}"
|
f"group_id={end_user_id}, config_id={config_id}"
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 调用服务层获取统计信息
|
# 调用服务层获取统计信息
|
||||||
stats = await forget_service.get_forgetting_stats(
|
stats = await forget_service.get_forgetting_stats(
|
||||||
db=db,
|
db=db,
|
||||||
group_id=group_id,
|
end_user_id=end_user_id,
|
||||||
config_id=config_id
|
config_id=config_id
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -50,7 +50,6 @@ async def write_memory_api_service(
|
|||||||
config_id=payload.config_id,
|
config_id=payload.config_id,
|
||||||
storage_type=payload.storage_type,
|
storage_type=payload.storage_type,
|
||||||
user_rag_memory_id=payload.user_rag_memory_id,
|
user_rag_memory_id=payload.user_rag_memory_id,
|
||||||
tenant_id=api_key_auth.tenant_id,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(f"Memory write successful for end_user: {payload.end_user_id}")
|
logger.info(f"Memory write successful for end_user: {payload.end_user_id}")
|
||||||
|
|||||||
@@ -145,33 +145,36 @@ class LangChainAgent:
|
|||||||
messages.append(HumanMessage(content=user_content))
|
messages.append(HumanMessage(content=user_content))
|
||||||
|
|
||||||
return messages
|
return messages
|
||||||
async def term_memory_save(self,messages,end_user_end,aimessages):
|
# TODO 乐力齐 - 累积多组对话批量写入功能已禁用
|
||||||
'''短长期存储redis,为不影响正常使用6句一段话,存储用户名加一个前缀,当数据存够6条返回给neo4j'''
|
# async def term_memory_save(self,messages,end_user_end,aimessages):
|
||||||
end_user_end=f"Term_{end_user_end}"
|
# '''短长期存储redis,为不影响正常使用6句一段话,存储用户名加一个前缀,当数据存够6条返回给neo4j'''
|
||||||
print(messages)
|
# end_user_end=f"Term_{end_user_end}"
|
||||||
print(aimessages)
|
# print(messages)
|
||||||
session_id = store.save_session(
|
# print(aimessages)
|
||||||
userid=end_user_end,
|
# session_id = store.save_session(
|
||||||
messages=messages,
|
# userid=end_user_end,
|
||||||
apply_id=end_user_end,
|
# messages=messages,
|
||||||
end_user_id=end_user_end,
|
# apply_id=end_user_end,
|
||||||
aimessages=aimessages
|
# group_id=end_user_end,
|
||||||
)
|
# aimessages=aimessages
|
||||||
store.delete_duplicate_sessions()
|
# )
|
||||||
# logger.info(f'Redis_Agent:{end_user_end};{session_id}')
|
# store.delete_duplicate_sessions()
|
||||||
return session_id
|
# # logger.info(f'Redis_Agent:{end_user_end};{session_id}')
|
||||||
async def term_memory_redis_read(self,end_user_end):
|
# return session_id
|
||||||
end_user_end = f"Term_{end_user_end}"
|
|
||||||
history = store.find_user_apply_group(end_user_end, end_user_end, end_user_end)
|
# TODO 乐力齐 - 累积多组对话批量写入功能已禁用
|
||||||
# logger.info(f'Redis_Agent:{end_user_end};{history}')
|
# async def term_memory_redis_read(self,end_user_end):
|
||||||
messagss_list=[]
|
# end_user_end = f"Term_{end_user_end}"
|
||||||
retrieved_content=[]
|
# history = store.find_user_apply_group(end_user_end, end_user_end, end_user_end)
|
||||||
for messages in history:
|
# # logger.info(f'Redis_Agent:{end_user_end};{history}')
|
||||||
query = messages.get("Query")
|
# messagss_list=[]
|
||||||
aimessages = messages.get("Answer")
|
# retrieved_content=[]
|
||||||
messagss_list.append(f'用户:{query}。AI回复:{aimessages}')
|
# for messages in history:
|
||||||
retrieved_content.append({query: aimessages})
|
# query = messages.get("Query")
|
||||||
return messagss_list,retrieved_content
|
# aimessages = messages.get("Answer")
|
||||||
|
# messagss_list.append(f'用户:{query}。AI回复:{aimessages}')
|
||||||
|
# retrieved_content.append({query: aimessages})
|
||||||
|
# return messagss_list,retrieved_content
|
||||||
|
|
||||||
async def write(self, storage_type, end_user_id, user_message, ai_message, user_rag_memory_id, actual_end_user_id, actual_config_id):
|
async def write(self, storage_type, end_user_id, user_message, ai_message, user_rag_memory_id, actual_end_user_id, actual_config_id):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -34,11 +34,17 @@ async def make_write_graph():
|
|||||||
end_user_id: Group identifier
|
end_user_id: Group identifier
|
||||||
memory_config: MemoryConfig object containing all configuration
|
memory_config: MemoryConfig object containing all configuration
|
||||||
"""
|
"""
|
||||||
|
# workflow = StateGraph(WriteState)
|
||||||
|
# workflow.add_node("content_input", content_input_write)
|
||||||
|
# workflow.add_node("save_neo4j", write_node)
|
||||||
|
# workflow.add_edge(START, "content_input")
|
||||||
|
# workflow.add_edge("content_input", "save_neo4j")
|
||||||
|
# workflow.add_edge("save_neo4j", END)
|
||||||
|
#
|
||||||
|
# graph = workflow.compile()
|
||||||
workflow = StateGraph(WriteState)
|
workflow = StateGraph(WriteState)
|
||||||
workflow.add_node("content_input", content_input_write)
|
|
||||||
workflow.add_node("save_neo4j", write_node)
|
workflow.add_node("save_neo4j", write_node)
|
||||||
workflow.add_edge(START, "content_input")
|
workflow.add_edge(START, "save_neo4j")
|
||||||
workflow.add_edge("content_input", "save_neo4j")
|
|
||||||
workflow.add_edge("save_neo4j", END)
|
workflow.add_edge("save_neo4j", END)
|
||||||
|
|
||||||
graph = workflow.compile()
|
graph = workflow.compile()
|
||||||
|
|||||||
@@ -30,7 +30,7 @@ from app.core.memory.storage_services.search import run_hybrid_search
|
|||||||
from app.core.memory.utils.config.definitions import (
|
from app.core.memory.utils.config.definitions import (
|
||||||
PROJECT_ROOT,
|
PROJECT_ROOT,
|
||||||
SELECTED_EMBEDDING_ID,
|
SELECTED_EMBEDDING_ID,
|
||||||
SELECTED_end_user_id,
|
SELECTED_GROUP_ID,
|
||||||
SELECTED_LLM_ID,
|
SELECTED_LLM_ID,
|
||||||
)
|
)
|
||||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||||
|
|||||||
@@ -27,7 +27,7 @@ from app.core.memory.storage_services.search import run_hybrid_search
|
|||||||
from app.core.memory.utils.config.definitions import (
|
from app.core.memory.utils.config.definitions import (
|
||||||
PROJECT_ROOT,
|
PROJECT_ROOT,
|
||||||
SELECTED_EMBEDDING_ID,
|
SELECTED_EMBEDDING_ID,
|
||||||
SELECTED_end_user_id,
|
SELECTED_GROUP_ID,
|
||||||
SELECTED_LLM_ID,
|
SELECTED_LLM_ID,
|
||||||
)
|
)
|
||||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||||
@@ -136,7 +136,7 @@ def _combine_dialogues_for_hybrid(results: Dict[str, Any]) -> List[Dict[str, Any
|
|||||||
|
|
||||||
|
|
||||||
async def run_memsciqa_eval(sample_size: int = 1, end_user_id: str | None = None, search_limit: int = 8, context_char_budget: int = 4000, llm_temperature: float = 0.0, llm_max_tokens: int = 64, search_type: str = "hybrid", memory_config: "MemoryConfig" = None) -> Dict[str, Any]:
|
async def run_memsciqa_eval(sample_size: int = 1, end_user_id: str | None = None, search_limit: int = 8, context_char_budget: int = 4000, llm_temperature: float = 0.0, llm_max_tokens: int = 64, search_type: str = "hybrid", memory_config: "MemoryConfig" = None) -> Dict[str, Any]:
|
||||||
end_user_id = end_user_id or SELECTED_end_user_id
|
end_user_id = end_user_id or SELECTED_GROUP_ID
|
||||||
# Load data
|
# Load data
|
||||||
data_path = os.path.join(PROJECT_ROOT, "data", "msc_self_instruct.jsonl")
|
data_path = os.path.join(PROJECT_ROOT, "data", "msc_self_instruct.jsonl")
|
||||||
if not os.path.exists(data_path):
|
if not os.path.exists(data_path):
|
||||||
|
|||||||
@@ -33,7 +33,7 @@ from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
|
|||||||
from app.core.memory.utils.config.definitions import (
|
from app.core.memory.utils.config.definitions import (
|
||||||
PROJECT_ROOT,
|
PROJECT_ROOT,
|
||||||
SELECTED_EMBEDDING_ID,
|
SELECTED_EMBEDDING_ID,
|
||||||
SELECTED_end_user_id,
|
SELECTED_GROUP_ID,
|
||||||
SELECTED_LLM_ID,
|
SELECTED_LLM_ID,
|
||||||
)
|
)
|
||||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ except Exception:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||||
from app.core.memory.utils.config.definitions import SELECTED_end_user_id, PROJECT_ROOT
|
from app.core.memory.utils.config.definitions import SELECTED_GROUP_ID, PROJECT_ROOT
|
||||||
|
|
||||||
from app.core.memory.evaluation.memsciqa.evaluate_qa import run_memsciqa_eval
|
from app.core.memory.evaluation.memsciqa.evaluate_qa import run_memsciqa_eval
|
||||||
from app.core.memory.evaluation.longmemeval.qwen_search_eval import run_longmemeval_test
|
from app.core.memory.evaluation.longmemeval.qwen_search_eval import run_longmemeval_test
|
||||||
@@ -37,7 +37,7 @@ async def run(
|
|||||||
max_contexts_per_item: int | None = None,
|
max_contexts_per_item: int | None = None,
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
# 恢复原始风格:统一入口做路由,并沿用各数据集既有默认
|
# 恢复原始风格:统一入口做路由,并沿用各数据集既有默认
|
||||||
end_user_id = end_user_id or SELECTED_end_user_id
|
end_user_id = end_user_id or SELECTED_GROUP_ID
|
||||||
|
|
||||||
if reset_group:
|
if reset_group:
|
||||||
connector = Neo4jConnector()
|
connector = Neo4jConnector()
|
||||||
|
|||||||
@@ -693,9 +693,6 @@ async def run_hybrid_search(
|
|||||||
# Start overall timing
|
# Start overall timing
|
||||||
search_start_time = time.time()
|
search_start_time = time.time()
|
||||||
latency_metrics = {}
|
latency_metrics = {}
|
||||||
print(100*'-')
|
|
||||||
print(memory_config)
|
|
||||||
print(100 * '-')
|
|
||||||
logger.info(f"using embedding_id:{memory_config.embedding_model_id}...")
|
logger.info(f"using embedding_id:{memory_config.embedding_model_id}...")
|
||||||
|
|
||||||
# Clean and normalize the incoming query before use/logging
|
# Clean and normalize the incoming query before use/logging
|
||||||
|
|||||||
@@ -209,7 +209,7 @@ class MemorySummaryRepository(BaseNeo4jRepository):
|
|||||||
|
|
||||||
results = await self.connector.execute_query(
|
results = await self.connector.execute_query(
|
||||||
query,
|
query,
|
||||||
group_id=group_id,
|
end_user_id=end_user_id,
|
||||||
days=days,
|
days=days,
|
||||||
limit=limit
|
limit=limit
|
||||||
)
|
)
|
||||||
@@ -253,7 +253,7 @@ class MemorySummaryRepository(BaseNeo4jRepository):
|
|||||||
results = await self.connector.execute_query(query, **params)
|
results = await self.connector.execute_query(query, **params)
|
||||||
return [self._map_to_dict(r) for r in results]
|
return [self._map_to_dict(r) for r in results]
|
||||||
|
|
||||||
async def get_summary_count_by_group(self, group_id: str) -> int:
|
async def get_summary_count_by_group(self, end_user_id: str) -> int:
|
||||||
"""Get count of memory summaries for a group
|
"""Get count of memory summaries for a group
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -268,6 +268,6 @@ class MemorySummaryRepository(BaseNeo4jRepository):
|
|||||||
RETURN count(n) as count
|
RETURN count(n) as count
|
||||||
"""
|
"""
|
||||||
|
|
||||||
results = await self.connector.execute_query(query, end_user_id=group_id)
|
results = await self.connector.execute_query(query, end_user_id=end_user_id)
|
||||||
return results[0]['count'] if results else 0
|
return results[0]['count'] if results else 0
|
||||||
|
|
||||||
@@ -70,11 +70,7 @@ class Neo4jConnector:
|
|||||||
List[Dict[str, Any]]: 查询结果列表,每个元素是一个字典
|
List[Dict[str, Any]]: 查询结果列表,每个元素是一个字典
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
>>> connector = Neo4jConnector()
|
|
||||||
>>> results = await connector.execute_query(
|
|
||||||
... "MATCH (n:Person {name: $name}) RETURN n",
|
|
||||||
... name="Alice"
|
|
||||||
... )
|
|
||||||
"""
|
"""
|
||||||
result = await self.driver.execute_query(
|
result = await self.driver.execute_query(
|
||||||
query,
|
query,
|
||||||
@@ -98,17 +94,7 @@ class Neo4jConnector:
|
|||||||
Any: 事务函数的返回值
|
Any: 事务函数的返回值
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
>>> async def create_node(tx, name):
|
|
||||||
... result = await tx.run(
|
|
||||||
... "CREATE (n:Person {name: $name}) RETURN n",
|
|
||||||
... name=name
|
|
||||||
... )
|
|
||||||
... return await result.single()
|
|
||||||
>>>
|
|
||||||
>>> connector = Neo4jConnector()
|
|
||||||
>>> result = await connector.execute_write_transaction(
|
|
||||||
... create_node, name="Alice"
|
|
||||||
... )
|
|
||||||
"""
|
"""
|
||||||
async with self.driver.session(database="neo4j") as session:
|
async with self.driver.session(database="neo4j") as session:
|
||||||
return await session.execute_write(transaction_func, **kwargs)
|
return await session.execute_write(transaction_func, **kwargs)
|
||||||
@@ -126,17 +112,7 @@ class Neo4jConnector:
|
|||||||
Any: 事务函数的返回值
|
Any: 事务函数的返回值
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
>>> async def get_node(tx, name):
|
|
||||||
... result = await tx.run(
|
|
||||||
... "MATCH (n:Person {name: $name}) RETURN n",
|
|
||||||
... name=name
|
|
||||||
... )
|
|
||||||
... return await result.single()
|
|
||||||
>>>
|
|
||||||
>>> connector = Neo4jConnector()
|
|
||||||
>>> result = await connector.execute_read_transaction(
|
|
||||||
... get_node, name="Alice"
|
|
||||||
... )
|
|
||||||
"""
|
"""
|
||||||
async with self.driver.session(database="neo4j") as session:
|
async with self.driver.session(database="neo4j") as session:
|
||||||
return await session.execute_read(transaction_func, **kwargs)
|
return await session.execute_read(transaction_func, **kwargs)
|
||||||
@@ -151,8 +127,6 @@ class Neo4jConnector:
|
|||||||
end_user_id: 要删除的组ID
|
end_user_id: 要删除的组ID
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
>>> connector = Neo4jConnector()
|
|
||||||
>>> await connector.delete_group("group_123")
|
|
||||||
Group group_123 deleted.
|
Group group_123 deleted.
|
||||||
"""
|
"""
|
||||||
# 删除节点(DETACH DELETE会同时删除相关的边)
|
# 删除节点(DETACH DELETE会同时删除相关的边)
|
||||||
|
|||||||
@@ -564,7 +564,7 @@ class MemoryAgentService:
|
|||||||
# 使用 upsert 方法
|
# 使用 upsert 方法
|
||||||
repo.upsert(
|
repo.upsert(
|
||||||
end_user_id=end_user_id,
|
end_user_id=end_user_id,
|
||||||
messages=message,
|
messages=ori_message,
|
||||||
aimessages=summary,
|
aimessages=summary,
|
||||||
retrieved_content=retrieved_content,
|
retrieved_content=retrieved_content,
|
||||||
search_switch=str(search_switch)
|
search_switch=str(search_switch)
|
||||||
|
|||||||
@@ -139,7 +139,7 @@ class MemoryAPIService:
|
|||||||
# Delegate to MemoryAgentService
|
# Delegate to MemoryAgentService
|
||||||
result = await MemoryAgentService().write_memory(
|
result = await MemoryAgentService().write_memory(
|
||||||
end_user_id=end_user_id,
|
end_user_id=end_user_id,
|
||||||
message=message,
|
messages=message,
|
||||||
config_id=config_id,
|
config_id=config_id,
|
||||||
db=self.db,
|
db=self.db,
|
||||||
storage_type=storage_type,
|
storage_type=storage_type,
|
||||||
|
|||||||
@@ -30,9 +30,10 @@ config_logger = get_config_logger()
|
|||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
def _validate_config_id(config_id):
|
def _validate_config_id(config_id):
|
||||||
"""Validate configuration ID format."""
|
"""Validate configuration ID format (supports both UUID and integer)."""
|
||||||
if isinstance(config_id, uuid.UUID):
|
if isinstance(config_id, uuid.UUID):
|
||||||
return config_id
|
return config_id
|
||||||
|
|
||||||
if config_id is None:
|
if config_id is None:
|
||||||
raise InvalidConfigError(
|
raise InvalidConfigError(
|
||||||
"Configuration ID cannot be None",
|
"Configuration ID cannot be None",
|
||||||
@@ -50,8 +51,17 @@ def _validate_config_id(config_id):
|
|||||||
return config_id
|
return config_id
|
||||||
|
|
||||||
if isinstance(config_id, str):
|
if isinstance(config_id, str):
|
||||||
|
config_id_stripped = config_id.strip()
|
||||||
|
|
||||||
|
# Try parsing as UUID first
|
||||||
try:
|
try:
|
||||||
parsed_id = int(config_id.strip())
|
return uuid.UUID(config_id_stripped)
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Fall back to integer parsing
|
||||||
|
try:
|
||||||
|
parsed_id = int(config_id_stripped)
|
||||||
if parsed_id <= 0:
|
if parsed_id <= 0:
|
||||||
raise InvalidConfigError(
|
raise InvalidConfigError(
|
||||||
f"Configuration ID must be positive: {parsed_id}",
|
f"Configuration ID must be positive: {parsed_id}",
|
||||||
@@ -61,13 +71,13 @@ def _validate_config_id(config_id):
|
|||||||
return parsed_id
|
return parsed_id
|
||||||
except ValueError:
|
except ValueError:
|
||||||
raise InvalidConfigError(
|
raise InvalidConfigError(
|
||||||
f"Invalid configuration ID format: '{config_id}'",
|
f"Invalid configuration ID format: '{config_id}' (must be UUID or positive integer)",
|
||||||
field_name="config_id",
|
field_name="config_id",
|
||||||
invalid_value=config_id,
|
invalid_value=config_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
raise InvalidConfigError(
|
raise InvalidConfigError(
|
||||||
f"Invalid type for configuration ID: expected int or str, got {type(config_id).__name__}",
|
f"Invalid type for configuration ID: expected UUID, int or str, got {type(config_id).__name__}",
|
||||||
field_name="config_id",
|
field_name="config_id",
|
||||||
invalid_value=config_id,
|
invalid_value=config_id,
|
||||||
)
|
)
|
||||||
@@ -113,7 +123,7 @@ class MemoryConfigService:
|
|||||||
ConfigurationError: If validation fails
|
ConfigurationError: If validation fails
|
||||||
"""
|
"""
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
validated_config_id = _validate_config_id(config_id)
|
|
||||||
config_logger.info(
|
config_logger.info(
|
||||||
"Starting memory configuration loading",
|
"Starting memory configuration loading",
|
||||||
extra={
|
extra={
|
||||||
@@ -126,27 +136,11 @@ class MemoryConfigService:
|
|||||||
logger.info(f"Loading memory configuration from database: config_id={config_id}")
|
logger.info(f"Loading memory configuration from database: config_id={config_id}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Validate config_id is UUID
|
validated_config_id = _validate_config_id(config_id)
|
||||||
if not isinstance(config_id, UUID):
|
|
||||||
if isinstance(config_id, str):
|
|
||||||
try:
|
|
||||||
config_id = UUID(config_id)
|
|
||||||
except ValueError:
|
|
||||||
raise InvalidConfigError(
|
|
||||||
f"Invalid UUID format for config_id: {config_id}",
|
|
||||||
field_name="config_id",
|
|
||||||
invalid_value=config_id,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise InvalidConfigError(
|
|
||||||
f"config_id must be UUID or valid UUID string, got {type(config_id).__name__}",
|
|
||||||
field_name="config_id",
|
|
||||||
invalid_value=config_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Step 1: Get config and workspace
|
# Step 1: Get config and workspace
|
||||||
db_query_start = time.time()
|
db_query_start = time.time()
|
||||||
result = MemoryConfigRepository.get_config_with_workspace(self.db, config_id)
|
result = MemoryConfigRepository.get_config_with_workspace(self.db, validated_config_id)
|
||||||
db_query_time = time.time() - db_query_start
|
db_query_time = time.time() - db_query_start
|
||||||
logger.info(f"[PERF] Config+Workspace query: {db_query_time:.4f}s")
|
logger.info(f"[PERF] Config+Workspace query: {db_query_time:.4f}s")
|
||||||
if not result:
|
if not result:
|
||||||
@@ -170,7 +164,7 @@ class MemoryConfigService:
|
|||||||
# Step 2: Validate embedding model (returns both UUID and name)
|
# Step 2: Validate embedding model (returns both UUID and name)
|
||||||
embed_start = time.time()
|
embed_start = time.time()
|
||||||
embedding_uuid, embedding_name = validate_embedding_model(
|
embedding_uuid, embedding_name = validate_embedding_model(
|
||||||
config_id,
|
validated_config_id,
|
||||||
memory_config.embedding_id,
|
memory_config.embedding_id,
|
||||||
self.db,
|
self.db,
|
||||||
workspace.tenant_id,
|
workspace.tenant_id,
|
||||||
@@ -187,7 +181,7 @@ class MemoryConfigService:
|
|||||||
self.db,
|
self.db,
|
||||||
workspace.tenant_id,
|
workspace.tenant_id,
|
||||||
required=True,
|
required=True,
|
||||||
config_id=config_id,
|
config_id=validated_config_id,
|
||||||
workspace_id=workspace.id,
|
workspace_id=workspace.id,
|
||||||
)
|
)
|
||||||
llm_time = time.time() - llm_start
|
llm_time = time.time() - llm_start
|
||||||
@@ -204,7 +198,7 @@ class MemoryConfigService:
|
|||||||
self.db,
|
self.db,
|
||||||
workspace.tenant_id,
|
workspace.tenant_id,
|
||||||
required=False,
|
required=False,
|
||||||
config_id=config_id,
|
config_id=validated_config_id,
|
||||||
workspace_id=workspace.id,
|
workspace_id=workspace.id,
|
||||||
)
|
)
|
||||||
rerank_time = time.time() - rerank_start
|
rerank_time = time.time() - rerank_start
|
||||||
@@ -262,7 +256,7 @@ class MemoryConfigService:
|
|||||||
extra={
|
extra={
|
||||||
"operation": "load_memory_config",
|
"operation": "load_memory_config",
|
||||||
"service": service_name,
|
"service": service_name,
|
||||||
"config_id": str(config_id),
|
"config_id": validated_config_id,
|
||||||
"config_name": config.config_name,
|
"config_name": config.config_name,
|
||||||
"workspace_id": str(config.workspace_id),
|
"workspace_id": str(config.workspace_id),
|
||||||
"load_result": "success",
|
"load_result": "success",
|
||||||
|
|||||||
@@ -505,29 +505,6 @@ async def search_edges(end_user_id: Optional[str] = None) -> List[Dict[str, Any]
|
|||||||
)
|
)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
async def search_entity_graph(end_user_id: Optional[str] = None) -> Dict[str, Any]:
|
|
||||||
"""搜索所有实体之间的关系网络(group 维度)。"""
|
|
||||||
result = await _neo4j_connector.execute_query(
|
|
||||||
DataConfigRepository.SEARCH_FOR_ENTITY_GRAPH,
|
|
||||||
end_user_id=end_user_id,
|
|
||||||
)
|
|
||||||
# 对source_node 和 target_node 的 fact_summary进行截取,只截取前三条的内容(需要提取前三条“来源”)
|
|
||||||
for item in result:
|
|
||||||
source_fact = item["sourceNode"]["fact_summary"]
|
|
||||||
target_fact = item["targetNode"]["fact_summary"]
|
|
||||||
# 截取前三条“来源”
|
|
||||||
item["sourceNode"]["fact_summary"] = source_fact.split("\n")[:4] if source_fact else []
|
|
||||||
item["targetNode"]["fact_summary"] = target_fact.split("\n")[:4] if target_fact else []
|
|
||||||
# 与现有返回风格保持一致,携带搜索类型、数量与详情
|
|
||||||
data = {
|
|
||||||
"search_for": "entity_graph",
|
|
||||||
"num": len(result),
|
|
||||||
"detials": result,
|
|
||||||
}
|
|
||||||
return data
|
|
||||||
|
|
||||||
|
|
||||||
async def analytics_hot_memory_tags(
|
async def analytics_hot_memory_tags(
|
||||||
db: Session,
|
db: Session,
|
||||||
current_user: User,
|
current_user: User,
|
||||||
|
|||||||
@@ -531,7 +531,7 @@ def write_message_task(self, end_user_id: str, message: str, config_id: str, sto
|
|||||||
except Exception:
|
except Exception:
|
||||||
# Log but continue - will fail later with proper error
|
# Log but continue - will fail later with proper error
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def _run() -> str:
|
async def _run() -> str:
|
||||||
db = next(get_db())
|
db = next(get_db())
|
||||||
try:
|
try:
|
||||||
@@ -619,53 +619,53 @@ def reflection_timer_task() -> None:
|
|||||||
"""
|
"""
|
||||||
reflection_engine()
|
reflection_engine()
|
||||||
|
|
||||||
|
# unused task
|
||||||
@celery_app.task(name="app.core.memory.agent.health.check_read_service")
|
# @celery_app.task(name="app.core.memory.agent.health.check_read_service")
|
||||||
def check_read_service_task() -> Dict[str, str]:
|
# def check_read_service_task() -> Dict[str, str]:
|
||||||
"""Call read_service and write latest status to Redis.
|
# """Call read_service and write latest status to Redis.
|
||||||
|
|
||||||
Returns status data dict that gets written to Redis.
|
# Returns status data dict that gets written to Redis.
|
||||||
"""
|
# """
|
||||||
client = redis.Redis(
|
# client = redis.Redis(
|
||||||
host=settings.REDIS_HOST,
|
# host=settings.REDIS_HOST,
|
||||||
port=settings.REDIS_PORT,
|
# port=settings.REDIS_PORT,
|
||||||
db=settings.REDIS_DB,
|
# db=settings.REDIS_DB,
|
||||||
password=settings.REDIS_PASSWORD if settings.REDIS_PASSWORD else None
|
# password=settings.REDIS_PASSWORD if settings.REDIS_PASSWORD else None
|
||||||
)
|
# )
|
||||||
try:
|
# try:
|
||||||
api_url = f"http://{settings.SERVER_IP}:8000/api/memory/read_service"
|
# api_url = f"http://{settings.SERVER_IP}:8000/api/memory/read_service"
|
||||||
payload = {
|
# payload = {
|
||||||
"user_id": "健康检查",
|
# "user_id": "健康检查",
|
||||||
"apply_id": "健康检查",
|
# "apply_id": "健康检查",
|
||||||
"end_user_id": "健康检查",
|
# "group_id": "健康检查",
|
||||||
"message": "你好",
|
# "message": "你好",
|
||||||
"history": [],
|
# "history": [],
|
||||||
"search_switch": "2",
|
# "search_switch": "2",
|
||||||
}
|
# }
|
||||||
resp = requests.post(api_url, json=payload, timeout=15)
|
# resp = requests.post(api_url, json=payload, timeout=15)
|
||||||
ok = resp.status_code == 200
|
# ok = resp.status_code == 200
|
||||||
status = "Success" if ok else "Fail"
|
# status = "Success" if ok else "Fail"
|
||||||
msg = "接口请求成功" if ok else f"接口请求失败: {resp.status_code}"
|
# msg = "接口请求成功" if ok else f"接口请求失败: {resp.status_code}"
|
||||||
error = "" if ok else resp.text
|
# error = "" if ok else resp.text
|
||||||
code = 0 if ok else 500
|
# code = 0 if ok else 500
|
||||||
except Exception as e:
|
# except Exception as e:
|
||||||
status = "Fail"
|
# status = "Fail"
|
||||||
msg = "接口请求失败"
|
# msg = "接口请求失败"
|
||||||
error = str(e)
|
# error = str(e)
|
||||||
code = 500
|
# code = 500
|
||||||
|
|
||||||
data = {
|
# data = {
|
||||||
"status": status,
|
# "status": status,
|
||||||
"msg": msg,
|
# "msg": msg,
|
||||||
"error": error,
|
# "error": error,
|
||||||
"code": str(code),
|
# "code": str(code),
|
||||||
"time": str(int(time.time())),
|
# "time": str(int(time.time())),
|
||||||
}
|
# }
|
||||||
|
|
||||||
client.hset("memsci:health:read_service", mapping=data)
|
# client.hset("memsci:health:read_service", mapping=data)
|
||||||
client.expire("memsci:health:read_service", int(settings.HEALTH_CHECK_SECONDS))
|
# client.expire("memsci:health:read_service", int(settings.HEALTH_CHECK_SECONDS))
|
||||||
|
|
||||||
return data
|
# return data
|
||||||
|
|
||||||
|
|
||||||
@celery_app.task(name="app.controllers.memory_storage_controller.search_all")
|
@celery_app.task(name="app.controllers.memory_storage_controller.search_all")
|
||||||
|
|||||||
Reference in New Issue
Block a user